Fix online feat and add docs (#1387)

This commit is contained in:
Dev Khant
2024-06-07 12:03:16 +05:30
committed by GitHub
parent b0e436d9c4
commit fd07513004
7 changed files with 115 additions and 81 deletions

View File

@@ -89,6 +89,7 @@ class BaseLlmConfig(BaseConfig):
max_tokens: int = 1000,
top_p: float = 1,
stream: bool = False,
online: bool = False,
deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None,
where: dict[str, Any] = None,
@@ -129,6 +130,8 @@ class BaseLlmConfig(BaseConfig):
:type top_p: float, optional
:param stream: Control if response is streamed back to user, defaults to False
:type stream: bool, optional
:param online: Controls whether to use internet for answering query, defaults to False
:type online: bool, optional
:param deployment_name: t.b.a., defaults to None
:type deployment_name: Optional[str], optional
:param system_prompt: System prompt string, defaults to None
@@ -181,6 +184,7 @@ class BaseLlmConfig(BaseConfig):
self.http_async_client = http_async_client
self.local = local
self.default_headers = default_headers
self.online = online
if isinstance(prompt, str):
prompt = Template(prompt)

View File

@@ -6,9 +6,7 @@ from typing import Any, Optional, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
from embedchain.cache import (adapt, get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
@@ -18,8 +16,7 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
@@ -97,13 +94,13 @@ class EmbedChain(JSONSerializable):
@property
def online(self):
return self.llm.online
return self.llm.config.online
@online.setter
def online(self, value):
if not isinstance(value, bool):
raise ValueError(f"Boolean value expected but got {type(value)}.")
self.llm.online = value
self.llm.config.online = value
def add(
self,

View File

@@ -5,9 +5,7 @@ from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
@@ -29,7 +27,6 @@ class BaseLlm(JSONSerializable):
self.memory = ChatHistory()
self.is_docs_site_instance = False
self.online = False
self.history: Any = None
def get_llm_model_answer(self):
@@ -213,7 +210,7 @@ class BaseLlm(JSONSerializable):
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5
k = {}
if self.online:
if self.config.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)
logger.info(f"Prompt: {prompt}")
@@ -268,7 +265,7 @@ class BaseLlm(JSONSerializable):
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5
k = {}
if self.online:
if self.config.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)

View File

@@ -419,6 +419,7 @@ def validate_config(config_data):
Optional("max_tokens"): int,
Optional("top_p"): Or(float, int),
Optional("stream"): bool,
Optional("online"): bool,
Optional("template"): str,
Optional("prompt"): str,
Optional("system_prompt"): str,