Add support for anthropic (#1819)

This commit is contained in:
Dev Khant
2024-09-07 02:12:22 +05:30
committed by GitHub
parent 965f7a3735
commit d32ae1a0b1
13 changed files with 82 additions and 84 deletions

View File

@@ -11,7 +11,7 @@ os.environ["ANTHROPIC_API_KEY"] = "your-api-key"
config = { config = {
"llm": { "llm": {
"provider": "litellm", "provider": "anthropic",
"config": { "config": {
"model": "claude-3-opus-20240229", "model": "claude-3-opus-20240229",
"temperature": 0.1, "temperature": 0.1,

67
mem0/llms/anthropic.py Normal file
View File

@@ -0,0 +1,67 @@
import subprocess
import sys
import os
import json
from typing import Dict, List, Optional
try:
import anthropic
except ImportError:
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class AnthropicLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model = "claude-3-5-sonnet-20240620"
api_key = self.config.api_key or os.getenv("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic(api_key=api_key)
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using Anthropic.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
# Separate system message from other messages
system_message = ""
filtered_messages = []
for message in messages:
if message['role'] == 'system':
system_message = message['content']
else:
filtered_messages.append(message)
params = {
"model": self.config.model,
"messages": filtered_messages,
"system": system_message,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.messages.create(**params)
return response.content[0].text

View File

@@ -7,16 +7,7 @@ from typing import Dict, List, Optional, Any
try: try:
import boto3 import boto3
except ImportError: except ImportError:
user_input = input("The 'boto3' library is required. Install it now? [y/N]: ") raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "boto3"])
import boto3
except subprocess.CalledProcessError:
print("Failed to install 'boto3'. Please install it manually using 'pip install boto3'")
sys.exit(1)
else:
raise ImportError("The required 'boto3' library is not installed.")
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig

View File

@@ -87,7 +87,7 @@ class AzureOpenAILLM(LLMBase):
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools params["tools"] = tools
params["tool_choice"] = tool_choice params["tool_choice"] = tool_choice

View File

@@ -7,16 +7,7 @@ from typing import Dict, List, Optional
try: try:
from groq import Groq from groq import Groq
except ImportError: except ImportError:
user_input = input("The 'groq' library is required. Install it now? [y/N]: ") raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "groq"])
from groq import Groq
except subprocess.CalledProcessError:
print("Failed to install 'groq'. Please install it manually using 'pip install groq'.")
sys.exit(1)
else:
raise ImportError("The required 'groq' library is not installed.")
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
@@ -28,7 +19,6 @@ class GroqLLM(LLMBase):
if not self.config.model: if not self.config.model:
self.config.model = "llama3-70b-8192" self.config.model = "llama3-70b-8192"
self.client = Groq()
api_key = self.config.api_key or os.getenv("GROQ_API_KEY") api_key = self.config.api_key or os.getenv("GROQ_API_KEY")
self.client = Groq(api_key=api_key) self.client = Groq(api_key=api_key)

View File

@@ -6,16 +6,7 @@ from typing import Dict, List, Optional
try: try:
import litellm import litellm
except ImportError: except ImportError:
user_input = input("The 'litellm' library is required. Install it now? [y/N]: ") raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
import litellm
except subprocess.CalledProcessError:
print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.")
sys.exit(1)
else:
raise ImportError("The required 'litellm' library is not installed.")
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
@@ -91,7 +82,7 @@ class LiteLLM(LLMBase):
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools params["tools"] = tools
params["tool_choice"] = tool_choice params["tool_choice"] = tool_choice

View File

@@ -5,17 +5,7 @@ from typing import Dict, List, Optional
try: try:
from ollama import Client from ollama import Client
except ImportError: except ImportError:
user_input = input("The 'ollama' library is required. Install it now? [y/N]: ") raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
from ollama import Client
except subprocess.CalledProcessError:
print("Failed to install 'ollama'. Please install it manually using 'pip install ollama'.")
sys.exit(1)
else:
print("The required 'ollama' library is not installed.")
sys.exit(1)
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig

View File

@@ -100,7 +100,7 @@ class OpenAILLM(LLMBase):
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools params["tools"] = tools
params["tool_choice"] = tool_choice params["tool_choice"] = tool_choice

View File

@@ -7,17 +7,7 @@ from typing import Dict, List, Optional
try: try:
from together import Together from together import Together
except ImportError: except ImportError:
user_input = input("The 'together' library is required. Install it now? [y/N]: ") raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "together"])
from together import Together
except subprocess.CalledProcessError:
print("Failed to install 'together'. Please install it manually using 'pip install together'.")
sys.exit(1)
else:
print("The required 'together' library is not installed.")
sys.exit(1)
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
@@ -29,7 +19,6 @@ class TogetherLLM(LLMBase):
if not self.config.model: if not self.config.model:
self.config.model = "mistralai/Mixtral-8x7B-Instruct-v0.1" self.config.model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
self.client = Together()
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY") api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
self.client = Together(api_key=api_key) self.client = Together(api_key=api_key)
@@ -92,7 +81,7 @@ class TogetherLLM(LLMBase):
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools params["tools"] = tools
params["tool_choice"] = tool_choice params["tool_choice"] = tool_choice

View File

@@ -1,3 +1,5 @@
# TODO: Remove these tools if no issues are found for new memory addition logic
ADD_MEMORY_TOOL = { ADD_MEMORY_TOOL = {
"type": "function", "type": "function",
"function": { "function": {

View File

@@ -20,6 +20,7 @@ class LlmFactory:
"litellm": "mem0.llms.litellm.LiteLLM", "litellm": "mem0.llms.litellm.LiteLLM",
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM", "azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM", "openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
"anthropic": "mem0.llms.anthropic.AnthropicLLM"
} }
@classmethod @classmethod

View File

@@ -9,18 +9,7 @@ try:
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
except ImportError: except ImportError:
user_input = input("The 'chromadb' library is required. Install it now? [y/N]: ") raise ImportError("The 'chromadb' library is required. Please install it using 'pip install chromadb'.")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"])
import chromadb
from chromadb.config import Settings
except subprocess.CalledProcessError:
print("Failed to install 'chromadb'. Please install it manually using 'pip install chromadb'.")
sys.exit(1)
else:
print("The required 'chromadb' library is not installed.")
sys.exit(1)
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase

View File

@@ -9,19 +9,7 @@ try:
import psycopg2 import psycopg2
from psycopg2.extras import execute_values from psycopg2.extras import execute_values
except ImportError: except ImportError:
user_input = input("The 'psycopg2' library is required. Install it now? [y/N]: ") raise ImportError("The 'psycopg2' library is required. Please install it using 'pip install psycopg2'.")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "psycopg2"])
import psycopg2
from psycopg2.extras import execute_values
except subprocess.CalledProcessError:
print("Failed to install 'psycopg2'. Please install it manually using 'pip install psycopg2'.")
sys.exit(1)
else:
print("The required 'psycopg2' library is not installed.")
sys.exit(1)
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase