Skip to content

Commit 56a2d4f

Browse files
EconoBenclaude
andcommitted
fix: address roborev findings for API handlers and remote engine
- Use r.Context() instead of context.Background() in all API handlers (handleAggregates, handleSubAggregates, handleFilteredMessages, handleTotalStats, handleFastSearch, handleDeepSearch) - Add doRequestWithContext to remote.Store for context propagation - Plumb context through all remote.Engine HTTP requests for proper cancellation/timeout support - Fix SubAggregate to use opts.TimeGranularity instead of filter - Add validation for view_type in handleFastSearch (return 400 if invalid) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 3b6d06c commit 56a2d4f

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

internal/api/handlers.go

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package api
22

33
import (
4-
"context"
54
"crypto/sha256"
65
"encoding/json"
76
"fmt"
@@ -875,8 +874,7 @@ func (s *Server) handleAggregates(w http.ResponseWriter, r *http.Request) {
875874

876875
opts := parseAggregateOptions(r)
877876

878-
ctx := context.Background()
879-
rows, err := s.engine.Aggregate(ctx, viewType, opts)
877+
rows, err := s.engine.Aggregate(r.Context(), viewType, opts)
880878
if err != nil {
881879
s.logger.Error("aggregate query failed", "view_type", viewTypeStr, "error", err)
882880
writeError(w, http.StatusInternalServerError, "internal_error", "Aggregate query failed")
@@ -917,8 +915,7 @@ func (s *Server) handleSubAggregates(w http.ResponseWriter, r *http.Request) {
917915
filter := parseMessageFilter(r)
918916
opts := parseAggregateOptions(r)
919917

920-
ctx := context.Background()
921-
rows, err := s.engine.SubAggregate(ctx, filter, viewType, opts)
918+
rows, err := s.engine.SubAggregate(r.Context(), filter, viewType, opts)
922919
if err != nil {
923920
s.logger.Error("sub-aggregate query failed", "view_type", viewTypeStr, "error", err)
924921
writeError(w, http.StatusInternalServerError, "internal_error", "Sub-aggregate query failed")
@@ -946,8 +943,7 @@ func (s *Server) handleFilteredMessages(w http.ResponseWriter, r *http.Request)
946943

947944
filter := parseMessageFilter(r)
948945

949-
ctx := context.Background()
950-
messages, err := s.engine.ListMessages(ctx, filter)
946+
messages, err := s.engine.ListMessages(r.Context(), filter)
951947
if err != nil {
952948
s.logger.Error("filtered messages query failed", "error", err)
953949
writeError(w, http.StatusInternalServerError, "internal_error", "Message query failed")
@@ -997,8 +993,7 @@ func (s *Server) handleTotalStats(w http.ResponseWriter, r *http.Request) {
997993
}
998994
}
999995

1000-
ctx := context.Background()
1001-
stats, err := s.engine.GetTotalStats(ctx, opts)
996+
stats, err := s.engine.GetTotalStats(r.Context(), opts)
1002997
if err != nil {
1003998
s.logger.Error("total stats query failed", "error", err)
1004999
writeError(w, http.StatusInternalServerError, "internal_error", "Stats query failed")
@@ -1024,10 +1019,16 @@ func (s *Server) handleFastSearch(w http.ResponseWriter, r *http.Request) {
10241019

10251020
filter := parseMessageFilter(r)
10261021

1027-
// Get view type for stats grouping
1022+
// Get view type for stats grouping (optional, defaults to senders)
10281023
var statsGroupBy query.ViewType
10291024
if v := r.URL.Query().Get("view_type"); v != "" {
1030-
statsGroupBy, _ = parseViewType(v)
1025+
var ok bool
1026+
statsGroupBy, ok = parseViewType(v)
1027+
if !ok {
1028+
writeError(w, http.StatusBadRequest, "invalid_view_type",
1029+
"Invalid view_type. Must be one of: senders, sender_names, recipients, recipient_names, domains, labels, time")
1030+
return
1031+
}
10311032
}
10321033

10331034
offset := filter.Pagination.Offset
@@ -1036,10 +1037,9 @@ func (s *Server) handleFastSearch(w http.ResponseWriter, r *http.Request) {
10361037
limit = 100
10371038
}
10381039

1039-
ctx := context.Background()
10401040
q := search.Parse(queryStr)
10411041

1042-
result, err := s.engine.SearchFastWithStats(ctx, q, queryStr, filter, statsGroupBy, limit, offset)
1042+
result, err := s.engine.SearchFastWithStats(r.Context(), q, queryStr, filter, statsGroupBy, limit, offset)
10431043
if err != nil {
10441044
s.logger.Error("fast search failed", "query", queryStr, "error", err)
10451045
writeError(w, http.StatusInternalServerError, "internal_error", "Search failed")
@@ -1082,10 +1082,9 @@ func (s *Server) handleDeepSearch(w http.ResponseWriter, r *http.Request) {
10821082
limit = 100
10831083
}
10841084

1085-
ctx := context.Background()
10861085
q := search.Parse(queryStr)
10871086

1088-
messages, err := s.engine.Search(ctx, q, limit, offset)
1087+
messages, err := s.engine.Search(r.Context(), q, limit, offset)
10891088
if err != nil {
10901089
s.logger.Error("deep search failed", "query", queryStr, "error", err)
10911090
writeError(w, http.StatusInternalServerError, "internal_error", "Search failed")

internal/remote/engine.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ func (e *Engine) Aggregate(ctx context.Context, groupBy query.ViewType, opts que
351351
params := buildAggregateQuery(groupBy, opts)
352352
path := "/api/v1/aggregates?" + params.Encode()
353353

354-
resp, err := e.store.doRequest("GET", path, nil)
354+
resp, err := e.store.doRequestWithContext(ctx, "GET", path, nil)
355355
if err != nil {
356356
return nil, err
357357
}
@@ -380,13 +380,15 @@ func (e *Engine) SubAggregate(ctx context.Context, filter query.MessageFilter, g
380380
if opts.Limit > 0 {
381381
params.Set("limit", strconv.Itoa(opts.Limit))
382382
}
383+
// Use TimeGranularity from opts, not from filter (fixes roborev finding)
384+
params.Set("time_granularity", timeGranularityToString(opts.TimeGranularity))
383385
if opts.SearchQuery != "" {
384386
params.Set("search_query", opts.SearchQuery)
385387
}
386388

387389
path := "/api/v1/aggregates/sub?" + params.Encode()
388390

389-
resp, err := e.store.doRequest("GET", path, nil)
391+
resp, err := e.store.doRequestWithContext(ctx, "GET", path, nil)
390392
if err != nil {
391393
return nil, err
392394
}
@@ -409,7 +411,7 @@ func (e *Engine) ListMessages(ctx context.Context, filter query.MessageFilter) (
409411
params := buildFilterQuery(filter)
410412
path := "/api/v1/messages/filter?" + params.Encode()
411413

412-
resp, err := e.store.doRequest("GET", path, nil)
414+
resp, err := e.store.doRequestWithContext(ctx, "GET", path, nil)
413415
if err != nil {
414416
return nil, err
415417
}
@@ -499,7 +501,7 @@ func (e *Engine) Search(ctx context.Context, q *search.Query, limit, offset int)
499501

500502
path := "/api/v1/search/deep?" + params.Encode()
501503

502-
resp, err := e.store.doRequest("GET", path, nil)
504+
resp, err := e.store.doRequestWithContext(ctx, "GET", path, nil)
503505
if err != nil {
504506
return nil, err
505507
}
@@ -549,7 +551,7 @@ func (e *Engine) SearchFastWithStats(ctx context.Context, q *search.Query, query
549551

550552
path := "/api/v1/search/fast?" + params.Encode()
551553

552-
resp, err := e.store.doRequest("GET", path, nil)
554+
resp, err := e.store.doRequestWithContext(ctx, "GET", path, nil)
553555
if err != nil {
554556
return nil, err
555557
}
@@ -611,7 +613,7 @@ func (e *Engine) GetTotalStats(ctx context.Context, opts query.StatsOptions) (*q
611613
params := buildStatsQuery(opts)
612614
path := "/api/v1/stats/total?" + params.Encode()
613615

614-
resp, err := e.store.doRequest("GET", path, nil)
616+
resp, err := e.store.doRequestWithContext(ctx, "GET", path, nil)
615617
if err != nil {
616618
return nil, err
617619
}

internal/remote/store.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package remote
33

44
import (
5+
"context"
56
"encoding/json"
67
"fmt"
78
"io"
@@ -77,9 +78,14 @@ func (s *Store) Close() error {
7778

7879
// doRequest performs an authenticated HTTP request.
7980
func (s *Store) doRequest(method, path string, body io.Reader) (*http.Response, error) {
81+
return s.doRequestWithContext(context.Background(), method, path, body)
82+
}
83+
84+
// doRequestWithContext performs an authenticated HTTP request with context support.
85+
func (s *Store) doRequestWithContext(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
8086
reqURL := s.baseURL + path
8187

82-
req, err := http.NewRequest(method, reqURL, body)
88+
req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
8389
if err != nil {
8490
return nil, fmt.Errorf("create request: %w", err)
8591
}

0 commit comments

Comments
 (0)