[Bug fix] Fix history sequence in prompt (#1254)
This commit is contained in:
@@ -24,7 +24,8 @@ DEFAULT_PROMPT_WITH_HISTORY = """
|
|||||||
|
|
||||||
$context
|
$context
|
||||||
|
|
||||||
History: $history
|
History:
|
||||||
|
$history
|
||||||
|
|
||||||
Query: $query
|
Query: $query
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,7 @@ from typing import Any, Optional
|
|||||||
from langchain.schema import BaseMessage as LCBaseMessage
|
from langchain.schema import BaseMessage as LCBaseMessage
|
||||||
|
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
|
||||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
|
||||||
DOCS_SITE_PROMPT_TEMPLATE)
|
|
||||||
from embedchain.helpers.json_serializable import JSONSerializable
|
from embedchain.helpers.json_serializable import JSONSerializable
|
||||||
from embedchain.memory.base import ChatHistory
|
from embedchain.memory.base import ChatHistory
|
||||||
from embedchain.memory.message import ChatMessage
|
from embedchain.memory.message import ChatMessage
|
||||||
@@ -65,6 +63,14 @@ class BaseLlm(JSONSerializable):
|
|||||||
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
|
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
|
||||||
self.update_history(app_id=app_id, session_id=session_id)
|
self.update_history(app_id=app_id, session_id=session_id)
|
||||||
|
|
||||||
|
def _format_history(self) -> str:
|
||||||
|
"""Format history to be used in prompt
|
||||||
|
|
||||||
|
:return: Formatted history
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
return "\n".join(self.history)
|
||||||
|
|
||||||
def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
|
def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a prompt based on the given query and context, ready to be
|
Generates a prompt based on the given query and context, ready to be
|
||||||
@@ -84,10 +90,8 @@ class BaseLlm(JSONSerializable):
|
|||||||
|
|
||||||
prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
|
prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
|
||||||
if prompt_contains_history:
|
if prompt_contains_history:
|
||||||
# Prompt contains history
|
|
||||||
# If there is no history yet, we insert `- no history -`
|
|
||||||
prompt = self.config.prompt.substitute(
|
prompt = self.config.prompt.substitute(
|
||||||
context=context_string, query=input_query, history=self.history or "- no history -"
|
context=context_string, query=input_query, history=self._format_history() or "No history"
|
||||||
)
|
)
|
||||||
elif self.history and not prompt_contains_history:
|
elif self.history and not prompt_contains_history:
|
||||||
# History is present, but not included in the prompt.
|
# History is present, but not included in the prompt.
|
||||||
@@ -98,7 +102,7 @@ class BaseLlm(JSONSerializable):
|
|||||||
):
|
):
|
||||||
# swap in the template with history
|
# swap in the template with history
|
||||||
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
|
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
|
||||||
context=context_string, query=input_query, history=self.history
|
context=context_string, query=input_query, history=self._format_history()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
|
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
|
||||||
|
|||||||
@@ -35,21 +35,19 @@ class OpenAILlm(BaseLlm):
|
|||||||
if config.top_p:
|
if config.top_p:
|
||||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||||
if config.stream:
|
if config.stream:
|
||||||
from langchain.callbacks.streaming_stdout import \
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
StreamingStdOutCallbackHandler
|
|
||||||
|
|
||||||
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
||||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
|
llm = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
|
||||||
else:
|
else:
|
||||||
chat = ChatOpenAI(**kwargs, api_key=api_key)
|
llm = ChatOpenAI(**kwargs, api_key=api_key)
|
||||||
|
|
||||||
if self.functions is not None:
|
if self.functions is not None:
|
||||||
from langchain.chains.openai_functions import \
|
from langchain.chains.openai_functions import create_openai_fn_runnable
|
||||||
create_openai_fn_runnable
|
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
structured_prompt = ChatPromptTemplate.from_messages(messages)
|
structured_prompt = ChatPromptTemplate.from_messages(messages)
|
||||||
runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=chat)
|
runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=llm)
|
||||||
fn_res = runnable.invoke(
|
fn_res = runnable.invoke(
|
||||||
{
|
{
|
||||||
"input": prompt,
|
"input": prompt,
|
||||||
@@ -57,4 +55,4 @@ class OpenAILlm(BaseLlm):
|
|||||||
)
|
)
|
||||||
messages.append(AIMessage(content=json.dumps(fn_res)))
|
messages.append(AIMessage(content=json.dumps(fn_res)))
|
||||||
|
|
||||||
return chat(messages).content
|
return llm(messages).content
|
||||||
|
|||||||
@@ -92,12 +92,12 @@ class ChatHistory:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if fetch_all:
|
if fetch_all:
|
||||||
additional_query = "ORDER BY created_at DESC"
|
additional_query = "ORDER BY created_at ASC"
|
||||||
params = (app_id,)
|
params = (app_id,)
|
||||||
else:
|
else:
|
||||||
additional_query = """
|
additional_query = """
|
||||||
AND session_id=?
|
AND session_id=?
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at ASC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
params = (app_id, session_id, num_rounds)
|
params = (app_id, session_id, num_rounds)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.1.76"
|
version = "0.1.77"
|
||||||
description = "Simplest open source retrieval(RAG) framework"
|
description = "Simplest open source retrieval(RAG) framework"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
|
|||||||
Reference in New Issue
Block a user