Added support for Ollama in TS SDK (#2345)

Co-authored-by: Dev Khant <devkhant24@gmail.com>
This commit is contained in:
Gaurav Agerwala
2025-03-19 09:08:19 -07:00
committed by GitHub
parent 540ada489b
commit 2ffe9922f3
7 changed files with 299 additions and 7 deletions

View File

@@ -31,9 +31,10 @@
"dist"
],
"scripts": {
"clean": "rm -rf dist",
"build": "npm run clean && prettier --check . && tsup",
"dev": "nodemon",
"clean": "rimraf dist",
"build": "npm run clean && npx prettier --check . && npx tsup",
"dev": "npx nodemon",
"start": "npx ts-node src/oss/examples/basic.ts",
"test": "jest",
"test:ts": "jest --config jest.config.js",
"test:watch": "jest --config jest.config.js --watch",
@@ -74,7 +75,9 @@
"dotenv": "^16.4.5",
"fix-tsup-cjs": "^1.2.0",
"jest": "^29.7.0",
"nodemon": "^3.0.1",
"prettier": "^3.5.2",
"rimraf": "^5.0.5",
"ts-jest": "^29.2.6",
"ts-node": "^10.9.2",
"tsup": "^8.3.0",
@@ -96,7 +99,8 @@
"groq-sdk": "0.3.0",
"pg": "8.11.3",
"redis": "4.7.0",
"sqlite3": "5.1.7"
"sqlite3": "5.1.7",
"ollama": "^0.5.14"
},
"peerDependenciesMeta": {
"posthog-node": {

View File

@@ -116,6 +116,36 @@ async function runTests(memory: Memory) {
}
}
async function demoLocalMemory() {
console.log("\n=== Testing In-Memory Vector Store with Ollama===\n");
const memory = new Memory({
version: "v1.1",
embedder: {
provider: "ollama",
config: {
model: "nomic-embed-text:latest",
},
},
vectorStore: {
provider: "memory",
config: {
collectionName: "memories",
dimension: 768, // 768 is the dimension of the nomic-embed-text model
},
},
llm: {
provider: "ollama",
config: {
model: "llama3.1:8b",
},
},
// historyDbPath: "memory.db",
});
await runTests(memory);
}
async function demoMemoryStore() {
console.log("\n=== Testing In-Memory Vector Store ===\n");
@@ -346,6 +376,9 @@ async function main() {
// Test in-memory store
await demoMemoryStore();
// Test in-memory store with Ollama
await demoLocalMemory();
// Test graph memory if Neo4j environment variables are set
if (
process.env.NEO4J_URL &&
@@ -384,4 +417,4 @@ async function main() {
}
}
// main();
main();

View File

@@ -0,0 +1,93 @@
import { Memory } from "../src";
import { Ollama } from "ollama";
import * as readline from "readline";
const memory = new Memory({
embedder: {
provider: "ollama",
config: {
model: "nomic-embed-text:latest",
},
},
vectorStore: {
provider: "memory",
config: {
collectionName: "memories",
dimension: 768, // since we are using nomic-embed-text
},
},
llm: {
provider: "ollama",
config: {
model: "llama3.1:8b",
},
},
historyDbPath: "local-llms.db",
});
async function chatWithMemories(message: string, userId = "default_user") {
const relevantMemories = await memory.search(message, { userId: userId });
const memoriesStr = relevantMemories.results
.map((entry) => `- ${entry.memory}`)
.join("\n");
const systemPrompt = `You are a helpful AI. Answer the question based on query and memories.
User Memories:
${memoriesStr}`;
const messages = [
{ role: "system", content: systemPrompt },
{ role: "user", content: message },
];
const ollama = new Ollama();
const response = await ollama.chat({
model: "llama3.1:8b",
messages: messages,
});
const assistantResponse = response.message.content || "";
messages.push({ role: "assistant", content: assistantResponse });
await memory.add(messages, { userId: userId });
return assistantResponse;
}
async function main() {
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
});
console.log("Chat with AI (type 'exit' to quit)");
const askQuestion = (): Promise<string> => {
return new Promise((resolve) => {
rl.question("You: ", (input) => {
resolve(input.trim());
});
});
};
try {
while (true) {
const userInput = await askQuestion();
if (userInput.toLowerCase() === "exit") {
console.log("Goodbye!");
rl.close();
break;
}
const response = await chatWithMemories(userInput, "sample_user");
console.log(`AI: ${response}`);
}
} catch (error) {
console.error("An error occurred:", error);
rl.close();
}
}
main().catch(console.error);

View File

@@ -0,0 +1,52 @@
import { Ollama } from "ollama";
import { Embedder } from "./base";
import { EmbeddingConfig } from "../types";
import { logger } from "../utils/logger";
export class OllamaEmbedder implements Embedder {
private ollama: Ollama;
private model: string;
// Using this variable to avoid calling the Ollama server multiple times
private initialized: boolean = false;
constructor(config: EmbeddingConfig) {
this.ollama = new Ollama({
host: config.url || "http://localhost:11434",
});
this.model = config.model || "nomic-embed-text:latest";
this.ensureModelExists().catch((err) => {
logger.error(`Error ensuring model exists: ${err}`);
});
}
async embed(text: string): Promise<number[]> {
try {
await this.ensureModelExists();
} catch (err) {
logger.error(`Error ensuring model exists: ${err}`);
}
const response = await this.ollama.embeddings({
model: this.model,
prompt: text,
});
return response.embedding;
}
async embedBatch(texts: string[]): Promise<number[][]> {
const response = await Promise.all(texts.map((text) => this.embed(text)));
return response;
}
private async ensureModelExists(): Promise<boolean> {
if (this.initialized) {
return true;
}
const local_models = await this.ollama.list();
if (!local_models.models.find((m: any) => m.name === this.model)) {
logger.info(`Pulling model ${this.model}...`);
await this.ollama.pull({ model: this.model });
}
this.initialized = true;
return true;
}
}

View File

@@ -0,0 +1,104 @@
import { Ollama } from "ollama";
import { LLM, LLMResponse } from "./base";
import { LLMConfig, Message } from "../types";
import { logger } from "../utils/logger";
export class OllamaLLM implements LLM {
private ollama: Ollama;
private model: string;
// Using this variable to avoid calling the Ollama server multiple times
private initialized: boolean = false;
constructor(config: LLMConfig) {
this.ollama = new Ollama({
host: config.config?.url || "http://localhost:11434",
});
this.model = config.model || "llama3.1:8b";
this.ensureModelExists().catch((err) => {
logger.error(`Error ensuring model exists: ${err}`);
});
}
async generateResponse(
messages: Message[],
responseFormat?: { type: string },
tools?: any[],
): Promise<string | LLMResponse> {
try {
await this.ensureModelExists();
} catch (err) {
logger.error(`Error ensuring model exists: ${err}`);
}
const completion = await this.ollama.chat({
model: this.model,
messages: messages.map((msg) => {
const role = msg.role as "system" | "user" | "assistant";
return {
role,
content:
typeof msg.content === "string"
? msg.content
: JSON.stringify(msg.content),
};
}),
...(responseFormat?.type === "json_object" && { format: "json" }),
...(tools && { tools, tool_choice: "auto" }),
});
const response = completion.message;
if (response.tool_calls) {
return {
content: response.content || "",
role: response.role,
toolCalls: response.tool_calls.map((call) => ({
name: call.function.name,
arguments: JSON.stringify(call.function.arguments),
})),
};
}
return response.content || "";
}
async generateChat(messages: Message[]): Promise<LLMResponse> {
try {
await this.ensureModelExists();
} catch (err) {
logger.error(`Error ensuring model exists: ${err}`);
}
const completion = await this.ollama.chat({
messages: messages.map((msg) => {
const role = msg.role as "system" | "user" | "assistant";
return {
role,
content:
typeof msg.content === "string"
? msg.content
: JSON.stringify(msg.content),
};
}),
model: this.model,
});
const response = completion.message;
return {
content: response.content || "",
role: response.role,
};
}
private async ensureModelExists(): Promise<boolean> {
if (this.initialized) {
return true;
}
const local_models = await this.ollama.list();
if (!local_models.models.find((m: any) => m.name === this.model)) {
logger.info(`Pulling model ${this.model}...`);
await this.ollama.pull({ model: this.model });
}
this.initialized = true;
return true;
}
}

View File

@@ -13,8 +13,9 @@ export interface Message {
}
export interface EmbeddingConfig {
apiKey: string;
apiKey?: string;
model?: string;
url?: string;
}
export interface VectorStoreConfig {

View File

@@ -1,4 +1,5 @@
import { OpenAIEmbedder } from "../embeddings/openai";
import { OllamaEmbedder } from "../embeddings/ollama";
import { OpenAILLM } from "../llms/openai";
import { OpenAIStructuredLLM } from "../llms/openai_structured";
import { AnthropicLLM } from "../llms/anthropic";
@@ -10,12 +11,14 @@ import { LLM } from "../llms/base";
import { VectorStore } from "../vector_stores/base";
import { Qdrant } from "../vector_stores/qdrant";
import { RedisDB } from "../vector_stores/redis";
import { OllamaLLM } from "../llms/ollama";
export class EmbedderFactory {
static create(provider: string, config: EmbeddingConfig): Embedder {
switch (provider.toLowerCase()) {
case "openai":
return new OpenAIEmbedder(config);
case "ollama":
return new OllamaEmbedder(config);
default:
throw new Error(`Unsupported embedder provider: ${provider}`);
}
@@ -33,6 +36,8 @@ export class LLMFactory {
return new AnthropicLLM(config);
case "groq":
return new GroqLLM(config);
case "ollama":
return new OllamaLLM(config);
default:
throw new Error(`Unsupported LLM provider: ${provider}`);
}