[Feature Improvement] Update JSON Loader to support loading data from more sources (#898)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-03 10:00:27 -07:00
committed by GitHub
parent e2546a653d
commit 53037b5ed8
6 changed files with 166 additions and 67 deletions

View File

@@ -2,9 +2,26 @@
title: '📃 JSON'
---
To add any json file, use the data_type as `json`. `json` allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
To add any json file, use the data_type as `json`. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
```python
Here are the supported sources for loading `json`:
```
1. URL - valid url to json file that ends with ".json" extension.
2. Local file - valid url to local json file that ends with ".json" extension.
3. String - valid json string (e.g. - app.add('{"foo": "bar"}'))
```
If you would like to add other data structures (e.x. list, dict etc.), do:
```
import json
a = {"foo": "bar"}
valid_json_string_data = json.dumps(a, indent=0)
b = [{"foo": "bar"}]
valid_json_string_data = json.dumps(b, indent=0)
```
Example:
```
import os
from embedchain.apps.app import App
@@ -25,8 +42,8 @@ response = app.query("What is the net worth of Elon Musk as of October 2023?")
print(response)
"As of October 2023, Elon Musk's net worth is $255.2 billion."
```
```temp.json
temp.json
```
{
"question": "What is your net worth, Elon Musk?",
"answer": "As of October 2023, Elon Musk's net worth is $255.2 billion, making him one of the wealthiest individuals in the world."

View File

@@ -20,7 +20,7 @@ from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype
from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
load_dotenv()
@@ -175,11 +175,27 @@ class EmbedChain(JSONSerializable):
if data_type:
try:
data_type = DataType(data_type)
if data_type == DataType.JSON:
if isinstance(source, str):
if not is_valid_json_string(source):
raise ValueError(
f"Invalid json input: {source}",
"Provide the correct JSON formatted source, \
refer `https://docs.embedchain.ai/data-sources/json`",
)
elif not isinstance(source, str):
raise ValueError(
"Invaid content input. \
If you want to upload (list, dict, etc.), do \
`json.dump(data, indent=0)` and add the stringified JSON. \
Check - `https://docs.embedchain.ai/data-sources/json`"
)
except ValueError:
raise ValueError(
f"Invalid data_type: '{data_type}'.",
f"Please use one of the following: {[data_type.value for data_type in DataType]}",
) from None
if not data_type:
data_type = detect_datatype(source)
@@ -287,6 +303,10 @@ class EmbedChain(JSONSerializable):
# These types have a indirect source reference
# As long as the reference is the same, they can be updated.
where = {"url": src}
if chunker.data_type == DataType.JSON and is_valid_json_string(src):
url = hashlib.sha256((src).encode("utf-8")).hexdigest()
where = {"url": url}
if self.config.id is not None:
where.update({"app_id": self.config.id})
@@ -368,6 +388,10 @@ class EmbedChain(JSONSerializable):
# get existing ids, and discard doc if any common id exist.
where = {"url": src}
if chunker.data_type == DataType.JSON and is_valid_json_string(src):
url = hashlib.sha256((src).encode("utf-8")).hexdigest()
where = {"url": url}
# if data type is qna_pair, we check for question
if chunker.data_type == DataType.QNA_PAIR:
where = {"question": src[0]}

View File

@@ -6,33 +6,37 @@ import re
import requests
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
from embedchain.utils import clean_string, is_valid_json_string
VALID_URL_PATTERN = "^https:\/\/[0-9A-z.]+.[0-9A-z.]+.[a-z]+\/.*\.json$"
class JSONLoader(BaseLoader):
@staticmethod
def load_data(content):
"""Load a json file. Each data point is a key value pair."""
def _get_llama_hub_loader():
try:
from llama_hub.jsondata.base import \
JSONDataReader as LLHBUBJSONLoader
except ImportError:
JSONDataReader as LLHUBJSONLoader
except ImportError as e:
raise Exception(
f"Couldn't import the required packages to load {content}, \
Do `pip install --upgrade 'embedchain[json]`"
f"Failed to install required packages: {e}, \
install them using `pip install --upgrade 'embedchain[json]`"
)
loader = LLHBUBJSONLoader()
return LLHUBJSONLoader()
if not isinstance(content, str):
print(f"Invaid content input. Provide the correct path to the json file saved locally in {content}")
@staticmethod
def load_data(content):
"""Load a json file. Each data point is a key value pair."""
loader = JSONLoader._get_llama_hub_loader()
data = []
data_content = []
# Load json data from various sources. TODO: add support for dictionary
content_url_str = content
# Load json data from various sources.
if os.path.isfile(content):
with open(content, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
@@ -45,13 +49,17 @@ class JSONLoader(BaseLoader):
f"Loading data from the given url: {content} failed. \
Make sure the url is working."
)
elif is_valid_json_string(content):
json_data = content
content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
else:
raise ValueError(f"Invalid content to load json data from: {content}")
docs = loader.load_data(json_data)
for doc in docs:
doc_content = clean_string(doc.text)
data.append({"content": doc_content, "meta_data": {"url": content}})
data.append({"content": doc_content, "meta_data": {"url": content_url_str}})
data_content.append(doc_content)
doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest()
doc_id = hashlib.sha256((content_url_str + ", ".join(data_content)).encode()).hexdigest()
return {"doc_id": doc_id, "data": data}

View File

@@ -1,3 +1,4 @@
import json
import logging
import os
import re
@@ -261,6 +262,24 @@ def detect_datatype(source: Any) -> DataType:
# TODO: check if source is gmail query
# check if the source is valid json string
if is_valid_json_string(source):
logging.debug(f"Source of `{formatted_source}` detected as `json`.")
return DataType.JSON
# Use text as final fallback.
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
return DataType.TEXT
# check if the source is valid json string
def is_valid_json_string(source: str):
try:
_ = json.loads(source)
return True
except json.JSONDecodeError:
logging.error(
"Insert valid string format of JSON. \
Check the docs to see the supported formats - `https://docs.embedchain.ai/data-sources/json`"
)
return False

View File

@@ -1,61 +1,65 @@
import os
import unittest
from unittest.mock import patch
import pytest
from chromadb.api.models.Collection import Collection
from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.embedchain import EmbedChain
from embedchain.llm.base import BaseLlm
os.environ["OPENAI_API_KEY"] = "test-api-key"
class TestChromaDbHostsLoglevel(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
@pytest.fixture
def app_instance():
config = AppConfig(log_level="DEBUG", collect_metrics=False)
return App(config)
@patch("chromadb.api.models.Collection.Collection.add")
@patch("embedchain.embedchain.EmbedChain.retrieve_from_database")
@patch("embedchain.llm.base.BaseLlm.get_answer_from_llm")
@patch("embedchain.llm.base.BaseLlm.get_llm_model_answer")
def test_whole_app(
self,
_mock_add,
_mock_ec_retrieve_from_database,
_mock_get_answer_from_llm,
mock_ec_get_llm_model_answer,
):
"""
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
"""
config = AppConfig(log_level="DEBUG", collect_metrics=False)
app = App(config)
def test_whole_app(app_instance, mocker):
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
mocker.patch.object(EmbedChain, "add")
mocker.patch.object(EmbedChain, "retrieve_from_database")
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
mocker.patch.object(BaseLlm, "generate_prompt")
app.add(knowledge, data_type="text")
app_instance.add(knowledge, data_type="text")
app_instance.query("What text did I give you?")
app_instance.chat("What text did I give you?")
app.query("What text did I give you?")
app.chat("What text did I give you?")
assert BaseLlm.generate_prompt.call_count == 2
app_instance.reset()
self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge])
def test_add_after_reset(self):
"""
Test if the `App` instance is correctly reconstructed after a reset.
"""
config = AppConfig(log_level="DEBUG", collect_metrics=False)
chroma_config = {"allow_reset": True}
app = App(config=config, db_config=ChromaDbConfig(**chroma_config))
app.reset()
def test_add_after_reset(app_instance, mocker):
config = AppConfig(log_level="DEBUG", collect_metrics=False)
chroma_config = {"allow_reset": True}
# Make sure the client is still healthy
app.db.client.heartbeat()
# Make sure the collection exists, and can be added to
app.db.collection.add(
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
metadatas=[
{"chapter": "3", "verse": "16"},
{"chapter": "3", "verse": "5"},
{"chapter": "29", "verse": "11"},
],
ids=["id1", "id2", "id3"],
)
app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
app_instance.reset()
app.reset()
app_instance.db.client.heartbeat()
mocker.patch.object(Collection, "add")
app_instance.db.collection.add(
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
metadatas=[
{"chapter": "3", "verse": "16"},
{"chapter": "3", "verse": "5"},
{"chapter": "29", "verse": "11"},
],
ids=["id1", "id2", "id3"],
)
app_instance.reset()
def test_add_with_incorrect_content(app_instance, mocker):
content = [{"foo": "bar"}]
with pytest.raises(ValueError):
app_instance.add(content, data_type="json")

View File

@@ -40,7 +40,7 @@ def test_load_data(mocker):
def test_load_data_url(mocker):
content = "https://example.com/posts.json"
mocker.patch("os.path.isfile", return_value=False) # Mocking os.path.isfile to simulate a URL case
mocker.patch("os.path.isfile", return_value=False)
mocker.patch(
"llama_hub.jsondata.base.JSONDataReader.load_data",
return_value=[Document(text="content1"), Document(text="content2")],
@@ -68,11 +68,11 @@ def test_load_data_url(mocker):
assert result["doc_id"] == expected_doc_id
def test_load_data_invalid_content(mocker):
def test_load_data_invalid_string_content(mocker):
mocker.patch("os.path.isfile", return_value=False)
mocker.patch("requests.get")
content = "123"
content = "123: 345}"
with pytest.raises(ValueError, match="Invalid content to load json data from"):
JSONLoader.load_data(content)
@@ -89,3 +89,30 @@ def test_load_data_invalid_url(mocker):
with pytest.raises(ValueError, match=f"Invalid content to load json data from: {content}"):
JSONLoader.load_data(content)
def test_load_data_from_json_string(mocker):
content = '{"foo": "bar"}'
content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
mocker.patch("os.path.isfile", return_value=False)
mocker.patch(
"llama_hub.jsondata.base.JSONDataReader.load_data",
return_value=[Document(text="content1"), Document(text="content2")],
)
result = JSONLoader.load_data(content)
assert "doc_id" in result
assert "data" in result
expected_data = [
{"content": "content1", "meta_data": {"url": content_url_str}},
{"content": "content2", "meta_data": {"url": content_url_str}},
]
assert result["data"] == expected_data
expected_doc_id = hashlib.sha256((content_url_str + ", ".join(["content1", "content2"])).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id