[Feature] Add Postgres data loader (#918)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-08 23:50:46 -08:00
committed by GitHub
parent f7dd65a3de
commit 7de8d85199
12 changed files with 285 additions and 27 deletions

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class PostgresChunker(BaseChunker):
"""Chunker for postgres."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -1,4 +1,5 @@
from importlib import import_module
from typing import Any, Dict
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig
@@ -15,7 +16,7 @@ class DataFormatter(JSONSerializable):
.add or .add_local method call
"""
def __init__(self, data_type: DataType, config: AddConfig):
def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]):
"""
Initialize a dataformatter, set data type and chunker based on datatype.
@@ -24,15 +25,15 @@ class DataFormatter(JSONSerializable):
:param config: AddConfig instance with nested loader and chunker config attributes.
:type config: AddConfig
"""
self.loader = self._get_loader(data_type=data_type, config=config.loader)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker)
self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs)
def _lazy_load(self, module_path: str):
module_path, class_name = module_path.rsplit(".", 1)
module = import_module(module_path)
return getattr(module, class_name)
def _get_loader(self, data_type: DataType, config: LoaderConfig) -> BaseLoader:
def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader:
"""
Returns the appropriate data loader for the given data type.
@@ -63,13 +64,28 @@ class DataFormatter(JSONSerializable):
DataType.GMAIL: "embedchain.loaders.gmail.GmailLoader",
DataType.NOTION: "embedchain.loaders.notion.NotionLoader",
}
custom_loaders = set(
[
DataType.POSTGRES,
]
)
if data_type in loaders:
loader_class: type = self._lazy_load(loaders[data_type])
return loader_class()
else:
raise ValueError(f"Unsupported data type: {data_type}")
elif data_type in custom_loaders:
loader_class: type = kwargs.get("loader", None)
if loader_class is not None:
return loader_class
def _get_chunker(self, data_type: DataType, config: ChunkerConfig) -> BaseChunker:
raise ValueError(
f"Cant find the loader for {data_type}.\
We recommend to pass the loader to use data_type: {data_type},\
check `https://docs.embedchain.ai/data-sources/overview`."
)
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker:
"""Returns the appropriate chunker for the given data type (updated for lazy loading)."""
chunker_classes = {
DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
@@ -89,12 +105,21 @@ class DataFormatter(JSONSerializable):
DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker",
DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
}
if data_type in chunker_classes:
chunker_class = self._lazy_load(chunker_classes[data_type])
if "chunker" in kwargs:
chunker_class = kwargs.get("chunker")
else:
chunker_class = self._lazy_load(chunker_classes[data_type])
chunker = chunker_class(config)
chunker.set_data_type(data_type)
return chunker
else:
raise ValueError(f"Unsupported data type: {data_type}")
raise ValueError(
f"Cant find the chunker for {data_type}.\
We recommend to pass the chunker to use data_type: {data_type},\
check `https://docs.embedchain.ai/data-sources/overview`."
)

View File

@@ -137,6 +137,7 @@ class EmbedChain(JSONSerializable):
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
dry_run=False,
**kwargs: Dict[str, Any],
):
"""
Adds the data from the given URL to the vector db.
@@ -180,21 +181,6 @@ class EmbedChain(JSONSerializable):
if data_type:
try:
data_type = DataType(data_type)
if data_type == DataType.JSON:
if isinstance(source, str):
if not is_valid_json_string(source):
raise ValueError(
f"Invalid json input: {source}",
"Provide the correct JSON formatted source, \
refer `https://docs.embedchain.ai/data-sources/json`",
)
elif not isinstance(source, str):
raise ValueError(
"Invaid content input. \
If you want to upload (list, dict, etc.), do \
`json.dump(data, indent=0)` and add the stringified JSON. \
Check - `https://docs.embedchain.ai/data-sources/json`"
)
except ValueError:
raise ValueError(
f"Invalid data_type: '{data_type}'.",
@@ -218,8 +204,9 @@ class EmbedChain(JSONSerializable):
print(f"Data with hash {source_hash} already exists. Skipping addition.")
return source_hash
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([source, data_type.value, metadata])
data_formatter = DataFormatter(data_type, config, kwargs)
documents, metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
)
@@ -265,6 +252,7 @@ class EmbedChain(JSONSerializable):
data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
**kwargs: Dict[str, Any],
):
"""
Adds the data from the given URL to the vector db.
@@ -290,7 +278,13 @@ class EmbedChain(JSONSerializable):
logging.warning(
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
)
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
return self.add(
source=source,
data_type=data_type,
metadata=metadata,
config=config,
kwargs=kwargs,
)
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
"""

View File

@@ -25,10 +25,21 @@ class JSONLoader(BaseLoader):
return LLHUBJSONLoader()
@staticmethod
def _check_content(content):
if not isinstance(content, str):
raise ValueError(
"Invaid content input. \
If you want to upload (list, dict, etc.), do \
`json.dump(data, indent=0)` and add the stringified JSON. \
Check - `https://docs.embedchain.ai/data-sources/json`"
)
@staticmethod
def load_data(content):
"""Load a json file. Each data point is a key value pair."""
JSONLoader._check_content(content)
loader = JSONLoader._get_llama_hub_loader()
data = []

View File

@@ -0,0 +1,73 @@
import hashlib
import logging
from typing import Any, Dict, Optional
from embedchain.loaders.base_loader import BaseLoader
class PostgresLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(f"Must provide the valid config. Received: {config}")
self.connection = None
self.cursor = None
self._setup_loader(config=config)
def _setup_loader(self, config: Dict[str, Any]):
try:
import psycopg
except ImportError as e:
raise ImportError(
"Unable to import required packages. \
Run `pip install --upgrade 'embedchain[postgres]'`"
) from e
config_info = ""
if "url" in config:
config_info = config.get("url")
else:
conn_params = []
for key, value in config.items():
conn_params.append(f"{key}={value}")
config_info = " ".join(conn_params)
logging.info(f"Connecting to postrgres sql: {config_info}")
self.connection = psycopg.connect(conninfo=config_info)
self.cursor = self.connection.cursor()
def _check_query(self, query):
if not isinstance(query, str):
raise ValueError(
f"Invalid postgres query: {query}",
"Provide the valid source to add from postgres, \
make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",
)
def load_data(self, query):
self._check_query(query)
try:
data = []
data_content = []
self.cursor.execute(query)
results = self.cursor.fetchall()
for result in results:
doc_content = str(result)
data.append({"content": doc_content, "meta_data": {"url": f"postgres_query-({query})"}})
data_content.append(doc_content)
doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}
except Exception as e:
raise ValueError(f"Failed to load data using query={query} with: {e}")
def close_connection(self):
if self.cursor:
self.cursor.close()
self.cursor = None
if self.connection:
self.connection.close()
self.connection = None

View File

@@ -29,6 +29,7 @@ class IndirectDataType(Enum):
JSON = "json"
OPENAPI = "openapi"
GMAIL = "gmail"
POSTGRES = "postgres"
class SpecialDataType(Enum):
@@ -57,3 +58,4 @@ class DataType(Enum):
JSON = IndirectDataType.JSON.value
OPENAPI = IndirectDataType.OPENAPI.value
GMAIL = IndirectDataType.GMAIL.value
POSTGRES = IndirectDataType.POSTGRES.value