feat: add mistral AI as LLM provider (#2496)

Co-authored-by: Saket Aryan <94069182+whysosaket@users.noreply.github.com>
This commit is contained in:
Achraf Dev
2025-04-11 15:32:44 +01:00
committed by GitHub
parent 942727fec6
commit d9236de4ed
8 changed files with 269 additions and 12 deletions

View File

@@ -0,0 +1,78 @@
import dotenv from "dotenv";
import { MistralLLM } from "../../src/llms/mistral";
// Load environment variables
dotenv.config();
async function testMistral() {
// Check for API key
if (!process.env.MISTRAL_API_KEY) {
console.error("MISTRAL_API_KEY environment variable is required");
process.exit(1);
}
console.log("Testing Mistral LLM implementation...");
// Initialize MistralLLM
const mistral = new MistralLLM({
apiKey: process.env.MISTRAL_API_KEY,
model: "mistral-tiny-latest", // You can change to other models like mistral-small-latest
});
try {
// Test simple chat completion
console.log("Testing simple chat completion:");
const chatResponse = await mistral.generateChat([
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "What is the capital of France?" },
]);
console.log("Chat response:");
console.log(`Role: ${chatResponse.role}`);
console.log(`Content: ${chatResponse.content}\n`);
// Test with functions/tools
console.log("Testing tool calling:");
const tools = [
{
type: "function",
function: {
name: "get_weather",
description: "Get the current weather in a given location",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
unit: {
type: "string",
enum: ["celsius", "fahrenheit"],
description: "The unit of temperature",
},
},
required: ["location"],
},
},
},
];
const toolResponse = await mistral.generateResponse(
[
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "What's the weather like in Paris, France?" },
],
undefined,
tools,
);
console.log("Tool response:", toolResponse);
console.log("\n✅ All tests completed successfully");
} catch (error) {
console.error("Error testing Mistral LLM:", error);
}
}
testMistral().catch(console.error);

View File

@@ -12,6 +12,7 @@ export * from "./llms/openai_structured";
export * from "./llms/anthropic";
export * from "./llms/groq";
export * from "./llms/ollama";
export * from "./llms/mistral";
export * from "./vector_stores/base";
export * from "./vector_stores/memory";
export * from "./vector_stores/qdrant";

View File

@@ -0,0 +1,112 @@
import { Mistral } from "@mistralai/mistralai";
import { LLM, LLMResponse } from "./base";
import { LLMConfig, Message } from "../types";
export class MistralLLM implements LLM {
private client: Mistral;
private model: string;
constructor(config: LLMConfig) {
if (!config.apiKey) {
throw new Error("Mistral API key is required");
}
this.client = new Mistral({
apiKey: config.apiKey,
});
this.model = config.model || "mistral-tiny-latest";
}
// Helper function to convert content to string
private contentToString(content: any): string {
if (typeof content === "string") {
return content;
}
if (Array.isArray(content)) {
// Handle ContentChunk array - extract text content
return content
.map((chunk) => {
if (chunk.type === "text") {
return chunk.text;
} else {
return JSON.stringify(chunk);
}
})
.join("");
}
return String(content || "");
}
async generateResponse(
messages: Message[],
responseFormat?: { type: string },
tools?: any[],
): Promise<string | LLMResponse> {
const response = await this.client.chat.complete({
model: this.model,
messages: messages.map((msg) => ({
role: msg.role as "system" | "user" | "assistant",
content:
typeof msg.content === "string"
? msg.content
: JSON.stringify(msg.content),
})),
...(tools && { tools }),
...(responseFormat && { response_format: responseFormat }),
});
if (!response || !response.choices || response.choices.length === 0) {
return "";
}
const message = response.choices[0].message;
if (!message) {
return "";
}
if (message.toolCalls && message.toolCalls.length > 0) {
return {
content: this.contentToString(message.content),
role: message.role || "assistant",
toolCalls: message.toolCalls.map((call) => ({
name: call.function.name,
arguments:
typeof call.function.arguments === "string"
? call.function.arguments
: JSON.stringify(call.function.arguments),
})),
};
}
return this.contentToString(message.content);
}
async generateChat(messages: Message[]): Promise<LLMResponse> {
const formattedMessages = messages.map((msg) => ({
role: msg.role as "system" | "user" | "assistant",
content:
typeof msg.content === "string"
? msg.content
: JSON.stringify(msg.content),
}));
const response = await this.client.chat.complete({
model: this.model,
messages: formattedMessages,
});
if (!response || !response.choices || response.choices.length === 0) {
return {
content: "",
role: "assistant",
};
}
const message = response.choices[0].message;
return {
content: this.contentToString(message.content),
role: message.role || "assistant",
};
}
}

View File

@@ -4,6 +4,7 @@ import { OpenAILLM } from "../llms/openai";
import { OpenAIStructuredLLM } from "../llms/openai_structured";
import { AnthropicLLM } from "../llms/anthropic";
import { GroqLLM } from "../llms/groq";
import { MistralLLM } from "../llms/mistral";
import { MemoryVectorStore } from "../vector_stores/memory";
import {
EmbeddingConfig,
@@ -55,6 +56,8 @@ export class LLMFactory {
return new OllamaLLM(config);
case "google":
return new GoogleLLM(config);
case "mistral":
return new MistralLLM(config);
default:
throw new Error(`Unsupported LLM provider: ${provider}`);
}