[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)

This commit is contained in:
Sandra Serrano
2024-01-08 19:47:46 +01:00
committed by GitHub
parent 62c0c52e31
commit 2496ed133e
41 changed files with 133 additions and 103 deletions

View File

@@ -4,9 +4,7 @@ from typing import Any, Dict, Generator, List, Optional
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
@@ -76,7 +74,7 @@ class BaseLlm(JSONSerializable):
:return: The prompt
:rtype: str
"""
context_string = (" | ").join(contexts)
context_string = " | ".join(contexts)
web_search_result = kwargs.get("web_search_result", "")
if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result)
@@ -110,7 +108,8 @@ class BaseLlm(JSONSerializable):
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
return prompt
def _append_search_and_context(self, context: str, web_search_result: str) -> str:
@staticmethod
def _append_search_and_context(context: str, web_search_result: str) -> str:
"""Append web search context to existing context
:param context: Existing context
@@ -134,7 +133,8 @@ class BaseLlm(JSONSerializable):
"""
return self.get_llm_model_answer(prompt)
def access_search_and_get_results(self, input_query: str):
@staticmethod
def access_search_and_get_results(input_query: str):
"""
Search the internet for additional context
@@ -153,7 +153,8 @@ class BaseLlm(JSONSerializable):
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def _stream_response(self, answer: Any) -> Generator[Any, Any, None]:
@staticmethod
def _stream_response(answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response
:param answer: Answer chunk from llm

View File

@@ -44,7 +44,7 @@ class GoogleLlm(BaseLlm):
"temperature": self.config.temperature or 0.5,
}
if self.config.top_p >= 0.0 and self.config.top_p <= 1.0:
if 0.0 <= self.config.top_p <= 1.0:
generation_config_params["top_p"] = self.config.top_p
else:
raise ValueError("`top_p` must be > 0.0 and < 1.0")

View File

@@ -48,7 +48,7 @@ class HuggingFaceLlm(BaseLlm):
"max_new_tokens": config.max_tokens,
}
if config.top_p > 0.0 and config.top_p < 1.0:
if 0.0 < config.top_p < 1.0:
model_kwargs["top_p"] = config.top_p
else:
raise ValueError("`top_p` must be > 0.0 and < 1.0")

View File

@@ -20,7 +20,8 @@ class OllamaLlm(BaseLlm):
def get_llm_model_answer(self, prompt):
return self._get_answer(prompt=prompt, config=self.config)
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
llm = Ollama(