Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 92 additions & 28 deletions model/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,49 +171,66 @@ var group2model2channels map[string]map[string][]*Channel
var channelSyncLock sync.RWMutex

func InitChannelCache() {
start := time.Now()

// Get all enabled channels
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
Comment on lines 177 to 181
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable newChannelId2channel is created but never used in the updated InitChannelCache() function. This appears to be leftover code from the refactoring.

Suggested change
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
var channels []*Channel
DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
for _, channel := range channels {
// no longer populating newChannelId2channel

Copilot uses AI. Check for mistakes.
}

// Build cache based on abilities table (more accurate)
newGroup2model2channels := buildChannelCacheFromAbilities(channels)

channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()

duration := time.Since(start)
logger.SysLog(fmt.Sprintf("channels synced from database in %v, loaded %d channels", duration, len(channels)))
}

func buildChannelCacheFromAbilities(channels []*Channel) map[string]map[string][]*Channel {
// Create channel ID to channel mapping
channelMap := make(map[int]*Channel)
for _, channel := range channels {
channelMap[channel.Id] = channel
}

// Get all enabled abilities
var abilities []*Ability
DB.Find(&abilities)
groups := make(map[string]bool)
DB.Where("enabled = ?", true).Find(&abilities)

// Build cache based on abilities (ensures consistency)
result := make(map[string]map[string][]*Channel)

for _, ability := range abilities {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]*Channel)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel)
}
for _, channel := range channels {
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok {
newGroup2model2channels[group][model] = make([]*Channel, 0)
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
}
channel, exists := channelMap[ability.ChannelId]
if !exists {
// Channel is disabled, skip this ability
continue
}

if result[ability.Group] == nil {
result[ability.Group] = make(map[string][]*Channel)
}

result[ability.Group][ability.Model] = append(result[ability.Group][ability.Model], channel)
}

// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
// Sort channels by priority within each group-model combination
for group, models := range result {
for model, channelList := range models {
sort.Slice(channelList, func(i, j int) bool {
return channelList[i].GetPriority() > channelList[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
result[group][model] = channelList
}
}

channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
logger.SysLog("channels synced from database")
return result
}

func SyncChannelCache(frequency int) {
Expand All @@ -224,6 +241,53 @@ func SyncChannelCache(frequency int) {
}
}

// InvalidateChannelCache forces immediate cache refresh for a specific channel
func InvalidateChannelCache(channelId int) {
if !config.MemoryCacheEnabled {
return
}

logger.SysLog(fmt.Sprintf("invalidating cache for channel #%d", channelId))

// Force immediate cache rebuild
InitChannelCache()
}

// InvalidateGroupModelCache clears Redis cache for specific group
func InvalidateGroupModelCache(group string) {
if !common.RedisEnabled {
return
}

cacheKey := fmt.Sprintf("group_models:%s", group)
err := common.RedisDel(cacheKey)
if err != nil {
logger.SysError(fmt.Sprintf("failed to invalidate group models cache for %s: %s", group, err.Error()))
} else {
logger.SysLog(fmt.Sprintf("invalidated group models cache for %s", group))
}
}

// InvalidateUserCache clears user-related Redis caches
func InvalidateUserCache(userId int) {
if !common.RedisEnabled {
return
}

patterns := []string{
fmt.Sprintf("user_group:%d", userId),
fmt.Sprintf("user_quota:%d", userId),
fmt.Sprintf("user_enabled:%d", userId),
}

for _, pattern := range patterns {
err := common.RedisDel(pattern)
if err != nil {
logger.SysError(fmt.Sprintf("failed to invalidate cache %s: %s", pattern, err.Error()))
}
}
}

func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority)
Expand Down
116 changes: 116 additions & 0 deletions model/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package model

import (
"testing"
"time"

"github.com/songquanpeng/one-api/common/config"
"github.com/stretchr/testify/assert"
)

// TestCacheSync tests the cache synchronization functionality
func TestCacheSync(t *testing.T) {
// Enable memory cache for testing
config.MemoryCacheEnabled = true

// Initialize database connection for testing
// Note: This assumes test database is configured
InitDB()

// Test channel cache initialization
t.Run("TestInitChannelCache", func(t *testing.T) {
// Initialize cache
InitChannelCache()

// Verify cache is built
channelSyncLock.RLock()
cacheExists := group2model2channels != nil
channelSyncLock.RUnlock()

assert.True(t, cacheExists, "Channel cache should be initialized")
})

// Test cache invalidation
t.Run("TestInvalidateChannelCache", func(t *testing.T) {
// Initialize cache first
InitChannelCache()

// Test cache invalidation
InvalidateChannelCache(1)

// Cache should be rebuilt after invalidation
channelSyncLock.RLock()
cacheExists := group2model2channels != nil
channelSyncLock.RUnlock()

assert.True(t, cacheExists, "Cache should exist after invalidation")
})

// Test abilities-based cache building
t.Run("TestBuildChannelCacheFromAbilities", func(t *testing.T) {
// Create test channels
channels := []*Channel{
{
Id: 1,
Status: ChannelStatusEnabled,
Group: "default",
Models: "gpt-3.5-turbo",
Priority: &[]int64{0}[0],
Comment on lines +52 to +58
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This inline slice creation and dereferencing pattern &[]int64{0}[0] is unnecessarily complex. Consider using a helper variable: priority := int64(0); Priority: &priority,

Suggested change
channels := []*Channel{
{
Id: 1,
Status: ChannelStatusEnabled,
Group: "default",
Models: "gpt-3.5-turbo",
Priority: &[]int64{0}[0],
priority := int64(0)
channels := []*Channel{
{
Id: 1,
Status: ChannelStatusEnabled,
Group: "default",
Models: "gpt-3.5-turbo",
Priority: &priority,

Copilot uses AI. Check for mistakes.
},
}

// Build cache from abilities
cache := buildChannelCacheFromAbilities(channels)

assert.NotNil(t, cache, "Cache should not be nil")
})
}

// TestChannelStatusUpdate tests channel status update with cache sync
func TestChannelStatusUpdate(t *testing.T) {
// Enable memory cache for testing
config.MemoryCacheEnabled = true

// Test channel status update
t.Run("TestUpdateChannelStatusById", func(t *testing.T) {
// This test requires a test database with actual data
// For now, we'll just test that the function doesn't panic
assert.NotPanics(t, func() {
UpdateChannelStatusById(999, ChannelStatusEnabled)
}, "UpdateChannelStatusById should not panic")
})
}

// BenchmarkCacheInit benchmarks cache initialization performance
func BenchmarkCacheInit(b *testing.B) {
config.MemoryCacheEnabled = true
InitDB()

b.ResetTimer()
for i := 0; i < b.N; i++ {
InitChannelCache()
}
}

// TestCacheConsistency verifies cache consistency after updates
func TestCacheConsistency(t *testing.T) {
config.MemoryCacheEnabled = true

t.Run("TestCacheConsistencyAfterUpdate", func(t *testing.T) {
// Initialize cache
InitChannelCache()

// Simulate channel status change
InvalidateChannelCache(1)

// Wait a bit for cache to rebuild
time.Sleep(100 * time.Millisecond)

// Verify cache is still consistent
channelSyncLock.RLock()
cacheExists := group2model2channels != nil
channelSyncLock.RUnlock()

assert.True(t, cacheExists, "Cache should remain consistent after updates")
})
}
66 changes: 64 additions & 2 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"fmt"
"strings"

"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
Expand Down Expand Up @@ -136,13 +137,47 @@ func (channel *Channel) Insert() error {

func (channel *Channel) Update() error {
var err error

// Get old channel info for cache invalidation
var oldChannel Channel
err = DB.Where("id = ?", channel.Id).First(&oldChannel).Error
if err != nil {
return err
}

err = DB.Model(channel).Updates(channel).Error
if err != nil {
return err
}
DB.Model(channel).First(channel, "id = ?", channel.Id)
err = channel.UpdateAbilities()
return err
if err != nil {
return err
}
Comment on lines +154 to +156
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing conditional check for UpdateAbilities() error. The error variable err is assigned but there's no if err != nil check before line 154.

Copilot uses AI. Check for mistakes.

// Invalidate caches if channel configuration changed
InvalidateChannelCache(channel.Id)

// Invalidate affected group caches
allGroups := make(map[string]bool)
// Add old groups
for _, group := range strings.Split(oldChannel.Group, ",") {
allGroups[strings.TrimSpace(group)] = true
}
// Add new groups
for _, group := range strings.Split(channel.Group, ",") {
allGroups[strings.TrimSpace(group)] = true
}

// Invalidate all affected groups
for group := range allGroups {
if group != "" {
InvalidateGroupModelCache(group)
}
}

logger.SysLog(fmt.Sprintf("updated channel #%d configuration and invalidated caches", channel.Id))
return nil
}

func (channel *Channel) UpdateResponseTime(responseTime int64) {
Expand Down Expand Up @@ -188,14 +223,41 @@ func (channel *Channel) LoadConfig() (ChannelConfig, error) {
}

func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == ChannelStatusEnabled)
// Get channel info for cache invalidation
var channel Channel
err := DB.Where("id = ?", id).First(&channel).Error
if err != nil {
logger.SysError("failed to get channel info: " + err.Error())
return
}

// Update abilities status
err = UpdateAbilityStatus(id, status == ChannelStatusEnabled)
if err != nil {
logger.SysError("failed to update ability status: " + err.Error())
}

// Update channel status
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
logger.SysError("failed to update channel status: " + err.Error())
return
}

// Immediately invalidate all related caches
InvalidateChannelCache(id)

// Invalidate affected group model caches
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
InvalidateGroupModelCache(strings.TrimSpace(group))
}

statusStr := "disabled"
if status == ChannelStatusEnabled {
statusStr = "enabled"
}
logger.SysLog(fmt.Sprintf("updated channel #%d status to %s and invalidated caches", id, statusStr))
}

func UpdateChannelUsedQuota(id int, quota int64) {
Expand Down