From 7de8d8519915e186c39657b1bc4366a0b7f078b7 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Wed, 8 Nov 2023 23:50:46 -0800 Subject: [PATCH] [Feature] Add Postgres data loader (#918) Co-authored-by: Deven Patel --- docs/data-sources/overview.mdx | 1 + docs/data-sources/postgres.mdx | 64 ++++++++++++++++++ embedchain/chunkers/postgres.py | 22 +++++++ embedchain/data_formatter/data_formatter.py | 43 +++++++++--- embedchain/embedchain.py | 28 ++++---- embedchain/loaders/json.py | 11 ++++ embedchain/loaders/postgres.py | 73 +++++++++++++++++++++ embedchain/models/data_type.py | 2 + pyproject.toml | 4 ++ tests/chunkers/test_chunkers.py | 2 + tests/embedchain/test_embedchain.py | 2 +- tests/loaders/test_postgres.py | 60 +++++++++++++++++ 12 files changed, 285 insertions(+), 27 deletions(-) create mode 100644 docs/data-sources/postgres.mdx create mode 100644 embedchain/chunkers/postgres.py create mode 100644 embedchain/loaders/postgres.py create mode 100644 tests/loaders/test_postgres.py diff --git a/docs/data-sources/overview.mdx b/docs/data-sources/overview.mdx index d5d5cca8..12e1b45f 100644 --- a/docs/data-sources/overview.mdx +++ b/docs/data-sources/overview.mdx @@ -20,6 +20,7 @@ Embedchain comes with built-in support for various data sources. We handle the c +
diff --git a/docs/data-sources/postgres.mdx b/docs/data-sources/postgres.mdx new file mode 100644 index 00000000..9cb5d0e6 --- /dev/null +++ b/docs/data-sources/postgres.mdx @@ -0,0 +1,64 @@ +--- +title: '🐘 Postgres' +--- + +1. Setup the Postgres loader by configuring the postgres db. +```Python +from embedchain.loaders.postgres import PostgresLoader + +config = { + "host": "host_address", + "port": "port_number", + "dbname": "database_name", + "user": "username", + "password": "password", +} + +""" +config = { + "url": "your_postgres_url" +} +""" + +postgres_loader = PostgresLoader(config=config) + +``` + +You can either setup the loader by passing the postgresql url or by providing the config data. +For more details on how to setup with valid url and config, check postgres [documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING:~:text=34.1.1.%C2%A0Connection%20Strings-,%23,-Several%20libpq%20functions). + +NOTE: if you provide the `url` field in config, all other fields will be ignored. + +2. Once you setup the loader, you can create an app and load data using the above postgres loader +```Python +import os +from embedchain.pipeline import Pipeline as App + +os.environ["OPENAI_API_KEY"] = "sk-xxx" + +app = App() + +question = "What is Elon Musk's networth?" +response = app.query(question) +# Answer: As of September 2021, Elon Musk's net worth is estimated to be around $250 billion, making him one of the wealthiest individuals in the world. However, please note that net worth can fluctuate over time due to various factors such as stock market changes and business ventures. + +app.add("SELECT * FROM table_name;", data_type='postgres', loader=postgres_loader) +# Adds `(1, 'What is your net worth, Elon Musk?', "As of October 2023, Elon Musk's net worth is $255.2 billion.")` + +response = app.query(question) +# Answer: As of October 2023, Elon Musk's net worth is $255.2 billion. +``` + +NOTE: The `add` function of the app will accept any executable query to load data. DO NOT pass the `CREATE`, `INSERT` queries in `add` function as they will result in not adding any data, so it is pointless. + +3. We automatically create a chunker to chunk your postgres data, however if you wish to provide your own chunker class. Here is how you can do that: +```Python + +from embedchain.chunkers.postgres import PostgresChunker +from embedchain.config.add_config import ChunkerConfig + +postgres_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) +postgres_chunker = PostgresChunker(config=postgres_chunker_config) + +app.add("SELECT * FROM table_name;", data_type='postgres', loader=postgres_loader, chunker=postgres_chunker) +``` \ No newline at end of file diff --git a/embedchain/chunkers/postgres.py b/embedchain/chunkers/postgres.py new file mode 100644 index 00000000..168b6fcd --- /dev/null +++ b/embedchain/chunkers/postgres.py @@ -0,0 +1,22 @@ +from typing import Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.add_config import ChunkerConfig +from embedchain.helper.json_serializable import register_deserializable + + +@register_deserializable +class PostgresChunker(BaseChunker): + """Chunker for postgres.""" + + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) + super().__init__(text_splitter) diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 37d475e5..d85aac5e 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -1,4 +1,5 @@ from importlib import import_module +from typing import Any, Dict from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig @@ -15,7 +16,7 @@ class DataFormatter(JSONSerializable): .add or .add_local method call """ - def __init__(self, data_type: DataType, config: AddConfig): + def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]): """ Initialize a dataformatter, set data type and chunker based on datatype. @@ -24,15 +25,15 @@ class DataFormatter(JSONSerializable): :param config: AddConfig instance with nested loader and chunker config attributes. :type config: AddConfig """ - self.loader = self._get_loader(data_type=data_type, config=config.loader) - self.chunker = self._get_chunker(data_type=data_type, config=config.chunker) + self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs) + self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs) def _lazy_load(self, module_path: str): module_path, class_name = module_path.rsplit(".", 1) module = import_module(module_path) return getattr(module, class_name) - def _get_loader(self, data_type: DataType, config: LoaderConfig) -> BaseLoader: + def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader: """ Returns the appropriate data loader for the given data type. @@ -63,13 +64,28 @@ class DataFormatter(JSONSerializable): DataType.GMAIL: "embedchain.loaders.gmail.GmailLoader", DataType.NOTION: "embedchain.loaders.notion.NotionLoader", } + + custom_loaders = set( + [ + DataType.POSTGRES, + ] + ) + if data_type in loaders: loader_class: type = self._lazy_load(loaders[data_type]) return loader_class() - else: - raise ValueError(f"Unsupported data type: {data_type}") + elif data_type in custom_loaders: + loader_class: type = kwargs.get("loader", None) + if loader_class is not None: + return loader_class - def _get_chunker(self, data_type: DataType, config: ChunkerConfig) -> BaseChunker: + raise ValueError( + f"Cant find the loader for {data_type}.\ + We recommend to pass the loader to use data_type: {data_type},\ + check `https://docs.embedchain.ai/data-sources/overview`." + ) + + def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker: """Returns the appropriate chunker for the given data type (updated for lazy loading).""" chunker_classes = { DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker", @@ -89,12 +105,21 @@ class DataFormatter(JSONSerializable): DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker", DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker", DataType.NOTION: "embedchain.chunkers.notion.NotionChunker", + DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker", } if data_type in chunker_classes: - chunker_class = self._lazy_load(chunker_classes[data_type]) + if "chunker" in kwargs: + chunker_class = kwargs.get("chunker") + else: + chunker_class = self._lazy_load(chunker_classes[data_type]) + chunker = chunker_class(config) chunker.set_data_type(data_type) return chunker else: - raise ValueError(f"Unsupported data type: {data_type}") + raise ValueError( + f"Cant find the chunker for {data_type}.\ + We recommend to pass the chunker to use data_type: {data_type},\ + check `https://docs.embedchain.ai/data-sources/overview`." + ) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 6b9fade8..b1a31a36 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -137,6 +137,7 @@ class EmbedChain(JSONSerializable): metadata: Optional[Dict[str, Any]] = None, config: Optional[AddConfig] = None, dry_run=False, + **kwargs: Dict[str, Any], ): """ Adds the data from the given URL to the vector db. @@ -180,21 +181,6 @@ class EmbedChain(JSONSerializable): if data_type: try: data_type = DataType(data_type) - if data_type == DataType.JSON: - if isinstance(source, str): - if not is_valid_json_string(source): - raise ValueError( - f"Invalid json input: {source}", - "Provide the correct JSON formatted source, \ - refer `https://docs.embedchain.ai/data-sources/json`", - ) - elif not isinstance(source, str): - raise ValueError( - "Invaid content input. \ - If you want to upload (list, dict, etc.), do \ - `json.dump(data, indent=0)` and add the stringified JSON. \ - Check - `https://docs.embedchain.ai/data-sources/json`" - ) except ValueError: raise ValueError( f"Invalid data_type: '{data_type}'.", @@ -218,8 +204,9 @@ class EmbedChain(JSONSerializable): print(f"Data with hash {source_hash} already exists. Skipping addition.") return source_hash - data_formatter = DataFormatter(data_type, config) self.user_asks.append([source, data_type.value, metadata]) + + data_formatter = DataFormatter(data_type, config, kwargs) documents, metadatas, _ids, new_chunks = self.load_and_embed( data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run ) @@ -265,6 +252,7 @@ class EmbedChain(JSONSerializable): data_type: Optional[DataType] = None, metadata: Optional[Dict[str, Any]] = None, config: Optional[AddConfig] = None, + **kwargs: Dict[str, Any], ): """ Adds the data from the given URL to the vector db. @@ -290,7 +278,13 @@ class EmbedChain(JSONSerializable): logging.warning( "The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501 ) - return self.add(source=source, data_type=data_type, metadata=metadata, config=config) + return self.add( + source=source, + data_type=data_type, + metadata=metadata, + config=config, + kwargs=kwargs, + ) def _get_existing_doc_id(self, chunker: BaseChunker, src: Any): """ diff --git a/embedchain/loaders/json.py b/embedchain/loaders/json.py index 058f8fd0..be12d734 100644 --- a/embedchain/loaders/json.py +++ b/embedchain/loaders/json.py @@ -25,10 +25,21 @@ class JSONLoader(BaseLoader): return LLHUBJSONLoader() + @staticmethod + def _check_content(content): + if not isinstance(content, str): + raise ValueError( + "Invaid content input. \ + If you want to upload (list, dict, etc.), do \ + `json.dump(data, indent=0)` and add the stringified JSON. \ + Check - `https://docs.embedchain.ai/data-sources/json`" + ) + @staticmethod def load_data(content): """Load a json file. Each data point is a key value pair.""" + JSONLoader._check_content(content) loader = JSONLoader._get_llama_hub_loader() data = [] diff --git a/embedchain/loaders/postgres.py b/embedchain/loaders/postgres.py new file mode 100644 index 00000000..8c6bafba --- /dev/null +++ b/embedchain/loaders/postgres.py @@ -0,0 +1,73 @@ +import hashlib +import logging +from typing import Any, Dict, Optional + +from embedchain.loaders.base_loader import BaseLoader + + +class PostgresLoader(BaseLoader): + def __init__(self, config: Optional[Dict[str, Any]] = None): + super().__init__() + if not config: + raise ValueError(f"Must provide the valid config. Received: {config}") + + self.connection = None + self.cursor = None + self._setup_loader(config=config) + + def _setup_loader(self, config: Dict[str, Any]): + try: + import psycopg + except ImportError as e: + raise ImportError( + "Unable to import required packages. \ + Run `pip install --upgrade 'embedchain[postgres]'`" + ) from e + + config_info = "" + if "url" in config: + config_info = config.get("url") + else: + conn_params = [] + for key, value in config.items(): + conn_params.append(f"{key}={value}") + config_info = " ".join(conn_params) + + logging.info(f"Connecting to postrgres sql: {config_info}") + self.connection = psycopg.connect(conninfo=config_info) + self.cursor = self.connection.cursor() + + def _check_query(self, query): + if not isinstance(query, str): + raise ValueError( + f"Invalid postgres query: {query}", + "Provide the valid source to add from postgres, \ + make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", + ) + + def load_data(self, query): + self._check_query(query) + try: + data = [] + data_content = [] + self.cursor.execute(query) + results = self.cursor.fetchall() + for result in results: + doc_content = str(result) + data.append({"content": doc_content, "meta_data": {"url": f"postgres_query-({query})"}}) + data_content.append(doc_content) + doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": data, + } + except Exception as e: + raise ValueError(f"Failed to load data using query={query} with: {e}") + + def close_connection(self): + if self.cursor: + self.cursor.close() + self.cursor = None + if self.connection: + self.connection.close() + self.connection = None diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index 31f07cdf..76dffaf1 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -29,6 +29,7 @@ class IndirectDataType(Enum): JSON = "json" OPENAPI = "openapi" GMAIL = "gmail" + POSTGRES = "postgres" class SpecialDataType(Enum): @@ -57,3 +58,4 @@ class DataType(Enum): JSON = IndirectDataType.JSON.value OPENAPI = IndirectDataType.OPENAPI.value GMAIL = IndirectDataType.GMAIL.value + POSTGRES = IndirectDataType.POSTGRES.value diff --git a/pyproject.toml b/pyproject.toml index 392f62ec..48d9ad4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,9 @@ pymilvus = { version = "2.3.1", optional = true } google-cloud-aiplatform = { version = "^1.26.1", optional = true } replicate = { version = "^0.15.4", optional = true } schema = "^0.7.5" +psycopg = { version = "^3.1.12", optional = true } +psycopg-binary = { version = "^3.1.12", optional = true } +psycopg-pool = { version = "^3.1.8", optional = true } [tool.poetry.group.dev.dependencies] black = "^23.3.0" @@ -184,6 +187,7 @@ gmail = [ "google-api-core", ] json = ["llama-hub"] +postgres = ["psycopg", "psycopg-binary", "psycopg-pool"] [tool.poetry.group.docs.dependencies] diff --git a/tests/chunkers/test_chunkers.py b/tests/chunkers/test_chunkers.py index 5d73f5c0..f38f2751 100644 --- a/tests/chunkers/test_chunkers.py +++ b/tests/chunkers/test_chunkers.py @@ -6,6 +6,7 @@ from embedchain.chunkers.mdx import MdxChunker from embedchain.chunkers.notion import NotionChunker from embedchain.chunkers.openapi import OpenAPIChunker from embedchain.chunkers.pdf_file import PdfFileChunker +from embedchain.chunkers.postgres import PostgresChunker from embedchain.chunkers.qna_pair import QnaPairChunker from embedchain.chunkers.sitemap import SitemapChunker from embedchain.chunkers.table import TableChunker @@ -33,6 +34,7 @@ chunker_common_config = { JSONChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, OpenAPIChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, + PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, } diff --git a/tests/embedchain/test_embedchain.py b/tests/embedchain/test_embedchain.py index f4671697..8d57dca2 100644 --- a/tests/embedchain/test_embedchain.py +++ b/tests/embedchain/test_embedchain.py @@ -63,5 +63,5 @@ def test_add_after_reset(app_instance, mocker): def test_add_with_incorrect_content(app_instance, mocker): content = [{"foo": "bar"}] - with pytest.raises(ValueError): + with pytest.raises(TypeError): app_instance.add(content, data_type="json") diff --git a/tests/loaders/test_postgres.py b/tests/loaders/test_postgres.py new file mode 100644 index 00000000..e0cdbfc5 --- /dev/null +++ b/tests/loaders/test_postgres.py @@ -0,0 +1,60 @@ +from unittest.mock import MagicMock + +import psycopg +import pytest + +from embedchain.loaders.postgres import PostgresLoader + + +@pytest.fixture +def postgres_loader(mocker): + with mocker.patch.object(psycopg, "connect"): + config = {"url": "postgres://user:password@localhost:5432/database"} + loader = PostgresLoader(config=config) + yield loader + + +def test_postgres_loader_initialization(postgres_loader): + assert postgres_loader.connection is not None + assert postgres_loader.cursor is not None + + +def test_postgres_loader_invalid_config(): + with pytest.raises(ValueError, match="Must provide the valid config. Received: None"): + PostgresLoader(config=None) + + +def test_load_data(postgres_loader, monkeypatch): + mock_cursor = MagicMock() + monkeypatch.setattr(postgres_loader, "cursor", mock_cursor) + + query = "SELECT * FROM table" + mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")] + + result = postgres_loader.load_data(query) + + assert "doc_id" in result + assert "data" in result + assert len(result["data"]) == 2 + assert result["data"][0]["meta_data"]["url"] == f"postgres_query-({query})" + assert result["data"][1]["meta_data"]["url"] == f"postgres_query-({query})" + assert mock_cursor.execute.called_with(query) + + +def test_load_data_exception(postgres_loader, monkeypatch): + mock_cursor = MagicMock() + monkeypatch.setattr(postgres_loader, "cursor", mock_cursor) + + _ = "SELECT * FROM table" + mock_cursor.execute.side_effect = Exception("Mocked exception") + + with pytest.raises( + ValueError, match=r"Failed to load data using query=SELECT \* FROM table with: Mocked exception" + ): + postgres_loader.load_data("SELECT * FROM table") + + +def test_close_connection(postgres_loader): + postgres_loader.close_connection() + assert postgres_loader.cursor is None + assert postgres_loader.connection is None