Skip to content

Commit d522d6e

Browse files
authored
{session/inmemory,session/redis}: fix event filtering to ensure at least one user message (#737)
1 parent ade1bd5 commit d522d6e

File tree

3 files changed

+33
-29
lines changed

3 files changed

+33
-29
lines changed

session/inmemory/service.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ func (s *SessionService) updateStoredSession(sess *session.Session, e *event.Eve
723723
sess.EventMu.Lock()
724724
sess.Events = append(sess.Events, *e)
725725
if s.opts.sessionEventLimit > 0 && len(sess.Events) > s.opts.sessionEventLimit {
726-
sess.Events = sess.Events[len(sess.Events)-s.opts.sessionEventLimit:]
726+
sess.ApplyEventFiltering(session.WithEventNum(s.opts.sessionEventLimit))
727727
}
728728
sess.EventMu.Unlock()
729729
}

session/redis/service.go

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -743,15 +743,7 @@ func (s *Service) getEventsList(
743743
) ([][]event.Event, error) {
744744
pipe := s.redisClient.Pipeline()
745745
for _, key := range sessionKeys {
746-
zrangeBy := &redis.ZRangeBy{
747-
Min: fmt.Sprintf("%d", afterTime.UnixNano()),
748-
Max: fmt.Sprintf("%d", time.Now().UnixNano()),
749-
}
750-
if limit > 0 {
751-
zrangeBy.Offset = 0
752-
zrangeBy.Count = int64(limit)
753-
}
754-
pipe.ZRevRangeByScore(ctx, getEventKey(key), zrangeBy)
746+
pipe.ZRange(ctx, getEventKey(key), 0, -1)
755747
}
756748
cmds, err := pipe.Exec(ctx)
757749
if err != nil && err != redis.Nil {
@@ -768,14 +760,14 @@ func (s *Service) getEventsList(
768760
if err != nil {
769761
return nil, fmt.Errorf("process event cmd failed: %w", err)
770762
}
771-
772-
// reverse events to get chronological order (oldest first)
773-
if len(events) > 1 {
774-
for i, j := 0, len(events)-1; i < j; i, j = i+1, j-1 {
775-
events[i], events[j] = events[j], events[i]
776-
}
763+
sess := &session.Session{
764+
Events: events,
777765
}
778-
sessEventsList = append(sessEventsList, events)
766+
if limit <= 0 {
767+
limit = s.opts.sessionEventLimit
768+
}
769+
sess.ApplyEventFiltering(session.WithEventNum(limit), session.WithEventTime(afterTime))
770+
sessEventsList = append(sessEventsList, sess.Events)
779771
}
780772
return sessEventsList, nil
781773
}
@@ -998,9 +990,6 @@ func (s *Service) addEvent(ctx context.Context, key session.Key, event *event.Ev
998990
Score: float64(event.Timestamp.UnixNano()),
999991
Member: eventBytes,
1000992
})
1001-
if s.opts.sessionEventLimit > 0 {
1002-
txPipe.ZRemRangeByRank(ctx, getEventKey(key), 0, -(int64(s.opts.sessionEventLimit) + 1))
1003-
}
1004993
// Set TTL for session state and event list if configured
1005994
if s.sessionTTL > 0 {
1006995
txPipe.Expire(ctx, getEventKey(key), s.sessionTTL)

session/redis/service_test.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,10 +1400,10 @@ func TestGetSession_EventFiltering_Integration(t *testing.T) {
14001400
require.NotNil(t, retrievedSess)
14011401

14021402
// Should have 2 events (from event3 onwards)
1403-
assert.Equal(t, 4, len(retrievedSess.Events), "Should filter out assistant events before first user event")
1404-
assert.Equal(t, "event1", retrievedSess.Events[0].ID, "First event should be the user event")
1405-
assert.Equal(t, model.RoleAssistant, retrievedSess.Events[0].Response.Choices[0].Message.Role)
1406-
assert.Equal(t, "event2", retrievedSess.Events[1].ID, "Second event should be the subsequent assistant event")
1403+
assert.Equal(t, 2, len(retrievedSess.Events), "Should filter out assistant events before first user event")
1404+
assert.Equal(t, "event3", retrievedSess.Events[0].ID, "First event should be the user event")
1405+
assert.Equal(t, model.RoleUser, retrievedSess.Events[0].Response.Choices[0].Message.Role)
1406+
assert.Equal(t, "event4", retrievedSess.Events[1].ID, "Second event should be the subsequent assistant event")
14071407

14081408
// Test ListSessions - should apply same filtering
14091409
sessionList, err := service.ListSessions(context.Background(), session.UserKey{
@@ -1414,9 +1414,9 @@ func TestGetSession_EventFiltering_Integration(t *testing.T) {
14141414
require.Len(t, sessionList, 1)
14151415

14161416
// Should have same filtering as GetSession
1417-
assert.Equal(t, 4, len(sessionList[0].Events), "ListSessions should also filter events")
1418-
assert.Equal(t, "event1", sessionList[0].Events[0].ID, "First event should be the user event")
1419-
assert.Equal(t, model.RoleAssistant, sessionList[0].Events[0].Response.Choices[0].Message.Role)
1417+
assert.Equal(t, 2, len(sessionList[0].Events), "ListSessions should also filter events")
1418+
assert.Equal(t, "event3", sessionList[0].Events[0].ID, "First event should be the user event")
1419+
assert.Equal(t, model.RoleUser, sessionList[0].Events[0].Response.Choices[0].Message.Role)
14201420
}
14211421

14221422
func TestGetSession_AllAssistantEvents_Integration(t *testing.T) {
@@ -1447,6 +1447,21 @@ func TestGetSession_AllAssistantEvents_Integration(t *testing.T) {
14471447
Choices: []model.Choice{
14481448
{
14491449
Index: 0,
1450+
Message: model.Message{
1451+
Role: model.RoleUser,
1452+
Content: "user message 1",
1453+
},
1454+
},
1455+
},
1456+
},
1457+
},
1458+
{
1459+
ID: "event1",
1460+
Timestamp: baseTime.Add(-3 * time.Hour),
1461+
Response: &model.Response{
1462+
Choices: []model.Choice{
1463+
{
1464+
Index: 1,
14501465
Message: model.Message{
14511466
Role: model.RoleAssistant,
14521467
Content: "Assistant message 1",
@@ -1461,7 +1476,7 @@ func TestGetSession_AllAssistantEvents_Integration(t *testing.T) {
14611476
Response: &model.Response{
14621477
Choices: []model.Choice{
14631478
{
1464-
Index: 0,
1479+
Index: 2,
14651480
Message: model.Message{
14661481
Role: model.RoleAssistant,
14671482
Content: "Assistant message 2",
@@ -1484,7 +1499,7 @@ func TestGetSession_AllAssistantEvents_Integration(t *testing.T) {
14841499
require.NotNil(t, retrievedSess)
14851500

14861501
// Should have no events since all are from assistant
1487-
assert.Equal(t, 2, len(retrievedSess.Events), "Should filter out all assistant events when no user events exist")
1502+
assert.Equal(t, 3, len(retrievedSess.Events), "Should filter out all assistant events when no user events exist")
14881503
}
14891504

14901505
func TestService_Close_MultipleTimes(t *testing.T) {

0 commit comments

Comments
 (0)