Graph Support for NodeSDK (#2298)

This commit is contained in:
Saket Aryan
2025-03-05 12:52:50 +05:30
committed by GitHub
parent 23dbce4f59
commit 6fdc63504a
21 changed files with 1676 additions and 112 deletions

View File

@@ -16,10 +16,11 @@ Users can add a customized prompt that will be used to extract specific entities
This allows for more targeted and relevant information extraction based on the user's needs.
Here's an example of how to add a customized prompt:
```python
from mem0 import Memory
<CodeGroup>
```python Python
from mem0 import Memory
config = {
config = {
"graph_store": {
"provider": "neo4j",
"config": {
@@ -29,10 +30,29 @@ config = {
},
"custom_prompt": "Please only extract entities containing sports related relationships and nothing else.",
}
}
}
m = Memory.from_config(config_dict=config)
```
m = Memory.from_config(config_dict=config)
```
```typescript TypeScript
import { Memory } from "mem0ai/oss";
const config = {
graphStore: {
provider: "neo4j",
config: {
url: "neo4j+s://xxx",
username: "neo4j",
password: "xxx",
},
customPrompt: "Please only extract entities containing sports related relationships and nothing else.",
}
}
const memory = new Memory(config);
```
</CodeGroup>
If you want to use a managed version of Mem0, please check out [Mem0](https://mem0.dev/pd). If you have any questions, please feel free to reach out to us using one of the following methods:

View File

@@ -9,14 +9,24 @@ Mem0 now supports **Graph Memory**.
With Graph Memory, users can now create and utilize complex relationships between pieces of information, allowing for more nuanced and context-aware responses.
This integration enables users to leverage the strengths of both vector-based and graph-based approaches, resulting in more accurate and comprehensive information retrieval and generation.
<Note>
NodeSDK now supports Graph Memory. 🎉
</Note>
## Installation
To use Mem0 with Graph Memory support, install it using pip:
```bash
<CodeGroup>
```bash Python
pip install "mem0ai[graph]"
```
```bash TypeScript
npm install mem0ai
```
</CodeGroup>
This command installs Mem0 along with the necessary dependencies for graph functionality.
Try Graph Memory on Google Colab.
@@ -42,6 +52,9 @@ Currently, we support Neo4j as a graph store provider. You can setup [Neo4j](htt
<Note>If you are using Neo4j locally, then you need to install [APOC plugins](https://neo4j.com/labs/apoc/4.1/installation/).</Note>
<Note>
If you are using NodeSDK, you need to pass `enableGraph` as `true` in the `config` object.
</Note>
User can also customize the LLM for Graph Memory from the [Supported LLM list](https://docs.mem0.ai/components/llms/overview) with three levels of configuration:
@@ -53,7 +66,7 @@ Here's how you can do it:
<CodeGroup>
```python Basic
```python Python
from mem0 import Memory
config = {
@@ -70,9 +83,25 @@ config = {
m = Memory.from_config(config_dict=config)
```
```python Advanced (Custom LLM)
from mem0 import Memory
```typescript TypeScript
import { Memory } from "mem0ai/oss";
const config = {
enableGraph: true,
graphStore: {
provider: "neo4j",
config: {
url: "neo4j+s://xxx",
username: "neo4j",
password: "xxx",
}
}
}
const memory = new Memory(config);
```
```python Python (Advanced)
config = {
"llm": {
"provider": "openai",
@@ -101,6 +130,37 @@ config = {
m = Memory.from_config(config_dict=config)
```
```typescript TypeScript (Advanced)
const config = {
llm: {
provider: "openai",
config: {
model: "gpt-4o",
temperature: 0.2,
max_tokens: 2000,
}
},
enableGraph: true,
graphStore: {
provider: "neo4j",
config: {
url: "neo4j+s://xxx",
username: "neo4j",
password: "xxx",
},
llm: {
provider: "openai",
config: {
model: "gpt-4o-mini",
temperature: 0.0,
}
}
}
}
const memory = new Memory(config);
```
</CodeGroup>
## Graph Operations
@@ -109,14 +169,18 @@ The Mem0's graph supports the following operations:
### Add Memories
<Note>
If you are using Mem0 with Graph Memory, it is recommended to pass `user_id`. The default value of `user_id` (in case of graph memory) is `user`.
If you are using Mem0 with Graph Memory, it is recommended to pass `user_id`. Use `userId` in NodeSDK.
</Note>
<CodeGroup>
```python Code
```python Python
m.add("I like pizza", user_id="alice")
```
```typescript TypeScript
memory.add("I like pizza", { userId: "alice" });
```
```json Output
{'message': 'ok'}
```
@@ -126,10 +190,14 @@ m.add("I like pizza", user_id="alice")
### Get all memories
<CodeGroup>
```python Code
```python Python
m.get_all(user_id="alice")
```
```typescript TypeScript
memory.getAll({ userId: "alice" });
```
```json Output
{
'memories': [
@@ -157,10 +225,14 @@ m.get_all(user_id="alice")
### Search Memories
<CodeGroup>
```python Code
```python Python
m.search("tell me my name.", user_id="alice")
```
```typescript TypeScript
memory.search("tell me my name.", { userId: "alice" });
```
```json Output
{
'memories': [
@@ -187,10 +259,16 @@ m.search("tell me my name.", user_id="alice")
### Delete all Memories
```python
<CodeGroup>
```python Python
m.delete_all(user_id="alice")
```
```typescript TypeScript
memory.deleteAll({ userId: "alice" });
```
</CodeGroup>
# Example Usage
Here's an example of how to use Mem0's graph operations:
@@ -206,64 +284,110 @@ Below are the steps to add memories and visualize the graph:
<Steps>
<Step title="Add memory 'I like going to hikes'">
```python
<CodeGroup>
```python Python
m.add("I like going to hikes", user_id="alice123")
```
```typescript TypeScript
memory.add("I like going to hikes", { userId: "alice123" });
```
</CodeGroup>
![Graph Memory Visualization](/images/graph_memory/graph_example1.png)
</Step>
<Step title="Add memory 'I love to play badminton'">
```python
<CodeGroup>
```python Python
m.add("I love to play badminton", user_id="alice123")
```
```typescript TypeScript
memory.add("I love to play badminton", { userId: "alice123" });
```
</CodeGroup>
![Graph Memory Visualization](/images/graph_memory/graph_example2.png)
</Step>
<Step title="Add memory 'I hate playing badminton'">
```python
<CodeGroup>
```python Python
m.add("I hate playing badminton", user_id="alice123")
```
```typescript TypeScript
memory.add("I hate playing badminton", { userId: "alice123" });
```
</CodeGroup>
![Graph Memory Visualization](/images/graph_memory/graph_example3.png)
</Step>
<Step title="Add memory 'My friend name is john and john has a dog named tommy'">
```python
<CodeGroup>
```python Python
m.add("My friend name is john and john has a dog named tommy", user_id="alice123")
```
```typescript TypeScript
memory.add("My friend name is john and john has a dog named tommy", { userId: "alice123" });
```
</CodeGroup>
![Graph Memory Visualization](/images/graph_memory/graph_example4.png)
</Step>
<Step title="Add memory 'My name is Alice'">
```python
<CodeGroup>
```python Python
m.add("My name is Alice", user_id="alice123")
```
```typescript TypeScript
memory.add("My name is Alice", { userId: "alice123" });
```
</CodeGroup>
![Graph Memory Visualization](/images/graph_memory/graph_example5.png)
</Step>
<Step title="Add memory 'John loves to hike and Harry loves to hike as well'">
```python
<CodeGroup>
```python Python
m.add("John loves to hike and Harry loves to hike as well", user_id="alice123")
```
```typescript TypeScript
memory.add("John loves to hike and Harry loves to hike as well", { userId: "alice123" });
```
</CodeGroup>
![Graph Memory Visualization](/images/graph_memory/graph_example6.png)
</Step>
<Step title="Add memory 'My friend peter is the spiderman'">
```python
<CodeGroup>
```python Python
m.add("My friend peter is the spiderman", user_id="alice123")
```
```typescript TypeScript
memory.add("My friend peter is the spiderman", { userId: "alice123" });
```
</CodeGroup>
![Graph Memory Visualization](/images/graph_memory/graph_example7.png)
</Step>
@@ -274,10 +398,14 @@ m.add("My friend peter is the spiderman", user_id="alice123")
### Search Memories
<CodeGroup>
```python Code
```python Python
m.search("What is my name?", user_id="alice123")
```
```typescript TypeScript
memory.search("What is my name?", { userId: "alice123" });
```
```json Output
{
'memories': [...],
@@ -297,10 +425,14 @@ Below graph visualization shows what nodes and relationships are fetched from th
![Graph Memory Visualization](/images/graph_memory/graph_example8.png)
<CodeGroup>
```python Code
```python Python
m.search("Who is spiderman?", user_id="alice123")
```
```typescript TypeScript
memory.search("Who is spiderman?", { userId: "alice123" });
```
```json Output
{
'memories': [...],

View File

@@ -1,6 +1,6 @@
{
"name": "mem0ai",
"version": "2.0.1",
"version": "2.1.0",
"description": "The Memory Layer For Your AI Apps",
"main": "./dist/index.js",
"module": "./dist/index.mjs",
@@ -82,17 +82,18 @@
},
"dependencies": {
"axios": "1.7.7",
"neo4j-driver": "^5.28.1",
"openai": "4.28.0",
"uuid": "9.0.1",
"zod": "3.22.4"
},
"peerDependencies": {
"@anthropic-ai/sdk": "0.18.0",
"groq-sdk": "0.3.0",
"@qdrant/js-client-rest": "1.13.0",
"@types/jest": "29.5.14",
"@types/pg": "8.11.0",
"@types/sqlite3": "3.1.11",
"groq-sdk": "0.3.0",
"pg": "8.11.3",
"redis": "4.7.0",
"sqlite3": "5.1.7"

View File

@@ -29,7 +29,9 @@ async function runTests(memory: Memory) {
console.log("\nAdding a single memory...");
const result1 = await memory.add(
"Hi, my name is John and I am a software engineer.",
"user123",
{
userId: "john",
},
);
console.log("Added memory:", result1);
@@ -40,7 +42,9 @@ async function runTests(memory: Memory) {
{ role: "user", content: "What is your favorite city?" },
{ role: "assistant", content: "I love Paris, it is my favorite city." },
],
"user123",
{
userId: "john",
},
);
console.log("Added messages:", result2);
@@ -53,7 +57,9 @@ async function runTests(memory: Memory) {
content: "I love New York, it is my favorite city.",
},
],
"user123",
{
userId: "john",
},
);
console.log("Updated messages:", result3);
@@ -75,15 +81,16 @@ async function runTests(memory: Memory) {
// Get all memories
console.log("\nGetting all memories...");
const allMemories = await memory.getAll("user123");
const allMemories = await memory.getAll({
userId: "john",
});
console.log("All memories:", allMemories);
// Search for memories
console.log("\nSearching memories...");
const searchResult = await memory.search(
"What do you know about Paris?",
"user123",
);
const searchResult = await memory.search("What do you know about Paris?", {
userId: "john",
});
console.log("Search results:", searchResult);
// Get memory history
@@ -255,10 +262,103 @@ async function demoRedis() {
await runTests(memory);
}
async function demoGraphMemory() {
console.log("\n=== Testing Graph Memory Store ===\n");
const memory = new Memory({
version: "v1.1",
embedder: {
provider: "openai",
config: {
apiKey: process.env.OPENAI_API_KEY || "",
model: "text-embedding-3-small",
},
},
vectorStore: {
provider: "memory",
config: {
collectionName: "memories",
dimension: 1536,
},
},
llm: {
provider: "openai",
config: {
apiKey: process.env.OPENAI_API_KEY || "",
model: "gpt-4-turbo-preview",
},
},
graphStore: {
provider: "neo4j",
config: {
url: process.env.NEO4J_URL || "neo4j://localhost:7687",
username: process.env.NEO4J_USERNAME || "neo4j",
password: process.env.NEO4J_PASSWORD || "password",
},
llm: {
provider: "openai",
config: {
model: "gpt-4-turbo-preview",
},
},
},
historyDbPath: "memory.db",
});
try {
// Reset all memories
await memory.reset();
// Add memories with relationships
const result = await memory.add(
[
{
role: "user",
content: "Alice is Bob's sister and works as a doctor.",
},
{
role: "assistant",
content:
"I understand that Alice and Bob are siblings and Alice is a medical professional.",
},
{ role: "user", content: "Bob is married to Carol who is a teacher." },
],
{
userId: "john",
},
);
console.log("Added memories with relationships:", result);
// Search for connected information
const searchResult = await memory.search(
"Tell me about Bob's family connections",
{
userId: "john",
},
);
console.log("Search results with graph relationships:", searchResult);
} catch (error) {
console.error("Error in graph memory demo:", error);
}
}
async function main() {
// Test in-memory store
await demoMemoryStore();
// Test graph memory if Neo4j environment variables are set
if (
process.env.NEO4J_URL &&
process.env.NEO4J_USERNAME &&
process.env.NEO4J_PASSWORD
) {
await demoGraphMemory();
} else {
console.log(
"\nSkipping Graph Memory test - Neo4j environment variables not set",
);
}
// Test PGVector store if environment variables are set
if (process.env.PGVECTOR_DB) {
await demoPGVector();

View File

@@ -23,5 +23,20 @@ export const DEFAULT_MEMORY_CONFIG: MemoryConfig = {
model: "gpt-4-turbo-preview",
},
},
enableGraph: false,
graphStore: {
provider: "neo4j",
config: {
url: process.env.NEO4J_URL || "neo4j://localhost:7687",
username: process.env.NEO4J_USERNAME || "neo4j",
password: process.env.NEO4J_PASSWORD || "password",
},
llm: {
provider: "openai",
config: {
model: "gpt-4-turbo-preview",
},
},
},
historyDbPath: "memory.db",
};

View File

@@ -47,7 +47,11 @@ export class ConfigManager {
historyDbPath:
userConfig.historyDbPath || DEFAULT_MEMORY_CONFIG.historyDbPath,
customPrompt: userConfig.customPrompt,
graphStore: userConfig.graphStore,
graphStore: {
...DEFAULT_MEMORY_CONFIG.graphStore,
...userConfig.graphStore,
},
enableGraph: userConfig.enableGraph || DEFAULT_MEMORY_CONFIG.enableGraph,
};
// Validate the merged config

View File

@@ -0,0 +1,30 @@
import { LLMConfig } from "../types";
export interface Neo4jConfig {
url: string | null;
username: string | null;
password: string | null;
}
export interface GraphStoreConfig {
provider: string;
config: Neo4jConfig;
llm?: LLMConfig;
customPrompt?: string;
}
export function validateNeo4jConfig(config: Neo4jConfig): void {
const { url, username, password } = config;
if (!url || !username || !password) {
throw new Error("Please provide 'url', 'username' and 'password'.");
}
}
export function validateGraphStoreConfig(config: GraphStoreConfig): void {
const { provider } = config;
if (provider === "neo4j") {
validateNeo4jConfig(config.config);
} else {
throw new Error(`Unsupported graph store provider: ${provider}`);
}
}

View File

@@ -0,0 +1,213 @@
export interface GraphToolParameters {
source: string;
destination: string;
relationship: string;
source_type?: string;
destination_type?: string;
}
export interface GraphEntitiesParameters {
entities: Array<{
entity: string;
entity_type: string;
}>;
}
export interface GraphRelationsParameters {
entities: Array<{
source: string;
relationship: string;
destination: string;
}>;
}
export const UPDATE_MEMORY_TOOL_GRAPH = {
type: "function",
function: {
name: "update_graph_memory",
description:
"Update the relationship key of an existing graph memory based on new information.",
parameters: {
type: "object",
properties: {
source: {
type: "string",
description:
"The identifier of the source node in the relationship to be updated.",
},
destination: {
type: "string",
description:
"The identifier of the destination node in the relationship to be updated.",
},
relationship: {
type: "string",
description:
"The new or updated relationship between the source and destination nodes.",
},
},
required: ["source", "destination", "relationship"],
additionalProperties: false,
},
},
};
export const ADD_MEMORY_TOOL_GRAPH = {
type: "function",
function: {
name: "add_graph_memory",
description: "Add a new graph memory to the knowledge graph.",
parameters: {
type: "object",
properties: {
source: {
type: "string",
description:
"The identifier of the source node in the new relationship.",
},
destination: {
type: "string",
description:
"The identifier of the destination node in the new relationship.",
},
relationship: {
type: "string",
description:
"The type of relationship between the source and destination nodes.",
},
source_type: {
type: "string",
description: "The type or category of the source node.",
},
destination_type: {
type: "string",
description: "The type or category of the destination node.",
},
},
required: [
"source",
"destination",
"relationship",
"source_type",
"destination_type",
],
additionalProperties: false,
},
},
};
export const NOOP_TOOL = {
type: "function",
function: {
name: "noop",
description: "No operation should be performed to the graph entities.",
parameters: {
type: "object",
properties: {},
required: [],
additionalProperties: false,
},
},
};
export const RELATIONS_TOOL = {
type: "function",
function: {
name: "establish_relationships",
description:
"Establish relationships among the entities based on the provided text.",
parameters: {
type: "object",
properties: {
entities: {
type: "array",
items: {
type: "object",
properties: {
source: {
type: "string",
description: "The source entity of the relationship.",
},
relationship: {
type: "string",
description:
"The relationship between the source and destination entities.",
},
destination: {
type: "string",
description: "The destination entity of the relationship.",
},
},
required: ["source", "relationship", "destination"],
additionalProperties: false,
},
},
},
required: ["entities"],
additionalProperties: false,
},
},
};
export const EXTRACT_ENTITIES_TOOL = {
type: "function",
function: {
name: "extract_entities",
description: "Extract entities and their types from the text.",
parameters: {
type: "object",
properties: {
entities: {
type: "array",
items: {
type: "object",
properties: {
entity: {
type: "string",
description: "The name or identifier of the entity.",
},
entity_type: {
type: "string",
description: "The type or category of the entity.",
},
},
required: ["entity", "entity_type"],
additionalProperties: false,
},
description: "An array of entities with their types.",
},
},
required: ["entities"],
additionalProperties: false,
},
},
};
export const DELETE_MEMORY_TOOL_GRAPH = {
type: "function",
function: {
name: "delete_graph_memory",
description: "Delete the relationship between two nodes.",
parameters: {
type: "object",
properties: {
source: {
type: "string",
description: "The identifier of the source node in the relationship.",
},
relationship: {
type: "string",
description:
"The existing relationship between the source and destination nodes that needs to be deleted.",
},
destination: {
type: "string",
description:
"The identifier of the destination node in the relationship.",
},
},
required: ["source", "relationship", "destination"],
additionalProperties: false,
},
},
};

View File

@@ -0,0 +1,114 @@
export const UPDATE_GRAPH_PROMPT = `
You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge.
Input:
1. Existing Graph Memories: A list of current graph memories, each containing source, target, and relationship information.
2. New Graph Memory: Fresh information to be integrated into the existing graph structure.
Guidelines:
1. Identification: Use the source and target as primary identifiers when matching existing memories with new information.
2. Conflict Resolution:
- If new information contradicts an existing memory:
a) For matching source and target but differing content, update the relationship of the existing memory.
b) If the new memory provides more recent or accurate information, update the existing memory accordingly.
3. Comprehensive Review: Thoroughly examine each existing graph memory against the new information, updating relationships as necessary. Multiple updates may be required.
4. Consistency: Maintain a uniform and clear style across all memories. Each entry should be concise yet comprehensive.
5. Semantic Coherence: Ensure that updates maintain or improve the overall semantic structure of the graph.
6. Temporal Awareness: If timestamps are available, consider the recency of information when making updates.
7. Relationship Refinement: Look for opportunities to refine relationship descriptions for greater precision or clarity.
8. Redundancy Elimination: Identify and merge any redundant or highly similar relationships that may result from the update.
Memory Format:
source -- RELATIONSHIP -- destination
Task Details:
======= Existing Graph Memories:=======
{existing_memories}
======= New Graph Memory:=======
{new_memories}
Output:
Provide a list of update instructions, each specifying the source, target, and the new relationship to be set. Only include memories that require updates.
`;
export const EXTRACT_RELATIONS_PROMPT = `
You are an advanced algorithm designed to extract structured information from text to construct knowledge graphs. Your goal is to capture comprehensive and accurate information. Follow these key principles:
1. Extract only explicitly stated information from the text.
2. Establish relationships among the entities provided.
3. Use "USER_ID" as the source entity for any self-references (e.g., "I," "me," "my," etc.) in user messages.
CUSTOM_PROMPT
Relationships:
- Use consistent, general, and timeless relationship types.
- Example: Prefer "professor" over "became_professor."
- Relationships should only be established among the entities explicitly mentioned in the user message.
Entity Consistency:
- Ensure that relationships are coherent and logically align with the context of the message.
- Maintain consistent naming for entities across the extracted data.
Strive to construct a coherent and easily understandable knowledge graph by eshtablishing all the relationships among the entities and adherence to the user's context.
Adhere strictly to these guidelines to ensure high-quality knowledge graph extraction.
`;
export const DELETE_RELATIONS_SYSTEM_PROMPT = `
You are a graph memory manager specializing in identifying, managing, and optimizing relationships within graph-based memories. Your primary task is to analyze a list of existing relationships and determine which ones should be deleted based on the new information provided.
Input:
1. Existing Graph Memories: A list of current graph memories, each containing source, relationship, and destination information.
2. New Text: The new information to be integrated into the existing graph structure.
3. Use "USER_ID" as node for any self-references (e.g., "I," "me," "my," etc.) in user messages.
Guidelines:
1. Identification: Use the new information to evaluate existing relationships in the memory graph.
2. Deletion Criteria: Delete a relationship only if it meets at least one of these conditions:
- Outdated or Inaccurate: The new information is more recent or accurate.
- Contradictory: The new information conflicts with or negates the existing information.
3. DO NOT DELETE if their is a possibility of same type of relationship but different destination nodes.
4. Comprehensive Analysis:
- Thoroughly examine each existing relationship against the new information and delete as necessary.
- Multiple deletions may be required based on the new information.
5. Semantic Integrity:
- Ensure that deletions maintain or improve the overall semantic structure of the graph.
- Avoid deleting relationships that are NOT contradictory/outdated to the new information.
6. Temporal Awareness: Prioritize recency when timestamps are available.
7. Necessity Principle: Only DELETE relationships that must be deleted and are contradictory/outdated to the new information to maintain an accurate and coherent memory graph.
Note: DO NOT DELETE if their is a possibility of same type of relationship but different destination nodes.
For example:
Existing Memory: alice -- loves_to_eat -- pizza
New Information: Alice also loves to eat burger.
Do not delete in the above example because there is a possibility that Alice loves to eat both pizza and burger.
Memory Format:
source -- relationship -- destination
Provide a list of deletion instructions, each specifying the relationship to be deleted.
`;
export function getDeleteMessages(
existingMemoriesString: string,
data: string,
userId: string,
): [string, string] {
return [
DELETE_RELATIONS_SYSTEM_PROMPT.replace("USER_ID", userId),
`Here are the existing memories: ${existingMemoriesString} \n\n New Information: ${data}`,
];
}
export function formatEntities(
entities: Array<{
source: string;
relationship: string;
destination: string;
}>,
): string {
return entities
.map((e) => `${e.source} -- ${e.relationship} -- ${e.destination}`)
.join("\n");
}

View File

@@ -3,12 +3,17 @@ import { Message } from "../types";
export interface LLMResponse {
content: string;
role: string;
toolCalls?: Array<{
name: string;
arguments: string;
}>;
}
export interface LLM {
generateResponse(
messages: Message[],
responseFormat?: { type: string },
): Promise<string>;
messages: Array<{ role: string; content: string }>,
response_format: { type: string },
tools?: any[],
): Promise<any>;
generateChat(messages: Message[]): Promise<LLMResponse>;
}

View File

@@ -14,7 +14,8 @@ export class OpenAILLM implements LLM {
async generateResponse(
messages: Message[],
responseFormat?: { type: string },
): Promise<string> {
tools?: any[],
): Promise<string | LLMResponse> {
const completion = await this.openai.chat.completions.create({
messages: messages.map((msg) => ({
role: msg.role as "system" | "user" | "assistant",
@@ -22,8 +23,23 @@ export class OpenAILLM implements LLM {
})),
model: this.model,
response_format: responseFormat as { type: "text" | "json_object" },
...(tools && { tools, tool_choice: "auto" }),
});
return completion.choices[0].message.content || "";
const response = completion.choices[0].message;
if (response.tool_calls) {
return {
content: response.content || "",
role: response.role,
toolCalls: response.tool_calls.map((call) => ({
name: call.function.name,
arguments: call.function.arguments,
})),
};
}
return response.content || "";
}
async generateChat(messages: Message[]): Promise<LLMResponse> {

View File

@@ -3,48 +3,74 @@ import { LLM, LLMResponse } from "./base";
import { LLMConfig, Message } from "../types";
export class OpenAIStructuredLLM implements LLM {
private client: OpenAI;
private openai: OpenAI;
private model: string;
constructor(config: LLMConfig) {
const apiKey = config.apiKey || process.env.OPENAI_API_KEY;
if (!apiKey) {
throw new Error("OpenAI API key is required");
}
const baseUrl = process.env.OPENAI_API_BASE || "https://api.openai.com/v1";
this.client = new OpenAI({ apiKey, baseURL: baseUrl });
this.model = config.model || "gpt-4-0125-preview";
this.openai = new OpenAI({ apiKey: config.apiKey });
this.model = config.model || "gpt-4-turbo-preview";
}
async generateResponse(
messages: Message[],
responseFormat?: { type: string },
): Promise<string> {
const response = await this.client.chat.completions.create({
model: this.model,
responseFormat?: { type: string } | null,
tools?: any[],
): Promise<string | LLMResponse> {
const completion = await this.openai.chat.completions.create({
messages: messages.map((msg) => ({
role: msg.role as "system" | "user" | "assistant",
content: msg.content,
})),
response_format: responseFormat as { type: "text" | "json_object" },
model: this.model,
...(tools
? {
tools: tools.map((tool) => ({
type: "function",
function: {
name: tool.function.name,
description: tool.function.description,
parameters: tool.function.parameters,
},
})),
tool_choice: "auto" as const,
}
: responseFormat
? {
response_format: {
type: responseFormat.type as "text" | "json_object",
},
}
: {}),
});
return response.choices[0].message.content || "";
const response = completion.choices[0].message;
if (response.tool_calls) {
return {
content: response.content || "",
role: response.role,
toolCalls: response.tool_calls.map((call) => ({
name: call.function.name,
arguments: call.function.arguments,
})),
};
}
return response.content || "";
}
async generateChat(messages: Message[]): Promise<LLMResponse> {
const response = await this.client.chat.completions.create({
model: this.model,
const completion = await this.openai.chat.completions.create({
messages: messages.map((msg) => ({
role: msg.role as "system" | "user" | "assistant",
content: msg.content,
})),
model: this.model,
});
const message = response.choices[0].message;
const response = completion.choices[0].message;
return {
content: message.content || "",
role: message.role,
content: response.content || "",
role: response.role,
};
}
}

View File

@@ -0,0 +1,675 @@
import neo4j, { Driver } from "neo4j-driver";
import { BM25 } from "../utils/bm25";
import { GraphStoreConfig } from "../graphs/configs";
import { MemoryConfig } from "../types";
import { EmbedderFactory, LLMFactory } from "../utils/factory";
import { Embedder } from "../embeddings/base";
import { LLM } from "../llms/base";
import {
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_TOOL,
RELATIONS_TOOL,
} from "../graphs/tools";
import { EXTRACT_RELATIONS_PROMPT, getDeleteMessages } from "../graphs/utils";
import { logger } from "../utils/logger";
interface SearchOutput {
source: string;
source_id: string;
relationship: string;
relation_id: string;
destination: string;
destination_id: string;
similarity: number;
}
interface ToolCall {
name: string;
arguments: string;
}
interface LLMResponse {
toolCalls?: ToolCall[];
}
interface Tool {
type: string;
function: {
name: string;
description: string;
parameters: Record<string, any>;
};
}
interface GraphMemoryResult {
deleted_entities: any[];
added_entities: any[];
relations?: any[];
}
export class MemoryGraph {
private config: MemoryConfig;
private graph: Driver;
private embeddingModel: Embedder;
private llm: LLM;
private structuredLlm: LLM;
private llmProvider: string;
private threshold: number;
constructor(config: MemoryConfig) {
this.config = config;
if (
!config.graphStore?.config?.url ||
!config.graphStore?.config?.username ||
!config.graphStore?.config?.password
) {
throw new Error("Neo4j configuration is incomplete");
}
this.graph = neo4j.driver(
config.graphStore.config.url,
neo4j.auth.basic(
config.graphStore.config.username,
config.graphStore.config.password,
),
);
this.embeddingModel = EmbedderFactory.create(
this.config.embedder.provider,
this.config.embedder.config,
);
this.llmProvider = "openai";
if (this.config.llm?.provider) {
this.llmProvider = this.config.llm.provider;
}
if (this.config.graphStore?.llm?.provider) {
this.llmProvider = this.config.graphStore.llm.provider;
}
this.llm = LLMFactory.create(this.llmProvider, this.config.llm.config);
this.structuredLlm = LLMFactory.create(
"openai_structured",
this.config.llm.config,
);
this.threshold = 0.7;
}
async add(
data: string,
filters: Record<string, any>,
): Promise<GraphMemoryResult> {
const entityTypeMap = await this._retrieveNodesFromData(data, filters);
const toBeAdded = await this._establishNodesRelationsFromData(
data,
filters,
entityTypeMap,
);
const searchOutput = await this._searchGraphDb(
Object.keys(entityTypeMap),
filters,
);
const toBeDeleted = await this._getDeleteEntitiesFromSearchOutput(
searchOutput,
data,
filters,
);
const deletedEntities = await this._deleteEntities(
toBeDeleted,
filters["userId"],
);
const addedEntities = await this._addEntities(
toBeAdded,
filters["userId"],
entityTypeMap,
);
return {
deleted_entities: deletedEntities,
added_entities: addedEntities,
relations: toBeAdded,
};
}
async search(query: string, filters: Record<string, any>, limit = 100) {
const entityTypeMap = await this._retrieveNodesFromData(query, filters);
const searchOutput = await this._searchGraphDb(
Object.keys(entityTypeMap),
filters,
);
if (!searchOutput.length) {
return [];
}
const searchOutputsSequence = searchOutput.map((item) => [
item.source,
item.relationship,
item.destination,
]);
const bm25 = new BM25(searchOutputsSequence);
const tokenizedQuery = query.split(" ");
const rerankedResults = bm25.search(tokenizedQuery).slice(0, 5);
const searchResults = rerankedResults.map((item) => ({
source: item[0],
relationship: item[1],
destination: item[2],
}));
logger.info(`Returned ${searchResults.length} search results`);
return searchResults;
}
async deleteAll(filters: Record<string, any>) {
const session = this.graph.session();
try {
await session.run("MATCH (n {user_id: $user_id}) DETACH DELETE n", {
user_id: filters["userId"],
});
} finally {
await session.close();
}
}
async getAll(filters: Record<string, any>, limit = 100) {
const session = this.graph.session();
try {
const result = await session.run(
`
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT toInteger($limit)
`,
{ user_id: filters["userId"], limit: Math.floor(Number(limit)) },
);
const finalResults = result.records.map((record) => ({
source: record.get("source"),
relationship: record.get("relationship"),
target: record.get("target"),
}));
logger.info(`Retrieved ${finalResults.length} relationships`);
return finalResults;
} finally {
await session.close();
}
}
private async _retrieveNodesFromData(
data: string,
filters: Record<string, any>,
) {
const tools = [EXTRACT_ENTITIES_TOOL] as Tool[];
const searchResults = await this.structuredLlm.generateResponse(
[
{
role: "system",
content: `You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use ${filters["userId"]} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.`,
},
{ role: "user", content: data },
],
{ type: "json_object" },
tools,
);
let entityTypeMap: Record<string, string> = {};
try {
if (typeof searchResults !== "string" && searchResults.toolCalls) {
for (const call of searchResults.toolCalls) {
if (call.name === "extract_entities") {
const args = JSON.parse(call.arguments);
for (const item of args.entities) {
entityTypeMap[item.entity] = item.entity_type;
}
}
}
}
} catch (e) {
logger.error(`Error in search tool: ${e}`);
}
entityTypeMap = Object.fromEntries(
Object.entries(entityTypeMap).map(([k, v]) => [
k.toLowerCase().replace(/ /g, "_"),
v.toLowerCase().replace(/ /g, "_"),
]),
);
logger.debug(`Entity type map: ${JSON.stringify(entityTypeMap)}`);
return entityTypeMap;
}
private async _establishNodesRelationsFromData(
data: string,
filters: Record<string, any>,
entityTypeMap: Record<string, string>,
) {
let messages;
if (this.config.graphStore?.customPrompt) {
messages = [
{
role: "system",
content:
EXTRACT_RELATIONS_PROMPT.replace(
"USER_ID",
filters["userId"],
).replace(
"CUSTOM_PROMPT",
`4. ${this.config.graphStore.customPrompt}`,
) + "\nPlease provide your response in JSON format.",
},
{ role: "user", content: data },
];
} else {
messages = [
{
role: "system",
content:
EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["userId"]) +
"\nPlease provide your response in JSON format.",
},
{
role: "user",
content: `List of entities: ${Object.keys(entityTypeMap)}. \n\nText: ${data}`,
},
];
}
const tools = [RELATIONS_TOOL] as Tool[];
const extractedEntities = await this.structuredLlm.generateResponse(
messages,
{ type: "json_object" },
tools,
);
let entities: any[] = [];
if (typeof extractedEntities !== "string" && extractedEntities.toolCalls) {
const toolCall = extractedEntities.toolCalls[0];
if (toolCall && toolCall.arguments) {
const args = JSON.parse(toolCall.arguments);
entities = args.entities || [];
}
}
entities = this._removeSpacesFromEntities(entities);
logger.debug(`Extracted entities: ${JSON.stringify(entities)}`);
return entities;
}
private async _searchGraphDb(
nodeList: string[],
filters: Record<string, any>,
limit = 100,
): Promise<SearchOutput[]> {
const resultRelations: SearchOutput[] = [];
const session = this.graph.session();
try {
for (const node of nodeList) {
const nEmbedding = await this.embeddingModel.embed(node);
const cypher = `
MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (n)-[r]->(m)
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
UNION
MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT toInteger($limit)
`;
const result = await session.run(cypher, {
n_embedding: nEmbedding,
threshold: this.threshold,
user_id: filters["userId"],
limit: Math.floor(Number(limit)),
});
resultRelations.push(
...result.records.map((record) => ({
source: record.get("source"),
source_id: record.get("source_id").toString(),
relationship: record.get("relationship"),
relation_id: record.get("relation_id").toString(),
destination: record.get("destination"),
destination_id: record.get("destination_id").toString(),
similarity: record.get("similarity"),
})),
);
}
} finally {
await session.close();
}
return resultRelations;
}
private async _getDeleteEntitiesFromSearchOutput(
searchOutput: SearchOutput[],
data: string,
filters: Record<string, any>,
) {
const searchOutputString = searchOutput
.map(
(item) =>
`${item.source} -- ${item.relationship} -- ${item.destination}`,
)
.join("\n");
const [systemPrompt, userPrompt] = getDeleteMessages(
searchOutputString,
data,
filters["userId"],
);
const tools = [DELETE_MEMORY_TOOL_GRAPH] as Tool[];
const memoryUpdates = await this.structuredLlm.generateResponse(
[
{ role: "system", content: systemPrompt },
{ role: "user", content: userPrompt },
],
{ type: "json_object" },
tools,
);
const toBeDeleted: any[] = [];
if (typeof memoryUpdates !== "string" && memoryUpdates.toolCalls) {
for (const item of memoryUpdates.toolCalls) {
if (item.name === "delete_graph_memory") {
toBeDeleted.push(JSON.parse(item.arguments));
}
}
}
const cleanedToBeDeleted = this._removeSpacesFromEntities(toBeDeleted);
logger.debug(
`Deleted relationships: ${JSON.stringify(cleanedToBeDeleted)}`,
);
return cleanedToBeDeleted;
}
private async _deleteEntities(toBeDeleted: any[], userId: string) {
const results: any[] = [];
const session = this.graph.session();
try {
for (const item of toBeDeleted) {
const { source, destination, relationship } = item;
const cypher = `
MATCH (n {name: $source_name, user_id: $user_id})
-[r:${relationship}]->
(m {name: $dest_name, user_id: $user_id})
DELETE r
RETURN
n.name AS source,
m.name AS target,
type(r) AS relationship
`;
const result = await session.run(cypher, {
source_name: source,
dest_name: destination,
user_id: userId,
});
results.push(result.records);
}
} finally {
await session.close();
}
return results;
}
private async _addEntities(
toBeAdded: any[],
userId: string,
entityTypeMap: Record<string, string>,
) {
const results: any[] = [];
const session = this.graph.session();
try {
for (const item of toBeAdded) {
const { source, destination, relationship } = item;
const sourceType = entityTypeMap[source] || "unknown";
const destinationType = entityTypeMap[destination] || "unknown";
const sourceEmbedding = await this.embeddingModel.embed(source);
const destEmbedding = await this.embeddingModel.embed(destination);
const sourceNodeSearchResult = await this._searchSourceNode(
sourceEmbedding,
userId,
);
const destinationNodeSearchResult = await this._searchDestinationNode(
destEmbedding,
userId,
);
let cypher: string;
let params: Record<string, any>;
if (
destinationNodeSearchResult.length === 0 &&
sourceNodeSearchResult.length > 0
) {
cypher = `
MATCH (source)
WHERE elementId(source) = $source_id
MERGE (destination:${destinationType} {name: $destination_name, user_id: $user_id})
ON CREATE SET
destination.created = timestamp(),
destination.embedding = $destination_embedding
MERGE (source)-[r:${relationship}]->(destination)
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
`;
params = {
source_id: sourceNodeSearchResult[0].elementId,
destination_name: destination,
destination_embedding: destEmbedding,
user_id: userId,
};
} else if (
destinationNodeSearchResult.length > 0 &&
sourceNodeSearchResult.length === 0
) {
cypher = `
MATCH (destination)
WHERE elementId(destination) = $destination_id
MERGE (source:${sourceType} {name: $source_name, user_id: $user_id})
ON CREATE SET
source.created = timestamp(),
source.embedding = $source_embedding
MERGE (source)-[r:${relationship}]->(destination)
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
`;
params = {
destination_id: destinationNodeSearchResult[0].elementId,
source_name: source,
source_embedding: sourceEmbedding,
user_id: userId,
};
} else if (
sourceNodeSearchResult.length > 0 &&
destinationNodeSearchResult.length > 0
) {
cypher = `
MATCH (source)
WHERE elementId(source) = $source_id
MATCH (destination)
WHERE elementId(destination) = $destination_id
MERGE (source)-[r:${relationship}]->(destination)
ON CREATE SET
r.created_at = timestamp(),
r.updated_at = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
`;
params = {
source_id: sourceNodeSearchResult[0]?.elementId,
destination_id: destinationNodeSearchResult[0]?.elementId,
user_id: userId,
};
} else {
cypher = `
MERGE (n:${sourceType} {name: $source_name, user_id: $user_id})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
ON MATCH SET n.embedding = $source_embedding
MERGE (m:${destinationType} {name: $dest_name, user_id: $user_id})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:${relationship}]->(m)
ON CREATE SET rel.created = timestamp()
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
`;
params = {
source_name: source,
dest_name: destination,
source_embedding: sourceEmbedding,
dest_embedding: destEmbedding,
user_id: userId,
};
}
const result = await session.run(cypher, params);
results.push(result.records);
}
} finally {
await session.close();
}
return results;
}
private _removeSpacesFromEntities(entityList: any[]) {
return entityList.map((item) => ({
...item,
source: item.source.toLowerCase().replace(/ /g, "_"),
relationship: item.relationship.toLowerCase().replace(/ /g, "_"),
destination: item.destination.toLowerCase().replace(/ /g, "_"),
}));
}
private async _searchSourceNode(
sourceEmbedding: number[],
userId: string,
threshold = 0.9,
) {
const session = this.graph.session();
try {
const cypher = `
MATCH (source_candidate)
WHERE source_candidate.embedding IS NOT NULL
AND source_candidate.user_id = $user_id
WITH source_candidate,
round(
reduce(dot = 0.0, i IN range(0, size(source_candidate.embedding)-1) |
dot + source_candidate.embedding[i] * $source_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(source_candidate.embedding)-1) |
l2 + source_candidate.embedding[i] * source_candidate.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($source_embedding)-1) |
l2 + $source_embedding[i] * $source_embedding[i])))
, 4) AS source_similarity
WHERE source_similarity >= $threshold
WITH source_candidate, source_similarity
ORDER BY source_similarity DESC
LIMIT 1
RETURN elementId(source_candidate) as element_id
`;
const params = {
source_embedding: sourceEmbedding,
user_id: userId,
threshold,
};
const result = await session.run(cypher, params);
return result.records.map((record) => ({
elementId: record.get("element_id").toString(),
}));
} finally {
await session.close();
}
}
private async _searchDestinationNode(
destinationEmbedding: number[],
userId: string,
threshold = 0.9,
) {
const session = this.graph.session();
try {
const cypher = `
MATCH (destination_candidate)
WHERE destination_candidate.embedding IS NOT NULL
AND destination_candidate.user_id = $user_id
WITH destination_candidate,
round(
reduce(dot = 0.0, i IN range(0, size(destination_candidate.embedding)-1) |
dot + destination_candidate.embedding[i] * $destination_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(destination_candidate.embedding)-1) |
l2 + destination_candidate.embedding[i] * destination_candidate.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($destination_embedding)-1) |
l2 + $destination_embedding[i] * $destination_embedding[i])))
, 4) AS destination_similarity
WHERE destination_similarity >= $threshold
WITH destination_candidate, destination_similarity
ORDER BY destination_similarity DESC
LIMIT 1
RETURN elementId(destination_candidate) as element_id
`;
const params = {
destination_embedding: destinationEmbedding,
user_id: userId,
threshold,
};
const result = await session.run(cypher, params);
return result.records.map((record) => ({
elementId: record.get("element_id").toString(),
}));
} finally {
await session.close();
}
}
}

View File

@@ -24,6 +24,7 @@ import { Embedder } from "../embeddings/base";
import { LLM } from "../llms/base";
import { VectorStore } from "../vector_stores/base";
import { ConfigManager } from "../config/manager";
import { MemoryGraph } from "./graph_memory";
import {
AddMemoryOptions,
SearchMemoryOptions,
@@ -40,6 +41,8 @@ export class Memory {
private db: SQLiteManager;
private collectionName: string;
private apiVersion: string;
private graphMemory?: MemoryGraph;
private enableGraph: boolean;
constructor(config: Partial<MemoryConfig> = {}) {
// Merge and validate config
@@ -61,6 +64,12 @@ export class Memory {
this.db = new SQLiteManager(this.config.historyDbPath || ":memory:");
this.collectionName = this.config.vectorStore.config.collectionName;
this.apiVersion = this.config.version || "v1.0";
this.enableGraph = this.config.enableGraph || false;
// Initialize graph memory if configured
if (this.enableGraph && this.config.graphStore) {
this.graphMemory = new MemoryGraph(this.config);
}
}
static fromConfig(configDict: Record<string, any>): Memory {
@@ -100,13 +109,30 @@ export class Memory {
? (messages as Message[])
: [{ role: "user", content: messages }];
// Add to vector store
const vectorStoreResult = await this.addToVectorStore(
parsedMessages,
metadata,
filters,
);
return { results: vectorStoreResult };
// Add to graph store if available
let graphResult;
if (this.graphMemory) {
try {
graphResult = await this.graphMemory.add(
parsedMessages.map((m) => m.content).join("\n"),
filters,
);
} catch (error) {
console.error("Error adding to graph memory:", error);
}
}
return {
results: vectorStoreResult,
relations: graphResult?.relations,
};
}
private async addToVectorStore(
@@ -284,6 +310,7 @@ export class Memory {
);
}
// Search vector store
const queryEmbedding = await this.embedder.embed(query);
const memories = await this.vectorStore.search(
queryEmbedding,
@@ -291,6 +318,16 @@ export class Memory {
filters,
);
// Search graph store if available
let graphResults;
if (this.graphMemory) {
try {
graphResults = await this.graphMemory.search(query, filters);
} catch (error) {
console.error("Error searching graph memory:", error);
}
}
const excludedKeys = new Set([
"userId",
"agentId",
@@ -315,7 +352,10 @@ export class Memory {
...(mem.payload.runId && { runId: mem.payload.runId }),
}));
return { results };
return {
results,
relations: graphResults,
};
}
async update(memoryId: string, data: string): Promise<{ message: string }> {
@@ -360,6 +400,9 @@ export class Memory {
async reset(): Promise<void> {
await this.db.reset();
await this.vectorStore.deleteCol();
if (this.graphMemory) {
await this.graphMemory.deleteAll({ userId: "default" });
}
this.vectorStore = VectorStoreFactory.create(
this.config.vectorStore.provider,
this.config.vectorStore.config,

View File

@@ -14,7 +14,6 @@ export interface AddMemoryOptions extends Entity {
}
export interface SearchMemoryOptions extends Entity {
query: string;
limit?: number;
filters?: SearchFilters;
}

View File

@@ -34,7 +34,7 @@ export function getFactRetrievalMessages(
Input: Me favourite movies are Inception and Interstellar.
Output: {"facts" : ["Favourite movies are Inception and Interstellar"]}
Return the facts and preferences in a json format as shown above.
Return the facts and preferences in a JSON format as shown above. You MUST return a valid JSON object with a 'facts' key containing an array of strings.
Remember the following:
- Today's date is ${new Date().toISOString().split("T")[0]}.
@@ -43,17 +43,17 @@ export function getFactRetrievalMessages(
- If the user asks where you fetched my information, answer that you found from publicly available sources on internet.
- If you do not find anything relevant in the below conversation, you can return an empty list corresponding to the "facts" key.
- Create the facts based on the user and assistant messages only. Do not pick anything from the system messages.
- Make sure to return the response in the format mentioned in the examples. The response should be in json with a key as "facts" and corresponding value will be a list of strings.
- Make sure to return the response in the JSON format mentioned in the examples. The response should be in JSON with a key as "facts" and corresponding value will be a list of strings.
- DO NOT RETURN ANYTHING ELSE OTHER THAN THE JSON FORMAT.
- DO NOT ADD ANY ADDITIONAL TEXT OR CODEBLOCK IN THE JSON FIELDS WHICH MAKE IT INVALUD SUCH AS "\`\`\`json" OR "\`\`\`".
- DO NOT ADD ANY ADDITIONAL TEXT OR CODEBLOCK IN THE JSON FIELDS WHICH MAKE IT INVALID SUCH AS "\`\`\`json" OR "\`\`\`".
- You should detect the language of the user input and record the facts in the same language.
- For basic factual statements, break them down into individual facts if they contain multiple pieces of information.
Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation and return them in the json format as shown above.
Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation and return them in the JSON format as shown above.
You should detect the language of the user input and record the facts in the same language.
`;
const userPrompt = `Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation and return them in the json format as shown above.\n\nInput:\n${parsedMessages}`;
const userPrompt = `Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation and return them in the JSON format as shown above.\n\nInput:\n${parsedMessages}`;
return [systemPrompt, userPrompt];
}
@@ -218,14 +218,14 @@ export function getUpdateMemoryMessages(
${JSON.stringify(newRetrievedFacts, null, 2)}
Follow the instruction mentioned below:
- Do not return anything from the custom few shot prompts provided above.
- Do not return anything from the custom few shot example prompts provided above.
- If the current memory is empty, then you have to add the new retrieved facts to the memory.
- You should return the updated memory in only JSON format as shown below. The memory key should be the same if no changes are made.
- If there is an addition, generate a new key and add the new memory corresponding to it.
- If there is a deletion, the memory key-value pair should be removed from the memory.
- If there is an update, the ID key should remain the same and only the value needs to be updated.
- DO NOT RETURN ANYTHING ELSE OTHER THAN THE JSON FORMAT.
- DO NOT ADD ANY ADDITIONAL TEXT OR CODEBLOCK IN THE JSON FIELDS WHICH MAKE IT INVALUD SUCH AS "\`\`\`json" OR "\`\`\`".
- DO NOT ADD ANY ADDITIONAL TEXT OR CODEBLOCK IN THE JSON FIELDS WHICH MAKE IT INVALID SUCH AS "\`\`\`json" OR "\`\`\`".
Do not return anything except the JSON format.`;
}

View File

@@ -17,12 +17,23 @@ export interface VectorStoreConfig {
}
export interface LLMConfig {
apiKey: string;
provider?: string;
config?: Record<string, any>;
apiKey?: string;
model?: string;
}
export interface Neo4jConfig {
url: string;
username: string;
password: string;
}
export interface GraphStoreConfig {
config?: any;
provider: string;
config: Neo4jConfig;
llm?: LLMConfig;
customPrompt?: string;
}
export interface MemoryConfig {
@@ -42,6 +53,7 @@ export interface MemoryConfig {
historyDbPath?: string;
customPrompt?: string;
graphStore?: GraphStoreConfig;
enableGraph?: boolean;
}
export interface MemoryItem {
@@ -99,9 +111,22 @@ export const MemoryConfigSchema = z.object({
}),
historyDbPath: z.string().optional(),
customPrompt: z.string().optional(),
enableGraph: z.boolean().optional(),
graphStore: z
.object({
config: z.any().optional(),
provider: z.string(),
config: z.object({
url: z.string(),
username: z.string(),
password: z.string(),
}),
llm: z
.object({
provider: z.string(),
config: z.record(z.string(), z.any()),
})
.optional(),
customPrompt: z.string().optional(),
})
.optional(),
});

View File

@@ -0,0 +1,64 @@
export class BM25 {
private documents: string[][];
private k1: number;
private b: number;
private avgDocLength: number;
private docFreq: Map<string, number>;
private docLengths: number[];
private idf: Map<string, number>;
constructor(documents: string[][], k1 = 1.5, b = 0.75) {
this.documents = documents;
this.k1 = k1;
this.b = b;
this.docLengths = documents.map((doc) => doc.length);
this.avgDocLength =
this.docLengths.reduce((a, b) => a + b, 0) / documents.length;
this.docFreq = new Map();
this.idf = new Map();
this.computeIdf();
}
private computeIdf() {
const N = this.documents.length;
// Count document frequency for each term
for (const doc of this.documents) {
const terms = new Set(doc);
for (const term of terms) {
this.docFreq.set(term, (this.docFreq.get(term) || 0) + 1);
}
}
// Compute IDF for each term
for (const [term, freq] of this.docFreq) {
this.idf.set(term, Math.log((N - freq + 0.5) / (freq + 0.5) + 1));
}
}
private score(query: string[], doc: string[], index: number): number {
let score = 0;
const docLength = this.docLengths[index];
for (const term of query) {
const tf = doc.filter((t) => t === term).length;
const idf = this.idf.get(term) || 0;
score +=
(idf * tf * (this.k1 + 1)) /
(tf +
this.k1 * (1 - this.b + (this.b * docLength) / this.avgDocLength));
}
return score;
}
search(query: string[]): string[][] {
const scores = this.documents.map((doc, idx) => ({
doc,
score: this.score(query, doc, idx),
}));
return scores.sort((a, b) => b.score - a.score).map((item) => item.doc);
}
}

View File

@@ -24,7 +24,7 @@ export class EmbedderFactory {
export class LLMFactory {
static create(provider: string, config: LLMConfig): LLM {
switch (provider.toLowerCase()) {
switch (provider) {
case "openai":
return new OpenAILLM(config);
case "openai_structured":

View File

@@ -0,0 +1,13 @@
export interface Logger {
info: (message: string) => void;
error: (message: string) => void;
debug: (message: string) => void;
warn: (message: string) => void;
}
export const logger: Logger = {
info: (message: string) => console.log(`[INFO] ${message}`),
error: (message: string) => console.error(`[ERROR] ${message}`),
debug: (message: string) => console.debug(`[DEBUG] ${message}`),
warn: (message: string) => console.warn(`[WARN] ${message}`),
};

View File

@@ -1,5 +1,7 @@
import { VectorStore } from "./base";
import { SearchFilters, VectorStoreConfig, VectorStoreResult } from "../types";
import sqlite3 from "sqlite3";
import path from "path";
interface MemoryVector {
id: string;
@@ -8,12 +10,55 @@ interface MemoryVector {
}
export class MemoryVectorStore implements VectorStore {
private vectors: Map<string, MemoryVector>;
private db: sqlite3.Database;
private dimension: number;
private dbPath: string;
constructor(config: VectorStoreConfig) {
this.vectors = new Map();
this.dimension = config.dimension || 1536; // Default OpenAI dimension
this.dbPath = path.join(process.cwd(), "vector_store.db");
if (config.dbPath) {
this.dbPath = config.dbPath;
}
this.db = new sqlite3.Database(this.dbPath);
this.init().catch(console.error);
}
private async init() {
await this.run(`
CREATE TABLE IF NOT EXISTS vectors (
id TEXT PRIMARY KEY,
vector BLOB NOT NULL,
payload TEXT NOT NULL
)
`);
}
private async run(sql: string, params: any[] = []): Promise<void> {
return new Promise((resolve, reject) => {
this.db.run(sql, params, (err) => {
if (err) reject(err);
else resolve();
});
});
}
private async all(sql: string, params: any[] = []): Promise<any[]> {
return new Promise((resolve, reject) => {
this.db.all(sql, params, (err, rows) => {
if (err) reject(err);
else resolve(rows);
});
});
}
private async getOne(sql: string, params: any[] = []): Promise<any> {
return new Promise((resolve, reject) => {
this.db.get(sql, params, (err, row) => {
if (err) reject(err);
else resolve(row);
});
});
}
private cosineSimilarity(a: number[], b: number[]): number {
@@ -46,11 +91,11 @@ export class MemoryVectorStore implements VectorStore {
`Vector dimension mismatch. Expected ${this.dimension}, got ${vectors[i].length}`,
);
}
this.vectors.set(ids[i], {
id: ids[i],
vector: vectors[i],
payload: payloads[i],
});
const vectorBuffer = Buffer.from(new Float32Array(vectors[i]).buffer);
await this.run(
`INSERT OR REPLACE INTO vectors (id, vector, payload) VALUES (?, ?, ?)`,
[ids[i], vectorBuffer, JSON.stringify(payloads[i])],
);
}
}
@@ -65,13 +110,23 @@ export class MemoryVectorStore implements VectorStore {
);
}
const rows = await this.all(`SELECT * FROM vectors`);
const results: VectorStoreResult[] = [];
for (const vector of this.vectors.values()) {
if (this.filterVector(vector, filters)) {
const score = this.cosineSimilarity(query, vector.vector);
for (const row of rows) {
const vector = new Float32Array(row.vector.buffer);
const payload = JSON.parse(row.payload);
const memoryVector: MemoryVector = {
id: row.id,
vector: Array.from(vector),
payload,
};
if (this.filterVector(memoryVector, filters)) {
const score = this.cosineSimilarity(query, Array.from(vector));
results.push({
id: vector.id,
payload: vector.payload,
id: memoryVector.id,
payload: memoryVector.payload,
score,
});
}
@@ -82,11 +137,15 @@ export class MemoryVectorStore implements VectorStore {
}
async get(vectorId: string): Promise<VectorStoreResult | null> {
const vector = this.vectors.get(vectorId);
if (!vector) return null;
const row = await this.getOne(`SELECT * FROM vectors WHERE id = ?`, [
vectorId,
]);
if (!row) return null;
const payload = JSON.parse(row.payload);
return {
id: vector.id,
payload: vector.payload,
id: row.id,
payload,
};
}
@@ -100,36 +159,46 @@ export class MemoryVectorStore implements VectorStore {
`Vector dimension mismatch. Expected ${this.dimension}, got ${vector.length}`,
);
}
const existing = this.vectors.get(vectorId);
if (!existing) throw new Error(`Vector with ID ${vectorId} not found`);
this.vectors.set(vectorId, {
id: vectorId,
vector,
payload,
});
const vectorBuffer = Buffer.from(new Float32Array(vector).buffer);
await this.run(`UPDATE vectors SET vector = ?, payload = ? WHERE id = ?`, [
vectorBuffer,
JSON.stringify(payload),
vectorId,
]);
}
async delete(vectorId: string): Promise<void> {
this.vectors.delete(vectorId);
await this.run(`DELETE FROM vectors WHERE id = ?`, [vectorId]);
}
async deleteCol(): Promise<void> {
this.vectors.clear();
await this.run(`DROP TABLE IF EXISTS vectors`);
await this.init();
}
async list(
filters?: SearchFilters,
limit: number = 100,
): Promise<[VectorStoreResult[], number]> {
const rows = await this.all(`SELECT * FROM vectors`);
const results: VectorStoreResult[] = [];
for (const vector of this.vectors.values()) {
if (this.filterVector(vector, filters)) {
for (const row of rows) {
const payload = JSON.parse(row.payload);
const memoryVector: MemoryVector = {
id: row.id,
vector: Array.from(new Float32Array(row.vector.buffer)),
payload,
};
if (this.filterVector(memoryVector, filters)) {
results.push({
id: vector.id,
payload: vector.payload,
id: memoryVector.id,
payload: memoryVector.payload,
});
}
}
return [results.slice(0, limit), results.length];
}
}