[Feature] OpenAI Function Calling (#1224)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
|
||||
from langchain.embeddings import AzureOpenAIEmbeddings
|
||||
from langchain_community.embeddings import AzureOpenAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import VertexAIEmbeddings
|
||||
from langchain_community.embeddings import VertexAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
@@ -19,7 +19,7 @@ class AnthropicLlm(BaseLlm):
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain_community.chat_models import ChatAnthropic
|
||||
|
||||
chat = ChatAnthropic(
|
||||
anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model=config.model
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms import Bedrock
|
||||
from langchain_community.llms import Bedrock
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -16,7 +16,7 @@ class AzureOpenAILlm(BaseLlm):
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from langchain_community.chat_models import AzureChatOpenAI
|
||||
|
||||
if not config.deployment_name:
|
||||
raise ValueError("Deployment name must be provided for Azure OpenAI")
|
||||
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms.cohere import Cohere
|
||||
from langchain_community.llms.cohere import Cohere
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -26,7 +26,8 @@ class GPT4ALLLlm(BaseLlm):
|
||||
@staticmethod
|
||||
def _get_instance(model):
|
||||
try:
|
||||
from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
|
||||
from langchain_community.llms.gpt4all import \
|
||||
GPT4All as LangchainGPT4All
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
||||
|
||||
@@ -3,8 +3,8 @@ import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.chat_models import JinaChat
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_community.chat_models import JinaChat
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms.replicate import Replicate
|
||||
from langchain_community.llms.replicate import Replicate
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional, Union
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.llms.ollama import Ollama
|
||||
from langchain_community.llms.ollama import Ollama
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -12,8 +14,12 @@ from embedchain.llm.base import BaseLlm
|
||||
|
||||
@register_deserializable
|
||||
class OpenAILlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None):
|
||||
self.functions = functions
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[BaseLlmConfig] = None,
|
||||
tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]] = None,
|
||||
):
|
||||
self.tools = tools
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt) -> str:
|
||||
@@ -38,21 +44,27 @@ class OpenAILlm(BaseLlm):
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
||||
llm = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
|
||||
else:
|
||||
llm = ChatOpenAI(**kwargs, api_key=api_key)
|
||||
chat = ChatOpenAI(**kwargs, api_key=api_key)
|
||||
if self.tools:
|
||||
return self._query_function_call(chat, self.tools, messages)
|
||||
|
||||
if self.functions is not None:
|
||||
from langchain.chains.openai_functions import create_openai_fn_runnable
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
return chat.invoke(messages).content
|
||||
|
||||
structured_prompt = ChatPromptTemplate.from_messages(messages)
|
||||
runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=llm)
|
||||
fn_res = runnable.invoke(
|
||||
{
|
||||
"input": prompt,
|
||||
}
|
||||
)
|
||||
messages.append(AIMessage(content=json.dumps(fn_res)))
|
||||
def _query_function_call(
|
||||
self,
|
||||
chat: ChatOpenAI,
|
||||
tools: Optional[Union[Dict[str, Any], Type[BaseModel], Callable[..., Any], BaseTool]],
|
||||
messages: list[BaseMessage],
|
||||
) -> str:
|
||||
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
||||
from langchain_core.utils.function_calling import \
|
||||
convert_to_openai_tool
|
||||
|
||||
return llm(messages).content
|
||||
openai_tools = [convert_to_openai_tool(tools)]
|
||||
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
|
||||
try:
|
||||
return json.dumps(chat.invoke(messages)[0])
|
||||
except IndexError:
|
||||
return "Input could not be mapped to the function!"
|
||||
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms import Together
|
||||
from langchain_community.llms import Together
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -24,7 +24,7 @@ class VertexAILlm(BaseLlm):
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
from langchain.chat_models import ChatVertexAI
|
||||
from langchain_community.chat_models import ChatVertexAI
|
||||
|
||||
chat = ChatVertexAI(temperature=config.temperature, model=config.model)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
from langchain.document_loaders import Docx2txtLoader
|
||||
from langchain_community.document_loaders import Docx2txtLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Docx file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
|
||||
@@ -8,8 +8,8 @@ except ImportError:
|
||||
"Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`"
|
||||
) from None
|
||||
|
||||
from langchain.document_loaders import GoogleDriveLoader as Loader
|
||||
from langchain.document_loaders import UnstructuredFileIOLoader
|
||||
from langchain_community.document_loaders import GoogleDriveLoader as Loader
|
||||
from langchain_community.document_loaders import UnstructuredFileIOLoader
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'PDF File requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
|
||||
@@ -28,7 +28,7 @@ class RSSFeedLoader(BaseLoader):
|
||||
@staticmethod
|
||||
def get_rss_content(url: str):
|
||||
try:
|
||||
from langchain.document_loaders import \
|
||||
from langchain_community.document_loaders import \
|
||||
RSSFeedLoader as LangchainRSSFeedLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
|
||||
@@ -10,7 +10,8 @@ class UnstructuredLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from an Unstructured file."""
|
||||
try:
|
||||
from langchain.document_loaders import UnstructuredFileLoader
|
||||
from langchain_community.document_loaders import \
|
||||
UnstructuredFileLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Unstructured file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' # noqa: E501
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
from langchain.document_loaders import UnstructuredXMLLoader
|
||||
from langchain_community.document_loaders import UnstructuredXMLLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'XML file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
from langchain.document_loaders import YoutubeLoader
|
||||
from langchain_community.document_loaders import YoutubeLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'YouTube video requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
|
||||
@@ -12,8 +12,8 @@ except ImportError:
|
||||
"OpenSearch requires extra dependencies. Install with `pip install --upgrade embedchain[opensearch]`"
|
||||
) from None
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import OpenSearchVectorSearch
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import OpenSearchVectorSearch
|
||||
|
||||
from embedchain.config import OpenSearchDBConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
Reference in New Issue
Block a user