From 2f285ea00a5e7676934542c7af9530ea1f5294ac Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Sun, 11 Feb 2024 16:07:36 -0800 Subject: [PATCH] [Bug fix] Fix history sequence in prompt (#1254) --- embedchain/config/llm/base.py | 3 ++- embedchain/llm/base.py | 18 +++++++++++------- embedchain/llm/openai.py | 14 ++++++-------- embedchain/memory/base.py | 4 ++-- pyproject.toml | 2 +- 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index ab4331e4..3619b924 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -24,7 +24,8 @@ DEFAULT_PROMPT_WITH_HISTORY = """ $context - History: $history + History: + $history Query: $query diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index 1e002916..6ae0e064 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -5,9 +5,7 @@ from typing import Any, Optional from langchain.schema import BaseMessage as LCBaseMessage from embedchain.config import BaseLlmConfig -from embedchain.config.llm.base import (DEFAULT_PROMPT, - DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, - DOCS_SITE_PROMPT_TEMPLATE) +from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE from embedchain.helpers.json_serializable import JSONSerializable from embedchain.memory.base import ChatHistory from embedchain.memory.message import ChatMessage @@ -65,6 +63,14 @@ class BaseLlm(JSONSerializable): self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id) self.update_history(app_id=app_id, session_id=session_id) + def _format_history(self) -> str: + """Format history to be used in prompt + + :return: Formatted history + :rtype: str + """ + return "\n".join(self.history) + def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str: """ Generates a prompt based on the given query and context, ready to be @@ -84,10 +90,8 @@ class BaseLlm(JSONSerializable): prompt_contains_history = self.config._validate_prompt_history(self.config.prompt) if prompt_contains_history: - # Prompt contains history - # If there is no history yet, we insert `- no history -` prompt = self.config.prompt.substitute( - context=context_string, query=input_query, history=self.history or "- no history -" + context=context_string, query=input_query, history=self._format_history() or "No history" ) elif self.history and not prompt_contains_history: # History is present, but not included in the prompt. @@ -98,7 +102,7 @@ class BaseLlm(JSONSerializable): ): # swap in the template with history prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute( - context=context_string, query=input_query, history=self.history + context=context_string, query=input_query, history=self._format_history() ) else: # If we can't swap in the default, we still proceed but tell users that the history is ignored. diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index 639ec919..40f750c0 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -35,21 +35,19 @@ class OpenAILlm(BaseLlm): if config.top_p: kwargs["model_kwargs"]["top_p"] = config.top_p if config.stream: - from langchain.callbacks.streaming_stdout import \ - StreamingStdOutCallbackHandler + from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] - chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key) + llm = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key) else: - chat = ChatOpenAI(**kwargs, api_key=api_key) + llm = ChatOpenAI(**kwargs, api_key=api_key) if self.functions is not None: - from langchain.chains.openai_functions import \ - create_openai_fn_runnable + from langchain.chains.openai_functions import create_openai_fn_runnable from langchain.prompts import ChatPromptTemplate structured_prompt = ChatPromptTemplate.from_messages(messages) - runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=chat) + runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=llm) fn_res = runnable.invoke( { "input": prompt, @@ -57,4 +55,4 @@ class OpenAILlm(BaseLlm): ) messages.append(AIMessage(content=json.dumps(fn_res))) - return chat(messages).content + return llm(messages).content diff --git a/embedchain/memory/base.py b/embedchain/memory/base.py index 378ad479..75780efb 100644 --- a/embedchain/memory/base.py +++ b/embedchain/memory/base.py @@ -92,12 +92,12 @@ class ChatHistory: """ if fetch_all: - additional_query = "ORDER BY created_at DESC" + additional_query = "ORDER BY created_at ASC" params = (app_id,) else: additional_query = """ AND session_id=? - ORDER BY created_at DESC + ORDER BY created_at ASC LIMIT ? """ params = (app_id, session_id, num_rounds) diff --git a/pyproject.toml b/pyproject.toml index 2c93221c..00142127 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.76" +version = "0.1.77" description = "Simplest open source retrieval(RAG) framework" authors = [ "Taranjeet Singh ",