Add support for Langchain VectorStores (#2518)

This commit is contained in:
Dev Khant
2025-04-11 13:37:18 +05:30
committed by GitHub
parent 8b789adb15
commit 19d7beef43
9 changed files with 386 additions and 5 deletions

View File

@@ -39,14 +39,14 @@ def test_create_col(faiss_instance, mock_faiss_index):
# Test creating a collection with euclidean distance
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2:
with patch('faiss.write_index'):
faiss_instance.create_col(name="new_collection", vector_size=256)
mock_index_flat_l2.assert_called_once_with(256)
faiss_instance.create_col(name="new_collection")
mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims)
# Test creating a collection with inner product distance
with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip:
with patch('faiss.write_index'):
faiss_instance.create_col(name="new_collection", vector_size=256, distance="inner_product")
mock_index_flat_ip.assert_called_once_with(256)
faiss_instance.create_col(name="new_collection", distance="inner_product")
mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims)
def test_insert(faiss_instance, mock_faiss_index):

View File

@@ -0,0 +1,101 @@
from unittest.mock import Mock, patch
import pytest
from langchain_community.vectorstores import VectorStore
from mem0.vector_stores.langchain import Langchain
@pytest.fixture
def mock_langchain_client():
with patch("langchain_community.vectorstores.VectorStore") as mock_client:
yield mock_client
@pytest.fixture
def langchain_instance(mock_langchain_client):
mock_client = Mock(spec=VectorStore)
return Langchain(client=mock_client, collection_name="test_collection")
def test_insert_vectors(langchain_instance):
# Test data
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
payloads = [{"data": "text1", "name": "vector1"}, {"data": "text2", "name": "vector2"}]
ids = ["id1", "id2"]
# Test with add_embeddings method
langchain_instance.client.add_embeddings = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_embeddings.assert_called_once_with(
embeddings=vectors,
metadatas=payloads,
ids=ids
)
# Test with add_texts method
delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely
langchain_instance.client.add_texts = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with(
texts=["text1", "text2"],
metadatas=payloads,
ids=ids
)
# Test with empty payloads
langchain_instance.client.add_texts.reset_mock()
langchain_instance.insert(vectors=vectors, payloads=None, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with(
texts=["", ""],
metadatas=None,
ids=ids
)
def test_search_vectors(langchain_instance):
# Mock search results
mock_docs = [
Mock(metadata={"name": "vector1"}, id="id1"),
Mock(metadata={"name": "vector2"}, id="id2")
]
langchain_instance.client.similarity_search_by_vector.return_value = mock_docs
# Test search without filters
vectors = [[0.1, 0.2, 0.3]]
results = langchain_instance.search(query="", vectors=vectors, limit=2)
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(
embedding=vectors,
k=2
)
assert len(results) == 2
assert results[0].id == "id1"
assert results[0].payload == {"name": "vector1"}
assert results[1].id == "id2"
assert results[1].payload == {"name": "vector2"}
# Test search with filters
filters = {"name": "vector1"}
langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters)
langchain_instance.client.similarity_search_by_vector.assert_called_with(
embedding=vectors,
k=2,
filter=filters
)
def test_get_vector(langchain_instance):
# Mock get result
mock_doc = Mock(metadata={"name": "vector1"}, id="id1")
langchain_instance.client.get_by_ids.return_value = [mock_doc]
# Test get existing vector
result = langchain_instance.get("id1")
langchain_instance.client.get_by_ids.assert_called_once_with(["id1"])
assert result is not None
assert result.id == "id1"
assert result.payload == {"name": "vector1"}
# Test get non-existent vector
langchain_instance.client.get_by_ids.return_value = []
result = langchain_instance.get("non_existent_id")
assert result is None