Show details for query tokens (#1392)

This commit is contained in:
Dev Khant
2024-07-05 00:10:56 +05:30
committed by GitHub
parent ea09b5f7f0
commit 4880557d51
25 changed files with 1825 additions and 517 deletions

View File

@@ -11,7 +11,7 @@ install:
install_all: install_all:
poetry install --all-extras poetry install --all-extras
poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface psutil poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama langchain_together==0.1.3 langchain_cohere==0.1.5 deepgram-sdk==3.2.7 langchain-huggingface psutil
install_es: install_es:
poetry install --extras elasticsearch poetry install --extras elasticsearch

View File

@@ -209,6 +209,7 @@ Alright, let's dive into what each key means in the yaml config above:
- `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse. - `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false). - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `online` (Boolean): Controls whether to use internet to get more context for answering query (set to false). - `online` (Boolean): Controls whether to use internet to get more context for answering query (set to false).
- `token_usage` (Boolean): Controls whether to use token usage for the querying models (set to false).
- `prompt` (String): A prompt for the model to follow when generating responses, requires `$context` and `$query` variables. - `prompt` (String): A prompt for the model to follow when generating responses, requires `$context` and `$query` variables.
- `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare. - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
- `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1 - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1

View File

@@ -840,6 +840,52 @@ answer = app.query("What is the net worth of Elon Musk today?")
``` ```
</CodeGroup> </CodeGroup>
## Token Usage
You can get the cost of the query by setting `token_usage` to `True` in the config file. This will return the token details: `input_tokens`, `output_tokens`, `total_cost`.
The list of paid LLMs that support token usage are:
- OpenAI
- Vertex AI
- Anthropic
- Cohere
- Together
- Groq
- Mistral AI
- NVIDIA AI
Here is an example of how to use token usage:
<CodeGroup>
```python main.py
os.environ["OPENAI_API_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml")
app.add("https://www.forbes.com/profile/elon-musk")
response, token_usage = app.query("what is the net worth of Elon Musk?")
# Elon Musk's net worth is $209.9 billion as of 6/9/24.
# {'input_tokens': 1228, 'output_tokens': 21, 'total_cost (USD)': 0.001884}
response, token_usage = app.chat("Which companies did Elon Musk found?")
# Elon Musk founded six companies, including Tesla, which is an electric car maker, SpaceX, a rocket producer, and the Boring Company, a tunneling startup.
# {'input_tokens': 1616, 'output_tokens': 34, 'total_cost (USD)': 0.002492}
```
```yaml config.yaml
llm:
provider: openai
config:
model: gpt-3.5-turbo
temperature: 0.5
max_tokens: 1000
token_usage: true
```
</CodeGroup>
If a model is missing and you'd like to add it to `model_prices_and_context_window.json`, please feel free to open a PR.
<br/ > <br/ >
<Snippet file="missing-llm-tip.mdx" /> <Snippet file="missing-llm-tip.mdx" />

View File

@@ -1,3 +1,4 @@
import json
import logging import logging
import re import re
from string import Template from string import Template
@@ -92,6 +93,7 @@ class BaseLlmConfig(BaseConfig):
top_p: float = 1, top_p: float = 1,
stream: bool = False, stream: bool = False,
online: bool = False, online: bool = False,
token_usage: bool = False,
deployment_name: Optional[str] = None, deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
where: dict[str, Any] = None, where: dict[str, Any] = None,
@@ -135,6 +137,8 @@ class BaseLlmConfig(BaseConfig):
:type stream: bool, optional :type stream: bool, optional
:param online: Controls whether to use internet for answering query, defaults to False :param online: Controls whether to use internet for answering query, defaults to False
:type online: bool, optional :type online: bool, optional
:param token_usage: Controls whether to return token usage in response, defaults to False
:type token_usage: bool, optional
:param deployment_name: t.b.a., defaults to None :param deployment_name: t.b.a., defaults to None
:type deployment_name: Optional[str], optional :type deployment_name: Optional[str], optional
:param system_prompt: System prompt string, defaults to None :param system_prompt: System prompt string, defaults to None
@@ -180,6 +184,8 @@ class BaseLlmConfig(BaseConfig):
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.model = model self.model = model
self.top_p = top_p self.top_p = top_p
self.online = online
self.token_usage = token_usage
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.query_type = query_type self.query_type = query_type
@@ -197,6 +203,10 @@ class BaseLlmConfig(BaseConfig):
self.online = online self.online = online
self.api_version = api_version self.api_version = api_version
if token_usage:
f = open("model_prices_and_context_window.json")
self.model_pricing_map = json.load(f)
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = Template(prompt) prompt = Template(prompt)

View File

@@ -6,9 +6,7 @@ from typing import Any, Optional, Union
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from embedchain.cache import (adapt, get_gptcache_session, from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig from embedchain.config.base_app_config import BaseAppConfig
@@ -18,8 +16,7 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType, from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
IndirectDataType, SpecialDataType)
from embedchain.utils.misc import detect_datatype, is_valid_json_string from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
@@ -478,7 +475,7 @@ class EmbedChain(JSONSerializable):
where: Optional[dict] = None, where: Optional[dict] = None,
citations: bool = False, citations: bool = False,
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> Union[tuple[str, list[tuple[str, dict]]], str]: ) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
@@ -501,7 +498,9 @@ class EmbedChain(JSONSerializable):
:type kwargs: dict[str, Any] :type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True :return: The answer to the query, with citations if the citation flag is True
or the dry run result or the dry run result
:rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]] :rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
tuple[str, list[tuple[str,str,str]], dict[str, Any]]
""" """
contexts = self._retrieve_from_database( contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -524,17 +523,29 @@ class EmbedChain(JSONSerializable):
dry_run=dry_run, dry_run=dry_run,
) )
else: else:
answer = self.llm.query( if self.llm.config.token_usage:
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run answer, token_info = self.llm.query(
) input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
else:
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# Send anonymous telemetry # Send anonymous telemetry
self.telemetry.capture(event_name="query", properties=self._telemetry_props) self.telemetry.capture(event_name="query", properties=self._telemetry_props)
if citations: if citations:
if self.llm.config.token_usage:
return {"answer": answer, "contexts": contexts, "usage": token_info}
return answer, contexts return answer, contexts
else: if self.llm.config.token_usage:
return answer return {"answer": answer, "usage": token_info}
logger.warning(
"Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
)
return answer
def chat( def chat(
self, self,
@@ -545,7 +556,7 @@ class EmbedChain(JSONSerializable):
where: Optional[dict[str, str]] = None, where: Optional[dict[str, str]] = None,
citations: bool = False, citations: bool = False,
**kwargs: dict[str, Any], **kwargs: dict[str, Any],
) -> Union[tuple[str, list[tuple[str, dict]]], str]: ) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
""" """
Queries the vector database on the given input query. Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
@@ -572,7 +583,9 @@ class EmbedChain(JSONSerializable):
:type kwargs: dict[str, Any] :type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True :return: The answer to the query, with citations if the citation flag is True
or the dry run result or the dry run result
:rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]] :rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
tuple[str, list[tuple[str,str,str]], dict[str, Any]]
""" """
contexts = self._retrieve_from_database( contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -600,9 +613,14 @@ class EmbedChain(JSONSerializable):
) )
else: else:
logger.debug("Cache disabled. Running chat without cache.") logger.debug("Cache disabled. Running chat without cache.")
answer = self.llm.chat( if self.llm.config.token_usage:
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run answer, token_info = self.llm.query(
) input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
else:
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# add conversation in memory # add conversation in memory
self.llm.add_history(self.config.id, input_query, answer, session_id=session_id) self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
@@ -611,9 +629,16 @@ class EmbedChain(JSONSerializable):
self.telemetry.capture(event_name="chat", properties=self._telemetry_props) self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
if citations: if citations:
if self.llm.config.token_usage:
return {"answer": answer, "contexts": contexts, "usage": token_info}
return answer, contexts return answer, contexts
else: if self.llm.config.token_usage:
return answer return {"answer": answer, "usage": token_info}
logger.warning(
"Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
)
return answer
def search(self, query, num_documents=3, where=None, raw_filter=None, namespace=None): def search(self, query, num_documents=3, where=None, raw_filter=None, namespace=None):
""" """

View File

@@ -9,10 +9,10 @@ class GPT4AllEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config) super().__init__(config=config)
from langchain.embeddings import \ from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
embeddings = LangchainGPT4AllEmbeddings() model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
embeddings = LangchainGPT4AllEmbeddings(model_name=model_name)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn) self.set_embedding_fn(embedding_fn=embedding_fn)

View File

@@ -1,6 +1,6 @@
import logging import logging
import os import os
from typing import Optional from typing import Any, Optional
try: try:
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
@@ -21,8 +21,27 @@ class AnthropicLlm(BaseLlm):
if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ: if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ:
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.") raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
return AnthropicLlm._get_answer(prompt=prompt, config=self.config) if self.config.token_usage:
response, token_info = self._get_answer(prompt, self.config)
model_name = "anthropic/" + self.config.model
if model_name not in self.config.model_pricing_map:
raise ValueError(
f"Model {model_name} not found in `model_prices_and_context_window.json`. \
You can disable token usage by setting `token_usage` to False."
)
total_cost = (
self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["input_tokens"]
) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["output_tokens"]
response_token_info = {
"prompt_tokens": token_info["input_tokens"],
"completion_tokens": token_info["output_tokens"],
"total_tokens": token_info["input_tokens"] + token_info["output_tokens"],
"total_cost": round(total_cost, 10),
"cost_currency": "USD",
}
return response, response_token_info
return self._get_answer(prompt, self.config)
@staticmethod @staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str: def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
@@ -34,4 +53,7 @@ class AnthropicLlm(BaseLlm):
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content chat_response = chat.invoke(messages)
if config.token_usage:
return chat_response.content, chat_response.response_metadata["token_usage"]
return chat_response.content

View File

@@ -164,7 +164,7 @@ class BaseLlm(JSONSerializable):
return search.run(input_query) return search.run(input_query)
@staticmethod @staticmethod
def _stream_response(answer: Any) -> Generator[Any, Any, None]: def _stream_response(answer: Any, token_info: Optional[dict[str, Any]] = None) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response """Generator to be used as streaming response
:param answer: Answer chunk from llm :param answer: Answer chunk from llm
@@ -177,6 +177,8 @@ class BaseLlm(JSONSerializable):
streamed_answer = streamed_answer + chunk streamed_answer = streamed_answer + chunk
yield chunk yield chunk
logger.info(f"Answer: {streamed_answer}") logger.info(f"Answer: {streamed_answer}")
if token_info:
logger.info(f"Token Info: {token_info}")
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False): def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
""" """
@@ -219,11 +221,18 @@ class BaseLlm(JSONSerializable):
if dry_run: if dry_run:
return prompt return prompt
answer = self.get_answer_from_llm(prompt) if self.config.token_usage:
answer, token_info = self.get_answer_from_llm(prompt)
else:
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str): if isinstance(answer, str):
logger.info(f"Answer: {answer}") logger.info(f"Answer: {answer}")
if self.config.token_usage:
return answer, token_info
return answer return answer
else: else:
if self.config.token_usage:
return self._stream_response(answer, token_info)
return self._stream_response(answer) return self._stream_response(answer)
finally: finally:
if config: if config:
@@ -276,13 +285,13 @@ class BaseLlm(JSONSerializable):
if dry_run: if dry_run:
return prompt return prompt
answer = self.get_answer_from_llm(prompt) answer, token_info = self.get_answer_from_llm(prompt)
if isinstance(answer, str): if isinstance(answer, str):
logger.info(f"Answer: {answer}") logger.info(f"Answer: {answer}")
return answer return answer, token_info
else: else:
# this is a streamed response and needs to be handled differently. # this is a streamed response and needs to be handled differently.
return self._stream_response(answer) return self._stream_response(answer, token_info)
finally: finally:
if config: if config:
# Restore previous config # Restore previous config

View File

@@ -1,8 +1,8 @@
import importlib import importlib
import os import os
from typing import Optional from typing import Any, Optional
from langchain_community.llms.cohere import Cohere from langchain_cohere import ChatCohere
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -17,27 +17,50 @@ class CohereLlm(BaseLlm):
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"The required dependencies for Cohere are not installed." "The required dependencies for Cohere are not installed."
'Please install with `pip install --upgrade "embedchain[cohere]"`' "Please install with `pip install langchain_cohere==1.16.0`"
) from None ) from None
super().__init__(config=config) super().__init__(config=config)
if not self.config.api_key and "COHERE_API_KEY" not in os.environ: if not self.config.api_key and "COHERE_API_KEY" not in os.environ:
raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.") raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
if self.config.system_prompt: if self.config.system_prompt:
raise ValueError("CohereLlm does not support `system_prompt`") raise ValueError("CohereLlm does not support `system_prompt`")
return CohereLlm._get_answer(prompt=prompt, config=self.config)
if self.config.token_usage:
response, token_info = self._get_answer(prompt, self.config)
model_name = "cohere/" + self.config.model
if model_name not in self.config.model_pricing_map:
raise ValueError(
f"Model {model_name} not found in `model_prices_and_context_window.json`. \
You can disable token usage by setting `token_usage` to False."
)
total_cost = (
self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["input_tokens"]
) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["output_tokens"]
response_token_info = {
"prompt_tokens": token_info["input_tokens"],
"completion_tokens": token_info["output_tokens"],
"total_tokens": token_info["input_tokens"] + token_info["output_tokens"],
"total_cost": round(total_cost, 10),
"cost_currency": "USD",
}
return response, response_token_info
return self._get_answer(prompt, self.config)
@staticmethod @staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str: def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
api_key = config.api_key or os.getenv("COHERE_API_KEY") api_key = config.api_key or os.environ["COHERE_API_KEY"]
llm = Cohere( kwargs = {
cohere_api_key=api_key, "model_name": config.model or "command-r",
model=config.model, "temperature": config.temperature,
max_tokens=config.max_tokens, "max_tokens": config.max_tokens,
temperature=config.temperature, "together_api_key": api_key,
p=config.top_p, }
)
return llm.invoke(prompt) chat = ChatCohere(**kwargs)
chat_response = chat.invoke(prompt)
if config.token_usage:
return chat_response.content, chat_response.response_metadata["token_count"]
return chat_response.content

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Optional from typing import Any, Optional
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
@@ -22,9 +22,27 @@ class GroqLlm(BaseLlm):
if not self.config.api_key and "GROQ_API_KEY" not in os.environ: if not self.config.api_key and "GROQ_API_KEY" not in os.environ:
raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.") raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt) -> str: def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
response = self._get_answer(prompt, self.config) if self.config.token_usage:
return response response, token_info = self._get_answer(prompt, self.config)
model_name = "groq/" + self.config.model
if model_name not in self.config.model_pricing_map:
raise ValueError(
f"Model {model_name} not found in `model_prices_and_context_window.json`. \
You can disable token usage by setting `token_usage` to False."
)
total_cost = (
self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
response_token_info = {
"prompt_tokens": token_info["prompt_tokens"],
"completion_tokens": token_info["completion_tokens"],
"total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
"total_cost": round(total_cost, 10),
"cost_currency": "USD",
}
return response, response_token_info
return self._get_answer(prompt, self.config)
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str: def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
messages = [] messages = []
@@ -42,4 +60,8 @@ class GroqLlm(BaseLlm):
chat = ChatGroq(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key) chat = ChatGroq(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
else: else:
chat = ChatGroq(**kwargs) chat = ChatGroq(**kwargs)
return chat.invoke(messages).content
chat_response = chat.invoke(prompt)
if self.config.token_usage:
return chat_response.content, chat_response.response_metadata["token_usage"]
return chat_response.content

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Optional from typing import Any, Optional
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -13,8 +13,27 @@ class MistralAILlm(BaseLlm):
if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ: if not self.config.api_key and "MISTRAL_API_KEY" not in os.environ:
raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.") raise ValueError("Please set the MISTRAL_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
return MistralAILlm._get_answer(prompt=prompt, config=self.config) if self.config.token_usage:
response, token_info = self._get_answer(prompt, self.config)
model_name = "mistralai/" + self.config.model
if model_name not in self.config.model_pricing_map:
raise ValueError(
f"Model {model_name} not found in `model_prices_and_context_window.json`. \
You can disable token usage by setting `token_usage` to False."
)
total_cost = (
self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
response_token_info = {
"prompt_tokens": token_info["prompt_tokens"],
"completion_tokens": token_info["completion_tokens"],
"total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
"total_cost": round(total_cost, 10),
"cost_currency": "USD",
}
return response, response_token_info
return self._get_answer(prompt, self.config)
@staticmethod @staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig): def _get_answer(prompt: str, config: BaseLlmConfig):
@@ -47,6 +66,7 @@ class MistralAILlm(BaseLlm):
answer += chunk.content answer += chunk.content
return answer return answer
else: else:
response = client.invoke(**kwargs, input=messages) chat_response = client.invoke(**kwargs, input=messages)
answer = response.content if config.token_usage:
return answer return chat_response.content, chat_response.response_metadata["token_usage"]
return chat_response.content

View File

@@ -1,6 +1,6 @@
import os import os
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Any, Optional, Union
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler
@@ -25,8 +25,27 @@ class NvidiaLlm(BaseLlm):
if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ: if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ:
raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.") raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
return self._get_answer(prompt=prompt, config=self.config) if self.config.token_usage:
response, token_info = self._get_answer(prompt, self.config)
model_name = "nvidia/" + self.config.model
if model_name not in self.config.model_pricing_map:
raise ValueError(
f"Model {model_name} not found in `model_prices_and_context_window.json`. \
You can disable token usage by setting `token_usage` to False."
)
total_cost = (
self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["input_tokens"]
) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["output_tokens"]
response_token_info = {
"prompt_tokens": token_info["input_tokens"],
"completion_tokens": token_info["output_tokens"],
"total_tokens": token_info["input_tokens"] + token_info["output_tokens"],
"total_cost": round(total_cost, 10),
"cost_currency": "USD",
}
return response, response_token_info
return self._get_answer(prompt, self.config)
@staticmethod @staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]: def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
@@ -43,4 +62,7 @@ class NvidiaLlm(BaseLlm):
if labels: if labels:
params["labels"] = labels params["labels"] = labels
llm = ChatNVIDIA(**params, callback_manager=CallbackManager(callback_manager)) llm = ChatNVIDIA(**params, callback_manager=CallbackManager(callback_manager))
return llm.invoke(prompt).content if labels is None else llm.invoke(prompt, labels=labels).content chat_response = llm.invoke(prompt) if labels is None else llm.invoke(prompt, labels=labels)
if config.token_usage:
return chat_response.content, chat_response.response_metadata["token_usage"]
return chat_response.content

View File

@@ -23,9 +23,28 @@ class OpenAILlm(BaseLlm):
self.tools = tools self.tools = tools
super().__init__(config=config) super().__init__(config=config)
def get_llm_model_answer(self, prompt) -> str: def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
response = self._get_answer(prompt, self.config) if self.config.token_usage:
return response response, token_info = self._get_answer(prompt, self.config)
model_name = "openai/" + self.config.model
if model_name not in self.config.model_pricing_map:
raise ValueError(
f"Model {model_name} not found in `model_prices_and_context_window.json`. \
You can disable token usage by setting `token_usage` to False."
)
total_cost = (
self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
response_token_info = {
"prompt_tokens": token_info["prompt_tokens"],
"completion_tokens": token_info["completion_tokens"],
"total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
"total_cost": round(total_cost, 10),
"cost_currency": "USD",
}
return response, response_token_info
return self._get_answer(prompt, self.config)
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str: def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
messages = [] messages = []
@@ -66,7 +85,10 @@ class OpenAILlm(BaseLlm):
if self.tools: if self.tools:
return self._query_function_call(chat, self.tools, messages) return self._query_function_call(chat, self.tools, messages)
return chat.invoke(messages).content chat_response = chat.invoke(messages)
if self.config.token_usage:
return chat_response.content, chat_response.response_metadata["token_usage"]
return chat_response.content
def _query_function_call( def _query_function_call(
self, self,

View File

@@ -1,8 +1,13 @@
import importlib import importlib
import os import os
from typing import Optional from typing import Any, Optional
from langchain_community.llms import Together try:
from langchain_together import ChatTogether
except ImportError:
raise ImportError(
"Please install the langchain_together package by running `pip install langchain_together==0.1.3`."
)
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -24,20 +29,43 @@ class TogetherLlm(BaseLlm):
if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ: if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ:
raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.") raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
if self.config.system_prompt: if self.config.system_prompt:
raise ValueError("TogetherLlm does not support `system_prompt`") raise ValueError("TogetherLlm does not support `system_prompt`")
return TogetherLlm._get_answer(prompt=prompt, config=self.config)
if self.config.token_usage:
response, token_info = self._get_answer(prompt, self.config)
model_name = "together/" + self.config.model
if model_name not in self.config.model_pricing_map:
raise ValueError(
f"Model {model_name} not found in `model_prices_and_context_window.json`. \
You can disable token usage by setting `token_usage` to False."
)
total_cost = (
self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_tokens"]
) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info["completion_tokens"]
response_token_info = {
"prompt_tokens": token_info["prompt_tokens"],
"completion_tokens": token_info["completion_tokens"],
"total_tokens": token_info["prompt_tokens"] + token_info["completion_tokens"],
"total_cost": round(total_cost, 10),
"cost_currency": "USD",
}
return response, response_token_info
return self._get_answer(prompt, self.config)
@staticmethod @staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str: def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
api_key = config.api_key or os.getenv("TOGETHER_API_KEY") api_key = config.api_key or os.environ["TOGETHER_API_KEY"]
llm = Together( kwargs = {
together_api_key=api_key, "model_name": config.model or "mixtral-8x7b-32768",
model=config.model, "temperature": config.temperature,
max_tokens=config.max_tokens, "max_tokens": config.max_tokens,
temperature=config.temperature, "together_api_key": api_key,
top_p=config.top_p, }
)
return llm.invoke(prompt) chat = ChatTogether(**kwargs)
chat_response = chat.invoke(prompt)
if config.token_usage:
return chat_response.content, chat_response.response_metadata["token_usage"]
return chat_response.content

View File

@@ -1,6 +1,6 @@
import importlib import importlib
import logging import logging
from typing import Optional from typing import Any, Optional
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_google_vertexai import ChatVertexAI from langchain_google_vertexai import ChatVertexAI
@@ -24,16 +24,35 @@ class VertexAILlm(BaseLlm):
) from None ) from None
super().__init__(config=config) super().__init__(config=config)
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt) -> tuple[str, Optional[dict[str, Any]]]:
return VertexAILlm._get_answer(prompt=prompt, config=self.config) if self.config.token_usage:
response, token_info = self._get_answer(prompt, self.config)
model_name = "vertexai/" + self.config.model
if model_name not in self.config.model_pricing_map:
raise ValueError(
f"Model {model_name} not found in `model_prices_and_context_window.json`. \
You can disable token usage by setting `token_usage` to False."
)
total_cost = (
self.config.model_pricing_map[model_name]["input_cost_per_token"] * token_info["prompt_token_count"]
) + self.config.model_pricing_map[model_name]["output_cost_per_token"] * token_info[
"candidates_token_count"
]
response_token_info = {
"prompt_tokens": token_info["prompt_token_count"],
"completion_tokens": token_info["candidates_token_count"],
"total_tokens": token_info["prompt_token_count"] + token_info["candidates_token_count"],
"total_cost": round(total_cost, 10),
"cost_currency": "USD",
}
return response, response_token_info
return self._get_answer(prompt, self.config)
@staticmethod @staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str: def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
if config.top_p and config.top_p != 1: if config.top_p and config.top_p != 1:
logger.warning("Config option `top_p` is not supported by this model.") logger.warning("Config option `top_p` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
if config.stream: if config.stream:
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
llm = ChatVertexAI( llm = ChatVertexAI(
@@ -42,4 +61,8 @@ class VertexAILlm(BaseLlm):
else: else:
llm = ChatVertexAI(temperature=config.temperature, model=config.model) llm = ChatVertexAI(temperature=config.temperature, model=config.model)
return llm.invoke(messages).content messages = VertexAILlm._get_messages(prompt)
chat_response = llm.invoke(messages)
if config.token_usage:
return chat_response.content, chat_response.response_metadata["usage_metadata"]
return chat_response.content

View File

@@ -428,6 +428,7 @@ def validate_config(config_data):
Optional("top_p"): Or(float, int), Optional("top_p"): Or(float, int),
Optional("stream"): bool, Optional("stream"): bool,
Optional("online"): bool, Optional("online"): bool,
Optional("token_usage"): bool,
Optional("template"): str, Optional("template"): str,
Optional("prompt"): str, Optional("prompt"): str,
Optional("system_prompt"): str, Optional("system_prompt"): str,

View File

@@ -0,0 +1,803 @@
{
"openai/gpt-4": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006
},
"openai/gpt-4o": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015
},
"openai/gpt-4o-2024-05-13": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015
},
"openai/gpt-4-turbo-preview": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"openai/gpt-4-0314": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006
},
"openai/gpt-4-0613": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006
},
"openai/gpt-4-32k": {
"max_tokens": 4096,
"max_input_tokens": 32768,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00006,
"output_cost_per_token": 0.00012
},
"openai/gpt-4-32k-0314": {
"max_tokens": 4096,
"max_input_tokens": 32768,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00006,
"output_cost_per_token": 0.00012
},
"openai/gpt-4-32k-0613": {
"max_tokens": 4096,
"max_input_tokens": 32768,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00006,
"output_cost_per_token": 0.00012
},
"openai/gpt-4-turbo": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"openai/gpt-4-turbo-2024-04-09": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"openai/gpt-4-1106-preview": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"openai/gpt-4-0125-preview": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"openai/gpt-3.5-turbo": {
"max_tokens": 4097,
"max_input_tokens": 16385,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002
},
"openai/gpt-3.5-turbo-0301": {
"max_tokens": 4097,
"max_input_tokens": 4097,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002
},
"openai/gpt-3.5-turbo-0613": {
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002
},
"openai/gpt-3.5-turbo-1106": {
"max_tokens": 16385,
"max_input_tokens": 16385,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000010,
"output_cost_per_token": 0.0000020
},
"openai/gpt-3.5-turbo-0125": {
"max_tokens": 16385,
"max_input_tokens": 16385,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015
},
"openai/gpt-3.5-turbo-16k": {
"max_tokens": 16385,
"max_input_tokens": 16385,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000004
},
"openai/gpt-3.5-turbo-16k-0613": {
"max_tokens": 16385,
"max_input_tokens": 16385,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000004
},
"openai/text-embedding-3-large": {
"max_tokens": 8191,
"max_input_tokens": 8191,
"output_vector_size": 3072,
"input_cost_per_token": 0.00000013,
"output_cost_per_token": 0.000000
},
"openai/text-embedding-3-small": {
"max_tokens": 8191,
"max_input_tokens": 8191,
"output_vector_size": 1536,
"input_cost_per_token": 0.00000002,
"output_cost_per_token": 0.000000
},
"openai/text-embedding-ada-002": {
"max_tokens": 8191,
"max_input_tokens": 8191,
"output_vector_size": 1536,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.000000
},
"openai/text-embedding-ada-002-v2": {
"max_tokens": 8191,
"max_input_tokens": 8191,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.000000
},
"openai/babbage-002": {
"max_tokens": 16384,
"max_input_tokens": 16384,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000004,
"output_cost_per_token": 0.0000004
},
"openai/davinci-002": {
"max_tokens": 16384,
"max_input_tokens": 16384,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000002
},
"openai/gpt-3.5-turbo-instruct": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002
},
"openai/gpt-3.5-turbo-instruct-0914": {
"max_tokens": 4097,
"max_input_tokens": 8192,
"max_output_tokens": 4097,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002
},
"azure/gpt-4o": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015
},
"azure/gpt-4-turbo-2024-04-09": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"azure/gpt-4-0125-preview": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"azure/gpt-4-1106-preview": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"azure/gpt-4-0613": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006
},
"azure/gpt-4-32k-0613": {
"max_tokens": 4096,
"max_input_tokens": 32768,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00006,
"output_cost_per_token": 0.00012
},
"azure/gpt-4-32k": {
"max_tokens": 4096,
"max_input_tokens": 32768,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00006,
"output_cost_per_token": 0.00012
},
"azure/gpt-4": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006
},
"azure/gpt-4-turbo": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"azure/gpt-4-turbo-vision-preview": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003
},
"azure/gpt-3.5-turbo-16k-0613": {
"max_tokens": 4096,
"max_input_tokens": 16385,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000004
},
"azure/gpt-3.5-turbo-1106": {
"max_tokens": 4096,
"max_input_tokens": 16384,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002
},
"azure/gpt-3.5-turbo-0125": {
"max_tokens": 4096,
"max_input_tokens": 16384,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015
},
"azure/gpt-3.5-turbo-16k": {
"max_tokens": 4096,
"max_input_tokens": 16385,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000004
},
"azure/gpt-3.5-turbo": {
"max_tokens": 4096,
"max_input_tokens": 4097,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015
},
"azure/gpt-3.5-turbo-instruct-0914": {
"max_tokens": 4097,
"max_input_tokens": 4097,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002
},
"azure/gpt-3.5-turbo-instruct": {
"max_tokens": 4097,
"max_input_tokens": 4097,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000002
},
"azure/text-embedding-ada-002": {
"max_tokens": 8191,
"max_input_tokens": 8191,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.000000
},
"azure/text-embedding-3-large": {
"max_tokens": 8191,
"max_input_tokens": 8191,
"input_cost_per_token": 0.00000013,
"output_cost_per_token": 0.000000
},
"azure/text-embedding-3-small": {
"max_tokens": 8191,
"max_input_tokens": 8191,
"input_cost_per_token": 0.00000002,
"output_cost_per_token": 0.000000
},
"mistralai/mistral-tiny": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000025
},
"mistralai/mistral-small": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000003
},
"mistralai/mistral-small-latest": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000003
},
"mistralai/mistral-medium": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.0000027,
"output_cost_per_token": 0.0000081
},
"mistralai/mistral-medium-latest": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.0000027,
"output_cost_per_token": 0.0000081
},
"mistralai/mistral-medium-2312": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.0000027,
"output_cost_per_token": 0.0000081
},
"mistralai/mistral-large-latest": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000004,
"output_cost_per_token": 0.000012
},
"mistralai/mistral-large-2402": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000004,
"output_cost_per_token": 0.000012
},
"mistralai/open-mistral-7b": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000025
},
"mistralai/open-mixtral-8x7b": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.0000007,
"output_cost_per_token": 0.0000007
},
"mistralai/open-mixtral-8x22b": {
"max_tokens": 8191,
"max_input_tokens": 64000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006
},
"mistralai/codestral-latest": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000003
},
"mistralai/codestral-2405": {
"max_tokens": 8191,
"max_input_tokens": 32000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000003
},
"mistralai/mistral-embed": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0
},
"groq/llama2-70b-4096": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000070,
"output_cost_per_token": 0.00000080
},
"groq/llama3-8b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000010
},
"groq/llama3-70b-8192": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000064,
"output_cost_per_token": 0.00000080
},
"groq/mixtral-8x7b-32768": {
"max_tokens": 32768,
"max_input_tokens": 32768,
"max_output_tokens": 32768,
"input_cost_per_token": 0.00000027,
"output_cost_per_token": 0.00000027
},
"groq/gemma-7b-it": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000010
},
"anthropic/claude-instant-1": {
"max_tokens": 8191,
"max_input_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.00000163,
"output_cost_per_token": 0.00000551
},
"anthropic/claude-instant-1.2": {
"max_tokens": 8191,
"max_input_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000000163,
"output_cost_per_token": 0.000000551
},
"anthropic/claude-2": {
"max_tokens": 8191,
"max_input_tokens": 100000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024
},
"anthropic/claude-2.1": {
"max_tokens": 8191,
"max_input_tokens": 200000,
"max_output_tokens": 8191,
"input_cost_per_token": 0.000008,
"output_cost_per_token": 0.000024
},
"anthropic/claude-3-haiku-20240307": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125
},
"anthropic/claude-3-opus-20240229": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075
},
"anthropic/claude-3-sonnet-20240229": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015
},
"vertexai/chat-bison": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/chat-bison@001": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/chat-bison@002": {
"max_tokens": 4096,
"max_input_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/chat-bison-32k": {
"max_tokens": 8192,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/code-bison": {
"max_tokens": 1024,
"max_input_tokens": 6144,
"max_output_tokens": 1024,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/code-bison@001": {
"max_tokens": 1024,
"max_input_tokens": 6144,
"max_output_tokens": 1024,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/code-gecko@001": {
"max_tokens": 64,
"max_input_tokens": 2048,
"max_output_tokens": 64,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/code-gecko@002": {
"max_tokens": 64,
"max_input_tokens": 2048,
"max_output_tokens": 64,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/code-gecko": {
"max_tokens": 64,
"max_input_tokens": 2048,
"max_output_tokens": 64,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/codechat-bison": {
"max_tokens": 1024,
"max_input_tokens": 6144,
"max_output_tokens": 1024,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/codechat-bison@001": {
"max_tokens": 1024,
"max_input_tokens": 6144,
"max_output_tokens": 1024,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/codechat-bison-32k": {
"max_tokens": 8192,
"max_input_tokens": 32000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000000125,
"output_cost_per_token": 0.000000125
},
"vertexai/gemini-pro": {
"max_tokens": 8192,
"max_input_tokens": 32760,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.0000005
},
"vertexai/gemini-1.0-pro": {
"max_tokens": 8192,
"max_input_tokens": 32760,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.0000005
},
"vertexai/gemini-1.0-pro-001": {
"max_tokens": 8192,
"max_input_tokens": 32760,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.0000005
},
"vertexai/gemini-1.0-pro-002": {
"max_tokens": 8192,
"max_input_tokens": 32760,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.0000005
},
"vertexai/gemini-1.5-pro": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000000625,
"output_cost_per_token": 0.000001875
},
"vertexai/gemini-1.5-flash-001": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0,
"output_cost_per_token": 0
},
"vertexai/gemini-1.5-flash-preview-0514": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0,
"output_cost_per_token": 0
},
"vertexai/gemini-1.5-pro-001": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000000625,
"output_cost_per_token": 0.000001875
},
"vertexai/gemini-1.5-pro-preview-0514": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000000625,
"output_cost_per_token": 0.000001875
},
"vertexai/gemini-1.5-pro-preview-0215": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000000625,
"output_cost_per_token": 0.000001875
},
"vertexai/gemini-1.5-pro-preview-0409": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.000000625,
"output_cost_per_token": 0.000001875
},
"vertexai/gemini-experimental": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"input_cost_per_token": 0,
"output_cost_per_token": 0
},
"vertexai/gemini-pro-vision": {
"max_tokens": 2048,
"max_input_tokens": 16384,
"max_output_tokens": 2048,
"max_images_per_prompt": 16,
"max_videos_per_prompt": 1,
"max_video_length": 2,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.0000005
},
"vertexai/gemini-1.0-pro-vision": {
"max_tokens": 2048,
"max_input_tokens": 16384,
"max_output_tokens": 2048,
"max_images_per_prompt": 16,
"max_videos_per_prompt": 1,
"max_video_length": 2,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.0000005
},
"vertexai/gemini-1.0-pro-vision-001": {
"max_tokens": 2048,
"max_input_tokens": 16384,
"max_output_tokens": 2048,
"max_images_per_prompt": 16,
"max_videos_per_prompt": 1,
"max_video_length": 2,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.0000005
},
"vertexai/claude-3-sonnet@20240229": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015
},
"vertexai/claude-3-haiku@20240307": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125
},
"vertexai/claude-3-opus@20240229": {
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075
},
"cohere/command-r": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000050,
"output_cost_per_token": 0.0000015
},
"cohere/command-light": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015
},
"cohere/command-r-plus": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015
},
"cohere/command-nightly": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015
},
"cohere/command": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015
},
"cohere/command-medium-beta": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015
},
"cohere/command-xlarge-beta": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000015
},
"together/together-ai-up-to-3b": {
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0000001
},
"together/together-ai-3.1b-7b": {
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002
},
"together/together-ai-7.1b-20b": {
"max_tokens": 1000,
"input_cost_per_token": 0.0000004,
"output_cost_per_token": 0.0000004
},
"together/together-ai-20.1b-40b": {
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000008
},
"together/together-ai-40.1b-70b": {
"input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.0000009
},
"together/mistralai/Mixtral-8x7B-Instruct-v0.1": {
"input_cost_per_token": 0.0000006,
"output_cost_per_token": 0.0000006
}
}

912
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -156,6 +156,7 @@ langchain-google-vertexai = { version = "^1.0.6", optional = true }
sqlalchemy = "^2.0.27" sqlalchemy = "^2.0.27"
alembic = "^1.13.1" alembic = "^1.13.1"
langchain-cohere = "^0.1.4" langchain-cohere = "^0.1.4"
langchain-community = "^0.2.6"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "^23.3.0" black = "^23.3.0"
@@ -183,9 +184,8 @@ slack = ["slack-sdk", "flask"]
whatsapp = ["twilio", "flask"] whatsapp = ["twilio", "flask"]
weaviate = ["weaviate-client"] weaviate = ["weaviate-client"]
qdrant = ["qdrant-client"] qdrant = ["qdrant-client"]
huggingface_hub=["huggingface_hub"]
cohere = ["cohere"]
together = ["together"] together = ["together"]
huggingface_hub=["huggingface_hub"]
milvus = ["pymilvus"] milvus = ["pymilvus"]
dataloaders=[ dataloaders=[
"youtube-transcript-api", "youtube-transcript-api",

View File

@@ -11,7 +11,7 @@ from embedchain.llm.anthropic import AnthropicLlm
@pytest.fixture @pytest.fixture
def anthropic_llm(): def anthropic_llm():
os.environ["ANTHROPIC_API_KEY"] = "test_api_key" os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
config = BaseLlmConfig(temperature=0.5, model="gpt2") config = BaseLlmConfig(temperature=0.5, model="claude-instant-1", token_usage=False)
return AnthropicLlm(config) return AnthropicLlm(config)
@@ -20,7 +20,7 @@ def test_get_llm_model_answer(anthropic_llm):
prompt = "Test Prompt" prompt = "Test Prompt"
response = anthropic_llm.get_llm_model_answer(prompt) response = anthropic_llm.get_llm_model_answer(prompt)
assert response == "Test Response" assert response == "Test Response"
mock_method.assert_called_once_with(prompt=prompt, config=anthropic_llm.config) mock_method.assert_called_once_with(prompt, anthropic_llm.config)
def test_get_messages(anthropic_llm): def test_get_messages(anthropic_llm):
@@ -31,3 +31,24 @@ def test_get_messages(anthropic_llm):
SystemMessage(content="Test System Prompt", additional_kwargs={}), SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False), HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
] ]
def test_get_llm_model_answer_with_token_usage(anthropic_llm):
test_config = BaseLlmConfig(
temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model, token_usage=True
)
anthropic_llm.config = test_config
with patch.object(
AnthropicLlm, "_get_answer", return_value=("Test Response", {"input_tokens": 1, "output_tokens": 2})
) as mock_method:
prompt = "Test Prompt"
response, token_info = anthropic_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 1.265e-05,
"cost_currency": "USD",
}
mock_method.assert_called_once_with(prompt, anthropic_llm.config)

View File

@@ -9,7 +9,7 @@ from embedchain.llm.cohere import CohereLlm
@pytest.fixture @pytest.fixture
def cohere_llm_config(): def cohere_llm_config():
os.environ["COHERE_API_KEY"] = "test_api_key" os.environ["COHERE_API_KEY"] = "test_api_key"
config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8) config = BaseLlmConfig(model="command-r", max_tokens=100, temperature=0.7, top_p=0.8, token_usage=False)
yield config yield config
os.environ.pop("COHERE_API_KEY") os.environ.pop("COHERE_API_KEY")
@@ -36,10 +36,35 @@ def test_get_llm_model_answer(cohere_llm_config, mocker):
assert answer == "Test answer" assert answer == "Test answer"
def test_get_llm_model_answer_with_token_usage(cohere_llm_config, mocker):
test_config = BaseLlmConfig(
temperature=cohere_llm_config.temperature,
max_tokens=cohere_llm_config.max_tokens,
top_p=cohere_llm_config.top_p,
model=cohere_llm_config.model,
token_usage=True,
)
mocker.patch(
"embedchain.llm.cohere.CohereLlm._get_answer",
return_value=("Test answer", {"input_tokens": 1, "output_tokens": 2}),
)
llm = CohereLlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 3.5e-06,
"cost_currency": "USD",
}
def test_get_answer_mocked_cohere(cohere_llm_config, mocker): def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere") mocked_cohere = mocker.patch("embedchain.llm.cohere.ChatCohere")
mock_instance = mocked_cohere.return_value mocked_cohere.return_value.invoke.return_value.content = "Mocked answer"
mock_instance.invoke.return_value = "Mocked answer"
llm = CohereLlm(cohere_llm_config) llm = CohereLlm(cohere_llm_config)
prompt = "Test query" prompt = "Test query"

View File

@@ -24,7 +24,7 @@ def test_mistralai_llm_init(monkeypatch):
def test_get_llm_model_answer(monkeypatch, mistralai_llm_config): def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
def mock_get_answer(prompt, config): def mock_get_answer(self, prompt, config):
return "Generated Text" return "Generated Text"
monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer) monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
@@ -36,7 +36,7 @@ def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config): def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
mistralai_llm_config.system_prompt = "Test system prompt" mistralai_llm_config.system_prompt = "Test system prompt"
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text") monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config) llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt") result = llm.get_llm_model_answer("test prompt")
@@ -44,7 +44,7 @@ def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_conf
def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config): def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text") monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config) llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("") result = llm.get_llm_model_answer("")
@@ -53,8 +53,35 @@ def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config): def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
mistralai_llm_config.system_prompt = None mistralai_llm_config.system_prompt = None
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda prompt, config: "Generated Text") monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config) llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt") result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text" assert result == "Generated Text"
def test_get_llm_model_answer_with_token_usage(monkeypatch, mistralai_llm_config):
test_config = BaseLlmConfig(
temperature=mistralai_llm_config.temperature,
max_tokens=mistralai_llm_config.max_tokens,
top_p=mistralai_llm_config.top_p,
model=mistralai_llm_config.model,
token_usage=True,
)
monkeypatch.setattr(
MistralAILlm,
"_get_answer",
lambda self, prompt, config: ("Generated Text", {"prompt_tokens": 1, "completion_tokens": 2}),
)
llm = MistralAILlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Generated Text"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 7.5e-07,
"cost_currency": "USD",
}

View File

@@ -62,6 +62,35 @@ def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer.assert_called_once_with("", config) mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_token_usage(config, mocker):
test_config = BaseLlmConfig(
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
stream=config.stream,
system_prompt=config.system_prompt,
model=config.model,
token_usage=True,
)
mocked_get_answer = mocker.patch(
"embedchain.llm.openai.OpenAILlm._get_answer",
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
)
llm = OpenAILlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 5.5e-06,
"cost_currency": "USD",
}
mocked_get_answer.assert_called_once_with("Test query", test_config)
def test_get_llm_model_answer_with_streaming(config, mocker): def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True config.stream = True
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI") mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")

View File

@@ -9,7 +9,7 @@ from embedchain.llm.together import TogetherLlm
@pytest.fixture @pytest.fixture
def together_llm_config(): def together_llm_config():
os.environ["TOGETHER_API_KEY"] = "test_api_key" os.environ["TOGETHER_API_KEY"] = "test_api_key"
config = BaseLlmConfig(model="togethercomputer/RedPajama-INCITE-7B-Base", max_tokens=50, temperature=0.7, top_p=0.8) config = BaseLlmConfig(model="together-ai-up-to-3b", max_tokens=50, temperature=0.7, top_p=0.8)
yield config yield config
os.environ.pop("TOGETHER_API_KEY") os.environ.pop("TOGETHER_API_KEY")
@@ -36,10 +36,36 @@ def test_get_llm_model_answer(together_llm_config, mocker):
assert answer == "Test answer" assert answer == "Test answer"
def test_get_llm_model_answer_with_token_usage(together_llm_config, mocker):
test_config = BaseLlmConfig(
temperature=together_llm_config.temperature,
max_tokens=together_llm_config.max_tokens,
top_p=together_llm_config.top_p,
model=together_llm_config.model,
token_usage=True,
)
mocker.patch(
"embedchain.llm.together.TogetherLlm._get_answer",
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
)
llm = TogetherLlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 3e-07,
"cost_currency": "USD",
}
def test_get_answer_mocked_together(together_llm_config, mocker): def test_get_answer_mocked_together(together_llm_config, mocker):
mocked_together = mocker.patch("embedchain.llm.together.Together") mocked_together = mocker.patch("embedchain.llm.together.ChatTogether")
mock_instance = mocked_together.return_value mock_instance = mocked_together.return_value
mock_instance.invoke.return_value = "Mocked answer" mock_instance.invoke.return_value.content = "Mocked answer"
llm = TogetherLlm(together_llm_config) llm = TogetherLlm(together_llm_config)
prompt = "Test query" prompt = "Test query"

View File

@@ -24,7 +24,32 @@ def test_get_llm_model_answer(vertexai_llm):
prompt = "Test Prompt" prompt = "Test Prompt"
response = vertexai_llm.get_llm_model_answer(prompt) response = vertexai_llm.get_llm_model_answer(prompt)
assert response == "Test Response" assert response == "Test Response"
mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config) mock_method.assert_called_once_with(prompt, vertexai_llm.config)
def test_get_llm_model_answer_with_token_usage(vertexai_llm):
test_config = BaseLlmConfig(
temperature=vertexai_llm.config.temperature,
max_tokens=vertexai_llm.config.max_tokens,
top_p=vertexai_llm.config.top_p,
model=vertexai_llm.config.model,
token_usage=True,
)
vertexai_llm.config = test_config
with patch.object(
VertexAILlm,
"_get_answer",
return_value=("Test Response", {"prompt_token_count": 1, "candidates_token_count": 2}),
):
response, token_info = vertexai_llm.get_llm_model_answer("Test Query")
assert response == "Test Response"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 3.75e-07,
"cost_currency": "USD",
}
@patch("embedchain.llm.vertex_ai.ChatVertexAI") @patch("embedchain.llm.vertex_ai.ChatVertexAI")