diff --git a/docs/components/llms/models/gemini.mdx b/docs/components/llms/models/gemini.mdx index 7a502ad5..89b576f9 100644 --- a/docs/components/llms/models/gemini.mdx +++ b/docs/components/llms/models/gemini.mdx @@ -4,7 +4,11 @@ title: Gemini -To use Gemini model, you have to set the `GEMINI_API_KEY` environment variable. You can obtain the Gemini API key from the [Google AI Studio](https://aistudio.google.com/app/apikey) +To use the Gemini model, set the `GEMINI_API_KEY` environment variable. You can obtain the Gemini API key from [Google AI Studio](https://aistudio.google.com/app/apikey). + +> **Note:** As of the latest release, Mem0 uses the new `google.genai` SDK instead of the deprecated `google.generativeai`. All message formatting and model interaction now use the updated `types` module from `google.genai`. + +> **Note:** Some Gemini models are being deprecated and will retire soon. It is recommended to migrate to the latest stable models like `"gemini-2.0-flash-001"` or `"gemini-2.0-flash-lite-001"` to ensure ongoing support and improvements. ## Usage @@ -12,28 +16,32 @@ To use Gemini model, you have to set the `GEMINI_API_KEY` environment variable. import os from mem0 import Memory -os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model -os.environ["GEMINI_API_KEY"] = "your-api-key" +os.environ["OPENAI_API_KEY"] = "your-openai-api-key" # Used for embedding model +os.environ["GEMINI_API_KEY"] = "your-gemini-api-key" config = { "llm": { "provider": "gemini", "config": { - "model": "gemini-1.5-flash-latest", + "model": "gemini-2.0-flash-001", "temperature": 0.2, "max_tokens": 2000, + "top_p": 1.0 } } } m = Memory.from_config(config) + messages = [ {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, - {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, - {"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."}, - {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} + {"role": "assistant", "content": "How about thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I’m not a big fan of thrillers, but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thrillers and suggest sci-fi movies instead."} ] + m.add(messages, user_id="alice", metadata={"category": "movies"}) + ``` ## Config diff --git a/docs/open-source/graph_memory/overview.mdx b/docs/open-source/graph_memory/overview.mdx index 2ff372d7..d1bb4712 100644 --- a/docs/open-source/graph_memory/overview.mdx +++ b/docs/open-source/graph_memory/overview.mdx @@ -238,16 +238,24 @@ The Mem0's graph supports the following operations: ### Add Memories -If you are using Mem0 with Graph Memory, it is recommended to pass `user_id`. Use `userId` in NodeSDK. +Mem0 with Graph Memory supports both `user_id` and `agent_id` parameters. You can use either or both to organize your memories. Use `userId` and `agentId` in NodeSDK. ```python Python +# Using only user_id m.add("I like pizza", user_id="alice") + +# Using both user_id and agent_id +m.add("I like pizza", user_id="alice", agent_id="food-assistant") ``` ```typescript TypeScript +// Using only userId memory.add("I like pizza", { userId: "alice" }); + +// Using both userId and agentId +memory.add("I like pizza", { userId: "alice", agentId: "food-assistant" }); ``` ```json Output @@ -260,11 +268,19 @@ memory.add("I like pizza", { userId: "alice" }); ```python Python +# Get all memories for a user m.get_all(user_id="alice") + +# Get all memories for a specific agent belonging to a user +m.get_all(user_id="alice", agent_id="food-assistant") ``` ```typescript TypeScript +// Get all memories for a user memory.getAll({ userId: "alice" }); + +// Get all memories for a specific agent belonging to a user +memory.getAll({ userId: "alice", agentId: "food-assistant" }); ``` ```json Output @@ -277,7 +293,8 @@ memory.getAll({ userId: "alice" }); 'metadata': None, 'created_at': '2024-08-20T14:09:27.588719-07:00', 'updated_at': None, - 'user_id': 'alice' + 'user_id': 'alice', + 'agent_id': 'food-assistant' } ], 'entities': [ @@ -295,11 +312,19 @@ memory.getAll({ userId: "alice" }); ```python Python +# Search memories for a user m.search("tell me my name.", user_id="alice") + +# Search memories for a specific agent belonging to a user +m.search("tell me my name.", user_id="alice", agent_id="food-assistant") ``` ```typescript TypeScript +// Search memories for a user memory.search("tell me my name.", { userId: "alice" }); + +// Search memories for a specific agent belonging to a user +memory.search("tell me my name.", { userId: "alice", agentId: "food-assistant" }); ``` ```json Output @@ -312,7 +337,8 @@ memory.search("tell me my name.", { userId: "alice" }); 'metadata': None, 'created_at': '2024-08-20T14:09:27.588719-07:00', 'updated_at': None, - 'user_id': 'alice' + 'user_id': 'alice', + 'agent_id': 'food-assistant' } ], 'entities': [ @@ -331,11 +357,19 @@ memory.search("tell me my name.", { userId: "alice" }); ```python Python +# Delete all memories for a user m.delete_all(user_id="alice") + +# Delete all memories for a specific agent belonging to a user +m.delete_all(user_id="alice", agent_id="food-assistant") ``` ```typescript TypeScript +// Delete all memories for a user memory.deleteAll({ userId: "alice" }); + +// Delete all memories for a specific agent belonging to a user +memory.deleteAll({ userId: "alice", agentId: "food-assistant" }); ``` @@ -516,6 +550,42 @@ memory.search("Who is spiderman?", { userId: "alice123" }); > **Note:** The Graph Memory implementation is not standalone. You will be adding/retrieving memories to the vector store and the graph store simultaneously. +## Using Multiple Agents with Graph Memory + +When working with multiple agents, you can use the `agent_id` parameter to organize memories by both user and agent. This allows you to: + +1. Create agent-specific knowledge graphs +2. Share common knowledge between agents +3. Isolate sensitive or specialized information to specific agents + +### Example: Multi-Agent Setup + + +```python Python +# Add memories for different agents +m.add("I prefer Italian cuisine", user_id="bob", agent_id="food-assistant") +m.add("I'm allergic to peanuts", user_id="bob", agent_id="health-assistant") +m.add("I live in Seattle", user_id="bob") # Shared across all agents + +# Search within specific agent context +food_preferences = m.search("What food do I like?", user_id="bob", agent_id="food-assistant") +health_info = m.search("What are my allergies?", user_id="bob", agent_id="health-assistant") +location = m.search("Where do I live?", user_id="bob") # Searches across all agents +``` + +```typescript TypeScript +// Add memories for different agents +memory.add("I prefer Italian cuisine", { userId: "bob", agentId: "food-assistant" }); +memory.add("I'm allergic to peanuts", { userId: "bob", agentId: "health-assistant" }); +memory.add("I live in Seattle", { userId: "bob" }); // Shared across all agents + +// Search within specific agent context +const foodPreferences = memory.search("What food do I like?", { userId: "bob", agentId: "food-assistant" }); +const healthInfo = memory.search("What are my allergies?", { userId: "bob", agentId: "health-assistant" }); +const location = memory.search("Where do I live?", { userId: "bob" }); // Searches across all agents +``` + + If you want to use a managed version of Mem0, please check out [Mem0](https://mem0.dev/pd). If you have any questions, please feel free to reach out to us using one of the following methods: - \ No newline at end of file + diff --git a/embedchain/poetry.lock b/embedchain/poetry.lock index f187bdbb..e53e71b4 100644 --- a/embedchain/poetry.lock +++ b/embedchain/poetry.lock @@ -2552,7 +2552,7 @@ azure = ["adlfs (>=2024.2.0)"] clip = ["open-clip", "pillow", "torch"] dev = ["pre-commit", "ruff"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] -embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] +embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch", "google-genai"] tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] [[package]] @@ -7129,7 +7129,7 @@ cffi = ["cffi (>=1.11)"] aws = ["langchain-aws"] elasticsearch = ["elasticsearch"] gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-auth-httplib2", "google-auth-oauthlib", "requests"] -google = ["google-generativeai"] +google = ["google-generativeai", "google-genai"] googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"] lancedb = ["lancedb"] llama2 = ["replicate"] diff --git a/mem0/llms/gemini.py b/mem0/llms/gemini.py index 7881cf05..3c48c5da 100644 --- a/mem0/llms/gemini.py +++ b/mem0/llms/gemini.py @@ -2,9 +2,9 @@ import os from typing import Dict, List, Optional try: - import google.generativeai as genai - from google.generativeai import GenerativeModel, protos - from google.generativeai.types import content_types + from google import genai + from google.genai import types + except ImportError: raise ImportError( "The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'." @@ -22,66 +22,71 @@ class GeminiLLM(LLMBase): self.config.model = "gemini-1.5-flash-latest" api_key = self.config.api_key or os.getenv("GEMINI_API_KEY") - genai.configure(api_key=api_key) - self.client = GenerativeModel(model_name=self.config.model) + self.client_gemini = genai.Client( + api_key=api_key, + ) def _parse_response(self, response, tools): """ Process the response based on whether tools are used or not. Args: - response: The raw response from API. + response: The raw response from the API. tools: The list of tools provided in the request. Returns: str or dict: The processed response. """ + candidate = response.candidates[0] + content = candidate.content.parts[0].text if candidate.content.parts else None + if tools: processed_response = { - "content": (content if (content := response.candidates[0].content.parts[0].text) else None), + "content": content, "tool_calls": [], } - for part in response.candidates[0].content.parts: - if fn := part.function_call: - if isinstance(fn, protos.FunctionCall): - fn_call = type(fn).to_dict(fn) - processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]}) - continue - processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args}) + for part in candidate.content.parts: + fn = getattr(part, "function_call", None) + if fn: + processed_response["tool_calls"].append({ + "name": fn.name, + "arguments": fn.args, + }) return processed_response - else: - return response.candidates[0].content.parts[0].text - def _reformat_messages(self, messages: List[Dict[str, str]]): + return content + + + def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]: """ - Reformat messages for Gemini. + Reformat messages for Gemini using google.genai.types. Args: messages: The list of messages provided in the request. Returns: - list: The list of messages in the required format. + list: A list of types.Content objects with proper role and parts. """ new_messages = [] for message in messages: if message["role"] == "system": content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] - else: content = message["content"] new_messages.append( - { - "parts": content, - "role": "model" if message["role"] == "model" else "user", - } + types.Content( + role="model" if message["role"] == "model" else "user", + parts=[types.Part(text=content)] + ) ) return new_messages + def _reformat_tools(self, tools: Optional[List[Dict]]): """ Reformat tools for Gemini. @@ -126,6 +131,7 @@ class GeminiLLM(LLMBase): tools: Optional[List[Dict]] = None, tool_choice: str = "auto", ): + """ Generate a response based on the given messages using Gemini. @@ -149,23 +155,37 @@ class GeminiLLM(LLMBase): params["response_mime_type"] = "application/json" if "schema" in response_format: params["response_schema"] = response_format["schema"] + + tool_config = None if tool_choice: - tool_config = content_types.to_tool_config( - { - "function_calling_config": { - "mode": tool_choice, - "allowed_function_names": ( - [tool["function"]["name"] for tool in tools] if tool_choice == "any" else None - ), - } - } + tool_config = types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc. + allowed_function_names=[ + tool["function"]["name"] for tool in tools + ] if tool_choice == "any" else None + ) ) - response = self.client.generate_content( - contents=self._reformat_messages(messages), - tools=self._reformat_tools(tools), - generation_config=genai.GenerationConfig(**params), - tool_config=tool_config, - ) + print(f"Tool config: {tool_config}") + print(f"Params: {params}" ) + print(f"Messages: {messages}") + print(f"Tools: {tools}") + print(f"Reformatted messages: {self._reformat_messages(messages)}") + print(f"Reformatted tools: {self._reformat_tools(tools)}") + + response = self.client_gemini.models.generate_content( + model=self.config.model, + contents=self._reformat_messages(messages), + config=types.GenerateContentConfig( + temperature= self.config.temperature, + max_output_tokens= self.config.max_tokens, + top_p= self.config.top_p, + tools=self._reformat_tools(tools), + tool_config=tool_config, + + ), + ) + print(f"Response test: {response}") return self._parse_response(response, tools) diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index ff50c221..5156668a 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -80,8 +80,8 @@ class MemoryGraph: # TODO: Batch queries with APOC plugin # TODO: Add more filter support - deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"]) - added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map) + deleted_entities = self._delete_entities(to_be_deleted, filters) + added_entities = self._add_entities(to_be_added, filters, entity_type_map) return {"deleted_entities": deleted_entities, "added_entities": added_entities} @@ -122,32 +122,35 @@ class MemoryGraph: return search_results def delete_all(self, filters): - cypher = f""" - MATCH (n {self.node_label} {{user_id: $user_id}}) - DETACH DELETE n - """ - params = {"user_id": filters["user_id"]} + if filters.get("agent_id"): + cypher = f""" + MATCH (n {self.node_label} {{user_id: $user_id, agent_id: $agent_id}}) + DETACH DELETE n + """ + params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]} + else: + cypher = f""" + MATCH (n {self.node_label} {{user_id: $user_id}}) + DETACH DELETE n + """ + params = {"user_id": filters["user_id"]} self.graph.query(cypher, params=params) - def get_all(self, filters, limit=100): - """ - Retrieves all nodes and relationships from the graph database based on optional filtering criteria. - Args: - filters (dict): A dictionary containing filters to be applied during the retrieval. - limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100. - Returns: - list: A list of dictionaries, each containing: - - 'contexts': The base data store response for each memory. - - 'entities': A list of strings representing the nodes and relationships - """ - # return all nodes and relationships + def get_all(self, filters, limit=100): + agent_filter = "" + params = {"user_id": filters["user_id"], "limit": limit} + if filters.get("agent_id"): + agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id" + params["agent_id"] = filters["agent_id"] + query = f""" MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}}) + WHERE 1=1 {agent_filter} RETURN n.name AS source, type(r) AS relationship, m.name AS target LIMIT $limit """ - results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit}) + results = self.graph.query(query, params=params) final_results = [] for result in results: @@ -163,6 +166,7 @@ class MemoryGraph: return final_results + def _retrieve_nodes_from_data(self, data, filters): """Extracts all the entities mentioned in the query.""" _tools = [EXTRACT_ENTITIES_TOOL] @@ -197,23 +201,27 @@ class MemoryGraph: return entity_type_map def _establish_nodes_relations_from_data(self, data, filters, entity_type_map): - """Eshtablish relations among the extracted nodes.""" + """Establish relations among the extracted nodes.""" + + # Compose user identification string for prompt + user_identity = f"user_id: {filters['user_id']}" + if filters.get("agent_id"): + user_identity += f", agent_id: {filters['agent_id']}" + if self.config.graph_store.custom_prompt: + system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity) + # Add the custom prompt line if configured + system_content = system_content.replace( + "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}" + ) messages = [ - { - "role": "system", - "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace( - "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}" - ), - }, + {"role": "system", "content": system_content}, {"role": "user", "content": data}, ] else: + system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity) messages = [ - { - "role": "system", - "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]), - }, + {"role": "system", "content": system_content}, {"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"}, ] @@ -227,8 +235,8 @@ class MemoryGraph: ) entities = [] - if extracted_entities["tool_calls"]: - entities = extracted_entities["tool_calls"][0]["arguments"]["entities"] + if extracted_entities.get("tool_calls"): + entities = extracted_entities["tool_calls"][0].get("arguments", {}).get("entities", []) entities = self._remove_spaces_from_entities(entities) logger.debug(f"Extracted entities: {entities}") @@ -237,32 +245,43 @@ class MemoryGraph: def _search_graph_db(self, node_list, filters, limit=100): """Search similar nodes among and their respective incoming and outgoing relations.""" result_relations = [] + agent_filter = "" + if filters.get("agent_id"): + agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id" + for node in node_list: n_embedding = self.embedding_model.embed(node) cypher_query = f""" MATCH (n {self.node_label}) WHERE n.embedding IS NOT NULL AND n.user_id = $user_id + {agent_filter} WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility WHERE similarity >= $threshold - CALL (n) {{ - MATCH (n)-[r]->(m) + CALL {{ + MATCH (n)-[r]->(m) + WHERE m.user_id = $user_id {agent_filter.replace("n.", "m.")} RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id UNION - MATCH (m)-[r]->(n) + MATCH (m)-[r]->(n) + WHERE m.user_id = $user_id {agent_filter.replace("n.", "m.")} RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id }} - WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate + WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity ORDER BY similarity DESC LIMIT $limit """ + params = { "n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"], "limit": limit, } + if filters.get("agent_id"): + params["agent_id"] = filters["agent_id"] + ans = self.graph.query(cypher_query, params=params) result_relations.extend(ans) @@ -271,7 +290,13 @@ class MemoryGraph: def _get_delete_entities_from_search_output(self, search_output, data, filters): """Get the entities to be deleted from the search output.""" search_output_string = format_entities(search_output) - system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"]) + + # Compose user identification string for prompt + user_identity = f"user_id: {filters['user_id']}" + if filters.get("agent_id"): + user_identity += f", agent_id: {filters['agent_id']}" + + system_prompt, user_prompt = get_delete_messages(search_output_string, data, user_identity) _tools = [DELETE_MEMORY_TOOL_GRAPH] if self.llm_provider in ["azure_openai_structured", "openai_structured"]: @@ -288,44 +313,59 @@ class MemoryGraph: ) to_be_deleted = [] - for item in memory_updates["tool_calls"]: - if item["name"] == "delete_graph_memory": - to_be_deleted.append(item["arguments"]) - # in case if it is not in the correct format + for item in memory_updates.get("tool_calls", []): + if item.get("name") == "delete_graph_memory": + to_be_deleted.append(item.get("arguments")) + # Clean entities formatting to_be_deleted = self._remove_spaces_from_entities(to_be_deleted) logger.debug(f"Deleted relationships: {to_be_deleted}") return to_be_deleted - def _delete_entities(self, to_be_deleted, user_id): + def _delete_entities(self, to_be_deleted, filters): """Delete the entities from the graph.""" + user_id = filters["user_id"] + agent_id = filters.get("agent_id", None) results = [] + for item in to_be_deleted: source = item["source"] destination = item["destination"] relationship = item["relationship"] + # Build the agent filter for the query + agent_filter = "" + params = { + "source_name": source, + "dest_name": destination, + "user_id": user_id, + } + + if agent_id: + agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id" + params["agent_id"] = agent_id + # Delete the specific relationship between nodes cypher = f""" MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}}) -[r:{relationship}]-> (m {self.node_label} {{name: $dest_name, user_id: $user_id}}) + WHERE 1=1 {agent_filter} DELETE r RETURN n.name AS source, m.name AS target, type(r) AS relationship """ - params = { - "source_name": source, - "dest_name": destination, - "user_id": user_id, - } + result = self.graph.query(cypher, params=params) results.append(result) + return results - def _add_entities(self, to_be_added, user_id, entity_type_map): + def _add_entities(self, to_be_added, filters, entity_type_map): """Add the new entities to the graph. Merge the nodes if they already exist.""" + user_id = filters["user_id"] + agent_id = filters.get("agent_id", None) results = [] for item in to_be_added: # entities @@ -346,65 +386,80 @@ class MemoryGraph: dest_embedding = self.embedding_model.embed(destination) # search for the nodes with the closest embeddings - source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9) - destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9) + source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9) + destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9) # TODO: Create a cypher query and common params for all the cases if not destination_node_search_result and source_node_search_result: - cypher = f""" - MATCH (source) - WHERE elementId(source) = $source_id - SET source.mentions = coalesce(source.mentions, 0) + 1 - WITH source - MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}}) - ON CREATE SET - destination.created = timestamp(), - destination.mentions = 1 - {destination_extra_set} - ON MATCH SET - destination.mentions = coalesce(destination.mentions, 0) + 1 - WITH source, destination - CALL db.create.setNodeVectorProperty(destination, 'embedding', $destination_embedding) - WITH source, destination - MERGE (source)-[r:{relationship}]->(destination) - ON CREATE SET - r.created = timestamp(), - r.mentions = 1 - ON MATCH SET - r.mentions = coalesce(r.mentions, 0) + 1 - RETURN source.name AS source, type(r) AS relationship, destination.name AS target - """ + # Build destination MERGE properties + merge_props = ["name: $destination_name", "user_id: $user_id"] + if agent_id: + merge_props.append("agent_id: $agent_id") + merge_props_str = ", ".join(merge_props) + cypher = f""" + MATCH (source) + WHERE elementId(source) = $source_id + SET source.mentions = coalesce(source.mentions, 0) + 1 + WITH source + MERGE (destination {destination_label} {{{merge_props_str}}}) + ON CREATE SET + destination.created = timestamp(), + destination.mentions = 1 + {destination_extra_set} + ON MATCH SET + destination.mentions = coalesce(destination.mentions, 0) + 1 + WITH source, destination + CALL db.create.setNodeVectorProperty(destination, 'embedding', $destination_embedding) + WITH source, destination + MERGE (source)-[r:{relationship}]->(destination) + ON CREATE SET + r.created = timestamp(), + r.mentions = 1 + ON MATCH SET + r.mentions = coalesce(r.mentions, 0) + 1 + RETURN source.name AS source, type(r) AS relationship, destination.name AS target + """ + params = { "source_id": source_node_search_result[0]["elementId(source_candidate)"], "destination_name": destination, "destination_embedding": dest_embedding, "user_id": user_id, } + if agent_id: + params["agent_id"] = agent_id + elif destination_node_search_result and not source_node_search_result: + # Build source MERGE properties + merge_props = ["name: $source_name", "user_id: $user_id"] + if agent_id: + merge_props.append("agent_id: $agent_id") + merge_props_str = ", ".join(merge_props) + cypher = f""" - MATCH (destination) - WHERE elementId(destination) = $destination_id - SET destination.mentions = coalesce(destination.mentions, 0) + 1 - WITH destination - MERGE (source {source_label} {{name: $source_name, user_id: $user_id}}) - ON CREATE SET - source.created = timestamp(), - source.mentions = 1 - {source_extra_set} - ON MATCH SET - source.mentions = coalesce(source.mentions, 0) + 1 - WITH source, destination - CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding) - WITH source, destination - MERGE (source)-[r:{relationship}]->(destination) - ON CREATE SET - r.created = timestamp(), - r.mentions = 1 - ON MATCH SET - r.mentions = coalesce(r.mentions, 0) + 1 - RETURN source.name AS source, type(r) AS relationship, destination.name AS target - """ + MATCH (destination) + WHERE elementId(destination) = $destination_id + SET destination.mentions = coalesce(destination.mentions, 0) + 1 + WITH destination + MERGE (source {source_label} {{{merge_props_str}}}) + ON CREATE SET + source.created = timestamp(), + source.mentions = 1 + {source_extra_set} + ON MATCH SET + source.mentions = coalesce(source.mentions, 0) + 1 + WITH source, destination + CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding) + WITH source, destination + MERGE (source)-[r:{relationship}]->(destination) + ON CREATE SET + r.created = timestamp(), + r.mentions = 1 + ON MATCH SET + r.mentions = coalesce(r.mentions, 0) + 1 + RETURN source.name AS source, type(r) AS relationship, destination.name AS target + """ params = { "destination_id": destination_node_search_result[0]["elementId(destination_candidate)"], @@ -412,53 +467,68 @@ class MemoryGraph: "source_embedding": source_embedding, "user_id": user_id, } + if agent_id: + params["agent_id"] = agent_id + elif source_node_search_result and destination_node_search_result: cypher = f""" - MATCH (source) - WHERE elementId(source) = $source_id - SET source.mentions = coalesce(source.mentions, 0) + 1 - WITH source - MATCH (destination) - WHERE elementId(destination) = $destination_id - SET destination.mentions = coalesce(destination.mentions) + 1 - MERGE (source)-[r:{relationship}]->(destination) - ON CREATE SET - r.created_at = timestamp(), - r.updated_at = timestamp(), - r.mentions = 1 - ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1 - - - RETURN source.name AS source, type(r) AS relationship, destination.name AS target - """ + MATCH (source) + WHERE elementId(source) = $source_id + SET source.mentions = coalesce(source.mentions, 0) + 1 + WITH source + MATCH (destination) + WHERE elementId(destination) = $destination_id + SET destination.mentions = coalesce(destination.mentions, 0) + 1 + MERGE (source)-[r:{relationship}]->(destination) + ON CREATE SET + r.created_at = timestamp(), + r.updated_at = timestamp(), + r.mentions = 1 + ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1 + RETURN source.name AS source, type(r) AS relationship, destination.name AS target + """ + params = { "source_id": source_node_search_result[0]["elementId(source_candidate)"], "destination_id": destination_node_search_result[0]["elementId(destination_candidate)"], "user_id": user_id, } + if agent_id: + params["agent_id"] = agent_id + else: + # Build dynamic MERGE props for both source and destination + source_props = ["name: $source_name", "user_id: $user_id"] + dest_props = ["name: $dest_name", "user_id: $user_id"] + if agent_id: + source_props.append("agent_id: $agent_id") + dest_props.append("agent_id: $agent_id") + source_props_str = ", ".join(source_props) + dest_props_str = ", ".join(dest_props) + cypher = f""" - MERGE (source {source_label} {{name: $source_name, user_id: $user_id}}) - ON CREATE SET source.created = timestamp(), - source.mentions = 1 - {source_extra_set} - ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1 - WITH source - CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding) - WITH source - MERGE (destination {destination_label} {{name: $dest_name, user_id: $user_id}}) - ON CREATE SET destination.created = timestamp(), - destination.mentions = 1 - {destination_extra_set} - ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1 - WITH source, destination - CALL db.create.setNodeVectorProperty(destination, 'embedding', $source_embedding) - WITH source, destination - MERGE (source)-[rel:{relationship}]->(destination) - ON CREATE SET rel.created = timestamp(), rel.mentions = 1 - ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1 - RETURN source.name AS source, type(rel) AS relationship, destination.name AS target - """ + MERGE (source {source_label} {{{source_props_str}}}) + ON CREATE SET source.created = timestamp(), + source.mentions = 1 + {source_extra_set} + ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1 + WITH source + CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding) + WITH source + MERGE (destination {destination_label} {{{dest_props_str}}}) + ON CREATE SET destination.created = timestamp(), + destination.mentions = 1 + {destination_extra_set} + ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1 + WITH source, destination + CALL db.create.setNodeVectorProperty(destination, 'embedding', $dest_embedding) + WITH source, destination + MERGE (source)-[rel:{relationship}]->(destination) + ON CREATE SET rel.created = timestamp(), rel.mentions = 1 + ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1 + RETURN source.name AS source, type(rel) AS relationship, destination.name AS target + """ + params = { "source_name": source, "dest_name": destination, @@ -466,6 +536,8 @@ class MemoryGraph: "dest_embedding": dest_embedding, "user_id": user_id, } + if agent_id: + params["agent_id"] = agent_id result = self.graph.query(cypher, params=params) results.append(result) return results @@ -477,11 +549,16 @@ class MemoryGraph: item["destination"] = item["destination"].lower().replace(" ", "_") return entity_list - def _search_source_node(self, source_embedding, user_id, threshold=0.9): + def _search_source_node(self, source_embedding, filters, threshold=0.9): + agent_filter = "" + if filters.get("agent_id"): + agent_filter = "AND source_candidate.agent_id = $agent_id" + cypher = f""" MATCH (source_candidate {self.node_label}) WHERE source_candidate.embedding IS NOT NULL AND source_candidate.user_id = $user_id + {agent_filter} WITH source_candidate, round(2 * vector.similarity.cosine(source_candidate.embedding, $source_embedding) - 1, 4) AS source_similarity // denormalize for backward compatibility @@ -496,18 +573,26 @@ class MemoryGraph: params = { "source_embedding": source_embedding, - "user_id": user_id, + "user_id": filters["user_id"], "threshold": threshold, } + if filters.get("agent_id"): + params["agent_id"] = filters["agent_id"] result = self.graph.query(cypher, params=params) return result - def _search_destination_node(self, destination_embedding, user_id, threshold=0.9): + + def _search_destination_node(self, destination_embedding, filters, threshold=0.9): + agent_filter = "" + if filters.get("agent_id"): + agent_filter = "AND destination_candidate.agent_id = $agent_id" + cypher = f""" MATCH (destination_candidate {self.node_label}) WHERE destination_candidate.embedding IS NOT NULL AND destination_candidate.user_id = $user_id + {agent_filter} WITH destination_candidate, round(2 * vector.similarity.cosine(destination_candidate.embedding, $destination_embedding) - 1, 4) AS destination_similarity // denormalize for backward compatibility @@ -520,11 +605,14 @@ class MemoryGraph: RETURN elementId(destination_candidate) """ + params = { "destination_embedding": destination_embedding, - "user_id": user_id, + "user_id": filters["user_id"], "threshold": threshold, } + if filters.get("agent_id"): + params["agent_id"] = filters["agent_id"] result = self.graph.query(cypher, params=params) return result diff --git a/mem0/memory/memgraph_memory.py b/mem0/memory/memgraph_memory.py index 5a7cf6ec..9b289f78 100644 --- a/mem0/memory/memgraph_memory.py +++ b/mem0/memory/memgraph_memory.py @@ -118,11 +118,19 @@ class MemoryGraph: return search_results def delete_all(self, filters): - cypher = """ - MATCH (n {user_id: $user_id}) - DETACH DELETE n - """ - params = {"user_id": filters["user_id"]} + """Delete all nodes and relationships for a user or specific agent.""" + if filters.get("agent_id"): + cypher = """ + MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id}) + DETACH DELETE n + """ + params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]} + else: + cypher = """ + MATCH (n:Entity {user_id: $user_id}) + DETACH DELETE n + """ + params = {"user_id": filters["user_id"]} self.graph.query(cypher, params=params) def get_all(self, filters, limit=100): @@ -131,20 +139,31 @@ class MemoryGraph: Args: filters (dict): A dictionary containing filters to be applied during the retrieval. + Supports 'user_id' (required) and 'agent_id' (optional). limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100. Returns: list: A list of dictionaries, each containing: - - 'contexts': The base data store response for each memory. - - 'entities': A list of strings representing the nodes and relationships + - 'source': The source node name. + - 'relationship': The relationship type. + - 'target': The target node name. """ - - # return all nodes and relationships - query = """ - MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id}) - RETURN n.name AS source, type(r) AS relationship, m.name AS target - LIMIT $limit - """ - results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit}) + # Build query based on whether agent_id is provided + if filters.get("agent_id"): + query = """ + MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity {user_id: $user_id, agent_id: $agent_id}) + RETURN n.name AS source, type(r) AS relationship, m.name AS target + LIMIT $limit + """ + params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"], "limit": limit} + else: + query = """ + MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id}) + RETURN n.name AS source, type(r) AS relationship, m.name AS target + LIMIT $limit + """ + params = {"user_id": filters["user_id"], "limit": limit} + + results = self.graph.query(query, params=params) final_results = [] for result in results: @@ -241,33 +260,65 @@ class MemoryGraph: for node in node_list: n_embedding = self.embedding_model.embed(node) - cypher_query = """ - MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity) - WHERE n.embedding IS NOT NULL - WITH collect(n) AS nodes1, collect(m) AS nodes2, r - CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2) - YIELD node1, node2, similarity - WITH node1, node2, similarity, r - WHERE similarity >= $threshold - RETURN node1.user_id AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.user_id AS destination, id(node2) AS destination_id, similarity - UNION - MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity) - WHERE n.embedding IS NOT NULL - WITH collect(n) AS nodes1, collect(m) AS nodes2, r - CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2) - YIELD node1, node2, similarity - WITH node1, node2, similarity, r - WHERE similarity >= $threshold - RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity - ORDER BY similarity DESC - LIMIT $limit; - """ - params = { - "n_embedding": n_embedding, - "threshold": self.threshold, - "user_id": filters["user_id"], - "limit": limit, - } + # Build query based on whether agent_id is provided + if filters.get("agent_id"): + cypher_query = """ + MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity) + WHERE n.embedding IS NOT NULL + WITH collect(n) AS nodes1, collect(m) AS nodes2, r + CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2) + YIELD node1, node2, similarity + WITH node1, node2, similarity, r + WHERE similarity >= $threshold + RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity + UNION + MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})<-[r]-(m:Entity) + WHERE n.embedding IS NOT NULL + WITH collect(n) AS nodes1, collect(m) AS nodes2, r + CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2) + YIELD node1, node2, similarity + WITH node1, node2, similarity, r + WHERE similarity >= $threshold + RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity + ORDER BY similarity DESC + LIMIT $limit; + """ + params = { + "n_embedding": n_embedding, + "threshold": self.threshold, + "user_id": filters["user_id"], + "agent_id": filters["agent_id"], + "limit": limit, + } + else: + cypher_query = """ + MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity) + WHERE n.embedding IS NOT NULL + WITH collect(n) AS nodes1, collect(m) AS nodes2, r + CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2) + YIELD node1, node2, similarity + WITH node1, node2, similarity, r + WHERE similarity >= $threshold + RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity + UNION + MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity) + WHERE n.embedding IS NOT NULL + WITH collect(n) AS nodes1, collect(m) AS nodes2, r + CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2) + YIELD node1, node2, similarity + WITH node1, node2, similarity, r + WHERE similarity >= $threshold + RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity + ORDER BY similarity DESC + LIMIT $limit; + """ + params = { + "n_embedding": n_embedding, + "threshold": self.threshold, + "user_id": filters["user_id"], + "limit": limit, + } + ans = self.graph.query(cypher_query, params=params) result_relations.extend(ans) @@ -300,38 +351,54 @@ class MemoryGraph: logger.debug(f"Deleted relationships: {to_be_deleted}") return to_be_deleted - def _delete_entities(self, to_be_deleted, user_id): + def _delete_entities(self, to_be_deleted, filters): """Delete the entities from the graph.""" + user_id = filters["user_id"] + agent_id = filters.get("agent_id", None) results = [] + for item in to_be_deleted: source = item["source"] destination = item["destination"] relationship = item["relationship"] + # Build the agent filter for the query + agent_filter = "" + params = { + "source_name": source, + "dest_name": destination, + "user_id": user_id, + } + + if agent_id: + agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id" + params["agent_id"] = agent_id + # Delete the specific relationship between nodes cypher = f""" MATCH (n:Entity {{name: $source_name, user_id: $user_id}}) -[r:{relationship}]-> - (m {{name: $dest_name, user_id: $user_id}}) + (m:Entity {{name: $dest_name, user_id: $user_id}}) + WHERE 1=1 {agent_filter} DELETE r RETURN n.name AS source, m.name AS target, type(r) AS relationship """ - params = { - "source_name": source, - "dest_name": destination, - "user_id": user_id, - } + result = self.graph.query(cypher, params=params) results.append(result) + return results # added Entity label to all nodes for vector search to work - def _add_entities(self, to_be_added, user_id, entity_type_map): + def _add_entities(self, to_be_added, filters, entity_type_map): """Add the new entities to the graph. Merge the nodes if they already exist.""" + user_id = filters["user_id"] + agent_id = filters.get("agent_id", None) results = [] + for item in to_be_added: # entities source = item["source"] @@ -346,18 +413,21 @@ class MemoryGraph: source_embedding = self.embedding_model.embed(source) dest_embedding = self.embedding_model.embed(destination) - # search for the nodes with the closest embeddings; this is basically - # comparison of one embedding to all embeddings in a graph -> vector - # search with cosine similarity metric - source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9) - destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9) + # search for the nodes with the closest embeddings + source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9) + destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9) + # Prepare agent_id for node creation + agent_id_clause = "" + if agent_id: + agent_id_clause = ", agent_id: $agent_id" + # TODO: Create a cypher query and common params for all the cases if not destination_node_search_result and source_node_search_result: cypher = f""" MATCH (source:Entity) WHERE id(source) = $source_id - MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id}}) + MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id{agent_id_clause}}}) ON CREATE SET destination.created = timestamp(), destination.embedding = $destination_embedding, @@ -374,11 +444,14 @@ class MemoryGraph: "destination_embedding": dest_embedding, "user_id": user_id, } + if agent_id: + params["agent_id"] = agent_id + elif destination_node_search_result and not source_node_search_result: cypher = f""" MATCH (destination:Entity) WHERE id(destination) = $destination_id - MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id}}) + MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}}) ON CREATE SET source.created = timestamp(), source.embedding = $source_embedding, @@ -395,6 +468,9 @@ class MemoryGraph: "source_embedding": source_embedding, "user_id": user_id, } + if agent_id: + params["agent_id"] = agent_id + elif source_node_search_result and destination_node_search_result: cypher = f""" MATCH (source:Entity) @@ -412,12 +488,15 @@ class MemoryGraph: "destination_id": destination_node_search_result[0]["id(destination_candidate)"], "user_id": user_id, } + if agent_id: + params["agent_id"] = agent_id + else: cypher = f""" - MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id}}) + MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}}) ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding, n:Entity ON MATCH SET n.embedding = $source_embedding - MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id}}) + MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id{agent_id_clause}}}) ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding, m:Entity ON MATCH SET m.embedding = $dest_embedding MERGE (n)-[rel:{relationship}]->(m) @@ -431,6 +510,9 @@ class MemoryGraph: "dest_embedding": dest_embedding, "user_id": user_id, } + if agent_id: + params["agent_id"] = agent_id + result = self.graph.query(cypher, params=params) results.append(result) return results @@ -442,37 +524,80 @@ class MemoryGraph: item["destination"] = item["destination"].lower().replace(" ", "_") return entity_list - def _search_source_node(self, source_embedding, user_id, threshold=0.9): - cypher = """ - CALL vector_search.search("memzero", 1, $source_embedding) - YIELD distance, node, similarity - WITH node AS source_candidate, similarity - WHERE source_candidate.user_id = $user_id AND similarity >= $threshold - RETURN id(source_candidate); - """ - - params = { - "source_embedding": source_embedding, - "user_id": user_id, - "threshold": threshold, - } + def _search_source_node(self, source_embedding, filters, threshold=0.9): + """Search for source nodes with similar embeddings.""" + user_id = filters["user_id"] + agent_id = filters.get("agent_id", None) + + if agent_id: + cypher = """ + CALL vector_search.search("memzero", 1, $source_embedding) + YIELD distance, node, similarity + WITH node AS source_candidate, similarity + WHERE source_candidate.user_id = $user_id + AND source_candidate.agent_id = $agent_id + AND similarity >= $threshold + RETURN id(source_candidate); + """ + params = { + "source_embedding": source_embedding, + "user_id": user_id, + "agent_id": agent_id, + "threshold": threshold, + } + else: + cypher = """ + CALL vector_search.search("memzero", 1, $source_embedding) + YIELD distance, node, similarity + WITH node AS source_candidate, similarity + WHERE source_candidate.user_id = $user_id + AND similarity >= $threshold + RETURN id(source_candidate); + """ + params = { + "source_embedding": source_embedding, + "user_id": user_id, + "threshold": threshold, + } result = self.graph.query(cypher, params=params) return result - def _search_destination_node(self, destination_embedding, user_id, threshold=0.9): - cypher = """ - CALL vector_search.search("memzero", 1, $destination_embedding) - YIELD distance, node, similarity - WITH node AS destination_candidate, similarity - WHERE node.user_id = $user_id AND similarity >= $threshold - RETURN id(destination_candidate); - """ - params = { - "destination_embedding": destination_embedding, - "user_id": user_id, - "threshold": threshold, - } + def _search_destination_node(self, destination_embedding, filters, threshold=0.9): + """Search for destination nodes with similar embeddings.""" + user_id = filters["user_id"] + agent_id = filters.get("agent_id", None) + + if agent_id: + cypher = """ + CALL vector_search.search("memzero", 1, $destination_embedding) + YIELD distance, node, similarity + WITH node AS destination_candidate, similarity + WHERE node.user_id = $user_id + AND node.agent_id = $agent_id + AND similarity >= $threshold + RETURN id(destination_candidate); + """ + params = { + "destination_embedding": destination_embedding, + "user_id": user_id, + "agent_id": agent_id, + "threshold": threshold, + } + else: + cypher = """ + CALL vector_search.search("memzero", 1, $destination_embedding) + YIELD distance, node, similarity + WITH node AS destination_candidate, similarity + WHERE node.user_id = $user_id + AND similarity >= $threshold + RETURN id(destination_candidate); + """ + params = { + "destination_embedding": destination_embedding, + "user_id": user_id, + "threshold": threshold, + } result = self.graph.query(cypher, params=params) return result diff --git a/pyproject.toml b/pyproject.toml index c3b86732..312ba754 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,8 @@ llms = [ "ollama>=0.1.0", "vertexai>=0.1.0", "google-generativeai>=0.3.0", + "google-genai>=1.0.0", + ] extras = [ "boto3>=1.34.0",