diff --git a/mem0/graphs/configs.py b/mem0/graphs/configs.py index e0e98274..50c585a1 100644 --- a/mem0/graphs/configs.py +++ b/mem0/graphs/configs.py @@ -9,6 +9,7 @@ class Neo4jConfig(BaseModel): url: Optional[str] = Field(None, description="Host address for the graph database") username: Optional[str] = Field(None, description="Username for the graph database") password: Optional[str] = Field(None, description="Password for the graph database") + database: Optional[str] = Field(None, description="Database for the graph database") @model_validator(mode="before") def check_host_port_or_path(cls, values): diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index f6307a73..760d864a 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -33,8 +33,9 @@ class MemoryGraph: self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password, + self.config.graph_store.config.database, refresh_schema=False, - driver_config={"notifications_min_severity":"OFF"} + driver_config={"notifications_min_severity":"OFF"}, ) self.embedding_model = EmbedderFactory.create( self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config