Skip to content

Commit 14a85a7

Browse files
authored
feat: Model redirect (#334)
* feat: Model redirect * fix: Hint error * fix: Prefix models error * feat: Override models API * feat: Code style
1 parent 2f6ba02 commit 14a85a7

File tree

22 files changed

+1026
-177
lines changed

22 files changed

+1026
-177
lines changed

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.23.0
55
toolchain go1.24.3
66

77
require (
8+
github.com/andybalholm/brotli v1.2.0
89
github.com/gin-contrib/gzip v1.2.3
910
github.com/gin-contrib/static v1.1.5
1011
github.com/gin-gonic/gin v1.10.1
@@ -13,6 +14,7 @@ require (
1314
github.com/google/uuid v1.6.0
1415
github.com/jackc/pgx/v5 v5.6.0
1516
github.com/joho/godotenv v1.5.1
17+
github.com/klauspost/compress v1.18.1
1618
github.com/nicksnyder/go-i18n/v2 v2.6.0
1719
github.com/redis/go-redis/v9 v9.5.3
1820
github.com/sirupsen/logrus v1.9.3

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
22
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
33
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
44
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
5+
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
6+
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
57
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
68
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
79
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -77,6 +79,8 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
7779
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
7880
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
7981
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
82+
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
83+
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
8084
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
8185
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
8286
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
@@ -129,6 +133,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
129133
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
130134
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
131135
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
136+
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
137+
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
132138
go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4=
133139
go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
134140
golang.org/x/arch v0.16.0 h1:foMtLTdyOmIniqWCHjY6+JxuC54XP1fDwx4N0ASyW+U=

internal/channel/base_channel.go

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package channel
22

33
import (
44
"bytes"
5+
"encoding/json"
56
"fmt"
67
"gpt-load/internal/models"
78
"gpt-load/internal/types"
@@ -12,6 +13,7 @@ import (
1213
"strings"
1314
"sync"
1415

16+
"github.com/sirupsen/logrus"
1517
"gorm.io/datatypes"
1618
)
1719

@@ -33,9 +35,11 @@ type BaseChannel struct {
3335
upstreamLock sync.Mutex
3436

3537
// Cached fields from the group for stale check
36-
channelType string
37-
groupUpstreams datatypes.JSON
38-
effectiveConfig *types.SystemSettings
38+
channelType string
39+
groupUpstreams datatypes.JSON
40+
effectiveConfig *types.SystemSettings
41+
modelRedirectRules datatypes.JSONMap
42+
modelRedirectStrict bool
3943
}
4044

4145
// getUpstreamURL selects an upstream URL using a smooth weighted round-robin algorithm.
@@ -107,6 +111,13 @@ func (b *BaseChannel) IsConfigStale(group *models.Group) bool {
107111
if !reflect.DeepEqual(b.effectiveConfig, &group.EffectiveConfig) {
108112
return true
109113
}
114+
// Check for model redirect rules changes
115+
if !reflect.DeepEqual(b.modelRedirectRules, group.ModelRedirectRules) {
116+
return true
117+
}
118+
if b.modelRedirectStrict != group.ModelRedirectStrict {
119+
return true
120+
}
110121
return false
111122
}
112123

@@ -119,3 +130,143 @@ func (b *BaseChannel) GetHTTPClient() *http.Client {
119130
func (b *BaseChannel) GetStreamClient() *http.Client {
120131
return b.StreamClient
121132
}
133+
134+
// ApplyModelRedirect applies model redirection based on the group's redirect rules.
135+
func (b *BaseChannel) ApplyModelRedirect(req *http.Request, bodyBytes []byte, group *models.Group) ([]byte, error) {
136+
if len(group.ModelRedirectMap) == 0 || len(bodyBytes) == 0 {
137+
return bodyBytes, nil
138+
}
139+
140+
var requestData map[string]any
141+
if err := json.Unmarshal(bodyBytes, &requestData); err != nil {
142+
return bodyBytes, nil
143+
}
144+
145+
modelValue, exists := requestData["model"]
146+
if !exists {
147+
return bodyBytes, nil
148+
}
149+
150+
model, ok := modelValue.(string)
151+
if !ok {
152+
return bodyBytes, nil
153+
}
154+
155+
// Direct match without any prefix processing
156+
if targetModel, found := group.ModelRedirectMap[model]; found {
157+
requestData["model"] = targetModel
158+
159+
// Log the redirection for audit
160+
logrus.WithFields(logrus.Fields{
161+
"group": group.Name,
162+
"original_model": model,
163+
"target_model": targetModel,
164+
"channel": "json_body",
165+
}).Debug("Model redirected")
166+
167+
return json.Marshal(requestData)
168+
}
169+
170+
if group.ModelRedirectStrict {
171+
return nil, fmt.Errorf("model '%s' is not configured in redirect rules", model)
172+
}
173+
174+
return bodyBytes, nil
175+
}
176+
177+
// TransformModelList transforms the model list response based on redirect rules.
178+
func (b *BaseChannel) TransformModelList(req *http.Request, bodyBytes []byte, group *models.Group) (map[string]any, error) {
179+
var response map[string]any
180+
if err := json.Unmarshal(bodyBytes, &response); err != nil {
181+
logrus.WithError(err).Debug("Failed to parse model list response, returning empty")
182+
return nil, err
183+
}
184+
185+
dataInterface, exists := response["data"]
186+
if !exists {
187+
return response, nil
188+
}
189+
190+
upstreamModels, ok := dataInterface.([]any)
191+
if !ok {
192+
return response, nil
193+
}
194+
195+
// Build configured source models list (common logic for both modes)
196+
configuredModels := buildConfiguredModels(group.ModelRedirectMap)
197+
198+
// Strict mode: return only configured models (whitelist)
199+
if group.ModelRedirectStrict {
200+
response["data"] = configuredModels
201+
202+
logrus.WithFields(logrus.Fields{
203+
"group": group.Name,
204+
"model_count": len(configuredModels),
205+
"strict_mode": true,
206+
}).Debug("Model list returned (strict mode - configured models only)")
207+
208+
return response, nil
209+
}
210+
211+
// Non-strict mode: merge upstream + configured models (upstream priority)
212+
merged := mergeModelLists(upstreamModels, configuredModels)
213+
response["data"] = merged
214+
215+
logrus.WithFields(logrus.Fields{
216+
"group": group.Name,
217+
"upstream_count": len(upstreamModels),
218+
"configured_count": len(configuredModels),
219+
"merged_count": len(merged),
220+
"strict_mode": false,
221+
}).Debug("Model list merged (non-strict mode)")
222+
223+
return response, nil
224+
}
225+
226+
// buildConfiguredModels builds a list of models from redirect rules
227+
func buildConfiguredModels(redirectMap map[string]string) []any {
228+
if len(redirectMap) == 0 {
229+
return []any{}
230+
}
231+
232+
models := make([]any, 0, len(redirectMap))
233+
for sourceModel := range redirectMap {
234+
models = append(models, map[string]any{
235+
"id": sourceModel,
236+
"object": "model",
237+
"created": 0,
238+
"owned_by": "system",
239+
})
240+
}
241+
return models
242+
}
243+
244+
// mergeModelLists merges upstream and configured model lists
245+
func mergeModelLists(upstream []any, configured []any) []any {
246+
// Create set of upstream model IDs
247+
upstreamIDs := make(map[string]bool)
248+
for _, item := range upstream {
249+
if modelObj, ok := item.(map[string]any); ok {
250+
if modelID, ok := modelObj["id"].(string); ok {
251+
upstreamIDs[modelID] = true
252+
}
253+
}
254+
}
255+
256+
// Start with all upstream models
257+
result := make([]any, len(upstream))
258+
copy(result, upstream)
259+
260+
// Add configured models that don't exist in upstream
261+
for _, item := range configured {
262+
if modelObj, ok := item.(map[string]any); ok {
263+
if modelID, ok := modelObj["id"].(string); ok {
264+
if !upstreamIDs[modelID] {
265+
result = append(result, item)
266+
}
267+
}
268+
}
269+
}
270+
271+
return result
272+
}

internal/channel/channel.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,10 @@ type ChannelProxy interface {
3434

3535
// ValidateKey checks if the given API key is valid.
3636
ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error)
37+
38+
// ApplyModelRedirect applies model redirection based on the group's redirect rules.
39+
ApplyModelRedirect(req *http.Request, bodyBytes []byte, group *models.Group) ([]byte, error)
40+
41+
// TransformModelList transforms the model list response based on redirect rules.
42+
TransformModelList(req *http.Request, bodyBytes []byte, group *models.Group) (map[string]any, error)
3743
}

internal/channel/factory.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,16 @@ func (f *Factory) newBaseChannel(name string, group *models.Group) (*BaseChannel
141141
streamClient := f.clientManager.GetClient(&streamConfig)
142142

143143
return &BaseChannel{
144-
Name: name,
145-
Upstreams: upstreamInfos,
146-
HTTPClient: httpClient,
147-
StreamClient: streamClient,
148-
TestModel: group.TestModel,
149-
ValidationEndpoint: utils.GetValidationEndpoint(group),
150-
channelType: group.ChannelType,
151-
groupUpstreams: group.Upstreams,
152-
effectiveConfig: &group.EffectiveConfig,
144+
Name: name,
145+
Upstreams: upstreamInfos,
146+
HTTPClient: httpClient,
147+
StreamClient: streamClient,
148+
TestModel: group.TestModel,
149+
ValidationEndpoint: utils.GetValidationEndpoint(group),
150+
channelType: group.ChannelType,
151+
groupUpstreams: group.Upstreams,
152+
effectiveConfig: &group.EffectiveConfig,
153+
modelRedirectRules: group.ModelRedirectRules,
154+
modelRedirectStrict: group.ModelRedirectStrict,
153155
}, nil
154156
}

0 commit comments

Comments
 (0)