80 lines
2.5 KiB
Lua
80 lines
2.5 KiB
Lua
local tools_module = require("chatgpt_nvim.tools")
|
|
|
|
local M = {}
|
|
|
|
local function is_destructive_command(cmd)
|
|
if not cmd then return false end
|
|
local destructive_list = { "rm", "sudo", "mv", "cp" }
|
|
for _, keyword in ipairs(destructive_list) do
|
|
if cmd:match("(^" .. keyword .. "[%s$])") or cmd:match("[%s]" .. keyword .. "[%s$]") then
|
|
return true
|
|
end
|
|
end
|
|
return false
|
|
end
|
|
|
|
local function prompt_user_tool_accept(tool_call, conf)
|
|
local auto_accept = conf.tool_auto_accept[tool_call.tool]
|
|
|
|
-- If this is an execute_command and we see it's destructive, force a user prompt
|
|
if tool_call.tool == "execute_command" and auto_accept then
|
|
if is_destructive_command(tool_call.command) then
|
|
auto_accept = false
|
|
end
|
|
end
|
|
|
|
if auto_accept then
|
|
-- If auto-accepted and not destructive, no prompt needed
|
|
return true
|
|
else
|
|
-- Build some context about the tool request
|
|
local msg = ("Tool request: %s\n"):format(tool_call.tool or "unknown")
|
|
if tool_call.path then
|
|
msg = msg .. ("Path: %s\n"):format(tool_call.path)
|
|
end
|
|
if tool_call.command then
|
|
msg = msg .. ("Command: %s\n"):format(tool_call.command)
|
|
end
|
|
if tool_call.replacements then
|
|
msg = msg .. ("Replacements: %s\n"):format(vim.inspect(tool_call.replacements))
|
|
end
|
|
|
|
msg = msg .. "Accept this tool request? [y/N]: "
|
|
|
|
-- Force a screen redraw so the user sees the prompt properly
|
|
vim.cmd("redraw")
|
|
|
|
local ans = vim.fn.input(msg)
|
|
return ans:lower() == "y"
|
|
end
|
|
end
|
|
|
|
local function handle_tool_calls(tools, conf, is_subpath_fn, read_file_fn)
|
|
local messages = {}
|
|
|
|
for _, call in ipairs(tools) do
|
|
local accepted = prompt_user_tool_accept(call, conf)
|
|
if not accepted then
|
|
table.insert(messages, ("Tool [%s] was rejected by user."):format(call.tool or "nil"))
|
|
else
|
|
local tool_impl = tools_module.tools_by_name[call.tool]
|
|
if tool_impl then
|
|
local msg = tool_impl.run(call, conf, prompt_user_tool_accept, is_subpath_fn, read_file_fn)
|
|
table.insert(messages, msg)
|
|
else
|
|
table.insert(messages, ("Unknown tool type: '%s'"):format(call.tool or "nil"))
|
|
end
|
|
end
|
|
end
|
|
|
|
local combined = table.concat(messages, "\n\n")
|
|
local limit = conf.prompt_char_limit or 8000
|
|
if #combined > limit then
|
|
return ("The combined tool output is too large (%d chars). Please break down the operations into smaller steps."):format(#combined)
|
|
end
|
|
return combined
|
|
end
|
|
|
|
M.handle_tool_calls = handle_tool_calls
|
|
return M
|