OpenAI function calling support (#1011)
This commit is contained in:
@@ -60,9 +60,129 @@ llm:
|
|||||||
top_p: 1
|
top_p: 1
|
||||||
stream: false
|
stream: false
|
||||||
```
|
```
|
||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
|
### Function Calling
|
||||||
|
To enable [function calling](https://platform.openai.com/docs/guides/function-calling) in your application using embedchain and OpenAI, you need to pass functions into `OpenAILlm` class as an array of functions. Here are several ways in which you can achieve that:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
<Accordion title="Using Pydantic Models">
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from embedchain import Pipeline as App
|
||||||
|
from embedchain.llm.openai import OpenAILlm
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||||
|
|
||||||
|
class QA(BaseModel):
|
||||||
|
"""
|
||||||
|
A question and answer pair.
|
||||||
|
"""
|
||||||
|
|
||||||
|
question: str = Field(
|
||||||
|
..., description="The question.", example="What is a mountain?"
|
||||||
|
)
|
||||||
|
answer: str = Field(
|
||||||
|
..., description="The answer.", example="A mountain is a hill."
|
||||||
|
)
|
||||||
|
person_who_is_asking: str = Field(
|
||||||
|
..., description="The person who is asking the question.", example="John"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("question")
|
||||||
|
def question_must_end_with_a_question_mark(cls, v):
|
||||||
|
"""
|
||||||
|
Validate that the question ends with a question mark.
|
||||||
|
"""
|
||||||
|
if not v.endswith("?"):
|
||||||
|
raise ValueError("question must end with a question mark")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("answer")
|
||||||
|
def answer_must_end_with_a_period(cls, v):
|
||||||
|
"""
|
||||||
|
Validate that the answer ends with a period.
|
||||||
|
"""
|
||||||
|
if not v.endswith("."):
|
||||||
|
raise ValueError("answer must end with a period")
|
||||||
|
return v
|
||||||
|
|
||||||
|
llm = OpenAILlm(config=None,functions=[QA])
|
||||||
|
app = App(llm=llm)
|
||||||
|
|
||||||
|
result = app.query("Hey I am Sid. What is a mountain? A mountain is a hill.")
|
||||||
|
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
</Accordion>
|
||||||
|
|
||||||
|
<Accordion title="Using OpenAI JSON schema">
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from embedchain import Pipeline as App
|
||||||
|
from embedchain.llm.openai import OpenAILlm
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||||
|
|
||||||
|
json_schema = {
|
||||||
|
"name": "get_qa",
|
||||||
|
"description": "A question and answer pair and the user who is asking the question.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"question": {"type": "string", "description": "The question."},
|
||||||
|
"answer": {"type": "string", "description": "The answer."},
|
||||||
|
"person_who_is_asking": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The person who is asking the question.",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["question", "answer", "person_who_is_asking"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
llm = OpenAILlm(config=None,functions=[json_schema])
|
||||||
|
app = App(llm=llm)
|
||||||
|
|
||||||
|
result = app.query("Hey I am Sid. What is a mountain? A mountain is a hill.")
|
||||||
|
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
</Accordion>
|
||||||
|
<Accordion title="Using actual python functions">
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from embedchain import Pipeline as App
|
||||||
|
from embedchain.llm.openai import OpenAILlm
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||||
|
|
||||||
|
def find_info_of_pokemon(pokemon: str):
|
||||||
|
"""
|
||||||
|
Find the information of the given pokemon.
|
||||||
|
Args:
|
||||||
|
pokemon: The pokemon.
|
||||||
|
"""
|
||||||
|
req = requests.get(f"https://pokeapi.co/api/v2/pokemon/{pokemon}")
|
||||||
|
if req.status_code == 404:
|
||||||
|
raise ValueError("pokemon not found")
|
||||||
|
return req.json()
|
||||||
|
|
||||||
|
llm = OpenAILlm(config=None,functions=[find_info_of_pokemon])
|
||||||
|
app = App(llm=llm)
|
||||||
|
|
||||||
|
result = app.query("Tell me more about the pokemon pikachu.")
|
||||||
|
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
</Accordion>
|
||||||
|
|
||||||
## Google AI
|
## Google AI
|
||||||
|
|
||||||
To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)
|
To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from typing import Optional
|
import json
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.schema import HumanMessage, SystemMessage
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
@@ -10,14 +11,15 @@ from embedchain.llm.base import BaseLlm
|
|||||||
|
|
||||||
@register_deserializable
|
@register_deserializable
|
||||||
class OpenAILlm(BaseLlm):
|
class OpenAILlm(BaseLlm):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[Dict[str, Any]] = None):
|
||||||
|
self.functions = functions
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt) -> str:
|
def get_llm_model_answer(self, prompt) -> str:
|
||||||
response = OpenAILlm._get_answer(prompt, self.config)
|
response = self._get_answer(prompt, self.config)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
|
||||||
messages = []
|
messages = []
|
||||||
if config.system_prompt:
|
if config.system_prompt:
|
||||||
messages.append(SystemMessage(content=config.system_prompt))
|
messages.append(SystemMessage(content=config.system_prompt))
|
||||||
@@ -31,11 +33,23 @@ class OpenAILlm(BaseLlm):
|
|||||||
if config.top_p:
|
if config.top_p:
|
||||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||||
if config.stream:
|
if config.stream:
|
||||||
from langchain.callbacks.streaming_stdout import \
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
StreamingStdOutCallbackHandler
|
|
||||||
|
|
||||||
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
||||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
|
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
|
||||||
else:
|
else:
|
||||||
chat = ChatOpenAI(**kwargs)
|
chat = ChatOpenAI(**kwargs)
|
||||||
|
if self.functions is not None:
|
||||||
|
from langchain.chains.openai_functions import create_openai_fn_runnable
|
||||||
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
structured_prompt = ChatPromptTemplate.from_messages(messages)
|
||||||
|
runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=chat)
|
||||||
|
fn_res = runnable.invoke(
|
||||||
|
{
|
||||||
|
"input": prompt,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
messages.append(AIMessage(content=json.dumps(fn_res)))
|
||||||
|
|
||||||
return chat(messages).content
|
return chat(messages).content
|
||||||
|
|||||||
@@ -50,24 +50,24 @@ def test_get_llm_model_answer_empty_prompt(config, mocker):
|
|||||||
|
|
||||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||||
config.stream = True
|
config.stream = True
|
||||||
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||||
|
|
||||||
llm = OpenAILlm(config)
|
llm = OpenAILlm(config)
|
||||||
llm.get_llm_model_answer("Test query")
|
llm.get_llm_model_answer("Test query")
|
||||||
|
|
||||||
mocked_jinachat.assert_called_once()
|
mocked_openai_chat.assert_called_once()
|
||||||
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
|
callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
|
||||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||||
|
|
||||||
|
|
||||||
def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||||
config.system_prompt = None
|
config.system_prompt = None
|
||||||
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||||
|
|
||||||
llm = OpenAILlm(config)
|
llm = OpenAILlm(config)
|
||||||
llm.get_llm_model_answer("Test query")
|
llm.get_llm_model_answer("Test query")
|
||||||
|
|
||||||
mocked_jinachat.assert_called_once_with(
|
mocked_openai_chat.assert_called_once_with(
|
||||||
model=config.model,
|
model=config.model,
|
||||||
temperature=config.temperature,
|
temperature=config.temperature,
|
||||||
max_tokens=config.max_tokens,
|
max_tokens=config.max_tokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user