[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)
This commit is contained in:
@@ -4,9 +4,7 @@ from typing import Any, Dict, Generator, List, Optional
|
||||
from langchain.schema import BaseMessage as LCBaseMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
@@ -76,7 +74,7 @@ class BaseLlm(JSONSerializable):
|
||||
:return: The prompt
|
||||
:rtype: str
|
||||
"""
|
||||
context_string = (" | ").join(contexts)
|
||||
context_string = " | ".join(contexts)
|
||||
web_search_result = kwargs.get("web_search_result", "")
|
||||
if web_search_result:
|
||||
context_string = self._append_search_and_context(context_string, web_search_result)
|
||||
@@ -110,7 +108,8 @@ class BaseLlm(JSONSerializable):
|
||||
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
|
||||
return prompt
|
||||
|
||||
def _append_search_and_context(self, context: str, web_search_result: str) -> str:
|
||||
@staticmethod
|
||||
def _append_search_and_context(context: str, web_search_result: str) -> str:
|
||||
"""Append web search context to existing context
|
||||
|
||||
:param context: Existing context
|
||||
@@ -134,7 +133,8 @@ class BaseLlm(JSONSerializable):
|
||||
"""
|
||||
return self.get_llm_model_answer(prompt)
|
||||
|
||||
def access_search_and_get_results(self, input_query: str):
|
||||
@staticmethod
|
||||
def access_search_and_get_results(input_query: str):
|
||||
"""
|
||||
Search the internet for additional context
|
||||
|
||||
@@ -153,7 +153,8 @@ class BaseLlm(JSONSerializable):
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
def _stream_response(self, answer: Any) -> Generator[Any, Any, None]:
|
||||
@staticmethod
|
||||
def _stream_response(answer: Any) -> Generator[Any, Any, None]:
|
||||
"""Generator to be used as streaming response
|
||||
|
||||
:param answer: Answer chunk from llm
|
||||
|
||||
@@ -44,7 +44,7 @@ class GoogleLlm(BaseLlm):
|
||||
"temperature": self.config.temperature or 0.5,
|
||||
}
|
||||
|
||||
if self.config.top_p >= 0.0 and self.config.top_p <= 1.0:
|
||||
if 0.0 <= self.config.top_p <= 1.0:
|
||||
generation_config_params["top_p"] = self.config.top_p
|
||||
else:
|
||||
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
||||
|
||||
@@ -48,7 +48,7 @@ class HuggingFaceLlm(BaseLlm):
|
||||
"max_new_tokens": config.max_tokens,
|
||||
}
|
||||
|
||||
if config.top_p > 0.0 and config.top_p < 1.0:
|
||||
if 0.0 < config.top_p < 1.0:
|
||||
model_kwargs["top_p"] = config.top_p
|
||||
else:
|
||||
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
||||
|
||||
@@ -20,7 +20,8 @@ class OllamaLlm(BaseLlm):
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return self._get_answer(prompt=prompt, config=self.config)
|
||||
|
||||
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
|
||||
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
|
||||
|
||||
llm = Ollama(
|
||||
|
||||
Reference in New Issue
Block a user