[Feature] Add Postgres data loader (#918)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
22
embedchain/chunkers/postgres.py
Normal file
22
embedchain/chunkers/postgres.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PostgresChunker(BaseChunker):
|
||||
"""Chunker for postgres."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -1,4 +1,5 @@
|
||||
from importlib import import_module
|
||||
from typing import Any, Dict
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig
|
||||
@@ -15,7 +16,7 @@ class DataFormatter(JSONSerializable):
|
||||
.add or .add_local method call
|
||||
"""
|
||||
|
||||
def __init__(self, data_type: DataType, config: AddConfig):
|
||||
def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]):
|
||||
"""
|
||||
Initialize a dataformatter, set data type and chunker based on datatype.
|
||||
|
||||
@@ -24,15 +25,15 @@ class DataFormatter(JSONSerializable):
|
||||
:param config: AddConfig instance with nested loader and chunker config attributes.
|
||||
:type config: AddConfig
|
||||
"""
|
||||
self.loader = self._get_loader(data_type=data_type, config=config.loader)
|
||||
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker)
|
||||
self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs)
|
||||
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs)
|
||||
|
||||
def _lazy_load(self, module_path: str):
|
||||
module_path, class_name = module_path.rsplit(".", 1)
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
def _get_loader(self, data_type: DataType, config: LoaderConfig) -> BaseLoader:
|
||||
def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader:
|
||||
"""
|
||||
Returns the appropriate data loader for the given data type.
|
||||
|
||||
@@ -63,13 +64,28 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.GMAIL: "embedchain.loaders.gmail.GmailLoader",
|
||||
DataType.NOTION: "embedchain.loaders.notion.NotionLoader",
|
||||
}
|
||||
|
||||
custom_loaders = set(
|
||||
[
|
||||
DataType.POSTGRES,
|
||||
]
|
||||
)
|
||||
|
||||
if data_type in loaders:
|
||||
loader_class: type = self._lazy_load(loaders[data_type])
|
||||
return loader_class()
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
elif data_type in custom_loaders:
|
||||
loader_class: type = kwargs.get("loader", None)
|
||||
if loader_class is not None:
|
||||
return loader_class
|
||||
|
||||
def _get_chunker(self, data_type: DataType, config: ChunkerConfig) -> BaseChunker:
|
||||
raise ValueError(
|
||||
f"Cant find the loader for {data_type}.\
|
||||
We recommend to pass the loader to use data_type: {data_type},\
|
||||
check `https://docs.embedchain.ai/data-sources/overview`."
|
||||
)
|
||||
|
||||
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker:
|
||||
"""Returns the appropriate chunker for the given data type (updated for lazy loading)."""
|
||||
chunker_classes = {
|
||||
DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
|
||||
@@ -89,12 +105,21 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker",
|
||||
DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
|
||||
DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
|
||||
DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
|
||||
}
|
||||
|
||||
if data_type in chunker_classes:
|
||||
chunker_class = self._lazy_load(chunker_classes[data_type])
|
||||
if "chunker" in kwargs:
|
||||
chunker_class = kwargs.get("chunker")
|
||||
else:
|
||||
chunker_class = self._lazy_load(chunker_classes[data_type])
|
||||
|
||||
chunker = chunker_class(config)
|
||||
chunker.set_data_type(data_type)
|
||||
return chunker
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
raise ValueError(
|
||||
f"Cant find the chunker for {data_type}.\
|
||||
We recommend to pass the chunker to use data_type: {data_type},\
|
||||
check `https://docs.embedchain.ai/data-sources/overview`."
|
||||
)
|
||||
|
||||
@@ -137,6 +137,7 @@ class EmbedChain(JSONSerializable):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
dry_run=False,
|
||||
**kwargs: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
@@ -180,21 +181,6 @@ class EmbedChain(JSONSerializable):
|
||||
if data_type:
|
||||
try:
|
||||
data_type = DataType(data_type)
|
||||
if data_type == DataType.JSON:
|
||||
if isinstance(source, str):
|
||||
if not is_valid_json_string(source):
|
||||
raise ValueError(
|
||||
f"Invalid json input: {source}",
|
||||
"Provide the correct JSON formatted source, \
|
||||
refer `https://docs.embedchain.ai/data-sources/json`",
|
||||
)
|
||||
elif not isinstance(source, str):
|
||||
raise ValueError(
|
||||
"Invaid content input. \
|
||||
If you want to upload (list, dict, etc.), do \
|
||||
`json.dump(data, indent=0)` and add the stringified JSON. \
|
||||
Check - `https://docs.embedchain.ai/data-sources/json`"
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid data_type: '{data_type}'.",
|
||||
@@ -218,8 +204,9 @@ class EmbedChain(JSONSerializable):
|
||||
print(f"Data with hash {source_hash} already exists. Skipping addition.")
|
||||
return source_hash
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([source, data_type.value, metadata])
|
||||
|
||||
data_formatter = DataFormatter(data_type, config, kwargs)
|
||||
documents, metadatas, _ids, new_chunks = self.load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
|
||||
)
|
||||
@@ -265,6 +252,7 @@ class EmbedChain(JSONSerializable):
|
||||
data_type: Optional[DataType] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
@@ -290,7 +278,13 @@ class EmbedChain(JSONSerializable):
|
||||
logging.warning(
|
||||
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
|
||||
)
|
||||
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
|
||||
return self.add(
|
||||
source=source,
|
||||
data_type=data_type,
|
||||
metadata=metadata,
|
||||
config=config,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
|
||||
"""
|
||||
|
||||
@@ -25,10 +25,21 @@ class JSONLoader(BaseLoader):
|
||||
|
||||
return LLHUBJSONLoader()
|
||||
|
||||
@staticmethod
|
||||
def _check_content(content):
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
"Invaid content input. \
|
||||
If you want to upload (list, dict, etc.), do \
|
||||
`json.dump(data, indent=0)` and add the stringified JSON. \
|
||||
Check - `https://docs.embedchain.ai/data-sources/json`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_data(content):
|
||||
"""Load a json file. Each data point is a key value pair."""
|
||||
|
||||
JSONLoader._check_content(content)
|
||||
loader = JSONLoader._get_llama_hub_loader()
|
||||
|
||||
data = []
|
||||
|
||||
73
embedchain/loaders/postgres.py
Normal file
73
embedchain/loaders/postgres.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class PostgresLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(f"Must provide the valid config. Received: {config}")
|
||||
|
||||
self.connection = None
|
||||
self.cursor = None
|
||||
self._setup_loader(config=config)
|
||||
|
||||
def _setup_loader(self, config: Dict[str, Any]):
|
||||
try:
|
||||
import psycopg
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import required packages. \
|
||||
Run `pip install --upgrade 'embedchain[postgres]'`"
|
||||
) from e
|
||||
|
||||
config_info = ""
|
||||
if "url" in config:
|
||||
config_info = config.get("url")
|
||||
else:
|
||||
conn_params = []
|
||||
for key, value in config.items():
|
||||
conn_params.append(f"{key}={value}")
|
||||
config_info = " ".join(conn_params)
|
||||
|
||||
logging.info(f"Connecting to postrgres sql: {config_info}")
|
||||
self.connection = psycopg.connect(conninfo=config_info)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
def _check_query(self, query):
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
f"Invalid postgres query: {query}",
|
||||
"Provide the valid source to add from postgres, \
|
||||
make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",
|
||||
)
|
||||
|
||||
def load_data(self, query):
|
||||
self._check_query(query)
|
||||
try:
|
||||
data = []
|
||||
data_content = []
|
||||
self.cursor.execute(query)
|
||||
results = self.cursor.fetchall()
|
||||
for result in results:
|
||||
doc_content = str(result)
|
||||
data.append({"content": doc_content, "meta_data": {"url": f"postgres_query-({query})"}})
|
||||
data_content.append(doc_content)
|
||||
doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data using query={query} with: {e}")
|
||||
|
||||
def close_connection(self):
|
||||
if self.cursor:
|
||||
self.cursor.close()
|
||||
self.cursor = None
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
@@ -29,6 +29,7 @@ class IndirectDataType(Enum):
|
||||
JSON = "json"
|
||||
OPENAPI = "openapi"
|
||||
GMAIL = "gmail"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
|
||||
class SpecialDataType(Enum):
|
||||
@@ -57,3 +58,4 @@ class DataType(Enum):
|
||||
JSON = IndirectDataType.JSON.value
|
||||
OPENAPI = IndirectDataType.OPENAPI.value
|
||||
GMAIL = IndirectDataType.GMAIL.value
|
||||
POSTGRES = IndirectDataType.POSTGRES.value
|
||||
|
||||
Reference in New Issue
Block a user