Skip to content

Commit cee01a5

Browse files
committed
feat(command): add model command to get/set the model
1 parent 2d75114 commit cee01a5

File tree

5 files changed

+278
-2
lines changed

5 files changed

+278
-2
lines changed

lua/copilot/api/init.lua

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,14 @@ end
154154
---@alias copilot_window_show_document { uri: string, external?: boolean, takeFocus?: boolean, selection?: boolean }
155155
---@alias copilot_window_show_document_result { success: boolean }
156156

157+
---@alias copilot_model { id: string, modelName: string, scopes: string[], preview?: boolean, default?: boolean }
158+
---@alias copilot_models_data copilot_model[]
159+
160+
---@return any|nil err
161+
---@return copilot_models_data data
162+
---@return table ctx
163+
function M.get_models(client, callback)
164+
return M.request(client, "copilot/models", {}, callback)
165+
end
166+
157167
return M

lua/copilot/client/config.lua

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ function M.prepare_client_config(overrides, client)
124124
end
125125

126126
require("copilot.nes").setup(lsp_client)
127+
128+
-- Validate configured model on startup
129+
if config.copilot_model and config.copilot_model ~= "" then
130+
require("copilot.model").validate_current()
131+
end
127132
end)
128133
end,
129134
on_exit = function(code, _, client_id)

lua/copilot/client/utils.lua

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ function M.get_workspace_configurations()
2929
filetypes = vim.tbl_deep_extend("keep", filetypes, client_ft.internal_filetypes)
3030
end
3131

32-
local copilot_model = config and config.copilot_model ~= "" and config.copilot_model or ""
32+
-- Use model module to get the current model (supports runtime override)
33+
local model = require("copilot.model")
34+
local copilot_model = model.get_current_model()
3335

3436
---@type string[]
3537
local disabled_filetypes = vim.tbl_filter(function(ft)

lua/copilot/model.lua

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
local c = require("copilot.client")
2+
local api = require("copilot.api")
3+
local config = require("copilot.config")
4+
local logger = require("copilot.logger")
5+
6+
local M = {}
7+
8+
--- Runtime override of the model (not persisted to user config)
9+
--- When set, this takes precedence over config.copilot_model
10+
---@type string|nil
11+
M.selected_model = nil
12+
13+
--- Get the currently active model ID
14+
---@return string
15+
function M.get_current_model()
16+
return M.selected_model or config.copilot_model or ""
17+
end
18+
19+
--- Filter models that support completions
20+
---@param models copilot_model[]
21+
---@return copilot_model[]
22+
local function get_completion_models(models)
23+
return vim.tbl_filter(function(m)
24+
return vim.tbl_contains(m.scopes or {}, "completion")
25+
end, models)
26+
end
27+
28+
--- Format a model for display
29+
---@param model copilot_model
30+
---@return string
31+
local function format_model(model, show_id)
32+
local parts = { model.modelName }
33+
if show_id then
34+
table.insert(parts, "[" .. model.id .. "]")
35+
end
36+
local annotations = {}
37+
38+
if model.default then
39+
table.insert(annotations, "default")
40+
end
41+
if model.preview then
42+
table.insert(annotations, "preview")
43+
end
44+
45+
if #annotations > 0 then
46+
table.insert(parts, "(" .. table.concat(annotations, ", ") .. ")")
47+
end
48+
49+
return table.concat(parts, " ")
50+
end
51+
52+
--- Apply the selected model by notifying the LSP server
53+
---@param model_id string
54+
local function apply_model(model_id)
55+
M.selected_model = model_id
56+
57+
local client = c.get()
58+
if client then
59+
local utils = require("copilot.client.utils")
60+
local configurations = utils.get_workspace_configurations()
61+
api.notify_change_configuration(client, configurations)
62+
logger.debug("Model changed to: " .. model_id)
63+
end
64+
end
65+
66+
--- Interactive model selection using vim.ui.select
67+
---@param opts? { force?: boolean, args?: string }
68+
function M.select(opts)
69+
opts = opts or {}
70+
71+
local client = c.get()
72+
if not client then
73+
logger.notify("Copilot client not running")
74+
return
75+
end
76+
77+
coroutine.wrap(function()
78+
local err, models = api.get_models(client)
79+
if err then
80+
logger.notify("Failed to get models: " .. vim.inspect(err))
81+
return
82+
end
83+
84+
if not models or #models == 0 then
85+
logger.notify("No models available")
86+
return
87+
end
88+
89+
local completion_models = get_completion_models(models)
90+
if #completion_models == 0 then
91+
logger.notify("No completion models available")
92+
return
93+
end
94+
95+
local current_model = M.get_current_model()
96+
if #completion_models == 1 then
97+
local model = completion_models[1]
98+
local model_name = format_model(model)
99+
logger.notify("Only one completion model available: " .. model_name)
100+
if model.id ~= current_model then
101+
apply_model(model.id)
102+
logger.notify("Copilot model set to: " .. model_name)
103+
else
104+
logger.notify("Copilot model is already set to: " .. model_name)
105+
end
106+
return
107+
end
108+
109+
-- Sort models: default first, then by name
110+
table.sort(completion_models, function(a, b)
111+
if a.default and not b.default then
112+
return true
113+
end
114+
if b.default and not a.default then
115+
return false
116+
end
117+
return a.modelName < b.modelName
118+
end)
119+
120+
vim.ui.select(completion_models, {
121+
prompt = "Select Copilot completion model:",
122+
format_item = function(model)
123+
local display = format_model(model)
124+
if model.id == current_model then
125+
display = display .. " [current]"
126+
end
127+
return display
128+
end,
129+
}, function(selected)
130+
if not selected then
131+
return
132+
end
133+
134+
apply_model(selected.id)
135+
logger.notify("Copilot model set to: " .. format_model(selected))
136+
end)
137+
end)()
138+
end
139+
140+
--- List available completion models
141+
---@param opts? { force?: boolean, args?: string }
142+
function M.list(opts)
143+
opts = opts or {}
144+
145+
local client = c.get()
146+
if not client then
147+
logger.notify("Copilot client not running")
148+
return
149+
end
150+
151+
coroutine.wrap(function()
152+
local err, models = api.get_models(client)
153+
if err then
154+
logger.notify("Failed to get models: " .. vim.inspect(err))
155+
return
156+
end
157+
158+
if not models or #models == 0 then
159+
logger.notify("No models available")
160+
return
161+
end
162+
163+
local completion_models = get_completion_models(models)
164+
if #completion_models == 0 then
165+
logger.notify("No completion models available")
166+
return
167+
end
168+
169+
local current_model = M.get_current_model()
170+
local lines = { "Available completion models:" }
171+
172+
for _, model in ipairs(completion_models) do
173+
local line = " " .. format_model(model, true)
174+
if model.id == current_model then
175+
line = line .. " <- current"
176+
end
177+
table.insert(lines, line)
178+
end
179+
180+
logger.notify(table.concat(lines, "\n"))
181+
end)()
182+
end
183+
184+
--- Show the current model
185+
---@param opts? { force?: boolean, args?: string }
186+
function M.get(opts)
187+
opts = opts or {}
188+
189+
local current = M.get_current_model()
190+
if current == "" then
191+
logger.notify("No model configured (using server default)")
192+
else
193+
logger.notify("Current model: " .. current)
194+
end
195+
end
196+
197+
--- Set the model programmatically
198+
---@param opts { model?: string, force?: boolean, args?: string }
199+
function M.set(opts)
200+
opts = opts or {}
201+
202+
local model_id = opts.model or opts.args
203+
if not model_id or model_id == "" then
204+
logger.notify("Usage: :Copilot model set <model-id>")
205+
return
206+
end
207+
208+
apply_model(model_id)
209+
logger.notify("Copilot model set to: " .. model_id)
210+
end
211+
212+
--- Validate the currently configured model against available models
213+
--- Called on startup to warn if the configured model is invalid
214+
function M.validate_current()
215+
local configured_model = config.copilot_model
216+
if not configured_model or configured_model == "" then
217+
return -- No model configured, nothing to validate
218+
end
219+
220+
local client = c.get()
221+
if not client then
222+
return
223+
end
224+
225+
coroutine.wrap(function()
226+
local err, models = api.get_models(client)
227+
if err then
228+
logger.debug("Failed to validate model: " .. vim.inspect(err))
229+
return
230+
end
231+
232+
if not models or #models == 0 then
233+
return
234+
end
235+
236+
local completion_models = get_completion_models(models)
237+
local valid_ids = vim.tbl_map(function(m)
238+
return m.id
239+
end, completion_models)
240+
241+
if not vim.tbl_contains(valid_ids, configured_model) then
242+
local valid_list = table.concat(valid_ids, ", ")
243+
logger.warn(
244+
string.format(
245+
"Configured copilot_model '%s' is not a valid completion model. Available: %s",
246+
configured_model,
247+
valid_list
248+
)
249+
)
250+
else
251+
logger.debug("Configured model '" .. configured_model .. "' is valid")
252+
end
253+
end)()
254+
end
255+
256+
return M

plugin/copilot.lua

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
local completion_store = {
2-
[""] = { "auth", "attach", "detach", "disable", "enable", "panel", "status", "suggestion", "toggle", "version" },
2+
[""] = { "auth", "attach", "detach", "disable", "enable", "model", "panel", "status", "suggestion", "toggle", "version" },
33
auth = { "signin", "signout", "info" },
4+
model = { "select", "list", "get", "set" },
45
panel = { "accept", "jump_next", "jump_prev", "open", "refresh", "toggle", "close", "is_open" },
56
suggestion = {
67
"accept",
@@ -34,6 +35,8 @@ vim.api.nvim_create_user_command("Copilot", function(opts)
3435
if not action_name then
3536
if mod_name == "auth" then
3637
action_name = "signin"
38+
elseif mod_name == "model" then
39+
action_name = "get"
3740
elseif mod_name == "panel" then
3841
action_name = "open"
3942
elseif mod_name == "suggestion" then

0 commit comments

Comments
 (0)