-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathai-utils.lua
78 lines (65 loc) · 1.87 KB
/
ai-utils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
local M = {}
---@class ai.Config
---@field api_key string|nil
---@field api_key_getter nil|fun():string
---@type ai.Config
M.plugin_config = {}
local get_api_key = function()
local api_key = M.plugin_config.api_key
if api_key then
return api_key
else
local new_api_key = M.plugin_config.api_key_getter()
M.plugin_config.api_key = new_api_key
return new_api_key
end
end
M.llm_run = function(post_data, on_exit)
local API_KEY = get_api_key()
local curl = require('plenary.curl')
local url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent?key=" ..
API_KEY
curl.post(url, {
body = vim.fn.json_encode(post_data),
headers = {
content_type = "application/json",
},
callback = function(res)
vim.schedule(function()
res = vim.fn.json_decode(res.body)
---@type string
local text = res.candidates[1].content.parts[1].text
on_exit(text)
end)
end
})
end
M.llm_run_streamed = function(post_data, on_next_line, on_end)
local API_KEY = get_api_key()
local curl = require('plenary.curl')
local url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:streamGenerateContent?key=" ..
API_KEY
local partial_chunk = ""
curl.post(url, {
body = vim.fn.json_encode(post_data),
headers = {
content_type = "application/json",
},
stream = function(_, chunk)
partial_chunk = partial_chunk .. chunk
pcall(function()
local data = vim.json.decode(partial_chunk .. "]")
vim.defer_fn(function()
on_next_line(data[#data].candidates[1].content.parts[1].text)
end, 0)
end)
end,
callback = function()
vim.defer_fn(on_end, 0)
end
})
end
---TODO: think about this later
local a = require('plenary.async')
M.llm_run_async = a.wrap(M.llm_run, 2)
return M