feat(ai-sdk): Added Support for Google Provider in AI SDK (#2771)
This commit is contained in:
@@ -7,7 +7,7 @@ class Mem0ClassSelector {
|
||||
provider_wrapper: string;
|
||||
config: Mem0ProviderSettings;
|
||||
provider_config?: ProviderSettings;
|
||||
static supportedProviders = ["openai", "anthropic", "cohere", "groq"];
|
||||
static supportedProviders = ["openai", "anthropic", "cohere", "groq", "google"];
|
||||
|
||||
constructor(modelId: string, config: Mem0ProviderSettings, provider_config?: ProviderSettings) {
|
||||
this.modelId = modelId;
|
||||
|
||||
@@ -1,104 +1,82 @@
|
||||
import { LanguageModelV1, LanguageModelV1CallOptions, LanguageModelV1Prompt } from "ai";
|
||||
import { LanguageModelV1, LanguageModelV1CallOptions } from "ai";
|
||||
import { Mem0ProviderSettings } from "./mem0-provider";
|
||||
import { createOpenAI, OpenAIProviderSettings } from "@ai-sdk/openai";
|
||||
import { CohereProviderSettings, createCohere } from "@ai-sdk/cohere";
|
||||
import { AnthropicProviderSettings, createAnthropic } from "@ai-sdk/anthropic";
|
||||
import { createGoogleGenerativeAI, GoogleGenerativeAIProviderSettings } from "@ai-sdk/google";
|
||||
import { createGroq, GroqProviderSettings } from "@ai-sdk/groq";
|
||||
|
||||
export type Provider = ReturnType<typeof createOpenAI> | ReturnType<typeof createCohere> | ReturnType<typeof createAnthropic> | ReturnType<typeof createGroq> | any;
|
||||
export type ProviderSettings = OpenAIProviderSettings | CohereProviderSettings | AnthropicProviderSettings | GroqProviderSettings;
|
||||
|
||||
const convertMessagesToMem0Format = (messages: LanguageModelV1Prompt) => {
|
||||
return messages.map((message) => {
|
||||
// If the content is a string, return it as is
|
||||
if (typeof message.content === "string") {
|
||||
return message;
|
||||
}
|
||||
|
||||
// Flatten the content array into a single string
|
||||
if (Array.isArray(message.content)) {
|
||||
message.content = message.content
|
||||
.map((contentItem) => {
|
||||
if ("text" in contentItem) {
|
||||
return contentItem.text;
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.join(" ");
|
||||
}
|
||||
|
||||
const contentText = message.content;
|
||||
|
||||
return {
|
||||
role: message.role,
|
||||
content: contentText,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// Define a private provider field
|
||||
class Mem0AITextGenerator implements LanguageModelV1 {
|
||||
readonly specificationVersion = "v1";
|
||||
readonly defaultObjectGenerationMode = "json";
|
||||
readonly supportsImageUrls = false;
|
||||
readonly modelId: string;
|
||||
readonly provider = "mem0";
|
||||
|
||||
provider: Provider;
|
||||
provider_config?: ProviderSettings;
|
||||
config: Mem0ProviderSettings;
|
||||
private languageModel: any; // Use any type to avoid version conflicts
|
||||
|
||||
constructor(modelId: string, config: Mem0ProviderSettings, provider_config: ProviderSettings) {
|
||||
this.modelId = modelId;
|
||||
|
||||
switch (config.provider) {
|
||||
case "openai":
|
||||
this.provider = createOpenAI({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
}).languageModel;
|
||||
if(config?.modelType === "completion"){
|
||||
this.provider = createOpenAI({
|
||||
this.languageModel = createOpenAI({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
}).completion;
|
||||
}else if(config?.modelType === "chat"){
|
||||
this.provider = createOpenAI({
|
||||
...provider_config as OpenAIProviderSettings,
|
||||
}).completion(modelId);
|
||||
} else if(config?.modelType === "chat"){
|
||||
this.languageModel = createOpenAI({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
}).chat;
|
||||
...provider_config as OpenAIProviderSettings,
|
||||
}).chat(modelId);
|
||||
} else {
|
||||
this.languageModel = createOpenAI({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config as OpenAIProviderSettings,
|
||||
}).languageModel(modelId);
|
||||
}
|
||||
break;
|
||||
case "cohere":
|
||||
this.provider = createCohere({
|
||||
this.languageModel = createCohere({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
});
|
||||
...provider_config as CohereProviderSettings,
|
||||
})(modelId);
|
||||
break;
|
||||
case "anthropic":
|
||||
this.provider = createAnthropic({
|
||||
this.languageModel = createAnthropic({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
}).languageModel;
|
||||
...provider_config as AnthropicProviderSettings,
|
||||
}).languageModel(modelId);
|
||||
break;
|
||||
case "groq":
|
||||
this.provider = createGroq({
|
||||
this.languageModel = createGroq({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config,
|
||||
});
|
||||
...provider_config as GroqProviderSettings,
|
||||
})(modelId);
|
||||
break;
|
||||
case "google":
|
||||
this.languageModel = createGoogleGenerativeAI({
|
||||
apiKey: config?.apiKey,
|
||||
...provider_config as GoogleGenerativeAIProviderSettings,
|
||||
})(modelId);
|
||||
break;
|
||||
default:
|
||||
throw new Error("Invalid provider");
|
||||
}
|
||||
this.modelId = modelId;
|
||||
this.provider_config = provider_config;
|
||||
this.config = config!;
|
||||
}
|
||||
|
||||
|
||||
doGenerate(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
|
||||
return this.provider(this.modelId, this.provider_config).doGenerate(options);
|
||||
async doGenerate(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
|
||||
const result = await this.languageModel.doGenerate(options);
|
||||
return result as Awaited<ReturnType<LanguageModelV1['doGenerate']>>;
|
||||
}
|
||||
|
||||
doStream(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
|
||||
return this.provider(this.modelId, this.provider_config).doStream(options);
|
||||
async doStream(options: LanguageModelV1CallOptions): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
|
||||
const result = await this.languageModel.doStream(options);
|
||||
return result as Awaited<ReturnType<LanguageModelV1['doStream']>>;
|
||||
}
|
||||
}
|
||||
|
||||
export type ProviderSettings = OpenAIProviderSettings | CohereProviderSettings | AnthropicProviderSettings | GroqProviderSettings | GoogleGenerativeAIProviderSettings;
|
||||
export default Mem0AITextGenerator;
|
||||
|
||||
Reference in New Issue
Block a user