Skip to content

Commit 725e646

Browse files
Merge remote-tracking branch 'upstream/main' into more_hooks
2 parents 7542cfd + 526e244 commit 725e646

15 files changed

+175
-97
lines changed

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ require('opencode').setup({
252252
enabled = false, -- Enable debug messages in the output window
253253
},
254254
prompt_guard = nil, -- Optional function that returns boolean to control when prompts can be sent (see Prompt Guard section)
255+
256+
-- User Hooks for custom behavior at certain events
257+
hooks = {
258+
on_file_edited = nil, -- Called after a file is edited by opencode.
259+
on_session_loaded = nil, -- Called after a session is loaded.
260+
},
255261
})
256262
```
257263

@@ -585,6 +591,28 @@ The plugin defines several highlight groups that can be customized to match your
585591

586592
The `prompt_guard` configuration option allows you to control when prompts can be sent to Opencode. This is useful for preventing accidental or unauthorized AI interactions in certain contexts.
587593

594+
## 🪝Custom user hooks
595+
596+
You can define custom functions to be called at specific events in Opencode:
597+
598+
- `on_file_edited`: Called after a file is edited by Opencode.
599+
- `on_session_loaded`: Called after a session is loaded.
600+
601+
```lua
602+
require('opencode').setup({
603+
hooks = {
604+
on_file_edited = function(file_path, edit_type)
605+
-- Custom logic after a file is edited
606+
print("File edited: " .. file_path .. " Type: " .. edit_type)
607+
end,
608+
on_session_loaded = function(session_name)
609+
-- Custom logic after a session is loaded
610+
print("Session loaded: " .. session_name)
611+
end,
612+
},
613+
})
614+
```
615+
588616
### Configuration
589617

590618
Set `prompt_guard` to a function that returns a boolean:

lua/opencode/ui/contextual_actions.lua

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
local state = require('opencode.state')
2-
local keymap = require('opencode.keymap')
32
local output_window = require('opencode.ui.output_window')
43

54
local M = {}
@@ -26,6 +25,13 @@ function M.setup_contextual_actions(windows)
2625
callback = function()
2726
vim.schedule(function()
2827
local line_num = vim.api.nvim_win_get_cursor(0)[1]
28+
29+
if not line_num or line_num <= 0 or not state.windows or not state.windows.output_buf then
30+
return
31+
end
32+
33+
line_num = line_num - 1 -- need api-indexing (e.g. 0 based line #), win_get_cursor returns 1 based line #
34+
2935
local actions = require('opencode.ui.renderer').get_actions_for_line(line_num)
3036
last_line_num = line_num
3137

@@ -34,7 +40,7 @@ function M.setup_contextual_actions(windows)
3440

3541
if actions and #actions > 0 then
3642
dirty = true
37-
M.show_contextual_actions_menu(state.windows.output_buf, line_num, actions, ns_id)
43+
M.show_contextual_actions_menu(state.windows.output_buf, actions, ns_id)
3844
end
3945
end)
4046
end,
@@ -48,6 +54,7 @@ function M.setup_contextual_actions(windows)
4854
if not output_window.mounted() then
4955
return
5056
end
57+
---@cast state.windows { output_buf: integer}
5158
local line_num = vim.api.nvim_win_get_cursor(0)[1]
5259
if last_line_num == line_num and not dirty then
5360
return
@@ -61,15 +68,17 @@ function M.setup_contextual_actions(windows)
6168
group = augroup,
6269
buffer = windows.output_buf,
6370
callback = function()
64-
vim.api.nvim_buf_clear_namespace(state.windows.output_buf, ns_id, 0, -1)
65-
clear_keymaps(state.windows.output_buf)
71+
if state.windows and state.windows.output_buf then
72+
vim.api.nvim_buf_clear_namespace(state.windows.output_buf, ns_id, 0, -1)
73+
clear_keymaps(state.windows.output_buf)
74+
end
6675
last_line_num = nil
6776
dirty = false
6877
end,
6978
})
7079
end
7180

72-
function M.show_contextual_actions_menu(buf, line_num, actions, ns_id)
81+
function M.show_contextual_actions_menu(buf, actions, ns_id)
7382
clear_keymaps(buf)
7483

7584
for _, action in ipairs(actions) do
@@ -80,7 +89,7 @@ function M.show_contextual_actions_menu(buf, line_num, actions, ns_id)
8089
hl_mode = 'combine',
8190
}
8291

83-
vim.api.nvim_buf_set_extmark(buf, ns_id, action.display_line - 1, 0, mark)
92+
vim.api.nvim_buf_set_extmark(buf, ns_id, action.display_line, 0, mark --[[@as vim.api.keyset.set_extmark]])
8493
end
8594
-- Setup key mappings for actions
8695
for _, action in ipairs(actions) do

lua/opencode/ui/formatter.lua

Lines changed: 28 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -154,66 +154,46 @@ function M._format_revert_message(session_data, start_idx)
154154
return output
155155
end
156156

157+
local function add_action(output, text, action_type, args, key, line)
158+
-- actions use api-indexing (e.g. 0 indexed)
159+
line = (line or output:get_line_count()) - 1
160+
output:add_action({
161+
text = text,
162+
type = action_type,
163+
args = args,
164+
key = key,
165+
display_line = line,
166+
range = { from = line, to = line },
167+
})
168+
end
169+
157170
---@param output Output Output object to write to
158171
---@param part OpencodeMessagePart
159172
function M._format_patch(output, part)
173+
if not part.hash then
174+
return
175+
end
176+
160177
local restore_points = snapshot.get_restore_points_by_parent(part.hash) or {}
161178
M._format_action(output, icons.get('snapshot') .. ' Created Snapshot', vim.trim(part.hash:sub(1, 8)))
162-
local snapshot_header_line = output:get_line_count()
163179

164180
-- Anchor all snapshot-level actions to the snapshot header line
165-
output:add_action({
166-
text = '[R]evert file',
167-
type = 'diff_revert_selected_file',
168-
args = { part.hash },
169-
key = 'R',
170-
display_line = snapshot_header_line,
171-
range = { from = snapshot_header_line, to = snapshot_header_line },
172-
})
173-
output:add_action({
174-
text = 'Revert [A]ll',
175-
type = 'diff_revert_all',
176-
args = { part.hash },
177-
key = 'A',
178-
display_line = snapshot_header_line,
179-
range = { from = snapshot_header_line, to = snapshot_header_line },
180-
})
181-
output:add_action({
182-
text = '[D]iff',
183-
type = 'diff_open',
184-
args = { part.hash },
185-
key = 'D',
186-
display_line = snapshot_header_line,
187-
range = { from = snapshot_header_line, to = snapshot_header_line },
188-
})
181+
add_action(output, '[R]evert file', 'diff_revert_selected_file', { part.hash }, 'R')
182+
add_action(output, 'Revert [A]ll', 'diff_revert_all', { part.hash }, 'A')
183+
add_action(output, '[D]iff', 'diff_open', { part.hash }, 'D')
189184

190185
if #restore_points > 0 then
191186
for _, restore_point in ipairs(restore_points) do
192187
output:add_line(
193188
string.format(
194-
' %s Restore point `%s` - %s',
189+
' %s Restore point `%s` - %s ',
195190
icons.get('restore_point'),
196-
restore_point.id:sub(1, 8),
191+
vim.trim(restore_point.id:sub(1, 8)),
197192
util.format_time(restore_point.created_at)
198193
)
199194
)
200-
local restore_line = output:get_line_count()
201-
output:add_action({
202-
text = 'Restore [A]ll',
203-
type = 'diff_restore_snapshot_all',
204-
args = { restore_point.id },
205-
key = 'A',
206-
display_line = restore_line,
207-
range = { from = restore_line, to = restore_line },
208-
})
209-
output:add_action({
210-
text = '[R]estore file',
211-
type = 'diff_restore_snapshot_file',
212-
args = { restore_point.id },
213-
key = 'R',
214-
display_line = restore_line,
215-
range = { from = restore_line, to = restore_line },
216-
})
195+
add_action(output, 'Restore [A]ll', 'diff_restore_snapshot_all', { restore_point.id }, 'A')
196+
add_action(output, '[R]estore file', 'diff_restore_snapshot_file', { restore_point.id }, 'R')
217197
end
218198
end
219199
end
@@ -282,10 +262,10 @@ function M.format_message_header(message)
282262
and (not message.parts or #message.parts == 0)
283263
then
284264
local error = message.info.error
285-
local error_messgage = error.data and error.data.message or vim.inspect(error)
265+
local error_message = error.data and error.data.message or vim.inspect(error)
286266

287267
output:add_line('')
288-
M._format_callout(output, 'ERROR', error_messgage)
268+
M._format_callout(output, 'ERROR', error_message)
289269
end
290270

291271
output:add_line('')
@@ -797,8 +777,8 @@ function M.format_part(part, message, is_last_part)
797777

798778
if is_last_part and role == 'assistant' and message.info.error and message.info.error ~= '' then
799779
local error = message.info.error
800-
local error_messgage = error.data and error.data.message or vim.inspect(error)
801-
M._format_callout(output, 'ERROR', error_messgage)
780+
local error_message = error.data and error.data.message or vim.inspect(error)
781+
M._format_callout(output, 'ERROR', error_message)
802782
output:add_empty_line()
803783
end
804784

lua/opencode/ui/render_state.lua

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ end
7777
---@param snapshot_id string Call ID
7878
---@return OpencodeMessagePart? part Part if found
7979
function RenderState:get_part_by_snapshot_id(snapshot_id)
80-
for _, rendered_message in pairs(self._messages) do
81-
for _, part in ipairs(rendered_message.message.parts) do
80+
for _, rendered_message in pairs(self._messages or {}) do
81+
for _, part in ipairs(rendered_message.message.parts or {}) do
8282
if part.type == 'patch' and part.hash == snapshot_id then
8383
return part
8484
end
@@ -119,7 +119,7 @@ function RenderState:get_message_at_line(line)
119119
end
120120

121121
---Get actions at specific line
122-
---@param line integer Line number (1-indexed)
122+
---@param line integer Line number (0-indexed)
123123
---@return table[] List of actions at that line
124124
function RenderState:get_actions_at_line(line)
125125
self:_ensure_line_index()

lua/opencode/ui/renderer.lua

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,6 @@ function M._write_formatted_data(formatted_data, part_id, start_line)
275275
return nil
276276
end
277277

278-
if part_id and formatted_data.actions then
279-
M._render_state:add_actions(part_id, formatted_data.actions, target_line)
280-
end
281-
282278
if is_insertion then
283279
output_window.set_lines(new_lines, target_line, target_line)
284280
else
@@ -287,6 +283,15 @@ function M._write_formatted_data(formatted_data, part_id, start_line)
287283
target_line = target_line - 1
288284
output_window.set_lines(extra_newline, target_line)
289285
end
286+
287+
-- update actions and extmarks after the insertion because that may
288+
-- adjust target_line (e.g. when we we're replacing the double newline at
289+
-- the end)
290+
291+
if part_id and formatted_data.actions then
292+
M._render_state:add_actions(part_id, formatted_data.actions, target_line)
293+
end
294+
290295
output_window.set_extmarks(extmarks, target_line)
291296

292297
return {
@@ -413,7 +418,7 @@ function M._replace_part_in_buffer(part_id, formatted_data)
413418
output_window.set_extmarks(formatted_data.extmarks, cached.line_start)
414419

415420
if formatted_data.actions then
416-
M._render_state:add_actions(part_id, formatted_data.actions, cached.line_start)
421+
M._render_state:add_actions(part_id, formatted_data.actions, cached.line_start + 1)
417422
end
418423

419424
M._render_state:update_part_lines(part_id, cached.line_start, new_line_end)
@@ -959,7 +964,7 @@ function M.on_session_changed(_, new, _)
959964
end
960965

961966
---Get all actions available at a specific line
962-
---@param line integer 1-indexed line number
967+
---@param line integer 0-indexed line number
963968
---@return table[] List of actions available at that line
964969
function M.get_actions_for_line(line)
965970
return M._render_state:get_actions_at_line(line)

tests/data/diagnostics.expected.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)