Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -1,9 +1,11 @@
import asyncio
import warnings
from google.adk.agents import Agent
from google.adk.sessions import InMemorySessionService
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai import types
from mem0 import MemoryClient
warnings.filterwarnings("ignore", category=DeprecationWarning)
@@ -19,14 +21,14 @@ def save_patient_info(information: str) -> dict:
print(f"Storing patient information: {information[:30]}...")
# Get user_id from session state or use default
user_id = getattr(save_patient_info, 'user_id', 'default_user')
user_id = getattr(save_patient_info, "user_id", "default_user")
# Store in Mem0
response = mem0_client.add(
mem0_client.add(
[{"role": "user", "content": information}],
user_id=user_id,
run_id="healthcare_session",
metadata={"type": "patient_information"}
metadata={"type": "patient_information"},
)
return {"status": "success", "message": "Information saved"}
@@ -37,7 +39,7 @@ def retrieve_patient_info(query: str) -> str:
print(f"Searching for patient information: {query}")
# Get user_id from session state or use default
user_id = getattr(retrieve_patient_info, 'user_id', 'default_user')
user_id = getattr(retrieve_patient_info, "user_id", "default_user")
# Search Mem0
results = mem0_client.search(
@@ -45,7 +47,7 @@ def retrieve_patient_info(query: str) -> str:
user_id=user_id,
run_id="healthcare_session",
limit=5,
threshold=0.7 # Higher threshold for more relevant results
threshold=0.7, # Higher threshold for more relevant results
)
if not results:
@@ -65,7 +67,7 @@ def schedule_appointment(date: str, time: str, reason: str) -> dict:
"status": "success",
"appointment_id": appointment_id,
"confirmation": f"Appointment scheduled for {date} at {time} for {reason}",
"message": "Please arrive 15 minutes early to complete paperwork."
"message": "Please arrive 15 minutes early to complete paperwork.",
}
@@ -89,7 +91,7 @@ IMPORTANT GUIDELINES:
- For serious symptoms, always recommend consulting a healthcare professional.
- Keep all patient information confidential.
""",
tools=[save_patient_info, retrieve_patient_info, schedule_appointment]
tools=[save_patient_info, retrieve_patient_info, schedule_appointment],
)
# Set Up Session and Runner
@@ -101,18 +103,10 @@ USER_ID = "Alex"
SESSION_ID = "session_001"
# Create a session
session = session_service.create_session(
app_name=APP_NAME,
user_id=USER_ID,
session_id=SESSION_ID
)
session = session_service.create_session(app_name=APP_NAME, user_id=USER_ID, session_id=SESSION_ID)
# Create the runner
runner = Runner(
agent=healthcare_agent,
app_name=APP_NAME,
session_service=session_service
)
runner = Runner(agent=healthcare_agent, app_name=APP_NAME, session_service=session_service)
# Interact with the Healthcare Assistant
@@ -121,21 +115,14 @@ async def call_agent_async(query, runner, user_id, session_id):
print(f"\n>>> Patient: {query}")
# Format the user's message
content = types.Content(
role='user',
parts=[types.Part(text=query)]
)
content = types.Content(role="user", parts=[types.Part(text=query)])
# Set user_id for tools to access
save_patient_info.user_id = user_id
retrieve_patient_info.user_id = user_id
# Run the agent
async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
new_message=content
):
async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=content):
if event.is_final_response():
if event.content and event.content.parts:
response = event.content.parts[0].text
@@ -152,7 +139,7 @@ async def run_conversation():
"Hi, I'm Alex. I've been having headaches for the past week, and I have a penicillin allergy.",
runner=runner,
user_id=USER_ID,
session_id=SESSION_ID
session_id=SESSION_ID,
)
# Request for health information
@@ -160,7 +147,7 @@ async def run_conversation():
"Can you tell me more about what might be causing my headaches?",
runner=runner,
user_id=USER_ID,
session_id=SESSION_ID
session_id=SESSION_ID,
)
# Schedule an appointment
@@ -168,15 +155,12 @@ async def run_conversation():
"I think I should see a doctor. Can you help me schedule an appointment for next Monday at 2pm?",
runner=runner,
user_id=USER_ID,
session_id=SESSION_ID
session_id=SESSION_ID,
)
# Test memory - should remember patient name, symptoms, and allergy
await call_agent_async(
"What medications should I avoid for my headaches?",
runner=runner,
user_id=USER_ID,
session_id=SESSION_ID
"What medications should I avoid for my headaches?", runner=runner, user_id=USER_ID, session_id=SESSION_ID
)
@@ -191,37 +175,28 @@ async def interactive_mode():
session_id = f"session_{hash(patient_id) % 1000:03d}"
# Create session for this user
session = session_service.create_session(
app_name=APP_NAME,
user_id=patient_id,
session_id=session_id
)
session_service.create_session(app_name=APP_NAME, user_id=patient_id, session_id=session_id)
print(f"\nStarting conversation with patient ID: {patient_id}")
print("Type your message and press Enter.")
while True:
user_input = input("\n>>> Patient: ").strip()
if user_input.lower() in ['exit', 'quit', 'bye']:
if user_input.lower() in ["exit", "quit", "bye"]:
print("Ending conversation. Thank you!")
break
await call_agent_async(
user_input,
runner=runner,
user_id=patient_id,
session_id=session_id
)
await call_agent_async(user_input, runner=runner, user_id=patient_id, session_id=session_id)
# Main execution
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Healthcare Assistant with Memory')
parser.add_argument('--demo', action='store_true', help='Run the demo conversation')
parser.add_argument('--interactive', action='store_true', help='Run in interactive mode')
parser.add_argument('--patient-id', type=str, default=USER_ID, help='Patient ID for the conversation')
parser = argparse.ArgumentParser(description="Healthcare Assistant with Memory")
parser.add_argument("--demo", action="store_true", help="Run the demo conversation")
parser.add_argument("--interactive", action="store_true", help="Run in interactive mode")
parser.add_argument("--patient-id", type=str, default=USER_ID, help="Patient ID for the conversation")
args = parser.parse_args()
if args.demo:
@@ -231,5 +206,3 @@ if __name__ == "__main__":
else:
# Default to demo mode if no arguments provided
asyncio.run(run_conversation())