From 33abf772ce338158f2ce0beea9118df4a913723b Mon Sep 17 00:00:00 2001 From: Saket Aryan Date: Tue, 15 Apr 2025 22:02:30 +0530 Subject: [PATCH] Adds Azure OpenAI Embedding Model (#2545) --- docs/changelog.mdx | 6 ++- .../embedders/models/azure_openai.mdx | 33 +++++++++++++++- mem0-ts/src/oss/src/config/manager.ts | 4 ++ mem0-ts/src/oss/src/embeddings/azure.ts | 39 +++++++++++++++++++ mem0-ts/src/oss/src/index.ts | 1 + mem0-ts/src/oss/src/types/index.ts | 2 + mem0-ts/src/oss/src/utils/factory.ts | 3 ++ mem0-ts/src/oss/src/vector_stores/supabase.ts | 23 +++++++---- 8 files changed, 102 insertions(+), 9 deletions(-) create mode 100644 mem0-ts/src/oss/src/embeddings/azure.ts diff --git a/docs/changelog.mdx b/docs/changelog.mdx index 68504fde..807bb496 100644 --- a/docs/changelog.mdx +++ b/docs/changelog.mdx @@ -127,15 +127,19 @@ mode: "wide" - + + **New Features:** - **OSS SDK:** Added support for Langchain LLM - **OSS SDK:** Added support for Langchain Embedder - **OSS SDK:** Added support for Langchain Vector Store +- **OSS SDK:** Added support for Azure OpenAI Embedder + **Improvements:** - **OSS SDK:** Changed `model` in LLM and Embedder to use type any from `string` to use langchain llm models - **OSS SDK:** Added client to vector store config for langchain vector store +- **OSS SDK:** - Updated Azure OpenAI to use new OpenAI SDK diff --git a/docs/components/embedders/models/azure_openai.mdx b/docs/components/embedders/models/azure_openai.mdx index 2a680958..cab58715 100644 --- a/docs/components/embedders/models/azure_openai.mdx +++ b/docs/components/embedders/models/azure_openai.mdx @@ -6,7 +6,8 @@ To use Azure OpenAI embedding models, set the `EMBEDDING_AZURE_OPENAI_API_KEY`, ### Usage -```python + +```python Python import os from mem0 import Memory @@ -46,6 +47,36 @@ messages = [ m.add(messages, user_id="john") ``` +```typescript TypeScript +import { Memory } from 'mem0ai/oss'; + +const config = { + embedder: { + provider: "azure_openai", + config: { + model: "text-embedding-3-large", + modelProperties: { + endpoint: "your-api-base-url", + deployment: "your-deployment-name", + apiVersion: "version-to-use", + } + } + } +} + +const memory = new Memory(config); + +const messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] + +await memory.add(messages, { userId: "john" }); +``` + + ### Config Here are the parameters available for configuring Azure OpenAI embedder: diff --git a/mem0-ts/src/oss/src/config/manager.ts b/mem0-ts/src/oss/src/config/manager.ts index 4e9cbb46..9371b705 100644 --- a/mem0-ts/src/oss/src/config/manager.ts +++ b/mem0-ts/src/oss/src/config/manager.ts @@ -27,6 +27,10 @@ export class ConfigManager { : defaultConf.apiKey, model: finalModel, url: userConf?.url, + modelProperties: + userConf?.modelProperties !== undefined + ? userConf.modelProperties + : defaultConf.modelProperties, }; })(), }, diff --git a/mem0-ts/src/oss/src/embeddings/azure.ts b/mem0-ts/src/oss/src/embeddings/azure.ts new file mode 100644 index 00000000..b153fa6e --- /dev/null +++ b/mem0-ts/src/oss/src/embeddings/azure.ts @@ -0,0 +1,39 @@ +import { AzureOpenAI } from "openai"; +import { Embedder } from "./base"; +import { EmbeddingConfig } from "../types"; + +export class AzureOpenAIEmbedder implements Embedder { + private client: AzureOpenAI; + private model: string; + + constructor(config: EmbeddingConfig) { + if (!config.apiKey || !config.modelProperties?.endpoint) { + throw new Error("Azure OpenAI requires both API key and endpoint"); + } + + const { endpoint, ...rest } = config.modelProperties; + + this.client = new AzureOpenAI({ + apiKey: config.apiKey, + endpoint: endpoint as string, + ...rest, + }); + this.model = config.model || "text-embedding-3-small"; + } + + async embed(text: string): Promise { + const response = await this.client.embeddings.create({ + model: this.model, + input: text, + }); + return response.data[0].embedding; + } + + async embedBatch(texts: string[]): Promise { + const response = await this.client.embeddings.create({ + model: this.model, + input: texts, + }); + return response.data.map((item) => item.embedding); + } +} diff --git a/mem0-ts/src/oss/src/index.ts b/mem0-ts/src/oss/src/index.ts index c14f3730..98aafd06 100644 --- a/mem0-ts/src/oss/src/index.ts +++ b/mem0-ts/src/oss/src/index.ts @@ -5,6 +5,7 @@ export * from "./embeddings/base"; export * from "./embeddings/openai"; export * from "./embeddings/ollama"; export * from "./embeddings/google"; +export * from "./embeddings/azure"; export * from "./embeddings/langchain"; export * from "./llms/base"; export * from "./llms/openai"; diff --git a/mem0-ts/src/oss/src/types/index.ts b/mem0-ts/src/oss/src/types/index.ts index 0cbd1c19..aa2d5d73 100644 --- a/mem0-ts/src/oss/src/types/index.ts +++ b/mem0-ts/src/oss/src/types/index.ts @@ -16,6 +16,7 @@ export interface EmbeddingConfig { apiKey?: string; model?: string | any; url?: string; + modelProperties?: Record; } export interface VectorStoreConfig { @@ -112,6 +113,7 @@ export const MemoryConfigSchema = z.object({ embedder: z.object({ provider: z.string(), config: z.object({ + modelProperties: z.record(z.string(), z.any()).optional(), apiKey: z.string().optional(), model: z.union([z.string(), z.any()]).optional(), }), diff --git a/mem0-ts/src/oss/src/utils/factory.ts b/mem0-ts/src/oss/src/utils/factory.ts index db9c609f..894963f3 100644 --- a/mem0-ts/src/oss/src/utils/factory.ts +++ b/mem0-ts/src/oss/src/utils/factory.ts @@ -26,6 +26,7 @@ import { HistoryManager } from "../storage/base"; import { GoogleEmbedder } from "../embeddings/google"; import { GoogleLLM } from "../llms/google"; import { AzureOpenAILLM } from "../llms/azure"; +import { AzureOpenAIEmbedder } from "../embeddings/azure"; import { LangchainLLM } from "../llms/langchain"; import { LangchainEmbedder } from "../embeddings/langchain"; import { LangchainVectorStore } from "../vector_stores/langchain"; @@ -39,6 +40,8 @@ export class EmbedderFactory { return new OllamaEmbedder(config); case "google": return new GoogleEmbedder(config); + case "azure_openai": + return new AzureOpenAIEmbedder(config); case "langchain": return new LangchainEmbedder(config); default: diff --git a/mem0-ts/src/oss/src/vector_stores/supabase.ts b/mem0-ts/src/oss/src/vector_stores/supabase.ts index bf2db537..878182e0 100644 --- a/mem0-ts/src/oss/src/vector_stores/supabase.ts +++ b/mem0-ts/src/oss/src/vector_stores/supabase.ts @@ -103,12 +103,16 @@ export class SupabaseDB implements VectorStore { try { // Verify table exists and vector operations work by attempting a test insert const testVector = Array(1536).fill(0); + + // First try to delete any existing test vector try { await this.client.from(this.tableName).delete().eq("id", "test_vector"); - } catch (error) { - console.warn("No test vector to delete, safe to ignore."); + } catch { + // Ignore delete errors - table might not exist yet } - const { error: testError } = await this.client + + // Try to insert the test vector + const { error: insertError } = await this.client .from(this.tableName) .insert({ id: "test_vector", @@ -117,8 +121,9 @@ export class SupabaseDB implements VectorStore { }) .select(); - if (testError) { - console.error("Test insert error:", testError); + // If we get a duplicate key error, that's actually fine - it means the table exists + if (insertError && insertError.code !== "23505") { + console.error("Test insert error:", insertError); throw new Error( `Vector operations failed. Please ensure: 1. The vector extension is enabled @@ -178,8 +183,12 @@ See the SQL migration instructions in the code comments.`, ); } - // Clean up test vector - await this.client.from(this.tableName).delete().eq("id", "test_vector"); + // Clean up test vector - ignore errors here too + try { + await this.client.from(this.tableName).delete().eq("id", "test_vector"); + } catch { + // Ignore delete errors + } console.log("Connected to Supabase successfully"); } catch (error) {