diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py index ec288605..02974402 100644 --- a/mem0/configs/llms/base.py +++ b/mem0/configs/llms/base.py @@ -12,6 +12,8 @@ class BaseLlmConfig(ABC): temperature: float = 0, max_tokens: int = 3000, top_p: float = 1, + + # Ollama specific base_url: Optional[str] = None ): """ @@ -35,4 +37,6 @@ class BaseLlmConfig(ABC): self.temperature = temperature self.max_tokens = max_tokens self.top_p = top_p - self.base_url = base_url \ No newline at end of file + + # Ollama specific + self.base_url = base_url diff --git a/mem0/embeddings/base.py b/mem0/embeddings/base.py index 388ddf6a..8693c4e1 100644 --- a/mem0/embeddings/base.py +++ b/mem0/embeddings/base.py @@ -1,7 +1,21 @@ +from typing import Optional from abc import ABC, abstractmethod +from mem0.configs.embeddings.base import BaseEmbedderConfig + class EmbeddingBase(ABC): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + """Initialize a base LLM class + + :param config: Embedder configuration option class, defaults to None + :type config: Optional[BaseEmbedderConfig], optional + """ + if config is None: + self.config = BaseEmbedderConfig() + else: + self.config = config + @abstractmethod def embed(self, text): """ diff --git a/mem0/embeddings/ollama.py b/mem0/embeddings/ollama.py index 00289d4e..904d8506 100644 --- a/mem0/embeddings/ollama.py +++ b/mem0/embeddings/ollama.py @@ -1,20 +1,33 @@ -import ollama +from typing import Optional + +from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.embeddings.base import EmbeddingBase +try: + from ollama import Client +except ImportError: + raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None + class OllamaEmbedding(EmbeddingBase): - def __init__(self, model="nomic-embed-text"): - self.model = model + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config) + + if not self.config.model: + self.config.model="nomic-embed-text" + if not self.config.embedding_dims: + self.config.embedding_dims=512 + + self.client = Client(host=self.config.base_url) self._ensure_model_exists() - self.dims = 512 def _ensure_model_exists(self): """ Ensure the specified model exists locally. If not, pull it from Ollama. """ - model_list = [m["name"] for m in ollama.list()["models"]] - if not any(m.startswith(self.model) for m in model_list): - ollama.pull(self.model) + local_models = self.client.list()["models"] + if not any(model.get("name") == self.config.model for model in local_models): + self.client.pull(self.config.model) def embed(self, text): """ @@ -26,5 +39,5 @@ class OllamaEmbedding(EmbeddingBase): Returns: list: The embedding vector. """ - response = ollama.embeddings(model=self.model, prompt=text) + response = self.client.embeddings(model=self.config.model, prompt=text) return response["embedding"] diff --git a/mem0/embeddings/openai.py b/mem0/embeddings/openai.py index 61968397..c10c00e1 100644 --- a/mem0/embeddings/openai.py +++ b/mem0/embeddings/openai.py @@ -1,13 +1,20 @@ +from typing import Optional from openai import OpenAI +from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.embeddings.base import EmbeddingBase class OpenAIEmbedding(EmbeddingBase): - def __init__(self, model="text-embedding-3-small"): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config) + + if not self.config.model: + self.config.model="text-embedding-3-small" + if not self.config.embedding_dims: + self.config.embedding_dims=1536 + self.client = OpenAI() - self.model = model - self.dims = 1536 def embed(self, text): """ @@ -21,7 +28,7 @@ class OpenAIEmbedding(EmbeddingBase): """ text = text.replace("\n", " ") return ( - self.client.embeddings.create(input=[text], model=self.model) + self.client.embeddings.create(input=[text], model=self.config.model) .data[0] .embedding ) diff --git a/mem0/memory/telemetry.py b/mem0/memory/telemetry.py index b894e257..f0fba296 100644 --- a/mem0/memory/telemetry.py +++ b/mem0/memory/telemetry.py @@ -48,7 +48,7 @@ telemetry = AnonymousTelemetry( def capture_event(event_name, memory_instance, additional_data=None): event_data = { "collection": memory_instance.collection_name, - "vector_size": memory_instance.embedding_model.dims, + "vector_size": memory_instance.embedding_model.config.embedding_dims, "history_store": "sqlite", "vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}", "llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}", diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 16076e38..9c72efd6 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -33,8 +33,7 @@ class LlmFactory: class EmbedderFactory: provider_to_class = { "openai": "mem0.embeddings.openai.OpenAIEmbedding", - "ollama": "mem0.embeddings.ollama.OllamaEmbedding", - "huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding" + "ollama": "mem0.embeddings.ollama.OllamaEmbedding" } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index c019f3d0..abfb7e36 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -73,8 +73,12 @@ class VectorStoreConfig(BaseModel): if isinstance(v, dict): if provider == "qdrant": + if "path" not in v: + v["path"] = "/tmp/qdrant" return QdrantConfig(**v) elif provider == "chromadb": + if "path" not in v: + v["path"] = "/tmp/chromadb" return ChromaDbConfig(**v) return v