diff --git a/embedchain/config/apps/AppConfig.py b/embedchain/config/apps/AppConfig.py index a75a8acc..615f8f2d 100644 --- a/embedchain/config/apps/AppConfig.py +++ b/embedchain/config/apps/AppConfig.py @@ -1,6 +1,12 @@ import os -from chromadb.utils import embedding_functions +try: + from chromadb.utils import embedding_functions +except RuntimeError: + from embedchain.utils import use_pysqlite3 + + use_pysqlite3() + from chromadb.utils import embedding_functions from .BaseAppConfig import BaseAppConfig diff --git a/embedchain/utils.py b/embedchain/utils.py index 02ee2deb..7a4f2b1d 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -1,3 +1,4 @@ +import logging import re import string @@ -45,3 +46,22 @@ def is_readable(s): """ printable_ratio = sum(c in string.printable for c in s) / len(s) return printable_ratio > 0.95 # 95% of characters are printable + + +def use_pysqlite3(): + """ + Swap std-lib sqlite3 with pysqlite3. + """ + import platform + + if platform.system() == "Linux": + # According to the Chroma team, this patch only works on Linux + import subprocess + import sys + + subprocess.check_call([sys.executable, "-m", "pip", "install", "pysqlite3-binary"]) + + __import__("pysqlite3") + sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") + # Don't be surprised if this doesn't log as you expect, because the logger is instantiated after the import + logging.info("Swapped std-lib sqlite3 with pysqlite3") diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index ff236a7b..61ade9af 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -1,6 +1,12 @@ import logging -import chromadb +try: + import chromadb +except RuntimeError: + from embedchain.utils import use_pysqlite3 + + use_pysqlite3() + import chromadb from chromadb.config import Settings from embedchain.vectordb.base_vector_db import BaseVectorDB