(Feature) Vercel AI SDK (#2024)
This commit is contained in:
4
vercel-ai-sdk/src/index.ts
Normal file
4
vercel-ai-sdk/src/index.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
export * from './mem0-facade'
|
||||
export type { Mem0Provider, Mem0ProviderSettings } from './mem0-provider'
|
||||
export { createMem0, mem0 } from './mem0-provider'
|
||||
export {addMemories, retrieveMemories, searchMemories } from './mem0-utils'
|
||||
150
vercel-ai-sdk/src/mem0-chat-language-model.ts
Normal file
150
vercel-ai-sdk/src/mem0-chat-language-model.ts
Normal file
@@ -0,0 +1,150 @@
|
||||
/* eslint-disable camelcase */
|
||||
import {
|
||||
LanguageModelV1,
|
||||
LanguageModelV1CallOptions,
|
||||
LanguageModelV1CallWarning,
|
||||
LanguageModelV1FinishReason,
|
||||
LanguageModelV1FunctionToolCall,
|
||||
LanguageModelV1LogProbs,
|
||||
LanguageModelV1ProviderMetadata,
|
||||
LanguageModelV1StreamPart,
|
||||
} from "@ai-sdk/provider";
|
||||
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
|
||||
import { Mem0ClassSelector } from "./mem0-provider-selector";
|
||||
import { filterStream } from "./stream-utils";
|
||||
import { Mem0Config } from "./mem0-chat-settings";
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
|
||||
|
||||
interface Mem0ChatConfig {
|
||||
baseURL: string;
|
||||
fetch?: typeof fetch;
|
||||
headers: () => Record<string, string | undefined>;
|
||||
provider: string;
|
||||
organization?: string;
|
||||
project?: string;
|
||||
name?: string;
|
||||
apiKey?: string;
|
||||
mem0_api_key?: string;
|
||||
}
|
||||
|
||||
export class Mem0ChatLanguageModel implements LanguageModelV1 {
|
||||
readonly specificationVersion = "v1";
|
||||
readonly defaultObjectGenerationMode = "json";
|
||||
readonly supportsImageUrls = false;
|
||||
|
||||
constructor(
|
||||
public readonly modelId: Mem0ChatModelId,
|
||||
public readonly settings: Mem0ChatSettings,
|
||||
public readonly config: Mem0ChatConfig,
|
||||
public readonly provider_config?: OpenAIProviderSettings
|
||||
) {
|
||||
this.provider = config.provider;
|
||||
}
|
||||
|
||||
provider: string;
|
||||
supportsStructuredOutputs?: boolean | undefined;
|
||||
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
|
||||
text?: string;
|
||||
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
|
||||
finishReason: LanguageModelV1FinishReason;
|
||||
usage: { promptTokens: number; completionTokens: number };
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
response?: { id?: string; timestamp?: Date; modelId?: string };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
providerMetadata?: LanguageModelV1ProviderMetadata;
|
||||
logprobs?: LanguageModelV1LogProbs;
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: "chat"
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
|
||||
const ans = await model.generateText(messagesPrompts, config);
|
||||
|
||||
|
||||
return {
|
||||
text: ans.text,
|
||||
finishReason: ans.finishReason,
|
||||
usage: ans.usage,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
response: ans.response,
|
||||
warnings: ans.warnings,
|
||||
};
|
||||
} catch (error) {
|
||||
// Handle errors properly
|
||||
console.error("Error in doGenerate:", error);
|
||||
throw new Error("Failed to generate response.");
|
||||
}
|
||||
}
|
||||
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<{
|
||||
stream: ReadableStream<LanguageModelV1StreamPart>;
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: "chat"
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
const response = await model.streamText(messagesPrompts, config);
|
||||
// @ts-ignore
|
||||
const filteredStream = await filterStream(response.originalStream);
|
||||
return {
|
||||
// @ts-ignore
|
||||
stream: filteredStream,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
...response,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error in doStream:", error);
|
||||
throw new Error("Streaming failed or method not implemented.");
|
||||
}
|
||||
}
|
||||
}
|
||||
36
vercel-ai-sdk/src/mem0-chat-settings.ts
Normal file
36
vercel-ai-sdk/src/mem0-chat-settings.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { OpenAIChatSettings } from "@ai-sdk/openai/internal";
|
||||
|
||||
export type Mem0ChatModelId =
|
||||
| "o1-preview"
|
||||
| "o1-mini"
|
||||
| "gpt-4o"
|
||||
| "gpt-4o-2024-05-13"
|
||||
| "gpt-4o-2024-08-06"
|
||||
| "gpt-4o-audio-preview"
|
||||
| "gpt-4o-audio-preview-2024-10-01"
|
||||
| "gpt-4o-mini"
|
||||
| "gpt-4o-mini-2024-07-18"
|
||||
| "gpt-4-turbo"
|
||||
| "gpt-4-turbo-2024-04-09"
|
||||
| "gpt-4-turbo-preview"
|
||||
| "gpt-4-0125-preview"
|
||||
| "gpt-4-1106-preview"
|
||||
| "gpt-4"
|
||||
| "gpt-4-0613"
|
||||
| "gpt-3.5-turbo-0125"
|
||||
| "gpt-3.5-turbo"
|
||||
| "gpt-3.5-turbo-1106"
|
||||
| (string & NonNullable<unknown>);
|
||||
|
||||
export interface Mem0ChatSettings extends OpenAIChatSettings {
|
||||
user_id?: string;
|
||||
app_id?: string;
|
||||
agent_id?: string;
|
||||
run_id?: string;
|
||||
org_name?: string;
|
||||
project_name?: string;
|
||||
mem0ApiKey?: string;
|
||||
structuredOutputs?: boolean;
|
||||
}
|
||||
|
||||
export interface Mem0Config extends Mem0ChatSettings {}
|
||||
150
vercel-ai-sdk/src/mem0-completion-language-model.ts
Normal file
150
vercel-ai-sdk/src/mem0-completion-language-model.ts
Normal file
@@ -0,0 +1,150 @@
|
||||
/* eslint-disable camelcase */
|
||||
import {
|
||||
LanguageModelV1,
|
||||
LanguageModelV1CallOptions,
|
||||
LanguageModelV1CallWarning,
|
||||
LanguageModelV1FinishReason,
|
||||
LanguageModelV1FunctionToolCall,
|
||||
LanguageModelV1LogProbs,
|
||||
LanguageModelV1ProviderMetadata,
|
||||
LanguageModelV1StreamPart,
|
||||
} from "@ai-sdk/provider";
|
||||
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
|
||||
import { Mem0ClassSelector } from "./mem0-provider-selector";
|
||||
import { filterStream } from "./stream-utils";
|
||||
import { Mem0Config } from "./mem0-completion-settings";
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
|
||||
|
||||
interface Mem0CompletionConfig {
|
||||
baseURL: string;
|
||||
fetch?: typeof fetch;
|
||||
headers: () => Record<string, string | undefined>;
|
||||
provider: string;
|
||||
organization?: string;
|
||||
project?: string;
|
||||
name?: string;
|
||||
apiKey?: string;
|
||||
mem0_api_key?: string;
|
||||
}
|
||||
|
||||
export class Mem0CompletionLanguageModel implements LanguageModelV1 {
|
||||
readonly specificationVersion = "v1";
|
||||
readonly defaultObjectGenerationMode = "json";
|
||||
readonly supportsImageUrls = false;
|
||||
|
||||
constructor(
|
||||
public readonly modelId: Mem0ChatModelId,
|
||||
public readonly settings: Mem0ChatSettings,
|
||||
public readonly config: Mem0CompletionConfig,
|
||||
public readonly provider_config?: OpenAIProviderSettings
|
||||
) {
|
||||
this.provider = config.provider;
|
||||
}
|
||||
|
||||
provider: string;
|
||||
supportsStructuredOutputs?: boolean | undefined;
|
||||
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
|
||||
text?: string;
|
||||
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
|
||||
finishReason: LanguageModelV1FinishReason;
|
||||
usage: { promptTokens: number; completionTokens: number };
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
response?: { id?: string; timestamp?: Date; modelId?: string };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
providerMetadata?: LanguageModelV1ProviderMetadata;
|
||||
logprobs?: LanguageModelV1LogProbs;
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: "completion"
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey, modelType: "completion"};
|
||||
|
||||
const ans = await model.generateText(messagesPrompts, config);
|
||||
|
||||
|
||||
return {
|
||||
text: ans.text,
|
||||
finishReason: ans.finishReason,
|
||||
usage: ans.usage,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
response: ans.response,
|
||||
warnings: ans.warnings,
|
||||
};
|
||||
} catch (error) {
|
||||
// Handle errors properly
|
||||
console.error("Error in doGenerate:", error);
|
||||
throw new Error("Failed to generate response.");
|
||||
}
|
||||
}
|
||||
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<{
|
||||
stream: ReadableStream<LanguageModelV1StreamPart>;
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
modelType: "completion"
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey, modelType: "completion"};
|
||||
const response = await model.streamText(messagesPrompts, config);
|
||||
// @ts-ignore
|
||||
const filteredStream = await filterStream(response.originalStream);
|
||||
return {
|
||||
// @ts-ignore
|
||||
stream: filteredStream,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
...response,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error in doStream:", error);
|
||||
throw new Error("Streaming failed or method not implemented.");
|
||||
}
|
||||
}
|
||||
}
|
||||
19
vercel-ai-sdk/src/mem0-completion-settings.ts
Normal file
19
vercel-ai-sdk/src/mem0-completion-settings.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { OpenAICompletionSettings } from "@ai-sdk/openai/internal";
|
||||
|
||||
export type Mem0CompletionModelId =
|
||||
| "gpt-3.5-turbo"
|
||||
| (string & NonNullable<unknown>);
|
||||
|
||||
export interface Mem0CompletionSettings extends OpenAICompletionSettings {
|
||||
user_id?: string;
|
||||
app_id?: string;
|
||||
agent_id?: string;
|
||||
run_id?: string;
|
||||
org_name?: string;
|
||||
project_name?: string;
|
||||
mem0ApiKey?: string;
|
||||
structuredOutputs?: boolean;
|
||||
modelType?: string;
|
||||
}
|
||||
|
||||
export interface Mem0Config extends Mem0CompletionSettings {}
|
||||
36
vercel-ai-sdk/src/mem0-facade.ts
Normal file
36
vercel-ai-sdk/src/mem0-facade.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { withoutTrailingSlash } from '@ai-sdk/provider-utils'
|
||||
|
||||
import { Mem0ChatLanguageModel } from './mem0-chat-language-model'
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from './mem0-chat-settings'
|
||||
import { Mem0ProviderSettings } from './mem0-provider'
|
||||
|
||||
export class Mem0 {
|
||||
readonly baseURL: string
|
||||
|
||||
readonly headers?: Record<string, string>
|
||||
|
||||
constructor(options: Mem0ProviderSettings = {
|
||||
provider: 'openai',
|
||||
}) {
|
||||
this.baseURL =
|
||||
withoutTrailingSlash(options.baseURL) ?? 'http://127.0.0.1:11434/api'
|
||||
|
||||
this.headers = options.headers
|
||||
}
|
||||
|
||||
private get baseConfig() {
|
||||
return {
|
||||
baseURL: this.baseURL,
|
||||
headers: () => ({
|
||||
...this.headers,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
chat(modelId: Mem0ChatModelId, settings: Mem0ChatSettings = {}) {
|
||||
return new Mem0ChatLanguageModel(modelId, settings, {
|
||||
provider: 'openai',
|
||||
...this.baseConfig,
|
||||
})
|
||||
}
|
||||
}
|
||||
148
vercel-ai-sdk/src/mem0-generic-language-model.ts
Normal file
148
vercel-ai-sdk/src/mem0-generic-language-model.ts
Normal file
@@ -0,0 +1,148 @@
|
||||
/* eslint-disable camelcase */
|
||||
import {
|
||||
LanguageModelV1,
|
||||
LanguageModelV1CallOptions,
|
||||
LanguageModelV1CallWarning,
|
||||
LanguageModelV1FinishReason,
|
||||
LanguageModelV1FunctionToolCall,
|
||||
LanguageModelV1LogProbs,
|
||||
LanguageModelV1ProviderMetadata,
|
||||
LanguageModelV1StreamPart,
|
||||
} from "@ai-sdk/provider";
|
||||
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from "./mem0-chat-settings";
|
||||
import { Mem0ClassSelector } from "./mem0-provider-selector";
|
||||
import { filterStream } from "./stream-utils";
|
||||
import { Mem0Config } from "./mem0-chat-settings";
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
|
||||
|
||||
interface Mem0ChatConfig {
|
||||
baseURL: string;
|
||||
fetch?: typeof fetch;
|
||||
headers: () => Record<string, string | undefined>;
|
||||
provider: string;
|
||||
organization?: string;
|
||||
project?: string;
|
||||
name?: string;
|
||||
apiKey?: string;
|
||||
mem0_api_key?: string;
|
||||
}
|
||||
|
||||
export class Mem0GenericLanguageModel implements LanguageModelV1 {
|
||||
readonly specificationVersion = "v1";
|
||||
readonly defaultObjectGenerationMode = "json";
|
||||
readonly supportsImageUrls = false;
|
||||
|
||||
constructor(
|
||||
public readonly modelId: Mem0ChatModelId,
|
||||
public readonly settings: Mem0ChatSettings,
|
||||
public readonly config: Mem0ChatConfig,
|
||||
public readonly provider_config?: OpenAIProviderSettings
|
||||
) {
|
||||
this.provider = config.provider;
|
||||
}
|
||||
|
||||
provider: string;
|
||||
supportsStructuredOutputs?: boolean | undefined;
|
||||
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<{
|
||||
text?: string;
|
||||
toolCalls?: Array<LanguageModelV1FunctionToolCall>;
|
||||
finishReason: LanguageModelV1FinishReason;
|
||||
usage: { promptTokens: number; completionTokens: number };
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
response?: { id?: string; timestamp?: Date; modelId?: string };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
providerMetadata?: LanguageModelV1ProviderMetadata;
|
||||
logprobs?: LanguageModelV1LogProbs;
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
|
||||
const ans = await model.generateText(messagesPrompts, config);
|
||||
|
||||
|
||||
return {
|
||||
text: ans.text,
|
||||
finishReason: ans.finishReason,
|
||||
usage: ans.usage,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
response: ans.response,
|
||||
warnings: ans.warnings,
|
||||
};
|
||||
} catch (error) {
|
||||
// Handle errors properly
|
||||
console.error("Error in doGenerate:", error);
|
||||
throw new Error("Failed to generate response.");
|
||||
}
|
||||
}
|
||||
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<{
|
||||
stream: ReadableStream<LanguageModelV1StreamPart>;
|
||||
rawCall: { rawPrompt: unknown; rawSettings: Record<string, unknown> };
|
||||
rawResponse?: { headers?: Record<string, string> };
|
||||
warnings?: LanguageModelV1CallWarning[];
|
||||
}> {
|
||||
try {
|
||||
const provider = this.config.provider;
|
||||
const mem0_api_key = this.config.mem0_api_key;
|
||||
const settings: Mem0ProviderSettings = {
|
||||
provider: provider,
|
||||
mem0ApiKey: mem0_api_key,
|
||||
apiKey: this.config.apiKey,
|
||||
}
|
||||
const selector = new Mem0ClassSelector(this.modelId, settings,this.provider_config);
|
||||
let messagesPrompts = options.prompt;
|
||||
const model = selector.createProvider();
|
||||
const user_id = this.settings.user_id;
|
||||
const app_id = this.settings.app_id;
|
||||
const agent_id = this.settings.agent_id;
|
||||
const run_id = this.settings.run_id;
|
||||
const org_name = this.settings.org_name;
|
||||
const project_name = this.settings.project_name;
|
||||
|
||||
const apiKey = mem0_api_key;
|
||||
|
||||
const config: Mem0Config = {user_id, app_id, agent_id, run_id, org_name, project_name, mem0ApiKey: apiKey};
|
||||
const response = await model.streamText(messagesPrompts, config);
|
||||
// @ts-ignore
|
||||
const filteredStream = await filterStream(response.originalStream);
|
||||
return {
|
||||
// @ts-ignore
|
||||
stream: filteredStream,
|
||||
rawCall: {
|
||||
rawPrompt: options.prompt,
|
||||
rawSettings: {},
|
||||
},
|
||||
...response,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error in doStream:", error);
|
||||
throw new Error("Streaming failed or method not implemented.");
|
||||
}
|
||||
}
|
||||
}
|
||||
34
vercel-ai-sdk/src/mem0-provider-selector.ts
Normal file
34
vercel-ai-sdk/src/mem0-provider-selector.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
import Mem0AITextGenerator, { ProviderSettings } from "./provider-response-provider";
|
||||
|
||||
class Mem0ClassSelector {
|
||||
modelId: string;
|
||||
provider_wrapper: string;
|
||||
model: string;
|
||||
config: Mem0ProviderSettings;
|
||||
provider_config?: ProviderSettings;
|
||||
static supportedProviders = ["openai", "anthropic", "cohere", "groq"];
|
||||
|
||||
constructor(modelId: string, config: Mem0ProviderSettings, provider_config?: ProviderSettings) {
|
||||
this.modelId = modelId;
|
||||
this.provider_wrapper = config.provider || "openai";
|
||||
this.model = this.modelId;
|
||||
this.provider_config = provider_config;
|
||||
if(config) this.config = config;
|
||||
else this.config = {
|
||||
provider: this.provider_wrapper,
|
||||
};
|
||||
|
||||
// Check if provider_wrapper is supported
|
||||
if (!Mem0ClassSelector.supportedProviders.includes(this.provider_wrapper)) {
|
||||
throw new Error(`Model not supported: ${this.provider_wrapper}`);
|
||||
}
|
||||
}
|
||||
|
||||
createProvider() {
|
||||
return new Mem0AITextGenerator(this.provider_wrapper, this.model, this.config , this.provider_config || {});
|
||||
}
|
||||
}
|
||||
|
||||
export { Mem0ClassSelector };
|
||||
145
vercel-ai-sdk/src/mem0-provider.ts
Normal file
145
vercel-ai-sdk/src/mem0-provider.ts
Normal file
@@ -0,0 +1,145 @@
|
||||
import { LanguageModelV1, ProviderV1 } from '@ai-sdk/provider'
|
||||
import { withoutTrailingSlash } from '@ai-sdk/provider-utils'
|
||||
|
||||
import { Mem0ChatLanguageModel } from './mem0-chat-language-model'
|
||||
import { Mem0ChatModelId, Mem0ChatSettings } from './mem0-chat-settings'
|
||||
import { OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { Mem0CompletionModelId, Mem0CompletionSettings } from './mem0-completion-settings'
|
||||
import { Mem0GenericLanguageModel } from './mem0-generic-language-model'
|
||||
import { Mem0CompletionLanguageModel } from './mem0-completion-language-model'
|
||||
|
||||
|
||||
export interface Mem0Provider extends ProviderV1 {
|
||||
(modelId: Mem0ChatModelId, settings?: Mem0ChatSettings): LanguageModelV1
|
||||
|
||||
chat(
|
||||
modelId: Mem0ChatModelId,
|
||||
settings?: Mem0ChatSettings,
|
||||
): LanguageModelV1
|
||||
|
||||
|
||||
languageModel(
|
||||
modelId: Mem0ChatModelId,
|
||||
settings?: Mem0ChatSettings,
|
||||
): LanguageModelV1
|
||||
|
||||
completion(
|
||||
modelId: Mem0CompletionModelId,
|
||||
settings?: Mem0CompletionSettings,
|
||||
): LanguageModelV1
|
||||
}
|
||||
|
||||
export interface Mem0ProviderSettings extends OpenAIProviderSettings {
|
||||
baseURL?: string
|
||||
/**
|
||||
* Custom fetch implementation. You can use it as a middleware to intercept
|
||||
* requests or to provide a custom fetch implementation for e.g. testing
|
||||
*/
|
||||
fetch?: typeof fetch
|
||||
/**
|
||||
* @internal
|
||||
*/
|
||||
generateId?: () => string
|
||||
/**
|
||||
* Custom headers to include in the requests.
|
||||
*/
|
||||
headers?: Record<string, string>
|
||||
organization?: string;
|
||||
project?: string;
|
||||
name?: string;
|
||||
mem0ApiKey?: string;
|
||||
apiKey?: string;
|
||||
provider?: string;
|
||||
config?: OpenAIProviderSettings;
|
||||
modelType?: "completion" | "chat";
|
||||
}
|
||||
|
||||
export function createMem0(
|
||||
options: Mem0ProviderSettings = {
|
||||
provider: "openai",
|
||||
},
|
||||
): Mem0Provider {
|
||||
const baseURL =
|
||||
withoutTrailingSlash(options.baseURL) ?? 'http://127.0.0.1:11434/api'
|
||||
|
||||
const getHeaders = () => ({
|
||||
...options.headers,
|
||||
})
|
||||
|
||||
const createGenericModel = (
|
||||
modelId: Mem0ChatModelId,
|
||||
settings: Mem0ChatSettings = {},
|
||||
) =>
|
||||
new Mem0GenericLanguageModel(modelId, settings, {
|
||||
baseURL,
|
||||
fetch: options.fetch,
|
||||
headers: getHeaders,
|
||||
provider: options.provider || "openai",
|
||||
organization: options.organization,
|
||||
project: options.project,
|
||||
name: options.name,
|
||||
mem0_api_key: options.mem0ApiKey,
|
||||
apiKey: options.apiKey,
|
||||
}, options.config)
|
||||
|
||||
const createChatModel = (
|
||||
modelId: Mem0ChatModelId,
|
||||
settings: Mem0ChatSettings = {},
|
||||
) =>
|
||||
|
||||
new Mem0ChatLanguageModel(modelId, settings, {
|
||||
baseURL,
|
||||
fetch: options.fetch,
|
||||
headers: getHeaders,
|
||||
provider: options.provider || "openai",
|
||||
organization: options.organization,
|
||||
project: options.project,
|
||||
name: options.name,
|
||||
mem0_api_key: options.mem0ApiKey,
|
||||
apiKey: options.apiKey,
|
||||
}, options.config)
|
||||
|
||||
const createCompletionModel = (
|
||||
modelId: Mem0CompletionModelId,
|
||||
settings: Mem0CompletionSettings = {}
|
||||
) =>
|
||||
new Mem0CompletionLanguageModel(
|
||||
modelId,
|
||||
settings,
|
||||
{
|
||||
baseURL,
|
||||
fetch: options.fetch,
|
||||
headers: getHeaders,
|
||||
provider: options.provider || "openai",
|
||||
organization: options.organization,
|
||||
project: options.project,
|
||||
name: options.name,
|
||||
mem0_api_key: options.mem0ApiKey,
|
||||
apiKey: options.apiKey
|
||||
},
|
||||
options.config
|
||||
);
|
||||
|
||||
const provider = function (
|
||||
modelId: Mem0ChatModelId,
|
||||
settings?: Mem0ChatSettings,
|
||||
) {
|
||||
if (new.target) {
|
||||
throw new Error(
|
||||
'The Mem0 model function cannot be called with the new keyword.',
|
||||
)
|
||||
}
|
||||
|
||||
return createGenericModel(modelId, settings)
|
||||
}
|
||||
|
||||
|
||||
|
||||
provider.chat = createChatModel
|
||||
provider.completion = createCompletionModel
|
||||
provider.languageModel = createChatModel
|
||||
|
||||
return provider as unknown as Mem0Provider
|
||||
}
|
||||
|
||||
export const mem0 = createMem0()
|
||||
114
vercel-ai-sdk/src/mem0-utils.ts
Normal file
114
vercel-ai-sdk/src/mem0-utils.ts
Normal file
@@ -0,0 +1,114 @@
|
||||
import { LanguageModelV1Prompt } from 'ai';
|
||||
import { Mem0Config } from './mem0-chat-settings';
|
||||
if (typeof process !== 'undefined' && process.env && process.env.NODE_ENV !== 'production') {
|
||||
// Dynamically import dotenv only in non-production environments
|
||||
import('dotenv').then((dotenv) => dotenv.config());
|
||||
}
|
||||
|
||||
const tokenIsPresent = (config?: Mem0Config)=>{
|
||||
if(!config && !config!.mem0ApiKey && (typeof process !== 'undefined' && process.env && !process.env.MEM0_API_KEY)){
|
||||
throw Error("MEM0_API_KEY is not present. Please set env MEM0_API_KEY as the value of your API KEY.");
|
||||
}
|
||||
}
|
||||
|
||||
interface Message {
|
||||
role: string;
|
||||
content: string | Array<{type: string, text: string}>;
|
||||
}
|
||||
|
||||
const flattenPrompt = (prompt: LanguageModelV1Prompt) => {
|
||||
return prompt.map((part) => {
|
||||
if (part.role === "user") {
|
||||
return part.content
|
||||
.filter((obj) => obj.type === 'text')
|
||||
.map((obj) => obj.text)
|
||||
.join(" ");
|
||||
}
|
||||
return "";
|
||||
}).join(" ");
|
||||
}
|
||||
|
||||
const searchInternalMemories = async (query: string, config?: Mem0Config, top_k: number = 5)=> {
|
||||
tokenIsPresent(config);
|
||||
const filters = {
|
||||
OR: [
|
||||
{
|
||||
user_id: config&&config.user_id,
|
||||
},
|
||||
{
|
||||
app_id: config&&config.app_id,
|
||||
},
|
||||
{
|
||||
agent_id: config&&config.agent_id,
|
||||
},
|
||||
{
|
||||
run_id: config&&config.run_id,
|
||||
},
|
||||
],
|
||||
};
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {Authorization: `Token ${(config&&config.mem0ApiKey) || (typeof process !== 'undefined' && process.env && process.env.MEM0_API_KEY) || ""}`, 'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({query, filters, top_k, version: "v2", org_name: config&&config.org_name, project_name: config&&config.project_name}),
|
||||
};
|
||||
const response = await fetch('https://api.mem0.ai/v2/memories/search/', options);
|
||||
const data = await response.json();
|
||||
return data;
|
||||
}
|
||||
|
||||
const addMemories = async (messages: LanguageModelV1Prompt, config?: Mem0Config)=>{
|
||||
tokenIsPresent(config);
|
||||
const message = flattenPrompt(messages);
|
||||
const response = await updateMemories([
|
||||
{ role: "user", content: message },
|
||||
{ role: "assistant", content: "Thank You!" },
|
||||
], config);
|
||||
return response;
|
||||
}
|
||||
|
||||
const updateMemories = async (messages: Array<Message>, config?: Mem0Config)=>{
|
||||
tokenIsPresent(config);
|
||||
const options = {
|
||||
method: 'POST',
|
||||
headers: {Authorization: `Token ${(config&&config.mem0ApiKey) || (typeof process !== 'undefined' && process.env && process.env.MEM0_API_KEY) || ""}`, 'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({messages, ...config}),
|
||||
};
|
||||
|
||||
const response = await fetch('https://api.mem0.ai/v1/memories/', options);
|
||||
const data = await response.json();
|
||||
return data;
|
||||
}
|
||||
|
||||
const retrieveMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0Config)=>{
|
||||
tokenIsPresent(config);
|
||||
const message = typeof prompt === 'string' ? prompt : flattenPrompt(prompt);
|
||||
const systemPrompt = "These are the memories I have stored. Give more weightage to the question by users and try to answer that first. You have to modify your answer based on the memories I have provided. If the memories are irrelevant you can ignore them. Also don't reply to this section of the prompt, or the memories, they are only for your reference. The System prompt starts after text System Message: \n\n";
|
||||
const memories = await searchInternalMemories(message, config);
|
||||
let memoriesText = "";
|
||||
try{
|
||||
// @ts-ignore
|
||||
memoriesText = memories.map((memory: any)=>{
|
||||
return `Memory: ${memory.memory}\n\n`;
|
||||
}).join("\n\n");
|
||||
}catch(e){
|
||||
console.error("Error while parsing memories");
|
||||
// console.log(e);
|
||||
}
|
||||
return `System Message: ${systemPrompt} ${memoriesText}`;
|
||||
}
|
||||
|
||||
const searchMemories = async (prompt: LanguageModelV1Prompt | string, config?: Mem0Config)=>{
|
||||
tokenIsPresent(config);
|
||||
const message = typeof prompt === 'string' ? prompt : flattenPrompt(prompt);
|
||||
let memories = [];
|
||||
try{
|
||||
// @ts-ignore
|
||||
memories = await searchInternalMemories(message, config);
|
||||
}
|
||||
catch(e){
|
||||
console.error("Error while searching memories");
|
||||
}
|
||||
return memories;
|
||||
}
|
||||
|
||||
export {addMemories, updateMemories, retrieveMemories, flattenPrompt, searchMemories};
|
||||
113
vercel-ai-sdk/src/provider-response-provider.ts
Normal file
113
vercel-ai-sdk/src/provider-response-provider.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import { createOpenAI, OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { generateText as aiGenerateText, streamText as aiStreamText, LanguageModelV1Prompt } from "ai";
|
||||
import { updateMemories, retrieveMemories, flattenPrompt } from "./mem0-utils";
|
||||
import { Mem0Config } from "./mem0-chat-settings";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
import { CohereProviderSettings, createCohere } from "@ai-sdk/cohere";
|
||||
import { AnthropicProviderSettings, createAnthropic } from "@ai-sdk/anthropic";
|
||||
import { createGroq, GroqProviderSettings } from "@ai-sdk/groq";
|
||||
|
||||
export type Provider = ReturnType<typeof createOpenAI> | ReturnType<typeof createCohere> | ReturnType<typeof createAnthropic> | ReturnType<typeof createGroq> | any;
|
||||
export type ProviderSettings = OpenAIProviderSettings | CohereProviderSettings | AnthropicProviderSettings | GroqProviderSettings;
|
||||
|
||||
class Mem0AITextGenerator {
|
||||
provider: Provider;
|
||||
model: string;
|
||||
provider_config?: ProviderSettings;
|
||||
config: Mem0ProviderSettings;
|
||||
|
||||
constructor(provider: string, model: string, config: Mem0ProviderSettings, provider_config: ProviderSettings) {
|
||||
switch (provider) {
|
||||
case "openai":
|
||||
this.provider = createOpenAI({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
});
|
||||
if(config?.modelType === "completion"){
|
||||
this.provider = createOpenAI({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
}).completion;
|
||||
}else if(config?.modelType === "chat"){
|
||||
this.provider = createOpenAI({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
}).chat;
|
||||
}
|
||||
break;
|
||||
case "cohere":
|
||||
this.provider = createCohere({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
});
|
||||
break;
|
||||
case "anthropic":
|
||||
this.provider = createAnthropic({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
});
|
||||
break;
|
||||
case "groq":
|
||||
this.provider = createGroq({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
});
|
||||
break;
|
||||
default:
|
||||
throw new Error("Invalid provider");
|
||||
}
|
||||
this.model = model;
|
||||
this.provider_config = provider_config;
|
||||
this.config = config!;
|
||||
}
|
||||
|
||||
|
||||
async generateText(prompt: LanguageModelV1Prompt, config: Mem0Config) {
|
||||
try {
|
||||
const flattenPromptResponse = flattenPrompt(prompt);
|
||||
const newPrompt = await retrieveMemories(prompt, config);
|
||||
const response = await aiGenerateText({
|
||||
// @ts-ignore
|
||||
model: this.provider(this.model),
|
||||
messages: prompt,
|
||||
system: newPrompt
|
||||
});
|
||||
|
||||
await updateMemories([
|
||||
{ role: "user", content: flattenPromptResponse },
|
||||
{ role: "assistant", content: response.text },
|
||||
], config);
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
console.error("Error generating text:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async streamText(prompt: LanguageModelV1Prompt, config: Mem0Config) {
|
||||
try {
|
||||
const flattenPromptResponse = flattenPrompt(prompt);
|
||||
const newPrompt = await retrieveMemories(prompt, config);
|
||||
|
||||
await updateMemories([
|
||||
{ role: "user", content: flattenPromptResponse },
|
||||
{ role: "assistant", content: "Thank You!" },
|
||||
], config);
|
||||
|
||||
const response = await aiStreamText({
|
||||
// @ts-ignore
|
||||
model: this.provider(this.model),
|
||||
messages: prompt,
|
||||
system: newPrompt
|
||||
});
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
console.error("Error generating text:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default Mem0AITextGenerator;
|
||||
28
vercel-ai-sdk/src/stream-utils.ts
Normal file
28
vercel-ai-sdk/src/stream-utils.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
async function filterStream(originalStream: ReadableStream) {
|
||||
const reader = originalStream.getReader();
|
||||
const filteredStream = new ReadableStream({
|
||||
async start(controller) {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
controller.close();
|
||||
break;
|
||||
}
|
||||
try {
|
||||
const chunk = JSON.parse(value);
|
||||
if (chunk.type !== "step-finish") {
|
||||
controller.enqueue(value);
|
||||
}
|
||||
} catch (error) {
|
||||
if (!(value.type==='step-finish')) {
|
||||
controller.enqueue(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return filteredStream;
|
||||
}
|
||||
|
||||
export { filterStream };
|
||||
Reference in New Issue
Block a user