(Feature) Vercel AI SDK (#2024)

This commit is contained in:
Saket Aryan
2024-11-19 23:53:58 +05:30
committed by GitHub
parent a02597ed59
commit 13374a12e9
70 changed files with 4074 additions and 0 deletions

2
vercel-ai-sdk/.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto

10
vercel-ai-sdk/.gitignore vendored Normal file
View File

@@ -0,0 +1,10 @@
**/.env
**/node_modules
**/.DS_Store
# Ignore test-related files
**/coverage.data
**/coverage/
# Build files
**/dist

228
vercel-ai-sdk/README.md Normal file
View File

@@ -0,0 +1,228 @@
# Mem0 AI SDK Provider
The **Mem0 AI SDK Provider** is a community-maintained library developed by [Mem0](https://mem0.ai/) to integrate with the Vercel AI SDK. This library brings enhanced AI interaction capabilities to your applications by introducing persistent memory functionality. With Mem0, language model conversations gain memory, enabling more contextualized and personalized responses based on past interactions.
Discover more of **Mem0** on [GitHub](https://github.com/mem0ai).
Explore the [Mem0 Documentation](https://docs.mem0.ai/overview) to gain deeper control and flexibility in managing your memories.
For detailed information on using the Vercel AI SDK, refer to Vercels [API Reference](https://sdk.vercel.ai/docs/reference) and [Documentation](https://sdk.vercel.ai/docs).
## Features
- 🧠 Persistent memory storage for AI conversations
- 🔄 Seamless integration with Vercel AI SDK
- 🚀 Support for multiple LLM providers
- 📝 Rich message format support
- ⚡ Streaming capabilities
- 🔍 Context-aware responses
## Installation
```bash
npm install @mem0/vercel-ai-provider
```
## Before We Begin
### Setting Up Mem0
1. Obtain your [Mem0 API Key](https://app.mem0.ai/dashboard/api-keys) from the Mem0 dashboard.
2. Initialize the Mem0 Client:
```typescript
import { createMem0 } from "@mem0/vercel-ai-provider";
const mem0 = createMem0({
provider: "openai",
mem0ApiKey: "m0-xxx",
apiKey: "openai-api-key",
config: {
compatibility: "strict",
// Additional model-specific configuration options can be added here.
},
});
```
### Note
By default, the `openai` provider is used, so specifying it is optional:
```typescript
const mem0 = createMem0();
```
For better security, consider setting `MEM0_API_KEY` and `OPENAI_API_KEY` as environment variables.
3. Add Memories to Enhance Context:
```typescript
import { LanguageModelV1Prompt } from "ai";
import { addMemories } from "@mem0/vercel-ai-provider";
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "I love red cars." },
{ type: "text", text: "I like Toyota Cars." },
{ type: "text", text: "I prefer SUVs." },
],
},
];
await addMemories(messages, { user_id: "borat" });
```
These memories are now stored in your profile. You can view and manage them on the [Mem0 Dashboard](https://app.mem0.ai/dashboard/users).
### Note:
For standalone features, such as `addMemories` and `retrieveMemories`,
you must either set `MEM0_API_KEY` as an environment variable or pass it directly in the function call.
Example:
```typescript
await addMemories(messages, { user_id: "borat", mem0ApiKey: "m0-xxx" });
await retrieveMemories(prompt, { user_id: "borat", mem0ApiKey: "m0-xxx" });
```
## Usage Examples
### 1. Basic Text Generation with Memory Context
```typescript
import { generateText } from "ai";
import { createMem0 } from "@mem0/vercel-ai-provider";
const mem0 = createMem0();
const { text } = await generateText({
model: mem0("gpt-4-turbo", {
user_id: "borat",
}),
prompt: "Suggest me a good car to buy!",
});
```
### 2. Combining OpenAI Provider with Memory Utils
```typescript
import { generateText } from "ai";
import { openai } from "@ai-sdk/openai";
import { retrieveMemories } from "@mem0/vercel-ai-provider";
const prompt = "Suggest me a good car to buy.";
const memories = await retrieveMemories(prompt, { user_id: "borat" });
const { text } = await generateText({
model: openai("gpt-4-turbo"),
prompt: prompt,
system: memories,
});
```
### 3. Structured Message Format with Memory
```typescript
import { generateText } from "ai";
import { createMem0 } from "@mem0/vercel-ai-provider";
const mem0 = createMem0();
const { text } = await generateText({
model: mem0("gpt-4-turbo", {
user_id: "borat",
}),
messages: [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: "Why is it better than the other cars for me?" },
{ type: "text", text: "Give options for every price range." },
],
},
],
});
```
### 4. Advanced Memory Integration with OpenAI
```typescript
import { generateText, LanguageModelV1Prompt } from "ai";
import { openai } from "@ai-sdk/openai";
import { retrieveMemories } from "@mem0/vercel-ai-provider";
// New format using system parameter for memory context
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: "Why is it better than the other cars for me?" },
{ type: "text", text: "Give options for every price range." },
],
},
];
const memories = await retrieveMemories(messages, { user_id: "borat" });
const { text } = await generateText({
model: openai("gpt-4-turbo"),
messages: messages,
system: memories,
});
```
### 5. Streaming Responses with Memory Context
```typescript
import { streamText } from "ai";
import { createMem0 } from "@mem0/vercel-ai-provider";
const mem0 = createMem0();
const { textStream } = await streamText({
model: mem0("gpt-4-turbo", {
user_id: "borat",
}),
prompt:
"Suggest me a good car to buy! Why is it better than the other cars for me? Give options for every price range.",
});
for await (const textPart of textStream) {
process.stdout.write(textPart);
}
```
## Core Functions
- `createMem0()`: Initializes a new mem0 provider instance with optional configuration
- `retrieveMemories()`: Enriches prompts with relevant memories
- `addMemories()`: Add memories to your profile
## Configuration Options
```typescript
const mem0 = createMem0({
config: {
...
// Additional model-specific configuration options can be added here.
},
});
```
## Best Practices
1. **User Identification**: Always provide a unique `user_id` identifier for consistent memory retrieval
2. **Context Management**: Use appropriate context window sizes to balance performance and memory
3. **Error Handling**: Implement proper error handling for memory operations
4. **Memory Cleanup**: Regularly clean up unused memory contexts to optimize performance
We also have support for `agent_id`, `app_id`, and `run_id`. Refer [Docs](https://docs.mem0.ai/api-reference/memory/add-memories).
## Notes
- Requires proper API key configuration for underlying providers (e.g., OpenAI)
- Memory features depend on proper user identification via `user_id`
- Supports both streaming and non-streaming responses
- Compatible with all Vercel AI SDK features and patterns

View File

@@ -0,0 +1,105 @@
import dotenv from "dotenv";
import { createMem0 } from "../src";
dotenv.config();
export interface Provider {
name: string;
activeModel: string;
apiKey: string | undefined;
}
export const testConfig = {
apiKey: process.env.MEM0_API_KEY,
userId: "mem0-ai-sdk-test-user-1134774",
deleteId: "",
providers: [
{
name: "openai",
activeModel: "gpt-4-turbo",
apiKey: process.env.OPENAI_API_KEY,
}
,
{
name: "anthropic",
activeModel: "claude-3-5-sonnet-20240620",
apiKey: process.env.ANTHROPIC_API_KEY,
},
// {
// name: "groq",
// activeModel: "gemma2-9b-it",
// apiKey: process.env.GROQ_API_KEY,
// },
{
name: "cohere",
activeModel: "command-r-plus",
apiKey: process.env.COHERE_API_KEY,
}
],
models: {
openai: "gpt-4-turbo",
anthropic: "claude-3-haiku-20240307",
groq: "gemma2-9b-it",
cohere: "command-r-plus"
},
apiKeys: {
openai: process.env.OPENAI_API_KEY,
anthropic: process.env.ANTHROPIC_API_KEY,
groq: process.env.GROQ_API_KEY,
cohere: process.env.COHERE_API_KEY,
},
createTestClient: (provider: Provider) => {
return createMem0({
provider: provider.name,
mem0ApiKey: process.env.MEM0_API_KEY,
apiKey: provider.apiKey,
});
},
fetchDeleteId: async function () {
const options = {
method: 'GET',
headers: {
Authorization: `Token ${this.apiKey}`,
},
};
try {
const response = await fetch('https://api.mem0.ai/v1/entities/', options);
const data = await response.json();
const entity = data.results.find((item: any) => item.name === this.userId);
if (entity) {
this.deleteId = entity.id;
} else {
console.error("No matching entity found for userId:", this.userId);
}
} catch (error) {
console.error("Error fetching deleteId:", error);
throw error;
}
},
deleteUser: async function () {
if (!this.deleteId) {
console.error("deleteId is not set. Ensure fetchDeleteId is called first.");
return;
}
const options = {
method: 'DELETE',
headers: {
Authorization: `Token ${this.apiKey}`,
},
};
try {
const response = await fetch(`https://api.mem0.ai/v1/entities/user/${this.deleteId}/`, options);
if (!response.ok) {
throw new Error(`Failed to delete user: ${response.statusText}`);
}
await response.json();
} catch (error) {
console.error("Error deleting user:", error);
throw error;
}
},
};

View File

@@ -0,0 +1,6 @@
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
globalTeardown: './teardown.ts',
};

View File

@@ -0,0 +1,5 @@
{
"watch": ["src"],
"ext": ".ts,.js",
"exec": "ts-node ./example/index.ts"
}

View File

@@ -0,0 +1,69 @@
{
"name": "@mem0/vercel-ai-provider",
"version": "0.0.7",
"description": "Vercel AI Provider for providing memory to LLMs",
"main": "./dist/index.js",
"module": "./dist/index.mjs",
"types": "./dist/index.d.ts",
"files": [
"dist/**/*"
],
"scripts": {
"build": "tsup",
"clean": "rm -rf dist",
"dev": "nodemon",
"lint": "eslint \"./**/*.ts*\"",
"type-check": "tsc --noEmit",
"prettier-check": "prettier --check \"./**/*.ts*\"",
"test": "jest",
"test:edge": "vitest --config vitest.edge.config.js --run",
"test:node": "vitest --config vitest.node.config.js --run"
},
"keywords": [
"ai",
"vercel-ai"
],
"author": "Saket Aryan <saketaryan2002@gmail.com>",
"license": "Apache-2.0",
"dependencies": {
"@ai-sdk/anthropic": "^0.0.54",
"@ai-sdk/cohere": "^0.0.28",
"@ai-sdk/groq": "^0.0.3",
"@ai-sdk/openai": "^0.0.71",
"@ai-sdk/provider": "^0.0.26",
"@ai-sdk/provider-utils": "^1.0.22",
"ai": "^3.4.31",
"dotenv": "^16.4.5",
"partial-json": "0.1.7",
"ts-node": "^10.9.2",
"zod": "^3.0.0"
},
"devDependencies": {
"@edge-runtime/vm": "^3.2.0",
"@types/jest": "^29.5.14",
"@types/node": "^18.19.46",
"jest": "^29.7.0",
"nodemon": "^3.1.7",
"ts-jest": "^29.2.5",
"tsup": "^8.3.0",
"typescript": "5.5.4"
},
"peerDependencies": {
"zod": "^3.0.0"
},
"peerDependenciesMeta": {
"zod": {
"optional": true
}
},
"engines": {
"node": ">=18"
},
"publishConfig": {
"access": "public"
},
"directories": {
"example": "example",
"test": "tests"
}
}

View File

@@ -0,0 +1,4 @@
export * from './mem0-facade'
export type { Mem0Provider, Mem0ProviderSettings } from './mem0-provider'
export { createMem0, mem0 } from './mem0-provider'
export {addMemories, retrieveMemories, searchMemories } from './mem0-utils'

View File

@@ -0,0 +1,150 @@
/* eslint-disable camelcase */
import {
LanguageModelV1,
LanguageModelV1CallOptions,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1FunctionToolCall,
LanguageModelV1LogProbs,
LanguageModelV1ProviderMetadata,
LanguageModelV1StreamPart,
} from "@ai-sdk/provider";
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
import { Mem0ClassSelector } from "./mem0-provider-selector";
import { filterStream } from "./stream-utils";
import { Mem0Config } from "./mem0-chat-settings";
import { OpenAIProviderSettings } from "@ai-sdk/openai";
import { Mem0ProviderSettings } from "./mem0-provider";
interface Mem0ChatConfig {
baseURL: string;
fetch?: typeof fetch;
headers: () => Record<string, string | undefined>;
provider: string;
organization?: string;
project?: string;
name?: string;
apiKey?: string;
mem0_api_key?: string;
}
export class Mem0ChatLanguageModel implements LanguageModelV1 {
readonly specificationVersion = "v1";
readonly defaultObjectGenerationMode = "json";
readonly supportsImageUrls = false;
constructor(
public readonly modelId: Mem0ChatModelId,
public readonly settings: Mem0ChatSettings,
public readonly config: Mem0ChatConfig,
public readonly provider_config?: OpenAIProviderSettings
) {
this.provider = config.provider;
}
provider: string;
supportsStructuredOutputs?: boolean | undefined;
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
text?: string;
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
finishReason: LanguageModelV1FinishReason;
usage: { promptTokens: number; completionTokens: number };
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
rawResponse?: { headers?: Record<string, string> };
response?: { id?: string; timestamp?: Date; modelId?: string };
warnings?: LanguageModelV1CallWarning[];
providerMetadata?: LanguageModelV1ProviderMetadata;
logprobs?: LanguageModelV1LogProbs;
}> {
try {
const provider = this.config.provider;
const mem0_api_key = this.config.mem0_api_key;
const settings: Mem0ProviderSettings = {
provider: provider,
mem0ApiKey: mem0_api_key,
apiKey: this.config.apiKey,
modelType: "chat"
}
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
let messagesPrompts = options.prompt;
const model = selector.createProvider();
const user_id = this.settings.user_id;
const app_id = this.settings.app_id;
const agent_id = this.settings.agent_id;
const run_id = this.settings.run_id;
const org_name = this.settings.org_name;
const project_name = this.settings.project_name;
const apiKey = mem0_api_key;
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
const ans = await model.generateText(messagesPrompts, config);
return {
text: ans.text,
finishReason: ans.finishReason,
usage: ans.usage,
rawCall: {
rawPrompt: options.prompt,
rawSettings: {},
},
response: ans.response,
warnings: ans.warnings,
};
} catch (error) {
// Handle errors properly
console.error("Error in doGenerate:", error);
throw new Error("Failed to generate response.");
}
}
async doStream(options: LanguageModelV1CallOptions): Promise<{
stream: ReadableStream<LanguageModelV1StreamPart>;
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
rawResponse?: { headers?: Record<string, string> };
warnings?: LanguageModelV1CallWarning[];
}> {
try {
const provider = this.config.provider;
const mem0_api_key = this.config.mem0_api_key;
const settings: Mem0ProviderSettings = {
provider: provider,
mem0ApiKey: mem0_api_key,
apiKey: this.config.apiKey,
modelType: "chat"
}
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
let messagesPrompts = options.prompt;
const model = selector.createProvider();
const user_id = this.settings.user_id;
const app_id = this.settings.app_id;
const agent_id = this.settings.agent_id;
const run_id = this.settings.run_id;
const org_name = this.settings.org_name;
const project_name = this.settings.project_name;
const apiKey = mem0_api_key;
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
const response = await model.streamText(messagesPrompts, config);
// @ts-ignore
const filteredStream = await filterStream(response.originalStream);
return {
// @ts-ignore
stream: filteredStream,
rawCall: {
rawPrompt: options.prompt,
rawSettings: {},
},
...response,
};
} catch (error) {
console.error("Error in doStream:", error);
throw new Error("Streaming failed or method not implemented.");
}
}
}

View File

@@ -0,0 +1,36 @@
import { OpenAIChatSettings } from "@ai-sdk/openai/internal";
export type Mem0ChatModelId =
| "o1-preview"
| "o1-mini"
| "gpt-4o"
| "gpt-4o-2024-05-13"
| "gpt-4o-2024-08-06"
| "gpt-4o-audio-preview"
| "gpt-4o-audio-preview-2024-10-01"
| "gpt-4o-mini"
| "gpt-4o-mini-2024-07-18"
| "gpt-4-turbo"
| "gpt-4-turbo-2024-04-09"
| "gpt-4-turbo-preview"
| "gpt-4-0125-preview"
| "gpt-4-1106-preview"
| "gpt-4"
| "gpt-4-0613"
| "gpt-3.5-turbo-0125"
| "gpt-3.5-turbo"
| "gpt-3.5-turbo-1106"
| (string & NonNullable<unknown>);
export interface Mem0ChatSettings extends OpenAIChatSettings {
user_id?: string;
app_id?: string;
agent_id?: string;
run_id?: string;
org_name?: string;
project_name?: string;
mem0ApiKey?: string;
structuredOutputs?: boolean;
}
export interface Mem0Config extends Mem0ChatSettings {}

View File

@@ -0,0 +1,150 @@
/* eslint-disable camelcase */
import {
LanguageModelV1,
LanguageModelV1CallOptions,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1FunctionToolCall,
LanguageModelV1LogProbs,
LanguageModelV1ProviderMetadata,
LanguageModelV1StreamPart,
} from "@ai-sdk/provider";
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
import { Mem0ClassSelector } from "./mem0-provider-selector";
import { filterStream } from "./stream-utils";
import { Mem0Config } from "./mem0-completion-settings";
import { OpenAIProviderSettings } from "@ai-sdk/openai";
import { Mem0ProviderSettings } from "./mem0-provider";
interface Mem0CompletionConfig {
baseURL: string;
fetch?: typeof fetch;
headers: () => Record<string, string | undefined>;
provider: string;
organization?: string;
project?: string;
name?: string;
apiKey?: string;
mem0_api_key?: string;
}
export class Mem0CompletionLanguageModel implements LanguageModelV1 {
readonly specificationVersion = "v1";
readonly defaultObjectGenerationMode = "json";
readonly supportsImageUrls = false;
constructor(
public readonly modelId: Mem0ChatModelId,
public readonly settings: Mem0ChatSettings,
public readonly config: Mem0CompletionConfig,
public readonly provider_config?: OpenAIProviderSettings
) {
this.provider = config.provider;
}
provider: string;
supportsStructuredOutputs?: boolean | undefined;
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
text?: string;
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
finishReason: LanguageModelV1FinishReason;
usage: { promptTokens: number; completionTokens: number };
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
rawResponse?: { headers?: Record<string, string> };
response?: { id?: string; timestamp?: Date; modelId?: string };
warnings?: LanguageModelV1CallWarning[];
providerMetadata?: LanguageModelV1ProviderMetadata;
logprobs?: LanguageModelV1LogProbs;
}> {
try {
const provider = this.config.provider;
const mem0_api_key = this.config.mem0_api_key;
const settings: Mem0ProviderSettings = {
provider: provider,
mem0ApiKey: mem0_api_key,
apiKey: this.config.apiKey,
modelType: "completion"
}
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
let messagesPrompts = options.prompt;
const model = selector.createProvider();
const user_id = this.settings.user_id;
const app_id = this.settings.app_id;
const agent_id = this.settings.agent_id;
const run_id = this.settings.run_id;
const org_name = this.settings.org_name;
const project_name = this.settings.project_name;
const apiKey = mem0_api_key;
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey, modelType: "completion"};
const ans = await model.generateText(messagesPrompts, config);
return {
text: ans.text,
finishReason: ans.finishReason,
usage: ans.usage,
rawCall: {
rawPrompt: options.prompt,
rawSettings: {},
},
response: ans.response,
warnings: ans.warnings,
};
} catch (error) {
// Handle errors properly
console.error("Error in doGenerate:", error);
throw new Error("Failed to generate response.");
}
}
async doStream(options: LanguageModelV1CallOptions): Promise<{
stream: ReadableStream<LanguageModelV1StreamPart>;
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
rawResponse?: { headers?: Record<string, string> };
warnings?: LanguageModelV1CallWarning[];
}> {
try {
const provider = this.config.provider;
const mem0_api_key = this.config.mem0_api_key;
const settings: Mem0ProviderSettings = {
provider: provider,
mem0ApiKey: mem0_api_key,
apiKey: this.config.apiKey,
modelType: "completion"
}
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
let messagesPrompts = options.prompt;
const model = selector.createProvider();
const user_id = this.settings.user_id;
const app_id = this.settings.app_id;
const agent_id = this.settings.agent_id;
const run_id = this.settings.run_id;
const org_name = this.settings.org_name;
const project_name = this.settings.project_name;
const apiKey = mem0_api_key;
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey, modelType: "completion"};
const response = await model.streamText(messagesPrompts, config);
// @ts-ignore
const filteredStream = await filterStream(response.originalStream);
return {
// @ts-ignore
stream: filteredStream,
rawCall: {
rawPrompt: options.prompt,
rawSettings: {},
},
...response,
};
} catch (error) {
console.error("Error in doStream:", error);
throw new Error("Streaming failed or method not implemented.");
}
}
}

View File

@@ -0,0 +1,19 @@
import { OpenAICompletionSettings } from "@ai-sdk/openai/internal";
export type Mem0CompletionModelId =
| "gpt-3.5-turbo"
| (string & NonNullable<unknown>);
export interface Mem0CompletionSettings extends OpenAICompletionSettings {
user_id?: string;
app_id?: string;
agent_id?: string;
run_id?: string;
org_name?: string;
project_name?: string;
mem0ApiKey?: string;
structuredOutputs?: boolean;
modelType?: string;
}
export interface Mem0Config extends Mem0CompletionSettings {}

View File

@@ -0,0 +1,36 @@
import { withoutTrailingSlash } from '@ai-sdk/provider-utils'
import { Mem0ChatLanguageModel } from './mem0-chat-language-model'
import { Mem0ChatModelId, Mem0ChatSettings } from './mem0-chat-settings'
import { Mem0ProviderSettings } from './mem0-provider'
export class Mem0 {
readonly baseURL: string
readonly headers?: Record<string, string>
constructor(options: Mem0ProviderSettings = {
provider: 'openai',
}) {
this.baseURL =
withoutTrailingSlash(options.baseURL) ?? 'http://127.0.0.1:11434/api'
this.headers = options.headers
}
private get baseConfig() {
return {
baseURL: this.baseURL,
headers: () => ({
...this.headers,
}),
}
}
chat(modelId: Mem0ChatModelId, settings: Mem0ChatSettings = {}) {
return new Mem0ChatLanguageModel(modelId, settings, {
provider: 'openai',
...this.baseConfig,
})
}
}

View File

@@ -0,0 +1,148 @@
/* eslint-disable camelcase */
import {
LanguageModelV1,
LanguageModelV1CallOptions,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1FunctionToolCall,
LanguageModelV1LogProbs,
LanguageModelV1ProviderMetadata,
LanguageModelV1StreamPart,
} from "@ai-sdk/provider";
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
import { Mem0ClassSelector } from "./mem0-provider-selector";
import { filterStream } from "./stream-utils";
import { Mem0Config } from "./mem0-chat-settings";
import { OpenAIProviderSettings } from "@ai-sdk/openai";
import { Mem0ProviderSettings } from "./mem0-provider";
interface Mem0ChatConfig {
baseURL: string;
fetch?: typeof fetch;
headers: () => Record<string, string | undefined>;
provider: string;
organization?: string;
project?: string;
name?: string;
apiKey?: string;
mem0_api_key?: string;
}
export class Mem0GenericLanguageModel implements LanguageModelV1 {
readonly specificationVersion = "v1";
readonly defaultObjectGenerationMode = "json";
readonly supportsImageUrls = false;
constructor(
public readonly modelId: Mem0ChatModelId,
public readonly settings: Mem0ChatSettings,
public readonly config: Mem0ChatConfig,
public readonly provider_config?: OpenAIProviderSettings
) {
this.provider = config.provider;
}
provider: string;
supportsStructuredOutputs?: boolean | undefined;
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
text?: string;
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
finishReason: LanguageModelV1FinishReason;
usage: { promptTokens: number; completionTokens: number };
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
rawResponse?: { headers?: Record<string, string> };
response?: { id?: string; timestamp?: Date; modelId?: string };
warnings?: LanguageModelV1CallWarning[];
providerMetadata?: LanguageModelV1ProviderMetadata;
logprobs?: LanguageModelV1LogProbs;
}> {
try {
const provider = this.config.provider;
const mem0_api_key = this.config.mem0_api_key;
const settings: Mem0ProviderSettings = {
provider: provider,
mem0ApiKey: mem0_api_key,
apiKey: this.config.apiKey,
}
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
let messagesPrompts = options.prompt;
const model = selector.createProvider();
const user_id = this.settings.user_id;
const app_id = this.settings.app_id;
const agent_id = this.settings.agent_id;
const run_id = this.settings.run_id;
const org_name = this.settings.org_name;
const project_name = this.settings.project_name;
const apiKey = mem0_api_key;
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
const ans = await model.generateText(messagesPrompts, config);
return {
text: ans.text,
finishReason: ans.finishReason,
usage: ans.usage,
rawCall: {
rawPrompt: options.prompt,
rawSettings: {},
},
response: ans.response,
warnings: ans.warnings,
};
} catch (error) {
// Handle errors properly
console.error("Error in doGenerate:", error);
throw new Error("Failed to generate response.");
}
}
async doStream(options: LanguageModelV1CallOptions): Promise<{
stream: ReadableStream<LanguageModelV1StreamPart>;
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
rawResponse?: { headers?: Record<string, string> };
warnings?: LanguageModelV1CallWarning[];
}> {
try {
const provider = this.config.provider;
const mem0_api_key = this.config.mem0_api_key;
const settings: Mem0ProviderSettings = {
provider: provider,
mem0ApiKey: mem0_api_key,
apiKey: this.config.apiKey,
}
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
let messagesPrompts = options.prompt;
const model = selector.createProvider();
const user_id = this.settings.user_id;
const app_id = this.settings.app_id;
const agent_id = this.settings.agent_id;
const run_id = this.settings.run_id;
const org_name = this.settings.org_name;
const project_name = this.settings.project_name;
const apiKey = mem0_api_key;
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
const response = await model.streamText(messagesPrompts, config);
// @ts-ignore
const filteredStream = await filterStream(response.originalStream);
return {
// @ts-ignore
stream: filteredStream,
rawCall: {
rawPrompt: options.prompt,
rawSettings: {},
},
...response,
};
} catch (error) {
console.error("Error in doStream:", error);
throw new Error("Streaming failed or method not implemented.");
}
}
}

View File

@@ -0,0 +1,34 @@
import { OpenAIProviderSettings } from "@ai-sdk/openai";
import { Mem0ProviderSettings } from "./mem0-provider";
import Mem0AITextGenerator, { ProviderSettings } from "./provider-response-provider";
class Mem0ClassSelector {
modelId: string;
provider_wrapper: string;
model: string;
config: Mem0ProviderSettings;
provider_config?: ProviderSettings;
static supportedProviders = ["openai", "anthropic", "cohere", "groq"];
constructor(modelId: string, config: Mem0ProviderSettings, provider_config?: ProviderSettings) {
this.modelId = modelId;
this.provider_wrapper = config.provider || "openai";
this.model = this.modelId;
this.provider_config = provider_config;
if(config) this.config = config;
else this.config = {
provider: this.provider_wrapper,
};
// Check if provider_wrapper is supported
if (!Mem0ClassSelector.supportedProviders.includes(this.provider_wrapper)) {
throw new Error(`Model not supported: ${this.provider_wrapper}`);
}
}
createProvider() {
return new Mem0AITextGenerator(this.provider_wrapper, this.model, this.config , this.provider_config || {});
}
}
export { Mem0ClassSelector };

View File

@@ -0,0 +1,145 @@
import { LanguageModelV1, ProviderV1 } from '@ai-sdk/provider'
import { withoutTrailingSlash } from '@ai-sdk/provider-utils'
import { Mem0ChatLanguageModel } from './mem0-chat-language-model'
import { Mem0ChatModelId, Mem0ChatSettings } from './mem0-chat-settings'
import { OpenAIProviderSettings } from '@ai-sdk/openai'
import { Mem0CompletionModelId, Mem0CompletionSettings } from './mem0-completion-settings'
import { Mem0GenericLanguageModel } from './mem0-generic-language-model'
import { Mem0CompletionLanguageModel } from './mem0-completion-language-model'
export interface Mem0Provider extends ProviderV1 {
(modelId: Mem0ChatModelId, settings?: Mem0ChatSettings): LanguageModelV1
chat(
modelId: Mem0ChatModelId,
settings?: Mem0ChatSettings,
): LanguageModelV1
languageModel(
modelId: Mem0ChatModelId,
settings?: Mem0ChatSettings,
): LanguageModelV1
completion(
modelId: Mem0CompletionModelId,
settings?: Mem0CompletionSettings,
): LanguageModelV1
}
export interface Mem0ProviderSettings extends OpenAIProviderSettings {
baseURL?: string
/**
* Custom fetch implementation. You can use it as a middleware to intercept
* requests or to provide a custom fetch implementation for e.g. testing
*/
fetch?: typeof fetch
/**
* @internal
*/
generateId?: () => string
/**
* Custom headers to include in the requests.
*/
headers?: Record<string, string>
organization?: string;
project?: string;
name?: string;
mem0ApiKey?: string;
apiKey?: string;
provider?: string;
config?: OpenAIProviderSettings;
modelType?: "completion" | "chat";
}
export function createMem0(
options: Mem0ProviderSettings = {
provider: "openai",
},
): Mem0Provider {
const baseURL =
withoutTrailingSlash(options.baseURL) ?? 'http://127.0.0.1:11434/api'
const getHeaders = () => ({
...options.headers,
})
const createGenericModel = (
modelId: Mem0ChatModelId,
settings: Mem0ChatSettings = {},
) =>
new Mem0GenericLanguageModel(modelId, settings, {
baseURL,
fetch: options.fetch,
headers: getHeaders,
provider: options.provider || "openai",
organization: options.organization,
project: options.project,
name: options.name,
mem0_api_key: options.mem0ApiKey,
apiKey: options.apiKey,
}, options.config)
const createChatModel = (
modelId: Mem0ChatModelId,
settings: Mem0ChatSettings = {},
) =>
new Mem0ChatLanguageModel(modelId, settings, {
baseURL,
fetch: options.fetch,
headers: getHeaders,
provider: options.provider || "openai",
organization: options.organization,
project: options.project,
name: options.name,
mem0_api_key: options.mem0ApiKey,
apiKey: options.apiKey,
}, options.config)
const createCompletionModel = (
modelId: Mem0CompletionModelId,
settings: Mem0CompletionSettings = {}
) =>
new Mem0CompletionLanguageModel(
modelId,
settings,
{
baseURL,
fetch: options.fetch,
headers: getHeaders,
provider: options.provider || "openai",
organization: options.organization,
project: options.project,
name: options.name,
mem0_api_key: options.mem0ApiKey,
apiKey: options.apiKey
},
options.config
);
const provider = function (
modelId: Mem0ChatModelId,
settings?: Mem0ChatSettings,
) {
if (new.target) {
throw new Error(
'The Mem0 model function cannot be called with the new keyword.',
)
}
return createGenericModel(modelId, settings)
}
provider.chat = createChatModel
provider.completion = createCompletionModel
provider.languageModel = createChatModel
return provider as unknown as Mem0Provider
}
export const mem0 = createMem0()

View File

@@ -0,0 +1,114 @@
import { LanguageModelV1Prompt } from 'ai';
import { Mem0Config } from './mem0-chat-settings';
if (typeof process !== 'undefined' && process.env && process.env.NODE_ENV !== 'production') {
// Dynamically import dotenv only in non-production environments
import('dotenv').then((dotenv) => dotenv.config());
}
const tokenIsPresent = (config?: Mem0Config)=>{
if(!config && !config!.mem0ApiKey && (typeof process !== 'undefined' && process.env && !process.env.MEM0_API_KEY)){
throw Error("MEM0_API_KEY is not present. Please set env MEM0_API_KEY as the value of your API KEY.");
}
}
interface Message {
role: string;
content: string | Array<{type: string, text: string}>;
}
const flattenPrompt = (prompt: LanguageModelV1Prompt) => {
return prompt.map((part) => {
if (part.role === "user") {
return part.content
.filter((obj) => obj.type === 'text')
.map((obj) => obj.text)
.join(" ");
}
return "";
}).join(" ");
}
const searchInternalMemories = async (query: string, config?: Mem0Config, top_k: number = 5)=> {
tokenIsPresent(config);
const filters = {
OR: [
{
user_id: config&&config.user_id,
},
{
app_id: config&&config.app_id,
},
{
agent_id: config&&config.agent_id,
},
{
run_id: config&&config.run_id,
},
],
};
const options = {
method: 'POST',
headers: {Authorization: `Token ${(config&&config.mem0ApiKey) || (typeof process !== 'undefined' && process.env && process.env.MEM0_API_KEY) || ""}`, 'Content-Type': 'application/json'},
body: JSON.stringify({query, filters, top_k, version: "v2", org_name: config&&config.org_name, project_name: config&&config.project_name}),
};
const response = await fetch('https://api.mem0.ai/v2/memories/search/', options);
const data = await response.json();
return data;
}
const addMemories = async (messages: LanguageModelV1Prompt, config?: Mem0Config)=>{
tokenIsPresent(config);
const message = flattenPrompt(messages);
const response = await updateMemories([
{ role: "user", content: message },
{ role: "assistant", content: "Thank You!" },
], config);
return response;
}
const updateMemories = async (messages: Array<Message>, config?: Mem0Config)=>{
tokenIsPresent(config);
const options = {
method: 'POST',
headers: {Authorization: `Token ${(config&&config.mem0ApiKey) || (typeof process !== 'undefined' && process.env && process.env.MEM0_API_KEY) || ""}`, 'Content-Type': 'application/json'},
body: JSON.stringify({messages, ...config}),
};
const response = await fetch('https://api.mem0.ai/v1/memories/', options);
const data = await response.json();
return data;
}
const retrieveMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0Config)=>{
tokenIsPresent(config);
const message = typeof prompt === 'string' ? prompt : flattenPrompt(prompt);
const systemPrompt = "These are the memories I have stored. Give more weightage to the question by users and try to answer that first. You have to modify your answer based on the memories I have provided. If the memories are irrelevant you can ignore them. Also don't reply to this section of the prompt, or the memories, they are only for your reference. The System prompt starts after text System Message: \n\n";
const memories = await searchInternalMemories(message, config);
let memoriesText = "";
try{
// @ts-ignore
memoriesText = memories.map((memory: any)=>{
return `Memory: ${memory.memory}\n\n`;
}).join("\n\n");
}catch(e){
console.error("Error while parsing memories");
// console.log(e);
}
return `System Message: ${systemPrompt} ${memoriesText}`;
}
const searchMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0Config)=>{
tokenIsPresent(config);
const message = typeof prompt === 'string' ? prompt : flattenPrompt(prompt);
let memories = [];
try{
// @ts-ignore
memories = await searchInternalMemories(message, config);
}
catch(e){
console.error("Error while searching memories");
}
return memories;
}
export {addMemories, updateMemories, retrieveMemories, flattenPrompt, searchMemories};

View File

@@ -0,0 +1,113 @@
import { createOpenAI, OpenAIProviderSettings } from "@ai-sdk/openai";
import { generateText as aiGenerateText, streamText as aiStreamText, LanguageModelV1Prompt } from "ai";
import { updateMemories, retrieveMemories, flattenPrompt } from "./mem0-utils";
import { Mem0Config } from "./mem0-chat-settings";
import { Mem0ProviderSettings } from "./mem0-provider";
import { CohereProviderSettings, createCohere } from "@ai-sdk/cohere";
import { AnthropicProviderSettings, createAnthropic } from "@ai-sdk/anthropic";
import { createGroq, GroqProviderSettings } from "@ai-sdk/groq";
export type Provider = ReturnType<typeof createOpenAI> | ReturnType<typeof createCohere> | ReturnType<typeof createAnthropic> | ReturnType<typeof createGroq> | any;
export type ProviderSettings = OpenAIProviderSettings | CohereProviderSettings | AnthropicProviderSettings | GroqProviderSettings;
class Mem0AITextGenerator {
provider: Provider;
model: string;
provider_config?: ProviderSettings;
config: Mem0ProviderSettings;
constructor(provider: string, model: string, config: Mem0ProviderSettings, provider_config: ProviderSettings) {
switch (provider) {
case "openai":
this.provider = createOpenAI({
apiKey: config?.apiKey,
...provider_config,
});
if(config?.modelType === "completion"){
this.provider = createOpenAI({
apiKey: config?.apiKey,
...provider_config,
}).completion;
}else if(config?.modelType === "chat"){
this.provider = createOpenAI({
apiKey: config?.apiKey,
...provider_config,
}).chat;
}
break;
case "cohere":
this.provider = createCohere({
apiKey: config?.apiKey,
...provider_config,
});
break;
case "anthropic":
this.provider = createAnthropic({
apiKey: config?.apiKey,
...provider_config,
});
break;
case "groq":
this.provider = createGroq({
apiKey: config?.apiKey,
...provider_config,
});
break;
default:
throw new Error("Invalid provider");
}
this.model = model;
this.provider_config = provider_config;
this.config = config!;
}
async generateText(prompt: LanguageModelV1Prompt, config: Mem0Config) {
try {
const flattenPromptResponse = flattenPrompt(prompt);
const newPrompt = await retrieveMemories(prompt, config);
const response = await aiGenerateText({
// @ts-ignore
model: this.provider(this.model),
messages: prompt,
system: newPrompt
});
await updateMemories([
{ role: "user", content: flattenPromptResponse },
{ role: "assistant", content: response.text },
], config);
return response;
} catch (error) {
console.error("Error generating text:", error);
throw error;
}
}
async streamText(prompt: LanguageModelV1Prompt, config: Mem0Config) {
try {
const flattenPromptResponse = flattenPrompt(prompt);
const newPrompt = await retrieveMemories(prompt, config);
await updateMemories([
{ role: "user", content: flattenPromptResponse },
{ role: "assistant", content: "Thank You!" },
], config);
const response = await aiStreamText({
// @ts-ignore
model: this.provider(this.model),
messages: prompt,
system: newPrompt
});
return response;
} catch (error) {
console.error("Error generating text:", error);
throw error;
}
}
}
export default Mem0AITextGenerator;

View File

@@ -0,0 +1,28 @@
async function filterStream(originalStream: ReadableStream) {
const reader = originalStream.getReader();
const filteredStream = new ReadableStream({
async start(controller) {
while (true) {
const { done, value } = await reader.read();
if (done) {
controller.close();
break;
}
try {
const chunk = JSON.parse(value);
if (chunk.type !== "step-finish") {
controller.enqueue(value);
}
} catch (error) {
if (!(value.type==='step-finish')) {
controller.enqueue(value);
}
}
}
}
});
return filteredStream;
}
export { filterStream };

12
vercel-ai-sdk/teardown.ts Normal file
View File

@@ -0,0 +1,12 @@
import { testConfig } from './config/test-config';
export default async function () {
console.log("Running global teardown...");
try {
await testConfig.fetchDeleteId();
await testConfig.deleteUser();
console.log("User deleted successfully after all tests.");
} catch (error) {
console.error("Failed to delete user after all tests:", error);
}
}

View File

@@ -0,0 +1,110 @@
import dotenv from "dotenv";
dotenv.config();
import { generateObject } from "ai";
import { testConfig } from "../config/test-config";
import { z } from "zod";
interface Provider {
name: string;
activeModel: string;
apiKey: string | undefined;
}
const provider: Provider = {
name: "anthropic",
activeModel: "claude-3-5-sonnet-20240620",
apiKey: process.env.ANTHROPIC_API_KEY,
}
describe("ANTHROPIC Structured Outputs", () => {
const { userId } = testConfig;
let mem0: ReturnType<typeof testConfig.createTestClient>;
jest.setTimeout(30000);
beforeEach(() => {
mem0 = testConfig.createTestClient(provider);
});
describe("ANTHROPIC Object Generation Tests", () => {
// Test 1: Generate a car preference object
it("should generate a car preference object with name and steps", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
schema: z.object({
car: z.object({
name: z.string(),
steps: z.array(z.string()),
}),
}),
prompt: "Which car would I like?",
});
expect(object.car).toBeDefined();
expect(typeof object.car.name).toBe("string");
expect(Array.isArray(object.car.steps)).toBe(true);
expect(object.car.steps.every((step) => typeof step === "string")).toBe(true);
});
// Test 2: Generate an array of car objects
it("should generate an array of three car objects with name, class, and description", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "array",
schema: z.object({
name: z.string(),
class: z.string(),
description: z.string(),
}),
prompt: "Write name of three cars that I would like.",
});
expect(Array.isArray(object)).toBe(true);
expect(object.length).toBe(3);
object.forEach((car) => {
expect(car).toHaveProperty("name");
expect(typeof car.name).toBe("string");
expect(car).toHaveProperty("class");
expect(typeof car.class).toBe("string");
expect(car).toHaveProperty("description");
expect(typeof car.description).toBe("string");
});
});
// Test 3: Generate an enum for movie genre classification
it("should classify the genre of a movie plot", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "enum",
enum: ["action", "comedy", "drama", "horror", "sci-fi"],
prompt: 'Classify the genre of this movie plot: "A group of astronauts travel through a wormhole in search of a new habitable planet for humanity."',
});
expect(object).toBeDefined();
expect(object).toBe("sci-fi");
});
// Test 4: Generate an object of car names without schema
it("should generate an object with car names", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "no-schema",
prompt: "Write name of 3 cars that I would like.",
});
const carObject = object as { cars: string[] };
expect(carObject).toBeDefined();
expect(Array.isArray(carObject.cars)).toBe(true);
expect(carObject.cars.length).toBe(3);
expect(carObject.cars.every((car) => typeof car === "string")).toBe(true);
});
});
});

View File

@@ -0,0 +1,61 @@
import dotenv from "dotenv";
dotenv.config();
import { retrieveMemories } from "../src";
import { generateText, LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
import { createAnthropic } from "@ai-sdk/anthropic";
describe("ANTHROPIC Functions", () => {
const { userId } = testConfig;
jest.setTimeout(30000);
let anthropic: any;
beforeEach(() => {
anthropic = createAnthropic({
apiKey: process.env.ANTHROPIC_API_KEY,
});
});
it("should retrieve memories and generate text using ANTHROPIC provider", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: " Write only the car name and it's color." },
],
},
];
// Retrieve memories based on previous messages
const memories = await retrieveMemories(messages, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: anthropic("claude-3-haiku-20240307"),
messages: messages,
system: memories,
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using ANTHROPIC provider with memories", async () => {
const prompt = "Suggest me a good car to buy.";
const memories = await retrieveMemories(prompt, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: anthropic("claude-3-haiku-20240307"),
prompt: prompt,
system: memories
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,60 @@
import dotenv from "dotenv";
dotenv.config();
import { retrieveMemories } from "../src";
import { generateText, LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
import { createCohere } from "@ai-sdk/cohere";
describe("COHERE Functions", () => {
const { userId } = testConfig;
jest.setTimeout(30000);
let cohere: any;
beforeEach(() => {
cohere = createCohere({
apiKey: process.env.COHERE_API_KEY,
});
});
it("should retrieve memories and generate text using COHERE provider", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: " Write only the car name and it's color." },
],
},
];
// Retrieve memories based on previous messages
const memories = await retrieveMemories(messages, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: cohere("command-r-plus"),
messages: messages,
system: memories,
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using COHERE provider with memories", async () => {
const prompt = "Suggest me a good car to buy.";
const memories = await retrieveMemories(prompt, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: cohere("command-r-plus"),
prompt: prompt,
system: memories
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,86 @@
import { generateText, LanguageModelV1Prompt, streamText } from "ai";
import { addMemories } from "../src";
import { testConfig } from "../config/test-config";
interface Provider {
name: string;
activeModel: string;
apiKey: string | undefined;
}
describe.each(testConfig.providers)('TESTS: Generate/Stream Text with model %s', (provider: Provider) => {
const { userId } = testConfig;
let mem0: ReturnType<typeof testConfig.createTestClient>;
jest.setTimeout(50000);
beforeEach(() => {
mem0 = testConfig.createTestClient(provider);
});
beforeAll(async () => {
// Add some test memories before all tests
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "I love red cars." },
{ type: "text", text: "I like Toyota Cars." },
{ type: "text", text: "I prefer SUVs." },
],
}
];
await addMemories(messages, { user_id: userId });
});
it("should generate text using mem0 model", async () => {
const { text } = await generateText({
model: mem0(provider.activeModel, {
user_id: userId,
}),
prompt: "Suggest me a good car to buy!",
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using provider with memories", async () => {
const { text } = await generateText({
model: mem0(provider.activeModel, {
user_id: userId,
}),
messages: [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: "Write only the car name and it's color." },
],
}
],
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should stream text using Mem0 provider", async () => {
const { textStream } = await streamText({
model: mem0(provider.activeModel, {
user_id: userId, // Use the uniform userId
}),
prompt: "Suggest me a good car to buy! Write only the car name and it's color.",
});
// Collect streamed text parts
let streamedText = '';
for await (const textPart of textStream) {
streamedText += textPart;
}
// Ensure the streamed text is a string
expect(typeof streamedText).toBe('string');
expect(streamedText.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,61 @@
import dotenv from "dotenv";
dotenv.config();
import { retrieveMemories } from "../src";
import { generateText, LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
import { createGroq } from "@ai-sdk/groq";
describe("GROQ Functions", () => {
const { userId } = testConfig;
jest.setTimeout(30000);
let groq: any;
beforeEach(() => {
groq = createGroq({
apiKey: process.env.GROQ_API_KEY,
});
});
it("should retrieve memories and generate text using GROQ provider", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: " Write only the car name and it's color." },
],
},
];
// Retrieve memories based on previous messages
const memories = await retrieveMemories(messages, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: groq("gemma2-9b-it"),
messages: messages,
system: memories,
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using GROQ provider with memories", async () => {
const prompt = "Suggest me a good car to buy.";
const memories = await retrieveMemories(prompt, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: groq("gemma2-9b-it"),
prompt: prompt,
system: memories
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,75 @@
import { addMemories, retrieveMemories } from "../src";
import { LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
describe("Memory Core Functions", () => {
const { userId } = testConfig;
jest.setTimeout(10000);
describe("addMemories", () => {
it("should successfully add memories and return correct format", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "I love red cars." },
{ type: "text", text: "I like Toyota Cars." },
{ type: "text", text: "I prefer SUVs." },
],
}
];
const response = await addMemories(messages, { user_id: userId });
expect(Array.isArray(response)).toBe(true);
response.forEach((memory: { event: any; }) => {
expect(memory).toHaveProperty('id');
expect(memory).toHaveProperty('data');
expect(memory).toHaveProperty('event');
expect(memory.event).toBe('ADD');
});
});
});
describe("retrieveMemories", () => {
beforeEach(async () => {
// Add some test memories before each retrieval test
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "I love red cars." },
{ type: "text", text: "I like Toyota Cars." },
{ type: "text", text: "I prefer SUVs." },
],
}
];
await addMemories(messages, { user_id: userId });
});
it("should retrieve memories with string prompt", async () => {
const prompt = "Which car would I prefer?";
const response = await retrieveMemories(prompt, { user_id: userId });
expect(typeof response).toBe('string');
expect(response.match(/Memory:/g)?.length).toBeGreaterThan(2);
});
it("should retrieve memories with array of prompts", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Which car would I prefer?" },
{ type: "text", text: "Suggest me some cars" },
],
}
];
const response = await retrieveMemories(messages, { user_id: userId });
expect(typeof response).toBe('string');
expect(response.match(/Memory:/g)?.length).toBeGreaterThan(2);
});
});
});

View File

@@ -0,0 +1,110 @@
import dotenv from "dotenv";
dotenv.config();
import { generateObject } from "ai";
import { testConfig } from "../config/test-config";
import { z } from "zod";
interface Provider {
name: string;
activeModel: string;
apiKey: string | undefined;
}
const provider: Provider = {
name: "openai",
activeModel: "gpt-4-turbo",
apiKey: process.env.OPENAI_API_KEY,
}
describe("OPENAI Structured Outputs", () => {
const { userId } = testConfig;
let mem0: ReturnType<typeof testConfig.createTestClient>;
jest.setTimeout(30000);
beforeEach(() => {
mem0 = testConfig.createTestClient(provider);
});
describe("openai Object Generation Tests", () => {
// Test 1: Generate a car preference object
it("should generate a car preference object with name and steps", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
schema: z.object({
car: z.object({
name: z.string(),
steps: z.array(z.string()),
}),
}),
prompt: "Which car would I like?",
});
expect(object.car).toBeDefined();
expect(typeof object.car.name).toBe("string");
expect(Array.isArray(object.car.steps)).toBe(true);
expect(object.car.steps.every((step) => typeof step === "string")).toBe(true);
});
// Test 2: Generate an array of car objects
it("should generate an array of three car objects with name, class, and description", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "array",
schema: z.object({
name: z.string(),
class: z.string().describe('Cars should be "SUV", "Sedan", or "Hatchback"'),
description: z.string(),
}),
prompt: "Write name of three cars that I would like.",
});
expect(Array.isArray(object)).toBe(true);
expect(object.length).toBe(3);
object.forEach((car) => {
expect(car).toHaveProperty("name");
expect(typeof car.name).toBe("string");
expect(car).toHaveProperty("class");
expect(typeof car.class).toBe("string");
expect(car).toHaveProperty("description");
expect(typeof car.description).toBe("string");
});
});
// Test 3: Generate an enum for movie genre classification
it("should classify the genre of a movie plot", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "enum",
enum: ["action", "comedy", "drama", "horror", "sci-fi"],
prompt: 'Classify the genre of this movie plot: "A group of astronauts travel through a wormhole in search of a new habitable planet for humanity."',
});
expect(object).toBeDefined();
expect(object).toBe("sci-fi");
});
// Test 4: Generate an object of car names without schema
it("should generate an object with car names", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "no-schema",
prompt: "Write name of 3 cars that I would like.",
});
const carObject = object as { cars: string[] };
expect(carObject).toBeDefined();
expect(Array.isArray(carObject.cars)).toBe(true);
expect(carObject.cars.length).toBe(3);
expect(carObject.cars.every((car) => typeof car === "string")).toBe(true);
});
});
});

View File

@@ -0,0 +1,58 @@
import dotenv from "dotenv";
dotenv.config();
import { retrieveMemories } from "../src";
import { generateText, LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
import { createOpenAI } from "@ai-sdk/openai";
describe("OPENAI Functions", () => {
const { userId } = testConfig;
jest.setTimeout(30000);
let openai: any;
beforeEach(() => {
openai = createOpenAI({
apiKey: process.env.OPENAI_API_KEY,
});
});
it("should retrieve memories and generate text using OpenAI provider", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: " Write only the car name and it's color." },
],
},
];
// Retrieve memories based on previous messages
const memories = await retrieveMemories(messages, { user_id: userId });
const { text } = await generateText({
model: openai("gpt-4-turbo"),
messages: messages,
system: memories,
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using openai provider with memories", async () => {
const prompt = "Suggest me a good car to buy.";
const memories = await retrieveMemories(prompt, { user_id: userId });
const { text } = await generateText({
model: openai("gpt-4-turbo"),
prompt: prompt,
system: memories
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,77 @@
import { generateText, streamText } from "ai";
import { testConfig } from "../config/test-config";
interface Provider {
name: string;
activeModel: string;
apiKey: string | undefined;
}
describe.each(testConfig.providers)('TEXT/STREAM PROPERTIES: Tests with model %s', (provider: Provider) => {
const { userId } = testConfig;
let mem0: ReturnType<typeof testConfig.createTestClient>;
jest.setTimeout(50000);
beforeEach(() => {
mem0 = testConfig.createTestClient(provider);
});
it("should stream text with onChunk handler", async () => {
const chunkTexts: string[] = [];
const { textStream } = await streamText({
model: mem0(provider.activeModel, {
user_id: userId, // Use the uniform userId
}),
prompt: "Write only the name of the car I prefer and its color.",
onChunk({ chunk }) {
if (chunk.type === "text-delta") {
// Store chunk text for assertions
chunkTexts.push(chunk.textDelta);
}
},
});
// Wait for the stream to complete
for await (const _ of textStream) {
}
// Ensure chunks are collected
expect(chunkTexts.length).toBeGreaterThan(0);
expect(chunkTexts.every((text) => typeof text === "string")).toBe(true);
});
it("should call onFinish handler without throwing an error", async () => {
await streamText({
model: mem0(provider.activeModel, {
user_id: userId, // Use the uniform userId
}),
prompt: "Write only the name of the car I prefer and its color.",
onFinish({ text, finishReason, usage }) {
},
});
});
it("should generate fullStream with expected usage", async () => {
const {
text, // combined text
usage, // combined usage of all steps
} = await generateText({
model: mem0(provider.activeModel), // Ensure the model name is correct
maxSteps: 5, // Enable multi-step calls
experimental_continueSteps: true,
prompt:
"Suggest me some good cars to buy. Each response MUST HAVE at least 200 words.",
});
// Ensure text is a string
expect(typeof text).toBe("string");
// Check usage
// promptTokens is a number, so we use toBeCloseTo instead of toBe and it should be in the range 155 to 165
expect(usage.promptTokens).toBeGreaterThanOrEqual(100);
expect(usage.promptTokens).toBeLessThanOrEqual(500);
expect(usage.completionTokens).toBeGreaterThanOrEqual(250); // Check completion tokens are above 250
expect(usage.totalTokens).toBeGreaterThan(400); // Check total tokens are above 400
});
});

View File

@@ -0,0 +1,29 @@
{
"$schema": "https://json.schemastore.org/tsconfig",
"compilerOptions": {
"composite": false,
"declaration": true,
"declarationMap": true,
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"inlineSources": false,
"isolatedModules": true,
"moduleResolution": "node",
"noUnusedLocals": false,
"noUnusedParameters": false,
"preserveWatchOutput": true,
"skipLibCheck": true,
"strict": true,
"types": ["@types/node", "jest"],
"jsx": "react-jsx",
"lib": ["dom", "ES2021"],
"module": "ESNext",
"target": "ES2018",
"stripInternal": true,
"paths": {
"@/*": ["./src/*"]
}
},
"include": ["."],
"exclude": ["dist", "build", "node_modules"]
}

View File

@@ -0,0 +1,10 @@
import { defineConfig } from 'tsup'
export default defineConfig([
{
dts: true,
entry: ['src/index.ts'],
format: ['cjs', 'esm'],
sourcemap: true,
},
])