feat: add mistral AI as LLM provider (#2496)
Co-authored-by: Saket Aryan <94069182+whysosaket@users.noreply.github.com>
This commit is contained in:
78
mem0-ts/src/oss/examples/llms/mistral-example.ts
Normal file
78
mem0-ts/src/oss/examples/llms/mistral-example.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
import dotenv from "dotenv";
|
||||
import { MistralLLM } from "../../src/llms/mistral";
|
||||
|
||||
// Load environment variables
|
||||
dotenv.config();
|
||||
|
||||
async function testMistral() {
|
||||
// Check for API key
|
||||
if (!process.env.MISTRAL_API_KEY) {
|
||||
console.error("MISTRAL_API_KEY environment variable is required");
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
console.log("Testing Mistral LLM implementation...");
|
||||
|
||||
// Initialize MistralLLM
|
||||
const mistral = new MistralLLM({
|
||||
apiKey: process.env.MISTRAL_API_KEY,
|
||||
model: "mistral-tiny-latest", // You can change to other models like mistral-small-latest
|
||||
});
|
||||
|
||||
try {
|
||||
// Test simple chat completion
|
||||
console.log("Testing simple chat completion:");
|
||||
const chatResponse = await mistral.generateChat([
|
||||
{ role: "system", content: "You are a helpful assistant." },
|
||||
{ role: "user", content: "What is the capital of France?" },
|
||||
]);
|
||||
|
||||
console.log("Chat response:");
|
||||
console.log(`Role: ${chatResponse.role}`);
|
||||
console.log(`Content: ${chatResponse.content}\n`);
|
||||
|
||||
// Test with functions/tools
|
||||
console.log("Testing tool calling:");
|
||||
const tools = [
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
name: "get_weather",
|
||||
description: "Get the current weather in a given location",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
type: "string",
|
||||
description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
unit: {
|
||||
type: "string",
|
||||
enum: ["celsius", "fahrenheit"],
|
||||
description: "The unit of temperature",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const toolResponse = await mistral.generateResponse(
|
||||
[
|
||||
{ role: "system", content: "You are a helpful assistant." },
|
||||
{ role: "user", content: "What's the weather like in Paris, France?" },
|
||||
],
|
||||
undefined,
|
||||
tools,
|
||||
);
|
||||
|
||||
console.log("Tool response:", toolResponse);
|
||||
|
||||
console.log("\n✅ All tests completed successfully");
|
||||
} catch (error) {
|
||||
console.error("Error testing Mistral LLM:", error);
|
||||
}
|
||||
}
|
||||
|
||||
testMistral().catch(console.error);
|
||||
@@ -12,6 +12,7 @@ export * from "./llms/openai_structured";
|
||||
export * from "./llms/anthropic";
|
||||
export * from "./llms/groq";
|
||||
export * from "./llms/ollama";
|
||||
export * from "./llms/mistral";
|
||||
export * from "./vector_stores/base";
|
||||
export * from "./vector_stores/memory";
|
||||
export * from "./vector_stores/qdrant";
|
||||
|
||||
112
mem0-ts/src/oss/src/llms/mistral.ts
Normal file
112
mem0-ts/src/oss/src/llms/mistral.ts
Normal file
@@ -0,0 +1,112 @@
|
||||
import { Mistral } from "@mistralai/mistralai";
|
||||
import { LLM, LLMResponse } from "./base";
|
||||
import { LLMConfig, Message } from "../types";
|
||||
|
||||
export class MistralLLM implements LLM {
|
||||
private client: Mistral;
|
||||
private model: string;
|
||||
|
||||
constructor(config: LLMConfig) {
|
||||
if (!config.apiKey) {
|
||||
throw new Error("Mistral API key is required");
|
||||
}
|
||||
this.client = new Mistral({
|
||||
apiKey: config.apiKey,
|
||||
});
|
||||
this.model = config.model || "mistral-tiny-latest";
|
||||
}
|
||||
|
||||
// Helper function to convert content to string
|
||||
private contentToString(content: any): string {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
}
|
||||
if (Array.isArray(content)) {
|
||||
// Handle ContentChunk array - extract text content
|
||||
return content
|
||||
.map((chunk) => {
|
||||
if (chunk.type === "text") {
|
||||
return chunk.text;
|
||||
} else {
|
||||
return JSON.stringify(chunk);
|
||||
}
|
||||
})
|
||||
.join("");
|
||||
}
|
||||
return String(content || "");
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
messages: Message[],
|
||||
responseFormat?: { type: string },
|
||||
tools?: any[],
|
||||
): Promise<string | LLMResponse> {
|
||||
const response = await this.client.chat.complete({
|
||||
model: this.model,
|
||||
messages: messages.map((msg) => ({
|
||||
role: msg.role as "system" | "user" | "assistant",
|
||||
content:
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: JSON.stringify(msg.content),
|
||||
})),
|
||||
...(tools && { tools }),
|
||||
...(responseFormat && { response_format: responseFormat }),
|
||||
});
|
||||
|
||||
if (!response || !response.choices || response.choices.length === 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const message = response.choices[0].message;
|
||||
|
||||
if (!message) {
|
||||
return "";
|
||||
}
|
||||
|
||||
if (message.toolCalls && message.toolCalls.length > 0) {
|
||||
return {
|
||||
content: this.contentToString(message.content),
|
||||
role: message.role || "assistant",
|
||||
toolCalls: message.toolCalls.map((call) => ({
|
||||
name: call.function.name,
|
||||
arguments:
|
||||
typeof call.function.arguments === "string"
|
||||
? call.function.arguments
|
||||
: JSON.stringify(call.function.arguments),
|
||||
})),
|
||||
};
|
||||
}
|
||||
|
||||
return this.contentToString(message.content);
|
||||
}
|
||||
|
||||
async generateChat(messages: Message[]): Promise<LLMResponse> {
|
||||
const formattedMessages = messages.map((msg) => ({
|
||||
role: msg.role as "system" | "user" | "assistant",
|
||||
content:
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: JSON.stringify(msg.content),
|
||||
}));
|
||||
|
||||
const response = await this.client.chat.complete({
|
||||
model: this.model,
|
||||
messages: formattedMessages,
|
||||
});
|
||||
|
||||
if (!response || !response.choices || response.choices.length === 0) {
|
||||
return {
|
||||
content: "",
|
||||
role: "assistant",
|
||||
};
|
||||
}
|
||||
|
||||
const message = response.choices[0].message;
|
||||
|
||||
return {
|
||||
content: this.contentToString(message.content),
|
||||
role: message.role || "assistant",
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import { OpenAILLM } from "../llms/openai";
|
||||
import { OpenAIStructuredLLM } from "../llms/openai_structured";
|
||||
import { AnthropicLLM } from "../llms/anthropic";
|
||||
import { GroqLLM } from "../llms/groq";
|
||||
import { MistralLLM } from "../llms/mistral";
|
||||
import { MemoryVectorStore } from "../vector_stores/memory";
|
||||
import {
|
||||
EmbeddingConfig,
|
||||
@@ -55,6 +56,8 @@ export class LLMFactory {
|
||||
return new OllamaLLM(config);
|
||||
case "google":
|
||||
return new GoogleLLM(config);
|
||||
case "mistral":
|
||||
return new MistralLLM(config);
|
||||
default:
|
||||
throw new Error(`Unsupported LLM provider: ${provider}`);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user