Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
a4004 committed Feb 15, 2024
1 parent 82daf26 commit e0ec83d
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 81 deletions.
156 changes: 77 additions & 79 deletions src/modules/assistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { ArgType, Bot, Command } from "orange-bot-base";
import { generate_no_context, generate_with_context } from "./gpt/openai.js";
import { getLogger } from "orange-common-lib";
import { APIEmbed, ActionRowBuilder, ButtonBuilder, ButtonStyle, EmbedBuilder, Message } from "discord.js";
import { allowUser, getOraUser, updateOraUser, createOraUser, calculateCost, ora_user, initDb } from "./gpt/costmgr.js";
import { allowUser, getOraUser, updateOraUser, createOraUser, calculateCost, ora_user, initDb, resetAllDailyCaps } from "./gpt/costmgr.js";

const logger = getLogger("assistant");

Expand All @@ -23,8 +23,10 @@ const command = {

const context_map: Map<string, string> = new Map();

export default function (bot: Bot) {
initDb();
export default async function (bot: Bot) {
await initDb();

setTimeout(() => resetAllDailyCaps(), 24 * 60 * 60 * 1000);

bot.client.on("interactionCreate", async interaction => {
if (interaction.isButton() && interaction.customId.startsWith("ora_")) {
Expand Down Expand Up @@ -198,98 +200,94 @@ export default function (bot: Bot) {


bot.client.on("messageCreate", async msg => {
if (bot.client.user && (msg.reference && msg.reference.messageId && context_map.has(msg.reference?.messageId) || (msg.mentions.has(bot.client.user) && msg.content.startsWith(`<@${bot.client.user.id}>`)))) {
var existing_account: ora_user;
logger.log(msg.author.id);

const allowed = await allowUser(msg.author.id);

if (!allowed) {
return;
}
if (!bot.client.user) {
logger.warn("bot.client.user not set! Cannot reply to AI request!");
return;
}

existing_account = await getOraUser(msg.author.id) as ora_user;

msg.channel.sendTyping();
const is_replying_to_context = msg.reference && msg.reference.messageId && context_map.has(msg.reference?.messageId);
const is_starting_new_ctx = msg.content.startsWith(`<@${bot.client.user.id}>`);

const user = msg.author.displayName;
const id = msg.author.id;
if (!is_replying_to_context && !is_starting_new_ctx) {
return;
}

if (!(await allowUser(msg.author.id))) {
return;
}

var result: { response?: string, thread_id?: string, input_tokens?: number, output_tokens?: number, new_context?: boolean };
const existing_account = await getOraUser(msg.author.id) as ora_user;
msg.channel.sendTyping();

if (msg.reference && msg.reference.messageId && context_map.has(msg.reference.messageId)) {
logger.info(`Using previous context: ${context_map.get(msg.reference.messageId)}`);
result = await generate_with_context(context_map.get(msg.reference.messageId)!, user, id, msg.content.replace(/<@\d+>/g, '').trim(), "asst_c053PWqAKmuUgJ0whEjGpJzG");
} else {
result = await generate_no_context(user, id, msg.content.replace(/<@\d+>/g, '').trim(), "asst_c053PWqAKmuUgJ0whEjGpJzG");
logger.info(`Using new context: ${result.thread_id}`);
}
const user = msg.author.displayName;
const id = msg.author.id;

//const result = await generate_no_context(user, id, msg.content.replace(/<@\d+>/g, '').trim(), "asst_c053PWqAKmuUgJ0whEjGpJzG");
var result: { response?: string, thread_id?: string, input_tokens?: number, output_tokens?: number, new_context?: boolean };

const sys_prompt_tokens = 269 + 25;
const input_tokens = result.input_tokens ?? 0;
const output_tokens = result.output_tokens ?? 0;
if (msg.reference && msg.reference.messageId && context_map.has(msg.reference.messageId)) {
logger.info(`Using previous context: ${context_map.get(msg.reference.messageId)}`);
result = await generate_with_context(context_map.get(msg.reference.messageId)!, user, id, msg.content.replace(/<@\d+>/g, '').trim(), "asst_c053PWqAKmuUgJ0whEjGpJzG");
} else {
result = await generate_no_context(user, id, msg.content.replace(/<@\d+>/g, '').trim(), "asst_c053PWqAKmuUgJ0whEjGpJzG");
logger.info(`Using new context: ${result.thread_id}`);
}

const { total_tokens, input_cost, output_cost, total_cost } = calculateCost(sys_prompt_tokens, input_tokens, output_tokens);
const sys_prompt_tokens = 269 + 25;
const input_tokens = result.input_tokens ?? 0;
const output_tokens = result.output_tokens ?? 0;

const cost_info = `Input cost: $${input_cost.toFixed(4)} (${input_tokens} tokens)\n` +
`Output cost: $${output_cost.toFixed(4)} (${output_tokens} tokens)\n` +
`Total cost: $${(total_cost).toFixed(4)} (${total_tokens} total tokens)`;
logger.info(cost_info);
const { total_tokens, input_cost, output_cost, total_cost } = calculateCost(sys_prompt_tokens, input_tokens, output_tokens);

if (result.response) {
var embeds: APIEmbed[] = [];
const cost_info = `Input cost: $${input_cost.toFixed(4)} (${input_tokens} tokens)\n` +
`Output cost: $${output_cost.toFixed(4)} (${output_tokens} tokens)\n` +
`Total cost: $${(total_cost).toFixed(4)} (${total_tokens} total tokens)`;
logger.info(cost_info);

if (result.new_context) {
embeds.push(
{
color: 0xffff00,
description: `This dialogue is in a new context window.`
}
);
}
if (result.response) {
var embeds: APIEmbed[] = [];

if (process.env.OPENAI_SHOW_PRICE && process.env.OPENAI_SHOW_PRICE === "true") {
embeds.push(
{
description: `This request: **$${total_cost.toFixed(4)}** (**${total_tokens}** tokens)\n` +
`**$${(total_cost * 60).toFixed(2)}**/h\u00A0\u00A0` +
`**$${(total_cost * 60 * 24).toFixed(2)}**/d\u00A0\u00A0` +
`**$${(total_cost * 60 * 24 * 7).toFixed(2)}**/w\u00A0\u00A0` +
`**$${(total_cost * 60 * 24 * 30).toFixed(2)}**/m\u00A0\u00A0` +
`**$${(total_cost * 60 * 24 * 365).toFixed(2)}**/y`,
footer: { text: `${input_tokens} tokens in \u00A0\u00A0 ${output_tokens} tokens out \u00A0\u00A0 ${total_tokens} tokens/min` },
}
);
}

const reply = await msg.reply({ content: result.response, embeds: embeds });

if (result.thread_id) {
context_map.set(reply.id, result.thread_id);
setTimeout(() => context_map.delete(reply.id), 3600000);
}
if (result.new_context) {
embeds.push( { color: 0xffff00, description: `This dialogue is in a new context window.` } );
}
else {
msg.reply({
embeds: [{
title: "Could not generate response",
description: "The response I received from OpenAI was `undefined`.",
}]
});

if (process.env.OPENAI_SHOW_PRICE && process.env.OPENAI_SHOW_PRICE === "true") {
embeds.push(
{
description: `This request: **$${total_cost.toFixed(4)}** (**${total_tokens}** tokens)\n` +
`**$${(total_cost * 60).toFixed(2)}**/h\u00A0\u00A0` +
`**$${(total_cost * 60 * 24).toFixed(2)}**/d\u00A0\u00A0` +
`**$${(total_cost * 60 * 24 * 7).toFixed(2)}**/w\u00A0\u00A0` +
`**$${(total_cost * 60 * 24 * 30).toFixed(2)}**/m\u00A0\u00A0` +
`**$${(total_cost * 60 * 24 * 365).toFixed(2)}**/y`,
footer: { text: `${input_tokens} tokens in \u00A0\u00A0 ${output_tokens} tokens out \u00A0\u00A0 ${total_tokens} tokens/min` },
}
);
}

if (existing_account) {
updateOraUser(id,
{
total_requests: existing_account.total_requests + 1,
total_tokens: existing_account.total_tokens + total_tokens,
daily_cost: existing_account.daily_cost + total_cost,
total_cost: existing_account.total_cost + total_cost
});
const reply = await msg.reply({ content: result.response, embeds: embeds });

if (result.thread_id) {
context_map.set(reply.id, result.thread_id);
setTimeout(() => context_map.delete(reply.id), 3600000);
}
}
else {
msg.reply({
embeds: [{
title: "Could not generate response",
description: "The response I received from OpenAI was `undefined`.",
}]
});
}

if (existing_account) {
updateOraUser(id,
{
total_requests: existing_account.total_requests + 1,
total_tokens: existing_account.total_tokens + total_tokens,
daily_cost: existing_account.daily_cost + total_cost,
total_cost: existing_account.total_cost + total_cost
});
}
});
};
17 changes: 15 additions & 2 deletions src/modules/gpt/costmgr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type ora_user = {
updated: Date
};

function initDb() {
async function initDb() {
logger.info(`Connecting to pocketbase...`);
pb = new pocketbase(`https://${process.env.PB_DOMAIN!}`);

Expand All @@ -37,6 +37,11 @@ function initDb() {
logger.error(err);
setTimeout(initDb, 5000);
});

while (!pb.authStore.isValid) {
await sleep(1000);
continue;
}
}

async function getOraUser(user_id: string): Promise<ora_user | undefined> {
Expand All @@ -51,6 +56,14 @@ async function getOraUser(user_id: string): Promise<ora_user | undefined> {
}
}

async function resetAllDailyCaps() {
const users = await pb.collection("ora_users").getFullList<ora_user>();
for (const user of users) {
logger.log(`Resetting daily cost cap for user ${user.user_id}...`);
await updateOraUser(user.user_id, { daily_cost: 0 });
}
}

async function createOraUser(user_id: string, name: string): Promise<string | undefined> {
try {
const user = await pb.collection("ora_users").create<ora_user>(
Expand Down Expand Up @@ -127,4 +140,4 @@ function calculateCost(sys_prompt_tokens: number, input_tokens: number, output_t
return { total_tokens, input_cost, output_cost, total_cost };
}

export { initDb, allowUser, getOraUser, updateOraUser, createOraUser, calculateCost, ora_user };
export { initDb, allowUser, getOraUser, updateOraUser, createOraUser, calculateCost, ora_user, resetAllDailyCaps };

0 comments on commit e0ec83d

Please # to comment.