Skip to content

Commit ea0a849

Browse files
authored
feat(inference): add waiter model (scaleway#2534)
1 parent b4babe8 commit ea0a849

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

api/inference/v1/inference_utils.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,47 @@ func (s *API) WaitForDeployment(req *WaitForDeploymentRequest, opts ...scw.Reque
5757
}
5858
return deployment.(*Deployment), nil
5959
}
60+
61+
type WaitForCustomModelRequest struct {
62+
ModelID string
63+
Region scw.Region
64+
Timeout *time.Duration
65+
RetryInterval *time.Duration
66+
}
67+
68+
func (s *API) WaitForCustomModel(req WaitForCustomModelRequest, opts ...scw.RequestOption) (*Model, error) {
69+
timeout := defaultTimeout
70+
if req.Timeout != nil {
71+
timeout = *req.Timeout
72+
}
73+
retryInterval := defaultRetryInterval
74+
if req.RetryInterval != nil {
75+
retryInterval = *req.RetryInterval
76+
}
77+
78+
terminalStatus := map[ModelStatus]struct{}{
79+
ModelStatusReady: {},
80+
ModelStatusError: {},
81+
}
82+
83+
model, err := async.WaitSync(&async.WaitSyncConfig{
84+
Get: func() (interface{}, bool, error) {
85+
model, err := s.GetModel(&GetModelRequest{
86+
Region: req.Region,
87+
ModelID: req.ModelID,
88+
}, opts...)
89+
if err != nil {
90+
return nil, false, err
91+
}
92+
_, isTerminal := terminalStatus[model.Status]
93+
return model, isTerminal, nil
94+
},
95+
IntervalStrategy: async.LinearIntervalStrategy(retryInterval),
96+
Timeout: timeout,
97+
})
98+
if err != nil {
99+
return nil, errors.Wrap(err, "waiting for model failed")
100+
}
101+
102+
return model.(*Model), nil
103+
}

0 commit comments

Comments
 (0)