Skip to content

Commit 8d0562c

Browse files
authored
feat: 允许自定义Header并支持变量 (#168)
* feat: 允许自定义Header并支持变量 * feat: 缓存header序列化数据 * feat: 调整自定义请求头变量
1 parent 44b6770 commit 8d0562c

File tree

14 files changed

+643
-56
lines changed

14 files changed

+643
-56
lines changed

internal/channel/anthropic_channel.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
app_errors "gpt-load/internal/errors"
99
"gpt-load/internal/models"
10+
"gpt-load/internal/utils"
1011
"io"
1112
"net/http"
1213
"net/url"
@@ -73,7 +74,7 @@ func (ch *AnthropicChannel) ExtractModel(c *gin.Context, bodyBytes []byte) strin
7374
}
7475

7576
// ValidateKey checks if the given API key is valid by making a messages request.
76-
func (ch *AnthropicChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
77+
func (ch *AnthropicChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error) {
7778
upstreamURL := ch.getUpstreamURL()
7879
if upstreamURL == nil {
7980
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
@@ -104,10 +105,16 @@ func (ch *AnthropicChannel) ValidateKey(ctx context.Context, key string) (bool,
104105
if err != nil {
105106
return false, fmt.Errorf("failed to create validation request: %w", err)
106107
}
107-
req.Header.Set("x-api-key", key)
108+
req.Header.Set("x-api-key", apiKey.KeyValue)
108109
req.Header.Set("anthropic-version", "2023-06-01")
109110
req.Header.Set("Content-Type", "application/json")
110111

112+
// Apply custom header rules if available
113+
if len(group.HeaderRuleList) > 0 {
114+
headerCtx := utils.NewHeaderVariableContext(group, apiKey)
115+
utils.ApplyHeaderRules(req, group.HeaderRuleList, headerCtx)
116+
}
117+
111118
resp, err := ch.HTTPClient.Do(req)
112119
if err != nil {
113120
return false, fmt.Errorf("failed to send validation request: %w", err)

internal/channel/channel.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ type ChannelProxy interface {
3333
ExtractModel(c *gin.Context, bodyBytes []byte) string
3434

3535
// ValidateKey checks if the given API key is valid.
36-
ValidateKey(ctx context.Context, key string) (bool, error)
36+
ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error)
3737
}

internal/channel/gemini_channel.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
app_errors "gpt-load/internal/errors"
99
"gpt-load/internal/models"
10+
"gpt-load/internal/utils"
1011
"io"
1112
"net/http"
1213
"net/url"
@@ -95,7 +96,7 @@ func (ch *GeminiChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
9596
}
9697

9798
// ValidateKey checks if the given API key is valid by making a generateContent request.
98-
func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
99+
func (ch *GeminiChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error) {
99100
upstreamURL := ch.getUpstreamURL()
100101
if upstreamURL == nil {
101102
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
@@ -106,7 +107,7 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err
106107
if err != nil {
107108
return false, fmt.Errorf("failed to create gemini validation path: %w", err)
108109
}
109-
reqURL += "?key=" + key
110+
reqURL += "?key=" + apiKey.KeyValue
110111

111112
payload := gin.H{
112113
"contents": []gin.H{
@@ -126,6 +127,12 @@ func (ch *GeminiChannel) ValidateKey(ctx context.Context, key string) (bool, err
126127
}
127128
req.Header.Set("Content-Type", "application/json")
128129

130+
// Apply custom header rules if available
131+
if len(group.HeaderRuleList) > 0 {
132+
headerCtx := utils.NewHeaderVariableContext(group, apiKey)
133+
utils.ApplyHeaderRules(req, group.HeaderRuleList, headerCtx)
134+
}
135+
129136
resp, err := ch.HTTPClient.Do(req)
130137
if err != nil {
131138
return false, fmt.Errorf("failed to send validation request: %w", err)

internal/channel/openai_channel.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
app_errors "gpt-load/internal/errors"
99
"gpt-load/internal/models"
10+
"gpt-load/internal/utils"
1011
"io"
1112
"net/http"
1213
"net/url"
@@ -72,7 +73,7 @@ func (ch *OpenAIChannel) ExtractModel(c *gin.Context, bodyBytes []byte) string {
7273
}
7374

7475
// ValidateKey checks if the given API key is valid by making a chat completion request.
75-
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, error) {
76+
func (ch *OpenAIChannel) ValidateKey(ctx context.Context, apiKey *models.APIKey, group *models.Group) (bool, error) {
7677
upstreamURL := ch.getUpstreamURL()
7778
if upstreamURL == nil {
7879
return false, fmt.Errorf("no upstream URL configured for channel %s", ch.Name)
@@ -103,9 +104,15 @@ func (ch *OpenAIChannel) ValidateKey(ctx context.Context, key string) (bool, err
103104
if err != nil {
104105
return false, fmt.Errorf("failed to create validation request: %w", err)
105106
}
106-
req.Header.Set("Authorization", "Bearer "+key)
107+
req.Header.Set("Authorization", "Bearer "+apiKey.KeyValue)
107108
req.Header.Set("Content-Type", "application/json")
108109

110+
// Apply custom header rules if available
111+
if len(group.HeaderRuleList) > 0 {
112+
headerCtx := utils.NewHeaderVariableContext(group, apiKey)
113+
utils.ApplyHeaderRules(req, group.HeaderRuleList, headerCtx)
114+
}
115+
109116
resp, err := ch.HTTPClient.Do(req)
110117
if err != nil {
111118
return false, fmt.Errorf("failed to send validation request: %w", err)

internal/handler/group_handler.go

Lines changed: 143 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package handler
44
import (
55
"encoding/json"
66
"fmt"
7+
"net/http"
78
"net/url"
89
"sync"
910

@@ -154,9 +155,25 @@ func (s *Server) validateAndCleanConfig(configMap map[string]any) (map[string]an
154155
return finalMap, nil
155156
}
156157

158+
// GroupCreateRequest defines the payload for creating a group.
159+
type GroupCreateRequest struct {
160+
Name string `json:"name"`
161+
DisplayName string `json:"display_name"`
162+
Description string `json:"description"`
163+
Upstreams json.RawMessage `json:"upstreams"`
164+
ChannelType string `json:"channel_type"`
165+
Sort int `json:"sort"`
166+
TestModel string `json:"test_model"`
167+
ValidationEndpoint string `json:"validation_endpoint"`
168+
ParamOverrides map[string]any `json:"param_overrides"`
169+
Config map[string]any `json:"config"`
170+
HeaderRules []models.HeaderRule `json:"header_rules"`
171+
ProxyKeys string `json:"proxy_keys"`
172+
}
173+
157174
// CreateGroup handles the creation of a new group.
158175
func (s *Server) CreateGroup(c *gin.Context) {
159-
var req models.Group
176+
var req GroupCreateRequest
160177
if err := c.ShouldBindJSON(&req); err != nil {
161178
response.Error(c, app_errors.NewAPIError(app_errors.ErrInvalidJSON, err.Error()))
162179
return
@@ -182,7 +199,7 @@ func (s *Server) CreateGroup(c *gin.Context) {
182199
return
183200
}
184201

185-
cleanedUpstreams, err := validateAndCleanUpstreams(json.RawMessage(req.Upstreams))
202+
cleanedUpstreams, err := validateAndCleanUpstreams(req.Upstreams)
186203
if err != nil {
187204
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, err.Error()))
188205
return
@@ -200,6 +217,48 @@ func (s *Server) CreateGroup(c *gin.Context) {
200217
return
201218
}
202219

220+
// Validate and normalize header rules if provided
221+
var headerRulesJSON datatypes.JSON
222+
if len(req.HeaderRules) > 0 {
223+
normalizedHeaderRules := make([]models.HeaderRule, 0)
224+
seenKeys := make(map[string]bool)
225+
226+
for _, rule := range req.HeaderRules {
227+
key := strings.TrimSpace(rule.Key)
228+
if key == "" {
229+
continue
230+
}
231+
232+
// Normalize to canonical form
233+
canonicalKey := http.CanonicalHeaderKey(key)
234+
235+
// Check for duplicate keys
236+
if seenKeys[canonicalKey] {
237+
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Duplicate header key: %s", canonicalKey)))
238+
return
239+
}
240+
seenKeys[canonicalKey] = true
241+
242+
normalizedHeaderRules = append(normalizedHeaderRules, models.HeaderRule{
243+
Key: canonicalKey,
244+
Value: rule.Value,
245+
Action: rule.Action,
246+
})
247+
}
248+
249+
if len(normalizedHeaderRules) > 0 {
250+
headerRulesBytes, err := json.Marshal(normalizedHeaderRules)
251+
if err != nil {
252+
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to process header rules: %v", err)))
253+
return
254+
}
255+
headerRulesJSON = headerRulesBytes
256+
}
257+
}
258+
if headerRulesJSON == nil {
259+
headerRulesJSON = datatypes.JSON("[]")
260+
}
261+
203262
group := models.Group{
204263
Name: name,
205264
DisplayName: strings.TrimSpace(req.DisplayName),
@@ -211,6 +270,7 @@ func (s *Server) CreateGroup(c *gin.Context) {
211270
ValidationEndpoint: validationEndpoint,
212271
ParamOverrides: req.ParamOverrides,
213272
Config: cleanedConfig,
273+
HeaderRules: headerRulesJSON,
214274
ProxyKeys: strings.TrimSpace(req.ProxyKeys),
215275
}
216276

@@ -244,17 +304,18 @@ func (s *Server) ListGroups(c *gin.Context) {
244304
// GroupUpdateRequest defines the payload for updating a group.
245305
// Using a dedicated struct avoids issues with zero values being ignored by GORM's Update.
246306
type GroupUpdateRequest struct {
247-
Name *string `json:"name,omitempty"`
248-
DisplayName *string `json:"display_name,omitempty"`
249-
Description *string `json:"description,omitempty"`
250-
Upstreams json.RawMessage `json:"upstreams"`
251-
ChannelType *string `json:"channel_type,omitempty"`
252-
Sort *int `json:"sort"`
253-
TestModel string `json:"test_model"`
254-
ValidationEndpoint *string `json:"validation_endpoint,omitempty"`
255-
ParamOverrides map[string]any `json:"param_overrides"`
256-
Config map[string]any `json:"config"`
257-
ProxyKeys *string `json:"proxy_keys,omitempty"`
307+
Name *string `json:"name,omitempty"`
308+
DisplayName *string `json:"display_name,omitempty"`
309+
Description *string `json:"description,omitempty"`
310+
Upstreams json.RawMessage `json:"upstreams"`
311+
ChannelType *string `json:"channel_type,omitempty"`
312+
Sort *int `json:"sort"`
313+
TestModel string `json:"test_model"`
314+
ValidationEndpoint *string `json:"validation_endpoint,omitempty"`
315+
ParamOverrides map[string]any `json:"param_overrides"`
316+
Config map[string]any `json:"config"`
317+
HeaderRules []models.HeaderRule `json:"header_rules"`
318+
ProxyKeys *string `json:"proxy_keys,omitempty"`
258319
}
259320

260321
// UpdateGroup handles updating an existing group.
@@ -357,6 +418,48 @@ func (s *Server) UpdateGroup(c *gin.Context) {
357418
group.ProxyKeys = strings.TrimSpace(*req.ProxyKeys)
358419
}
359420

421+
// Handle header rules update
422+
if req.HeaderRules != nil {
423+
var headerRulesJSON datatypes.JSON
424+
normalizedHeaderRules := make([]models.HeaderRule, 0)
425+
seenKeys := make(map[string]bool)
426+
427+
for _, rule := range req.HeaderRules {
428+
key := strings.TrimSpace(rule.Key)
429+
if key == "" {
430+
continue
431+
}
432+
433+
// Normalize to canonical form
434+
canonicalKey := http.CanonicalHeaderKey(key)
435+
436+
// Check for duplicate keys
437+
if seenKeys[canonicalKey] {
438+
response.Error(c, app_errors.NewAPIError(app_errors.ErrValidation, fmt.Sprintf("Duplicate header key: %s", canonicalKey)))
439+
return
440+
}
441+
seenKeys[canonicalKey] = true
442+
443+
normalizedHeaderRules = append(normalizedHeaderRules, models.HeaderRule{
444+
Key: canonicalKey,
445+
Value: rule.Value,
446+
Action: rule.Action,
447+
})
448+
}
449+
450+
if len(normalizedHeaderRules) > 0 {
451+
headerRulesBytes, err := json.Marshal(normalizedHeaderRules)
452+
if err != nil {
453+
response.Error(c, app_errors.NewAPIError(app_errors.ErrInternalServer, fmt.Sprintf("Failed to process header rules: %v", err)))
454+
return
455+
}
456+
headerRulesJSON = headerRulesBytes
457+
} else {
458+
headerRulesJSON = datatypes.JSON("[]")
459+
}
460+
group.HeaderRules = headerRulesJSON
461+
}
462+
360463
// Save the updated group object
361464
if err := tx.Save(&group).Error; err != nil {
362465
response.Error(c, app_errors.ParseDBError(err))
@@ -376,22 +479,23 @@ func (s *Server) UpdateGroup(c *gin.Context) {
376479

377480
// GroupResponse defines the structure for a group response, excluding sensitive or large fields.
378481
type GroupResponse struct {
379-
ID uint `json:"id"`
380-
Name string `json:"name"`
381-
Endpoint string `json:"endpoint"`
382-
DisplayName string `json:"display_name"`
383-
Description string `json:"description"`
384-
Upstreams datatypes.JSON `json:"upstreams"`
385-
ChannelType string `json:"channel_type"`
386-
Sort int `json:"sort"`
387-
TestModel string `json:"test_model"`
388-
ValidationEndpoint string `json:"validation_endpoint"`
389-
ParamOverrides datatypes.JSONMap `json:"param_overrides"`
390-
Config datatypes.JSONMap `json:"config"`
391-
ProxyKeys string `json:"proxy_keys"`
392-
LastValidatedAt *time.Time `json:"last_validated_at"`
393-
CreatedAt time.Time `json:"created_at"`
394-
UpdatedAt time.Time `json:"updated_at"`
482+
ID uint `json:"id"`
483+
Name string `json:"name"`
484+
Endpoint string `json:"endpoint"`
485+
DisplayName string `json:"display_name"`
486+
Description string `json:"description"`
487+
Upstreams datatypes.JSON `json:"upstreams"`
488+
ChannelType string `json:"channel_type"`
489+
Sort int `json:"sort"`
490+
TestModel string `json:"test_model"`
491+
ValidationEndpoint string `json:"validation_endpoint"`
492+
ParamOverrides datatypes.JSONMap `json:"param_overrides"`
493+
Config datatypes.JSONMap `json:"config"`
494+
HeaderRules []models.HeaderRule `json:"header_rules"`
495+
ProxyKeys string `json:"proxy_keys"`
496+
LastValidatedAt *time.Time `json:"last_validated_at"`
497+
CreatedAt time.Time `json:"created_at"`
498+
UpdatedAt time.Time `json:"updated_at"`
395499
}
396500

397501
// newGroupResponse creates a new GroupResponse from a models.Group.
@@ -406,6 +510,15 @@ func (s *Server) newGroupResponse(group *models.Group) *GroupResponse {
406510
}
407511
}
408512

513+
// Parse header rules from JSON
514+
var headerRules []models.HeaderRule
515+
if len(group.HeaderRules) > 0 {
516+
if err := json.Unmarshal(group.HeaderRules, &headerRules); err != nil {
517+
logrus.WithError(err).Error("Failed to unmarshal header rules")
518+
headerRules = make([]models.HeaderRule, 0)
519+
}
520+
}
521+
409522
return &GroupResponse{
410523
ID: group.ID,
411524
Name: group.Name,
@@ -419,6 +532,7 @@ func (s *Server) newGroupResponse(group *models.Group) *GroupResponse {
419532
ValidationEndpoint: group.ValidationEndpoint,
420533
ParamOverrides: group.ParamOverrides,
421534
Config: group.Config,
535+
HeaderRules: headerRules,
422536
ProxyKeys: group.ProxyKeys,
423537
LastValidatedAt: group.LastValidatedAt,
424538
CreatedAt: group.CreatedAt,

internal/keypool/validator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (s *KeyValidator) ValidateSingleKey(key *models.APIKey, group *models.Group
5959
return false, fmt.Errorf("failed to get channel for group %s: %w", group.Name, err)
6060
}
6161

62-
isValid, validationErr := ch.ValidateKey(ctx, key.KeyValue)
62+
isValid, validationErr := ch.ValidateKey(ctx, key, group)
6363

6464
s.keypoolProvider.UpdateStatus(key, group, isValid)
6565

internal/models/types.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ type GroupConfig struct {
3939
KeyValidationTimeoutSeconds *int `json:"key_validation_timeout_seconds,omitempty"`
4040
}
4141

42+
// HeaderRule defines a single rule for header manipulation.
43+
type HeaderRule struct {
44+
Key string `json:"key"`
45+
Value string `json:"value"`
46+
Action string `json:"action"` // "set" or "remove"
47+
}
48+
4249
// Group 对应 groups 表
4350
type Group struct {
4451
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
@@ -55,13 +62,15 @@ type Group struct {
5562
TestModel string `gorm:"type:varchar(255);not null" json:"test_model"`
5663
ParamOverrides datatypes.JSONMap `gorm:"type:json" json:"param_overrides"`
5764
Config datatypes.JSONMap `gorm:"type:json" json:"config"`
65+
HeaderRules datatypes.JSON `gorm:"type:json" json:"header_rules"`
5866
APIKeys []APIKey `gorm:"foreignKey:GroupID" json:"api_keys"`
5967
LastValidatedAt *time.Time `json:"last_validated_at"`
6068
CreatedAt time.Time `json:"created_at"`
6169
UpdatedAt time.Time `json:"updated_at"`
6270

6371
// For cache
64-
ProxyKeysMap map[string]struct{} `gorm:"-" json:"-"`
72+
ProxyKeysMap map[string]struct{} `gorm:"-" json:"-"`
73+
HeaderRuleList []HeaderRule `gorm:"-" json:"-"`
6574
}
6675

6776
// APIKey 对应 api_keys 表

0 commit comments

Comments
 (0)