diff --git a/mem0-ts/package.json b/mem0-ts/package.json index 2b0e6fc4..4cc1ffc6 100644 --- a/mem0-ts/package.json +++ b/mem0-ts/package.json @@ -31,9 +31,10 @@ "dist" ], "scripts": { - "clean": "rm -rf dist", - "build": "npm run clean && prettier --check . && tsup", - "dev": "nodemon", + "clean": "rimraf dist", + "build": "npm run clean && npx prettier --check . && npx tsup", + "dev": "npx nodemon", + "start": "npx ts-node src/oss/examples/basic.ts", "test": "jest", "test:ts": "jest --config jest.config.js", "test:watch": "jest --config jest.config.js --watch", @@ -74,7 +75,9 @@ "dotenv": "^16.4.5", "fix-tsup-cjs": "^1.2.0", "jest": "^29.7.0", + "nodemon": "^3.0.1", "prettier": "^3.5.2", + "rimraf": "^5.0.5", "ts-jest": "^29.2.6", "ts-node": "^10.9.2", "tsup": "^8.3.0", @@ -96,7 +99,8 @@ "groq-sdk": "0.3.0", "pg": "8.11.3", "redis": "4.7.0", - "sqlite3": "5.1.7" + "sqlite3": "5.1.7", + "ollama": "^0.5.14" }, "peerDependenciesMeta": { "posthog-node": { diff --git a/mem0-ts/src/oss/examples/basic.ts b/mem0-ts/src/oss/examples/basic.ts index 39c7d1f5..b6c62d19 100644 --- a/mem0-ts/src/oss/examples/basic.ts +++ b/mem0-ts/src/oss/examples/basic.ts @@ -116,6 +116,36 @@ async function runTests(memory: Memory) { } } +async function demoLocalMemory() { + console.log("\n=== Testing In-Memory Vector Store with Ollama===\n"); + + const memory = new Memory({ + version: "v1.1", + embedder: { + provider: "ollama", + config: { + model: "nomic-embed-text:latest", + }, + }, + vectorStore: { + provider: "memory", + config: { + collectionName: "memories", + dimension: 768, // 768 is the dimension of the nomic-embed-text model + }, + }, + llm: { + provider: "ollama", + config: { + model: "llama3.1:8b", + }, + }, + // historyDbPath: "memory.db", + }); + + await runTests(memory); +} + async function demoMemoryStore() { console.log("\n=== Testing In-Memory Vector Store ===\n"); @@ -346,6 +376,9 @@ async function main() { // Test in-memory store await demoMemoryStore(); + // Test in-memory store with Ollama + await demoLocalMemory(); + // Test graph memory if Neo4j environment variables are set if ( process.env.NEO4J_URL && @@ -384,4 +417,4 @@ async function main() { } } -// main(); +main(); diff --git a/mem0-ts/src/oss/examples/local-llms.ts b/mem0-ts/src/oss/examples/local-llms.ts new file mode 100644 index 00000000..29a8812c --- /dev/null +++ b/mem0-ts/src/oss/examples/local-llms.ts @@ -0,0 +1,93 @@ +import { Memory } from "../src"; +import { Ollama } from "ollama"; +import * as readline from "readline"; + +const memory = new Memory({ + embedder: { + provider: "ollama", + config: { + model: "nomic-embed-text:latest", + }, + }, + vectorStore: { + provider: "memory", + config: { + collectionName: "memories", + dimension: 768, // since we are using nomic-embed-text + }, + }, + llm: { + provider: "ollama", + config: { + model: "llama3.1:8b", + }, + }, + historyDbPath: "local-llms.db", +}); + +async function chatWithMemories(message: string, userId = "default_user") { + const relevantMemories = await memory.search(message, { userId: userId }); + + const memoriesStr = relevantMemories.results + .map((entry) => `- ${entry.memory}`) + .join("\n"); + + const systemPrompt = `You are a helpful AI. Answer the question based on query and memories. +User Memories: +${memoriesStr}`; + + const messages = [ + { role: "system", content: systemPrompt }, + { role: "user", content: message }, + ]; + + const ollama = new Ollama(); + const response = await ollama.chat({ + model: "llama3.1:8b", + messages: messages, + }); + + const assistantResponse = response.message.content || ""; + + messages.push({ role: "assistant", content: assistantResponse }); + await memory.add(messages, { userId: userId }); + + return assistantResponse; +} + +async function main() { + const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout, + }); + + console.log("Chat with AI (type 'exit' to quit)"); + + const askQuestion = (): Promise => { + return new Promise((resolve) => { + rl.question("You: ", (input) => { + resolve(input.trim()); + }); + }); + }; + + try { + while (true) { + const userInput = await askQuestion(); + + if (userInput.toLowerCase() === "exit") { + console.log("Goodbye!"); + rl.close(); + break; + } + + const response = await chatWithMemories(userInput, "sample_user"); + console.log(`AI: ${response}`); + } + } catch (error) { + console.error("An error occurred:", error); + rl.close(); + } +} + +main().catch(console.error); diff --git a/mem0-ts/src/oss/src/embeddings/ollama.ts b/mem0-ts/src/oss/src/embeddings/ollama.ts new file mode 100644 index 00000000..a738bd96 --- /dev/null +++ b/mem0-ts/src/oss/src/embeddings/ollama.ts @@ -0,0 +1,52 @@ +import { Ollama } from "ollama"; +import { Embedder } from "./base"; +import { EmbeddingConfig } from "../types"; +import { logger } from "../utils/logger"; + +export class OllamaEmbedder implements Embedder { + private ollama: Ollama; + private model: string; + // Using this variable to avoid calling the Ollama server multiple times + private initialized: boolean = false; + + constructor(config: EmbeddingConfig) { + this.ollama = new Ollama({ + host: config.url || "http://localhost:11434", + }); + this.model = config.model || "nomic-embed-text:latest"; + this.ensureModelExists().catch((err) => { + logger.error(`Error ensuring model exists: ${err}`); + }); + } + + async embed(text: string): Promise { + try { + await this.ensureModelExists(); + } catch (err) { + logger.error(`Error ensuring model exists: ${err}`); + } + const response = await this.ollama.embeddings({ + model: this.model, + prompt: text, + }); + return response.embedding; + } + + async embedBatch(texts: string[]): Promise { + const response = await Promise.all(texts.map((text) => this.embed(text))); + return response; + } + + private async ensureModelExists(): Promise { + if (this.initialized) { + return true; + } + const local_models = await this.ollama.list(); + if (!local_models.models.find((m: any) => m.name === this.model)) { + logger.info(`Pulling model ${this.model}...`); + await this.ollama.pull({ model: this.model }); + } + this.initialized = true; + return true; + } +} diff --git a/mem0-ts/src/oss/src/llms/ollama.ts b/mem0-ts/src/oss/src/llms/ollama.ts new file mode 100644 index 00000000..c076f5d3 --- /dev/null +++ b/mem0-ts/src/oss/src/llms/ollama.ts @@ -0,0 +1,104 @@ +import { Ollama } from "ollama"; +import { LLM, LLMResponse } from "./base"; +import { LLMConfig, Message } from "../types"; +import { logger } from "../utils/logger"; + +export class OllamaLLM implements LLM { + private ollama: Ollama; + private model: string; + // Using this variable to avoid calling the Ollama server multiple times + private initialized: boolean = false; + + constructor(config: LLMConfig) { + this.ollama = new Ollama({ + host: config.config?.url || "http://localhost:11434", + }); + this.model = config.model || "llama3.1:8b"; + this.ensureModelExists().catch((err) => { + logger.error(`Error ensuring model exists: ${err}`); + }); + } + + async generateResponse( + messages: Message[], + responseFormat?: { type: string }, + tools?: any[], + ): Promise { + try { + await this.ensureModelExists(); + } catch (err) { + logger.error(`Error ensuring model exists: ${err}`); + } + + const completion = await this.ollama.chat({ + model: this.model, + messages: messages.map((msg) => { + const role = msg.role as "system" | "user" | "assistant"; + return { + role, + content: + typeof msg.content === "string" + ? msg.content + : JSON.stringify(msg.content), + }; + }), + ...(responseFormat?.type === "json_object" && { format: "json" }), + ...(tools && { tools, tool_choice: "auto" }), + }); + + const response = completion.message; + + if (response.tool_calls) { + return { + content: response.content || "", + role: response.role, + toolCalls: response.tool_calls.map((call) => ({ + name: call.function.name, + arguments: JSON.stringify(call.function.arguments), + })), + }; + } + + return response.content || ""; + } + + async generateChat(messages: Message[]): Promise { + try { + await this.ensureModelExists(); + } catch (err) { + logger.error(`Error ensuring model exists: ${err}`); + } + + const completion = await this.ollama.chat({ + messages: messages.map((msg) => { + const role = msg.role as "system" | "user" | "assistant"; + return { + role, + content: + typeof msg.content === "string" + ? msg.content + : JSON.stringify(msg.content), + }; + }), + model: this.model, + }); + const response = completion.message; + return { + content: response.content || "", + role: response.role, + }; + } + + private async ensureModelExists(): Promise { + if (this.initialized) { + return true; + } + const local_models = await this.ollama.list(); + if (!local_models.models.find((m: any) => m.name === this.model)) { + logger.info(`Pulling model ${this.model}...`); + await this.ollama.pull({ model: this.model }); + } + this.initialized = true; + return true; + } +} diff --git a/mem0-ts/src/oss/src/types/index.ts b/mem0-ts/src/oss/src/types/index.ts index 2dcd5cab..62d2d473 100644 --- a/mem0-ts/src/oss/src/types/index.ts +++ b/mem0-ts/src/oss/src/types/index.ts @@ -13,8 +13,9 @@ export interface Message { } export interface EmbeddingConfig { - apiKey: string; + apiKey?: string; model?: string; + url?: string; } export interface VectorStoreConfig { diff --git a/mem0-ts/src/oss/src/utils/factory.ts b/mem0-ts/src/oss/src/utils/factory.ts index 80f69682..cfdbe1af 100644 --- a/mem0-ts/src/oss/src/utils/factory.ts +++ b/mem0-ts/src/oss/src/utils/factory.ts @@ -1,4 +1,5 @@ import { OpenAIEmbedder } from "../embeddings/openai"; +import { OllamaEmbedder } from "../embeddings/ollama"; import { OpenAILLM } from "../llms/openai"; import { OpenAIStructuredLLM } from "../llms/openai_structured"; import { AnthropicLLM } from "../llms/anthropic"; @@ -10,12 +11,14 @@ import { LLM } from "../llms/base"; import { VectorStore } from "../vector_stores/base"; import { Qdrant } from "../vector_stores/qdrant"; import { RedisDB } from "../vector_stores/redis"; - +import { OllamaLLM } from "../llms/ollama"; export class EmbedderFactory { static create(provider: string, config: EmbeddingConfig): Embedder { switch (provider.toLowerCase()) { case "openai": return new OpenAIEmbedder(config); + case "ollama": + return new OllamaEmbedder(config); default: throw new Error(`Unsupported embedder provider: ${provider}`); } @@ -33,6 +36,8 @@ export class LLMFactory { return new AnthropicLLM(config); case "groq": return new GroqLLM(config); + case "ollama": + return new OllamaLLM(config); default: throw new Error(`Unsupported LLM provider: ${provider}`); }