Update Vercel AI SDK to support tools call (#2383)
This commit is contained in:
@@ -231,4 +231,4 @@ We also have support for `agent_id`, `app_id`, and `run_id`. Refer [Docs](https:
|
||||
- Requires proper API key configuration for underlying providers (e.g., OpenAI)
|
||||
- Memory features depend on proper user identification via `user_id`
|
||||
- Supports both streaming and non-streaming responses
|
||||
- Compatible with all Vercel AI SDK features and patterns
|
||||
- Compatible with all Vercel AI SDK features and patterns
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@mem0/vercel-ai-provider",
|
||||
"version": "0.0.10",
|
||||
"version": "0.0.14",
|
||||
"description": "Vercel AI Provider for providing memory to LLMs",
|
||||
"main": "./dist/index.js",
|
||||
"module": "./dist/index.mjs",
|
||||
@@ -26,16 +26,16 @@
|
||||
"author": "Saket Aryan <saketaryan2002@gmail.com>",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@ai-sdk/anthropic": "^0.0.54",
|
||||
"@ai-sdk/cohere": "^0.0.28",
|
||||
"@ai-sdk/groq": "^0.0.3",
|
||||
"@ai-sdk/openai": "^0.0.71",
|
||||
"@ai-sdk/provider": "^0.0.26",
|
||||
"@ai-sdk/provider-utils": "^1.0.22",
|
||||
"ai": "^3.4.31",
|
||||
"@ai-sdk/anthropic": "1.1.12",
|
||||
"@ai-sdk/cohere": "1.1.12",
|
||||
"@ai-sdk/groq": "1.1.11",
|
||||
"@ai-sdk/openai": "1.1.15",
|
||||
"@ai-sdk/provider": "1.0.9",
|
||||
"@ai-sdk/provider-utils": "2.1.10",
|
||||
"ai": "4.1.46",
|
||||
"dotenv": "^16.4.5",
|
||||
"mem0ai": "^1.0.29",
|
||||
"partial-json": "0.1.7",
|
||||
"ts-node": "^10.9.2",
|
||||
"zod": "^3.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -45,8 +45,9 @@
|
||||
"jest": "^29.7.0",
|
||||
"nodemon": "^3.1.7",
|
||||
"ts-jest": "^29.2.5",
|
||||
"ts-node": "^10.9.2",
|
||||
"tsup": "^8.3.0",
|
||||
"typescript": "5.5.4"
|
||||
"typescript": "^5.5.4"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"zod": "^3.0.0"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
export * from './mem0-facade'
|
||||
export type { Mem0Provider, Mem0ProviderSettings } from './mem0-provider'
|
||||
export { createMem0, mem0 } from './mem0-provider'
|
||||
export type { Mem0ConfigSettings, Mem0ChatConfig, Mem0ChatSettings } from './mem0-types'
|
||||
export {addMemories, retrieveMemories, searchMemories, getMemories } from './mem0-utils'
|
||||
@@ -1,150 +0,0 @@
|
||||
/* eslint-disable camelcase */
|
||||
import {
|
||||
LanguageModelV1,
|
||||
LanguageModelV1CallOptions,
|
||||
LanguageModelV1CallWarning,
|
||||
LanguageModelV1FinishReason,
|
||||
LanguageModelV1FunctionToolCall,
|
||||
LanguageModelV1LogProbs,
|
||||
LanguageModelV1ProviderMetadata,
|
||||
LanguageModelV1StreamPart,
|
||||
} from "@ai-sdk/provider";
|
||||
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
|
||||
import { Mem0ClassSelector } from "./mem0-provider-selector";
|
||||
import { filterStream } from "./stream-utils";
|
||||
import { Mem0Config } from "./mem0-chat-settings";
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
|
||||
|
||||
interface Mem0ChatConfig {
|
||||
baseURL: string;
|
||||
fetch?: typeof fetch;
|
||||
headers: () => Record<string, string | undefined>;
|
||||
provider: string;
|
||||
organization?: string;
|
||||
project?: string;
|
||||
name?: string;
|
||||
apiKey?: string;
|
||||
mem0_api_key?: string;
|
||||
}
|
||||
|
||||
export class Mem0ChatLanguageModel implements LanguageModelV1 {
|
||||
readonly specificationVersion = "v1";
|
||||
readonly defaultObjectGenerationMode = "json";
|
||||
readonly supportsImageUrls = false;
|
||||
|
||||
constructor(
|
||||
public readonly modelId: Mem0ChatModelId,
|
||||
public readonly settings: Mem0ChatSettings,
|
||||
public readonly config: Mem0ChatConfig,
|
||||
public readonly provider_config?: OpenAIProviderSettings
|
||||
) {
|
||||
this.provider = config.provider;
|
||||
}
|
||||
|
||||
provider: string;
|
||||
supportsStructuredOutputs?: boolean | undefined;
|
||||
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
|
||||
text?: string;
|
||||
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
|
||||
finishReason: LanguageModelV1FinishReason;
|
||||
usage: { promptTokens: number; completionTokens: number };
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
response?: { id?: string; timestamp?: Date; modelId?: string };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
providerMetadata?: LanguageModelV1ProviderMetadata;
|
||||
logprobs?: LanguageModelV1LogProbs;
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: "chat"
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
|
||||
const ans = await model.generateText(messagesPrompts, config);
|
||||
|
||||
|
||||
return {
|
||||
text: ans.text,
|
||||
finishReason: ans.finishReason,
|
||||
usage: ans.usage,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
response: ans.response,
|
||||
warnings: ans.warnings,
|
||||
};
|
||||
} catch (error) {
|
||||
// Handle errors properly
|
||||
console.error("Error in doGenerate:", error);
|
||||
throw new Error("Failed to generate response.");
|
||||
}
|
||||
}
|
||||
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<{
|
||||
stream: ReadableStream<LanguageModelV1StreamPart>;
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: "chat"
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
const response = await model.streamText(messagesPrompts, config);
|
||||
// @ts-ignore
|
||||
const filteredStream = await filterStream(response.originalStream);
|
||||
return {
|
||||
// @ts-ignore
|
||||
stream: filteredStream,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
...response,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error in doStream:", error);
|
||||
throw new Error("Streaming failed or method not implemented.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
import { OpenAIChatSettings } from "@ai-sdk/openai/internal";
|
||||
|
||||
export type Mem0ChatModelId =
|
||||
| "o1-preview"
|
||||
| "o1-mini"
|
||||
| "gpt-4o"
|
||||
| "gpt-4o-2024-05-13"
|
||||
| "gpt-4o-2024-08-06"
|
||||
| "gpt-4o-audio-preview"
|
||||
| "gpt-4o-audio-preview-2024-10-01"
|
||||
| "gpt-4o-mini"
|
||||
| "gpt-4o-mini-2024-07-18"
|
||||
| "gpt-4-turbo"
|
||||
| "gpt-4-turbo-2024-04-09"
|
||||
| "gpt-4-turbo-preview"
|
||||
| "gpt-4-0125-preview"
|
||||
| "gpt-4-1106-preview"
|
||||
| "gpt-4"
|
||||
| "gpt-4-0613"
|
||||
| "gpt-3.5-turbo-0125"
|
||||
| "gpt-3.5-turbo"
|
||||
| "gpt-3.5-turbo-1106"
|
||||
| (string & NonNullable<unknown>);
|
||||
|
||||
export interface Mem0ChatSettings extends OpenAIChatSettings {
|
||||
user_id?: string;
|
||||
app_id?: string;
|
||||
agent_id?: string;
|
||||
run_id?: string;
|
||||
org_name?: string;
|
||||
project_name?: string;
|
||||
mem0ApiKey?: string;
|
||||
structuredOutputs?: boolean;
|
||||
org_id?: string;
|
||||
project_id?: string;
|
||||
metadata?: Record<string, any>;
|
||||
filters?: Record<string, any>;
|
||||
infer?: boolean;
|
||||
page?: number;
|
||||
page_size?: number;
|
||||
}
|
||||
|
||||
export interface Mem0Config extends Mem0ChatSettings {}
|
||||
@@ -1,150 +0,0 @@
|
||||
/* eslint-disable camelcase */
|
||||
import {
|
||||
LanguageModelV1,
|
||||
LanguageModelV1CallOptions,
|
||||
LanguageModelV1CallWarning,
|
||||
LanguageModelV1FinishReason,
|
||||
LanguageModelV1FunctionToolCall,
|
||||
LanguageModelV1LogProbs,
|
||||
LanguageModelV1ProviderMetadata,
|
||||
LanguageModelV1StreamPart,
|
||||
} from "@ai-sdk/provider";
|
||||
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
|
||||
import { Mem0ClassSelector } from "./mem0-provider-selector";
|
||||
import { filterStream } from "./stream-utils";
|
||||
import { Mem0Config } from "./mem0-completion-settings";
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
|
||||
|
||||
interface Mem0CompletionConfig {
|
||||
baseURL: string;
|
||||
fetch?: typeof fetch;
|
||||
headers: () => Record<string, string | undefined>;
|
||||
provider: string;
|
||||
organization?: string;
|
||||
project?: string;
|
||||
name?: string;
|
||||
apiKey?: string;
|
||||
mem0_api_key?: string;
|
||||
}
|
||||
|
||||
export class Mem0CompletionLanguageModel implements LanguageModelV1 {
|
||||
readonly specificationVersion = "v1";
|
||||
readonly defaultObjectGenerationMode = "json";
|
||||
readonly supportsImageUrls = false;
|
||||
|
||||
constructor(
|
||||
public readonly modelId: Mem0ChatModelId,
|
||||
public readonly settings: Mem0ChatSettings,
|
||||
public readonly config: Mem0CompletionConfig,
|
||||
public readonly provider_config?: OpenAIProviderSettings
|
||||
) {
|
||||
this.provider = config.provider;
|
||||
}
|
||||
|
||||
provider: string;
|
||||
supportsStructuredOutputs?: boolean | undefined;
|
||||
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
|
||||
text?: string;
|
||||
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
|
||||
finishReason: LanguageModelV1FinishReason;
|
||||
usage: { promptTokens: number; completionTokens: number };
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
response?: { id?: string; timestamp?: Date; modelId?: string };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
providerMetadata?: LanguageModelV1ProviderMetadata;
|
||||
logprobs?: LanguageModelV1LogProbs;
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: "completion"
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey, modelType: "completion"};
|
||||
|
||||
const ans = await model.generateText(messagesPrompts, config);
|
||||
|
||||
|
||||
return {
|
||||
text: ans.text,
|
||||
finishReason: ans.finishReason,
|
||||
usage: ans.usage,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
response: ans.response,
|
||||
warnings: ans.warnings,
|
||||
};
|
||||
} catch (error) {
|
||||
// Handle errors properly
|
||||
console.error("Error in doGenerate:", error);
|
||||
throw new Error("Failed to generate response.");
|
||||
}
|
||||
}
|
||||
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<{
|
||||
stream: ReadableStream<LanguageModelV1StreamPart>;
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: "completion"
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey, modelType: "completion"};
|
||||
const response = await model.streamText(messagesPrompts, config);
|
||||
// @ts-ignore
|
||||
const filteredStream = await filterStream(response.originalStream);
|
||||
return {
|
||||
// @ts-ignore
|
||||
stream: filteredStream,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
...response,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error in doStream:", error);
|
||||
throw new Error("Streaming failed or method not implemented.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
import { OpenAICompletionSettings } from "@ai-sdk/openai/internal";
|
||||
|
||||
export type Mem0CompletionModelId =
|
||||
| "gpt-3.5-turbo"
|
||||
| (string & NonNullable<unknown>);
|
||||
|
||||
export interface Mem0CompletionSettings extends OpenAICompletionSettings {
|
||||
user_id?: string;
|
||||
app_id?: string;
|
||||
agent_id?: string;
|
||||
run_id?: string;
|
||||
org_name?: string;
|
||||
project_name?: string;
|
||||
mem0ApiKey?: string;
|
||||
structuredOutputs?: boolean;
|
||||
modelType?: string;
|
||||
}
|
||||
|
||||
export interface Mem0Config extends Mem0CompletionSettings {}
|
||||
@@ -1,13 +1,12 @@
|
||||
import { withoutTrailingSlash } from '@ai-sdk/provider-utils'
|
||||
|
||||
import { Mem0ChatLanguageModel } from './mem0-chat-language-model'
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from './mem0-chat-settings'
|
||||
import { Mem0GenericLanguageModel } from './mem0-generic-language-model'
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from './mem0-types'
|
||||
import { Mem0ProviderSettings } from './mem0-provider'
|
||||
|
||||
export class Mem0 {
|
||||
readonly baseURL: string
|
||||
|
||||
readonly headers?: Record<string, string>
|
||||
readonly headers?: any
|
||||
|
||||
constructor(options: Mem0ProviderSettings = {
|
||||
provider: 'openai',
|
||||
@@ -21,15 +20,22 @@ export class Mem0 {
|
||||
private get baseConfig() {
|
||||
return {
|
||||
baseURL: this.baseURL,
|
||||
headers: () => ({
|
||||
...this.headers,
|
||||
}),
|
||||
headers: this.headers,
|
||||
}
|
||||
}
|
||||
|
||||
chat(modelId: Mem0ChatModelId, settings: Mem0ChatSettings = {}) {
|
||||
return new Mem0ChatLanguageModel(modelId, settings, {
|
||||
return new Mem0GenericLanguageModel(modelId, settings, {
|
||||
provider: 'openai',
|
||||
modelType: 'chat',
|
||||
...this.baseConfig,
|
||||
})
|
||||
}
|
||||
|
||||
completion(modelId: Mem0ChatModelId, settings: Mem0ChatSettings = {}) {
|
||||
return new Mem0GenericLanguageModel(modelId, settings, {
|
||||
provider: 'openai',
|
||||
modelType: 'completion',
|
||||
...this.baseConfig,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,32 +2,16 @@
|
||||
import {
|
||||
LanguageModelV1,
|
||||
LanguageModelV1CallOptions,
|
||||
LanguageModelV1CallWarning,
|
||||
LanguageModelV1FinishReason,
|
||||
LanguageModelV1FunctionToolCall,
|
||||
LanguageModelV1LogProbs,
|
||||
LanguageModelV1ProviderMetadata,
|
||||
LanguageModelV1StreamPart,
|
||||
LanguageModelV1Message,
|
||||
} from "@ai-sdk/provider";
|
||||
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
|
||||
import { Mem0ChatConfig, Mem0ChatModelId, Mem0ChatSettings, Mem0ConfigSettings, Mem0StreamResponse } from "./mem0-types";
|
||||
import { Mem0ClassSelector } from "./mem0-provider-selector";
|
||||
import { filterStream } from "./stream-utils";
|
||||
import { Mem0Config } from "./mem0-chat-settings";
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
import { addMemories, getMemories, retrieveMemories } from "./mem0-utils";
|
||||
|
||||
|
||||
interface Mem0ChatConfig {
|
||||
baseURL: string;
|
||||
fetch?: typeof fetch;
|
||||
headers: () => Record<string, string | undefined>;
|
||||
provider: string;
|
||||
organization?: string;
|
||||
project?: string;
|
||||
name?: string;
|
||||
apiKey?: string;
|
||||
mem0_api_key?: string;
|
||||
const generateRandomId = () => {
|
||||
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15);
|
||||
}
|
||||
|
||||
export class Mem0GenericLanguageModel implements LanguageModelV1 {
|
||||
@@ -39,60 +23,119 @@ export class Mem0GenericLanguageModel implements LanguageModelV1 {
|
||||
public readonly modelId: Mem0ChatModelId,
|
||||
public readonly settings: Mem0ChatSettings,
|
||||
public readonly config: Mem0ChatConfig,
|
||||
public readonly provider_config?: OpenAIProviderSettings
|
||||
public readonly provider_config?: Mem0ProviderSettings
|
||||
) {
|
||||
this.provider = config.provider;
|
||||
this.provider = config.provider ?? "openai";
|
||||
}
|
||||
|
||||
provider: string;
|
||||
supportsStructuredOutputs?: boolean | undefined;
|
||||
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
|
||||
text?: string;
|
||||
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
|
||||
finishReason: LanguageModelV1FinishReason;
|
||||
usage: { promptTokens: number; completionTokens: number };
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
response?: { id?: string; timestamp?: Date; modelId?: string };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
providerMetadata?: LanguageModelV1ProviderMetadata;
|
||||
logprobs?: LanguageModelV1LogProbs;
|
||||
}> {
|
||||
private async processMemories(messagesPrompts: LanguageModelV1Message[], mem0Config: Mem0ConfigSettings) {
|
||||
const memories = await getMemories(messagesPrompts, mem0Config);
|
||||
// Add New Memories
|
||||
await addMemories(messagesPrompts, mem0Config);
|
||||
const mySystemPrompt = "These are the memories I have stored. Give more weightage to the question by users and try to answer that first. You have to modify your answer based on the memories I have provided. If the memories are irrelevant you can ignore them. Also don't reply to this section of the prompt, or the memories, they are only for your reference. The System prompt starts after text System Message: \n\n";
|
||||
|
||||
let memoriesText = "";
|
||||
try {
|
||||
// @ts-ignore
|
||||
memoriesText = memories.map((memory: any) => {
|
||||
return `Memory: ${memory.memory}\n\n`;
|
||||
}).join("\n\n");
|
||||
} catch(e) {
|
||||
console.error("Error while parsing memories");
|
||||
}
|
||||
|
||||
const memoriesPrompt = `System Message: ${mySystemPrompt} ${memoriesText}`;
|
||||
|
||||
// System Prompt - The memories go as a system prompt
|
||||
const systemPrompt: LanguageModelV1Message = {
|
||||
role: "system",
|
||||
content: memoriesPrompt
|
||||
};
|
||||
|
||||
// Add the system prompt to the beginning of the messages if there are memories
|
||||
if (memories.length > 0) {
|
||||
messagesPrompts.unshift(systemPrompt);
|
||||
}
|
||||
|
||||
return { memories, messagesPrompts };
|
||||
}
|
||||
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const mem0_api_key = this.config.mem0ApiKey;
|
||||
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
|
||||
const mem0Config: Mem0ConfigSettings = {
|
||||
mem0ApiKey: mem0_api_key,
|
||||
...this.config.mem0Config,
|
||||
...this.settings,
|
||||
}
|
||||
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings, this.provider_config);
|
||||
|
||||
let messagesPrompts = options.prompt;
|
||||
|
||||
// Process memories and update prompts
|
||||
const { memories, messagesPrompts: updatedPrompts } = await this.processMemories(messagesPrompts, mem0Config);
|
||||
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
|
||||
const ans = await model.generateText(messagesPrompts, config);
|
||||
const ans = await model.doGenerate({
|
||||
...options,
|
||||
prompt: updatedPrompts,
|
||||
});
|
||||
|
||||
// If there are no memories, return the original response
|
||||
if (!memories || memories.length === 0) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
// Create sources array with existing sources
|
||||
const sources = [...(ans.sources || [])];
|
||||
|
||||
// Add a combined source with all memories
|
||||
if (Array.isArray(memories) && memories.length > 0) {
|
||||
sources.push({
|
||||
title: "Mem0 Memories",
|
||||
sourceType: "url",
|
||||
id: "mem0-" + generateRandomId(),
|
||||
url: "https://app.mem0.ai",
|
||||
providerMetadata: {
|
||||
mem0: {
|
||||
memories: memories,
|
||||
memoriesText: memories.map((memory: any) => memory.memory).join("\n\n")
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Add individual memory sources for more detailed information
|
||||
memories.forEach((memory: any) => {
|
||||
sources.push({
|
||||
title: memory.title || "Memory",
|
||||
sourceType: "url",
|
||||
id: "mem0-memory-" + generateRandomId(),
|
||||
url: "https://app.mem0.ai",
|
||||
providerMetadata: {
|
||||
mem0: {
|
||||
memory: memory,
|
||||
memoryText: memory.memory
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
text: ans.text,
|
||||
finishReason: ans.finishReason,
|
||||
usage: ans.usage,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
response: ans.response,
|
||||
warnings: ans.warnings,
|
||||
...ans,
|
||||
sources
|
||||
};
|
||||
} catch (error) {
|
||||
// Handle errors properly
|
||||
@@ -101,44 +144,108 @@ export class Mem0GenericLanguageModel implements LanguageModelV1 {
|
||||
}
|
||||
}
|
||||
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<{
|
||||
stream: ReadableStream<LanguageModelV1StreamPart>;
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
}> {
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const mem0_api_key = this.config.mem0ApiKey;
|
||||
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: this.config.modelType,
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
|
||||
const mem0Config: Mem0ConfigSettings = {
|
||||
mem0ApiKey: mem0_api_key,
|
||||
...this.config.mem0Config,
|
||||
...this.settings,
|
||||
}
|
||||
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings, this.provider_config);
|
||||
|
||||
let messagesPrompts = options.prompt;
|
||||
|
||||
// Process memories and update prompts
|
||||
const { memories, messagesPrompts: updatedPrompts } = await this.processMemories(messagesPrompts, mem0Config);
|
||||
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
|
||||
const apiKey = mem0_api_key;
|
||||
const streamResponse = await model.doStream({
|
||||
...options,
|
||||
prompt: updatedPrompts,
|
||||
});
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
const response = await model.streamText(messagesPrompts, config);
|
||||
// @ts-ignore
|
||||
const filteredStream = await filterStream(response.originalStream);
|
||||
return {
|
||||
// @ts-ignore
|
||||
stream: filteredStream,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
// If there are no memories, return the original stream
|
||||
if (!memories || memories.length === 0) {
|
||||
return streamResponse;
|
||||
}
|
||||
|
||||
// Create a new stream that includes memory sources
|
||||
const originalStream = streamResponse.stream;
|
||||
|
||||
// Create a transform stream that adds memory sources at the beginning
|
||||
const transformStream = new TransformStream({
|
||||
start(controller) {
|
||||
// Add source chunks for each memory at the beginning
|
||||
try {
|
||||
if (Array.isArray(memories) && memories.length > 0) {
|
||||
// Create a single source that contains all memories
|
||||
controller.enqueue({
|
||||
type: 'source',
|
||||
source: {
|
||||
title: "Mem0 Memories",
|
||||
sourceType: "url",
|
||||
id: "mem0-" + generateRandomId(),
|
||||
url: "https://app.mem0.ai",
|
||||
providerMetadata: {
|
||||
mem0: {
|
||||
memories: memories,
|
||||
memoriesText: memories.map((memory: any) => memory.memory).join("\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Also add individual memory sources for more detailed information
|
||||
memories.forEach((memory: any) => {
|
||||
controller.enqueue({
|
||||
type: 'source',
|
||||
source: {
|
||||
title: memory.title || "Memory",
|
||||
sourceType: "url",
|
||||
id: "mem0-memory-" + generateRandomId(),
|
||||
url: "https://app.mem0.ai",
|
||||
providerMetadata: {
|
||||
mem0: {
|
||||
memory: memory,
|
||||
memoryText: memory.memory
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error adding memory sources:", error);
|
||||
}
|
||||
},
|
||||
...response,
|
||||
transform(chunk, controller) {
|
||||
// Pass through all chunks from the original stream
|
||||
controller.enqueue(chunk);
|
||||
}
|
||||
});
|
||||
|
||||
// Pipe the original stream through our transform stream
|
||||
const enhancedStream = originalStream.pipeThrough(transformStream);
|
||||
|
||||
// Return a new stream response with our enhanced stream
|
||||
return {
|
||||
stream: enhancedStream,
|
||||
rawCall: streamResponse.rawCall,
|
||||
rawResponse: streamResponse.rawResponse,
|
||||
request: streamResponse.request,
|
||||
warnings: streamResponse.warnings
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error in doStream:", error);
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
import Mem0AITextGenerator, { ProviderSettings } from "./provider-response-provider";
|
||||
import { LanguageModelV1 } from "ai";
|
||||
|
||||
class Mem0ClassSelector {
|
||||
modelId: string;
|
||||
provider_wrapper: string;
|
||||
model: string;
|
||||
config: Mem0ProviderSettings;
|
||||
provider_config?: ProviderSettings;
|
||||
static supportedProviders = ["openai", "anthropic", "cohere", "groq"];
|
||||
@@ -13,7 +12,6 @@ class Mem0ClassSelector {
|
||||
constructor(modelId: string, config: Mem0ProviderSettings, provider_config?: ProviderSettings) {
|
||||
this.modelId = modelId;
|
||||
this.provider_wrapper = config.provider || "openai";
|
||||
this.model = this.modelId;
|
||||
this.provider_config = provider_config;
|
||||
if(config) this.config = config;
|
||||
else this.config = {
|
||||
@@ -26,8 +24,8 @@ class Mem0ClassSelector {
|
||||
}
|
||||
}
|
||||
|
||||
createProvider() {
|
||||
return new Mem0AITextGenerator(this.provider_wrapper, this.model, this.config , this.provider_config || {});
|
||||
createProvider(): LanguageModelV1 {
|
||||
return new Mem0AITextGenerator(this.modelId, this.config , this.provider_config || {});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,145 +1,145 @@
|
||||
import { LanguageModelV1, ProviderV1 } from '@ai-sdk/provider'
|
||||
import { withoutTrailingSlash } from '@ai-sdk/provider-utils'
|
||||
|
||||
import { Mem0ChatLanguageModel } from './mem0-chat-language-model'
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from './mem0-chat-settings'
|
||||
import { OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { Mem0CompletionModelId, Mem0CompletionSettings } from './mem0-completion-settings'
|
||||
import { Mem0GenericLanguageModel } from './mem0-generic-language-model'
|
||||
import { Mem0CompletionLanguageModel } from './mem0-completion-language-model'
|
||||
|
||||
import { LanguageModelV1, ProviderV1 } from "@ai-sdk/provider";
|
||||
import { loadApiKey, withoutTrailingSlash } from "@ai-sdk/provider-utils";
|
||||
import { Mem0ChatModelId, Mem0ChatSettings, Mem0Config } from "./mem0-types";
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0GenericLanguageModel } from "./mem0-generic-language-model";
|
||||
import { OpenAIChatSettings } from "@ai-sdk/openai/internal";
|
||||
import { AnthropicMessagesSettings } from "@ai-sdk/anthropic/internal";
|
||||
import { AnthropicProviderSettings } from "@ai-sdk/anthropic";
|
||||
|
||||
export interface Mem0Provider extends ProviderV1 {
|
||||
(modelId: Mem0ChatModelId, settings?: Mem0ChatSettings): LanguageModelV1
|
||||
|
||||
chat(
|
||||
modelId: Mem0ChatModelId,
|
||||
settings?: Mem0ChatSettings,
|
||||
): LanguageModelV1
|
||||
(modelId: Mem0ChatModelId, settings?: Mem0ChatSettings): LanguageModelV1;
|
||||
|
||||
chat(modelId: Mem0ChatModelId, settings?: Mem0ChatSettings): LanguageModelV1;
|
||||
completion(modelId: Mem0ChatModelId, settings?: Mem0ChatSettings): LanguageModelV1;
|
||||
|
||||
languageModel(
|
||||
modelId: Mem0ChatModelId,
|
||||
settings?: Mem0ChatSettings,
|
||||
): LanguageModelV1
|
||||
|
||||
completion(
|
||||
modelId: Mem0CompletionModelId,
|
||||
settings?: Mem0CompletionSettings,
|
||||
): LanguageModelV1
|
||||
settings?: Mem0ChatSettings
|
||||
): LanguageModelV1;
|
||||
}
|
||||
|
||||
export interface Mem0ProviderSettings extends OpenAIProviderSettings {
|
||||
baseURL?: string
|
||||
export interface Mem0ProviderSettings
|
||||
extends OpenAIChatSettings,
|
||||
AnthropicMessagesSettings {
|
||||
baseURL?: string;
|
||||
/**
|
||||
* Custom fetch implementation. You can use it as a middleware to intercept
|
||||
* requests or to provide a custom fetch implementation for e.g. testing
|
||||
*/
|
||||
fetch?: typeof fetch
|
||||
fetch?: typeof fetch;
|
||||
/**
|
||||
* @internal
|
||||
*/
|
||||
generateId?: () => string
|
||||
generateId?: () => string;
|
||||
/**
|
||||
* Custom headers to include in the requests.
|
||||
*/
|
||||
headers?: Record<string, string>
|
||||
organization?: string;
|
||||
project?: string;
|
||||
headers?: Record<string, string>;
|
||||
name?: string;
|
||||
mem0ApiKey?: string;
|
||||
apiKey?: string;
|
||||
provider?: string;
|
||||
config?: OpenAIProviderSettings;
|
||||
modelType?: "completion" | "chat";
|
||||
mem0Config?: Mem0Config;
|
||||
|
||||
/**
|
||||
* The configuration for the provider.
|
||||
*/
|
||||
config?: OpenAIProviderSettings | AnthropicProviderSettings;
|
||||
}
|
||||
|
||||
export function createMem0(
|
||||
options: Mem0ProviderSettings = {
|
||||
provider: "openai",
|
||||
},
|
||||
}
|
||||
): Mem0Provider {
|
||||
const baseURL =
|
||||
withoutTrailingSlash(options.baseURL) ?? 'http://127.0.0.1:11434/api'
|
||||
|
||||
withoutTrailingSlash(options.baseURL) ?? "http://api.openai.com";
|
||||
const getHeaders = () => ({
|
||||
...options.headers,
|
||||
})
|
||||
});
|
||||
|
||||
const createGenericModel = (
|
||||
modelId: Mem0ChatModelId,
|
||||
settings: Mem0ChatSettings = {},
|
||||
settings: Mem0ChatSettings = {}
|
||||
) =>
|
||||
new Mem0GenericLanguageModel(modelId, settings, {
|
||||
baseURL,
|
||||
fetch: options.fetch,
|
||||
headers: getHeaders,
|
||||
provider: options.provider || "openai",
|
||||
organization: options.organization,
|
||||
project: options.project,
|
||||
name: options.name,
|
||||
mem0_api_key: options.mem0ApiKey,
|
||||
apiKey: options.apiKey,
|
||||
}, options.config)
|
||||
new Mem0GenericLanguageModel(
|
||||
modelId,
|
||||
settings,
|
||||
{
|
||||
baseURL,
|
||||
fetch: options.fetch,
|
||||
headers: getHeaders(),
|
||||
provider: options.provider || "openai",
|
||||
name: options.name,
|
||||
mem0ApiKey: options.mem0ApiKey,
|
||||
apiKey: options.apiKey,
|
||||
mem0Config: options.mem0Config,
|
||||
},
|
||||
options.config
|
||||
);
|
||||
|
||||
const createCompletionModel = (
|
||||
modelId: Mem0ChatModelId,
|
||||
settings: Mem0ChatSettings = {}
|
||||
) =>
|
||||
new Mem0GenericLanguageModel(
|
||||
modelId,
|
||||
settings,
|
||||
{
|
||||
baseURL,
|
||||
fetch: options.fetch,
|
||||
headers: getHeaders(),
|
||||
provider: options.provider || "openai",
|
||||
name: options.name,
|
||||
mem0ApiKey: options.mem0ApiKey,
|
||||
apiKey: options.apiKey,
|
||||
mem0Config: options.mem0Config,
|
||||
modelType: "completion",
|
||||
},
|
||||
options.config
|
||||
);
|
||||
|
||||
const createChatModel = (
|
||||
modelId: Mem0ChatModelId,
|
||||
settings: Mem0ChatSettings = {},
|
||||
settings: Mem0ChatSettings = {}
|
||||
) =>
|
||||
|
||||
new Mem0ChatLanguageModel(modelId, settings, {
|
||||
baseURL,
|
||||
fetch: options.fetch,
|
||||
headers: getHeaders,
|
||||
provider: options.provider || "openai",
|
||||
organization: options.organization,
|
||||
project: options.project,
|
||||
name: options.name,
|
||||
mem0_api_key: options.mem0ApiKey,
|
||||
apiKey: options.apiKey,
|
||||
}, options.config)
|
||||
|
||||
const createCompletionModel = (
|
||||
modelId: Mem0CompletionModelId,
|
||||
settings: Mem0CompletionSettings = {}
|
||||
) =>
|
||||
new Mem0CompletionLanguageModel(
|
||||
modelId,
|
||||
settings,
|
||||
{
|
||||
baseURL,
|
||||
fetch: options.fetch,
|
||||
headers: getHeaders,
|
||||
provider: options.provider || "openai",
|
||||
organization: options.organization,
|
||||
project: options.project,
|
||||
name: options.name,
|
||||
mem0_api_key: options.mem0ApiKey,
|
||||
apiKey: options.apiKey
|
||||
},
|
||||
options.config
|
||||
);
|
||||
new Mem0GenericLanguageModel(
|
||||
modelId,
|
||||
settings,
|
||||
{
|
||||
baseURL,
|
||||
fetch: options.fetch,
|
||||
headers: getHeaders(),
|
||||
provider: options.provider || "openai",
|
||||
name: options.name,
|
||||
mem0ApiKey: options.mem0ApiKey,
|
||||
apiKey: options.apiKey,
|
||||
mem0Config: options.mem0Config,
|
||||
modelType: "completion",
|
||||
},
|
||||
options.config
|
||||
);
|
||||
|
||||
const provider = function (
|
||||
modelId: Mem0ChatModelId,
|
||||
settings?: Mem0ChatSettings,
|
||||
settings: Mem0ChatSettings = {}
|
||||
) {
|
||||
if (new.target) {
|
||||
throw new Error(
|
||||
'The Mem0 model function cannot be called with the new keyword.',
|
||||
)
|
||||
"The Mem0 model function cannot be called with the new keyword."
|
||||
);
|
||||
}
|
||||
|
||||
return createGenericModel(modelId, settings)
|
||||
}
|
||||
|
||||
return createGenericModel(modelId, settings);
|
||||
};
|
||||
|
||||
provider.languageModel = createGenericModel;
|
||||
provider.completion = createCompletionModel;
|
||||
provider.chat = createChatModel;
|
||||
|
||||
provider.chat = createChatModel
|
||||
provider.completion = createCompletionModel
|
||||
provider.languageModel = createChatModel
|
||||
|
||||
return provider as unknown as Mem0Provider
|
||||
return provider as unknown as Mem0Provider;
|
||||
}
|
||||
|
||||
export const mem0 = createMem0()
|
||||
export const mem0 = createMem0();
|
||||
|
||||
38
vercel-ai-sdk/src/mem0-types.ts
Normal file
38
vercel-ai-sdk/src/mem0-types.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
import { OpenAIChatSettings } from "@ai-sdk/openai/internal";
|
||||
import { AnthropicMessagesSettings } from "@ai-sdk/anthropic/internal";
|
||||
import {
|
||||
LanguageModelV1,
|
||||
LanguageModelV1CallOptions,
|
||||
LanguageModelV1Message,
|
||||
} from "@ai-sdk/provider";
|
||||
|
||||
export type Mem0ChatModelId =
|
||||
| (string & NonNullable<unknown>);
|
||||
|
||||
export interface Mem0ConfigSettings {
|
||||
user_id?: string;
|
||||
app_id?: string;
|
||||
agent_id?: string;
|
||||
run_id?: string;
|
||||
org_name?: string;
|
||||
project_name?: string;
|
||||
org_id?: string;
|
||||
project_id?: string;
|
||||
metadata?: Record<string, any>;
|
||||
filters?: Record<string, any>;
|
||||
infer?: boolean;
|
||||
page?: number;
|
||||
page_size?: number;
|
||||
mem0ApiKey?: string;
|
||||
top_k?: number;
|
||||
}
|
||||
|
||||
export interface Mem0ChatConfig extends Mem0ConfigSettings, Mem0ProviderSettings {}
|
||||
|
||||
export interface Mem0Config extends Mem0ConfigSettings {}
|
||||
export interface Mem0ChatSettings extends OpenAIChatSettings, AnthropicMessagesSettings, Mem0ConfigSettings {}
|
||||
|
||||
export interface Mem0StreamResponse extends Awaited<ReturnType<LanguageModelV1['doStream']>> {
|
||||
memories: any;
|
||||
}
|
||||
@@ -1,16 +1,6 @@
|
||||
import { LanguageModelV1Prompt } from 'ai';
|
||||
import { Mem0Config } from './mem0-chat-settings';
|
||||
if (typeof process !== 'undefined' && process.env && process.env.NODE_ENV !== 'production') {
|
||||
// Dynamically import dotenv only in non-production environments
|
||||
import('dotenv').then((dotenv) => dotenv.config());
|
||||
}
|
||||
|
||||
const tokenIsPresent = (config?: Mem0Config)=>{
|
||||
if(!config && !config!.mem0ApiKey && (typeof process !== 'undefined' && process.env && !process.env.MEM0_API_KEY)){
|
||||
throw Error("MEM0_API_KEY is not present. Please set env MEM0_API_KEY as the value of your API KEY.");
|
||||
}
|
||||
}
|
||||
|
||||
import { Mem0ConfigSettings } from './mem0-types';
|
||||
import { loadApiKey } from '@ai-sdk/provider-utils';
|
||||
interface Message {
|
||||
role: string;
|
||||
content: string | Array<{type: string, text: string}>;
|
||||
@@ -50,36 +40,7 @@ const convertToMem0Format = (messages: LanguageModelV1Prompt) => {
|
||||
}
|
||||
})};
|
||||
|
||||
function convertMessagesToMem0Format(messages: LanguageModelV1Prompt) {
|
||||
return messages.map((message) => {
|
||||
// If the content is a string, return it as is
|
||||
if (typeof message.content === "string") {
|
||||
return message;
|
||||
}
|
||||
|
||||
// Flatten the content array into a single string
|
||||
if (Array.isArray(message.content)) {
|
||||
message.content = message.content
|
||||
.map((contentItem) => {
|
||||
if ("text" in contentItem) {
|
||||
return contentItem.text;
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.join(" ");
|
||||
}
|
||||
|
||||
const contentText = message.content;
|
||||
|
||||
return {
|
||||
role: message.role,
|
||||
content: contentText,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
const searchInternalMemories = async (query: string, config?: Mem0Config, top_k: number = 5)=> {
|
||||
tokenIsPresent(config);
|
||||
const searchInternalMemories = async (query: string, config?: Mem0ConfigSettings, top_k: number = 5)=> {
|
||||
const filters = {
|
||||
OR: [
|
||||
{
|
||||
@@ -104,16 +65,20 @@ const searchInternalMemories = async (query: string, config?: Mem0Config, top_k:
|
||||
}
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {Authorization: `Token ${(config&&config.mem0ApiKey) || (typeof process !== 'undefined' && process.env && process.env.MEM0_API_KEY) || ""}`, 'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({query, filters, top_k, version: "v2", ...org_project_filters}),
|
||||
// headers: {Authorization: `Token ${(config&&config.mem0ApiKey) || (typeof process !== 'undefined' && process.env && process.env.MEM0_API_KEY) || ""}`, 'Content-Type': 'application/json'},
|
||||
headers: {Authorization: `Token ${loadApiKey({
|
||||
apiKey: (config&&config.mem0ApiKey),
|
||||
environmentVariableName: "MEM0_API_KEY",
|
||||
description: "Mem0",
|
||||
})}`, 'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({query, filters, top_k: config&&config.top_k || top_k, version: "v2", ...org_project_filters}),
|
||||
};
|
||||
const response = await fetch('https://api.mem0.ai/v2/memories/search/', options);
|
||||
const data = await response.json();
|
||||
return data;
|
||||
}
|
||||
|
||||
const addMemories = async (messages: LanguageModelV1Prompt, config?: Mem0Config)=>{
|
||||
tokenIsPresent(config);
|
||||
const addMemories = async (messages: LanguageModelV1Prompt, config?: Mem0ConfigSettings)=>{
|
||||
let finalMessages: Array<Message> = [];
|
||||
if (typeof messages === "string") {
|
||||
finalMessages = [{ role: "user", content: messages }];
|
||||
@@ -124,11 +89,14 @@ const addMemories = async (messages: LanguageModelV1Prompt, config?: Mem0Config)
|
||||
return response;
|
||||
}
|
||||
|
||||
const updateMemories = async (messages: Array<Message>, config?: Mem0Config)=>{
|
||||
tokenIsPresent(config);
|
||||
const updateMemories = async (messages: Array<Message>, config?: Mem0ConfigSettings)=>{
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {Authorization: `Token ${(config&&config.mem0ApiKey) || (typeof process !== 'undefined' && process.env && process.env.MEM0_API_KEY) || ""}`, 'Content-Type': 'application/json'},
|
||||
headers: {Authorization: `Token ${loadApiKey({
|
||||
apiKey: (config&&config.mem0ApiKey),
|
||||
environmentVariableName: "MEM0_API_KEY",
|
||||
description: "Mem0",
|
||||
})}`, 'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({messages, ...config}),
|
||||
};
|
||||
|
||||
@@ -137,8 +105,7 @@ const updateMemories = async (messages: Array<Message>, config?: Mem0Config)=>{
|
||||
return data;
|
||||
}
|
||||
|
||||
const retrieveMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0Config)=>{
|
||||
tokenIsPresent(config);
|
||||
const retrieveMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0ConfigSettings)=>{
|
||||
const message = typeof prompt === 'string' ? prompt : flattenPrompt(prompt);
|
||||
const systemPrompt = "These are the memories I have stored. Give more weightage to the question by users and try to answer that first. You have to modify your answer based on the memories I have provided. If the memories are irrelevant you can ignore them. Also don't reply to this section of the prompt, or the memories, they are only for your reference. The System prompt starts after text System Message: \n\n";
|
||||
const memories = await searchInternalMemories(message, config);
|
||||
@@ -158,8 +125,7 @@ const retrieveMemories = async (prompt: LanguageModelV1Prompt | string, config?:
|
||||
return `System Message: ${systemPrompt} ${memoriesText}`;
|
||||
}
|
||||
|
||||
const getMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0Config)=>{
|
||||
tokenIsPresent(config);
|
||||
const getMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0ConfigSettings)=>{
|
||||
const message = typeof prompt === 'string' ? prompt : flattenPrompt(prompt);
|
||||
let memories = [];
|
||||
try{
|
||||
@@ -172,8 +138,7 @@ const getMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0
|
||||
return memories;
|
||||
}
|
||||
|
||||
const searchMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0Config)=>{
|
||||
tokenIsPresent(config);
|
||||
const searchMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0ConfigSettings)=>{
|
||||
const message = typeof prompt === 'string' ? prompt : flattenPrompt(prompt);
|
||||
let memories = [];
|
||||
try{
|
||||
@@ -186,4 +151,4 @@ const searchMemories = async (prompt: LanguageModelV1Prompt | string, config?: M
|
||||
return memories;
|
||||
}
|
||||
|
||||
export {addMemories, updateMemories, retrieveMemories, flattenPrompt, searchMemories, convertMessagesToMem0Format, getMemories};
|
||||
export {addMemories, updateMemories, retrieveMemories, flattenPrompt, searchMemories, getMemories};
|
||||
@@ -1,8 +1,6 @@
|
||||
import { createOpenAI, OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { generateText as aiGenerateText, streamText as aiStreamText, LanguageModelV1Prompt } from "ai";
|
||||
import { updateMemories, retrieveMemories, flattenPrompt, convertMessagesToMem0Format } from "./mem0-utils";
|
||||
import { Mem0Config } from "./mem0-chat-settings";
|
||||
import { LanguageModelV1, LanguageModelV1CallOptions, LanguageModelV1Prompt } from "ai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
import { createOpenAI, OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { CohereProviderSettings, createCohere } from "@ai-sdk/cohere";
|
||||
import { AnthropicProviderSettings, createAnthropic } from "@ai-sdk/anthropic";
|
||||
import { createGroq, GroqProviderSettings } from "@ai-sdk/groq";
|
||||
@@ -10,19 +8,51 @@ import { createGroq, GroqProviderSettings } from "@ai-sdk/groq";
|
||||
export type Provider = ReturnType<typeof createOpenAI> | ReturnType<typeof createCohere> | ReturnType<typeof createAnthropic> | ReturnType<typeof createGroq> | any;
|
||||
export type ProviderSettings = OpenAIProviderSettings | CohereProviderSettings | AnthropicProviderSettings | GroqProviderSettings;
|
||||
|
||||
class Mem0AITextGenerator {
|
||||
const convertMessagesToMem0Format = (messages: LanguageModelV1Prompt) => {
|
||||
return messages.map((message) => {
|
||||
// If the content is a string, return it as is
|
||||
if (typeof message.content === "string") {
|
||||
return message;
|
||||
}
|
||||
|
||||
// Flatten the content array into a single string
|
||||
if (Array.isArray(message.content)) {
|
||||
message.content = message.content
|
||||
.map((contentItem) => {
|
||||
if ("text" in contentItem) {
|
||||
return contentItem.text;
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.join(" ");
|
||||
}
|
||||
|
||||
const contentText = message.content;
|
||||
|
||||
return {
|
||||
role: message.role,
|
||||
content: contentText,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
class Mem0AITextGenerator implements LanguageModelV1 {
|
||||
readonly specificationVersion = "v1";
|
||||
readonly defaultObjectGenerationMode = "json";
|
||||
readonly supportsImageUrls = false;
|
||||
readonly modelId: string;
|
||||
|
||||
provider: Provider;
|
||||
model: string;
|
||||
provider_config?: ProviderSettings;
|
||||
config: Mem0ProviderSettings;
|
||||
|
||||
constructor(provider: string, model: string, config: Mem0ProviderSettings, provider_config: ProviderSettings) {
|
||||
switch (provider) {
|
||||
constructor(modelId: string, config: Mem0ProviderSettings, provider_config: ProviderSettings) {
|
||||
switch (config.provider) {
|
||||
case "openai":
|
||||
this.provider = createOpenAI({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
});
|
||||
}).languageModel;
|
||||
if(config?.modelType === "completion"){
|
||||
this.provider = createOpenAI({
|
||||
apiKey: config?.apiKey,
|
||||
@@ -45,7 +75,7 @@ class Mem0AITextGenerator {
|
||||
this.provider = createAnthropic({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
});
|
||||
}).languageModel;
|
||||
break;
|
||||
case "groq":
|
||||
this.provider = createGroq({
|
||||
@@ -56,56 +86,18 @@ class Mem0AITextGenerator {
|
||||
default:
|
||||
throw new Error("Invalid provider");
|
||||
}
|
||||
this.model = model;
|
||||
this.modelId = modelId;
|
||||
this.provider_config = provider_config;
|
||||
this.config = config!;
|
||||
}
|
||||
|
||||
|
||||
async generateText(prompt: LanguageModelV1Prompt, config: Mem0Config) {
|
||||
try {
|
||||
const flattenPromptResponse = flattenPrompt(prompt);
|
||||
const newPrompt = await retrieveMemories(prompt, config);
|
||||
const response = await aiGenerateText({
|
||||
// @ts-ignore
|
||||
model: this.provider(this.model),
|
||||
messages: prompt,
|
||||
system: newPrompt
|
||||
});
|
||||
|
||||
const mem0Prompts = convertMessagesToMem0Format(prompt);
|
||||
|
||||
await updateMemories(mem0Prompts as any, config);
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
console.error("Error generating text:", error);
|
||||
throw error;
|
||||
}
|
||||
doGenerate(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
|
||||
return this.provider(this.modelId, this.provider_config).doGenerate(options);
|
||||
}
|
||||
|
||||
async streamText(prompt: LanguageModelV1Prompt, config: Mem0Config) {
|
||||
try {
|
||||
const flattenPromptResponse = flattenPrompt(prompt);
|
||||
const newPrompt = await retrieveMemories(prompt, config);
|
||||
|
||||
await updateMemories([
|
||||
{ role: "user", content: flattenPromptResponse },
|
||||
{ role: "assistant", content: "Thank You!" },
|
||||
], config);
|
||||
|
||||
const response = await aiStreamText({
|
||||
// @ts-ignore
|
||||
model: this.provider(this.model),
|
||||
messages: prompt,
|
||||
system: newPrompt
|
||||
});
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
console.error("Error generating text:", error);
|
||||
throw error;
|
||||
}
|
||||
doStream(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
|
||||
return this.provider(this.modelId, this.provider_config).doStream(options);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { generateObject } from "ai";
|
||||
import { testConfig } from "../config/test-config";
|
||||
import { z } from "zod";
|
||||
|
||||
interface Provider {
|
||||
name: string;
|
||||
activeModel: string;
|
||||
apiKey: string | undefined;
|
||||
}
|
||||
|
||||
const provider: Provider = {
|
||||
name: "anthropic",
|
||||
activeModel: "claude-3-5-sonnet-20240620",
|
||||
apiKey: process.env.ANTHROPIC_API_KEY,
|
||||
}
|
||||
describe("ANTHROPIC Structured Outputs", () => {
|
||||
const { userId } = testConfig;
|
||||
let mem0: ReturnType<typeof testConfig.createTestClient>;
|
||||
jest.setTimeout(30000);
|
||||
|
||||
beforeEach(() => {
|
||||
mem0 = testConfig.createTestClient(provider);
|
||||
});
|
||||
|
||||
describe("ANTHROPIC Object Generation Tests", () => {
|
||||
// Test 1: Generate a car preference object
|
||||
it("should generate a car preference object with name and steps", async () => {
|
||||
const { object } = await generateObject({
|
||||
model: mem0(provider.activeModel, {
|
||||
user_id: userId,
|
||||
}),
|
||||
schema: z.object({
|
||||
car: z.object({
|
||||
name: z.string(),
|
||||
steps: z.array(z.string()),
|
||||
}),
|
||||
}),
|
||||
prompt: "Which car would I like?",
|
||||
});
|
||||
|
||||
expect(object.car).toBeDefined();
|
||||
expect(typeof object.car.name).toBe("string");
|
||||
expect(Array.isArray(object.car.steps)).toBe(true);
|
||||
expect(object.car.steps.every((step) => typeof step === "string")).toBe(true);
|
||||
});
|
||||
|
||||
// Test 2: Generate an array of car objects
|
||||
it("should generate an array of three car objects with name, class, and description", async () => {
|
||||
const { object } = await generateObject({
|
||||
model: mem0(provider.activeModel, {
|
||||
user_id: userId,
|
||||
}),
|
||||
output: "array",
|
||||
schema: z.object({
|
||||
name: z.string(),
|
||||
class: z.string(),
|
||||
description: z.string(),
|
||||
}),
|
||||
prompt: "Write name of three cars that I would like.",
|
||||
});
|
||||
|
||||
expect(Array.isArray(object)).toBe(true);
|
||||
expect(object.length).toBe(3);
|
||||
object.forEach((car) => {
|
||||
expect(car).toHaveProperty("name");
|
||||
expect(typeof car.name).toBe("string");
|
||||
expect(car).toHaveProperty("class");
|
||||
expect(typeof car.class).toBe("string");
|
||||
expect(car).toHaveProperty("description");
|
||||
expect(typeof car.description).toBe("string");
|
||||
});
|
||||
});
|
||||
|
||||
// Test 3: Generate an enum for movie genre classification
|
||||
it("should classify the genre of a movie plot", async () => {
|
||||
const { object } = await generateObject({
|
||||
model: mem0(provider.activeModel, {
|
||||
user_id: userId,
|
||||
}),
|
||||
output: "enum",
|
||||
enum: ["action", "comedy", "drama", "horror", "sci-fi"],
|
||||
prompt: 'Classify the genre of this movie plot: "A group of astronauts travel through a wormhole in search of a new habitable planet for humanity."',
|
||||
});
|
||||
|
||||
expect(object).toBeDefined();
|
||||
expect(object).toBe("sci-fi");
|
||||
});
|
||||
|
||||
// Test 4: Generate an object of car names without schema
|
||||
it("should generate an object with car names", async () => {
|
||||
const { object } = await generateObject({
|
||||
model: mem0(provider.activeModel, {
|
||||
user_id: userId,
|
||||
}),
|
||||
output: "no-schema",
|
||||
prompt: "Write name of 3 cars that I would like.",
|
||||
});
|
||||
|
||||
const carObject = object as { cars: string[] };
|
||||
|
||||
expect(carObject).toBeDefined();
|
||||
expect(Array.isArray(carObject.cars)).toBe(true);
|
||||
expect(carObject.cars.length).toBe(3);
|
||||
expect(carObject.cars.every((car) => typeof car === "string")).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
59
vercel-ai-sdk/tests/mem0-provider-tests/mem0-cohere.test.ts
Normal file
59
vercel-ai-sdk/tests/mem0-provider-tests/mem0-cohere.test.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { createMem0, retrieveMemories } from "../../src";
|
||||
import { generateText, LanguageModelV1Prompt } from "ai";
|
||||
import { testConfig } from "../../config/test-config";
|
||||
import { createCohere } from "@ai-sdk/cohere";
|
||||
|
||||
describe("COHERE MEM0 Tests", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(30000);
|
||||
let mem0: any;
|
||||
|
||||
beforeEach(() => {
|
||||
mem0 = createMem0({
|
||||
provider: "cohere",
|
||||
apiKey: process.env.COHERE_API_KEY,
|
||||
mem0Config: {
|
||||
user_id: userId
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it("should retrieve memories and generate text using COHERE provider", async () => {
|
||||
const messages: LanguageModelV1Prompt = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "Suggest me a good car to buy." },
|
||||
{ type: "text", text: " Write only the car name and it's color." },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
const { text } = await generateText({
|
||||
// @ts-ignore
|
||||
model: mem0("command-r-plus"),
|
||||
messages: messages
|
||||
});
|
||||
|
||||
// Expect text to be a string
|
||||
expect(typeof text).toBe('string');
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("should generate text using COHERE provider with memories", async () => {
|
||||
const prompt = "Suggest me a good car to buy.";
|
||||
|
||||
const { text } = await generateText({
|
||||
// @ts-ignore
|
||||
model: mem0("command-r-plus"),
|
||||
prompt: prompt
|
||||
});
|
||||
|
||||
expect(typeof text).toBe('string');
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
60
vercel-ai-sdk/tests/mem0-provider-tests/mem0-groq.test.ts
Normal file
60
vercel-ai-sdk/tests/mem0-provider-tests/mem0-groq.test.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { createMem0, retrieveMemories } from "../../src";
|
||||
import { generateText, LanguageModelV1Prompt } from "ai";
|
||||
import { testConfig } from "../../config/test-config";
|
||||
import { createGroq } from "@ai-sdk/groq";
|
||||
|
||||
describe("GROQ MEM0 Tests", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(30000);
|
||||
|
||||
let mem0: any;
|
||||
|
||||
beforeEach(() => {
|
||||
mem0 = createMem0({
|
||||
provider: "groq",
|
||||
apiKey: process.env.GROQ_API_KEY,
|
||||
mem0Config: {
|
||||
user_id: userId
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it("should retrieve memories and generate text using GROQ provider", async () => {
|
||||
const messages: LanguageModelV1Prompt = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "Suggest me a good car to buy." },
|
||||
{ type: "text", text: " Write only the car name and it's color." },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
const { text } = await generateText({
|
||||
// @ts-ignore
|
||||
model: mem0("llama3-8b-8192"),
|
||||
messages: messages
|
||||
});
|
||||
|
||||
// Expect text to be a string
|
||||
expect(typeof text).toBe('string');
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("should generate text using GROQ provider with memories", async () => {
|
||||
const prompt = "Suggest me a good car to buy.";
|
||||
|
||||
const { text } = await generateText({
|
||||
// @ts-ignore
|
||||
model: mem0("llama3-8b-8192"),
|
||||
prompt: prompt
|
||||
});
|
||||
|
||||
expect(typeof text).toBe('string');
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -2,7 +2,7 @@ import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { generateObject } from "ai";
|
||||
import { testConfig } from "../config/test-config";
|
||||
import { testConfig } from "../../config/test-config";
|
||||
import { z } from "zod";
|
||||
|
||||
interface Provider {
|
||||
55
vercel-ai-sdk/tests/mem0-provider-tests/mem0-openai.test.ts
Normal file
55
vercel-ai-sdk/tests/mem0-provider-tests/mem0-openai.test.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { createMem0 } from "../../src";
|
||||
import { generateText, LanguageModelV1Prompt } from "ai";
|
||||
import { testConfig } from "../../config/test-config";
|
||||
|
||||
describe("OPENAI MEM0 Tests", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(30000);
|
||||
let mem0: any;
|
||||
|
||||
beforeEach(() => {
|
||||
mem0 = createMem0({
|
||||
provider: "openai",
|
||||
apiKey: process.env.OPENAI_API_KEY,
|
||||
mem0Config: {
|
||||
user_id: userId
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it("should retrieve memories and generate text using Mem0 OpenAI provider", async () => {
|
||||
const messages: LanguageModelV1Prompt = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "Suggest me a good car to buy." },
|
||||
{ type: "text", text: " Write only the car name and it's color." },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const { text } = await generateText({
|
||||
model: mem0("gpt-4-turbo"),
|
||||
messages: messages
|
||||
});
|
||||
|
||||
// Expect text to be a string
|
||||
expect(typeof text).toBe('string');
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("should generate text using openai provider with memories", async () => {
|
||||
const prompt = "Suggest me a good car to buy.";
|
||||
|
||||
const { text } = await generateText({
|
||||
model: mem0("gpt-4-turbo"),
|
||||
prompt: prompt
|
||||
});
|
||||
|
||||
expect(typeof text).toBe('string');
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,59 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { createMem0, retrieveMemories } from "../../src";
|
||||
import { generateText, LanguageModelV1Prompt } from "ai";
|
||||
import { testConfig } from "../../config/test-config";
|
||||
import { createAnthropic } from "@ai-sdk/anthropic";
|
||||
|
||||
describe("ANTHROPIC MEM0 Tests", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(30000);
|
||||
|
||||
let mem0: any;
|
||||
|
||||
beforeEach(() => {
|
||||
mem0 = createMem0({
|
||||
provider: "anthropic",
|
||||
apiKey: process.env.ANTHROPIC_API_KEY,
|
||||
mem0Config: {
|
||||
user_id: userId
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it("should retrieve memories and generate text using ANTHROPIC provider", async () => {
|
||||
const messages: LanguageModelV1Prompt = [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "Suggest me a good car to buy." },
|
||||
{ type: "text", text: " Write only the car name and it's color." },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const { text } = await generateText({
|
||||
// @ts-ignore
|
||||
model: mem0("claude-3-haiku-20240307"),
|
||||
messages: messages,
|
||||
});
|
||||
|
||||
// Expect text to be a string
|
||||
expect(typeof text).toBe('string');
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("should generate text using ANTHROPIC provider with memories", async () => {
|
||||
const prompt = "Suggest me a good car to buy.";
|
||||
|
||||
const { text } = await generateText({
|
||||
// @ts-ignore
|
||||
model: mem0("claude-3-haiku-20240307"),
|
||||
prompt: prompt,
|
||||
});
|
||||
|
||||
expect(typeof text).toBe('string');
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
91
vercel-ai-sdk/tests/mem0-toolcalls.test.ts
Normal file
91
vercel-ai-sdk/tests/mem0-toolcalls.test.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { addMemories, createMem0 } from "../src";
|
||||
import { generateText, tool } from "ai";
|
||||
import { testConfig } from "../config/test-config";
|
||||
import { z } from "zod";
|
||||
|
||||
describe("Tool Calls Tests", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(30000);
|
||||
|
||||
beforeEach(async () => {
|
||||
await addMemories([{
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "I live in Mumbai" }],
|
||||
}], { user_id: userId });
|
||||
});
|
||||
|
||||
it("should Execute a Tool Call Using OpenAI", async () => {
|
||||
const mem0OpenAI = createMem0({
|
||||
provider: "openai",
|
||||
apiKey: process.env.OPENAI_API_KEY,
|
||||
mem0Config: {
|
||||
user_id: userId,
|
||||
},
|
||||
});
|
||||
|
||||
const result = await generateText({
|
||||
model: mem0OpenAI("gpt-4o"),
|
||||
tools: {
|
||||
weather: tool({
|
||||
description: "Get the weather in a location",
|
||||
parameters: z.object({
|
||||
location: z
|
||||
.string()
|
||||
.describe("The location to get the weather for"),
|
||||
}),
|
||||
execute: async ({ location }) => ({
|
||||
location,
|
||||
temperature: 72 + Math.floor(Math.random() * 21) - 10,
|
||||
}),
|
||||
}),
|
||||
},
|
||||
prompt: "What is the temperature in the city that I live in?",
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
const text = result.response.messages[1].content[0].result.location;
|
||||
|
||||
// Expect text to be a string
|
||||
expect(typeof text).toBe("string");
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("should Execute a Tool Call Using Anthropic", async () => {
|
||||
const mem0Anthropic = createMem0({
|
||||
provider: "anthropic",
|
||||
apiKey: process.env.ANTHROPIC_API_KEY,
|
||||
mem0Config: {
|
||||
user_id: userId,
|
||||
},
|
||||
});
|
||||
|
||||
const result = await generateText({
|
||||
model: mem0Anthropic("claude-3-haiku-20240307"),
|
||||
tools: {
|
||||
weather: tool({
|
||||
description: "Get the weather in a location",
|
||||
parameters: z.object({
|
||||
location: z
|
||||
.string()
|
||||
.describe("The location to get the weather for"),
|
||||
}),
|
||||
execute: async ({ location }) => ({
|
||||
location,
|
||||
temperature: 72 + Math.floor(Math.random() * 21) - 10,
|
||||
}),
|
||||
}),
|
||||
},
|
||||
prompt: "What is the temperature in the city that I live in?",
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
const text = result.response.messages[1].content[0].result.location;
|
||||
|
||||
// Expect text to be a string
|
||||
expect(typeof text).toBe("string");
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -4,7 +4,7 @@ import { testConfig } from "../config/test-config";
|
||||
|
||||
describe("Memory Core Functions", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(10000);
|
||||
jest.setTimeout(20000);
|
||||
|
||||
describe("addMemories", () => {
|
||||
it("should successfully add memories and return correct format", async () => {
|
||||
|
||||
@@ -57,7 +57,7 @@ describe.each(testConfig.providers)('TEXT/STREAM PROPERTIES: Tests with model %s
|
||||
text, // combined text
|
||||
usage, // combined usage of all steps
|
||||
} = await generateText({
|
||||
model: mem0(provider.activeModel), // Ensure the model name is correct
|
||||
model: mem0.completion(provider.activeModel), // Ensure the model name is correct
|
||||
maxSteps: 5, // Enable multi-step calls
|
||||
experimental_continueSteps: true,
|
||||
prompt:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { retrieveMemories } from "../src";
|
||||
import { retrieveMemories } from "../../src";
|
||||
import { generateText, LanguageModelV1Prompt } from "ai";
|
||||
import { testConfig } from "../config/test-config";
|
||||
import { testConfig } from "../../config/test-config";
|
||||
import { createAnthropic } from "@ai-sdk/anthropic";
|
||||
|
||||
describe("ANTHROPIC Functions", () => {
|
||||
describe("ANTHROPIC Integration Tests", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(30000);
|
||||
|
||||
@@ -36,7 +36,7 @@ describe("ANTHROPIC Functions", () => {
|
||||
// @ts-ignore
|
||||
model: anthropic("claude-3-haiku-20240307"),
|
||||
messages: messages,
|
||||
system: memories,
|
||||
system: memories.length > 0 ? memories : "No Memories Found"
|
||||
});
|
||||
|
||||
// Expect text to be a string
|
||||
@@ -52,7 +52,7 @@ describe("ANTHROPIC Functions", () => {
|
||||
// @ts-ignore
|
||||
model: anthropic("claude-3-haiku-20240307"),
|
||||
prompt: prompt,
|
||||
system: memories
|
||||
system: memories.length > 0 ? memories : "No Memories Found"
|
||||
});
|
||||
|
||||
expect(typeof text).toBe('string');
|
||||
@@ -1,12 +1,12 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { retrieveMemories } from "../src";
|
||||
import { retrieveMemories } from "../../src";
|
||||
import { generateText, LanguageModelV1Prompt } from "ai";
|
||||
import { testConfig } from "../config/test-config";
|
||||
import { testConfig } from "../../config/test-config";
|
||||
import { createCohere } from "@ai-sdk/cohere";
|
||||
|
||||
describe("COHERE Functions", () => {
|
||||
describe("COHERE Integration Tests", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(30000);
|
||||
let cohere: any;
|
||||
@@ -1,12 +1,12 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { retrieveMemories } from "../src";
|
||||
import { retrieveMemories } from "../../src";
|
||||
import { generateText, LanguageModelV1Prompt } from "ai";
|
||||
import { testConfig } from "../config/test-config";
|
||||
import { testConfig } from "../../config/test-config";
|
||||
import { createGroq } from "@ai-sdk/groq";
|
||||
|
||||
describe("GROQ Functions", () => {
|
||||
describe("GROQ Integration Tests", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(30000);
|
||||
|
||||
@@ -34,7 +34,7 @@ describe("GROQ Functions", () => {
|
||||
|
||||
const { text } = await generateText({
|
||||
// @ts-ignore
|
||||
model: groq("gemma2-9b-it"),
|
||||
model: groq("llama3-8b-8192"),
|
||||
messages: messages,
|
||||
system: memories,
|
||||
});
|
||||
@@ -50,7 +50,7 @@ describe("GROQ Functions", () => {
|
||||
|
||||
const { text } = await generateText({
|
||||
// @ts-ignore
|
||||
model: groq("gemma2-9b-it"),
|
||||
model: groq("llama3-8b-8192"),
|
||||
prompt: prompt,
|
||||
system: memories
|
||||
});
|
||||
@@ -1,12 +1,12 @@
|
||||
import dotenv from "dotenv";
|
||||
dotenv.config();
|
||||
|
||||
import { retrieveMemories } from "../src";
|
||||
import { retrieveMemories } from "../../src";
|
||||
import { generateText, LanguageModelV1Prompt } from "ai";
|
||||
import { testConfig } from "../config/test-config";
|
||||
import { testConfig } from "../../config/test-config";
|
||||
import { createOpenAI } from "@ai-sdk/openai";
|
||||
|
||||
describe("OPENAI Functions", () => {
|
||||
describe("OPENAI Integration Tests", () => {
|
||||
const { userId } = testConfig;
|
||||
jest.setTimeout(30000);
|
||||
let openai: any;
|
||||
@@ -8,7 +8,7 @@
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"inlineSources": false,
|
||||
"isolatedModules": true,
|
||||
"moduleResolution": "node",
|
||||
"moduleResolution": "node16",
|
||||
"noUnusedLocals": false,
|
||||
"noUnusedParameters": false,
|
||||
"preserveWatchOutput": true,
|
||||
@@ -17,7 +17,7 @@
|
||||
"types": ["@types/node", "jest"],
|
||||
"jsx": "react-jsx",
|
||||
"lib": ["dom", "ES2021"],
|
||||
"module": "ESNext",
|
||||
"module": "Node16",
|
||||
"target": "ES2018",
|
||||
"stripInternal": true,
|
||||
"paths": {
|
||||
|
||||
Reference in New Issue
Block a user