Added support for Ollama in TS SDK (#2345)
Co-authored-by: Dev Khant <devkhant24@gmail.com>
This commit is contained in:
@@ -31,9 +31,10 @@
|
||||
"dist"
|
||||
],
|
||||
"scripts": {
|
||||
"clean": "rm -rf dist",
|
||||
"build": "npm run clean && prettier --check . && tsup",
|
||||
"dev": "nodemon",
|
||||
"clean": "rimraf dist",
|
||||
"build": "npm run clean && npx prettier --check . && npx tsup",
|
||||
"dev": "npx nodemon",
|
||||
"start": "npx ts-node src/oss/examples/basic.ts",
|
||||
"test": "jest",
|
||||
"test:ts": "jest --config jest.config.js",
|
||||
"test:watch": "jest --config jest.config.js --watch",
|
||||
@@ -74,7 +75,9 @@
|
||||
"dotenv": "^16.4.5",
|
||||
"fix-tsup-cjs": "^1.2.0",
|
||||
"jest": "^29.7.0",
|
||||
"nodemon": "^3.0.1",
|
||||
"prettier": "^3.5.2",
|
||||
"rimraf": "^5.0.5",
|
||||
"ts-jest": "^29.2.6",
|
||||
"ts-node": "^10.9.2",
|
||||
"tsup": "^8.3.0",
|
||||
@@ -96,7 +99,8 @@
|
||||
"groq-sdk": "0.3.0",
|
||||
"pg": "8.11.3",
|
||||
"redis": "4.7.0",
|
||||
"sqlite3": "5.1.7"
|
||||
"sqlite3": "5.1.7",
|
||||
"ollama": "^0.5.14"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"posthog-node": {
|
||||
|
||||
@@ -116,6 +116,36 @@ async function runTests(memory: Memory) {
|
||||
}
|
||||
}
|
||||
|
||||
async function demoLocalMemory() {
|
||||
console.log("\n=== Testing In-Memory Vector Store with Ollama===\n");
|
||||
|
||||
const memory = new Memory({
|
||||
version: "v1.1",
|
||||
embedder: {
|
||||
provider: "ollama",
|
||||
config: {
|
||||
model: "nomic-embed-text:latest",
|
||||
},
|
||||
},
|
||||
vectorStore: {
|
||||
provider: "memory",
|
||||
config: {
|
||||
collectionName: "memories",
|
||||
dimension: 768, // 768 is the dimension of the nomic-embed-text model
|
||||
},
|
||||
},
|
||||
llm: {
|
||||
provider: "ollama",
|
||||
config: {
|
||||
model: "llama3.1:8b",
|
||||
},
|
||||
},
|
||||
// historyDbPath: "memory.db",
|
||||
});
|
||||
|
||||
await runTests(memory);
|
||||
}
|
||||
|
||||
async function demoMemoryStore() {
|
||||
console.log("\n=== Testing In-Memory Vector Store ===\n");
|
||||
|
||||
@@ -346,6 +376,9 @@ async function main() {
|
||||
// Test in-memory store
|
||||
await demoMemoryStore();
|
||||
|
||||
// Test in-memory store with Ollama
|
||||
await demoLocalMemory();
|
||||
|
||||
// Test graph memory if Neo4j environment variables are set
|
||||
if (
|
||||
process.env.NEO4J_URL &&
|
||||
@@ -384,4 +417,4 @@ async function main() {
|
||||
}
|
||||
}
|
||||
|
||||
// main();
|
||||
main();
|
||||
|
||||
93
mem0-ts/src/oss/examples/local-llms.ts
Normal file
93
mem0-ts/src/oss/examples/local-llms.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
import { Memory } from "../src";
|
||||
import { Ollama } from "ollama";
|
||||
import * as readline from "readline";
|
||||
|
||||
const memory = new Memory({
|
||||
embedder: {
|
||||
provider: "ollama",
|
||||
config: {
|
||||
model: "nomic-embed-text:latest",
|
||||
},
|
||||
},
|
||||
vectorStore: {
|
||||
provider: "memory",
|
||||
config: {
|
||||
collectionName: "memories",
|
||||
dimension: 768, // since we are using nomic-embed-text
|
||||
},
|
||||
},
|
||||
llm: {
|
||||
provider: "ollama",
|
||||
config: {
|
||||
model: "llama3.1:8b",
|
||||
},
|
||||
},
|
||||
historyDbPath: "local-llms.db",
|
||||
});
|
||||
|
||||
async function chatWithMemories(message: string, userId = "default_user") {
|
||||
const relevantMemories = await memory.search(message, { userId: userId });
|
||||
|
||||
const memoriesStr = relevantMemories.results
|
||||
.map((entry) => `- ${entry.memory}`)
|
||||
.join("\n");
|
||||
|
||||
const systemPrompt = `You are a helpful AI. Answer the question based on query and memories.
|
||||
User Memories:
|
||||
${memoriesStr}`;
|
||||
|
||||
const messages = [
|
||||
{ role: "system", content: systemPrompt },
|
||||
{ role: "user", content: message },
|
||||
];
|
||||
|
||||
const ollama = new Ollama();
|
||||
const response = await ollama.chat({
|
||||
model: "llama3.1:8b",
|
||||
messages: messages,
|
||||
});
|
||||
|
||||
const assistantResponse = response.message.content || "";
|
||||
|
||||
messages.push({ role: "assistant", content: assistantResponse });
|
||||
await memory.add(messages, { userId: userId });
|
||||
|
||||
return assistantResponse;
|
||||
}
|
||||
|
||||
async function main() {
|
||||
const rl = readline.createInterface({
|
||||
input: process.stdin,
|
||||
output: process.stdout,
|
||||
});
|
||||
|
||||
console.log("Chat with AI (type 'exit' to quit)");
|
||||
|
||||
const askQuestion = (): Promise<string> => {
|
||||
return new Promise((resolve) => {
|
||||
rl.question("You: ", (input) => {
|
||||
resolve(input.trim());
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const userInput = await askQuestion();
|
||||
|
||||
if (userInput.toLowerCase() === "exit") {
|
||||
console.log("Goodbye!");
|
||||
rl.close();
|
||||
break;
|
||||
}
|
||||
|
||||
const response = await chatWithMemories(userInput, "sample_user");
|
||||
console.log(`AI: ${response}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("An error occurred:", error);
|
||||
rl.close();
|
||||
}
|
||||
}
|
||||
|
||||
main().catch(console.error);
|
||||
52
mem0-ts/src/oss/src/embeddings/ollama.ts
Normal file
52
mem0-ts/src/oss/src/embeddings/ollama.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import { Ollama } from "ollama";
|
||||
import { Embedder } from "./base";
|
||||
import { EmbeddingConfig } from "../types";
|
||||
import { logger } from "../utils/logger";
|
||||
|
||||
export class OllamaEmbedder implements Embedder {
|
||||
private ollama: Ollama;
|
||||
private model: string;
|
||||
// Using this variable to avoid calling the Ollama server multiple times
|
||||
private initialized: boolean = false;
|
||||
|
||||
constructor(config: EmbeddingConfig) {
|
||||
this.ollama = new Ollama({
|
||||
host: config.url || "http://localhost:11434",
|
||||
});
|
||||
this.model = config.model || "nomic-embed-text:latest";
|
||||
this.ensureModelExists().catch((err) => {
|
||||
logger.error(`Error ensuring model exists: ${err}`);
|
||||
});
|
||||
}
|
||||
|
||||
async embed(text: string): Promise<number[]> {
|
||||
try {
|
||||
await this.ensureModelExists();
|
||||
} catch (err) {
|
||||
logger.error(`Error ensuring model exists: ${err}`);
|
||||
}
|
||||
const response = await this.ollama.embeddings({
|
||||
model: this.model,
|
||||
prompt: text,
|
||||
});
|
||||
return response.embedding;
|
||||
}
|
||||
|
||||
async embedBatch(texts: string[]): Promise<number[][]> {
|
||||
const response = await Promise.all(texts.map((text) => this.embed(text)));
|
||||
return response;
|
||||
}
|
||||
|
||||
private async ensureModelExists(): Promise<boolean> {
|
||||
if (this.initialized) {
|
||||
return true;
|
||||
}
|
||||
const local_models = await this.ollama.list();
|
||||
if (!local_models.models.find((m: any) => m.name === this.model)) {
|
||||
logger.info(`Pulling model ${this.model}...`);
|
||||
await this.ollama.pull({ model: this.model });
|
||||
}
|
||||
this.initialized = true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
104
mem0-ts/src/oss/src/llms/ollama.ts
Normal file
104
mem0-ts/src/oss/src/llms/ollama.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
import { Ollama } from "ollama";
|
||||
import { LLM, LLMResponse } from "./base";
|
||||
import { LLMConfig, Message } from "../types";
|
||||
import { logger } from "../utils/logger";
|
||||
|
||||
export class OllamaLLM implements LLM {
|
||||
private ollama: Ollama;
|
||||
private model: string;
|
||||
// Using this variable to avoid calling the Ollama server multiple times
|
||||
private initialized: boolean = false;
|
||||
|
||||
constructor(config: LLMConfig) {
|
||||
this.ollama = new Ollama({
|
||||
host: config.config?.url || "http://localhost:11434",
|
||||
});
|
||||
this.model = config.model || "llama3.1:8b";
|
||||
this.ensureModelExists().catch((err) => {
|
||||
logger.error(`Error ensuring model exists: ${err}`);
|
||||
});
|
||||
}
|
||||
|
||||
async generateResponse(
|
||||
messages: Message[],
|
||||
responseFormat?: { type: string },
|
||||
tools?: any[],
|
||||
): Promise<string | LLMResponse> {
|
||||
try {
|
||||
await this.ensureModelExists();
|
||||
} catch (err) {
|
||||
logger.error(`Error ensuring model exists: ${err}`);
|
||||
}
|
||||
|
||||
const completion = await this.ollama.chat({
|
||||
model: this.model,
|
||||
messages: messages.map((msg) => {
|
||||
const role = msg.role as "system" | "user" | "assistant";
|
||||
return {
|
||||
role,
|
||||
content:
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: JSON.stringify(msg.content),
|
||||
};
|
||||
}),
|
||||
...(responseFormat?.type === "json_object" && { format: "json" }),
|
||||
...(tools && { tools, tool_choice: "auto" }),
|
||||
});
|
||||
|
||||
const response = completion.message;
|
||||
|
||||
if (response.tool_calls) {
|
||||
return {
|
||||
content: response.content || "",
|
||||
role: response.role,
|
||||
toolCalls: response.tool_calls.map((call) => ({
|
||||
name: call.function.name,
|
||||
arguments: JSON.stringify(call.function.arguments),
|
||||
})),
|
||||
};
|
||||
}
|
||||
|
||||
return response.content || "";
|
||||
}
|
||||
|
||||
async generateChat(messages: Message[]): Promise<LLMResponse> {
|
||||
try {
|
||||
await this.ensureModelExists();
|
||||
} catch (err) {
|
||||
logger.error(`Error ensuring model exists: ${err}`);
|
||||
}
|
||||
|
||||
const completion = await this.ollama.chat({
|
||||
messages: messages.map((msg) => {
|
||||
const role = msg.role as "system" | "user" | "assistant";
|
||||
return {
|
||||
role,
|
||||
content:
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: JSON.stringify(msg.content),
|
||||
};
|
||||
}),
|
||||
model: this.model,
|
||||
});
|
||||
const response = completion.message;
|
||||
return {
|
||||
content: response.content || "",
|
||||
role: response.role,
|
||||
};
|
||||
}
|
||||
|
||||
private async ensureModelExists(): Promise<boolean> {
|
||||
if (this.initialized) {
|
||||
return true;
|
||||
}
|
||||
const local_models = await this.ollama.list();
|
||||
if (!local_models.models.find((m: any) => m.name === this.model)) {
|
||||
logger.info(`Pulling model ${this.model}...`);
|
||||
await this.ollama.pull({ model: this.model });
|
||||
}
|
||||
this.initialized = true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -13,8 +13,9 @@ export interface Message {
|
||||
}
|
||||
|
||||
export interface EmbeddingConfig {
|
||||
apiKey: string;
|
||||
apiKey?: string;
|
||||
model?: string;
|
||||
url?: string;
|
||||
}
|
||||
|
||||
export interface VectorStoreConfig {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { OpenAIEmbedder } from "../embeddings/openai";
|
||||
import { OllamaEmbedder } from "../embeddings/ollama";
|
||||
import { OpenAILLM } from "../llms/openai";
|
||||
import { OpenAIStructuredLLM } from "../llms/openai_structured";
|
||||
import { AnthropicLLM } from "../llms/anthropic";
|
||||
@@ -10,12 +11,14 @@ import { LLM } from "../llms/base";
|
||||
import { VectorStore } from "../vector_stores/base";
|
||||
import { Qdrant } from "../vector_stores/qdrant";
|
||||
import { RedisDB } from "../vector_stores/redis";
|
||||
|
||||
import { OllamaLLM } from "../llms/ollama";
|
||||
export class EmbedderFactory {
|
||||
static create(provider: string, config: EmbeddingConfig): Embedder {
|
||||
switch (provider.toLowerCase()) {
|
||||
case "openai":
|
||||
return new OpenAIEmbedder(config);
|
||||
case "ollama":
|
||||
return new OllamaEmbedder(config);
|
||||
default:
|
||||
throw new Error(`Unsupported embedder provider: ${provider}`);
|
||||
}
|
||||
@@ -33,6 +36,8 @@ export class LLMFactory {
|
||||
return new AnthropicLLM(config);
|
||||
case "groq":
|
||||
return new GroqLLM(config);
|
||||
case "ollama":
|
||||
return new OllamaLLM(config);
|
||||
default:
|
||||
throw new Error(`Unsupported LLM provider: ${provider}`);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user