(Feature) Vercel AI SDK (#2024)

This commit is contained in:
Saket Aryan
2024-11-19 23:53:58 +05:30
committed by GitHub
parent a02597ed59
commit 13374a12e9
70 changed files with 4074 additions and 0 deletions

View File

@@ -0,0 +1,110 @@
import dotenv from "dotenv";
dotenv.config();
import { generateObject } from "ai";
import { testConfig } from "../config/test-config";
import { z } from "zod";
interface Provider {
name: string;
activeModel: string;
apiKey: string | undefined;
}
const provider: Provider = {
name: "anthropic",
activeModel: "claude-3-5-sonnet-20240620",
apiKey: process.env.ANTHROPIC_API_KEY,
}
describe("ANTHROPIC Structured Outputs", () => {
const { userId } = testConfig;
let mem0: ReturnType<typeof testConfig.createTestClient>;
jest.setTimeout(30000);
beforeEach(() => {
mem0 = testConfig.createTestClient(provider);
});
describe("ANTHROPIC Object Generation Tests", () => {
// Test 1: Generate a car preference object
it("should generate a car preference object with name and steps", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
schema: z.object({
car: z.object({
name: z.string(),
steps: z.array(z.string()),
}),
}),
prompt: "Which car would I like?",
});
expect(object.car).toBeDefined();
expect(typeof object.car.name).toBe("string");
expect(Array.isArray(object.car.steps)).toBe(true);
expect(object.car.steps.every((step) => typeof step === "string")).toBe(true);
});
// Test 2: Generate an array of car objects
it("should generate an array of three car objects with name, class, and description", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "array",
schema: z.object({
name: z.string(),
class: z.string(),
description: z.string(),
}),
prompt: "Write name of three cars that I would like.",
});
expect(Array.isArray(object)).toBe(true);
expect(object.length).toBe(3);
object.forEach((car) => {
expect(car).toHaveProperty("name");
expect(typeof car.name).toBe("string");
expect(car).toHaveProperty("class");
expect(typeof car.class).toBe("string");
expect(car).toHaveProperty("description");
expect(typeof car.description).toBe("string");
});
});
// Test 3: Generate an enum for movie genre classification
it("should classify the genre of a movie plot", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "enum",
enum: ["action", "comedy", "drama", "horror", "sci-fi"],
prompt: 'Classify the genre of this movie plot: "A group of astronauts travel through a wormhole in search of a new habitable planet for humanity."',
});
expect(object).toBeDefined();
expect(object).toBe("sci-fi");
});
// Test 4: Generate an object of car names without schema
it("should generate an object with car names", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "no-schema",
prompt: "Write name of 3 cars that I would like.",
});
const carObject = object as { cars: string[] };
expect(carObject).toBeDefined();
expect(Array.isArray(carObject.cars)).toBe(true);
expect(carObject.cars.length).toBe(3);
expect(carObject.cars.every((car) => typeof car === "string")).toBe(true);
});
});
});

View File

@@ -0,0 +1,61 @@
import dotenv from "dotenv";
dotenv.config();
import { retrieveMemories } from "../src";
import { generateText, LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
import { createAnthropic } from "@ai-sdk/anthropic";
describe("ANTHROPIC Functions", () => {
const { userId } = testConfig;
jest.setTimeout(30000);
let anthropic: any;
beforeEach(() => {
anthropic = createAnthropic({
apiKey: process.env.ANTHROPIC_API_KEY,
});
});
it("should retrieve memories and generate text using ANTHROPIC provider", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: " Write only the car name and it's color." },
],
},
];
// Retrieve memories based on previous messages
const memories = await retrieveMemories(messages, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: anthropic("claude-3-haiku-20240307"),
messages: messages,
system: memories,
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using ANTHROPIC provider with memories", async () => {
const prompt = "Suggest me a good car to buy.";
const memories = await retrieveMemories(prompt, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: anthropic("claude-3-haiku-20240307"),
prompt: prompt,
system: memories
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,60 @@
import dotenv from "dotenv";
dotenv.config();
import { retrieveMemories } from "../src";
import { generateText, LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
import { createCohere } from "@ai-sdk/cohere";
describe("COHERE Functions", () => {
const { userId } = testConfig;
jest.setTimeout(30000);
let cohere: any;
beforeEach(() => {
cohere = createCohere({
apiKey: process.env.COHERE_API_KEY,
});
});
it("should retrieve memories and generate text using COHERE provider", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: " Write only the car name and it's color." },
],
},
];
// Retrieve memories based on previous messages
const memories = await retrieveMemories(messages, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: cohere("command-r-plus"),
messages: messages,
system: memories,
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using COHERE provider with memories", async () => {
const prompt = "Suggest me a good car to buy.";
const memories = await retrieveMemories(prompt, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: cohere("command-r-plus"),
prompt: prompt,
system: memories
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,86 @@
import { generateText, LanguageModelV1Prompt, streamText } from "ai";
import { addMemories } from "../src";
import { testConfig } from "../config/test-config";
interface Provider {
name: string;
activeModel: string;
apiKey: string | undefined;
}
describe.each(testConfig.providers)('TESTS: Generate/Stream Text with model %s', (provider: Provider) => {
const { userId } = testConfig;
let mem0: ReturnType<typeof testConfig.createTestClient>;
jest.setTimeout(50000);
beforeEach(() => {
mem0 = testConfig.createTestClient(provider);
});
beforeAll(async () => {
// Add some test memories before all tests
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "I love red cars." },
{ type: "text", text: "I like Toyota Cars." },
{ type: "text", text: "I prefer SUVs." },
],
}
];
await addMemories(messages, { user_id: userId });
});
it("should generate text using mem0 model", async () => {
const { text } = await generateText({
model: mem0(provider.activeModel, {
user_id: userId,
}),
prompt: "Suggest me a good car to buy!",
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using provider with memories", async () => {
const { text } = await generateText({
model: mem0(provider.activeModel, {
user_id: userId,
}),
messages: [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: "Write only the car name and it's color." },
],
}
],
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should stream text using Mem0 provider", async () => {
const { textStream } = await streamText({
model: mem0(provider.activeModel, {
user_id: userId, // Use the uniform userId
}),
prompt: "Suggest me a good car to buy! Write only the car name and it's color.",
});
// Collect streamed text parts
let streamedText = '';
for await (const textPart of textStream) {
streamedText += textPart;
}
// Ensure the streamed text is a string
expect(typeof streamedText).toBe('string');
expect(streamedText.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,61 @@
import dotenv from "dotenv";
dotenv.config();
import { retrieveMemories } from "../src";
import { generateText, LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
import { createGroq } from "@ai-sdk/groq";
describe("GROQ Functions", () => {
const { userId } = testConfig;
jest.setTimeout(30000);
let groq: any;
beforeEach(() => {
groq = createGroq({
apiKey: process.env.GROQ_API_KEY,
});
});
it("should retrieve memories and generate text using GROQ provider", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: " Write only the car name and it's color." },
],
},
];
// Retrieve memories based on previous messages
const memories = await retrieveMemories(messages, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: groq("gemma2-9b-it"),
messages: messages,
system: memories,
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using GROQ provider with memories", async () => {
const prompt = "Suggest me a good car to buy.";
const memories = await retrieveMemories(prompt, { user_id: userId });
const { text } = await generateText({
// @ts-ignore
model: groq("gemma2-9b-it"),
prompt: prompt,
system: memories
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,75 @@
import { addMemories, retrieveMemories } from "../src";
import { LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
describe("Memory Core Functions", () => {
const { userId } = testConfig;
jest.setTimeout(10000);
describe("addMemories", () => {
it("should successfully add memories and return correct format", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "I love red cars." },
{ type: "text", text: "I like Toyota Cars." },
{ type: "text", text: "I prefer SUVs." },
],
}
];
const response = await addMemories(messages, { user_id: userId });
expect(Array.isArray(response)).toBe(true);
response.forEach((memory: { event: any; }) => {
expect(memory).toHaveProperty('id');
expect(memory).toHaveProperty('data');
expect(memory).toHaveProperty('event');
expect(memory.event).toBe('ADD');
});
});
});
describe("retrieveMemories", () => {
beforeEach(async () => {
// Add some test memories before each retrieval test
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "I love red cars." },
{ type: "text", text: "I like Toyota Cars." },
{ type: "text", text: "I prefer SUVs." },
],
}
];
await addMemories(messages, { user_id: userId });
});
it("should retrieve memories with string prompt", async () => {
const prompt = "Which car would I prefer?";
const response = await retrieveMemories(prompt, { user_id: userId });
expect(typeof response).toBe('string');
expect(response.match(/Memory:/g)?.length).toBeGreaterThan(2);
});
it("should retrieve memories with array of prompts", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Which car would I prefer?" },
{ type: "text", text: "Suggest me some cars" },
],
}
];
const response = await retrieveMemories(messages, { user_id: userId });
expect(typeof response).toBe('string');
expect(response.match(/Memory:/g)?.length).toBeGreaterThan(2);
});
});
});

View File

@@ -0,0 +1,110 @@
import dotenv from "dotenv";
dotenv.config();
import { generateObject } from "ai";
import { testConfig } from "../config/test-config";
import { z } from "zod";
interface Provider {
name: string;
activeModel: string;
apiKey: string | undefined;
}
const provider: Provider = {
name: "openai",
activeModel: "gpt-4-turbo",
apiKey: process.env.OPENAI_API_KEY,
}
describe("OPENAI Structured Outputs", () => {
const { userId } = testConfig;
let mem0: ReturnType<typeof testConfig.createTestClient>;
jest.setTimeout(30000);
beforeEach(() => {
mem0 = testConfig.createTestClient(provider);
});
describe("openai Object Generation Tests", () => {
// Test 1: Generate a car preference object
it("should generate a car preference object with name and steps", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
schema: z.object({
car: z.object({
name: z.string(),
steps: z.array(z.string()),
}),
}),
prompt: "Which car would I like?",
});
expect(object.car).toBeDefined();
expect(typeof object.car.name).toBe("string");
expect(Array.isArray(object.car.steps)).toBe(true);
expect(object.car.steps.every((step) => typeof step === "string")).toBe(true);
});
// Test 2: Generate an array of car objects
it("should generate an array of three car objects with name, class, and description", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "array",
schema: z.object({
name: z.string(),
class: z.string().describe('Cars should be "SUV", "Sedan", or "Hatchback"'),
description: z.string(),
}),
prompt: "Write name of three cars that I would like.",
});
expect(Array.isArray(object)).toBe(true);
expect(object.length).toBe(3);
object.forEach((car) => {
expect(car).toHaveProperty("name");
expect(typeof car.name).toBe("string");
expect(car).toHaveProperty("class");
expect(typeof car.class).toBe("string");
expect(car).toHaveProperty("description");
expect(typeof car.description).toBe("string");
});
});
// Test 3: Generate an enum for movie genre classification
it("should classify the genre of a movie plot", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "enum",
enum: ["action", "comedy", "drama", "horror", "sci-fi"],
prompt: 'Classify the genre of this movie plot: "A group of astronauts travel through a wormhole in search of a new habitable planet for humanity."',
});
expect(object).toBeDefined();
expect(object).toBe("sci-fi");
});
// Test 4: Generate an object of car names without schema
it("should generate an object with car names", async () => {
const { object } = await generateObject({
model: mem0(provider.activeModel, {
user_id: userId,
}),
output: "no-schema",
prompt: "Write name of 3 cars that I would like.",
});
const carObject = object as { cars: string[] };
expect(carObject).toBeDefined();
expect(Array.isArray(carObject.cars)).toBe(true);
expect(carObject.cars.length).toBe(3);
expect(carObject.cars.every((car) => typeof car === "string")).toBe(true);
});
});
});

View File

@@ -0,0 +1,58 @@
import dotenv from "dotenv";
dotenv.config();
import { retrieveMemories } from "../src";
import { generateText, LanguageModelV1Prompt } from "ai";
import { testConfig } from "../config/test-config";
import { createOpenAI } from "@ai-sdk/openai";
describe("OPENAI Functions", () => {
const { userId } = testConfig;
jest.setTimeout(30000);
let openai: any;
beforeEach(() => {
openai = createOpenAI({
apiKey: process.env.OPENAI_API_KEY,
});
});
it("should retrieve memories and generate text using OpenAI provider", async () => {
const messages: LanguageModelV1Prompt = [
{
role: "user",
content: [
{ type: "text", text: "Suggest me a good car to buy." },
{ type: "text", text: " Write only the car name and it's color." },
],
},
];
// Retrieve memories based on previous messages
const memories = await retrieveMemories(messages, { user_id: userId });
const { text } = await generateText({
model: openai("gpt-4-turbo"),
messages: messages,
system: memories,
});
// Expect text to be a string
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
it("should generate text using openai provider with memories", async () => {
const prompt = "Suggest me a good car to buy.";
const memories = await retrieveMemories(prompt, { user_id: userId });
const { text } = await generateText({
model: openai("gpt-4-turbo"),
prompt: prompt,
system: memories
});
expect(typeof text).toBe('string');
expect(text.length).toBeGreaterThan(0);
});
});

View File

@@ -0,0 +1,77 @@
import { generateText, streamText } from "ai";
import { testConfig } from "../config/test-config";
interface Provider {
name: string;
activeModel: string;
apiKey: string | undefined;
}
describe.each(testConfig.providers)('TEXT/STREAM PROPERTIES: Tests with model %s', (provider: Provider) => {
const { userId } = testConfig;
let mem0: ReturnType<typeof testConfig.createTestClient>;
jest.setTimeout(50000);
beforeEach(() => {
mem0 = testConfig.createTestClient(provider);
});
it("should stream text with onChunk handler", async () => {
const chunkTexts: string[] = [];
const { textStream } = await streamText({
model: mem0(provider.activeModel, {
user_id: userId, // Use the uniform userId
}),
prompt: "Write only the name of the car I prefer and its color.",
onChunk({ chunk }) {
if (chunk.type === "text-delta") {
// Store chunk text for assertions
chunkTexts.push(chunk.textDelta);
}
},
});
// Wait for the stream to complete
for await (const _ of textStream) {
}
// Ensure chunks are collected
expect(chunkTexts.length).toBeGreaterThan(0);
expect(chunkTexts.every((text) => typeof text === "string")).toBe(true);
});
it("should call onFinish handler without throwing an error", async () => {
await streamText({
model: mem0(provider.activeModel, {
user_id: userId, // Use the uniform userId
}),
prompt: "Write only the name of the car I prefer and its color.",
onFinish({ text, finishReason, usage }) {
},
});
});
it("should generate fullStream with expected usage", async () => {
const {
text, // combined text
usage, // combined usage of all steps
} = await generateText({
model: mem0(provider.activeModel), // Ensure the model name is correct
maxSteps: 5, // Enable multi-step calls
experimental_continueSteps: true,
prompt:
"Suggest me some good cars to buy. Each response MUST HAVE at least 200 words.",
});
// Ensure text is a string
expect(typeof text).toBe("string");
// Check usage
// promptTokens is a number, so we use toBeCloseTo instead of toBe and it should be in the range 155 to 165
expect(usage.promptTokens).toBeGreaterThanOrEqual(100);
expect(usage.promptTokens).toBeLessThanOrEqual(500);
expect(usage.completionTokens).toBeGreaterThanOrEqual(250); // Check completion tokens are above 250
expect(usage.totalTokens).toBeGreaterThan(400); // Check total tokens are above 400
});
});