[Feature] OpenAI Function Calling (#1224)

This commit is contained in:
UnMonsieur
2024-02-12 02:58:11 +01:00
committed by GitHub
parent 38e212c721
commit 41bd258b93
31 changed files with 259 additions and 213 deletions

View File

@@ -1,9 +1,11 @@
import json
import os
from typing import Any, Optional
from typing import Any, Callable, Dict, Optional, Type, Union
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -12,8 +14,12 @@ from embedchain.llm.base import BaseLlm
@register_deserializable
class OpenAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None):
self.functions = functions
def __init__(
self,
config: Optional[BaseLlmConfig] = None,
tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]] = None,
):
self.tools = tools
super().__init__(config=config)
def get_llm_model_answer(self, prompt) -> str:
@@ -38,21 +44,27 @@ class OpenAILlm(BaseLlm):
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
llm = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
else:
llm = ChatOpenAI(**kwargs, api_key=api_key)
chat = ChatOpenAI(**kwargs, api_key=api_key)
if self.tools:
return self._query_function_call(chat, self.tools, messages)
if self.functions is not None:
from langchain.chains.openai_functions import create_openai_fn_runnable
from langchain.prompts import ChatPromptTemplate
return chat.invoke(messages).content
structured_prompt = ChatPromptTemplate.from_messages(messages)
runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=llm)
fn_res = runnable.invoke(
{
"input": prompt,
}
)
messages.append(AIMessage(content=json.dumps(fn_res)))
def _query_function_call(
self,
chat: ChatOpenAI,
tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]],
messages: list[BaseMessage],
) -> str:
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.utils.function_calling import \
convert_to_openai_tool
return llm(messages).content
openai_tools = [convert_to_openai_tool(tools)]
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
try:
return json.dumps(chat.invoke(messages)[0])
except IndexError:
return "Input could not be mapped to the function!"