Add support for configurable embedding model (#1627)

Co-authored-by: Dev Khant <devkhant24@gmail.com>
This commit is contained in:
Mitul Kataria
2024-08-12 18:39:18 +09:00
committed by GitHub
parent 4aae2b5cca
commit 464a188662
8 changed files with 88 additions and 23 deletions

View File

@@ -10,9 +10,10 @@ class BaseEmbedderConfig(ABC):
self, self,
model: Optional[str] = None, model: Optional[str] = None,
embedding_dims: Optional[int] = None, embedding_dims: Optional[int] = None,
# Ollama specific # Ollama specific
base_url: Optional[str] = None base_url: Optional[str] = None,
# Huggingface specific
model_kwargs: Optional[dict] = None
): ):
""" """
Initializes a configuration class instance for the Embeddings. Initializes a configuration class instance for the Embeddings.
@@ -23,10 +24,16 @@ class BaseEmbedderConfig(ABC):
:type embedding_dims: Optional[int], optional :type embedding_dims: Optional[int], optional
:param base_url: Base URL for the Ollama API, defaults to None :param base_url: Base URL for the Ollama API, defaults to None
:type base_url: Optional[str], optional :type base_url: Optional[str], optional
:param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
""" """
self.model = model self.model = model
self.embedding_dims = embedding_dims self.embedding_dims = embedding_dims
# Ollama specific # Ollama specific
self.base_url = base_url self.base_url = base_url
# Huggingface specific
self.model_kwargs = model_kwargs or {}

View File

@@ -0,0 +1,36 @@
from typing import Optional
from openai import AzureOpenAI
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
class AzureOpenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None:
self.config.model = "text-embedding-3-small"
if self.config.embedding_dims is None:
self.config.embedding_dims = 1536
self.client = AzureOpenAI()
def embed(self, text):
"""
Get the embedding for the given text using OpenAI.
Args:
text (str): The text to embed.
Returns:
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(
input=[text],
model=self.config.model
)
.data[0]
.embedding
)

View File

@@ -3,19 +3,18 @@ from abc import ABC, abstractmethod
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
class EmbeddingBase(ABC): class EmbeddingBase(ABC):
def __init__(self, config: Optional[BaseEmbedderConfig] = None): """Initialized a base embedding class
"""Initialize a base LLM class
:param config: Embedder configuration option class, defaults to None :param config: Embedding configuration option class, defaults to None
:type config: Optional[BaseEmbedderConfig], optional :type config: Optional[BaseEmbedderConfig], optional
""" """
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
if config is None: if config is None:
self.config = BaseEmbedderConfig() self.config = BaseEmbedderConfig()
else: else:
self.config = config self.config = config
@abstractmethod @abstractmethod
def embed(self, text): def embed(self, text):
""" """

View File

@@ -9,13 +9,14 @@ class EmbedderConfig(BaseModel):
default="openai", default="openai",
) )
config: Optional[dict] = Field( config: Optional[dict] = Field(
description="Configuration for the specific embedding model", default=None description="Configuration for the specific embedding model",
default={}
) )
@field_validator("config") @field_validator("config")
def validate_config(cls, v, values): def validate_config(cls, v, values):
provider = values.data.get("provider") provider = values.data.get("provider")
if provider in ["openai", "ollama"]: if provider in ["openai", "ollama", "huggingface", "azure_openai"]:
return v return v
else: else:
raise ValueError(f"Unsupported embedding provider: {provider}") raise ValueError(f"Unsupported embedding provider: {provider}")

View File

@@ -1,11 +1,27 @@
from mem0.embeddings.base import EmbeddingBase from typing import Optional
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
class HuggingFaceEmbedding(EmbeddingBase): class HuggingFaceEmbedding(EmbeddingBase):
def __init__(self, model_name="multi-qa-MiniLM-L6-cos-v1"): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
self.model = SentenceTransformer(model_name) super().__init__(config)
if self.config.model is None:
self.config.model = "multi-qa-MiniLM-L6-cos-v1"
self.model = SentenceTransformer(
self.config.model,
**self.config.model_kwargs
)
if self.config.embedding_dims is None:
self.config.embedding_dims = self.model.get_sentence_embedding_dimension()
def embed(self, text): def embed(self, text):
""" """
Get the embedding for the given text using Hugging Face. Get the embedding for the given text using Hugging Face.

View File

@@ -1,10 +1,10 @@
from typing import Optional from typing import Optional
from openai import OpenAI from openai import OpenAI
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase from mem0.embeddings.base import EmbeddingBase
class OpenAIEmbedding(EmbeddingBase): class OpenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config) super().__init__(config)
@@ -28,7 +28,10 @@ class OpenAIEmbedding(EmbeddingBase):
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
return ( return (
self.client.embeddings.create(input=[text], model=self.config.model) self.client.embeddings.create(
input=[text],
model=self.config.model
)
.data[0] .data[0]
.embedding .embedding
) )

View File

@@ -28,7 +28,7 @@ setup_config()
class Memory(MemoryBase): class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()): def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config self.config = config
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider) self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
self.vector_store = VectorStoreFactory.create(self.config.vector_store.provider, self.config.vector_store.config) self.vector_store = VectorStoreFactory.create(self.config.vector_store.provider, self.config.vector_store.config)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config) self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path) self.db = SQLiteManager(self.config.history_db_path)

View File

@@ -1,7 +1,7 @@
import importlib import importlib
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.configs.embeddings.base import BaseEmbedderConfig
def load_class(class_type): def load_class(class_type):
module_path, class_name = class_type.rsplit(".", 1) module_path, class_name = class_type.rsplit(".", 1)
@@ -33,15 +33,18 @@ class LlmFactory:
class EmbedderFactory: class EmbedderFactory:
provider_to_class = { provider_to_class = {
"openai": "mem0.embeddings.openai.OpenAIEmbedding", "openai": "mem0.embeddings.openai.OpenAIEmbedding",
"ollama": "mem0.embeddings.ollama.OllamaEmbedding" "ollama": "mem0.embeddings.ollama.OllamaEmbedding",
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
} }
@classmethod @classmethod
def create(cls, provider_name): def create(cls, provider_name, config):
class_type = cls.provider_to_class.get(provider_name) class_type = cls.provider_to_class.get(provider_name)
if class_type: if class_type:
embedder_instance = load_class(class_type)() embedder_instance = load_class(class_type)
return embedder_instance base_config = BaseEmbedderConfig(**config)
return embedder_instance(base_config)
else: else:
raise ValueError(f"Unsupported Embedder provider: {provider_name}") raise ValueError(f"Unsupported Embedder provider: {provider_name}")