Adding Gemini (#1862)

This commit is contained in:
Pranav Puranik
2024-09-27 11:46:40 -05:00
committed by GitHub
parent 699741c760
commit aaf8e6e7ff
9 changed files with 379 additions and 4 deletions

View File

@@ -12,7 +12,8 @@ install:
install_all: install_all:
poetry install poetry install
poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \
google-generativeai
# Format code with ruff # Format code with ruff
format: format:

View File

@@ -0,0 +1,41 @@
---
title: Gemini
---
To use Gemini embedding models, set the `GOOGLE_API_KEY` environment variables. You can obtain the Gemini API key from [here](https://aistudio.google.com/app/apikey).
### Usage
```python
import os
from mem0 import Memory
os.environ["GOOGLE_API_KEY"] = "key"
config = {
"embedder": {
"provider": "gemini",
"config": {
"model": "models/text-embedding-004"
}
},
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": "test",
"embedding_model_dims": 768,
}
},
}
m = Memory.from_config(config)
m.add("I'm visiting Paris", user_id="john")
```
### Config
Here are the parameters available for configuring Gemini embedder:
| Parameter | Description | Default Value |
| --- | --- | --- |
| `model` | The name of the embedding model to use | `models/text-embedding-004` |

View File

@@ -13,6 +13,7 @@ See the list of supported embedders below.
<Card title="Azure OpenAI" href="/components/embedders/models/azure_openai"></Card> <Card title="Azure OpenAI" href="/components/embedders/models/azure_openai"></Card>
<Card title="Ollama" href="/components/embedders/models/ollama"></Card> <Card title="Ollama" href="/components/embedders/models/ollama"></Card>
<Card title="Hugging Face" href="/components/embedders/models/huggingface"></Card> <Card title="Hugging Face" href="/components/embedders/models/huggingface"></Card>
<Card title="Gemini" href="/components/embedders/models/gemini"></Card>
<Card title="Vertex AI" href="/components/embedders/models/vertexai"></Card> <Card title="Vertex AI" href="/components/embedders/models/vertexai"></Card>
</CardGroup> </CardGroup>

View File

@@ -125,7 +125,8 @@
"components/embedders/models/openai", "components/embedders/models/openai",
"components/embedders/models/azure_openai", "components/embedders/models/azure_openai",
"components/embedders/models/ollama", "components/embedders/models/ollama",
"components/embedders/models/huggingface" "components/embedders/models/huggingface",
"components/embedders/models/gemini"
] ]
} }
] ]

View File

@@ -13,7 +13,7 @@ class EmbedderConfig(BaseModel):
@field_validator("config") @field_validator("config")
def validate_config(cls, v, values): def validate_config(cls, v, values):
provider = values.data.get("provider") provider = values.data.get("provider")
if provider in ["openai", "ollama", "huggingface", "azure_openai", "vertexai"]: if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai"]:
return v return v
else: else:
raise ValueError(f"Unsupported embedding provider: {provider}") raise ValueError(f"Unsupported embedding provider: {provider}")

27
mem0/embeddings/gemini.py Normal file
View File

@@ -0,0 +1,27 @@
import os
from typing import Optional
import google.generativeai as genai
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
class GoogleGenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None:
self.config.model = "models/text-embedding-004" # embedding-dim = 768
genai.configure(api_key=self.config.api_key or os.getenv("GOOGLE_API_KEY"))
def embed(self, text):
"""
Get the embedding for the given text using Google Generative AI.
Args:
text (str): The text to embed.
Returns:
list: The embedding vector.
"""
text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text)
return response['embedding']

View File

@@ -41,6 +41,7 @@ class EmbedderFactory:
"ollama": "mem0.embeddings.ollama.OllamaEmbedding", "ollama": "mem0.embeddings.ollama.OllamaEmbedding",
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding", "huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding", "azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
} }
@classmethod @classmethod

268
poetry.lock generated
View File

@@ -211,6 +211,17 @@ files = [
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
] ]
[[package]]
name = "cachetools"
version = "5.5.0"
description = "Extensible memoizing collections and decorators"
optional = false
python-versions = ">=3.7"
files = [
{file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"},
{file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"},
]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2024.7.4" version = "2024.7.4"
@@ -458,6 +469,150 @@ files = [
{file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
] ]
[[package]]
name = "google-ai-generativelanguage"
version = "0.6.9"
description = "Google Ai Generativelanguage API client library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google_ai_generativelanguage-0.6.9-py3-none-any.whl", hash = "sha256:50360cd80015d1a8cc70952e98560f32fa06ddee2e8e9f4b4b98e431dc561e0b"},
{file = "google_ai_generativelanguage-0.6.9.tar.gz", hash = "sha256:899f1d3a06efa9739f1cd9d2788070178db33c89d4a76f2e8f4da76f649155fa"},
]
[package.dependencies]
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
proto-plus = ">=1.22.3,<2.0.0dev"
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
[[package]]
name = "google-api-core"
version = "2.20.0"
description = "Google API client core library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google_api_core-2.20.0-py3-none-any.whl", hash = "sha256:ef0591ef03c30bb83f79b3d0575c3f31219001fc9c5cf37024d08310aeffed8a"},
{file = "google_api_core-2.20.0.tar.gz", hash = "sha256:f74dff1889ba291a4b76c5079df0711810e2d9da81abfdc99957bc961c1eb28f"},
]
[package.dependencies]
google-auth = ">=2.14.1,<3.0.dev0"
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
grpcio = [
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
]
grpcio-status = [
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
]
proto-plus = ">=1.22.3,<2.0.0dev"
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
requests = ">=2.18.0,<3.0.0.dev0"
[package.extras]
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
[[package]]
name = "google-api-python-client"
version = "2.146.0"
description = "Google API Client Library for Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "google_api_python_client-2.146.0-py2.py3-none-any.whl", hash = "sha256:b1e62c9889c5ef6022f11d30d7ef23dc55100300f0e8aaf8aa09e8e92540acad"},
{file = "google_api_python_client-2.146.0.tar.gz", hash = "sha256:41f671be10fa077ee5143ee9f0903c14006d39dc644564f4e044ae96b380bf68"},
]
[package.dependencies]
google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0"
google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0"
google-auth-httplib2 = ">=0.2.0,<1.0.0"
httplib2 = ">=0.19.0,<1.dev0"
uritemplate = ">=3.0.1,<5"
[[package]]
name = "google-auth"
version = "2.35.0"
description = "Google Authentication Library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google_auth-2.35.0-py2.py3-none-any.whl", hash = "sha256:25df55f327ef021de8be50bad0dfd4a916ad0de96da86cd05661c9297723ad3f"},
{file = "google_auth-2.35.0.tar.gz", hash = "sha256:f4c64ed4e01e8e8b646ef34c018f8bf3338df0c8e37d8b3bba40e7f574a3278a"},
]
[package.dependencies]
cachetools = ">=2.0.0,<6.0"
pyasn1-modules = ">=0.2.1"
rsa = ">=3.1.4,<5"
[package.extras]
aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"]
enterprise-cert = ["cryptography", "pyopenssl"]
pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
reauth = ["pyu2f (>=0.1.5)"]
requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
[[package]]
name = "google-auth-httplib2"
version = "0.2.0"
description = "Google Authentication Library: httplib2 transport"
optional = false
python-versions = "*"
files = [
{file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"},
{file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"},
]
[package.dependencies]
google-auth = "*"
httplib2 = ">=0.19.0"
[[package]]
name = "google-generativeai"
version = "0.8.1"
description = "Google Generative AI High level API client library and tools."
optional = false
python-versions = ">=3.9"
files = [
{file = "google_generativeai-0.8.1-py3-none-any.whl", hash = "sha256:b031877f24d51af0945207657c085896a0a886eceec7a1cb7029327b0aa6e2f6"},
]
[package.dependencies]
google-ai-generativelanguage = "0.6.9"
google-api-core = "*"
google-api-python-client = "*"
google-auth = ">=2.15.0"
protobuf = "*"
pydantic = "*"
tqdm = "*"
typing-extensions = "*"
[package.extras]
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
[[package]]
name = "googleapis-common-protos"
version = "1.65.0"
description = "Common protobufs used in Google APIs"
optional = false
python-versions = ">=3.7"
files = [
{file = "googleapis_common_protos-1.65.0-py2.py3-none-any.whl", hash = "sha256:2972e6c496f435b92590fd54045060867f3fe9be2c82ab148fc8885035479a63"},
{file = "googleapis_common_protos-1.65.0.tar.gz", hash = "sha256:334a29d07cddc3aa01dee4988f9afd9b2916ee2ff49d6b757155dc0d197852c0"},
]
[package.dependencies]
protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
[package.extras]
grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
[[package]] [[package]]
name = "greenlet" name = "greenlet"
version = "3.0.3" version = "3.0.3"
@@ -587,6 +742,22 @@ files = [
[package.extras] [package.extras]
protobuf = ["grpcio-tools (>=1.64.1)"] protobuf = ["grpcio-tools (>=1.64.1)"]
[[package]]
name = "grpcio-status"
version = "1.62.3"
description = "Status proto mapping for gRPC"
optional = false
python-versions = ">=3.6"
files = [
{file = "grpcio-status-1.62.3.tar.gz", hash = "sha256:289bdd7b2459794a12cf95dc0cb727bd4a1742c37bd823f760236c937e53a485"},
{file = "grpcio_status-1.62.3-py3-none-any.whl", hash = "sha256:f9049b762ba8de6b1086789d8315846e094edac2c50beaf462338b301a8fd4b8"},
]
[package.dependencies]
googleapis-common-protos = ">=1.5.5"
grpcio = ">=1.62.3"
protobuf = ">=4.21.6"
[[package]] [[package]]
name = "grpcio-tools" name = "grpcio-tools"
version = "1.62.2" version = "1.62.2"
@@ -713,6 +884,20 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"] socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.26.0)"] trio = ["trio (>=0.22.0,<0.26.0)"]
[[package]]
name = "httplib2"
version = "0.22.0"
description = "A comprehensive HTTP client library."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"},
{file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"},
]
[package.dependencies]
pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""}
[[package]] [[package]]
name = "httpx" name = "httpx"
version = "0.27.0" version = "0.27.0"
@@ -1281,6 +1466,23 @@ dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"]
sentry = ["django", "sentry-sdk"] sentry = ["django", "sentry-sdk"]
test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"] test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"]
[[package]]
name = "proto-plus"
version = "1.24.0"
description = "Beautiful, Pythonic protocol buffers."
optional = false
python-versions = ">=3.7"
files = [
{file = "proto-plus-1.24.0.tar.gz", hash = "sha256:30b72a5ecafe4406b0d339db35b56c4059064e69227b8c3bda7462397f966445"},
{file = "proto_plus-1.24.0-py3-none-any.whl", hash = "sha256:402576830425e5f6ce4c2a6702400ac79897dab0b4343821aa5188b0fab81a12"},
]
[package.dependencies]
protobuf = ">=3.19.0,<6.0.0dev"
[package.extras]
testing = ["google-api-core (>=1.31.5)"]
[[package]] [[package]]
name = "protobuf" name = "protobuf"
version = "4.25.4" version = "4.25.4"
@@ -1301,6 +1503,31 @@ files = [
{file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"}, {file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"},
] ]
[[package]]
name = "pyasn1"
version = "0.6.1"
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
]
[[package]]
name = "pyasn1-modules"
version = "0.4.1"
description = "A collection of ASN.1-based protocols modules"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
{file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
]
[package.dependencies]
pyasn1 = ">=0.4.6,<0.7.0"
[[package]] [[package]]
name = "pydantic" name = "pydantic"
version = "2.8.2" version = "2.8.2"
@@ -1424,6 +1651,20 @@ files = [
[package.dependencies] [package.dependencies]
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pyparsing"
version = "3.1.4"
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
optional = false
python-versions = ">=3.6.8"
files = [
{file = "pyparsing-3.1.4-py3-none-any.whl", hash = "sha256:a6a7ee4235a3f944aa1fa2249307708f893fe5717dc603503c6c7969c070fb7c"},
{file = "pyparsing-3.1.4.tar.gz", hash = "sha256:f86ec8d1a83f11977c9a6ea7598e8c27fc5cddfa5b07ea2241edbbde1d7bc032"},
]
[package.extras]
diagrams = ["jinja2", "railroad-diagrams"]
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "8.2.2" version = "8.2.2"
@@ -1621,6 +1862,20 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rsa"
version = "4.9"
description = "Pure-Python RSA implementation"
optional = false
python-versions = ">=3.6,<4"
files = [
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
]
[package.dependencies]
pyasn1 = ">=0.1.3"
[[package]] [[package]]
name = "ruff" name = "ruff"
version = "0.6.5" version = "0.6.5"
@@ -1844,6 +2099,17 @@ files = [
mypy-extensions = ">=0.3.0" mypy-extensions = ">=0.3.0"
typing-extensions = ">=3.7.4" typing-extensions = ">=3.7.4"
[[package]]
name = "uritemplate"
version = "4.1.1"
description = "Implementation of RFC 6570 URI Templates"
optional = false
python-versions = ">=3.6"
files = [
{file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"},
{file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"},
]
[[package]] [[package]]
name = "urllib3" name = "urllib3"
version = "2.2.2" version = "2.2.2"
@@ -1967,4 +2233,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
content-hash = "56197730e020f77ee9824292f34348bbe935b42519b4027f6fb131084b88300b" content-hash = "d0b1ade58154f0acd26f2d8ee6cb05c9ddcd10efa0fc5835af12a82e43628bbe"

View File

@@ -0,0 +1,37 @@
from unittest.mock import patch
import pytest
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.gemini import GoogleGenAIEmbedding
@pytest.fixture
def mock_genai():
with patch("mem0.embeddings.gemini.genai.embed_content") as mock_genai:
yield mock_genai
@pytest.fixture
def config():
return BaseEmbedderConfig(
api_key="dummy_api_key",
model="test_model"
)
def test_embed_query(mock_genai, config):
mock_embedding_response = {
'embedding': [0.1, 0.2, 0.3, 0.4]
}
mock_genai.return_value = mock_embedding_response
embedder = GoogleGenAIEmbedding(config)
text = "Hello, world!"
embedding = embedder.embed(text)
assert embedding == [0.1, 0.2, 0.3, 0.4]
mock_genai.assert_called_once_with(
model="test_model",
content="Hello, world!"
)