Add config option for vertex embedding tasks (#2266)

This commit is contained in:
Wonbin Kim
2025-02-28 18:50:05 +09:00
committed by GitHub
parent 8143f86be6
commit 6acb00731d
14 changed files with 141 additions and 48 deletions

View File

@@ -57,7 +57,10 @@ Here's a comprehensive list of all parameters that can be used across different
| `model_kwargs` | Key-Value arguments for the Huggingface embedding model | | `model_kwargs` | Key-Value arguments for the Huggingface embedding model |
| `azure_kwargs` | Key-Value arguments for the AzureOpenAI embedding model | | `azure_kwargs` | Key-Value arguments for the AzureOpenAI embedding model |
| `openai_base_url` | Base URL for OpenAI API | OpenAI | | `openai_base_url` | Base URL for OpenAI API | OpenAI |
| `vertex_credentials_json` | Path to the Google Cloud credentials JSON file for VertexAI | | `vertex_credentials_json` | Path to the Google Cloud credentials JSON file for VertexAI | VertexAI |
| `memory_add_embedding_type` | The type of embedding to use for the add memory action | VertexAI |
| `memory_update_embedding_type` | The type of embedding to use for the update memory action | VertexAI |
| `memory_search_embedding_type` | The type of embedding to use for the search memory action | VertexAI |
## Supported Embedding Models ## Supported Embedding Models

View File

@@ -16,7 +16,10 @@ config = {
"embedder": { "embedder": {
"provider": "vertexai", "provider": "vertexai",
"config": { "config": {
"model": "text-embedding-004" "model": "text-embedding-004",
"memory_add_embedding_type": "RETRIEVAL_DOCUMENT",
"memory_update_embedding_type": "RETRIEVAL_DOCUMENT",
"memory_search_embedding_type": "RETRIEVAL_QUERY"
} }
} }
} }
@@ -24,6 +27,13 @@ config = {
m = Memory.from_config(config) m = Memory.from_config(config)
m.add("I'm visiting Paris", user_id="john") m.add("I'm visiting Paris", user_id="john")
``` ```
The embedding types can be one of the following:
- SEMANTIC_SIMILARITY
- CLASSIFICATION
- CLUSTERING
- RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION
- CODE_RETRIEVAL_QUERY
Check out the [Vertex AI documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types#supported_task_types) for more information.
### Config ### Config
@@ -34,3 +44,6 @@ Here are the parameters available for configuring the Vertex AI embedder:
| `model` | The name of the Vertex AI embedding model to use | `text-embedding-004` | | `model` | The name of the Vertex AI embedding model to use | `text-embedding-004` |
| `vertex_credentials_json` | Path to the Google Cloud credentials JSON file | `None` | | `vertex_credentials_json` | Path to the Google Cloud credentials JSON file | `None` |
| `embedding_dims` | Dimensions of the embedding model | `256` | | `embedding_dims` | Dimensions of the embedding model | `256` |
| `memory_add_embedding_type` | The type of embedding to use for the add memory action | `RETRIEVAL_DOCUMENT` |
| `memory_update_embedding_type` | The type of embedding to use for the update memory action | `RETRIEVAL_DOCUMENT` |
| `memory_search_embedding_type` | The type of embedding to use for the search memory action | `RETRIEVAL_QUERY` |

View File

@@ -27,6 +27,9 @@ class BaseEmbedderConfig(ABC):
http_client_proxies: Optional[Union[Dict, str]] = None, http_client_proxies: Optional[Union[Dict, str]] = None,
# VertexAI specific # VertexAI specific
vertex_credentials_json: Optional[str] = None, vertex_credentials_json: Optional[str] = None,
memory_add_embedding_type: Optional[str] = None,
memory_update_embedding_type: Optional[str] = None,
memory_search_embedding_type: Optional[str] = None,
): ):
""" """
Initializes a configuration class instance for the Embeddings. Initializes a configuration class instance for the Embeddings.
@@ -47,6 +50,14 @@ class BaseEmbedderConfig(ABC):
:type azure_kwargs: Optional[Dict[str, Any]], defaults a dict inside init :type azure_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None :param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
:type http_client_proxies: Optional[Dict | str], optional :type http_client_proxies: Optional[Dict | str], optional
:param vertex_credentials_json: The path to the Vertex AI credentials JSON file, defaults to None
:type vertex_credentials_json: Optional[str], optional
:param memory_add_embedding_type: The type of embedding to use for the add memory action, defaults to None
:type memory_add_embedding_type: Optional[str], optional
:param memory_update_embedding_type: The type of embedding to use for the update memory action, defaults to None
:type memory_update_embedding_type: Optional[str], optional
:param memory_search_embedding_type: The type of embedding to use for the search memory action, defaults to None
:type memory_search_embedding_type: Optional[str], optional
""" """
self.model = model self.model = model
@@ -68,3 +79,6 @@ class BaseEmbedderConfig(ABC):
# VertexAI specific # VertexAI specific
self.vertex_credentials_json = vertex_credentials_json self.vertex_credentials_json = vertex_credentials_json
self.memory_add_embedding_type = memory_add_embedding_type
self.memory_update_embedding_type = memory_update_embedding_type
self.memory_search_embedding_type = memory_search_embedding_type

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Optional from typing import Literal, Optional
from openai import AzureOpenAI from openai import AzureOpenAI
@@ -26,13 +26,13 @@ class AzureOpenAIEmbedding(EmbeddingBase):
default_headers=default_headers, default_headers=default_headers,
) )
def embed(self, text): def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
""" """
Get the embedding for the given text using OpenAI. Get the embedding for the given text using OpenAI.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Literal, Optional
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
@@ -18,13 +18,13 @@ class EmbeddingBase(ABC):
self.config = config self.config = config
@abstractmethod @abstractmethod
def embed(self, text): def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]]):
""" """
Get the embedding for the given text. Get the embedding for the given text.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Optional from typing import Literal, Optional
import google.generativeai as genai import google.generativeai as genai
@@ -18,11 +18,12 @@ class GoogleGenAIEmbedding(EmbeddingBase):
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
def embed(self, text): def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
""" """
Get the embedding for the given text using Google Generative AI. Get the embedding for the given text using Google Generative AI.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """

View File

@@ -1,4 +1,4 @@
from typing import Optional from typing import Literal, Optional
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
@@ -16,13 +16,13 @@ class HuggingFaceEmbedding(EmbeddingBase):
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension() self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
def embed(self, text): def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
""" """
Get the embedding for the given text using Hugging Face. Get the embedding for the given text using Hugging Face.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """

View File

@@ -1,6 +1,6 @@
import subprocess import subprocess
import sys import sys
from typing import Optional from typing import Literal, Optional
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase from mem0.embeddings.base import EmbeddingBase
@@ -39,13 +39,13 @@ class OllamaEmbedding(EmbeddingBase):
if not any(model.get("name") == self.config.model for model in local_models): if not any(model.get("name") == self.config.model for model in local_models):
self.client.pull(self.config.model) self.client.pull(self.config.model)
def embed(self, text): def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
""" """
Get the embedding for the given text using Ollama. Get the embedding for the given text using Ollama.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Optional from typing import Literal, Optional
from openai import OpenAI from openai import OpenAI
@@ -18,13 +18,13 @@ class OpenAIEmbedding(EmbeddingBase):
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE")
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)
def embed(self, text): def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
""" """
Get the embedding for the given text using OpenAI. Get the embedding for the given text using OpenAI.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Optional from typing import Literal, Optional
from together import Together from together import Together
@@ -17,13 +17,13 @@ class TogetherEmbedding(EmbeddingBase):
self.config.embedding_dims = self.config.embedding_dims or 768 self.config.embedding_dims = self.config.embedding_dims or 768
self.client = Together(api_key=api_key) self.client = Together(api_key=api_key)
def embed(self, text): def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
""" """
Get the embedding for the given text using OpenAI. Get the embedding for the given text using OpenAI.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """

View File

@@ -1,7 +1,7 @@
import os import os
from typing import Optional from typing import Literal, Optional
from vertexai.language_models import TextEmbeddingModel from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase from mem0.embeddings.base import EmbeddingBase
@@ -14,6 +14,12 @@ class VertexAIEmbedding(EmbeddingBase):
self.config.model = self.config.model or "text-embedding-004" self.config.model = self.config.model or "text-embedding-004"
self.config.embedding_dims = self.config.embedding_dims or 256 self.config.embedding_dims = self.config.embedding_dims or 256
self.embedding_types = {
"add": self.config.memory_add_embedding_type or "RETRIEVAL_DOCUMENT",
"update": self.config.memory_update_embedding_type or "RETRIEVAL_DOCUMENT",
"search": self.config.memory_search_embedding_type or "RETRIEVAL_QUERY"
}
credentials_path = self.config.vertex_credentials_json credentials_path = self.config.vertex_credentials_json
if credentials_path: if credentials_path:
@@ -25,16 +31,24 @@ class VertexAIEmbedding(EmbeddingBase):
self.model = TextEmbeddingModel.from_pretrained(self.config.model) self.model = TextEmbeddingModel.from_pretrained(self.config.model)
def embed(self, text): def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
""" """
Get the embedding for the given text using Vertex AI. Get the embedding for the given text using Vertex AI.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality=self.config.embedding_dims) embedding_type = "SEMANTIC_SIMILARITY"
if memory_action is not None:
if memory_action not in self.embedding_types:
raise ValueError(f"Invalid memory action: {memory_action}")
embedding_type = self.embedding_types[memory_action]
text_input = TextEmbeddingInput(text=text, task_type=embedding_type)
embeddings = self.model.get_embeddings(texts=[text_input], output_dimensionality=self.config.embedding_dims)
return embeddings[0].values return embeddings[0].values

View File

@@ -9,7 +9,7 @@ from typing import Any, Dict
import pytz import pytz
from pydantic import ValidationError from pydantic import ValidationError
from mem0.memory.utils import parse_vision_messages
from mem0.configs.base import MemoryConfig, MemoryItem from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase from mem0.memory.base import MemoryBase
@@ -19,6 +19,7 @@ from mem0.memory.telemetry import capture_event
from mem0.memory.utils import ( from mem0.memory.utils import (
get_fact_retrieval_messages, get_fact_retrieval_messages,
parse_messages, parse_messages,
parse_vision_messages,
remove_code_blocks, remove_code_blocks,
) )
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
@@ -167,7 +168,7 @@ class Memory(MemoryBase):
retrieved_old_memory = [] retrieved_old_memory = []
new_message_embeddings = {} new_message_embeddings = {}
for new_mem in new_retrieved_facts: for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem) messages_embeddings = self.embedding_model.embed(new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search( existing_memories = self.vector_store.search(
query=messages_embeddings, query=messages_embeddings,
@@ -446,7 +447,7 @@ class Memory(MemoryBase):
return original_memories return original_memories
def _search_vector_store(self, query, filters, limit): def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query) embeddings = self.embedding_model.embed(query, "search")
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters) memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)
excluded_keys = { excluded_keys = {
@@ -494,7 +495,7 @@ class Memory(MemoryBase):
""" """
capture_event("mem0.update", self, {"memory_id": memory_id}) capture_event("mem0.update", self, {"memory_id": memory_id})
existing_embeddings = {data: self.embedding_model.embed(data)} existing_embeddings = {data: self.embedding_model.embed(data, "update")}
self._update_memory(memory_id, data, existing_embeddings) self._update_memory(memory_id, data, existing_embeddings)
return {"message": "Memory updated successfully!"} return {"message": "Memory updated successfully!"}
@@ -562,7 +563,7 @@ class Memory(MemoryBase):
if data in existing_embeddings: if data in existing_embeddings:
embeddings = existing_embeddings[data] embeddings = existing_embeddings[data]
else: else:
embeddings = self.embedding_model.embed(data) embeddings = self.embedding_model.embed(data, "add")
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
metadata = metadata or {} metadata = metadata or {}
metadata["data"] = data metadata["data"] = data
@@ -603,7 +604,7 @@ class Memory(MemoryBase):
if data in existing_embeddings: if data in existing_embeddings:
embeddings = existing_embeddings[data] embeddings = existing_embeddings[data]
else: else:
embeddings = self.embedding_model.embed(data) embeddings = self.embedding_model.embed(data, "update")
self.vector_store.update( self.vector_store.update(
vector_id=memory_id, vector_id=memory_id,
vector=embeddings, vector=embeddings,

View File

@@ -1,5 +1,7 @@
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from mem0.embeddings.vertexai import VertexAIEmbedding from mem0.embeddings.vertexai import VertexAIEmbedding
@@ -20,13 +22,23 @@ def mock_os_environ():
@pytest.fixture @pytest.fixture
def mock_config(): def mock_config():
with patch("mem0.configs.embeddings.base.BaseEmbedderConfig") as mock_config: with patch("mem0.configs.embeddings.base.BaseEmbedderConfig") as mock_config:
mock_config.vertex_credentials_json = None mock_config.return_value.vertex_credentials_json = "/path/to/credentials.json"
yield mock_config yield mock_config
@pytest.fixture
def mock_embedding_types():
return ["SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "RETRIEVAL_DOCUMENT", "RETRIEVAL_QUERY", "QUESTION_ANSWERING", "FACT_VERIFICATION", "CODE_RETRIEVAL_QUERY"]
@pytest.fixture
def mock_text_embedding_input():
with patch("mem0.embeddings.vertexai.TextEmbeddingInput") as mock_input:
yield mock_input
@patch("mem0.embeddings.vertexai.TextEmbeddingModel") @patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_default_model(mock_text_embedding_model, mock_os_environ, mock_config): def test_embed_default_model(mock_text_embedding_model, mock_os_environ, mock_config, mock_text_embedding_input):
mock_config.vertex_credentials_json = "/path/to/credentials.json"
mock_config.return_value.model = "text-embedding-004" mock_config.return_value.model = "text-embedding-004"
mock_config.return_value.embedding_dims = 256 mock_config.return_value.embedding_dims = 256
@@ -37,16 +49,16 @@ def test_embed_default_model(mock_text_embedding_model, mock_os_environ, mock_co
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding] mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding]
embedder.embed("Hello world") embedder.embed("Hello world")
mock_text_embedding_input.assert_called_once_with(text="Hello world", task_type="SEMANTIC_SIMILARITY")
mock_text_embedding_model.from_pretrained.assert_called_once_with("text-embedding-004") mock_text_embedding_model.from_pretrained.assert_called_once_with("text-embedding-004")
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_once_with( mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_once_with(
texts=["Hello world"], output_dimensionality=256 texts=[mock_text_embedding_input("Hello world")], output_dimensionality=256
) )
@patch("mem0.embeddings.vertexai.TextEmbeddingModel") @patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_config): def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_config, mock_text_embedding_input):
mock_config.vertex_credentials_json = "/path/to/credentials.json"
mock_config.return_value.model = "custom-embedding-model" mock_config.return_value.model = "custom-embedding-model"
mock_config.return_value.embedding_dims = 512 mock_config.return_value.embedding_dims = 512
@@ -58,18 +70,42 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding] mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [mock_embedding]
result = embedder.embed("Test embedding") result = embedder.embed("Test embedding")
mock_text_embedding_input.assert_called_once_with(text="Test embedding", task_type="SEMANTIC_SIMILARITY")
mock_text_embedding_model.from_pretrained.assert_called_with("custom-embedding-model") mock_text_embedding_model.from_pretrained.assert_called_with("custom-embedding-model")
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_once_with( mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_once_with(
texts=["Test embedding"], output_dimensionality=512 texts=[mock_text_embedding_input("Test embedding")], output_dimensionality=512
) )
assert result == [0.4, 0.5, 0.6] assert result == [0.4, 0.5, 0.6]
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_with_memory_action(mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input):
mock_config.return_value.model = "text-embedding-004"
mock_config.return_value.embedding_dims = 256
for embedding_type in mock_embedding_types:
mock_config.return_value.memory_add_embedding_type = embedding_type
mock_config.return_value.memory_update_embedding_type = embedding_type
mock_config.return_value.memory_search_embedding_type = embedding_type
config = mock_config()
embedder = VertexAIEmbedding(config)
mock_text_embedding_model.from_pretrained.assert_called_with("text-embedding-004")
for memory_action in ["add", "update", "search"]:
embedder.embed("Hello world", memory_action=memory_action)
mock_text_embedding_input.assert_called_with(text="Hello world", task_type=embedding_type)
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_with(
texts=[mock_text_embedding_input("Hello world", embedding_type)], output_dimensionality=256
)
@patch("mem0.embeddings.vertexai.os") @patch("mem0.embeddings.vertexai.os")
def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_config): def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_config):
mock_os.getenv.return_value = "/path/to/env/credentials.json"
mock_config.vertex_credentials_json = None mock_config.vertex_credentials_json = None
config = mock_config() config = mock_config()
VertexAIEmbedding(config) VertexAIEmbedding(config)
@@ -90,7 +126,6 @@ def test_missing_credentials(mock_os, mock_text_embedding_model, mock_config):
@patch("mem0.embeddings.vertexai.TextEmbeddingModel") @patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_environ, mock_config): def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_environ, mock_config):
mock_config.vertex_credentials_json = "/path/to/credentials.json"
mock_config.return_value.embedding_dims = 1024 mock_config.return_value.embedding_dims = 1024
config = mock_config() config = mock_config()
@@ -102,3 +137,15 @@ def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_envi
result = embedder.embed("Large embedding test") result = embedder.embed("Large embedding test")
assert result == [0.1] * 1024 assert result == [0.1] * 1024
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_invalid_memory_action(mock_text_embedding_model, mock_config):
mock_config.return_value.model = "text-embedding-004"
mock_config.return_value.embedding_dims = 256
config = mock_config()
embedder = VertexAIEmbedding(config)
with pytest.raises(ValueError):
embedder.embed("Hello world", memory_action="invalid_action")

View File

@@ -119,7 +119,7 @@ def test_search(memory_instance, version, enable_graph):
memory_instance.vector_store.search.assert_called_once_with( memory_instance.vector_store.search.assert_called_once_with(
query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"} query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
) )
memory_instance.embedding_model.embed.assert_called_once_with("test query") memory_instance.embedding_model.embed.assert_called_once_with("test query", "search")
if enable_graph: if enable_graph:
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}, 100) memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}, 100)