diff --git a/docs/data-sources/json.mdx b/docs/data-sources/json.mdx index 268009e9..5f243070 100644 --- a/docs/data-sources/json.mdx +++ b/docs/data-sources/json.mdx @@ -2,9 +2,26 @@ title: '📃 JSON' --- -To add any json file, use the data_type as `json`. `json` allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg: +To add any json file, use the data_type as `json`. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg: -```python +Here are the supported sources for loading `json`: +``` +1. URL - valid url to json file that ends with ".json" extension. +2. Local file - valid url to local json file that ends with ".json" extension. +3. String - valid json string (e.g. - app.add('{"foo": "bar"}')) +``` + +If you would like to add other data structures (e.x. list, dict etc.), do: +``` + import json + a = {"foo": "bar"} + valid_json_string_data = json.dumps(a, indent=0) + + b = [{"foo": "bar"}] + valid_json_string_data = json.dumps(b, indent=0) +``` +Example: +``` import os from embedchain.apps.app import App @@ -25,8 +42,8 @@ response = app.query("What is the net worth of Elon Musk as of October 2023?") print(response) "As of October 2023, Elon Musk's net worth is $255.2 billion." ``` - -```temp.json +temp.json +``` { "question": "What is your net worth, Elon Musk?", "answer": "As of October 2023, Elon Musk's net worth is $255.2 billion, making him one of the wealthiest individuals in the world." diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index b1e2a454..0577bc3b 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -20,7 +20,7 @@ from embedchain.loaders.base_loader import BaseLoader from embedchain.models.data_type import (DataType, DirectDataType, IndirectDataType, SpecialDataType) from embedchain.telemetry.posthog import AnonymousTelemetry -from embedchain.utils import detect_datatype +from embedchain.utils import detect_datatype, is_valid_json_string from embedchain.vectordb.base import BaseVectorDB load_dotenv() @@ -175,11 +175,27 @@ class EmbedChain(JSONSerializable): if data_type: try: data_type = DataType(data_type) + if data_type == DataType.JSON: + if isinstance(source, str): + if not is_valid_json_string(source): + raise ValueError( + f"Invalid json input: {source}", + "Provide the correct JSON formatted source, \ + refer `https://docs.embedchain.ai/data-sources/json`", + ) + elif not isinstance(source, str): + raise ValueError( + "Invaid content input. \ + If you want to upload (list, dict, etc.), do \ + `json.dump(data, indent=0)` and add the stringified JSON. \ + Check - `https://docs.embedchain.ai/data-sources/json`" + ) except ValueError: raise ValueError( f"Invalid data_type: '{data_type}'.", f"Please use one of the following: {[data_type.value for data_type in DataType]}", ) from None + if not data_type: data_type = detect_datatype(source) @@ -287,6 +303,10 @@ class EmbedChain(JSONSerializable): # These types have a indirect source reference # As long as the reference is the same, they can be updated. where = {"url": src} + if chunker.data_type == DataType.JSON and is_valid_json_string(src): + url = hashlib.sha256((src).encode("utf-8")).hexdigest() + where = {"url": url} + if self.config.id is not None: where.update({"app_id": self.config.id}) @@ -368,6 +388,10 @@ class EmbedChain(JSONSerializable): # get existing ids, and discard doc if any common id exist. where = {"url": src} + if chunker.data_type == DataType.JSON and is_valid_json_string(src): + url = hashlib.sha256((src).encode("utf-8")).hexdigest() + where = {"url": url} + # if data type is qna_pair, we check for question if chunker.data_type == DataType.QNA_PAIR: where = {"question": src[0]} diff --git a/embedchain/loaders/json.py b/embedchain/loaders/json.py index 9ecc77f1..058f8fd0 100644 --- a/embedchain/loaders/json.py +++ b/embedchain/loaders/json.py @@ -6,33 +6,37 @@ import re import requests from embedchain.loaders.base_loader import BaseLoader -from embedchain.utils import clean_string +from embedchain.utils import clean_string, is_valid_json_string VALID_URL_PATTERN = "^https:\/\/[0-9A-z.]+.[0-9A-z.]+.[a-z]+\/.*\.json$" class JSONLoader(BaseLoader): @staticmethod - def load_data(content): - """Load a json file. Each data point is a key value pair.""" + def _get_llama_hub_loader(): try: from llama_hub.jsondata.base import \ - JSONDataReader as LLHBUBJSONLoader - except ImportError: + JSONDataReader as LLHUBJSONLoader + except ImportError as e: raise Exception( - f"Couldn't import the required packages to load {content}, \ - Do `pip install --upgrade 'embedchain[json]`" + f"Failed to install required packages: {e}, \ + install them using `pip install --upgrade 'embedchain[json]`" ) - loader = LLHBUBJSONLoader() + return LLHUBJSONLoader() - if not isinstance(content, str): - print(f"Invaid content input. Provide the correct path to the json file saved locally in {content}") + @staticmethod + def load_data(content): + """Load a json file. Each data point is a key value pair.""" + + loader = JSONLoader._get_llama_hub_loader() data = [] data_content = [] - # Load json data from various sources. TODO: add support for dictionary + content_url_str = content + + # Load json data from various sources. if os.path.isfile(content): with open(content, "r", encoding="utf-8") as json_file: json_data = json.load(json_file) @@ -45,13 +49,17 @@ class JSONLoader(BaseLoader): f"Loading data from the given url: {content} failed. \ Make sure the url is working." ) + elif is_valid_json_string(content): + json_data = content + content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest() else: raise ValueError(f"Invalid content to load json data from: {content}") docs = loader.load_data(json_data) for doc in docs: doc_content = clean_string(doc.text) - data.append({"content": doc_content, "meta_data": {"url": content}}) + data.append({"content": doc_content, "meta_data": {"url": content_url_str}}) data_content.append(doc_content) - doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest() + + doc_id = hashlib.sha256((content_url_str + ", ".join(data_content)).encode()).hexdigest() return {"doc_id": doc_id, "data": data} diff --git a/embedchain/utils.py b/embedchain/utils.py index b2b30cf2..6f6f6bf1 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -1,3 +1,4 @@ +import json import logging import os import re @@ -261,6 +262,24 @@ def detect_datatype(source: Any) -> DataType: # TODO: check if source is gmail query + # check if the source is valid json string + if is_valid_json_string(source): + logging.debug(f"Source of `{formatted_source}` detected as `json`.") + return DataType.JSON + # Use text as final fallback. logging.debug(f"Source of `{formatted_source}` detected as `text`.") return DataType.TEXT + + +# check if the source is valid json string +def is_valid_json_string(source: str): + try: + _ = json.loads(source) + return True + except json.JSONDecodeError: + logging.error( + "Insert valid string format of JSON. \ + Check the docs to see the supported formats - `https://docs.embedchain.ai/data-sources/json`" + ) + return False diff --git a/tests/embedchain/test_embedchain.py b/tests/embedchain/test_embedchain.py index 3c5ffc7e..a377957c 100644 --- a/tests/embedchain/test_embedchain.py +++ b/tests/embedchain/test_embedchain.py @@ -1,61 +1,65 @@ import os -import unittest -from unittest.mock import patch + +import pytest +from chromadb.api.models.Collection import Collection from embedchain import App from embedchain.config import AppConfig, ChromaDbConfig +from embedchain.embedchain import EmbedChain +from embedchain.llm.base import BaseLlm + +os.environ["OPENAI_API_KEY"] = "test-api-key" -class TestChromaDbHostsLoglevel(unittest.TestCase): - os.environ["OPENAI_API_KEY"] = "test_key" +@pytest.fixture +def app_instance(): + config = AppConfig(log_level="DEBUG", collect_metrics=False) + return App(config) - @patch("chromadb.api.models.Collection.Collection.add") - @patch("embedchain.embedchain.EmbedChain.retrieve_from_database") - @patch("embedchain.llm.base.BaseLlm.get_answer_from_llm") - @patch("embedchain.llm.base.BaseLlm.get_llm_model_answer") - def test_whole_app( - self, - _mock_add, - _mock_ec_retrieve_from_database, - _mock_get_answer_from_llm, - mock_ec_get_llm_model_answer, - ): - """ - Test if the `App` instance is initialized without a config that does not contain default hosts and ports. - """ - config = AppConfig(log_level="DEBUG", collect_metrics=False) - app = App(config) +def test_whole_app(app_instance, mocker): + knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing" - knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing" + mocker.patch.object(EmbedChain, "add") + mocker.patch.object(EmbedChain, "retrieve_from_database") + mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge) + mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge) + mocker.patch.object(BaseLlm, "generate_prompt") - app.add(knowledge, data_type="text") + app_instance.add(knowledge, data_type="text") + app_instance.query("What text did I give you?") + app_instance.chat("What text did I give you?") - app.query("What text did I give you?") - app.chat("What text did I give you?") + assert BaseLlm.generate_prompt.call_count == 2 + app_instance.reset() - self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge]) - def test_add_after_reset(self): - """ - Test if the `App` instance is correctly reconstructed after a reset. - """ - config = AppConfig(log_level="DEBUG", collect_metrics=False) - chroma_config = {"allow_reset": True} - app = App(config=config, db_config=ChromaDbConfig(**chroma_config)) - app.reset() +def test_add_after_reset(app_instance, mocker): + config = AppConfig(log_level="DEBUG", collect_metrics=False) + chroma_config = {"allow_reset": True} - # Make sure the client is still healthy - app.db.client.heartbeat() - # Make sure the collection exists, and can be added to - app.db.collection.add( - embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]], - metadatas=[ - {"chapter": "3", "verse": "16"}, - {"chapter": "3", "verse": "5"}, - {"chapter": "29", "verse": "11"}, - ], - ids=["id1", "id2", "id3"], - ) + app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config)) + app_instance.reset() - app.reset() + app_instance.db.client.heartbeat() + + mocker.patch.object(Collection, "add") + + app_instance.db.collection.add( + embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]], + metadatas=[ + {"chapter": "3", "verse": "16"}, + {"chapter": "3", "verse": "5"}, + {"chapter": "29", "verse": "11"}, + ], + ids=["id1", "id2", "id3"], + ) + + app_instance.reset() + + +def test_add_with_incorrect_content(app_instance, mocker): + content = [{"foo": "bar"}] + + with pytest.raises(ValueError): + app_instance.add(content, data_type="json") diff --git a/tests/loaders/test_json.py b/tests/loaders/test_json.py index 085bcadd..9f21d402 100644 --- a/tests/loaders/test_json.py +++ b/tests/loaders/test_json.py @@ -40,7 +40,7 @@ def test_load_data(mocker): def test_load_data_url(mocker): content = "https://example.com/posts.json" - mocker.patch("os.path.isfile", return_value=False) # Mocking os.path.isfile to simulate a URL case + mocker.patch("os.path.isfile", return_value=False) mocker.patch( "llama_hub.jsondata.base.JSONDataReader.load_data", return_value=[Document(text="content1"), Document(text="content2")], @@ -68,11 +68,11 @@ def test_load_data_url(mocker): assert result["doc_id"] == expected_doc_id -def test_load_data_invalid_content(mocker): +def test_load_data_invalid_string_content(mocker): mocker.patch("os.path.isfile", return_value=False) mocker.patch("requests.get") - content = "123" + content = "123: 345}" with pytest.raises(ValueError, match="Invalid content to load json data from"): JSONLoader.load_data(content) @@ -89,3 +89,30 @@ def test_load_data_invalid_url(mocker): with pytest.raises(ValueError, match=f"Invalid content to load json data from: {content}"): JSONLoader.load_data(content) + + +def test_load_data_from_json_string(mocker): + content = '{"foo": "bar"}' + + content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest() + + mocker.patch("os.path.isfile", return_value=False) + mocker.patch( + "llama_hub.jsondata.base.JSONDataReader.load_data", + return_value=[Document(text="content1"), Document(text="content2")], + ) + + result = JSONLoader.load_data(content) + + assert "doc_id" in result + assert "data" in result + + expected_data = [ + {"content": "content1", "meta_data": {"url": content_url_str}}, + {"content": "content2", "meta_data": {"url": content_url_str}}, + ] + + assert result["data"] == expected_data + + expected_doc_id = hashlib.sha256((content_url_str + ", ".join(["content1", "content2"])).encode()).hexdigest() + assert result["doc_id"] == expected_doc_id