Added custom prompt support (#1849)

This commit is contained in:
Prateek Chhikara
2024-09-10 16:57:32 -07:00
committed by GitHub
parent 5eeeb4e38c
commit ac7b7aa20a
5 changed files with 122 additions and 3 deletions

View File

@@ -56,6 +56,10 @@ class MemoryConfig(BaseModel):
description="The version of the API",
default="v1.0",
)
custom_prompt: Optional[str] = Field(
description="Custom prompt for the memory",
default=None,
)
class AzureConfig(BaseModel):

View File

@@ -28,6 +28,8 @@ logger = logging.getLogger(__name__)
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.custom_prompt = self.config.custom_prompt
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config
)
@@ -131,7 +133,11 @@ class Memory(MemoryBase):
def _add_to_vector_store(self, messages, metadata, filters):
parsed_messages = parse_messages(messages)
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
if self.custom_prompt:
system_prompt=self.custom_prompt
user_prompt=f"Input: {parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
response = self.llm.generate_response(
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],