diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index 802ba270..a48145c5 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -60,9 +60,129 @@ llm: top_p: 1 stream: false ``` - +### Function Calling +To enable [function calling](https://platform.openai.com/docs/guides/function-calling) in your application using embedchain and OpenAI, you need to pass functions into `OpenAILlm` class as an array of functions. Here are several ways in which you can achieve that: + +Examples: + + ```python +import os +from embedchain import Pipeline as App +from embedchain.llm.openai import OpenAILlm +import requests +from pydantic import BaseModel, Field, ValidationError, field_validator + +os.environ["OPENAI_API_KEY"] = "sk-xxx" + +class QA(BaseModel): + """ + A question and answer pair. + """ + + question: str = Field( + ..., description="The question.", example="What is a mountain?" + ) + answer: str = Field( + ..., description="The answer.", example="A mountain is a hill." + ) + person_who_is_asking: str = Field( + ..., description="The person who is asking the question.", example="John" + ) + + @field_validator("question") + def question_must_end_with_a_question_mark(cls, v): + """ + Validate that the question ends with a question mark. + """ + if not v.endswith("?"): + raise ValueError("question must end with a question mark") + return v + + @field_validator("answer") + def answer_must_end_with_a_period(cls, v): + """ + Validate that the answer ends with a period. + """ + if not v.endswith("."): + raise ValueError("answer must end with a period") + return v + +llm = OpenAILlm(config=None,functions=[QA]) +app = App(llm=llm) + +result = app.query("Hey I am Sid. What is a mountain? A mountain is a hill.") + +print(result) + ``` + + + +```python +import os +from embedchain import Pipeline as App +from embedchain.llm.openai import OpenAILlm +import requests +from pydantic import BaseModel, Field, ValidationError, field_validator + +os.environ["OPENAI_API_KEY"] = "sk-xxx" + +json_schema = { + "name": "get_qa", + "description": "A question and answer pair and the user who is asking the question.", + "parameters": { + "type": "object", + "properties": { + "question": {"type": "string", "description": "The question."}, + "answer": {"type": "string", "description": "The answer."}, + "person_who_is_asking": { + "type": "string", + "description": "The person who is asking the question.", + } + }, + "required": ["question", "answer", "person_who_is_asking"], + }, +} + +llm = OpenAILlm(config=None,functions=[json_schema]) +app = App(llm=llm) + +result = app.query("Hey I am Sid. What is a mountain? A mountain is a hill.") + +print(result) + ``` + + + ```python +import os +from embedchain import Pipeline as App +from embedchain.llm.openai import OpenAILlm +import requests +from pydantic import BaseModel, Field, ValidationError, field_validator + +os.environ["OPENAI_API_KEY"] = "sk-xxx" + +def find_info_of_pokemon(pokemon: str): + """ + Find the information of the given pokemon. + Args: + pokemon: The pokemon. + """ + req = requests.get(f"https://pokeapi.co/api/v2/pokemon/{pokemon}") + if req.status_code == 404: + raise ValueError("pokemon not found") + return req.json() + +llm = OpenAILlm(config=None,functions=[find_info_of_pokemon]) +app = App(llm=llm) + +result = app.query("Tell me more about the pokemon pikachu.") + +print(result) +``` + + ## Google AI To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey) diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index 9efa019b..5ee192d7 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -1,7 +1,8 @@ -from typing import Optional +import json +from typing import Any, Dict, Optional from langchain.chat_models import ChatOpenAI -from langchain.schema import HumanMessage, SystemMessage +from langchain.schema import AIMessage, HumanMessage, SystemMessage from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable @@ -10,14 +11,15 @@ from embedchain.llm.base import BaseLlm @register_deserializable class OpenAILlm(BaseLlm): - def __init__(self, config: Optional[BaseLlmConfig] = None): + def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[Dict[str, Any]] = None): + self.functions = functions super().__init__(config=config) def get_llm_model_answer(self, prompt) -> str: - response = OpenAILlm._get_answer(prompt, self.config) + response = self._get_answer(prompt, self.config) return response - def _get_answer(prompt: str, config: BaseLlmConfig) -> str: + def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str: messages = [] if config.system_prompt: messages.append(SystemMessage(content=config.system_prompt)) @@ -31,11 +33,23 @@ class OpenAILlm(BaseLlm): if config.top_p: kwargs["model_kwargs"]["top_p"] = config.top_p if config.stream: - from langchain.callbacks.streaming_stdout import \ - StreamingStdOutCallbackHandler + from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks) else: chat = ChatOpenAI(**kwargs) + if self.functions is not None: + from langchain.chains.openai_functions import create_openai_fn_runnable + from langchain.prompts import ChatPromptTemplate + + structured_prompt = ChatPromptTemplate.from_messages(messages) + runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=chat) + fn_res = runnable.invoke( + { + "input": prompt, + } + ) + messages.append(AIMessage(content=json.dumps(fn_res))) + return chat(messages).content diff --git a/tests/llm/test_openai.py b/tests/llm/test_openai.py index fc823337..370a310a 100644 --- a/tests/llm/test_openai.py +++ b/tests/llm/test_openai.py @@ -50,24 +50,24 @@ def test_get_llm_model_answer_empty_prompt(config, mocker): def test_get_llm_model_answer_with_streaming(config, mocker): config.stream = True - mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI") + mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI") llm = OpenAILlm(config) llm.get_llm_model_answer("Test query") - mocked_jinachat.assert_called_once() - callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list] + mocked_openai_chat.assert_called_once() + callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list] assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks) def test_get_llm_model_answer_without_system_prompt(config, mocker): config.system_prompt = None - mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI") + mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI") llm = OpenAILlm(config) llm.get_llm_model_answer("Test query") - mocked_jinachat.assert_called_once_with( + mocked_openai_chat.assert_called_once_with( model=config.model, temperature=config.temperature, max_tokens=config.max_tokens,