[Feature] Add Postgres data loader (#918)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-08 23:50:46 -08:00
committed by GitHub
parent f7dd65a3de
commit 7de8d85199
12 changed files with 285 additions and 27 deletions

View File

@@ -20,6 +20,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
<Card title="🙌 OpenAPI" href="/data-sources/openapi"></Card>
<Card title="🎥📺 youtube video" href="/data-sources/youtube-video"></Card>
<Card title="📬 Gmail" href="/data-sources/gmail"></Card>
<Card title="🐘 Postgres" href="/data-sources/postgres"></Card>
</CardGroup>
<br/ >

View File

@@ -0,0 +1,64 @@
---
title: '🐘 Postgres'
---
1. Setup the Postgres loader by configuring the postgres db.
```Python
from embedchain.loaders.postgres import PostgresLoader
config = {
"host": "host_address",
"port": "port_number",
"dbname": "database_name",
"user": "username",
"password": "password",
}
"""
config = {
"url": "your_postgres_url"
}
"""
postgres_loader = PostgresLoader(config=config)
```
You can either setup the loader by passing the postgresql url or by providing the config data.
For more details on how to setup with valid url and config, check postgres [documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING:~:text=34.1.1.%C2%A0Connection%20Strings-,%23,-Several%20libpq%20functions).
NOTE: if you provide the `url` field in config, all other fields will be ignored.
2. Once you setup the loader, you can create an app and load data using the above postgres loader
```Python
import os
from embedchain.pipeline import Pipeline as App
os.environ["OPENAI_API_KEY"] = "sk-xxx"
app = App()
question = "What is Elon Musk's networth?"
response = app.query(question)
# Answer: As of September 2021, Elon Musk's net worth is estimated to be around $250 billion, making him one of the wealthiest individuals in the world. However, please note that net worth can fluctuate over time due to various factors such as stock market changes and business ventures.
app.add("SELECT * FROM table_name;", data_type='postgres', loader=postgres_loader)
# Adds `(1, 'What is your net worth, Elon Musk?', "As of October 2023, Elon Musk's net worth is $255.2 billion.")`
response = app.query(question)
# Answer: As of October 2023, Elon Musk's net worth is $255.2 billion.
```
NOTE: The `add` function of the app will accept any executable query to load data. DO NOT pass the `CREATE`, `INSERT` queries in `add` function as they will result in not adding any data, so it is pointless.
3. We automatically create a chunker to chunk your postgres data, however if you wish to provide your own chunker class. Here is how you can do that:
```Python
from embedchain.chunkers.postgres import PostgresChunker
from embedchain.config.add_config import ChunkerConfig
postgres_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
postgres_chunker = PostgresChunker(config=postgres_chunker_config)
app.add("SELECT * FROM table_name;", data_type='postgres', loader=postgres_loader, chunker=postgres_chunker)
```

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class PostgresChunker(BaseChunker):
"""Chunker for postgres."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -1,4 +1,5 @@
from importlib import import_module
from typing import Any, Dict
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig
@@ -15,7 +16,7 @@ class DataFormatter(JSONSerializable):
.add or .add_local method call
"""
def __init__(self, data_type: DataType, config: AddConfig):
def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]):
"""
Initialize a dataformatter, set data type and chunker based on datatype.
@@ -24,15 +25,15 @@ class DataFormatter(JSONSerializable):
:param config: AddConfig instance with nested loader and chunker config attributes.
:type config: AddConfig
"""
self.loader = self._get_loader(data_type=data_type, config=config.loader)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker)
self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs)
def _lazy_load(self, module_path: str):
module_path, class_name = module_path.rsplit(".", 1)
module = import_module(module_path)
return getattr(module, class_name)
def _get_loader(self, data_type: DataType, config: LoaderConfig) -> BaseLoader:
def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader:
"""
Returns the appropriate data loader for the given data type.
@@ -63,13 +64,28 @@ class DataFormatter(JSONSerializable):
DataType.GMAIL: "embedchain.loaders.gmail.GmailLoader",
DataType.NOTION: "embedchain.loaders.notion.NotionLoader",
}
custom_loaders = set(
[
DataType.POSTGRES,
]
)
if data_type in loaders:
loader_class: type = self._lazy_load(loaders[data_type])
return loader_class()
else:
raise ValueError(f"Unsupported data type: {data_type}")
elif data_type in custom_loaders:
loader_class: type = kwargs.get("loader", None)
if loader_class is not None:
return loader_class
def _get_chunker(self, data_type: DataType, config: ChunkerConfig) -> BaseChunker:
raise ValueError(
f"Cant find the loader for {data_type}.\
We recommend to pass the loader to use data_type: {data_type},\
check `https://docs.embedchain.ai/data-sources/overview`."
)
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker:
"""Returns the appropriate chunker for the given data type (updated for lazy loading)."""
chunker_classes = {
DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
@@ -89,12 +105,21 @@ class DataFormatter(JSONSerializable):
DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker",
DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
}
if data_type in chunker_classes:
chunker_class = self._lazy_load(chunker_classes[data_type])
if "chunker" in kwargs:
chunker_class = kwargs.get("chunker")
else:
chunker_class = self._lazy_load(chunker_classes[data_type])
chunker = chunker_class(config)
chunker.set_data_type(data_type)
return chunker
else:
raise ValueError(f"Unsupported data type: {data_type}")
raise ValueError(
f"Cant find the chunker for {data_type}.\
We recommend to pass the chunker to use data_type: {data_type},\
check `https://docs.embedchain.ai/data-sources/overview`."
)

View File

@@ -137,6 +137,7 @@ class EmbedChain(JSONSerializable):
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
dry_run=False,
**kwargs: Dict[str, Any],
):
"""
Adds the data from the given URL to the vector db.
@@ -180,21 +181,6 @@ class EmbedChain(JSONSerializable):
if data_type:
try:
data_type = DataType(data_type)
if data_type == DataType.JSON:
if isinstance(source, str):
if not is_valid_json_string(source):
raise ValueError(
f"Invalid json input: {source}",
"Provide the correct JSON formatted source, \
refer `https://docs.embedchain.ai/data-sources/json`",
)
elif not isinstance(source, str):
raise ValueError(
"Invaid content input. \
If you want to upload (list, dict, etc.), do \
`json.dump(data, indent=0)` and add the stringified JSON. \
Check - `https://docs.embedchain.ai/data-sources/json`"
)
except ValueError:
raise ValueError(
f"Invalid data_type: '{data_type}'.",
@@ -218,8 +204,9 @@ class EmbedChain(JSONSerializable):
print(f"Data with hash {source_hash} already exists. Skipping addition.")
return source_hash
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([source, data_type.value, metadata])
data_formatter = DataFormatter(data_type, config, kwargs)
documents, metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
)
@@ -265,6 +252,7 @@ class EmbedChain(JSONSerializable):
data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
**kwargs: Dict[str, Any],
):
"""
Adds the data from the given URL to the vector db.
@@ -290,7 +278,13 @@ class EmbedChain(JSONSerializable):
logging.warning(
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
)
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
return self.add(
source=source,
data_type=data_type,
metadata=metadata,
config=config,
kwargs=kwargs,
)
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
"""

View File

@@ -25,10 +25,21 @@ class JSONLoader(BaseLoader):
return LLHUBJSONLoader()
@staticmethod
def _check_content(content):
if not isinstance(content, str):
raise ValueError(
"Invaid content input. \
If you want to upload (list, dict, etc.), do \
`json.dump(data, indent=0)` and add the stringified JSON. \
Check - `https://docs.embedchain.ai/data-sources/json`"
)
@staticmethod
def load_data(content):
"""Load a json file. Each data point is a key value pair."""
JSONLoader._check_content(content)
loader = JSONLoader._get_llama_hub_loader()
data = []

View File

@@ -0,0 +1,73 @@
import hashlib
import logging
from typing import Any, Dict, Optional
from embedchain.loaders.base_loader import BaseLoader
class PostgresLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(f"Must provide the valid config. Received: {config}")
self.connection = None
self.cursor = None
self._setup_loader(config=config)
def _setup_loader(self, config: Dict[str, Any]):
try:
import psycopg
except ImportError as e:
raise ImportError(
"Unable to import required packages. \
Run `pip install --upgrade 'embedchain[postgres]'`"
) from e
config_info = ""
if "url" in config:
config_info = config.get("url")
else:
conn_params = []
for key, value in config.items():
conn_params.append(f"{key}={value}")
config_info = " ".join(conn_params)
logging.info(f"Connecting to postrgres sql: {config_info}")
self.connection = psycopg.connect(conninfo=config_info)
self.cursor = self.connection.cursor()
def _check_query(self, query):
if not isinstance(query, str):
raise ValueError(
f"Invalid postgres query: {query}",
"Provide the valid source to add from postgres, \
make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",
)
def load_data(self, query):
self._check_query(query)
try:
data = []
data_content = []
self.cursor.execute(query)
results = self.cursor.fetchall()
for result in results:
doc_content = str(result)
data.append({"content": doc_content, "meta_data": {"url": f"postgres_query-({query})"}})
data_content.append(doc_content)
doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}
except Exception as e:
raise ValueError(f"Failed to load data using query={query} with: {e}")
def close_connection(self):
if self.cursor:
self.cursor.close()
self.cursor = None
if self.connection:
self.connection.close()
self.connection = None

View File

@@ -29,6 +29,7 @@ class IndirectDataType(Enum):
JSON = "json"
OPENAPI = "openapi"
GMAIL = "gmail"
POSTGRES = "postgres"
class SpecialDataType(Enum):
@@ -57,3 +58,4 @@ class DataType(Enum):
JSON = IndirectDataType.JSON.value
OPENAPI = IndirectDataType.OPENAPI.value
GMAIL = IndirectDataType.GMAIL.value
POSTGRES = IndirectDataType.POSTGRES.value

View File

@@ -130,6 +130,9 @@ pymilvus = { version = "2.3.1", optional = true }
google-cloud-aiplatform = { version = "^1.26.1", optional = true }
replicate = { version = "^0.15.4", optional = true }
schema = "^0.7.5"
psycopg = { version = "^3.1.12", optional = true }
psycopg-binary = { version = "^3.1.12", optional = true }
psycopg-pool = { version = "^3.1.8", optional = true }
[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
@@ -184,6 +187,7 @@ gmail = [
"google-api-core",
]
json = ["llama-hub"]
postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
[tool.poetry.group.docs.dependencies]

View File

@@ -6,6 +6,7 @@ from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.openapi import OpenAPIChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.postgres import PostgresChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.sitemap import SitemapChunker
from embedchain.chunkers.table import TableChunker
@@ -33,6 +34,7 @@ chunker_common_config = {
JSONChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
OpenAPIChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
}

View File

@@ -63,5 +63,5 @@ def test_add_after_reset(app_instance, mocker):
def test_add_with_incorrect_content(app_instance, mocker):
content = [{"foo": "bar"}]
with pytest.raises(ValueError):
with pytest.raises(TypeError):
app_instance.add(content, data_type="json")

View File

@@ -0,0 +1,60 @@
from unittest.mock import MagicMock
import psycopg
import pytest
from embedchain.loaders.postgres import PostgresLoader
@pytest.fixture
def postgres_loader(mocker):
with mocker.patch.object(psycopg, "connect"):
config = {"url": "postgres://user:password@localhost:5432/database"}
loader = PostgresLoader(config=config)
yield loader
def test_postgres_loader_initialization(postgres_loader):
assert postgres_loader.connection is not None
assert postgres_loader.cursor is not None
def test_postgres_loader_invalid_config():
with pytest.raises(ValueError, match="Must provide the valid config. Received: None"):
PostgresLoader(config=None)
def test_load_data(postgres_loader, monkeypatch):
mock_cursor = MagicMock()
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
query = "SELECT * FROM table"
mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
result = postgres_loader.load_data(query)
assert "doc_id" in result
assert "data" in result
assert len(result["data"]) == 2
assert result["data"][0]["meta_data"]["url"] == f"postgres_query-({query})"
assert result["data"][1]["meta_data"]["url"] == f"postgres_query-({query})"
assert mock_cursor.execute.called_with(query)
def test_load_data_exception(postgres_loader, monkeypatch):
mock_cursor = MagicMock()
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
_ = "SELECT * FROM table"
mock_cursor.execute.side_effect = Exception("Mocked exception")
with pytest.raises(
ValueError, match=r"Failed to load data using query=SELECT \* FROM table with: Mocked exception"
):
postgres_loader.load_data("SELECT * FROM table")
def test_close_connection(postgres_loader):
postgres_loader.close_connection()
assert postgres_loader.cursor is None
assert postgres_loader.connection is None