Skip to content

Commit 5954925

Browse files
committed
some simple working chat
1 parent 7a31d6f commit 5954925

9 files changed

+366
-124
lines changed

llama/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ message(STATUS "GGML_INCLUDE_DIR: ${GGML_INCLUDE_DIR}")
1313

1414
configure_file(config.hpp.in config.hpp)
1515

16-
set(SOURCES src/main.cpp src/ApplicationLogic.cpp)
16+
set(SOURCES src/main.cpp src/ApplicationLogic.cpp ../src/helpers/llvm.cpp)
1717

1818
if(MSVC)
1919
set(BNAME ${LLAMA_BINARY_NAME})

llama/src/ApplicationLogic.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ std::string ApplicationLogic::LlamaGenerate(const std::string& prompt) {
398398
printf("%s", piece.c_str());
399399
fflush(stdout);
400400
response += piece;
401-
this->currentMessage->AppendOrCreateLastAssistantAnswer(piece);
401+
this->currentMessage->UpdateOrCreateAssistantAnswer(piece);
402402
this->UpdateCurrentSession();
403403

404404
// prepare the next batch with the sampled token
@@ -407,3 +407,6 @@ std::string ApplicationLogic::LlamaGenerate(const std::string& prompt) {
407407

408408
return response;
409409
}
410+
void ApplicationLogic::UpdateCurrentSession() {
411+
this->sharedMemoryManager->write(this->currentMessage->toString());
412+
}

llama/src/ApplicationLogic.h

+1-7
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,12 @@ class ApplicationLogic {
8181
std::vector<llama_chat_message> messages = {};
8282
std::vector<char> formatted = {};
8383
int prev_len = 0;
84-
8584
bool loadModel();
8685
void unloadModel();
87-
8886
bool loadContext();
8987
void unloadContext();
90-
9188
void generateText();
92-
93-
inline void UpdateCurrentSession() {
94-
this->sharedMemoryManager->write(this->currentMessage->toString());
95-
}
89+
void UpdateCurrentSession();
9690

9791
std::string LlamaGenerate(const std::string& prompt);
9892
};

src/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ set(SOURCES
1616
helpers/QueueItem.cpp
1717
helpers/DataViewListManager.cpp
1818
helpers/ModelUiManager.cpp
19+
helpers/llvm.cpp
1920
)
2021

2122
add_executable(${PROJECT_BINARY_NAME} ${SOURCES})

src/helpers/llvm.cpp

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#include "llvm.h"
2+
3+
void sd_gui_utils::llvmMessage::Update(const llvmMessage& other) {
4+
std::lock_guard<std::mutex> lock(mutex);
5+
if (this->id != other.id) {
6+
throw std::runtime_error("Message IDs do not match");
7+
}
8+
this->updated_at = other.updated_at;
9+
this->messages = other.messages;
10+
this->model_path = other.model_path;
11+
this->status_message = other.status_message;
12+
this->title = other.title;
13+
this->status = other.status;
14+
this->command = other.command;
15+
this->ngl = other.ngl;
16+
this->n_ctx = other.n_ctx;
17+
this->n_batch = other.n_batch;
18+
this->n_threads = other.n_threads;
19+
this->next_message_id = other.next_message_id;
20+
}
21+
std::map<uint64_t, sd_gui_utils::llvmText> sd_gui_utils::llvmMessage::GetMessages() {
22+
std::lock_guard<std::mutex> lock(mutex);
23+
return this->messages;
24+
}
25+
void sd_gui_utils::llvmMessage::InsertMessage(const llvmText& message, uint64_t id) {
26+
std::lock_guard<std::mutex> lock(mutex);
27+
if (id == 0) {
28+
id = this->next_message_id++;
29+
}
30+
this->messages[id] = message;
31+
}
32+
sd_gui_utils::llvmText sd_gui_utils::llvmMessage::GetMessage(uint64_t id) {
33+
std::lock_guard<std::mutex> lock(mutex);
34+
if (this->messages.contains(id)) {
35+
return this->messages[id];
36+
}
37+
return llvmText();
38+
}
39+
uint64_t sd_gui_utils::llvmMessage::GetNextMessageId() {
40+
std::lock_guard<std::mutex> lock(mutex);
41+
return next_message_id;
42+
}
43+
void sd_gui_utils::llvmMessage::SetId() {
44+
std::lock_guard<std::mutex> lock(mutex);
45+
this->id = GenerateId();
46+
}
47+
uint64_t sd_gui_utils::llvmMessage::GetId() {
48+
std::lock_guard<std::mutex> lock(mutex);
49+
return this->id;
50+
}
51+
const uint64_t sd_gui_utils::llvmMessage::GetUpdatedAt() const {
52+
std::lock_guard<std::mutex> lock(mutex);
53+
return this->updated_at;
54+
}
55+
bool sd_gui_utils::llvmMessage::CheckUpdatedAt(const uint64_t& updated_at) {
56+
std::lock_guard<std::mutex> lock(mutex);
57+
return this->updated_at == updated_at;
58+
}
59+
void sd_gui_utils::llvmMessage::SetCommandType(llvmCommand cmd) {
60+
std::lock_guard<std::mutex> lock(mutex);
61+
this->command = cmd;
62+
}
63+
sd_gui_utils::llvmCommand sd_gui_utils::llvmMessage::GetCommandType() {
64+
std::lock_guard<std::mutex> lock(mutex);
65+
return this->command;
66+
}
67+
void sd_gui_utils::llvmMessage::SetModelPath(const std::string& path) {
68+
std::lock_guard<std::mutex> lock(mutex);
69+
this->model_path = path;
70+
}
71+
const std::string sd_gui_utils::llvmMessage::GetModelPath() {
72+
std::lock_guard<std::mutex> lock(mutex);
73+
return this->model_path;
74+
}
75+
void sd_gui_utils::llvmMessage::SetStatus(llvmstatus status) {
76+
std::lock_guard<std::mutex> lock(mutex);
77+
this->status = status;
78+
}
79+
sd_gui_utils::llvmstatus sd_gui_utils::llvmMessage::GetStatus() {
80+
std::lock_guard<std::mutex> lock(mutex);
81+
return this->status;
82+
}
83+
void sd_gui_utils::llvmMessage::SetTitle(const std::string& title) {
84+
std::lock_guard<std::mutex> lock(mutex);
85+
this->title = title;
86+
}
87+
void sd_gui_utils::llvmMessage::SetNgl(int ngl) {
88+
std::lock_guard<std::mutex> lock(mutex);
89+
this->ngl = ngl;
90+
}
91+
int sd_gui_utils::llvmMessage::GetNgl() {
92+
std::lock_guard<std::mutex> lock(mutex);
93+
return this->ngl;
94+
}
95+
void sd_gui_utils::llvmMessage::SetNCtx(int n_ctx) {
96+
std::lock_guard<std::mutex> lock(mutex);
97+
this->n_ctx = n_ctx;
98+
}
99+
int sd_gui_utils::llvmMessage::GetNctx() {
100+
std::lock_guard<std::mutex> lock(mutex);
101+
return this->n_ctx;
102+
}
103+
void sd_gui_utils::llvmMessage::SetNThreads(int n_threads) {
104+
std::lock_guard<std::mutex> lock(mutex);
105+
this->n_threads = n_threads;
106+
}
107+
int sd_gui_utils::llvmMessage::GetNThreads() {
108+
std::lock_guard<std::mutex> lock(mutex);
109+
return this->n_threads;
110+
}
111+
void sd_gui_utils::llvmMessage::SetNBatch(int n_batch) {
112+
std::lock_guard<std::mutex> lock(mutex);
113+
this->n_batch = n_batch;
114+
}
115+
int sd_gui_utils::llvmMessage::GetNBatch() {
116+
std::lock_guard<std::mutex> lock(mutex);
117+
return this->n_batch;
118+
}
119+
void sd_gui_utils::llvmMessage::SetStatusMessage(const std::string& msg) {
120+
std::lock_guard<std::mutex> lock(mutex);
121+
this->status_message = msg;
122+
this->updated_at = this->GenerateId();
123+
}
124+
std::string sd_gui_utils::llvmMessage::GetStatusMessage() {
125+
std::lock_guard<std::mutex> lock(mutex);
126+
return this->status_message;
127+
}
128+
int sd_gui_utils::llvmMessage::GenerateId() {
129+
return std::chrono::duration_cast<std::chrono::milliseconds>(
130+
std::chrono::system_clock::now().time_since_epoch())
131+
.count();
132+
}
133+
const std::string sd_gui_utils::llvmMessage::GetLatestUserPrompt() {
134+
std::lock_guard<std::mutex> lock(mutex);
135+
for (const auto& p : this->messages) {
136+
if (p.second.sender == llvmTextSender::USER) {
137+
return p.second.text;
138+
}
139+
}
140+
return "";
141+
}
142+
const sd_gui_utils::llvmText sd_gui_utils::llvmMessage::GetLatestMessage() {
143+
std::lock_guard<std::mutex> lock(mutex);
144+
if (this->messages.empty()) {
145+
return llvmText();
146+
}
147+
auto last = this->messages.rbegin();
148+
return last->second;
149+
}
150+
151+
void sd_gui_utils::llvmMessage::AppendUserPrompt(const std::string& str) {
152+
std::lock_guard<std::mutex> lock(mutex);
153+
auto t = llvmText{llvmTextSender::USER};
154+
t.UpdateText(str);
155+
this->messages[this->next_message_id++] = t;
156+
this->updated_at = this->GenerateId();
157+
}
158+
159+
std::string sd_gui_utils::llvmMessage::toString() {
160+
std::lock_guard<std::mutex> lock(mutex);
161+
try {
162+
nlohmann::json j = *this;
163+
return j.dump();
164+
} catch (const std::exception& e) {
165+
std::cerr << "Error converting to string: " << e.what() << __FILE__ << ": " << __LINE__ << std::endl;
166+
return "";
167+
}
168+
}
169+
170+
void sd_gui_utils::llvmMessage::UpdateOrCreateAssistantAnswer(const std::string& str) {
171+
std::lock_guard<std::mutex> lock(mutex);
172+
173+
if (this->messages.empty()) {
174+
return;
175+
}
176+
177+
auto last = this->messages.rbegin();
178+
if (last->second.sender == llvmTextSender::ASSISTANT) {
179+
last->second.UpdateText(str);
180+
} else {
181+
auto t = llvmText{llvmTextSender::ASSISTANT};
182+
t.UpdateText(str);
183+
this->messages[this->next_message_id++] = t;
184+
}
185+
186+
this->updated_at = this->GenerateId();
187+
};

0 commit comments

Comments
 (0)