Skip to content

Commit efe77f8

Browse files
davidggphywesmclaude
authored
Add --account flag to search command (#165)
Add `--account` flag to `msgvault search` to scope results to a single account when multiple accounts are synced into the same archive. ``` $ msgvault search "budget proposal" --account alice@work.com # only messages belonging to alice@work.com ``` The flag also works standalone to list all messages for an account: ``` $ msgvault search --account alice@work.com ``` Additional fixes included in this PR: - Reject `--account` in remote mode with a clear error instead of silently ignoring it - Fail fast on invalid queries before opening the database - Add `AccountID` to `Query.IsEmpty()` so it is treated as a valid filter criterion - Extract shared output helpers (`formatSize`, `printJSON`, aggregate formatters) from `search.go` into `output.go` Closes #164. --------- Co-authored-by: Wes McKinney <wesmckinn+git@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 21d4633 commit efe77f8

File tree

5 files changed

+456
-116
lines changed

5 files changed

+456
-116
lines changed

cmd/msgvault/cmd/output.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package cmd
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
"strings"
8+
"text/tabwriter"
9+
"time"
10+
11+
"github.com/spf13/cobra"
12+
"github.com/wesm/msgvault/internal/query"
13+
)
14+
15+
// Common flag variables used across aggregate commands
16+
var (
17+
aggLimit int
18+
aggAfter string
19+
aggBefore string
20+
aggJSON bool
21+
)
22+
23+
// parseCommonFlags converts string flags to AggregateOptions
24+
func parseCommonFlags() (query.AggregateOptions, error) {
25+
opts := query.DefaultAggregateOptions()
26+
27+
if aggLimit > 0 {
28+
opts.Limit = aggLimit
29+
}
30+
31+
if aggAfter != "" {
32+
t, err := time.Parse("2006-01-02", aggAfter)
33+
if err != nil {
34+
return opts, fmt.Errorf("invalid after date: %w", err)
35+
}
36+
opts.After = &t
37+
}
38+
39+
if aggBefore != "" {
40+
t, err := time.Parse("2006-01-02", aggBefore)
41+
if err != nil {
42+
return opts, fmt.Errorf("invalid before date: %w", err)
43+
}
44+
opts.Before = &t
45+
}
46+
47+
return opts, nil
48+
}
49+
50+
// addCommonAggregateFlags adds shared flags to aggregate commands
51+
func addCommonAggregateFlags(cmd *cobra.Command) {
52+
cmd.Flags().IntVarP(
53+
&aggLimit, "limit", "n", 50, "Maximum number of results",
54+
)
55+
cmd.Flags().StringVar(
56+
&aggAfter, "after", "",
57+
"Filter to messages after date (YYYY-MM-DD)",
58+
)
59+
cmd.Flags().StringVar(
60+
&aggBefore, "before", "",
61+
"Filter to messages before date (YYYY-MM-DD)",
62+
)
63+
cmd.Flags().BoolVar(
64+
&aggJSON, "json", false, "Output as JSON",
65+
)
66+
}
67+
68+
// outputAggregateTable prints aggregate results as a table
69+
func outputAggregateTable(
70+
rows []query.AggregateRow, keyHeader string,
71+
) {
72+
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
73+
_, _ = fmt.Fprintf(
74+
w, "%s\tCOUNT\tSIZE\tATT SIZE\n",
75+
strings.ToUpper(keyHeader),
76+
)
77+
_, _ = fmt.Fprintln(
78+
w,
79+
strings.Repeat("─", len(keyHeader))+
80+
"\t─────\t────\t────────",
81+
)
82+
83+
for _, row := range rows {
84+
_, _ = fmt.Fprintf(w, "%s\t%d\t%s\t%s\n",
85+
truncate(row.Key, 40),
86+
row.Count,
87+
formatSize(row.TotalSize),
88+
formatSize(row.AttachmentSize),
89+
)
90+
}
91+
_ = w.Flush()
92+
fmt.Printf("\nShowing %d results\n", len(rows))
93+
}
94+
95+
// outputAggregateJSON prints aggregate results as JSON
96+
func outputAggregateJSON(rows []query.AggregateRow) error {
97+
output := make([]map[string]interface{}, len(rows))
98+
for i, row := range rows {
99+
output[i] = map[string]interface{}{
100+
"key": row.Key,
101+
"count": row.Count,
102+
"total_size": row.TotalSize,
103+
"attachment_size": row.AttachmentSize,
104+
}
105+
}
106+
return printJSON(output)
107+
}
108+
109+
func formatSize(bytes int64) string {
110+
const (
111+
KB = 1024
112+
MB = 1024 * KB
113+
GB = 1024 * MB
114+
)
115+
116+
switch {
117+
case bytes >= GB:
118+
return fmt.Sprintf("%.1fG", float64(bytes)/float64(GB))
119+
case bytes >= MB:
120+
return fmt.Sprintf("%.1fM", float64(bytes)/float64(MB))
121+
case bytes >= KB:
122+
return fmt.Sprintf("%.1fK", float64(bytes)/float64(KB))
123+
default:
124+
return fmt.Sprintf("%dB", bytes)
125+
}
126+
}
127+
128+
func printJSON(v any) error {
129+
enc := json.NewEncoder(os.Stdout)
130+
enc.SetIndent("", " ")
131+
return enc.Encode(v)
132+
}

cmd/msgvault/cmd/search.go

Lines changed: 39 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package cmd
22

33
import (
4-
"encoding/json"
54
"fmt"
65
"os"
76
"strings"
@@ -15,9 +14,10 @@ import (
1514
)
1615

1716
var (
18-
searchLimit int
19-
searchOffset int
20-
searchJSON bool
17+
searchLimit int
18+
searchOffset int
19+
searchJSON bool
20+
searchAccount string
2121
)
2222

2323
var searchCmd = &cobra.Command{
@@ -50,17 +50,22 @@ Examples:
5050
msgvault search subject:meeting after:2024-01-01
5151
msgvault search project report newer_than:30d
5252
msgvault search '"exact phrase"' label:INBOX`,
53-
Args: cobra.MinimumNArgs(1),
53+
Args: cobra.ArbitraryArgs,
5454
RunE: func(cmd *cobra.Command, args []string) error {
5555
// Join all args to form the query (allows unquoted multi-term searches)
5656
queryStr := strings.Join(args, " ")
5757

58-
if queryStr == "" {
59-
return fmt.Errorf("empty search query")
58+
if queryStr == "" && searchAccount == "" {
59+
return fmt.Errorf("provide a search query or --account flag")
6060
}
6161

6262
// Use remote search if configured
6363
if IsRemoteMode() {
64+
if searchAccount != "" {
65+
return fmt.Errorf(
66+
"--account is not supported in remote mode",
67+
)
68+
}
6469
return runRemoteSearch(queryStr)
6570
}
6671

@@ -99,12 +104,13 @@ func runRemoteSearch(queryStr string) error {
99104
func runLocalSearch(cmd *cobra.Command, queryStr string) error {
100105
// Parse the query
101106
q := search.Parse(queryStr)
102-
if q.IsEmpty() {
107+
108+
// Fail fast on invalid queries before touching the database,
109+
// unless --account is set (which requires a DB lookup to resolve).
110+
if searchAccount == "" && q.IsEmpty() {
103111
return fmt.Errorf("empty search query")
104112
}
105113

106-
fmt.Fprintf(os.Stderr, "Searching...")
107-
108114
// Open database
109115
dbPath := cfg.DatabaseDSN()
110116
s, err := store.Open(dbPath)
@@ -117,6 +123,25 @@ func runLocalSearch(cmd *cobra.Command, queryStr string) error {
117123
if err := s.InitSchema(); err != nil {
118124
return fmt.Errorf("init schema: %w", err)
119125
}
126+
127+
// Resolve --account and recheck emptiness.
128+
if searchAccount != "" {
129+
src, err := s.GetSourceByIdentifier(searchAccount)
130+
if err != nil {
131+
return fmt.Errorf("look up account: %w", err)
132+
}
133+
if src == nil {
134+
return fmt.Errorf("account %q not found", searchAccount)
135+
}
136+
q.AccountID = &src.ID
137+
}
138+
139+
if q.IsEmpty() {
140+
return fmt.Errorf("empty search query")
141+
}
142+
143+
fmt.Fprintf(os.Stderr, "Searching...")
144+
120145
if err := ensureFTSIndex(s); err != nil {
121146
return err
122147
}
@@ -177,13 +202,10 @@ func outputRemoteSearchResultsTable(results []store.APIMessage, total int64) err
177202
}
178203

179204
func outputRemoteSearchResultsJSON(results []store.APIMessage, total int64) error {
180-
output := map[string]interface{}{
205+
return printJSON(map[string]interface{}{
181206
"total": total,
182207
"results": results,
183-
}
184-
enc := json.NewEncoder(os.Stdout)
185-
enc.SetIndent("", " ")
186-
return enc.Encode(output)
208+
})
187209
}
188210

189211
func outputSearchResultsJSON(results []query.MessageSummary) error {
@@ -205,35 +227,15 @@ func outputSearchResultsJSON(results []query.MessageSummary) error {
205227
}
206228
}
207229

208-
enc := json.NewEncoder(os.Stdout)
209-
enc.SetIndent("", " ")
210-
return enc.Encode(output)
211-
}
212-
213-
func formatSize(bytes int64) string {
214-
const (
215-
KB = 1024
216-
MB = 1024 * KB
217-
GB = 1024 * MB
218-
)
219-
220-
switch {
221-
case bytes >= GB:
222-
return fmt.Sprintf("%.1fG", float64(bytes)/float64(GB))
223-
case bytes >= MB:
224-
return fmt.Sprintf("%.1fM", float64(bytes)/float64(MB))
225-
case bytes >= KB:
226-
return fmt.Sprintf("%.1fK", float64(bytes)/float64(KB))
227-
default:
228-
return fmt.Sprintf("%dB", bytes)
229-
}
230+
return printJSON(output)
230231
}
231232

232233
func init() {
233234
rootCmd.AddCommand(searchCmd)
234235
searchCmd.Flags().IntVarP(&searchLimit, "limit", "n", 50, "Maximum number of results")
235236
searchCmd.Flags().IntVar(&searchOffset, "offset", 0, "Skip first N results")
236237
searchCmd.Flags().BoolVar(&searchJSON, "json", false, "Output as JSON")
238+
searchCmd.Flags().StringVar(&searchAccount, "account", "", "Limit results to a specific account (email address)")
237239
}
238240

239241
// ensureFTSIndex checks if the FTS search index needs to be built and
@@ -264,81 +266,3 @@ func ensureFTSIndex(s *store.Store) error {
264266
fmt.Fprintf(os.Stderr, "\r [%s] 100%% %d messages indexed.\n", strings.Repeat("=", 30), n)
265267
return nil
266268
}
267-
268-
// Common flag variables used across aggregate commands
269-
var (
270-
aggLimit int
271-
aggAfter string
272-
aggBefore string
273-
aggJSON bool
274-
)
275-
276-
// parseCommonFlags converts string flags to AggregateOptions
277-
func parseCommonFlags() (query.AggregateOptions, error) {
278-
opts := query.DefaultAggregateOptions()
279-
280-
if aggLimit > 0 {
281-
opts.Limit = aggLimit
282-
}
283-
284-
if aggAfter != "" {
285-
t, err := time.Parse("2006-01-02", aggAfter)
286-
if err != nil {
287-
return opts, fmt.Errorf("invalid after date: %w", err)
288-
}
289-
opts.After = &t
290-
}
291-
292-
if aggBefore != "" {
293-
t, err := time.Parse("2006-01-02", aggBefore)
294-
if err != nil {
295-
return opts, fmt.Errorf("invalid before date: %w", err)
296-
}
297-
opts.Before = &t
298-
}
299-
300-
return opts, nil
301-
}
302-
303-
// addCommonAggregateFlags adds shared flags to aggregate commands
304-
func addCommonAggregateFlags(cmd *cobra.Command) {
305-
cmd.Flags().IntVarP(&aggLimit, "limit", "n", 50, "Maximum number of results")
306-
cmd.Flags().StringVar(&aggAfter, "after", "", "Filter to messages after date (YYYY-MM-DD)")
307-
cmd.Flags().StringVar(&aggBefore, "before", "", "Filter to messages before date (YYYY-MM-DD)")
308-
cmd.Flags().BoolVar(&aggJSON, "json", false, "Output as JSON")
309-
}
310-
311-
// outputAggregateTable prints aggregate results as a table
312-
func outputAggregateTable(rows []query.AggregateRow, keyHeader string) {
313-
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
314-
fmt.Fprintf(w, "%s\tCOUNT\tSIZE\tATT SIZE\n", strings.ToUpper(keyHeader))
315-
fmt.Fprintln(w, strings.Repeat("─", len(keyHeader))+"\t─────\t────\t────────")
316-
317-
for _, row := range rows {
318-
fmt.Fprintf(w, "%s\t%d\t%s\t%s\n",
319-
truncate(row.Key, 40),
320-
row.Count,
321-
formatSize(row.TotalSize),
322-
formatSize(row.AttachmentSize),
323-
)
324-
}
325-
w.Flush()
326-
fmt.Printf("\nShowing %d results\n", len(rows))
327-
}
328-
329-
// outputAggregateJSON prints aggregate results as JSON
330-
func outputAggregateJSON(rows []query.AggregateRow) error {
331-
output := make([]map[string]interface{}, len(rows))
332-
for i, row := range rows {
333-
output[i] = map[string]interface{}{
334-
"key": row.Key,
335-
"count": row.Count,
336-
"total_size": row.TotalSize,
337-
"attachment_size": row.AttachmentSize,
338-
}
339-
}
340-
341-
enc := json.NewEncoder(os.Stdout)
342-
enc.SetIndent("", " ")
343-
return enc.Encode(output)
344-
}

0 commit comments

Comments
 (0)