Add OpenAI proxy (#1503)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Dev Khant
2024-08-02 20:14:27 +05:30
committed by GitHub
parent 51092b0b64
commit 419dc6598c
18 changed files with 637 additions and 135 deletions

View File

@@ -21,20 +21,20 @@ class OutputData(BaseModel):
class ChromaDB(VectorStoreBase):
def __init__(
self,
collection_name="mem0",
client=None,
host=None,
port=None,
path=None
collection_name,
client,
host,
port,
path
):
"""
Initialize the Qdrant vector store.
Args:
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
host (str, optional): Host address for Qdrant server. Defaults to None.
port (int, optional): Port for Qdrant server. Defaults to None.
path (str, optional): Path for local Qdrant database. Defaults to None.
client (QdrantClient, optional): Existing Qdrant client instance.
host (str, optional): Host address for Qdrant server.
port (int, optional): Port for Qdrant server.
path (str, optional): Path for local Qdrant database.
"""
if client:
self.client = client
@@ -95,7 +95,7 @@ class ChromaDB(VectorStoreBase):
Args:
name (str): Name of the collection.
embedding_fn (function): Embedding function to use.
embedding_fn (function): Embedding function to use. Defaults to None.
"""
# Skip creating collection if already exists
collections = self.list_cols()
@@ -213,7 +213,7 @@ class ChromaDB(VectorStoreBase):
Args:
name (str): Name of the collection.
filters (dict, optional): Filters to apply to the list. Defaults to None.
filters (dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:

View File

@@ -3,14 +3,23 @@ from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator
def create_default_config(provider: str):
"""Create a default configuration based on the provider."""
if provider == "qdrant":
return QdrantConfig(path="/tmp/qdrant")
elif provider == "chromadb":
return ChromaDbConfig(path="/tmp/chromadb")
else:
raise ValueError(f"Unsupported vector store provider: {provider}")
class QdrantConfig(BaseModel):
collection_name: str = Field(default="mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(
default=1536, description="Dimensions of the embedding model"
)
collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
client: Optional[str] = Field(None, description="Existing Qdrant client instance")
host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field(None, description="Path for local Qdrant database")
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
@@ -31,18 +40,11 @@ class QdrantConfig(BaseModel):
class ChromaDbConfig(BaseModel):
collection_name: str = Field(
default="mem0", description="Default name for the collection"
)
path: Optional[str] = Field(
default=None, description="Path to the database directory"
)
host: Optional[str] = Field(
default=None, description="Database connection remote host"
)
port: Optional[str] = Field(
default=None, description="Database connection remote port"
)
collection_name: str = Field("mem0", description="Default name for the collection")
client: Optional[str] = Field(None, description="Existing ChromaDB client instance")
path: Optional[str] = Field(None, description="Path to the database directory")
host: Optional[str] = Field(None, description="Database connection remote host")
port: Optional[str] = Field(None, description="Database connection remote port")
@model_validator(mode="before")
def check_host_port_or_path(cls, values):
@@ -59,15 +61,37 @@ class VectorStoreConfig(BaseModel):
)
config: Optional[dict] = Field(
description="Configuration for the specific vector store",
default={},
default=None
)
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider == "qdrant":
return QdrantConfig(**v.model_dump())
elif provider == "chromadb":
return ChromaDbConfig(**v.model_dump())
else:
raise ValueError(f"Unsupported vector store provider: {provider}")
if v is None:
return create_default_config(provider)
if isinstance(v, dict):
if provider == "qdrant":
return QdrantConfig(**v)
elif provider == "chromadb":
return ChromaDbConfig(**v)
return v
@model_validator(mode="after")
def ensure_config_type(cls, values):
provider = values.provider
config = values.config
if config is None:
values.config = create_default_config(provider)
elif isinstance(config, dict):
if provider == "qdrant":
values.config = QdrantConfig(**config)
elif provider == "chromadb":
values.config = ChromaDbConfig(**config)
elif not isinstance(config, (QdrantConfig, ChromaDbConfig)):
raise ValueError(f"Invalid config type for provider {provider}")
return values

View File

@@ -20,25 +20,25 @@ from mem0.vector_stores.base import VectorStoreBase
class Qdrant(VectorStoreBase):
def __init__(
self,
collection_name="mem0",
embedding_model_dims=1536,
client=None,
host="localhost",
port=6333,
path=None,
url=None,
api_key=None,
collection_name,
embedding_model_dims,
client,
host,
port,
path,
url,
api_key,
):
"""
Initialize the Qdrant vector store.
Args:
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
host (str, optional): Host address for Qdrant server. Defaults to "localhost".
port (int, optional): Port for Qdrant server. Defaults to 6333.
path (str, optional): Path for local Qdrant database. Defaults to None.
url (str, optional): Full URL for Qdrant server. Defaults to None.
api_key (str, optional): API key for Qdrant server. Defaults to None.
client (QdrantClient, optional): Existing Qdrant client instance.
host (str, optional): Host address for Qdrant server.
port (int, optional): Port for Qdrant server.
path (str, optional): Path for local Qdrant database.
url (str, optional): Full URL for Qdrant server.
api_key (str, optional): API key for Qdrant server.
"""
if client:
self.client = client