[Feature] OpenAI Function Calling (#1224)

This commit is contained in:
UnMonsieur
2024-02-12 02:58:11 +01:00
committed by GitHub
parent 38e212c721
commit 41bd258b93
31 changed files with 259 additions and 213 deletions

View File

@@ -1,6 +1,6 @@
from typing import Optional
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder

View File

@@ -2,7 +2,7 @@ import os
from typing import Optional
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from langchain.embeddings import AzureOpenAIEmbeddings
from langchain_community.embeddings import AzureOpenAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder

View File

@@ -1,6 +1,6 @@
from typing import Optional
from langchain.embeddings import VertexAIEmbeddings
from langchain_community.embeddings import VertexAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder

View File

@@ -19,7 +19,7 @@ class AnthropicLlm(BaseLlm):
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain.chat_models import ChatAnthropic
from langchain_community.chat_models import ChatAnthropic
chat = ChatAnthropic(
anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model=config.model

View File

@@ -1,7 +1,7 @@
import os
from typing import Optional
from langchain.llms import Bedrock
from langchain_community.llms import Bedrock
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -16,7 +16,7 @@ class AzureOpenAILlm(BaseLlm):
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain.chat_models import AzureChatOpenAI
from langchain_community.chat_models import AzureChatOpenAI
if not config.deployment_name:
raise ValueError("Deployment name must be provided for Azure OpenAI")

View File

@@ -2,7 +2,7 @@ import importlib
import os
from typing import Optional
from langchain.llms.cohere import Cohere
from langchain_community.llms.cohere import Cohere
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -26,7 +26,8 @@ class GPT4ALLLlm(BaseLlm):
@staticmethod
def _get_instance(model):
try:
from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
from langchain_community.llms.gpt4all import \
GPT4All as LangchainGPT4All
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501

View File

@@ -3,8 +3,8 @@ import logging
import os
from typing import Optional
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_community.llms.huggingface_hub import HuggingFaceHub
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,8 +1,8 @@
import os
from typing import Optional
from langchain.chat_models import JinaChat
from langchain.schema import HumanMessage, SystemMessage
from langchain_community.chat_models import JinaChat
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -2,7 +2,7 @@ import importlib
import os
from typing import Optional
from langchain.llms.replicate import Replicate
from langchain_community.llms.replicate import Replicate
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -4,7 +4,7 @@ from typing import Optional, Union
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.llms.ollama import Ollama
from langchain_community.llms.ollama import Ollama
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,9 +1,11 @@
import json
import os
from typing import Any, Optional
from typing import Any, Callable, Dict, Optional, Type, Union
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -12,8 +14,12 @@ from embedchain.llm.base import BaseLlm
@register_deserializable
class OpenAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None):
self.functions = functions
def __init__(
self,
config: Optional[BaseLlmConfig] = None,
tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]] = None,
):
self.tools = tools
super().__init__(config=config)
def get_llm_model_answer(self, prompt) -> str:
@@ -38,21 +44,27 @@ class OpenAILlm(BaseLlm):
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
llm = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
else:
llm = ChatOpenAI(**kwargs, api_key=api_key)
chat = ChatOpenAI(**kwargs, api_key=api_key)
if self.tools:
return self._query_function_call(chat, self.tools, messages)
if self.functions is not None:
from langchain.chains.openai_functions import create_openai_fn_runnable
from langchain.prompts import ChatPromptTemplate
return chat.invoke(messages).content
structured_prompt = ChatPromptTemplate.from_messages(messages)
runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=llm)
fn_res = runnable.invoke(
{
"input": prompt,
}
)
messages.append(AIMessage(content=json.dumps(fn_res)))
def _query_function_call(
self,
chat: ChatOpenAI,
tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]],
messages: list[BaseMessage],
) -> str:
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.utils.function_calling import \
convert_to_openai_tool
return llm(messages).content
openai_tools = [convert_to_openai_tool(tools)]
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
try:
return json.dumps(chat.invoke(messages)[0])
except IndexError:
return "Input could not be mapped to the function!"

View File

@@ -2,7 +2,7 @@ import importlib
import os
from typing import Optional
from langchain.llms import Together
from langchain_community.llms import Together
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -24,7 +24,7 @@ class VertexAILlm(BaseLlm):
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain.chat_models import ChatVertexAI
from langchain_community.chat_models import ChatVertexAI
chat = ChatVertexAI(temperature=config.temperature, model=config.model)

View File

@@ -1,7 +1,7 @@
import hashlib
try:
from langchain.document_loaders import Docx2txtLoader
from langchain_community.document_loaders import Docx2txtLoader
except ImportError:
raise ImportError(
'Docx file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'

View File

@@ -8,8 +8,8 @@ except ImportError:
"Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`"
) from None
from langchain.document_loaders import GoogleDriveLoader as Loader
from langchain.document_loaders import UnstructuredFileIOLoader
from langchain_community.document_loaders import GoogleDriveLoader as Loader
from langchain_community.document_loaders import UnstructuredFileIOLoader
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader

View File

@@ -1,7 +1,7 @@
import hashlib
try:
from langchain.document_loaders import PyPDFLoader
from langchain_community.document_loaders import PyPDFLoader
except ImportError:
raise ImportError(
'PDF File requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'

View File

@@ -28,7 +28,7 @@ class RSSFeedLoader(BaseLoader):
@staticmethod
def get_rss_content(url: str):
try:
from langchain.document_loaders import \
from langchain_community.document_loaders import \
RSSFeedLoader as LangchainRSSFeedLoader
except ImportError:
raise ImportError(

View File

@@ -10,7 +10,8 @@ class UnstructuredLoader(BaseLoader):
def load_data(self, url):
"""Load data from an Unstructured file."""
try:
from langchain.document_loaders import UnstructuredFileLoader
from langchain_community.document_loaders import \
UnstructuredFileLoader
except ImportError:
raise ImportError(
'Unstructured file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' # noqa: E501

View File

@@ -1,7 +1,7 @@
import hashlib
try:
from langchain.document_loaders import UnstructuredXMLLoader
from langchain_community.document_loaders import UnstructuredXMLLoader
except ImportError:
raise ImportError(
'XML file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'

View File

@@ -1,7 +1,7 @@
import hashlib
try:
from langchain.document_loaders import YoutubeLoader
from langchain_community.document_loaders import YoutubeLoader
except ImportError:
raise ImportError(
'YouTube video requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'

View File

@@ -12,8 +12,8 @@ except ImportError:
"OpenSearch requires extra dependencies. Install with `pip install --upgrade embedchain[opensearch]`"
) from None
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import OpenSearchVectorSearch
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.vectorstores import OpenSearchVectorSearch
from embedchain.config import OpenSearchDBConfig
from embedchain.helpers.json_serializable import register_deserializable