Skip to content

Commit f8d002a

Browse files
committed
fix: install instant-model by id rather than modelID
1 parent e008336 commit f8d002a

File tree

6 files changed

+53
-58
lines changed

6 files changed

+53
-58
lines changed

pkg/apis/llm/llm.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ type ModelInfo struct {
101101
DisplayName string `json:"display_name"`
102102
// 秒装模型 tag,如: 7b
103103
Tag string `json:"tag"`
104+
// 秒装模型 LLM 类型
105+
LlmType string `json:"llm_type"`
104106
}
105107

106108
type LLMPerformQuickModelsInput struct {

pkg/llm/models/instantmodel.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -576,15 +576,15 @@ func (model *SInstantModel) PerformEnable(
576576
return nil, errors.Wrapf(errors.ErrInvalidStatus, "cannot enable model of status %s", model.Status)
577577
}
578578
// check duplicate
579-
{
580-
existing, err := GetInstantModelManager().findInstantModel(model.ModelId, model.ModelTag, true)
581-
if err != nil {
582-
return nil, errors.Wrap(err, "findInstantModel")
583-
}
584-
if existing != nil && existing.Id != model.Id {
585-
return nil, errors.Wrapf(errors.ErrDuplicateId, "model of modelId %s tag %s has been enabled", model.ModelId, model.ModelTag)
586-
}
587-
}
579+
// {
580+
// existing, err := GetInstantModelManager().findInstantModel(model.ModelId, model.ModelTag, true)
581+
// if err != nil {
582+
// return nil, errors.Wrap(err, "findInstantModel")
583+
// }
584+
// if existing != nil && existing.Id != model.Id {
585+
// return nil, errors.Wrapf(errors.ErrDuplicateId, "model of modelId %s tag %s has been enabled", model.ModelId, model.ModelTag)
586+
// }
587+
// }
588588
_, err := db.Update(model, func() error {
589589
model.SEnabledResourceBase.SetEnabled(true)
590590
return nil

pkg/llm/models/llm_instant_model_sync.go

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,13 @@ func (llm *SLLM) PerformQuickModels(ctx context.Context, userCred mcclient.Token
250250
errs = append(errs, errors.Wrap(err, "FetchByIdOrName"))
251251
}
252252
} else {
253-
instApp := instModelObj.(*SInstantModel)
254-
input.Models[i].Id = instApp.Id
255-
input.Models[i].ModelId = instApp.ModelId
256-
input.Models[i].Tag = instApp.ModelTag
253+
instMdl := instModelObj.(*SInstantModel)
254+
input.Models[i].Id = instMdl.Id
255+
input.Models[i].ModelId = instMdl.ModelId
256+
input.Models[i].Tag = instMdl.ModelTag
257+
input.Models[i].LlmType = instMdl.LlmType
257258
if input.Method == apis.QuickModelInstall {
258-
toInstallSizeGb += float64(instApp.GetActualSizeMb()) * 1024 * 1024 / 1000 / 1000 / 1000
259+
toInstallSizeGb += float64(instMdl.GetActualSizeMb()) * 1024 * 1024 / 1000 / 1000 / 1000
259260
}
260261
}
261262
} else {
@@ -269,8 +270,12 @@ func (llm *SLLM) PerformQuickModels(ctx context.Context, userCred mcclient.Token
269270
input.Models[i].Id = mdl.Id
270271
input.Models[i].Tag = mdl.ModelTag
271272
input.Models[i].ModelId = mdl.ModelId
273+
input.Models[i].LlmType = mdl.LlmType
272274
}
273275
}
276+
if !apis.IsLLMContainerType(input.Models[i].LlmType) || apis.LLMContainerType(input.Models[i].LlmType) != llm.GetLLMContainerDriver().GetType() {
277+
errs = append(errs, errors.Wrapf(httperrors.ErrInvalidStatus, "model %s is not of type %s", input.Models[i].ModelId, llm.GetLLMContainerDriver().GetType()))
278+
}
274279
}
275280
if len(errs) > 0 {
276281
return nil, errors.NewAggregate(errs)
@@ -610,7 +615,7 @@ type mdlFullNameInfo struct {
610615
IsMounted bool
611616
}
612617

613-
func (llm *SLLM) UpdateMountedModelFullNames(ctx context.Context, mdlinfos []string, isReset bool, imageId string, skuId string) error {
618+
func (llm *SLLM) UpdateMountedModelFullNames(ctx context.Context, userCred mcclient.TokenCredential, mdlinfos []string, isReset bool, imageId string, skuId string) error {
614619
mdlFullNameInfos := make(map[string]*mdlFullNameInfo)
615620
for i := range mdlinfos {
616621
parts := strings.Split(mdlinfos[i], "@")
@@ -636,15 +641,19 @@ func (llm *SLLM) UpdateMountedModelFullNames(ctx context.Context, mdlinfos []str
636641
}
637642
}
638643
for i := range sku.MountedModels {
639-
parts := strings.Split(sku.MountedModels[i], "@")
640-
if !isReset && slices.Contains(deletedModelIds, parts[0]) {
641-
// if not reset, and the package is deleted, skip it
644+
instMdl, err := GetInstantModelManager().FetchByIdOrName(ctx, userCred, sku.MountedModels[i])
645+
if err != nil {
646+
return errors.Wrap(err, "FetchByIdOrName")
647+
}
648+
instantModle := instMdl.(*SInstantModel)
649+
if !isReset && slices.Contains(deletedModelIds, instantModle.ModelId) {
650+
// if not reset, and the model is deleted, skip it
642651
continue
643652
}
644-
if _, ok := mdlFullNameInfos[parts[0]]; !ok {
645-
mdlFullNameInfos[parts[0]] = &mdlFullNameInfo{
646-
ModelId: parts[0],
647-
ModelFullName: parts[1],
653+
if _, ok := mdlFullNameInfos[instantModle.ModelId]; !ok {
654+
mdlFullNameInfos[instantModle.ModelId] = &mdlFullNameInfo{
655+
ModelId: instantModle.ModelId,
656+
ModelFullName: instantModle.ModelName + ":" + instantModle.ModelTag,
648657
IsMounted: false,
649658
}
650659
}

pkg/llm/models/llm_pod.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func GetLLMPodCreateInput(
2626

2727
// generate post overlay info
2828
{
29-
err = llm.UpdateMountedModelFullNames(ctx, nil, true, input.LLMImageId, input.LLMSkuId)
29+
err = llm.UpdateMountedModelFullNames(ctx, userCred, nil, true, input.LLMImageId, input.LLMSkuId)
3030
if err != nil {
3131
return nil, errors.Wrap(err, "UpdateMountedModelFullNames")
3232
}

pkg/llm/models/llm_sku.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,20 @@ func (sku *SLLMSku) ValidateUpdateData(ctx context.Context, userCred mcclient.To
200200
return input, errors.Wrap(err, "validate LLMSkuBaseUpdateInput")
201201
}
202202

203+
if input.MountedModels != nil {
204+
for i, mdl := range input.MountedModels {
205+
instMdl, err := GetInstantModelManager().FetchByIdOrName(ctx, userCred, mdl)
206+
if err != nil {
207+
return input, errors.Wrapf(err, "validate mounted model %s", mdl)
208+
}
209+
instantModle := instMdl.(*SInstantModel)
210+
if instantModle.LlmType != sku.LLMType {
211+
return input, errors.Wrapf(httperrors.ErrInvalidStatus, "mounted model %s is not of type %s", mdl, sku.LLMType)
212+
}
213+
input.MountedModels[i] = instantModle.GetId()
214+
}
215+
}
216+
203217
if input.LLMImageId != "" {
204218
imgObj, err := validators.ValidateModel(ctx, userCred, GetLLMImageManager(), &input.LLMImageId)
205219
if err != nil {

pkg/mcclient/options/llm/llm.go

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"strings"
55

66
"yunion.io/x/jsonutils"
7-
"yunion.io/x/pkg/util/regutils"
87

98
api "yunion.io/x/onecloud/pkg/apis/llm"
109
"yunion.io/x/onecloud/pkg/mcclient/options"
@@ -135,47 +134,18 @@ func (opts *LLMSaveInstantModelOptions) Params() (jsonutils.JSONObject, error) {
135134
type LLMQuickModelsOptions struct {
136135
LLMIdOptions
137136

138-
MODEL []string `help:"model id and optional display name in the format of modelId[@modelName:modelTag], e.g. 6f48b936a09f or 6f48b936a09f@qwen2:0.5b"`
137+
MODEL []string `help:"model id of instant model, e.g. qwen3:0.6b-251202 or 7f72b5a1-4049-43db-8e91-8dee736ae1ac"`
139138

140139
Method string `help:"install or uninstall" choices:"install|uninstall"`
141140
}
142141

143142
func (opts *LLMQuickModelsOptions) Params() (jsonutils.JSONObject, error) {
144143
params := api.LLMPerformQuickModelsInput{}
145-
for _, mdlFul := range opts.MODEL {
146-
var mdl api.ModelInfo
147-
148-
var idPart string
149-
var nameAndTagPart string
150-
151-
if idx := strings.Index(mdlFul, "@"); idx >= 0 {
152-
idPart = mdlFul[:idx]
153-
nameAndTagPart = mdlFul[idx+1:]
154-
155-
if idxTag := strings.LastIndex(nameAndTagPart, ":"); idxTag >= 0 {
156-
mdl.DisplayName = nameAndTagPart[:idxTag]
157-
mdl.Tag = nameAndTagPart[idxTag+1:]
158-
} else {
159-
mdl.DisplayName = nameAndTagPart
160-
}
161-
} else {
162-
idPart = mdlFul
163-
164-
if idxTag := strings.LastIndex(idPart, ":"); idxTag >= 0 {
165-
mdl.Tag = idPart[idxTag+1:]
166-
idPart = idPart[:idxTag]
167-
}
168-
}
169-
170-
if regutils.MatchUUID(idPart) {
171-
mdl.Id = idPart
172-
} else {
173-
mdl.ModelId = idPart
174-
}
175-
176-
params.Models = append(params.Models, mdl)
144+
for _, mdl := range opts.MODEL {
145+
params.Models = append(params.Models, api.ModelInfo{
146+
Id: mdl,
147+
})
177148
}
178-
179149
if len(opts.Method) > 0 {
180150
params.Method = api.TQuickModelMethod(opts.Method)
181151
}

0 commit comments

Comments
 (0)