Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lua/jumpy/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ M.config = {
"- Do NOT wrap in markdown code fences",
"- Do NOT explain",
}, "\n"),
system_prompt_multi_file = table.concat({
"When multiple files are provided, prefix the SEARCH marker with the file path:",
"<<<< SEARCH path/to/file.lua",
"exact existing lines from that file",
"====",
"replacement lines",
">>>> REPLACE",
"",
"The path must exactly match the path shown in the --- FILE: ... --- header.",
"You may edit any subset of the provided files. Every SEARCH block MUST include a path.",
}, "\n"),
keymaps = {
prompt = "<leader>j",
next_hunk = "]h",
Expand Down
23 changes: 23 additions & 0 deletions lua/jumpy/llm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,32 @@ local function get_config()
return require("jumpy").config
end

local function build_file_block(path, contents)
return string.format("--- FILE: %s ---\n%s\n--- END FILE ---", path, contents)
end

local function build_messages(context)
local config = get_config()

local tagged = context.tagged_files
if tagged and #tagged > 0 then
local parts = {}
for _, file in ipairs(tagged) do
table.insert(parts, build_file_block(file.path, table.concat(file.lines, "\n")))
end
if context.symbols and context.symbols ~= "" then
table.insert(parts, context.symbols)
end
table.insert(parts, "")
table.insert(parts, "Instruction: " .. context.prompt)
local user_content = table.concat(parts, "\n")
local system = config.system_prompt .. "\n\n" .. config.system_prompt_multi_file
return {
{ role = "system", content = system },
{ role = "user", content = user_content },
}
end

local user_content = string.format(
"File type: %s\n\n--- FILE CONTENTS ---\n%s\n--- END FILE ---%s\n\nInstruction: %s",
context.filetype or "text",
Expand Down
79 changes: 67 additions & 12 deletions lua/jumpy/patch.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,26 @@ local function find_lines(haystack, needle)
return nil
end

local function parse_search_marker(line)
local rest = line:match("^<<<< SEARCH%s*(.*)$")
if rest == nil then
return nil
end
rest = rest:match("^%s*(.-)%s*$")
if rest == "" then
return nil
end
return rest
end

function M.parse(text)
local blocks = {}
local lines = split_lines(text)
local i = 1

while i <= #lines do
if lines[i]:match("^<<<< SEARCH%s*$") then
local path = parse_search_marker(lines[i])
if path ~= nil or lines[i]:match("^<<<< SEARCH%s*$") then
local search_lines = {}
local replace_lines = {}
i = i + 1
Expand All @@ -59,6 +72,7 @@ function M.parse(text)
end

table.insert(blocks, {
path = path,
search = search_lines,
replace = replace_lines,
})
Expand All @@ -69,13 +83,7 @@ function M.parse(text)
return blocks
end

function M.apply(original_lines, response_text)
local blocks = M.parse(response_text)

if #blocks == 0 then
return split_lines(response_text), 0
end

local function apply_blocks(original_lines, blocks)
local lines = {}
for _, l in ipairs(original_lines) do
table.insert(lines, l)
Expand All @@ -87,14 +95,14 @@ function M.apply(original_lines, response_text)
local pos = find_lines(lines, block.search)
if pos then
local new = {}
for i = 1, pos - 1 do
table.insert(new, lines[i])
for j = 1, pos - 1 do
table.insert(new, lines[j])
end
for _, l in ipairs(block.replace) do
table.insert(new, l)
end
for i = pos + #block.search, #lines do
table.insert(new, lines[i])
for j = pos + #block.search, #lines do
table.insert(new, lines[j])
end
lines = new
else
Expand All @@ -105,4 +113,51 @@ function M.apply(original_lines, response_text)
return lines, unmatched
end

function M.apply(original_lines, response_text)
local blocks = M.parse(response_text)

if #blocks == 0 then
return split_lines(response_text), 0
end

return apply_blocks(original_lines, blocks)
end

function M.apply_by_file(files_by_path, response_text, primary_path)
assert(primary_path, "apply_by_file requires a primary_path")

local blocks = M.parse(response_text)

if #blocks == 0 then
if files_by_path[primary_path] then
local lines, unmatched = M.apply(files_by_path[primary_path], response_text)
return { [primary_path] = { lines = lines, unmatched = unmatched } }, unmatched
end
return {}, 0
end

local grouped = {}
for _, block in ipairs(blocks) do
local key = block.path or primary_path
grouped[key] = grouped[key] or {}
table.insert(grouped[key], block)
end

local results = {}
local total_unmatched = 0

for path, file_blocks in pairs(grouped) do
local original = files_by_path[path]
if not original then
total_unmatched = total_unmatched + #file_blocks
else
local lines, unmatched = apply_blocks(original, file_blocks)
results[path] = { lines = lines, unmatched = unmatched }
total_unmatched = total_unmatched + unmatched
end
end

return results, total_unmatched
end

return M
104 changes: 101 additions & 3 deletions lua/jumpy/prompt.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@ local state = {

local mention_ns = vim.api.nvim_create_namespace("jumpy_mentions")

local function index_tagged_files(tagged_files)
local by_path = {}
for _, file in ipairs(tagged_files) do
by_path[file.path] = file
end
return by_path
end

local function buffer_for_tagged_file(tags, file)
if file.bufnr and vim.api.nvim_buf_is_valid(file.bufnr) then
return file.bufnr
end
return tags.open_buffer(file.abs_path)
end

local function highlight_mentions(buf)
if not vim.api.nvim_buf_is_valid(buf) then
return
Expand Down Expand Up @@ -181,12 +196,35 @@ function M._submit()
end

local source_buf = state.source_buf
local tags = require("jumpy.tags")

local source_lines = state.visual_selection and vim.split(state.visual_selection.text, "\n", { plain = true })
or vim.api.nvim_buf_get_lines(source_buf, 0, -1, false)

local source_name = vim.api.nvim_buf_get_name(source_buf)
local source_rel = source_name ~= "" and tags.rel_path(source_name, tags.project_root()) or "current"

local parsed = tags.parse(prompt_text, {
source = {
path = source_rel,
abs_path = source_name ~= "" and tags.normalize_abs(source_name) or nil,
lines = source_lines,
bufnr = source_buf,
},
})

local cleaned_prompt = parsed.cleaned_prompt
local tagged_files = parsed.tagged

if #parsed.errors > 0 then
for _, err in ipairs(parsed.errors) do
vim.notify("jumpy: " .. err, vim.log.levels.WARN)
end
end

local filetype = vim.bo[source_buf].filetype
local reprompt_idx = state.reprompt_hunk_idx
local is_multi_file = #tagged_files > 1

local llm = require("jumpy.llm")

Expand All @@ -213,7 +251,7 @@ function M._submit()
local context = {
original_lines = hunk.removed_lines,
proposed_lines = hunk.added_lines,
prompt = prompt_text,
prompt = cleaned_prompt,
symbols = symbols,
filetype = filetype,
}
Expand All @@ -227,10 +265,70 @@ function M._submit()
vim.notify("jumpy: hunk updated", vim.log.levels.INFO)
end)
end)
elseif is_multi_file then
local context = {
file_contents = table.concat(source_lines, "\n"),
tagged_files = tagged_files,
primary_path = source_rel,
prompt = cleaned_prompt,
symbols = symbols,
filetype = filetype,
}

llm.request(context, function(response_text)
vim.schedule(function()
local diff = require("jumpy.diff")
local render = require("jumpy.render")
local patch = require("jumpy.patch")

local files_by_path = {}
for _, file in ipairs(tagged_files) do
files_by_path[file.path] = file.lines
end

local results, total_unmatched = patch.apply_by_file(files_by_path, response_text, source_rel)

if total_unmatched > 0 then
vim.notify(string.format("jumpy: %d block(s) could not be matched", total_unmatched), vim.log.levels.WARN)
end

local tagged_by_path = index_tagged_files(tagged_files)
local total_hunks = 0

for path, result in pairs(results) do
local file = tagged_by_path[path]
if file then
local bufnr = buffer_for_tagged_file(tags, file)
if bufnr then
local hunks = diff.compute(file.lines, result.lines)
if #hunks > 0 then
render.show(bufnr, hunks, file.lines, result.lines)
total_hunks = total_hunks + #hunks
end
else
vim.notify("jumpy: could not open " .. path .. ", skipping", vim.log.levels.WARN)
end
end
end

if total_hunks == 0 then
vim.notify("jumpy: no changes proposed", vim.log.levels.INFO)
return
end

vim.notify(
string.format("jumpy: %d hunk(s) proposed across %d file(s)", total_hunks, vim.tbl_count(results)),
vim.log.levels.INFO
)

local nav = require("jumpy.navigate")
nav.next_hunk()
end)
end)
else
local context = {
file_contents = table.concat(source_lines, "\n"),
prompt = prompt_text,
prompt = cleaned_prompt,
symbols = symbols,
filetype = filetype,
}
Expand Down Expand Up @@ -278,7 +376,7 @@ function M._submit()
end

if prompt_text:find("@lsp") then
prompt_text = vim.trim(prompt_text:gsub("%f[%w@]@lsp%f[%W]", ""))
cleaned_prompt = vim.trim(cleaned_prompt:gsub("%f[%w@]@lsp%f[%W]", ""))

context_tools.get_workspace_symbols(tonumber(source_buf) or 0, send_request)
else
Expand Down
Loading
Loading