Skip to content

feat: add CoT support for DeepSeek-R1 (only for reference) #228

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
3 changes: 2 additions & 1 deletion lua/gp/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ local config = {
-- secret : "sk-...",
-- secret = os.getenv("env_name.."),
openai = {
disable = false,
disable = true,
endpoint = "https://api.openai.com/v1/chat/completions",
-- secret = os.getenv("OPENAI_API_KEY"),
},
Expand Down Expand Up @@ -103,6 +103,7 @@ local config = {
disable = true,
},
{
provider = "openai",
name = "ChatGPT4o",
chat = true,
command = false,
Expand Down
55 changes: 40 additions & 15 deletions lua/gp/dispatcher.lua
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ end
---@param handler function # response handler
---@param on_exit function | nil # optional on_exit handler
---@param callback function | nil # optional callback handler
local query = function(buf, provider, payload, handler, on_exit, callback)
---@param is_reasoning boolean # whether model is reasoning model
local query = function(buf, provider, payload, handler, on_exit, callback, is_reasoning)
-- make sure handler is a function
if type(handler) ~= "function" then
logger.error(
Expand Down Expand Up @@ -241,9 +242,15 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
qt.raw_response = qt.raw_response .. line .. "\n"
end
line = line:gsub("^data: ", "")

local content = ""
local reasoning_content = ""

if line:match("choices") and line:match("delta") and line:match("content") then
line = vim.json.decode(line)
if line.choices[1] and line.choices[1].delta and line.choices[1].delta.reasoning_content then
reasoning_content = line.choices[1].delta.reasoning_content
end
if line.choices[1] and line.choices[1].delta and line.choices[1].delta.content then
content = line.choices[1].delta.content
end
Expand All @@ -267,10 +274,15 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end
end


if content and type(content) == "string" then
if reasoning_content ~= "" and type(reasoning_content) == "string" then
handler(qid, reasoning_content, true)
elseif content ~= "" and type(content) == "string" then
if is_reasoning then
handler(qid, "\n</details>\n</think>\n", false)
is_reasoning = false
end
qt.response = qt.response .. content
handler(qid, content)
handler(qid, content, false)
end
end
end
Expand Down Expand Up @@ -314,11 +326,16 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end
end


if qt.response == "" then
logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response))
if is_reasoning then
handler(qid, "\n", false)
handler(qid, "\n</details>\n</think>\n", false)
is_reasoning = false
end

-- if qt.response == "" then
-- logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response))
-- end

-- optional on_exit handler
if type(on_exit) == "function" then
on_exit(qid)
Expand Down Expand Up @@ -396,7 +413,7 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end

local temp_file = D.query_dir ..
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
helpers.table_to_file(payload, temp_file)

local curl_params = vim.deepcopy(D.config.curl_params or {})
Expand Down Expand Up @@ -428,16 +445,17 @@ end
---@param handler function # response handler
---@param on_exit function | nil # optional on_exit handler
---@param callback function | nil # optional callback handler
D.query = function(buf, provider, payload, handler, on_exit, callback)
---@param is_reasoning boolean # whether the model is reasoning model
D.query = function(buf, provider, payload, handler, on_exit, callback, is_reasoning)
if provider == "copilot" then
return vault.run_with_secret(provider, function()
vault.refresh_copilot_bearer(function()
query(buf, provider, payload, handler, on_exit, callback)
query(buf, provider, payload, handler, on_exit, callback, is_reasoning)
end)
end)
end
vault.run_with_secret(provider, function()
query(buf, provider, payload, handler, on_exit, callback)
query(buf, provider, payload, handler, on_exit, callback, is_reasoning)
end)
end

Expand Down Expand Up @@ -466,7 +484,7 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
})

local response = ""
return vim.schedule_wrap(function(qid, chunk)
return vim.schedule_wrap(function(qid, chunk, is_reasoning)
local qt = tasker.get_query(qid)
if not qt then
return
Expand Down Expand Up @@ -506,6 +524,13 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
lines[i] = prefix .. l
end

-- prepend prefix > to each line inside CoT
if is_reasoning then
for i, l in ipairs(lines) do
lines[i] = "> " .. l
end
end

local unfinished_lines = {}
for i = finished_lines + 1, #lines do
table.insert(unfinished_lines, lines[i])
Expand All @@ -514,9 +539,9 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
vim.api.nvim_buf_set_lines(buf, first_line + finished_lines, first_line + finished_lines, false, unfinished_lines)

local new_finished_lines = math.max(0, #lines - 1)
for i = finished_lines, new_finished_lines do
vim.api.nvim_buf_add_highlight(buf, qt.ns_id, hl_handler_group, first_line + i, 0, -1)
end
-- for i = finished_lines, new_finished_lines do
-- vim.api.nvim_buf_add_highlight(buf, qt.ns_id, hl_handler_group, first_line + i, 0, -1)
-- end
finished_lines = new_finished_lines

local end_line = first_line + #vim.split(response, "\n")
Expand Down
82 changes: 58 additions & 24 deletions lua/gp/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1031,22 +1031,37 @@ M.chat_respond = function(params)
agent_suffix = M.render.template(agent_suffix, { ["{{agent}}"] = agent_name })

local old_default_user_prefix = "🗨:"
local in_cot_block = false -- Flag to track if we're inside a CoT block

for index = start_index, end_index do
local line = lines[index]
if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#M.config.chat_user_prefix + 1)
elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#old_default_user_prefix + 1)
elseif line:sub(1, #agent_prefix) == agent_prefix then
table.insert(messages, { role = role, content = content })
role = "assistant"
content = ""
elseif role ~= "" then
content = content .. "\n" .. line

if line:match("^<think>$") then
in_cot_block = true
end

-- Skip lines if we're inside a CoT block
if not in_cot_block then
-- Original logic for handling chat messages
if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#M.config.chat_user_prefix + 1)
elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#old_default_user_prefix + 1)
elseif line:sub(1, #agent_prefix) == agent_prefix then
table.insert(messages, { role = role, content = content })
role = "assistant"
content = ""
elseif role ~= "" then
content = content .. "\n" .. line
end
end

if line:match("^</think>$") then
in_cot_block = false
end
end
-- insert last message not handled in loop
Expand All @@ -1063,6 +1078,8 @@ M.chat_respond = function(params)
-- make it multiline again if it contains escaped newlines
content = content:gsub("\\n", "\n")
messages[1] = { role = "system", content = content }
else
table.remove(messages, 1)
end

-- strip whitespace from ends of content
Expand All @@ -1074,12 +1091,23 @@ M.chat_respond = function(params)
local last_content_line = M.helpers.last_content_line(buf)
vim.api.nvim_buf_set_lines(buf, last_content_line, last_content_line, false, { "", agent_prefix .. agent_suffix, "" })

local offset = 0
local is_reasoning = false
-- Add CoT for DeepSeekReasoner
if string.match(agent_name, "^DeepSeekReasoner") then
vim.api.nvim_buf_set_lines(buf, last_content_line + 3, last_content_line + 3, false,
{ "<think>", "<details>", "<summary>CoT</summary>", "" })
offset = 1
is_reasoning = true
end

-- call the model and write response
M.dispatcher.query(
buf,
headers.provider or agent.provider,
M.dispatcher.prepare_payload(messages, headers.model or agent.model, headers.provider or agent.provider),
M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf), true, "", not M.config.chat_free_cursor),
M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf) + offset, true, "",
not M.config.chat_free_cursor),
vim.schedule_wrap(function(qid)
local qt = M.tasker.get_query(qid)
if not qt then
Expand Down Expand Up @@ -1125,31 +1153,36 @@ M.chat_respond = function(params)
topic_handler,
vim.schedule_wrap(function()
-- get topic from invisible buffer
local topic = vim.api.nvim_buf_get_lines(topic_buf, 0, -1, false)[1]
-- instead of the first line, get the last two line can skip CoT
local topic = vim.api.nvim_buf_get_lines(topic_buf, -3, -1, false)[1]
-- close invisible buffer
vim.api.nvim_buf_delete(topic_buf, { force = true })
-- strip whitespace from ends of topic
topic = topic:gsub("^%s*(.-)%s*$", "%1")
-- strip dot from end of topic
topic = topic:gsub("%.$", "")

-- if topic is empty do not replace it
if topic == "" then
-- if topic is empty or too long do not replace it
if topic == "" or #topic > 50 then
return
end

-- replace topic in current buffer
M.helpers.undojoin(buf)
vim.api.nvim_buf_set_lines(buf, 0, 1, false, { "# topic: " .. topic })
end)
end),
nil,
false
)
end
if not M.config.chat_free_cursor then
local line = vim.api.nvim_buf_line_count(buf)
M.helpers.cursor_to_line(line, buf, win)
end
vim.cmd("doautocmd User GpDone")
end)
end),
nil,
is_reasoning
)
end

Expand Down Expand Up @@ -1813,9 +1846,9 @@ M.Prompt = function(params, target, agent, template, prompt, whisper, callback)
end

-- select from first_line to last_line
vim.api.nvim_win_set_cursor(0, { start + 1, 0 })
vim.api.nvim_command("normal! V")
vim.api.nvim_win_set_cursor(0, { finish + 1, 0 })
-- vim.api.nvim_win_set_cursor(0, { start + 1, 0 })
-- vim.api.nvim_command("normal! V")
-- vim.api.nvim_win_set_cursor(0, { finish + 1, 0 })
end

-- prepare messages
Expand Down Expand Up @@ -1935,7 +1968,8 @@ M.Prompt = function(params, target, agent, template, prompt, whisper, callback)
on_exit(qid)
vim.cmd("doautocmd User GpDone")
end),
callback
callback,
false
)
end

Expand Down