From 65cffa03690226877be888778f74311499deae2c Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 19 Mar 2025 21:29:36 +0530 Subject: [PATCH] Fix: made tools support for graph (#2400) --- mem0/llms/azure_openai_structured.py | 38 +++++++++++++++++++++++++- mem0/llms/openai_structured.py | 40 ++++++++++++++++++++++++++-- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/mem0/llms/azure_openai_structured.py b/mem0/llms/azure_openai_structured.py index 68ca0122..9e07bf0b 100644 --- a/mem0/llms/azure_openai_structured.py +++ b/mem0/llms/azure_openai_structured.py @@ -50,10 +50,40 @@ class AzureOpenAIStructuredLLM(LLMBase): default_headers=default_headers, ) + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [], + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append( + { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + + return processed_response + else: + return response.choices[0].message.content + def generate_response( self, messages: List[Dict[str, str]], response_format: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", ) -> str: """ Generates a response using Azure OpenAI based on the provided messages. @@ -61,6 +91,8 @@ class AzureOpenAIStructuredLLM(LLMBase): Args: messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. response_format (Optional[str]): The desired format of the response. Defaults to None. + tools (Optional[List[Dict]]): A list of dictionaries, each containing a 'name' and 'arguments' key. + tool_choice (str): The choice of tool to use. Defaults to "auto". Returns: str: The generated response from the model. @@ -76,5 +108,9 @@ class AzureOpenAIStructuredLLM(LLMBase): if response_format: params["response_format"] = response_format + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice + response = self.client.chat.completions.create(**params) - return response.choices[0].message.content + return self._parse_response(response, tools) diff --git a/mem0/llms/openai_structured.py b/mem0/llms/openai_structured.py index bd5b2c62..08c2edfb 100644 --- a/mem0/llms/openai_structured.py +++ b/mem0/llms/openai_structured.py @@ -33,10 +33,42 @@ class OpenAIStructuredLLM(LLMBase): ) self.client = OpenAI(api_key=api_key, base_url=base_url) + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + Args: + response: The raw response from API. + tools (list, optional): List of tools that the model can call. + Returns: + str or dict: The processed response. + """ + + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [], + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append( + { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + + return processed_response + + else: + return response.choices[0].message.content + def generate_response( self, messages: List[Dict[str, str]], response_format: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", ) -> str: """ Generates a response using OpenAI based on the provided messages. @@ -44,7 +76,8 @@ class OpenAIStructuredLLM(LLMBase): Args: messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. response_format (Optional[str]): The desired format of the response. Defaults to None. - + tools (Optional[List[Dict]]): A list of dictionaries, each containing a 'name' and 'arguments' key. + tool_choice (str): The choice of tool to use. Defaults to "auto". Returns: str: The generated response from the model. @@ -57,6 +90,9 @@ class OpenAIStructuredLLM(LLMBase): if response_format: params["response_format"] = response_format + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice response = self.client.beta.chat.completions.parse(**params) - return response.choices[0].message.content + return self._parse_response(response, tools)