Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions common/client/init.go
Original file line number Diff line number Diff line change
@@ -1,35 +1,75 @@
package client

import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"net/url"
"sync"
"time"

"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/ssrf"
)

var HTTPClient *http.Client
var ImpatientHTTPClient *http.Client

// UserContentRequestHTTPClient is the HTTP client used to fetch user content (images, files, etc.)
// SSRF protection is applied to this client.
var UserContentRequestHTTPClient *http.Client

var userContentTransport *http.Transport
var userContentTransportMu sync.RWMutex

func Init() {
ctx := context.Background()
initUserContentRequestClient(ctx)
initRelayClient(ctx)
}

func initUserContentRequestClient(ctx context.Context) {
var transport *http.Transport

if config.UserContentRequestProxy != "" {
logger.SysLog(fmt.Sprintf("using %s as proxy to fetch user content", config.UserContentRequestProxy))
proxyURL, err := url.Parse(config.UserContentRequestProxy)
if err != nil {
logger.FatalLog(fmt.Sprintf("USER_CONTENT_REQUEST_PROXY set but invalid: %s", config.UserContentRequestProxy))
}
transport := &http.Transport{
transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
UserContentRequestHTTPClient = &http.Client{
Transport: transport,
Timeout: time.Second * time.Duration(config.UserContentRequestTimeout),
}
} else {
UserContentRequestHTTPClient = &http.Client{}
transport = &http.Transport{}
}

protection := ssrf.GetGlobalProtection()
safeTransport := protection.CreateSafeTransport(ctx, transport)

userContentTransportMu.Lock()
userContentTransport = safeTransport
userContentTransportMu.Unlock()

UserContentRequestHTTPClient = &http.Client{
Transport: safeTransport,
Timeout: time.Second * time.Duration(config.UserContentRequestTimeout),
CheckRedirect: protection.CheckRedirect(ctx),
}
}

func CloseUserContentIdleConnections() {
userContentTransportMu.RLock()
transport := userContentTransport
userContentTransportMu.RUnlock()

if transport != nil {
transport.CloseIdleConnections()
}
}

func initRelayClient(ctx context.Context) {
var transport http.RoundTripper
if config.RelayProxy != "" {
logger.SysLog(fmt.Sprintf("using %s as api relay proxy", config.RelayProxy))
Expand Down
20 changes: 20 additions & 0 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,23 @@ var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)

var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
var TestPrompt = env.String("TEST_PROMPT", "Output only your specific model name with no additional text.")

// SSRF Protection Configuration
var (
// SSRFProtectionEnabled 是否启用SSRF防护
SSRFProtectionEnabled = true
// SSRFAllowPrivateIP 是否允许访问私有IP地址
SSRFAllowPrivateIP = false
// SSRFApplyIPFilterForDomain 对域名启用IP过滤(DNS解析后检查)
SSRFApplyIPFilterForDomain = false
// SSRFDomainListMode 域名列表模式(whitelist/blacklist)
SSRFDomainListMode = "blacklist"
// SSRFDomainList 统一域名列表(逗号分隔,支持通配符如 *.example.com)
SSRFDomainList = []string{}
// SSRFIPListMode IP列表模式(whitelist/blacklist)
SSRFIPListMode = "blacklist"
// SSRFIPList 统一IP列表(逗号分隔,支持CIDR格式如 192.168.1.0/24)
SSRFIPList = []string{}
// SSRFAllowedPorts 允许的端口列表
SSRFAllowedPorts = []int{80, 443, 8080, 8443}
)
3 changes: 2 additions & 1 deletion common/image/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package image
import (
"bytes"
"encoding/base64"
"github.com/songquanpeng/one-api/common/client"
"image"
_ "image/gif"
_ "image/jpeg"
Expand All @@ -13,6 +12,8 @@ import (
"strings"
"sync"

"github.com/songquanpeng/one-api/common/client"

_ "golang.org/x/image/webp"
)

Expand Down
Loading