Add embeding_dims param to FAISS (#2513)

This commit is contained in:
Dev Khant
2025-04-07 23:29:55 +05:30
committed by GitHub
parent 2a79add7a5
commit cdb8dcdb9e
3 changed files with 7 additions and 5 deletions

View File

@@ -12,6 +12,7 @@ class FAISSConfig(BaseModel):
normalize_L2: bool = Field(
False, description="Whether to normalize L2 vectors (only applicable for euclidean distance)"
)
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
@model_validator(mode="before")
@classmethod

View File

@@ -35,6 +35,7 @@ class FAISS(VectorStoreBase):
path: Optional[str] = None,
distance_strategy: str = "euclidean",
normalize_L2: bool = False,
embedding_model_dims: int = 1536,
):
"""
Initialize the FAISS vector store.
@@ -51,6 +52,7 @@ class FAISS(VectorStoreBase):
self.path = path or f"/tmp/faiss/{collection_name}"
self.distance_strategy = distance_strategy
self.normalize_L2 = normalize_L2
self.embedding_model_dims = embedding_model_dims
# Initialize storage structures
self.index = None
@@ -145,13 +147,12 @@ class FAISS(VectorStoreBase):
return results
def create_col(self, name: str, vector_size: int = 1536, distance: str = None):
def create_col(self, name: str, distance: str = None):
"""
Create a new collection.
Args:
name (str): Name of the collection.
vector_size (int, optional): Dimensionality of vectors. Defaults to 1536.
distance (str, optional): Distance metric to use. Overrides the distance_strategy
passed during initialization. Defaults to None.
@@ -162,9 +163,9 @@ class FAISS(VectorStoreBase):
# Create index based on distance strategy
if distance_strategy.lower() == "inner_product" or distance_strategy.lower() == "cosine":
self.index = faiss.IndexFlatIP(vector_size)
self.index = faiss.IndexFlatIP(self.embedding_model_dims)
else:
self.index = faiss.IndexFlatL2(vector_size)
self.index = faiss.IndexFlatL2(self.embedding_model_dims)
self.collection_name = name