From 05f9607282a8f7d39a00c03cd836d8f5b9b21671 Mon Sep 17 00:00:00 2001 From: Saket Aryan Date: Fri, 11 Apr 2025 20:09:20 +0530 Subject: [PATCH] Adds Azure OpenAI LLM to Mem0 TS SDK (#2536) --- docs/changelog/overview.mdx | 1 + docs/components/llms/models/azure_openai.mdx | 38 ++++++++- mem0-ts/src/client/telemetry.ts | 2 +- mem0-ts/src/oss/src/config/manager.ts | 3 + mem0-ts/src/oss/src/llms/azure.ts | 82 ++++++++++++++++++++ mem0-ts/src/oss/src/types/index.ts | 2 + mem0-ts/src/oss/src/utils/factory.ts | 3 + mem0-ts/src/oss/src/utils/telemetry.ts | 2 +- 8 files changed, 129 insertions(+), 4 deletions(-) create mode 100644 mem0-ts/src/oss/src/llms/azure.ts diff --git a/docs/changelog/overview.mdx b/docs/changelog/overview.mdx index 9fca91c9..2652a357 100644 --- a/docs/changelog/overview.mdx +++ b/docs/changelog/overview.mdx @@ -129,6 +129,7 @@ mode: "wide" **New Features:** +- **Azure OpenAI:** Added support for Azure OpenAI - **Mistral LLM:** Added Mistral LLM integration in OSS **Improvements:** diff --git a/docs/components/llms/models/azure_openai.mdx b/docs/components/llms/models/azure_openai.mdx index ede726b6..4b333590 100644 --- a/docs/components/llms/models/azure_openai.mdx +++ b/docs/components/llms/models/azure_openai.mdx @@ -2,6 +2,8 @@ title: Azure OpenAI --- + Mem0 Now Supports Azure OpenAI Models in TypeScript SDK + To use Azure OpenAI models, you have to set the `LLM_AZURE_OPENAI_API_KEY`, `LLM_AZURE_ENDPOINT`, `LLM_AZURE_DEPLOYMENT` and `LLM_AZURE_API_VERSION` environment variables. You can obtain the Azure API key from the [Azure](https://azure.microsoft.com/). > **Note**: The following are currently unsupported with reasoning models `Parallel tool calling`,`temperature`, `top_p`, `presence_penalty`, `frequency_penalty`, `logprobs`, `top_logprobs`, `logit_bias`, `max_tokens` @@ -9,7 +11,8 @@ To use Azure OpenAI models, you have to set the `LLM_AZURE_OPENAI_API_KEY`, `LLM ## Usage -```python + +```python Python import os from mem0 import Memory @@ -48,7 +51,38 @@ messages = [ m.add(messages, user_id="alice", metadata={"category": "movies"}) ``` -We also support the new [OpenAI structured-outputs](https://platform.openai.com/docs/guides/structured-outputs/introduction) model. +```typescript TypeScript +import { Memory } from 'mem0ai/oss'; + +const config = { + llm: { + provider: 'azure_openai', + config: { + apiKey: process.env.AZURE_OPENAI_API_KEY || '', + modelProperties: { + endpoint: 'https://your-api-base-url', + deployment: 'your-deployment-name', + modelName: 'your-model-name', + apiVersion: 'version-to-use', + // Any other parameters you want to pass to the Azure OpenAI API + }, + }, + }, +}; + +const memory = new Memory(config); +const messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +await memory.add(messages, { userId: "alice", metadata: { category: "movies" } }); +``` + + + +We also support the new [OpenAI structured-outputs](https://platform.openai.com/docs/guides/structured-outputs/introduction) model. Typescript SDK does not support the `azure_openai_structured` model yet. ```python import os diff --git a/mem0-ts/src/client/telemetry.ts b/mem0-ts/src/client/telemetry.ts index ab56cb8f..3691a201 100644 --- a/mem0-ts/src/client/telemetry.ts +++ b/mem0-ts/src/client/telemetry.ts @@ -1,7 +1,7 @@ // @ts-nocheck import type { TelemetryClient, TelemetryOptions } from "./telemetry.types"; -let version = "2.1.12"; +let version = "2.1.16"; // Safely check for process.env in different environments let MEM0_TELEMETRY = true; diff --git a/mem0-ts/src/oss/src/config/manager.ts b/mem0-ts/src/oss/src/config/manager.ts index fa8daed7..1e50f355 100644 --- a/mem0-ts/src/oss/src/config/manager.ts +++ b/mem0-ts/src/oss/src/config/manager.ts @@ -42,6 +42,9 @@ export class ConfigManager { model: userConfig.llm?.config?.model || DEFAULT_MEMORY_CONFIG.llm.config.model, + modelProperties: + userConfig.llm?.config?.modelProperties || + DEFAULT_MEMORY_CONFIG.llm.config.modelProperties, }, }, historyDbPath: diff --git a/mem0-ts/src/oss/src/llms/azure.ts b/mem0-ts/src/oss/src/llms/azure.ts new file mode 100644 index 00000000..ff960abc --- /dev/null +++ b/mem0-ts/src/oss/src/llms/azure.ts @@ -0,0 +1,82 @@ +import { AzureOpenAI } from "openai"; +import { LLM, LLMResponse } from "./base"; +import { LLMConfig, Message } from "../types"; + +export class AzureOpenAILLM implements LLM { + private client: AzureOpenAI; + private model: string; + + constructor(config: LLMConfig) { + if (!config.apiKey || !config.modelProperties?.endpoint) { + throw new Error("Azure OpenAI requires both API key and endpoint"); + } + + const { endpoint, ...rest } = config.modelProperties; + + this.client = new AzureOpenAI({ + apiKey: config.apiKey, + endpoint: endpoint as string, + ...rest, + }); + this.model = config.model || "gpt-4"; + } + + async generateResponse( + messages: Message[], + responseFormat?: { type: string }, + tools?: any[], + ): Promise { + const completion = await this.client.chat.completions.create({ + messages: messages.map((msg) => { + const role = msg.role as "system" | "user" | "assistant"; + return { + role, + content: + typeof msg.content === "string" + ? msg.content + : JSON.stringify(msg.content), + }; + }), + model: this.model, + response_format: responseFormat as { type: "text" | "json_object" }, + ...(tools && { tools, tool_choice: "auto" }), + }); + + const response = completion.choices[0].message; + + if (response.tool_calls) { + return { + content: response.content || "", + role: response.role, + toolCalls: response.tool_calls.map((call) => ({ + name: call.function.name, + arguments: call.function.arguments, + })), + }; + } + + return response.content || ""; + } + + async generateChat(messages: Message[]): Promise { + const completion = await this.client.chat.completions.create({ + messages: messages.map((msg) => { + const role = msg.role as "system" | "user" | "assistant"; + return { + role, + content: + typeof msg.content === "string" + ? msg.content + : JSON.stringify(msg.content), + }; + }), + model: this.model, + }); + + const response = completion.choices[0].message; + return { + content: response.content || "", + role: response.role, + }; + } +} diff --git a/mem0-ts/src/oss/src/types/index.ts b/mem0-ts/src/oss/src/types/index.ts index 9bd356bf..3c15d668 100644 --- a/mem0-ts/src/oss/src/types/index.ts +++ b/mem0-ts/src/oss/src/types/index.ts @@ -39,6 +39,7 @@ export interface LLMConfig { config?: Record; apiKey?: string; model?: string; + modelProperties?: Record; } export interface Neo4jConfig { @@ -127,6 +128,7 @@ export const MemoryConfigSchema = z.object({ config: z.object({ apiKey: z.string(), model: z.string().optional(), + modelProperties: z.record(z.string(), z.any()).optional(), }), }), historyDbPath: z.string().optional(), diff --git a/mem0-ts/src/oss/src/utils/factory.ts b/mem0-ts/src/oss/src/utils/factory.ts index 85ba7f83..f5dc222f 100644 --- a/mem0-ts/src/oss/src/utils/factory.ts +++ b/mem0-ts/src/oss/src/utils/factory.ts @@ -25,6 +25,7 @@ import { SupabaseHistoryManager } from "../storage/SupabaseHistoryManager"; import { HistoryManager } from "../storage/base"; import { GoogleEmbedder } from "../embeddings/google"; import { GoogleLLM } from "../llms/google"; +import { AzureOpenAILLM } from "../llms/azure"; export class EmbedderFactory { static create(provider: string, config: EmbeddingConfig): Embedder { @@ -56,6 +57,8 @@ export class LLMFactory { return new OllamaLLM(config); case "google": return new GoogleLLM(config); + case "azure_openai": + return new AzureOpenAILLM(config); case "mistral": return new MistralLLM(config); default: diff --git a/mem0-ts/src/oss/src/utils/telemetry.ts b/mem0-ts/src/oss/src/utils/telemetry.ts index a478db02..7edff41d 100644 --- a/mem0-ts/src/oss/src/utils/telemetry.ts +++ b/mem0-ts/src/oss/src/utils/telemetry.ts @@ -4,7 +4,7 @@ import type { TelemetryEventData, } from "./telemetry.types"; -let version = "2.1.15"; +let version = "2.1.16"; // Safely check for process.env in different environments let MEM0_TELEMETRY = true;