Update Vercel AI SDK to support tools call (#2383)

This commit is contained in:
Saket Aryan
2025-03-26 10:30:44 +05:30
committed by GitHub
parent 366d263e0b
commit 9d0300f774
28 changed files with 763 additions and 803 deletions

View File

@@ -1,8 +1,6 @@
import { createOpenAI, OpenAIProviderSettings } from "@ai-sdk/openai";
import { generateText as aiGenerateText, streamText as aiStreamText, LanguageModelV1Prompt } from "ai";
import { updateMemories, retrieveMemories, flattenPrompt, convertMessagesToMem0Format } from "./mem0-utils";
import { Mem0Config } from "./mem0-chat-settings";
import { LanguageModelV1, LanguageModelV1CallOptions, LanguageModelV1Prompt } from "ai";
import { Mem0ProviderSettings } from "./mem0-provider";
import { createOpenAI, OpenAIProviderSettings } from "@ai-sdk/openai";
import { CohereProviderSettings, createCohere } from "@ai-sdk/cohere";
import { AnthropicProviderSettings, createAnthropic } from "@ai-sdk/anthropic";
import { createGroq, GroqProviderSettings } from "@ai-sdk/groq";
@@ -10,19 +8,51 @@ import { createGroq, GroqProviderSettings } from "@ai-sdk/groq";
export type Provider = ReturnType<typeof createOpenAI> | ReturnType<typeof createCohere> | ReturnType<typeof createAnthropic> | ReturnType<typeof createGroq> | any;
export type ProviderSettings = OpenAIProviderSettings | CohereProviderSettings | AnthropicProviderSettings | GroqProviderSettings;
class Mem0AITextGenerator {
const convertMessagesToMem0Format = (messages: LanguageModelV1Prompt) => {
return messages.map((message) => {
// If the content is a string, return it as is
if (typeof message.content === "string") {
return message;
}
// Flatten the content array into a single string
if (Array.isArray(message.content)) {
message.content = message.content
.map((contentItem) => {
if ("text" in contentItem) {
return contentItem.text;
}
return "";
})
.join(" ");
}
const contentText = message.content;
return {
role: message.role,
content: contentText,
};
});
}
class Mem0AITextGenerator implements LanguageModelV1 {
readonly specificationVersion = "v1";
readonly defaultObjectGenerationMode = "json";
readonly supportsImageUrls = false;
readonly modelId: string;
provider: Provider;
model: string;
provider_config?: ProviderSettings;
config: Mem0ProviderSettings;
constructor(provider: string, model: string, config: Mem0ProviderSettings, provider_config: ProviderSettings) {
switch (provider) {
constructor(modelId: string, config: Mem0ProviderSettings, provider_config: ProviderSettings) {
switch (config.provider) {
case "openai":
this.provider = createOpenAI({
apiKey: config?.apiKey,
...provider_config,
});
}).languageModel;
if(config?.modelType === "completion"){
this.provider = createOpenAI({
apiKey: config?.apiKey,
@@ -45,7 +75,7 @@ class Mem0AITextGenerator {
this.provider = createAnthropic({
apiKey: config?.apiKey,
...provider_config,
});
}).languageModel;
break;
case "groq":
this.provider = createGroq({
@@ -56,56 +86,18 @@ class Mem0AITextGenerator {
default:
throw new Error("Invalid provider");
}
this.model = model;
this.modelId = modelId;
this.provider_config = provider_config;
this.config = config!;
}
async generateText(prompt: LanguageModelV1Prompt, config: Mem0Config) {
try {
const flattenPromptResponse = flattenPrompt(prompt);
const newPrompt = await retrieveMemories(prompt, config);
const response = await aiGenerateText({
// @ts-ignore
model: this.provider(this.model),
messages: prompt,
system: newPrompt
});
const mem0Prompts = convertMessagesToMem0Format(prompt);
await updateMemories(mem0Prompts as any, config);
return response;
} catch (error) {
console.error("Error generating text:", error);
throw error;
}
doGenerate(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
return this.provider(this.modelId, this.provider_config).doGenerate(options);
}
async streamText(prompt: LanguageModelV1Prompt, config: Mem0Config) {
try {
const flattenPromptResponse = flattenPrompt(prompt);
const newPrompt = await retrieveMemories(prompt, config);
await updateMemories([
{ role: "user", content: flattenPromptResponse },
{ role: "assistant", content: "Thank You!" },
], config);
const response = await aiStreamText({
// @ts-ignore
model: this.provider(this.model),
messages: prompt,
system: newPrompt
});
return response;
} catch (error) {
console.error("Error generating text:", error);
throw error;
}
doStream(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
return this.provider(this.modelId, this.provider_config).doStream(options);
}
}