This repository has been archived by the owner on May 10, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 155
/
chat.cpp
274 lines (240 loc) · 9.61 KB
/
chat.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
#include "chat.h"
#include "llm.h"
#include "network.h"
#include "download.h"
Chat::Chat(QObject *parent)
: QObject(parent)
, m_id(Network::globalInstance()->generateUniqueId())
, m_name(tr("New Chat"))
, m_chatModel(new ChatModel(this))
, m_responseInProgress(false)
, m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new ChatLLM(this))
{
// Should be in same thread
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
// Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::modelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection);
connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection);
connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast
// to respond to
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::BlockingQueuedConnection);
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection);
}
void Chat::reset()
{
stopGenerating();
// Erase our current on disk representation as we're completely resetting the chat along with id
LLM::globalInstance()->chatListModel()->removeChatFile(this);
emit resetContextRequested(); // blocking queued connection
m_id = Network::globalInstance()->generateUniqueId();
emit idChanged();
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
// further down in the list. This might surprise the user. In the future, we me might get rid of
// the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown
// we effectively do a reset context. We *have* to do this right now when switching between different
// types of models. The only way to get rid of that would be a very long recalculate where we rebuild
// the context if we switch between different types of models. Probably the right way to fix this
// is to allow switching models but throwing up a dialog warning users if we switch between types
// of models that a long recalculation will ensue.
m_chatModel->clear();
}
bool Chat::isModelLoaded() const
{
return m_llmodel->isModelLoaded();
}
void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch,
repeat_penalty, repeat_penalty_tokens, LLM::globalInstance()->threadCount());
}
void Chat::regenerateResponse()
{
emit regenerateResponseRequested(); // blocking queued connection
}
void Chat::stopGenerating()
{
m_llmodel->stopGenerating();
}
QString Chat::response() const
{
return m_llmodel->response();
}
void Chat::handleResponseChanged()
{
const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, response());
emit responseChanged();
}
void Chat::responseStarted()
{
m_responseInProgress = true;
emit responseInProgressChanged();
}
void Chat::responseStopped()
{
m_responseInProgress = false;
emit responseInProgressChanged();
if (m_llmodel->generatedName().isEmpty())
emit generateNameRequested();
if (chatModel()->count() < 3)
Network::globalInstance()->sendChatStarted();
}
QString Chat::modelName() const
{
return m_llmodel->modelName();
}
void Chat::setModelName(const QString &modelName)
{
// doesn't block but will unload old model and load new one which the gui can see through changes
// to the isModelLoaded property
emit modelNameChangeRequested(modelName);
}
void Chat::newPromptResponsePair(const QString &prompt)
{
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
m_chatModel->appendResponse(tr("Response: "), prompt);
emit resetResponseRequested(); // blocking queued connection
}
bool Chat::isRecalc() const
{
return m_llmodel->isRecalc();
}
void Chat::loadDefaultModel()
{
emit loadDefaultModelRequested();
}
void Chat::loadModel(const QString &modelName)
{
emit loadModelRequested(modelName);
}
void Chat::unloadModel()
{
stopGenerating();
emit unloadModelRequested();
}
void Chat::reloadModel()
{
emit reloadModelRequested(m_savedModelName);
}
void Chat::generatedNameChanged()
{
// Only use the first three words maximum and remove newlines and extra spaces
QString gen = m_llmodel->generatedName().simplified();
QStringList words = gen.split(' ', Qt::SkipEmptyParts);
int wordCount = qMin(3, words.size());
m_name = words.mid(0, wordCount).join(' ');
emit nameChanged();
}
void Chat::handleRecalculating()
{
Network::globalInstance()->sendRecalculatingContext(m_chatModel->count());
emit recalcChanged();
}
void Chat::handleModelNameChanged()
{
m_savedModelName = modelName();
emit modelNameChanged();
}
bool Chat::serialize(QDataStream &stream, int version) const
{
stream << m_creationDate;
stream << m_id;
stream << m_name;
stream << m_userName;
stream << m_savedModelName;
if (!m_llmodel->serialize(stream, version))
return false;
if (!m_chatModel->serialize(stream, version))
return false;
return stream.status() == QDataStream::Ok;
}
bool Chat::deserialize(QDataStream &stream, int version)
{
stream >> m_creationDate;
stream >> m_id;
emit idChanged();
stream >> m_name;
stream >> m_userName;
emit nameChanged();
stream >> m_savedModelName;
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these
if (version < 2 && m_savedModelName.contains("gpt4all-j"))
return false;
if (!m_llmodel->deserialize(stream, version))
return false;
if (!m_chatModel->deserialize(stream, version))
return false;
emit chatModelChanged();
return stream.status() == QDataStream::Ok;
}
QList<QString> Chat::modelList() const
{
// Build a model list from exepath and from the localpath
QList<QString> list;
QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
QString localPath = Download::globalInstance()->downloadLocalModelsPath();
{
QDir dir(exePath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = exePath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists()) {
if (name == modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (localPath != exePath) {
QDir dir(localPath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = localPath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists() && !list.contains(name)) { // don't allow duplicates
if (name == modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (list.isEmpty()) {
if (exePath != localPath) {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath << "nor" << localPath;
} else {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath;
}
return QList<QString>();
}
return list;
}