Modified the return statement for ADD call | Added tests to main.py and graph_memory.py (#1812)
This commit is contained in:
@@ -101,6 +101,8 @@ class MemoryGraph:
|
||||
elif item['name'] == "noop":
|
||||
continue
|
||||
|
||||
returned_entities = []
|
||||
|
||||
for item in to_be_added:
|
||||
source = item['source'].lower().replace(" ", "_")
|
||||
source_type = item['source_type'].lower().replace(" ", "_")
|
||||
@@ -108,6 +110,12 @@ class MemoryGraph:
|
||||
destination = item['destination'].lower().replace(" ", "_")
|
||||
destination_type = item['destination_type'].lower().replace(" ", "_")
|
||||
|
||||
returned_entities.append({
|
||||
"source" : source,
|
||||
"relationship" : relation,
|
||||
"target" : destination
|
||||
})
|
||||
|
||||
# Create embeddings
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
dest_embedding = self.embedding_model.embed(destination)
|
||||
@@ -137,6 +145,7 @@ class MemoryGraph:
|
||||
|
||||
logger.info(f"Added {len(to_be_added)} new memories to the graph")
|
||||
|
||||
return returned_entities
|
||||
|
||||
def _search(self, query, filters):
|
||||
_tools = [SEARCH_TOOL]
|
||||
@@ -155,8 +164,10 @@ class MemoryGraph:
|
||||
|
||||
for item in search_results['tool_calls']:
|
||||
if item['name'] == "search":
|
||||
node_list.extend(item['arguments']['nodes'])
|
||||
relation_list.extend(item['arguments']['relations'])
|
||||
try:
|
||||
node_list.extend(item['arguments']['nodes'])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search tool: {e}")
|
||||
|
||||
node_list = list(set(node_list))
|
||||
relation_list = list(set(relation_list))
|
||||
@@ -228,8 +239,8 @@ class MemoryGraph:
|
||||
for item in reranked_results:
|
||||
search_results.append({
|
||||
"source": item[0],
|
||||
"relation": item[1],
|
||||
"destination": item[2]
|
||||
"relationship": item[1],
|
||||
"target": item[2]
|
||||
})
|
||||
|
||||
logger.info(f"Returned {len(search_results)} search results")
|
||||
|
||||
Reference in New Issue
Block a user