Adds Azure OpenAI LLM to Mem0 TS SDK (#2536)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
// @ts-nocheck
|
||||
import type { TelemetryClient, TelemetryOptions } from "./telemetry.types";
|
||||
|
||||
let version = "2.1.12";
|
||||
let version = "2.1.16";
|
||||
|
||||
// Safely check for process.env in different environments
|
||||
let MEM0_TELEMETRY = true;
|
||||
|
||||
@@ -42,6 +42,9 @@ export class ConfigManager {
|
||||
model:
|
||||
userConfig.llm?.config?.model ||
|
||||
DEFAULT_MEMORY_CONFIG.llm.config.model,
|
||||
modelProperties:
|
||||
userConfig.llm?.config?.modelProperties ||
|
||||
DEFAULT_MEMORY_CONFIG.llm.config.modelProperties,
|
||||
},
|
||||
},
|
||||
historyDbPath:
|
||||
|
||||
82
mem0-ts/src/oss/src/llms/azure.ts
Normal file
82
mem0-ts/src/oss/src/llms/azure.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
import { AzureOpenAI } from "openai";
|
||||
import { LLM, LLMResponse } from "./base";
|
||||
import { LLMConfig, Message } from "../types";
|
||||
|
||||
export class AzureOpenAILLM implements LLM {
|
||||
private client: AzureOpenAI;
|
||||
private model: string;
|
||||
|
||||
constructor(config: LLMConfig) {
|
||||
if (!config.apiKey || !config.modelProperties?.endpoint) {
|
||||
throw new Error("Azure OpenAI requires both API key and endpoint");
|
||||
}
|
||||
|
||||
const { endpoint, ...rest } = config.modelProperties;
|
||||
|
||||
this.client = new AzureOpenAI({
|
||||
apiKey: config.apiKey,
|
||||
endpoint: endpoint as string,
|
||||
...rest,
|
||||
});
|
||||
this.model = config.model || "gpt-4";
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
messages: Message[],
|
||||
responseFormat?: { type: string },
|
||||
tools?: any[],
|
||||
): Promise<string | LLMResponse> {
|
||||
const completion = await this.client.chat.completions.create({
|
||||
messages: messages.map((msg) => {
|
||||
const role = msg.role as "system" | "user" | "assistant";
|
||||
return {
|
||||
role,
|
||||
content:
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: JSON.stringify(msg.content),
|
||||
};
|
||||
}),
|
||||
model: this.model,
|
||||
response_format: responseFormat as { type: "text" | "json_object" },
|
||||
...(tools && { tools, tool_choice: "auto" }),
|
||||
});
|
||||
|
||||
const response = completion.choices[0].message;
|
||||
|
||||
if (response.tool_calls) {
|
||||
return {
|
||||
content: response.content || "",
|
||||
role: response.role,
|
||||
toolCalls: response.tool_calls.map((call) => ({
|
||||
name: call.function.name,
|
||||
arguments: call.function.arguments,
|
||||
})),
|
||||
};
|
||||
}
|
||||
|
||||
return response.content || "";
|
||||
}
|
||||
|
||||
async generateChat(messages: Message[]): Promise<LLMResponse> {
|
||||
const completion = await this.client.chat.completions.create({
|
||||
messages: messages.map((msg) => {
|
||||
const role = msg.role as "system" | "user" | "assistant";
|
||||
return {
|
||||
role,
|
||||
content:
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: JSON.stringify(msg.content),
|
||||
};
|
||||
}),
|
||||
model: this.model,
|
||||
});
|
||||
|
||||
const response = completion.choices[0].message;
|
||||
return {
|
||||
content: response.content || "",
|
||||
role: response.role,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,7 @@ export interface LLMConfig {
|
||||
config?: Record<string, any>;
|
||||
apiKey?: string;
|
||||
model?: string;
|
||||
modelProperties?: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface Neo4jConfig {
|
||||
@@ -127,6 +128,7 @@ export const MemoryConfigSchema = z.object({
|
||||
config: z.object({
|
||||
apiKey: z.string(),
|
||||
model: z.string().optional(),
|
||||
modelProperties: z.record(z.string(), z.any()).optional(),
|
||||
}),
|
||||
}),
|
||||
historyDbPath: z.string().optional(),
|
||||
|
||||
@@ -25,6 +25,7 @@ import { SupabaseHistoryManager } from "../storage/SupabaseHistoryManager";
|
||||
import { HistoryManager } from "../storage/base";
|
||||
import { GoogleEmbedder } from "../embeddings/google";
|
||||
import { GoogleLLM } from "../llms/google";
|
||||
import { AzureOpenAILLM } from "../llms/azure";
|
||||
|
||||
export class EmbedderFactory {
|
||||
static create(provider: string, config: EmbeddingConfig): Embedder {
|
||||
@@ -56,6 +57,8 @@ export class LLMFactory {
|
||||
return new OllamaLLM(config);
|
||||
case "google":
|
||||
return new GoogleLLM(config);
|
||||
case "azure_openai":
|
||||
return new AzureOpenAILLM(config);
|
||||
case "mistral":
|
||||
return new MistralLLM(config);
|
||||
default:
|
||||
|
||||
@@ -4,7 +4,7 @@ import type {
|
||||
TelemetryEventData,
|
||||
} from "./telemetry.types";
|
||||
|
||||
let version = "2.1.15";
|
||||
let version = "2.1.16";
|
||||
|
||||
// Safely check for process.env in different environments
|
||||
let MEM0_TELEMETRY = true;
|
||||
|
||||
Reference in New Issue
Block a user