Update Vercel AI SDK to support tools call (#2383)
This commit is contained in:
@@ -2,32 +2,16 @@
|
||||
import {
|
||||
LanguageModelV1,
|
||||
LanguageModelV1CallOptions,
|
||||
LanguageModelV1CallWarning,
|
||||
LanguageModelV1FinishReason,
|
||||
LanguageModelV1FunctionToolCall,
|
||||
LanguageModelV1LogProbs,
|
||||
LanguageModelV1ProviderMetadata,
|
||||
LanguageModelV1StreamPart,
|
||||
LanguageModelV1Message,
|
||||
} from "@ai-sdk/provider";
|
||||
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
|
||||
import { Mem0ChatConfig, Mem0ChatModelId, Mem0ChatSettings, Mem0ConfigSettings, Mem0StreamResponse } from "./mem0-types";
|
||||
import { Mem0ClassSelector } from "./mem0-provider-selector";
|
||||
import { filterStream } from "./stream-utils";
|
||||
import { Mem0Config } from "./mem0-chat-settings";
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
import { addMemories, getMemories, retrieveMemories } from "./mem0-utils";
|
||||
|
||||
|
||||
interface Mem0ChatConfig {
|
||||
baseURL: string;
|
||||
fetch?: typeof fetch;
|
||||
headers: () => Record<string, string | undefined>;
|
||||
provider: string;
|
||||
organization?: string;
|
||||
project?: string;
|
||||
name?: string;
|
||||
apiKey?: string;
|
||||
mem0_api_key?: string;
|
||||
const generateRandomId = () => {
|
||||
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
||||
}
|
||||
|
||||
export class Mem0GenericLanguageModel implements LanguageModelV1 {
|
||||
@@ -39,60 +23,119 @@ export class Mem0GenericLanguageModel implements LanguageModelV1 {
|
||||
public readonly modelId: Mem0ChatModelId,
|
||||
public readonly settings: Mem0ChatSettings,
|
||||
public readonly config: Mem0ChatConfig,
|
||||
public readonly provider_config?: OpenAIProviderSettings
|
||||
public readonly provider_config?: Mem0ProviderSettings
|
||||
) {
|
||||
this.provider = config.provider;
|
||||
this.provider = config.provider ?? "openai";
|
||||
}
|
||||
|
||||
provider: string;
|
||||
supportsStructuredOutputs?: boolean | undefined;
|
||||
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
|
||||
text?: string;
|
||||
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
|
||||
finishReason: LanguageModelV1FinishReason;
|
||||
usage: { promptTokens: number; completionTokens: number };
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
response?: { id?: string; timestamp?: Date; modelId?: string };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
providerMetadata?: LanguageModelV1ProviderMetadata;
|
||||
logprobs?: LanguageModelV1LogProbs;
|
||||
}> {
|
||||
private async processMemories(messagesPrompts: LanguageModelV1Message[], mem0Config: Mem0ConfigSettings) {
|
||||
const memories = await getMemories(messagesPrompts, mem0Config);
|
||||
// Add New Memories
|
||||
await addMemories(messagesPrompts, mem0Config);
|
||||
const mySystemPrompt = "These are the memories I have stored. Give more weightage to the question by users and try to answer that first. You have to modify your answer based on the memories I have provided. If the memories are irrelevant you can ignore them. Also don't reply to this section of the prompt, or the memories, they are only for your reference. The System prompt starts after text System Message: \n\n";
|
||||
|
||||
let memoriesText = "";
|
||||
try {
|
||||
// @ts-ignore
|
||||
memoriesText = memories.map((memory: any) => {
|
||||
return `Memory: ${memory.memory}\n\n`;
|
||||
}).join("\n\n");
|
||||
} catch(e) {
|
||||
console.error("Error while parsing memories");
|
||||
}
|
||||
|
||||
const memoriesPrompt = `System Message: ${mySystemPrompt} ${memoriesText}`;
|
||||
|
||||
// System Prompt - The memories go as a system prompt
|
||||
const systemPrompt: LanguageModelV1Message = {
|
||||
role: "system",
|
||||
content: memoriesPrompt
|
||||
};
|
||||
|
||||
// Add the system prompt to the beginning of the messages if there are memories
|
||||
if (memories.length > 0) {
|
||||
messagesPrompts.unshift(systemPrompt);
|
||||
}
|
||||
|
||||
return { memories, messagesPrompts };
|
||||
}
|
||||
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const mem0_api_key = this.config.mem0ApiKey;
|
||||
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
|
||||
const mem0Config: Mem0ConfigSettings = {
|
||||
mem0ApiKey: mem0_api_key,
|
||||
...this.config.mem0Config,
|
||||
...this.settings,
|
||||
}
|
||||
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings, this.provider_config);
|
||||
|
||||
let messagesPrompts = options.prompt;
|
||||
|
||||
// Process memories and update prompts
|
||||
const { memories, messagesPrompts: updatedPrompts } = await this.processMemories(messagesPrompts, mem0Config);
|
||||
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
|
||||
const ans = await model.generateText(messagesPrompts, config);
|
||||
const ans = await model.doGenerate({
|
||||
...options,
|
||||
prompt: updatedPrompts,
|
||||
});
|
||||
|
||||
// If there are no memories, return the original response
|
||||
if (!memories || memories.length === 0) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
// Create sources array with existing sources
|
||||
const sources = [...(ans.sources || [])];
|
||||
|
||||
// Add a combined source with all memories
|
||||
if (Array.isArray(memories) && memories.length > 0) {
|
||||
sources.push({
|
||||
title: "Mem0 Memories",
|
||||
sourceType: "url",
|
||||
id: "mem0-" + generateRandomId(),
|
||||
url: "https://app.mem0.ai",
|
||||
providerMetadata: {
|
||||
mem0: {
|
||||
memories: memories,
|
||||
memoriesText: memories.map((memory: any) => memory.memory).join("\n\n")
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Add individual memory sources for more detailed information
|
||||
memories.forEach((memory: any) => {
|
||||
sources.push({
|
||||
title: memory.title || "Memory",
|
||||
sourceType: "url",
|
||||
id: "mem0-memory-" + generateRandomId(),
|
||||
url: "https://app.mem0.ai",
|
||||
providerMetadata: {
|
||||
mem0: {
|
||||
memory: memory,
|
||||
memoryText: memory.memory
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
text: ans.text,
|
||||
finishReason: ans.finishReason,
|
||||
usage: ans.usage,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
response: ans.response,
|
||||
warnings: ans.warnings,
|
||||
...ans,
|
||||
sources
|
||||
};
|
||||
} catch (error) {
|
||||
// Handle errors properly
|
||||
@@ -101,44 +144,108 @@ export class Mem0GenericLanguageModel implements LanguageModelV1 {
|
||||
}
|
||||
}
|
||||
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<{
|
||||
stream: ReadableStream<LanguageModelV1StreamPart>;
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
}> {
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const mem0_api_key = this.config.mem0ApiKey;
|
||||
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: this.config.modelType,
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
|
||||
const mem0Config: Mem0ConfigSettings = {
|
||||
mem0ApiKey: mem0_api_key,
|
||||
...this.config.mem0Config,
|
||||
...this.settings,
|
||||
}
|
||||
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings, this.provider_config);
|
||||
|
||||
let messagesPrompts = options.prompt;
|
||||
|
||||
// Process memories and update prompts
|
||||
const { memories, messagesPrompts: updatedPrompts } = await this.processMemories(messagesPrompts, mem0Config);
|
||||
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
|
||||
const apiKey = mem0_api_key;
|
||||
const streamResponse = await model.doStream({
|
||||
...options,
|
||||
prompt: updatedPrompts,
|
||||
});
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
const response = await model.streamText(messagesPrompts, config);
|
||||
// @ts-ignore
|
||||
const filteredStream = await filterStream(response.originalStream);
|
||||
return {
|
||||
// @ts-ignore
|
||||
stream: filteredStream,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
// If there are no memories, return the original stream
|
||||
if (!memories || memories.length === 0) {
|
||||
return streamResponse;
|
||||
}
|
||||
|
||||
// Create a new stream that includes memory sources
|
||||
const originalStream = streamResponse.stream;
|
||||
|
||||
// Create a transform stream that adds memory sources at the beginning
|
||||
const transformStream = new TransformStream({
|
||||
start(controller) {
|
||||
// Add source chunks for each memory at the beginning
|
||||
try {
|
||||
if (Array.isArray(memories) && memories.length > 0) {
|
||||
// Create a single source that contains all memories
|
||||
controller.enqueue({
|
||||
type: 'source',
|
||||
source: {
|
||||
title: "Mem0 Memories",
|
||||
sourceType: "url",
|
||||
id: "mem0-" + generateRandomId(),
|
||||
url: "https://app.mem0.ai",
|
||||
providerMetadata: {
|
||||
mem0: {
|
||||
memories: memories,
|
||||
memoriesText: memories.map((memory: any) => memory.memory).join("\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Also add individual memory sources for more detailed information
|
||||
memories.forEach((memory: any) => {
|
||||
controller.enqueue({
|
||||
type: 'source',
|
||||
source: {
|
||||
title: memory.title || "Memory",
|
||||
sourceType: "url",
|
||||
id: "mem0-memory-" + generateRandomId(),
|
||||
url: "https://app.mem0.ai",
|
||||
providerMetadata: {
|
||||
mem0: {
|
||||
memory: memory,
|
||||
memoryText: memory.memory
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error adding memory sources:", error);
|
||||
}
|
||||
},
|
||||
...response,
|
||||
transform(chunk, controller) {
|
||||
// Pass through all chunks from the original stream
|
||||
controller.enqueue(chunk);
|
||||
}
|
||||
});
|
||||
|
||||
// Pipe the original stream through our transform stream
|
||||
const enhancedStream = originalStream.pipeThrough(transformStream);
|
||||
|
||||
// Return a new stream response with our enhanced stream
|
||||
return {
|
||||
stream: enhancedStream,
|
||||
rawCall: streamResponse.rawCall,
|
||||
rawResponse: streamResponse.rawResponse,
|
||||
request: streamResponse.request,
|
||||
warnings: streamResponse.warnings
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error in doStream:", error);
|
||||
|
||||
Reference in New Issue
Block a user