[Bug Fix]: Fix test cases and update version to 0.1.93 (#1303)

This commit is contained in:
Deshraj Yadav
2024-03-04 18:35:01 -08:00
committed by GitHub
parent 11f4ce8fb6
commit 4428768eaa
6 changed files with 6 additions and 3 deletions

View File

@@ -97,7 +97,6 @@ class BaseLlmConfig(BaseConfig):
endpoint: Optional[str] = None,
model_kwargs: Optional[dict[str, Any]] = None,
local: Optional[bool] = False,
base_url: Optional[str] = None,
):
"""
Initializes a configuration class instance for the LLM.
@@ -172,7 +171,6 @@ class BaseLlmConfig(BaseConfig):
self.endpoint = endpoint
self.model_kwargs = model_kwargs
self.local = local
self.base_url = base_url
if isinstance(prompt, str):
prompt = Template(prompt)

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
version = "0.1.92"
version = "0.1.93"
description = "Simplest open source retrieval(RAG) framework"
authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -30,6 +30,7 @@ def mock_data():
@pytest.fixture
def mock_answer_relevance_metric(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
monkeypatch.setenv("OPENAI_API_BASE", "test_api_base")
metric = AnswerRelevance()
return metric

View File

@@ -73,6 +73,7 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
)
@@ -98,6 +99,7 @@ def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
)
mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
mocked_json_output_tools_parser.assert_called_once()

View File

@@ -14,6 +14,7 @@ from embedchain.vectordb.chroma import ChromaDB
@pytest.fixture
def app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
os.environ["OPENAI_API_BASE"] = "test_api_base"
return App()

View File

@@ -26,6 +26,7 @@ class TestFactories:
def test_llm_factory_create(self, provider_name, config_data, expected_class):
os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
os.environ["OPENAI_API_KEY"] = "test_api_key"
os.environ["OPENAI_API_BASE"] = "test_api_base"
llm_instance = LlmFactory.create(provider_name, config_data)
assert isinstance(llm_instance, expected_class)