feat: add embedchain javascript package (#576)

This commit is contained in:
Taranjeet Singh
2023-09-06 17:22:44 -07:00
committed by GitHub
parent f582d70031
commit 3c3d98b9c3
44 changed files with 20073 additions and 0 deletions

View File

@@ -0,0 +1,66 @@
import { EmbedChainApp } from '../embedchain';
const mockAdd = jest.fn();
const mockAddLocal = jest.fn();
const mockQuery = jest.fn();
jest.mock('../embedchain', () => {
return {
EmbedChainApp: jest.fn().mockImplementation(() => {
return {
add: mockAdd,
addLocal: mockAddLocal,
query: mockQuery,
};
}),
};
});
describe('Test App', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('tests the App', async () => {
mockQuery.mockResolvedValue(
'Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.'
);
const navalChatBot = await new EmbedChainApp(undefined, false);
// Embed Online Resources
await navalChatBot.add('web_page', 'https://nav.al/feedback');
await navalChatBot.add('web_page', 'https://nav.al/agi');
await navalChatBot.add(
'pdf_file',
'https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf'
);
// Embed Local Resources
await navalChatBot.addLocal('qna_pair', [
'Who is Naval Ravikant?',
'Naval Ravikant is an Indian-American entrepreneur and investor.',
]);
const result = await navalChatBot.query(
'What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?'
);
expect(mockAdd).toHaveBeenCalledWith('web_page', 'https://nav.al/feedback');
expect(mockAdd).toHaveBeenCalledWith('web_page', 'https://nav.al/agi');
expect(mockAdd).toHaveBeenCalledWith(
'pdf_file',
'https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf'
);
expect(mockAddLocal).toHaveBeenCalledWith('qna_pair', [
'Who is Naval Ravikant?',
'Naval Ravikant is an Indian-American entrepreneur and investor.',
]);
expect(mockQuery).toHaveBeenCalledWith(
'What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?'
);
expect(result).toBe(
'Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.'
);
});
});

View File

@@ -0,0 +1,44 @@
import { createHash } from 'crypto';
import type { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import type { BaseLoader } from '../loaders';
import type { Input, LoaderResult } from '../models';
import type { ChunkResult } from '../models/ChunkResult';
class BaseChunker {
textSplitter: RecursiveCharacterTextSplitter;
constructor(textSplitter: RecursiveCharacterTextSplitter) {
this.textSplitter = textSplitter;
}
async createChunks(loader: BaseLoader, url: Input): Promise<ChunkResult> {
const documents: ChunkResult['documents'] = [];
const ids: ChunkResult['ids'] = [];
const datas: LoaderResult = await loader.loadData(url);
const metadatas: ChunkResult['metadatas'] = [];
const dataPromises = datas.map(async (data) => {
const { content, metaData } = data;
const chunks: string[] = await this.textSplitter.splitText(content);
chunks.forEach((chunk) => {
const chunkId = createHash('sha256')
.update(chunk + metaData.url)
.digest('hex');
ids.push(chunkId);
documents.push(chunk);
metadatas.push(metaData);
});
});
await Promise.all(dataPromises);
return {
documents,
ids,
metadatas,
};
}
}
export { BaseChunker };

View File

@@ -0,0 +1,26 @@
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { BaseChunker } from './BaseChunker';
interface TextSplitterChunkParams {
chunkSize: number;
chunkOverlap: number;
keepSeparator: boolean;
}
const TEXT_SPLITTER_CHUNK_PARAMS: TextSplitterChunkParams = {
chunkSize: 1000,
chunkOverlap: 0,
keepSeparator: false,
};
class PdfFileChunker extends BaseChunker {
constructor() {
const textSplitter = new RecursiveCharacterTextSplitter(
TEXT_SPLITTER_CHUNK_PARAMS
);
super(textSplitter);
}
}
export { PdfFileChunker };

View File

@@ -0,0 +1,26 @@
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { BaseChunker } from './BaseChunker';
interface TextSplitterChunkParams {
chunkSize: number;
chunkOverlap: number;
keepSeparator: boolean;
}
const TEXT_SPLITTER_CHUNK_PARAMS: TextSplitterChunkParams = {
chunkSize: 300,
chunkOverlap: 0,
keepSeparator: false,
};
class QnaPairChunker extends BaseChunker {
constructor() {
const textSplitter = new RecursiveCharacterTextSplitter(
TEXT_SPLITTER_CHUNK_PARAMS
);
super(textSplitter);
}
}
export { QnaPairChunker };

View File

@@ -0,0 +1,26 @@
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { BaseChunker } from './BaseChunker';
interface TextSplitterChunkParams {
chunkSize: number;
chunkOverlap: number;
keepSeparator: boolean;
}
const TEXT_SPLITTER_CHUNK_PARAMS: TextSplitterChunkParams = {
chunkSize: 500,
chunkOverlap: 0,
keepSeparator: false,
};
class WebPageChunker extends BaseChunker {
constructor() {
const textSplitter = new RecursiveCharacterTextSplitter(
TEXT_SPLITTER_CHUNK_PARAMS
);
super(textSplitter);
}
}
export { WebPageChunker };

View File

@@ -0,0 +1,6 @@
import { BaseChunker } from './BaseChunker';
import { PdfFileChunker } from './PdfFile';
import { QnaPairChunker } from './QnaPair';
import { WebPageChunker } from './WebPage';
export { BaseChunker, PdfFileChunker, QnaPairChunker, WebPageChunker };

View File

@@ -0,0 +1,317 @@
/* eslint-disable max-classes-per-file */
import type { Collection } from 'chromadb';
import type { QueryResponse } from 'chromadb/dist/main/types';
import * as fs from 'fs';
import { Document } from 'langchain/document';
import OpenAI from 'openai';
import * as path from 'path';
import { v4 as uuidv4 } from 'uuid';
import type { BaseChunker } from './chunkers';
import { PdfFileChunker, QnaPairChunker, WebPageChunker } from './chunkers';
import type { BaseLoader } from './loaders';
import { LocalQnaPairLoader, PdfFileLoader, WebPageLoader } from './loaders';
import type {
DataDict,
DataType,
FormattedResult,
Input,
LocalInput,
Metadata,
Method,
RemoteInput,
} from './models';
import { ChromaDB } from './vectordb';
import type { BaseVectorDB } from './vectordb/BaseVectorDb';
const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});
class EmbedChain {
dbClient: any;
// TODO: Definitely assign
collection!: Collection;
userAsks: [DataType, Input][] = [];
initApp: Promise<void>;
collectMetrics: boolean;
sId: string; // sessionId
constructor(db?: BaseVectorDB, collectMetrics: boolean = true) {
if (!db) {
this.initApp = this.setupChroma();
} else {
this.initApp = this.setupOther(db);
}
this.collectMetrics = collectMetrics;
// Send anonymous telemetry
this.sId = uuidv4();
this.sendTelemetryEvent('init');
}
async setupChroma(): Promise<void> {
const db = new ChromaDB();
await db.initDb;
this.dbClient = db.client;
if (db.collection) {
this.collection = db.collection;
} else {
// TODO: Add proper error handling
console.error('No collection');
}
}
async setupOther(db: BaseVectorDB): Promise<void> {
await db.initDb;
// TODO: Figure out how we can initialize an unknown database.
// this.dbClient = db.client;
// this.collection = db.collection;
this.userAsks = [];
}
static getLoader(dataType: DataType) {
const loaders: { [t in DataType]: BaseLoader } = {
pdf_file: new PdfFileLoader(),
web_page: new WebPageLoader(),
qna_pair: new LocalQnaPairLoader(),
};
return loaders[dataType];
}
static getChunker(dataType: DataType) {
const chunkers: { [t in DataType]: BaseChunker } = {
pdf_file: new PdfFileChunker(),
web_page: new WebPageChunker(),
qna_pair: new QnaPairChunker(),
};
return chunkers[dataType];
}
public async add(dataType: DataType, url: RemoteInput) {
const loader = EmbedChain.getLoader(dataType);
const chunker = EmbedChain.getChunker(dataType);
this.userAsks.push([dataType, url]);
const { documents, countNewChunks } = await this.loadAndEmbed(
loader,
chunker,
url
);
if (this.collectMetrics) {
const wordCount = documents.reduce(
(sum, document) => sum + document.split(' ').length,
0
);
this.sendTelemetryEvent('add', {
data_type: dataType,
word_count: wordCount,
chunks_count: countNewChunks,
});
}
}
public async addLocal(dataType: DataType, content: LocalInput) {
const loader = EmbedChain.getLoader(dataType);
const chunker = EmbedChain.getChunker(dataType);
this.userAsks.push([dataType, content]);
const { documents, countNewChunks } = await this.loadAndEmbed(
loader,
chunker,
content
);
if (this.collectMetrics) {
const wordCount = documents.reduce(
(sum, document) => sum + document.split(' ').length,
0
);
this.sendTelemetryEvent('add_local', {
data_type: dataType,
word_count: wordCount,
chunks_count: countNewChunks,
});
}
}
protected async loadAndEmbed(
loader: any,
chunker: BaseChunker,
src: Input
): Promise<{
documents: string[];
metadatas: Metadata[];
ids: string[];
countNewChunks: number;
}> {
const embeddingsData = await chunker.createChunks(loader, src);
let { documents, ids, metadatas } = embeddingsData;
const existingDocs = await this.collection.get({ ids });
const existingIds = new Set(existingDocs.ids);
if (existingIds.size > 0) {
const dataDict: DataDict = {};
for (let i = 0; i < ids.length; i += 1) {
const id = ids[i];
if (!existingIds.has(id)) {
dataDict[id] = { doc: documents[i], meta: metadatas[i] };
}
}
if (Object.keys(dataDict).length === 0) {
console.log(`All data from ${src} already exists in the database.`);
return { documents: [], metadatas: [], ids: [], countNewChunks: 0 };
}
ids = Object.keys(dataDict);
const dataValues = Object.values(dataDict);
documents = dataValues.map(({ doc }) => doc);
metadatas = dataValues.map(({ meta }) => meta);
}
const countBeforeAddition = await this.count();
await this.collection.add({ documents, metadatas, ids });
const countNewChunks = (await this.count()) - countBeforeAddition;
console.log(
`Successfully saved ${src}. New chunks count: ${countNewChunks}`
);
return { documents, metadatas, ids, countNewChunks };
}
static async formatResult(
results: QueryResponse
): Promise<FormattedResult[]> {
return results.documents[0].map((document: any, index: number) => {
const metadata = results.metadatas[0][index] || {};
// TODO: Add proper error handling
const distance = results.distances ? results.distances[0][index] : null;
return [new Document({ pageContent: document, metadata }), distance];
});
}
static async getOpenAiAnswer(prompt: string) {
const messages: OpenAI.Chat.CreateChatCompletionRequestMessage[] = [
{ role: 'user', content: prompt },
];
const response = await openai.chat.completions.create({
model: 'gpt-3.5-turbo',
messages,
temperature: 0,
max_tokens: 1000,
top_p: 1,
});
return (
response.choices[0].message?.content ?? 'Response could not be processed.'
);
}
protected async retrieveFromDatabase(inputQuery: string) {
const result = await this.collection.query({
nResults: 1,
queryTexts: [inputQuery],
});
const resultFormatted = await EmbedChain.formatResult(result);
const content = resultFormatted[0][0].pageContent;
return content;
}
static generatePrompt(inputQuery: string, context: any) {
const prompt = `Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n${context}\nQuery: ${inputQuery}\nHelpful Answer:`;
return prompt;
}
static async getAnswerFromLlm(prompt: string) {
const answer = await EmbedChain.getOpenAiAnswer(prompt);
return answer;
}
public async query(inputQuery: string) {
const context = await this.retrieveFromDatabase(inputQuery);
const prompt = EmbedChain.generatePrompt(inputQuery, context);
const answer = await EmbedChain.getAnswerFromLlm(prompt);
this.sendTelemetryEvent('query');
return answer;
}
public async dryRun(input_query: string) {
const context = await this.retrieveFromDatabase(input_query);
const prompt = EmbedChain.generatePrompt(input_query, context);
return prompt;
}
/**
* Count the number of embeddings.
* @returns {Promise<number>}: The number of embeddings.
*/
public count(): Promise<number> {
return this.collection.count();
}
protected async sendTelemetryEvent(method: Method, extraMetadata?: object) {
if (!this.collectMetrics) {
return;
}
const url = 'https://api.embedchain.ai/api/v1/telemetry/';
// Read package version from filesystem (because it's not in the ts root dir)
const packageJsonPath = path.join(__dirname, '..', 'package.json');
const packageJson = JSON.parse(fs.readFileSync(packageJsonPath, 'utf8'));
const metadata = {
s_id: this.sId,
version: packageJson.version,
method,
language: 'js',
...extraMetadata,
};
const maxRetries = 3;
// Retry the fetch
for (let i = 0; i < maxRetries; i += 1) {
try {
// eslint-disable-next-line no-await-in-loop
const response = await fetch(url, {
method: 'POST',
body: JSON.stringify({ metadata }),
});
if (response.ok) {
// Break out of the loop if the request was successful
break;
} else {
// Log the unsuccessful response (optional)
console.error(
`Telemetry: Attempt ${i + 1} failed with status:`,
response.status
);
}
} catch (error) {
// Log the error (optional)
console.error(`Telemetry: Attempt ${i + 1} failed with error:`, error);
}
// If this was the last attempt, throw an error or handle the failure
if (i === maxRetries - 1) {
console.error('Telemetry: Max retries reached');
}
}
}
}
class EmbedChainApp extends EmbedChain {
// The EmbedChain app.
// Has two functions: add and query.
// adds(dataType, url): adds the data from the given URL to the vector db.
// query(query): finds answer to the given query using vector database and LLM.
}
export { EmbedChainApp };

View File

@@ -0,0 +1,7 @@
import { EmbedChainApp } from './embedchain';
export const App = async () => {
const app = new EmbedChainApp();
await app.initApp;
return app;
};

View File

@@ -0,0 +1,5 @@
import type { Input, LoaderResult } from '../models';
export abstract class BaseLoader {
abstract loadData(src: Input): Promise<LoaderResult>;
}

View File

@@ -0,0 +1,21 @@
import type { LoaderResult, QnaPair } from '../models';
import { BaseLoader } from './BaseLoader';
class LocalQnaPairLoader extends BaseLoader {
// eslint-disable-next-line class-methods-use-this
async loadData(content: QnaPair): Promise<LoaderResult> {
const [question, answer] = content;
const contentText = `Q: ${question}\nA: ${answer}`;
const metaData = {
url: 'local',
};
return [
{
content: contentText,
metaData,
},
];
}
}
export { LocalQnaPairLoader };

View File

@@ -0,0 +1,58 @@
import type { TextContent } from 'pdfjs-dist/types/src/display/api';
import type { LoaderResult, Metadata } from '../models';
import { cleanString } from '../utils';
import { BaseLoader } from './BaseLoader';
const pdfjsLib = require('pdfjs-dist');
interface Page {
page_content: string;
}
class PdfFileLoader extends BaseLoader {
static async getPagesFromPdf(url: string): Promise<Page[]> {
const loadingTask = pdfjsLib.getDocument(url);
const pdf = await loadingTask.promise;
const { numPages } = pdf;
const promises = Array.from({ length: numPages }, async (_, i) => {
const page = await pdf.getPage(i + 1);
const pageText: TextContent = await page.getTextContent();
const pageContent: string = pageText.items
.map((item) => ('str' in item ? item.str : ''))
.join(' ');
return {
page_content: pageContent,
};
});
return Promise.all(promises);
}
// eslint-disable-next-line class-methods-use-this
async loadData(url: string): Promise<LoaderResult> {
const pages: Page[] = await PdfFileLoader.getPagesFromPdf(url);
const output: LoaderResult = [];
if (!pages.length) {
throw new Error('No data found');
}
pages.forEach((page) => {
let content: string = page.page_content;
content = cleanString(content);
const metaData: Metadata = {
url,
};
output.push({
content,
metaData,
});
});
return output;
}
}
export { PdfFileLoader };

View File

@@ -0,0 +1,51 @@
import axios from 'axios';
import { JSDOM } from 'jsdom';
import { cleanString } from '../utils';
import { BaseLoader } from './BaseLoader';
class WebPageLoader extends BaseLoader {
// eslint-disable-next-line class-methods-use-this
async loadData(url: string) {
const response = await axios.get(url);
const html = response.data;
const dom = new JSDOM(html);
const { document } = dom.window;
const unwantedTags = [
'nav',
'aside',
'form',
'header',
'noscript',
'svg',
'canvas',
'footer',
'script',
'style',
];
unwantedTags.forEach((tagName) => {
const elements = document.getElementsByTagName(tagName);
Array.from(elements).forEach((element) => {
// eslint-disable-next-line no-param-reassign
(element as HTMLElement).textContent = ' ';
});
});
const output = [];
let content = document.body.textContent;
if (!content) {
throw new Error('Web page content is empty.');
}
content = cleanString(content);
const metaData = {
url,
};
output.push({
content,
metaData,
});
return output;
}
}
export { WebPageLoader };

View File

@@ -0,0 +1,6 @@
import { BaseLoader } from './BaseLoader';
import { LocalQnaPairLoader } from './LocalQnaPair';
import { PdfFileLoader } from './PdfFile';
import { WebPageLoader } from './WebPage';
export { BaseLoader, LocalQnaPairLoader, PdfFileLoader, WebPageLoader };

View File

@@ -0,0 +1,7 @@
import type { Metadata } from './Metadata';
export type ChunkResult = {
documents: string[];
ids: string[];
metadatas: Metadata[];
};

View File

@@ -0,0 +1,10 @@
import type { ChunkResult } from './ChunkResult';
type Data = {
doc: ChunkResult['documents'][0];
meta: ChunkResult['metadatas'][0];
};
export type DataDict = {
[id: string]: Data;
};

View File

@@ -0,0 +1 @@
export type DataType = 'pdf_file' | 'web_page' | 'qna_pair';

View File

@@ -0,0 +1,3 @@
import type { Document } from 'langchain/document';
export type FormattedResult = [Document, number | null];

View File

@@ -0,0 +1,7 @@
import type { QnaPair } from './QnAPair';
export type RemoteInput = string;
export type LocalInput = QnaPair;
export type Input = RemoteInput | LocalInput;

View File

@@ -0,0 +1,3 @@
import type { Metadata } from './Metadata';
export type LoaderResult = { content: any; metaData: Metadata }[];

View File

@@ -0,0 +1,3 @@
export type Metadata = {
url: string;
};

View File

@@ -0,0 +1 @@
export type Method = 'init' | 'query' | 'add' | 'add_local';

View File

@@ -0,0 +1,4 @@
type Question = string;
type Answer = string;
export type QnaPair = [Question, Answer];

View File

@@ -0,0 +1,21 @@
import { DataDict } from './DataDict';
import { DataType } from './DataType';
import { FormattedResult } from './FormattedResult';
import { Input, LocalInput, RemoteInput } from './Input';
import { LoaderResult } from './LoaderResult';
import { Metadata } from './Metadata';
import { Method } from './Method';
import { QnaPair } from './QnAPair';
export {
DataDict,
DataType,
FormattedResult,
Input,
LoaderResult,
LocalInput,
Metadata,
Method,
QnaPair,
RemoteInput,
};

View File

@@ -0,0 +1,26 @@
/**
* This function takes in a string and performs a series of text cleaning operations.
* @param {str} text: The text to be cleaned. This is expected to be a string.
* @returns {str}: The cleaned text after all the cleaning operations have been performed.
*/
export function cleanString(text: string): string {
// Replacement of newline characters:
let cleanedText = text.replace(/\n/g, ' ');
// Stripping and reducing multiple spaces to single:
cleanedText = cleanedText.trim().replace(/\s+/g, ' ');
// Removing backslashes:
cleanedText = cleanedText.replace(/\\/g, '');
// Replacing hash characters:
cleanedText = cleanedText.replace(/#/g, ' ');
// Eliminating consecutive non-alphanumeric characters:
// This regex identifies consecutive non-alphanumeric characters (i.e., not a word character [a-zA-Z0-9_] and not a whitespace) in the string
// and replaces each group of such characters with a single occurrence of that character.
// For example, "!!! hello !!!" would become "! hello !".
cleanedText = cleanedText.replace(/([^\w\s])\1*/g, '$1');
return cleanedText;
}

View File

@@ -0,0 +1,14 @@
class BaseVectorDB {
initDb: Promise<void>;
constructor() {
this.initDb = this.getClientAndCollection();
}
// eslint-disable-next-line class-methods-use-this
protected async getClientAndCollection(): Promise<void> {
throw new Error('getClientAndCollection() method is not implemented');
}
}
export { BaseVectorDB };

View File

@@ -0,0 +1,38 @@
import type { Collection } from 'chromadb';
import { ChromaClient, OpenAIEmbeddingFunction } from 'chromadb';
import { BaseVectorDB } from './BaseVectorDb';
const embedder = new OpenAIEmbeddingFunction({
openai_api_key: process.env.OPENAI_API_KEY ?? '',
});
class ChromaDB extends BaseVectorDB {
client: ChromaClient | undefined;
collection: Collection | null = null;
// eslint-disable-next-line @typescript-eslint/no-useless-constructor
constructor() {
super();
}
protected async getClientAndCollection(): Promise<void> {
this.client = new ChromaClient({ path: 'http://localhost:8000' });
try {
this.collection = await this.client.getCollection({
name: 'embedchain_store',
embeddingFunction: embedder,
});
} catch (err) {
if (!this.collection) {
this.collection = await this.client.createCollection({
name: 'embedchain_store',
embeddingFunction: embedder,
});
}
}
}
}
export { ChromaDB };

View File

@@ -0,0 +1,3 @@
import { ChromaDB } from './ChromaDb';
export { ChromaDB };