[Feature] OpenAI Function Calling (#1224)
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -12,8 +14,12 @@ from embedchain.llm.base import BaseLlm
|
||||
|
||||
@register_deserializable
|
||||
class OpenAILlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None):
|
||||
self.functions = functions
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[BaseLlmConfig] = None,
|
||||
tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]] = None,
|
||||
):
|
||||
self.tools = tools
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt) -> str:
|
||||
@@ -38,21 +44,27 @@ class OpenAILlm(BaseLlm):
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
||||
llm = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
|
||||
else:
|
||||
llm = ChatOpenAI(**kwargs, api_key=api_key)
|
||||
chat = ChatOpenAI(**kwargs, api_key=api_key)
|
||||
if self.tools:
|
||||
return self._query_function_call(chat, self.tools, messages)
|
||||
|
||||
if self.functions is not None:
|
||||
from langchain.chains.openai_functions import create_openai_fn_runnable
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
return chat.invoke(messages).content
|
||||
|
||||
structured_prompt = ChatPromptTemplate.from_messages(messages)
|
||||
runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=llm)
|
||||
fn_res = runnable.invoke(
|
||||
{
|
||||
"input": prompt,
|
||||
}
|
||||
)
|
||||
messages.append(AIMessage(content=json.dumps(fn_res)))
|
||||
def _query_function_call(
|
||||
self,
|
||||
chat: ChatOpenAI,
|
||||
tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]],
|
||||
messages: list[BaseMessage],
|
||||
) -> str:
|
||||
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
||||
from langchain_core.utils.function_calling import \
|
||||
convert_to_openai_tool
|
||||
|
||||
return llm(messages).content
|
||||
openai_tools = [convert_to_openai_tool(tools)]
|
||||
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
|
||||
try:
|
||||
return json.dumps(chat.invoke(messages)[0])
|
||||
except IndexError:
|
||||
return "Input could not be mapped to the function!"
|
||||
|
||||
Reference in New Issue
Block a user