Formatting (#2750)
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import warnings
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.sessions import InMemorySessionService
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions import InMemorySessionService
|
||||
from google.genai import types
|
||||
|
||||
from mem0 import MemoryClient
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
@@ -19,14 +21,14 @@ def save_patient_info(information: str) -> dict:
|
||||
print(f"Storing patient information: {information[:30]}...")
|
||||
|
||||
# Get user_id from session state or use default
|
||||
user_id = getattr(save_patient_info, 'user_id', 'default_user')
|
||||
user_id = getattr(save_patient_info, "user_id", "default_user")
|
||||
|
||||
# Store in Mem0
|
||||
response = mem0_client.add(
|
||||
mem0_client.add(
|
||||
[{"role": "user", "content": information}],
|
||||
user_id=user_id,
|
||||
run_id="healthcare_session",
|
||||
metadata={"type": "patient_information"}
|
||||
metadata={"type": "patient_information"},
|
||||
)
|
||||
|
||||
return {"status": "success", "message": "Information saved"}
|
||||
@@ -37,7 +39,7 @@ def retrieve_patient_info(query: str) -> str:
|
||||
print(f"Searching for patient information: {query}")
|
||||
|
||||
# Get user_id from session state or use default
|
||||
user_id = getattr(retrieve_patient_info, 'user_id', 'default_user')
|
||||
user_id = getattr(retrieve_patient_info, "user_id", "default_user")
|
||||
|
||||
# Search Mem0
|
||||
results = mem0_client.search(
|
||||
@@ -45,7 +47,7 @@ def retrieve_patient_info(query: str) -> str:
|
||||
user_id=user_id,
|
||||
run_id="healthcare_session",
|
||||
limit=5,
|
||||
threshold=0.7 # Higher threshold for more relevant results
|
||||
threshold=0.7, # Higher threshold for more relevant results
|
||||
)
|
||||
|
||||
if not results:
|
||||
@@ -65,7 +67,7 @@ def schedule_appointment(date: str, time: str, reason: str) -> dict:
|
||||
"status": "success",
|
||||
"appointment_id": appointment_id,
|
||||
"confirmation": f"Appointment scheduled for {date} at {time} for {reason}",
|
||||
"message": "Please arrive 15 minutes early to complete paperwork."
|
||||
"message": "Please arrive 15 minutes early to complete paperwork.",
|
||||
}
|
||||
|
||||
|
||||
@@ -89,7 +91,7 @@ IMPORTANT GUIDELINES:
|
||||
- For serious symptoms, always recommend consulting a healthcare professional.
|
||||
- Keep all patient information confidential.
|
||||
""",
|
||||
tools=[save_patient_info, retrieve_patient_info, schedule_appointment]
|
||||
tools=[save_patient_info, retrieve_patient_info, schedule_appointment],
|
||||
)
|
||||
|
||||
# Set Up Session and Runner
|
||||
@@ -101,18 +103,10 @@ USER_ID = "Alex"
|
||||
SESSION_ID = "session_001"
|
||||
|
||||
# Create a session
|
||||
session = session_service.create_session(
|
||||
app_name=APP_NAME,
|
||||
user_id=USER_ID,
|
||||
session_id=SESSION_ID
|
||||
)
|
||||
session = session_service.create_session(app_name=APP_NAME, user_id=USER_ID, session_id=SESSION_ID)
|
||||
|
||||
# Create the runner
|
||||
runner = Runner(
|
||||
agent=healthcare_agent,
|
||||
app_name=APP_NAME,
|
||||
session_service=session_service
|
||||
)
|
||||
runner = Runner(agent=healthcare_agent, app_name=APP_NAME, session_service=session_service)
|
||||
|
||||
|
||||
# Interact with the Healthcare Assistant
|
||||
@@ -121,21 +115,14 @@ async def call_agent_async(query, runner, user_id, session_id):
|
||||
print(f"\n>>> Patient: {query}")
|
||||
|
||||
# Format the user's message
|
||||
content = types.Content(
|
||||
role='user',
|
||||
parts=[types.Part(text=query)]
|
||||
)
|
||||
content = types.Content(role="user", parts=[types.Part(text=query)])
|
||||
|
||||
# Set user_id for tools to access
|
||||
save_patient_info.user_id = user_id
|
||||
retrieve_patient_info.user_id = user_id
|
||||
|
||||
# Run the agent
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
new_message=content
|
||||
):
|
||||
async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=content):
|
||||
if event.is_final_response():
|
||||
if event.content and event.content.parts:
|
||||
response = event.content.parts[0].text
|
||||
@@ -152,7 +139,7 @@ async def run_conversation():
|
||||
"Hi, I'm Alex. I've been having headaches for the past week, and I have a penicillin allergy.",
|
||||
runner=runner,
|
||||
user_id=USER_ID,
|
||||
session_id=SESSION_ID
|
||||
session_id=SESSION_ID,
|
||||
)
|
||||
|
||||
# Request for health information
|
||||
@@ -160,7 +147,7 @@ async def run_conversation():
|
||||
"Can you tell me more about what might be causing my headaches?",
|
||||
runner=runner,
|
||||
user_id=USER_ID,
|
||||
session_id=SESSION_ID
|
||||
session_id=SESSION_ID,
|
||||
)
|
||||
|
||||
# Schedule an appointment
|
||||
@@ -168,15 +155,12 @@ async def run_conversation():
|
||||
"I think I should see a doctor. Can you help me schedule an appointment for next Monday at 2pm?",
|
||||
runner=runner,
|
||||
user_id=USER_ID,
|
||||
session_id=SESSION_ID
|
||||
session_id=SESSION_ID,
|
||||
)
|
||||
|
||||
# Test memory - should remember patient name, symptoms, and allergy
|
||||
await call_agent_async(
|
||||
"What medications should I avoid for my headaches?",
|
||||
runner=runner,
|
||||
user_id=USER_ID,
|
||||
session_id=SESSION_ID
|
||||
"What medications should I avoid for my headaches?", runner=runner, user_id=USER_ID, session_id=SESSION_ID
|
||||
)
|
||||
|
||||
|
||||
@@ -191,37 +175,28 @@ async def interactive_mode():
|
||||
session_id = f"session_{hash(patient_id) % 1000:03d}"
|
||||
|
||||
# Create session for this user
|
||||
session = session_service.create_session(
|
||||
app_name=APP_NAME,
|
||||
user_id=patient_id,
|
||||
session_id=session_id
|
||||
)
|
||||
session_service.create_session(app_name=APP_NAME, user_id=patient_id, session_id=session_id)
|
||||
|
||||
print(f"\nStarting conversation with patient ID: {patient_id}")
|
||||
print("Type your message and press Enter.")
|
||||
|
||||
while True:
|
||||
user_input = input("\n>>> Patient: ").strip()
|
||||
if user_input.lower() in ['exit', 'quit', 'bye']:
|
||||
if user_input.lower() in ["exit", "quit", "bye"]:
|
||||
print("Ending conversation. Thank you!")
|
||||
break
|
||||
|
||||
await call_agent_async(
|
||||
user_input,
|
||||
runner=runner,
|
||||
user_id=patient_id,
|
||||
session_id=session_id
|
||||
)
|
||||
await call_agent_async(user_input, runner=runner, user_id=patient_id, session_id=session_id)
|
||||
|
||||
|
||||
# Main execution
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Healthcare Assistant with Memory')
|
||||
parser.add_argument('--demo', action='store_true', help='Run the demo conversation')
|
||||
parser.add_argument('--interactive', action='store_true', help='Run in interactive mode')
|
||||
parser.add_argument('--patient-id', type=str, default=USER_ID, help='Patient ID for the conversation')
|
||||
parser = argparse.ArgumentParser(description="Healthcare Assistant with Memory")
|
||||
parser.add_argument("--demo", action="store_true", help="Run the demo conversation")
|
||||
parser.add_argument("--interactive", action="store_true", help="Run in interactive mode")
|
||||
parser.add_argument("--patient-id", type=str, default=USER_ID, help="Patient ID for the conversation")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.demo:
|
||||
@@ -231,5 +206,3 @@ if __name__ == "__main__":
|
||||
else:
|
||||
# Default to demo mode if no arguments provided
|
||||
asyncio.run(run_conversation())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user