Skip to content

Commit

Permalink
Merge pull request #60 from MarcoMandar/main
Browse files Browse the repository at this point in the history
sqlite_vss issue
  • Loading branch information
sirkitree authored Oct 28, 2024
2 parents 28716e7 + e96764a commit 9a456bb
Show file tree
Hide file tree
Showing 8 changed files with 592 additions and 166 deletions.
399 changes: 399 additions & 0 deletions package-lock.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
"puppeteer-extra": "^3.3.6",
"puppeteer-extra-plugin-capsolver": "^2.0.1",
"sql.js": "^1.10.2",
"sqlite-vec": "^0.1.4-alpha.2",
"sqlite-vss": "^0.1.2",
"srt": "^0.0.3",
"systeminformation": "^5.23.5",
Expand Down
122 changes: 60 additions & 62 deletions src/adapters/sqlite.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { v4 } from "uuid";
import { load } from "../adapters/sqlite/sqlite_vss.ts";
// import { load } from "../adapters/sqlite/sqlite_vss.ts";
import { load } from "../adapters/sqlite/sqlite_vec.ts";

import { DatabaseAdapter } from "../core/database.ts";
import {
Expand Down Expand Up @@ -43,10 +44,10 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {

async getParticipantUserState(
roomId: UUID,
userId: UUID,
userId: UUID
): Promise<"FOLLOWED" | "MUTED" | null> {
const stmt = this.db.prepare(
"SELECT userState FROM participants WHERE roomId = ? AND userId = ?",
"SELECT userState FROM participants WHERE roomId = ? AND userId = ?"
);
const res = stmt.get(roomId, userId) as
| { userState: "FOLLOWED" | "MUTED" | null }
Expand All @@ -57,10 +58,10 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
async setParticipantUserState(
roomId: UUID,
userId: UUID,
state: "FOLLOWED" | "MUTED" | null,
state: "FOLLOWED" | "MUTED" | null
): Promise<void> {
const stmt = this.db.prepare(
"UPDATE participants SET userState = ? WHERE roomId = ? AND userId = ?",
"UPDATE participants SET userState = ? WHERE roomId = ? AND userId = ?"
);
stmt.run(state, roomId, userId);
}
Expand All @@ -73,7 +74,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
// Check if the 'accounts' table exists as a representative table
const tableExists = this.db
.prepare(
"SELECT name FROM sqlite_master WHERE type='table' AND name='accounts'",
"SELECT name FROM sqlite_master WHERE type='table' AND name='accounts'"
)
.get();

Expand Down Expand Up @@ -107,7 +108,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
account.username,
account.email,
account.avatarUrl,
JSON.stringify(account.details),
JSON.stringify(account.details)
);
return true;
} catch (error) {
Expand Down Expand Up @@ -198,7 +199,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
roomId: memory.roomId,
match_threshold: 0.95, // 5% similarity threshold
count: 1,
},
}
);

isUnique = similarMemories.length === 0;
Expand All @@ -209,19 +210,17 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {

// Insert the memory with the appropriate 'unique' value
const sql = `INSERT OR REPLACE INTO memories (id, type, content, embedding, userId, roomId, \`unique\`, createdAt) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`;
this.db
.prepare(sql)
.run(
memory.id ?? v4(),
tableName,
content,
JSON.stringify(memory.embedding ?? embeddingZeroVector),
memory.userId,
memory.roomId,
isUnique ? 1 : 0,
createdAt,
);
}
this.db.prepare(sql).run(
memory.id ?? v4(),
tableName,
content,
new Float32Array(memory.embedding ?? embeddingZeroVector), // Store as Float32Array
memory.userId,
memory.roomId,
isUnique ? 1 : 0,
createdAt
);
}

async searchMemories(params: {
tableName: string;
Expand All @@ -231,37 +230,36 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
match_count: number;
unique: boolean;
}): Promise<Memory[]> {
const queryParams = [
new Float32Array(params.embedding), // Ensure embedding is Float32Array
params.tableName,
params.roomId,
params.match_count,
];

let sql = `
SELECT *, (1 - vss_distance_l2(embedding, ?)) AS similarity
FROM memories
WHERE type = ?
AND roomId = ?`;
SELECT *, vec_distance_L2(embedding, ?) AS similarity
FROM memories
WHERE type = ?`;

if (params.unique) {
sql += " AND `unique` = 1";
}

sql += ` ORDER BY similarity DESC LIMIT ?`;
const queryParams = [
JSON.stringify(params.embedding),
params.tableName,
params.roomId,
params.match_count,
];
sql += ` ORDER BY similarity ASC LIMIT ?`; // ASC for lower distance
// Updated queryParams order matches the placeholders

const memories = this.db.prepare(sql).all(...queryParams) as (Memory & {
similarity: number;
})[];
return memories.map((memory) => {
return {
...memory,
createdAt:
typeof memory.createdAt === "string"
? Date.parse(memory.createdAt as string)
: memory.createdAt,
content: JSON.parse(memory.content as unknown as string),
};
});
return memories.map((memory) => ({
...memory,
createdAt:
typeof memory.createdAt === "string"
? Date.parse(memory.createdAt as string)
: memory.createdAt,
content: JSON.parse(memory.content as unknown as string),
}));
}

async searchMemoriesByEmbedding(
Expand All @@ -272,18 +270,18 @@ AND roomId = ?`;
roomId?: UUID;
unique?: boolean;
tableName: string;
},
}
): Promise<Memory[]> {
const queryParams = [
JSON.stringify(embedding),
params.tableName,
// JSON.stringify(embedding),
new Float32Array(embedding),
params.tableName,
];

let sql = `
SELECT *, (1 - vss_distance_l2(embedding, ?)) AS similarity
SELECT *, vec_distance_L2(embedding, ?) AS similarity
FROM memories
WHERE type = ?`; // AND vss_search(embedding, ?)
WHERE type = ?`;

if (params.unique) {
sql += " AND `unique` = 1";
Expand Down Expand Up @@ -329,21 +327,21 @@ AND roomId = ?`;
SELECT *
FROM memories
WHERE type = ?
AND vss_search(${opts.query_field_name}, ?)
ORDER BY vss_search(${opts.query_field_name}, ?) DESC
AND vec_distance_L2(${opts.query_field_name}, ?) <= ?
ORDER BY vec_distance_L2(${opts.query_field_name}, ?) ASC
LIMIT ?
`;
const memories = this.db
.prepare(sql)
.all(
opts.query_table_name,
opts.query_input,
opts.query_input,
opts.query_match_count,
) as Memory[];
const memories = this.db.prepare(sql).all(
opts.query_table_name,
new Float32Array(opts.query_input.split(",").map(Number)), // Convert string to Float32Array
opts.query_input,
new Float32Array(opts.query_input.split(",").map(Number))
) as Memory[];

return memories.map((memory) => ({
embedding: JSON.parse(memory.embedding as unknown as string),
embedding: Array.from(
new Float32Array(memory.embedding as unknown as Buffer)
), // Convert Buffer to number[]
levenshtein_score: 0,
}));
}
Expand All @@ -370,7 +368,7 @@ AND roomId = ?`;
JSON.stringify(params.body),
params.userId,
params.roomId,
params.type,
params.type
);
}

Expand Down Expand Up @@ -444,7 +442,7 @@ AND roomId = ?`;
async countMemories(
roomId: UUID,
unique = true,
tableName = "",
tableName = ""
): Promise<number> {
if (!tableName) {
throw new Error("tableName is required");
Expand Down Expand Up @@ -514,7 +512,7 @@ AND roomId = ?`;
goal.userId,
goal.name,
goal.status,
JSON.stringify(goal.objectives),
JSON.stringify(goal.objectives)
);
}

Expand Down Expand Up @@ -610,7 +608,7 @@ AND roomId = ?`;
params.userA,
params.userB,
params.userB,
params.userA,
params.userA
) as Relationship) || null
);
}
Expand Down
21 changes: 21 additions & 0 deletions src/adapters/sqlite/sqlite_vec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import * as sqliteVec from "sqlite-vec";
import { Database } from "better-sqlite3";

// Loads the sqlite-vec extensions into the provided SQLite database
export function loadVecExtensions(db: Database): void {
try {
// Load sqlite-vec extensions
sqliteVec.load(db);
console.log("sqlite-vec extensions loaded successfully.");
} catch (error) {
console.error("Failed to load sqlite-vec extensions:", error);
throw error;
}
}

/**
* @param db - An instance of better - sqlite3 Database
*/
export function load(db: Database): void {
loadVecExtensions(db);
}
Loading

0 comments on commit 9a456bb

Please # to comment.