Formatting and version bump -> 0.1.107 (#2927)
This commit is contained in:
@@ -18,22 +18,15 @@ class SarvamLLM(LLMBase):
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable "
|
||||
"or provide api_key in config."
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable " "or provide api_key in config."
|
||||
)
|
||||
|
||||
# Set base URL - use config value or environment or default
|
||||
self.base_url = (
|
||||
getattr(self.config, 'sarvam_base_url', None) or
|
||||
os.getenv("SARVAM_API_BASE") or
|
||||
"https://api.sarvam.ai/v1"
|
||||
getattr(self.config, "sarvam_base_url", None) or os.getenv("SARVAM_API_BASE") or "https://api.sarvam.ai/v1"
|
||||
)
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None
|
||||
) -> str:
|
||||
def generate_response(self, messages: List[Dict[str, str]], response_format=None) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Sarvam-M.
|
||||
|
||||
@@ -47,10 +40,7 @@ class SarvamLLM(LLMBase):
|
||||
"""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# Prepare the request payload
|
||||
params = {
|
||||
@@ -74,10 +64,7 @@ class SarvamLLM(LLMBase):
|
||||
params["model"] = self.config.model.get("name", "sarvam-m")
|
||||
|
||||
# Add Sarvam-specific parameters
|
||||
sarvam_specific_params = [
|
||||
'reasoning_effort', 'frequency_penalty', 'presence_penalty',
|
||||
'seed', 'stop', 'n'
|
||||
]
|
||||
sarvam_specific_params = ["reasoning_effort", "frequency_penalty", "presence_penalty", "seed", "stop", "n"]
|
||||
|
||||
for param in sarvam_specific_params:
|
||||
if param in self.config.model:
|
||||
@@ -89,12 +76,12 @@ class SarvamLLM(LLMBase):
|
||||
|
||||
result = response.json()
|
||||
|
||||
if 'choices' in result and len(result['choices']) > 0:
|
||||
return result['choices'][0]['message']['content']
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
return result["choices"][0]["message"]["content"]
|
||||
else:
|
||||
raise ValueError("No response choices found in Sarvam API response")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Sarvam API request failed: {e}")
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unexpected response format from Sarvam API: {e}")
|
||||
raise ValueError(f"Unexpected response format from Sarvam API: {e}")
|
||||
|
||||
Reference in New Issue
Block a user