Distance metric change and PGVectorScale support (#1703)
This commit is contained in:
@@ -36,4 +36,5 @@ Here's the parameters available for configuring pgvector:
|
|||||||
| `user` | User name to connect to the database | `None` |
|
| `user` | User name to connect to the database | `None` |
|
||||||
| `password` | Password to connect to the database | `None` |
|
| `password` | Password to connect to the database | `None` |
|
||||||
| `host` | The host where the Postgres server is running | `None` |
|
| `host` | The host where the Postgres server is running | `None` |
|
||||||
| `port` | The port where the Postgres server is running | `None` |
|
| `port` | The port where the Postgres server is running | `None` |
|
||||||
|
| `diskann` | Whether to use diskann for vector similarity search (requires pgvectorscale) | `True` |
|
||||||
@@ -14,6 +14,7 @@ class PGVectorConfig(BaseModel):
|
|||||||
password: Optional[str] = Field(None, description="Database password")
|
password: Optional[str] = Field(None, description="Database password")
|
||||||
host: Optional[str] = Field(None, description="Database host. Default is localhost")
|
host: Optional[str] = Field(None, description="Database host. Default is localhost")
|
||||||
port: Optional[int] = Field(None, description="Database port. Default is 1536")
|
port: Optional[int] = Field(None, description="Database port. Default is 1536")
|
||||||
|
diskann: Optional[bool] = Field(True, description="Use diskann for approximate nearest neighbors search")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
def check_auth_and_connection(cls, values):
|
def check_auth_and_connection(cls, values):
|
||||||
|
|||||||
@@ -39,7 +39,8 @@ class VectorStoreConfig(BaseModel):
|
|||||||
raise ValueError(f"Invalid config type for provider {provider}")
|
raise ValueError(f"Invalid config type for provider {provider}")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if "path" not in config:
|
# also check if path in allowed kays for pydantic model, and whether config extra fields are allowed
|
||||||
|
if "path" not in config and "path" in config_class.__annotations__:
|
||||||
config["path"] = f"/tmp/{provider}"
|
config["path"] = f"/tmp/{provider}"
|
||||||
|
|
||||||
self.config = config_class(**config)
|
self.config = config_class(**config)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class OutputData(BaseModel):
|
|||||||
|
|
||||||
class PGVector(VectorStoreBase):
|
class PGVector(VectorStoreBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dbname, collection_name, embedding_model_dims, user, password, host, port
|
self, dbname, collection_name, embedding_model_dims, user, password, host, port, diskann
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the PGVector database.
|
Initialize the PGVector database.
|
||||||
@@ -35,8 +35,10 @@ class PGVector(VectorStoreBase):
|
|||||||
password (str): Database password
|
password (str): Database password
|
||||||
host (str, optional): Database host
|
host (str, optional): Database host
|
||||||
port (int, optional): Database port
|
port (int, optional): Database port
|
||||||
|
diskann (bool, optional): Use DiskANN for faster search
|
||||||
"""
|
"""
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
|
self.use_diskann = diskann
|
||||||
|
|
||||||
self.conn = psycopg2.connect(
|
self.conn = psycopg2.connect(
|
||||||
dbname=dbname, user=user, password=password, host=host, port=port
|
dbname=dbname, user=user, password=password, host=host, port=port
|
||||||
@@ -50,6 +52,7 @@ class PGVector(VectorStoreBase):
|
|||||||
def create_col(self, embedding_model_dims):
|
def create_col(self, embedding_model_dims):
|
||||||
"""
|
"""
|
||||||
Create a new collection (table in PostgreSQL).
|
Create a new collection (table in PostgreSQL).
|
||||||
|
Will also initialize DiskANN index if the extension is installed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): Name of the collection.
|
name (str): Name of the collection.
|
||||||
@@ -64,6 +67,19 @@ class PGVector(VectorStoreBase):
|
|||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_diskann and embedding_model_dims < 2000:
|
||||||
|
# Check if vectorscale extension is installed
|
||||||
|
self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
|
||||||
|
if self.cur.fetchone():
|
||||||
|
# Create DiskANN index if extension is installed for faster search
|
||||||
|
self.cur.execute(f"""
|
||||||
|
CREATE INDEX IF NOT EXISTS {self.collection_name}_vector_idx
|
||||||
|
ON {self.collection_name}
|
||||||
|
USING diskann (vector);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def insert(self, vectors, payloads=None, ids=None):
|
def insert(self, vectors, payloads=None, ids=None):
|
||||||
@@ -114,13 +130,13 @@ class PGVector(VectorStoreBase):
|
|||||||
|
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
f"""
|
f"""
|
||||||
SELECT id, vector <-> %s::vector AS distance, payload
|
SELECT id, vector <=> %s::vector AS distance, payload
|
||||||
FROM {self.collection_name}
|
FROM {self.collection_name}
|
||||||
{filter_clause}
|
{filter_clause}
|
||||||
ORDER BY distance
|
ORDER BY distance
|
||||||
LIMIT %s
|
LIMIT %s
|
||||||
""",
|
""",
|
||||||
(query, *filter_params, limit),
|
(query, *filter_params, limit),
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.cur.fetchall()
|
results = self.cur.fetchall()
|
||||||
|
|||||||
Reference in New Issue
Block a user