Distance metric change and PGVectorScale support (#1703)
This commit is contained in:
@@ -22,7 +22,7 @@ class OutputData(BaseModel):
|
||||
|
||||
class PGVector(VectorStoreBase):
|
||||
def __init__(
|
||||
self, dbname, collection_name, embedding_model_dims, user, password, host, port
|
||||
self, dbname, collection_name, embedding_model_dims, user, password, host, port, diskann
|
||||
):
|
||||
"""
|
||||
Initialize the PGVector database.
|
||||
@@ -35,8 +35,10 @@ class PGVector(VectorStoreBase):
|
||||
password (str): Database password
|
||||
host (str, optional): Database host
|
||||
port (int, optional): Database port
|
||||
diskann (bool, optional): Use DiskANN for faster search
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.use_diskann = diskann
|
||||
|
||||
self.conn = psycopg2.connect(
|
||||
dbname=dbname, user=user, password=password, host=host, port=port
|
||||
@@ -50,6 +52,7 @@ class PGVector(VectorStoreBase):
|
||||
def create_col(self, embedding_model_dims):
|
||||
"""
|
||||
Create a new collection (table in PostgreSQL).
|
||||
Will also initialize DiskANN index if the extension is installed.
|
||||
|
||||
Args:
|
||||
name (str): Name of the collection.
|
||||
@@ -64,6 +67,19 @@ class PGVector(VectorStoreBase):
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
if self.use_diskann and embedding_model_dims < 2000:
|
||||
# Check if vectorscale extension is installed
|
||||
self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
|
||||
if self.cur.fetchone():
|
||||
# Create DiskANN index if extension is installed for faster search
|
||||
self.cur.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS {self.collection_name}_vector_idx
|
||||
ON {self.collection_name}
|
||||
USING diskann (vector);
|
||||
"""
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def insert(self, vectors, payloads=None, ids=None):
|
||||
@@ -114,13 +130,13 @@ class PGVector(VectorStoreBase):
|
||||
|
||||
self.cur.execute(
|
||||
f"""
|
||||
SELECT id, vector <-> %s::vector AS distance, payload
|
||||
SELECT id, vector <=> %s::vector AS distance, payload
|
||||
FROM {self.collection_name}
|
||||
{filter_clause}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
""",
|
||||
(query, *filter_params, limit),
|
||||
(query, *filter_params, limit),
|
||||
)
|
||||
|
||||
results = self.cur.fetchall()
|
||||
|
||||
Reference in New Issue
Block a user