Skip to content

Commit 753ffc4

Browse files
committed
feat(restore-point): make restore points work
1 parent 399bcdb commit 753ffc4

File tree

7 files changed

+54
-27
lines changed

7 files changed

+54
-27
lines changed

lua/opencode/event_manager.lua

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ local state = require('opencode.state')
8888

8989
--- @class ServerStoppedEvent
9090

91-
--- @alias EventName
91+
--- @class RestorePointCreatedEvent
92+
--- @field restore_point RestorePoint
93+
94+
--- @alias OpencodeEventName
9295
--- | "installation.updated"
9396
--- | "lsp.client.diagnostics"
9497
--- | "message.updated"
@@ -109,6 +112,7 @@ local state = require('opencode.state')
109112
--- | "custom.server_starting"
110113
--- | "custom.server_ready"
111114
--- | "custom.server_stopped"
115+
--- | "custom.restore_point.created"
112116

113117
--- @class EventManager
114118
--- @field events table<string, function[]> Event listener registry
@@ -150,7 +154,8 @@ end
150154
--- @overload fun(self: EventManager, event_name: "custom.server_starting", callback: fun(data: ServerStartingEvent): nil)
151155
--- @overload fun(self: EventManager, event_name: "custom.server_ready", callback: fun(data: ServerReadyEvent): nil)
152156
--- @overload fun(self: EventManager, event_name: "custom.server_stopped", callback: fun(data: ServerStoppedEvent): nil)
153-
--- @param event_name EventName The event name to listen for
157+
--- @overload fun(self: EventManager, event_name: "custom.restore_point.created", callback: fun(data: RestorePointCreatedEvent): nil)
158+
--- @param event_name OpencodeEventName The event name to listen for
154159
--- @param callback function Callback function to execute when event is triggered
155160
function EventManager:subscribe(event_name, callback)
156161
if not self.events[event_name] then
@@ -180,7 +185,8 @@ end
180185
--- @overload fun(self: EventManager, event_name: "custom.server_starting", callback: fun(data: ServerStartingEvent): nil)
181186
--- @overload fun(self: EventManager, event_name: "custom.server_ready", callback: fun(data: ServerReadyEvent): nil)
182187
--- @overload fun(self: EventManager, event_name: "custom.server_stopped", callback: fun(data: ServerStoppedEvent): nil)
183-
--- @param event_name EventName The event name
188+
--- @overload fun(self: EventManager, event_name: "custom.restore_point.created", callback: fun(data: RestorePointCreatedEvent): nil)
189+
--- @param event_name OpencodeEventName The event name
184190
--- @param callback function The callback function to remove
185191
function EventManager:unsubscribe(event_name, callback)
186192
local listeners = self.events[event_name]
@@ -197,14 +203,21 @@ function EventManager:unsubscribe(event_name, callback)
197203
end
198204

199205
--- Emit an event to all subscribers
200-
--- @param event_name EventName The event name
206+
--- @param event_name OpencodeEventName The event name
201207
--- @param data any Data to pass to event listeners
202208
function EventManager:emit(event_name, data)
203209
local listeners = self.events[event_name]
204210
if not listeners then
205211
return
206212
end
207213

214+
local event = { type = event_name, properties = data }
215+
216+
if require('opencode.config').debug.capture_streamed_events then
217+
table.insert(self.captured_events, vim.deepcopy(event))
218+
end
219+
220+
-- schedule events to allow for similar pieces of state to be updated
208221
for _, callback in ipairs(listeners) do
209222
pcall(callback, data)
210223
end
@@ -269,22 +282,11 @@ function EventManager:_subscribe_to_server_events(server)
269282
local api_client = state.api_client
270283

271284
local emitter = function(event)
272-
-- schedule events to allow for similar pieces of state to be updated
273285
vim.schedule(function()
274286
self:emit(event.type, event.properties)
275287
end)
276288
end
277289

278-
if require('opencode.config').debug.capture_streamed_events then
279-
local _emitter = emitter
280-
emitter = function(event)
281-
-- make a deepcopy to make sure we're saving a clean copy
282-
-- (we modify event in renderer)
283-
table.insert(self.captured_events, vim.deepcopy(event))
284-
_emitter(event)
285-
end
286-
end
287-
288290
self.server_subscription = api_client:subscribe_to_events(nil, emitter)
289291
end
290292

@@ -312,7 +314,7 @@ function EventManager:get_event_names()
312314
end
313315

314316
--- Get number of subscribers for an event
315-
--- @param event_name EventName The event name
317+
--- @param event_name OpencodeEventName The event name
316318
--- @return number Number of subscribers
317319
function EventManager:get_subscriber_count(event_name)
318320
local listeners = self.events[event_name]

lua/opencode/snapshot.lua

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function M.save_restore_point(snapshot_id, from_snapshot_id, deleted_files)
9393
return nil
9494
end
9595

96-
state.append('restore_points', snapshot)
96+
state.event_manager:emit('custom.restore_point.created', { restore_point = snapshot })
9797
return snapshot
9898
end
9999

@@ -254,7 +254,7 @@ function M.restore_file(snapshot_id, file_path)
254254
end
255255

256256
---@param from_snapshot_id string
257-
---@return RestorePoint[]
257+
---@return RestorePoint[]|nil
258258
function M.get_restore_points_by_parent(from_snapshot_id)
259259
local restore_points = M.get_restore_points()
260260
restore_points = vim.tbl_filter(function(item)
@@ -263,6 +263,9 @@ function M.get_restore_points_by_parent(from_snapshot_id)
263263
table.sort(restore_points, function(a, b)
264264
return a.created_at > b.created_at
265265
end)
266+
if #restore_points == 0 then
267+
return nil
268+
end
266269
return restore_points
267270
end
268271

lua/opencode/state.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ local config = require('opencode.config')
2020
---@field last_output number
2121
---@field last_sent_context any
2222
---@field active_session Session|nil
23-
---@field restore_points table<string, any>
23+
---@field restore_points RestorePoint[]
2424
---@field current_model string|nil
2525
---@field current_model_info table|nil
2626
---@field messages OpencodeMessage[]|nil

lua/opencode/ui/formatter.lua

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ end
156156
---@param output Output Output object to write to
157157
---@param part MessagePart
158158
function M._format_patch(output, part)
159-
local restore_points = snapshot.get_restore_points_by_parent(part.hash)
159+
local restore_points = snapshot.get_restore_points_by_parent(part.hash) or {}
160160
M._format_action(output, icons.get('snapshot') .. ' Created Snapshot', vim.trim(part.hash:sub(1, 8)))
161161
local snapshot_header_line = output:get_line_count()
162162

@@ -200,15 +200,15 @@ function M._format_patch(output, part)
200200
output:add_action({
201201
text = 'Restore [A]ll',
202202
type = 'diff_restore_snapshot_all',
203-
args = { part.hash },
203+
args = { restore_point.id },
204204
key = 'A',
205205
display_line = restore_line,
206206
range = { from = restore_line, to = restore_line },
207207
})
208208
output:add_action({
209209
text = '[R]estore file',
210210
type = 'diff_restore_snapshot_file',
211-
args = { part.hash },
211+
args = { restore_point.id },
212212
key = 'R',
213213
display_line = restore_line,
214214
range = { from = restore_line, to = restore_line },

lua/opencode/ui/output_window.lua

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ function M.setup(windows)
5151

5252
M.update_dimensions(windows)
5353
M.setup_keymaps(windows)
54-
state.subscribe('restore_points', function(_, new_val, old_val)
55-
-- FIXME: restore points
56-
-- local outout_renderer = require('opencode.ui.output_renderer')
57-
-- outout_renderer.render(state.windows, true)
58-
end)
5954
end
6055

6156
function M.update_dimensions(windows)

lua/opencode/ui/render_state.lua

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,20 @@ function RenderState:get_part_by_call_id(call_id, message_id)
7373
return nil
7474
end
7575

76+
---Get part ID by snapshot_id and message ID
77+
---@param snapshot_id string Call ID
78+
---@return MessagePart? part Part if found
79+
function RenderState:get_part_by_snapshot_id(snapshot_id)
80+
for _, rendered_message in pairs(self._messages) do
81+
for _, part in ipairs(rendered_message.message.parts) do
82+
if part.type == 'patch' and part.hash == snapshot_id then
83+
return part
84+
end
85+
end
86+
end
87+
return nil
88+
end
89+
7690
---Ensure line index is up to date
7791
function RenderState:_ensure_line_index()
7892
if not self._line_index_valid then

lua/opencode/ui/renderer.lua

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ function M._setup_event_subscriptions(subscribe)
7777
state.event_manager[method](state.event_manager, 'permission.updated', M.on_permission_updated)
7878
state.event_manager[method](state.event_manager, 'permission.replied', M.on_permission_replied)
7979
state.event_manager[method](state.event_manager, 'file.edited', M.on_file_edited)
80+
state.event_manager[method](state.event_manager, 'custom.restore_point.created', M.on_restore_points)
8081
end
8182

8283
---Unsubscribe from local state and server subscriptions
@@ -635,6 +636,18 @@ function M.on_file_edited(properties)
635636
vim.cmd('checktime')
636637
end
637638

639+
---@param properties RestorePointCreatedEvent
640+
function M.on_restore_points(properties)
641+
state.append('restore_points', properties.restore_point)
642+
if not properties or not properties.restore_point or not properties.restore_point.from_snapshot_id then
643+
return
644+
end
645+
local part = M._render_state:get_part_by_snapshot_id(properties.restore_point.from_snapshot_id)
646+
if part then
647+
M.on_part_updated({ part = part })
648+
end
649+
end
650+
638651
---Find part ID by call ID and message ID
639652
---Useful for finding a part for a permission
640653
---@param call_id string Call ID to search for

0 commit comments

Comments
 (0)