Fix Embedding Dimension Parameter Not Being Passed (#2304)
This commit is contained in:
committed by
GitHub
parent
dd1f2989bc
commit
8c6d16a6f0
@@ -28,5 +28,5 @@ class GoogleGenAIEmbedding(EmbeddingBase):
|
|||||||
list: The embedding vector.
|
list: The embedding vector.
|
||||||
"""
|
"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
response = genai.embed_content(model=self.config.model, content=text)
|
response = genai.embed_content(model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims)
|
||||||
return response["embedding"]
|
return response["embedding"]
|
||||||
|
|||||||
@@ -29,4 +29,4 @@ class OpenAIEmbedding(EmbeddingBase):
|
|||||||
list: The embedding vector.
|
list: The embedding vector.
|
||||||
"""
|
"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
return self.client.embeddings.create(input=[text], model=self.config.model, dimensions = self.config.embedding_dims).data[0].embedding
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ def mock_genai():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def config():
|
def config():
|
||||||
return BaseEmbedderConfig(api_key="dummy_api_key", model="test_model")
|
return BaseEmbedderConfig(api_key="dummy_api_key", model="test_model", embedding_dims=786)
|
||||||
|
|
||||||
|
|
||||||
def test_embed_query(mock_genai, config):
|
def test_embed_query(mock_genai, config):
|
||||||
@@ -25,4 +25,4 @@ def test_embed_query(mock_genai, config):
|
|||||||
embedding = embedder.embed(text)
|
embedding = embedder.embed(text)
|
||||||
|
|
||||||
assert embedding == [0.1, 0.2, 0.3, 0.4]
|
assert embedding == [0.1, 0.2, 0.3, 0.4]
|
||||||
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!")
|
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def test_embed_default_model(mock_openai_client):
|
|||||||
|
|
||||||
result = embedder.embed("Hello world")
|
result = embedder.embed("Hello world")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small")
|
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
|
||||||
assert result == [0.1, 0.2, 0.3]
|
assert result == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ def test_embed_custom_model(mock_openai_client):
|
|||||||
result = embedder.embed("Test embedding")
|
result = embedder.embed("Test embedding")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(
|
mock_openai_client.embeddings.create.assert_called_once_with(
|
||||||
input=["Test embedding"], model="text-embedding-2-medium"
|
input=["Test embedding"], model="text-embedding-2-medium", dimensions = 1024
|
||||||
)
|
)
|
||||||
assert result == [0.4, 0.5, 0.6]
|
assert result == [0.4, 0.5, 0.6]
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ def test_embed_removes_newlines(mock_openai_client):
|
|||||||
|
|
||||||
result = embedder.embed("Hello\nworld")
|
result = embedder.embed("Hello\nworld")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small")
|
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
|
||||||
assert result == [0.7, 0.8, 0.9]
|
assert result == [0.7, 0.8, 0.9]
|
||||||
|
|
||||||
|
|
||||||
@@ -63,7 +63,7 @@ def test_embed_without_api_key_env_var(mock_openai_client):
|
|||||||
result = embedder.embed("Testing API key")
|
result = embedder.embed("Testing API key")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(
|
mock_openai_client.embeddings.create.assert_called_once_with(
|
||||||
input=["Testing API key"], model="text-embedding-3-small"
|
input=["Testing API key"], model="text-embedding-3-small", dimensions = 1536
|
||||||
)
|
)
|
||||||
assert result == [1.0, 1.1, 1.2]
|
assert result == [1.0, 1.1, 1.2]
|
||||||
|
|
||||||
@@ -79,6 +79,6 @@ def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch):
|
|||||||
result = embedder.embed("Environment key test")
|
result = embedder.embed("Environment key test")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(
|
mock_openai_client.embeddings.create.assert_called_once_with(
|
||||||
input=["Environment key test"], model="text-embedding-3-small"
|
input=["Environment key test"], model="text-embedding-3-small", dimensions = 1536
|
||||||
)
|
)
|
||||||
assert result == [1.3, 1.4, 1.5]
|
assert result == [1.3, 1.4, 1.5]
|
||||||
|
|||||||
Reference in New Issue
Block a user