chore: linting (#414)

This commit is contained in:
cachho
2023-08-10 22:23:42 +02:00
committed by GitHub
parent 77e223be52
commit f0abfea55d
4 changed files with 11 additions and 10 deletions

View File

@@ -52,8 +52,7 @@ class CustomAppConfig(BaseAppConfig):
super().__init__( super().__init__(
log_level=log_level, log_level=log_level,
embedding_fn=CustomAppConfig.embedding_function( embedding_fn=CustomAppConfig.embedding_function(
embedding_function=embedding_fn, model=embedding_fn_model, embedding_function=embedding_fn, model=embedding_fn_model, deployment_name=deployment_name
deployment_name=deployment_name
), ),
db=db, db=db,
host=host, host=host,

View File

@@ -46,12 +46,13 @@ class DataFormatter:
"sitemap": SitemapLoader(), "sitemap": SitemapLoader(),
"docs_site": DocsSiteLoader(), "docs_site": DocsSiteLoader(),
} }
lazy_loaders = ("notion", ) lazy_loaders = ("notion",)
if data_type in loaders: if data_type in loaders:
return loaders[data_type] return loaders[data_type]
elif data_type in lazy_loaders: elif data_type in lazy_loaders:
if data_type == "notion": if data_type == "notion":
from embedchain.loaders.notion import NotionLoader from embedchain.loaders.notion import NotionLoader
return NotionLoader() return NotionLoader()
else: else:
raise ValueError(f"Unsupported data type: {data_type}") raise ValueError(f"Unsupported data type: {data_type}")

View File

@@ -65,7 +65,7 @@ def use_pysqlite3():
import datetime import datetime
import subprocess import subprocess
import sys import sys
subprocess.check_call( subprocess.check_call(
[sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"] [sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"]
) )
@@ -86,6 +86,6 @@ def use_pysqlite3():
print( print(
f"{current_time} [embedchain] [ERROR]", f"{current_time} [embedchain] [ERROR]",
"Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.", "Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.",
f"Error:", "Error:",
e e,
) )

View File

@@ -73,6 +73,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None) self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None) self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
class TestChromaDbDuplicateHandling: class TestChromaDbDuplicateHandling:
def test_duplicates_throw_warning(self, caplog): def test_duplicates_throw_warning(self, caplog):
""" """
@@ -101,8 +102,8 @@ class TestChromaDbDuplicateHandling:
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection("test_collection_2") app.set_collection("test_collection_2")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text # not assert "Insert of existing embedding ID: 0" not in caplog.text # not
assert "Add of existing embedding ID: 0" not in caplog.text # not assert "Add of existing embedding ID: 0" not in caplog.text # not
class TestChromaDbCollection(unittest.TestCase): class TestChromaDbCollection(unittest.TestCase):
@@ -197,9 +198,9 @@ class TestChromaDbCollection(unittest.TestCase):
app2.collection.add(embeddings=[0, 0, 0], ids=["0"]) app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
# Swap names and test # Swap names and test
app1.set_collection('test_collection_2') app1.set_collection("test_collection_2")
self.assertEqual(app1.count(), 1) self.assertEqual(app1.count(), 1)
app2.set_collection('test_collection_1') app2.set_collection("test_collection_1")
self.assertEqual(app2.count(), 3) self.assertEqual(app2.count(), 3)
def test_ids_share_collections(self): def test_ids_share_collections(self):