Skip to content

Commit 0b39ef3

Browse files
committed
feat: add interactive CLI with TUI, check/batch/server commands, LLM integration, and config persistence
1 parent 25cf66c commit 0b39ef3

File tree

9 files changed

+1899
-2
lines changed

9 files changed

+1899
-2
lines changed

cmd/go-promptguard/batch.go

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
package main
2+
3+
import (
4+
"bufio"
5+
"context"
6+
"encoding/csv"
7+
"encoding/json"
8+
"fmt"
9+
"os"
10+
"path/filepath"
11+
"strings"
12+
"time"
13+
14+
"github.com/mdombrov-33/go-promptguard/detector"
15+
)
16+
17+
type BatchResult struct {
18+
Input string `json:"input"`
19+
Result detector.Result `json:"result"`
20+
ProcessedAt time.Time `json:"processed_at"`
21+
}
22+
23+
type BatchSummary struct {
24+
Total int
25+
Safe int
26+
Unsafe int
27+
HighRisk int
28+
MediumRisk int
29+
LowRisk int
30+
Results []BatchResult
31+
Duration time.Duration
32+
}
33+
34+
func ProcessBatch(filePath string, guard *detector.MultiDetector, progressChan chan<- int) (*BatchSummary, error) {
35+
file, err := os.Open(filePath)
36+
if err != nil {
37+
return nil, err
38+
}
39+
defer file.Close()
40+
41+
inputs, err := readInputFile(file, filePath)
42+
if err != nil {
43+
return nil, err
44+
}
45+
46+
summary := &BatchSummary{
47+
Total: len(inputs),
48+
Results: make([]BatchResult, 0, len(inputs)),
49+
}
50+
51+
startTime := time.Now()
52+
ctx := context.Background()
53+
54+
for i, input := range inputs {
55+
input = strings.TrimSpace(input)
56+
if input == "" {
57+
continue
58+
}
59+
60+
result := guard.Detect(ctx, input)
61+
62+
batchResult := BatchResult{
63+
Input: input,
64+
Result: result,
65+
ProcessedAt: time.Now(),
66+
}
67+
summary.Results = append(summary.Results, batchResult)
68+
69+
if result.Safe {
70+
summary.Safe++
71+
} else {
72+
summary.Unsafe++
73+
if result.RiskScore >= 0.9 {
74+
summary.HighRisk++
75+
} else if result.RiskScore >= 0.7 {
76+
summary.MediumRisk++
77+
} else {
78+
summary.LowRisk++
79+
}
80+
}
81+
82+
if progressChan != nil {
83+
progressChan <- i + 1
84+
}
85+
}
86+
87+
summary.Duration = time.Since(startTime)
88+
return summary, nil
89+
}
90+
91+
func readInputFile(file *os.File, filePath string) ([]string, error) {
92+
var inputs []string
93+
94+
file.Seek(0, 0)
95+
96+
if strings.HasSuffix(strings.ToLower(filePath), ".csv") {
97+
reader := csv.NewReader(file)
98+
records, err := reader.ReadAll()
99+
if err != nil {
100+
return nil, err
101+
}
102+
103+
for i, record := range records {
104+
if i == 0 && isHeaderRow(record) {
105+
continue // Skip header
106+
}
107+
if len(record) > 0 {
108+
inputs = append(inputs, record[0])
109+
}
110+
}
111+
} else {
112+
scanner := bufio.NewScanner(file)
113+
for scanner.Scan() {
114+
inputs = append(inputs, scanner.Text())
115+
}
116+
if err := scanner.Err(); err != nil {
117+
return nil, err
118+
}
119+
}
120+
121+
return inputs, nil
122+
}
123+
124+
func isHeaderRow(record []string) bool {
125+
if len(record) == 0 {
126+
return false
127+
}
128+
headers := []string{"input", "text", "prompt", "message", "content"}
129+
first := strings.ToLower(strings.TrimSpace(record[0]))
130+
for _, h := range headers {
131+
if first == h {
132+
return true
133+
}
134+
}
135+
return false
136+
}
137+
138+
func ExportResults(summary *BatchSummary, outputPath string) error {
139+
dir := filepath.Dir(outputPath)
140+
if dir != "." && dir != "" {
141+
if err := os.MkdirAll(dir, 0755); err != nil {
142+
return err
143+
}
144+
}
145+
146+
if strings.HasSuffix(strings.ToLower(outputPath), ".json") {
147+
return exportJSON(summary, outputPath)
148+
}
149+
return exportCSV(summary, outputPath)
150+
}
151+
152+
func exportJSON(summary *BatchSummary, outputPath string) error {
153+
file, err := os.Create(outputPath)
154+
if err != nil {
155+
return err
156+
}
157+
defer file.Close()
158+
159+
encoder := json.NewEncoder(file)
160+
encoder.SetIndent("", " ")
161+
return encoder.Encode(summary)
162+
}
163+
164+
func exportCSV(summary *BatchSummary, outputPath string) error {
165+
file, err := os.Create(outputPath)
166+
if err != nil {
167+
return err
168+
}
169+
defer file.Close()
170+
171+
writer := csv.NewWriter(file)
172+
defer writer.Flush()
173+
174+
header := []string{"Input", "Safe", "Risk Score", "Confidence", "Detected Patterns"}
175+
if err := writer.Write(header); err != nil {
176+
return err
177+
}
178+
179+
for _, result := range summary.Results {
180+
patterns := []string{}
181+
for _, p := range result.Result.DetectedPatterns {
182+
patterns = append(patterns, fmt.Sprintf("%s(%.2f)", p.Type, p.Score))
183+
}
184+
patternsStr := strings.Join(patterns, "; ")
185+
186+
row := []string{
187+
truncateForCSV(result.Input, 100),
188+
fmt.Sprintf("%v", result.Result.Safe),
189+
fmt.Sprintf("%.2f", result.Result.RiskScore),
190+
fmt.Sprintf("%.2f", result.Result.Confidence),
191+
patternsStr,
192+
}
193+
if err := writer.Write(row); err != nil {
194+
return err
195+
}
196+
}
197+
198+
return nil
199+
}
200+
201+
func truncateForCSV(s string, max int) string {
202+
if len(s) <= max {
203+
return s
204+
}
205+
return s[:max] + "..."
206+
}

cmd/go-promptguard/batch_cmd.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"time"
7+
8+
"github.com/fatih/color"
9+
"github.com/mdombrov-33/go-promptguard/detector"
10+
"github.com/spf13/cobra"
11+
)
12+
13+
var (
14+
batchThreshold float64
15+
batchOutput string
16+
)
17+
18+
var batchCmd = &cobra.Command{
19+
Use: "batch [input-file]",
20+
Short: "Process multiple inputs from a file",
21+
Long: `Process multiple inputs from TXT or CSV file.
22+
23+
Examples:
24+
# Process text file (one input per line)
25+
go-promptguard batch inputs.txt
26+
27+
# Process CSV file (first column)
28+
go-promptguard batch inputs.csv
29+
30+
# Save results to file
31+
go-promptguard batch inputs.txt --output results.json
32+
go-promptguard batch inputs.txt --output results.csv
33+
34+
# Custom threshold
35+
go-promptguard batch inputs.txt --threshold 0.8`,
36+
Run: runBatch,
37+
}
38+
39+
func init() {
40+
rootCmd.AddCommand(batchCmd)
41+
42+
batchCmd.Flags().Float64VarP(&batchThreshold, "threshold", "t", 0.7, "Risk threshold (0.0-1.0)")
43+
batchCmd.Flags().StringVarP(&batchOutput, "output", "o", "", "Output file (JSON or CSV)")
44+
}
45+
46+
func runBatch(cmd *cobra.Command, args []string) {
47+
if len(args) == 0 {
48+
color.Red("Error: no input file provided")
49+
fmt.Println("\nUsage: go-promptguard batch [input-file]")
50+
fmt.Println(" or: go-promptguard batch inputs.txt --output results.json")
51+
os.Exit(1)
52+
}
53+
54+
inputFile := args[0]
55+
56+
guard := detector.New(
57+
detector.WithThreshold(batchThreshold),
58+
)
59+
60+
color.Cyan("📦 Processing batch file: %s", inputFile)
61+
fmt.Println()
62+
63+
startTime := time.Now()
64+
summary, err := ProcessBatch(inputFile, guard, nil)
65+
if err != nil {
66+
color.Red("Error processing batch: %v", err)
67+
os.Exit(1)
68+
}
69+
duration := time.Since(startTime)
70+
71+
fmt.Println()
72+
color.Green("✓ Batch processing complete")
73+
fmt.Println()
74+
fmt.Printf(" Total: %d\n", summary.Total)
75+
fmt.Printf(" Safe: %d\n", summary.Safe)
76+
fmt.Printf(" Unsafe: %d\n", summary.Unsafe)
77+
if summary.Unsafe > 0 {
78+
fmt.Printf(" High: %d\n", summary.HighRisk)
79+
fmt.Printf(" Medium: %d\n", summary.MediumRisk)
80+
fmt.Printf(" Low: %d\n", summary.LowRisk)
81+
}
82+
fmt.Printf(" Duration: %s\n", duration.Round(time.Millisecond))
83+
fmt.Println()
84+
85+
if batchOutput != "" {
86+
if err := ExportResults(summary, batchOutput); err != nil {
87+
color.Red("Error saving results: %v", err)
88+
os.Exit(1)
89+
}
90+
color.Green("✓ Results saved to: %s", batchOutput)
91+
fmt.Println()
92+
}
93+
}

0 commit comments

Comments
 (0)