Added Support for Ollama for local model inference. (#1045)

Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
Sukkrit Sharma
2023-12-22 05:10:00 +05:30
committed by GitHub
parent 210fe9bb80
commit 0f73ba9677
11 changed files with 348 additions and 1 deletions

View File

@@ -13,6 +13,7 @@ class LlmFactory:
"azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
"cohere": "embedchain.llm.cohere.CohereLlm",
"gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
"ollama": "embedchain.llm.ollama.OllamaLlm",
"huggingface": "embedchain.llm.huggingface.HuggingFaceLlm",
"jina": "embedchain.llm.jina.JinaLlm",
"llama2": "embedchain.llm.llama2.Llama2Llm",

34
embedchain/llm/ollama.py Normal file
View File

@@ -0,0 +1,34 @@
from typing import Iterable, Optional, Union
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.llms.ollama import Ollama
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@register_deserializable
class OllamaLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
if self.config.model is None:
self.config.model = "llama2"
def get_llm_model_answer(self, prompt):
return self._get_answer(prompt=prompt, config=self.config)
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
llm = Ollama(
model=config.model,
system=config.system_prompt,
temperature=config.temperature,
top_p=config.top_p,
callback_manager=CallbackManager(callback_manager)
)
return llm(prompt)

View File

@@ -6,4 +6,5 @@ class Providers(Enum):
ANTHROPHIC = "ANTHPROPIC"
VERTEX_AI = "VERTEX_AI"
GPT4ALL = "GPT4ALL"
OLLAMA = "OLLAMA"
AZURE_OPENAI = "AZURE_OPENAI"

View File

@@ -385,6 +385,7 @@ def validate_config(config_data):
"huggingface",
"cohere",
"gpt4all",
"ollama",
"jina",
"llama2",
"vertexai",