Distance metric change and PGVectorScale support (#1703)

This commit is contained in:
Tibor Sloboda
2024-08-27 13:26:01 +02:00
committed by GitHub
parent e8004537c1
commit a94bd11a76
4 changed files with 24 additions and 5 deletions

View File

@@ -36,4 +36,5 @@ Here's the parameters available for configuring pgvector:
| `user` | User name to connect to the database | `None` | | `user` | User name to connect to the database | `None` |
| `password` | Password to connect to the database | `None` | | `password` | Password to connect to the database | `None` |
| `host` | The host where the Postgres server is running | `None` | | `host` | The host where the Postgres server is running | `None` |
| `port` | The port where the Postgres server is running | `None` | | `port` | The port where the Postgres server is running | `None` |
| `diskann` | Whether to use diskann for vector similarity search (requires pgvectorscale) | `True` |

View File

@@ -14,6 +14,7 @@ class PGVectorConfig(BaseModel):
password: Optional[str] = Field(None, description="Database password") password: Optional[str] = Field(None, description="Database password")
host: Optional[str] = Field(None, description="Database host. Default is localhost") host: Optional[str] = Field(None, description="Database host. Default is localhost")
port: Optional[int] = Field(None, description="Database port. Default is 1536") port: Optional[int] = Field(None, description="Database port. Default is 1536")
diskann: Optional[bool] = Field(True, description="Use diskann for approximate nearest neighbors search")
@model_validator(mode="before") @model_validator(mode="before")
def check_auth_and_connection(cls, values): def check_auth_and_connection(cls, values):

View File

@@ -39,7 +39,8 @@ class VectorStoreConfig(BaseModel):
raise ValueError(f"Invalid config type for provider {provider}") raise ValueError(f"Invalid config type for provider {provider}")
return self return self
if "path" not in config: # also check if path in allowed kays for pydantic model, and whether config extra fields are allowed
if "path" not in config and "path" in config_class.__annotations__:
config["path"] = f"/tmp/{provider}" config["path"] = f"/tmp/{provider}"
self.config = config_class(**config) self.config = config_class(**config)

View File

@@ -22,7 +22,7 @@ class OutputData(BaseModel):
class PGVector(VectorStoreBase): class PGVector(VectorStoreBase):
def __init__( def __init__(
self, dbname, collection_name, embedding_model_dims, user, password, host, port self, dbname, collection_name, embedding_model_dims, user, password, host, port, diskann
): ):
""" """
Initialize the PGVector database. Initialize the PGVector database.
@@ -35,8 +35,10 @@ class PGVector(VectorStoreBase):
password (str): Database password password (str): Database password
host (str, optional): Database host host (str, optional): Database host
port (int, optional): Database port port (int, optional): Database port
diskann (bool, optional): Use DiskANN for faster search
""" """
self.collection_name = collection_name self.collection_name = collection_name
self.use_diskann = diskann
self.conn = psycopg2.connect( self.conn = psycopg2.connect(
dbname=dbname, user=user, password=password, host=host, port=port dbname=dbname, user=user, password=password, host=host, port=port
@@ -50,6 +52,7 @@ class PGVector(VectorStoreBase):
def create_col(self, embedding_model_dims): def create_col(self, embedding_model_dims):
""" """
Create a new collection (table in PostgreSQL). Create a new collection (table in PostgreSQL).
Will also initialize DiskANN index if the extension is installed.
Args: Args:
name (str): Name of the collection. name (str): Name of the collection.
@@ -64,6 +67,19 @@ class PGVector(VectorStoreBase):
); );
""" """
) )
if self.use_diskann and embedding_model_dims < 2000:
# Check if vectorscale extension is installed
self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
if self.cur.fetchone():
# Create DiskANN index if extension is installed for faster search
self.cur.execute(f"""
CREATE INDEX IF NOT EXISTS {self.collection_name}_vector_idx
ON {self.collection_name}
USING diskann (vector);
"""
)
self.conn.commit() self.conn.commit()
def insert(self, vectors, payloads=None, ids=None): def insert(self, vectors, payloads=None, ids=None):
@@ -114,13 +130,13 @@ class PGVector(VectorStoreBase):
self.cur.execute( self.cur.execute(
f""" f"""
SELECT id, vector <-> %s::vector AS distance, payload SELECT id, vector <=> %s::vector AS distance, payload
FROM {self.collection_name} FROM {self.collection_name}
{filter_clause} {filter_clause}
ORDER BY distance ORDER BY distance
LIMIT %s LIMIT %s
""", """,
(query, *filter_params, limit), (query, *filter_params, limit),
) )
results = self.cur.fetchall() results = self.cur.fetchall()