[Feature] Add MySQL Loader (#920)
Co-authored-by: Deven Patel <deven298@yahoo.com> Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
22
embedchain/chunkers/mysql.py
Normal file
22
embedchain/chunkers/mysql.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class MySQLChunker(BaseChunker):
|
||||
"""Chunker for json."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -68,6 +68,7 @@ class DataFormatter(JSONSerializable):
|
||||
custom_loaders = set(
|
||||
[
|
||||
DataType.POSTGRES,
|
||||
DataType.MYSQL,
|
||||
DataType.SLACK,
|
||||
]
|
||||
)
|
||||
@@ -107,6 +108,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
|
||||
DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
|
||||
DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
|
||||
DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker",
|
||||
DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
|
||||
}
|
||||
|
||||
|
||||
64
embedchain/loaders/mysql.py
Normal file
64
embedchain/loaders/mysql.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
class MySQLLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[Dict[str, Any]]):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(
|
||||
f"Invalid sql config: {config}.",
|
||||
"Provide the correct config, refer `https://docs.embedchain.ai/data-sources/mysql`.",
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.connection = None
|
||||
self.cursor = None
|
||||
self._setup_loader(config=config)
|
||||
|
||||
def _setup_loader(self, config: Dict[str, Any]):
|
||||
try:
|
||||
import mysql.connector as sqlconnector
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import required packages for MySQL loader. Run `pip install --upgrade 'embedchain[mysql]'`." # noqa: E501
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.connection = sqlconnector.connection.MySQLConnection(**config)
|
||||
self.cursor = self.connection.cursor()
|
||||
except (sqlconnector.Error, IOError) as err:
|
||||
logging.info(f"Connection failed: {err}")
|
||||
raise ValueError(
|
||||
f"Unable to connect with the given config: {config}.",
|
||||
"Please provide the correct configuration to load data from you MySQL DB. \
|
||||
Refer `https://docs.embedchain.ai/data-sources/mysql`.",
|
||||
)
|
||||
|
||||
def _check_query(self, query):
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
f"Invalid mysql query: {query}",
|
||||
"Provide the valid query to add from mysql, \
|
||||
make sure you are following `https://docs.embedchain.ai/data-sources/mysql`",
|
||||
)
|
||||
|
||||
def load_data(self, query):
|
||||
self._check_query(query=query)
|
||||
data = []
|
||||
data_content = []
|
||||
self.cursor.execute(query)
|
||||
rows = self.cursor.fetchall()
|
||||
for row in rows:
|
||||
doc_content = clean_string(str(row))
|
||||
data.append({"content": doc_content, "meta_data": {"url": query}})
|
||||
data_content.append(doc_content)
|
||||
doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
@@ -52,7 +52,7 @@ class PostgresLoader(BaseLoader):
|
||||
results = self.cursor.fetchall()
|
||||
for result in results:
|
||||
doc_content = str(result)
|
||||
data.append({"content": doc_content, "meta_data": {"url": f"postgres_query-({query})"}})
|
||||
data.append({"content": doc_content, "meta_data": {"url": query}})
|
||||
data_content.append(doc_content)
|
||||
doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
|
||||
return {
|
||||
|
||||
@@ -30,6 +30,7 @@ class IndirectDataType(Enum):
|
||||
OPENAPI = "openapi"
|
||||
GMAIL = "gmail"
|
||||
POSTGRES = "postgres"
|
||||
MYSQL = "mysql"
|
||||
SLACK = "slack"
|
||||
|
||||
|
||||
@@ -60,4 +61,5 @@ class DataType(Enum):
|
||||
OPENAPI = IndirectDataType.OPENAPI.value
|
||||
GMAIL = IndirectDataType.GMAIL.value
|
||||
POSTGRES = IndirectDataType.POSTGRES.value
|
||||
MYSQL = IndirectDataType.MYSQL.value
|
||||
SLACK = IndirectDataType.SLACK.value
|
||||
|
||||
Reference in New Issue
Block a user