Add embeding_dims param to FAISS (#2513)
This commit is contained in:
@@ -12,6 +12,7 @@ class FAISSConfig(BaseModel):
|
|||||||
normalize_L2: bool = Field(
|
normalize_L2: bool = Field(
|
||||||
False, description="Whether to normalize L2 vectors (only applicable for euclidean distance)"
|
False, description="Whether to normalize L2 vectors (only applicable for euclidean distance)"
|
||||||
)
|
)
|
||||||
|
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class FAISS(VectorStoreBase):
|
|||||||
path: Optional[str] = None,
|
path: Optional[str] = None,
|
||||||
distance_strategy: str = "euclidean",
|
distance_strategy: str = "euclidean",
|
||||||
normalize_L2: bool = False,
|
normalize_L2: bool = False,
|
||||||
|
embedding_model_dims: int = 1536,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the FAISS vector store.
|
Initialize the FAISS vector store.
|
||||||
@@ -51,6 +52,7 @@ class FAISS(VectorStoreBase):
|
|||||||
self.path = path or f"/tmp/faiss/{collection_name}"
|
self.path = path or f"/tmp/faiss/{collection_name}"
|
||||||
self.distance_strategy = distance_strategy
|
self.distance_strategy = distance_strategy
|
||||||
self.normalize_L2 = normalize_L2
|
self.normalize_L2 = normalize_L2
|
||||||
|
self.embedding_model_dims = embedding_model_dims
|
||||||
|
|
||||||
# Initialize storage structures
|
# Initialize storage structures
|
||||||
self.index = None
|
self.index = None
|
||||||
@@ -145,13 +147,12 @@ class FAISS(VectorStoreBase):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def create_col(self, name: str, vector_size: int = 1536, distance: str = None):
|
def create_col(self, name: str, distance: str = None):
|
||||||
"""
|
"""
|
||||||
Create a new collection.
|
Create a new collection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): Name of the collection.
|
name (str): Name of the collection.
|
||||||
vector_size (int, optional): Dimensionality of vectors. Defaults to 1536.
|
|
||||||
distance (str, optional): Distance metric to use. Overrides the distance_strategy
|
distance (str, optional): Distance metric to use. Overrides the distance_strategy
|
||||||
passed during initialization. Defaults to None.
|
passed during initialization. Defaults to None.
|
||||||
|
|
||||||
@@ -162,9 +163,9 @@ class FAISS(VectorStoreBase):
|
|||||||
|
|
||||||
# Create index based on distance strategy
|
# Create index based on distance strategy
|
||||||
if distance_strategy.lower() == "inner_product" or distance_strategy.lower() == "cosine":
|
if distance_strategy.lower() == "inner_product" or distance_strategy.lower() == "cosine":
|
||||||
self.index = faiss.IndexFlatIP(vector_size)
|
self.index = faiss.IndexFlatIP(self.embedding_model_dims)
|
||||||
else:
|
else:
|
||||||
self.index = faiss.IndexFlatL2(vector_size)
|
self.index = faiss.IndexFlatL2(self.embedding_model_dims)
|
||||||
|
|
||||||
self.collection_name = name
|
self.collection_name = name
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "mem0ai"
|
name = "mem0ai"
|
||||||
version = "0.1.85"
|
version = "0.1.86"
|
||||||
description = "Long-term memory for AI Agents"
|
description = "Long-term memory for AI Agents"
|
||||||
authors = ["Mem0 <founders@mem0.ai>"]
|
authors = ["Mem0 <founders@mem0.ai>"]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
|||||||
Reference in New Issue
Block a user