|
4 | 4 | package cmd |
5 | 5 |
|
6 | 6 | import ( |
| 7 | + "encoding/base64" |
7 | 8 | "fmt" |
8 | 9 | "io" |
| 10 | + "net/http" |
9 | 11 | "os" |
| 12 | + "path/filepath" |
10 | 13 | "strings" |
11 | 14 |
|
12 | 15 | "github.com/spf13/cobra" |
13 | | - "github.com/wavetermdev/waveterm/pkg/waveobj" |
| 16 | + "github.com/wavetermdev/waveterm/pkg/util/utilfn" |
14 | 17 | "github.com/wavetermdev/waveterm/pkg/wshrpc" |
15 | 18 | "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" |
16 | 19 | "github.com/wavetermdev/waveterm/pkg/wshutil" |
17 | 20 | ) |
18 | 21 |
|
19 | 22 | var aiCmd = &cobra.Command{ |
20 | | - Use: "ai [-] [message...]", |
21 | | - Short: "Send a message to an AI block", |
| 23 | + Use: "ai [options] [files...]", |
| 24 | + Short: "Append content to Wave AI sidebar prompt", |
| 25 | + Long: `Append content to Wave AI sidebar prompt (does not auto-submit by default) |
| 26 | +
|
| 27 | +Arguments: |
| 28 | + files... Files to attach (use '-' for stdin) |
| 29 | +
|
| 30 | +Examples: |
| 31 | + git diff | wsh ai - # Pipe diff to AI, ask question in UI |
| 32 | + wsh ai main.go # Attach file, ask question in UI |
| 33 | + wsh ai *.go -m "find bugs" # Attach files with message |
| 34 | + wsh ai -s - -m "review" < log.txt # Stdin + message, auto-submit |
| 35 | + wsh ai -n config.json # New chat with file attached`, |
22 | 36 | RunE: aiRun, |
23 | 37 | PreRunE: preRunSetupRpcClient, |
24 | 38 | DisableFlagsInUseLine: true, |
25 | 39 | } |
26 | 40 |
|
27 | | -var aiFileFlags []string |
| 41 | +var aiMessageFlag string |
| 42 | +var aiSubmitFlag bool |
28 | 43 | var aiNewBlockFlag bool |
29 | 44 |
|
30 | 45 | func init() { |
31 | 46 | rootCmd.AddCommand(aiCmd) |
32 | | - aiCmd.Flags().BoolVarP(&aiNewBlockFlag, "new", "n", false, "create a new AI block") |
33 | | - aiCmd.Flags().StringArrayVarP(&aiFileFlags, "file", "f", nil, "attach file content (use '-' for stdin)") |
| 47 | + aiCmd.Flags().StringVarP(&aiMessageFlag, "message", "m", "", "optional message/question to append after files") |
| 48 | + aiCmd.Flags().BoolVarP(&aiSubmitFlag, "submit", "s", false, "submit the prompt immediately after appending") |
| 49 | + aiCmd.Flags().BoolVarP(&aiNewBlockFlag, "new", "n", false, "create a new AI chat instead of using existing") |
34 | 50 | } |
35 | 51 |
|
36 | | -func encodeFile(builder *strings.Builder, file io.Reader, fileName string) error { |
37 | | - data, err := io.ReadAll(file) |
38 | | - if err != nil { |
39 | | - return fmt.Errorf("error reading file: %w", err) |
| 52 | +func detectMimeType(data []byte) string { |
| 53 | + mimeType := http.DetectContentType(data) |
| 54 | + return strings.Split(mimeType, ";")[0] |
| 55 | +} |
| 56 | + |
| 57 | +func getMaxFileSize(mimeType string) (int, string) { |
| 58 | + if mimeType == "application/pdf" { |
| 59 | + return 5 * 1024 * 1024, "5MB" |
40 | 60 | } |
41 | | - // Start delimiter with the file name |
42 | | - builder.WriteString(fmt.Sprintf("\n@@@start file %q\n", fileName)) |
43 | | - // Read the file content and write it to the builder |
44 | | - builder.Write(data) |
45 | | - // End delimiter with the file name |
46 | | - builder.WriteString(fmt.Sprintf("\n@@@end file %q\n\n", fileName)) |
47 | | - return nil |
| 61 | + if strings.HasPrefix(mimeType, "image/") { |
| 62 | + return 7 * 1024 * 1024, "7MB" |
| 63 | + } |
| 64 | + return 200 * 1024, "200KB" |
48 | 65 | } |
49 | 66 |
|
50 | 67 | func aiRun(cmd *cobra.Command, args []string) (rtnErr error) { |
51 | 68 | defer func() { |
52 | 69 | sendActivity("ai", rtnErr == nil) |
53 | 70 | }() |
54 | 71 |
|
55 | | - if len(args) == 0 { |
| 72 | + if len(args) == 0 && aiMessageFlag == "" { |
56 | 73 | OutputHelpMessage(cmd) |
57 | | - return fmt.Errorf("no message provided") |
| 74 | + return fmt.Errorf("no files or message provided") |
58 | 75 | } |
59 | 76 |
|
| 77 | + const maxFileCount = 15 |
| 78 | + const rpcTimeout = 30000 |
| 79 | + |
| 80 | + var allFiles []wshrpc.AIAttachedFile |
60 | 81 | var stdinUsed bool |
61 | | - var message strings.Builder |
62 | 82 |
|
63 | | - // Handle file attachments first |
64 | | - for _, file := range aiFileFlags { |
65 | | - if file == "-" { |
| 83 | + if len(args) > maxFileCount { |
| 84 | + return fmt.Errorf("too many files (maximum %d files allowed)", maxFileCount) |
| 85 | + } |
| 86 | + |
| 87 | + for _, filePath := range args { |
| 88 | + var data []byte |
| 89 | + var fileName string |
| 90 | + var mimeType string |
| 91 | + var err error |
| 92 | + |
| 93 | + if filePath == "-" { |
66 | 94 | if stdinUsed { |
67 | 95 | return fmt.Errorf("stdin (-) can only be used once") |
68 | 96 | } |
69 | 97 | stdinUsed = true |
70 | | - if err := encodeFile(&message, os.Stdin, "<stdin>"); err != nil { |
| 98 | + |
| 99 | + data, err = io.ReadAll(os.Stdin) |
| 100 | + if err != nil { |
71 | 101 | return fmt.Errorf("reading from stdin: %w", err) |
72 | 102 | } |
| 103 | + fileName = "stdin" |
| 104 | + mimeType = "text/plain" |
73 | 105 | } else { |
74 | | - fd, err := os.Open(file) |
| 106 | + fileInfo, err := os.Stat(filePath) |
75 | 107 | if err != nil { |
76 | | - return fmt.Errorf("opening file %s: %w", file, err) |
| 108 | + return fmt.Errorf("accessing file %s: %w", filePath, err) |
77 | 109 | } |
78 | | - defer fd.Close() |
79 | | - if err := encodeFile(&message, fd, file); err != nil { |
80 | | - return fmt.Errorf("reading file %s: %w", file, err) |
| 110 | + if fileInfo.IsDir() { |
| 111 | + return fmt.Errorf("%s is a directory, not a file", filePath) |
81 | 112 | } |
82 | | - } |
83 | | - } |
84 | 113 |
|
85 | | - // Default to "waveai" block |
86 | | - isDefaultBlock := blockArg == "" |
87 | | - if isDefaultBlock { |
88 | | - blockArg = "view@waveai" |
89 | | - } |
90 | | - var fullORef *waveobj.ORef |
91 | | - var err error |
92 | | - if !aiNewBlockFlag { |
93 | | - fullORef, err = resolveSimpleId(blockArg) |
94 | | - } |
95 | | - if (err != nil && isDefaultBlock) || aiNewBlockFlag { |
96 | | - // Create new AI block if default block doesn't exist |
97 | | - data := &wshrpc.CommandCreateBlockData{ |
98 | | - BlockDef: &waveobj.BlockDef{ |
99 | | - Meta: map[string]interface{}{ |
100 | | - waveobj.MetaKey_View: "waveai", |
101 | | - }, |
102 | | - }, |
103 | | - Focused: true, |
| 114 | + data, err = os.ReadFile(filePath) |
| 115 | + if err != nil { |
| 116 | + return fmt.Errorf("reading file %s: %w", filePath, err) |
| 117 | + } |
| 118 | + fileName = filepath.Base(filePath) |
| 119 | + mimeType = detectMimeType(data) |
104 | 120 | } |
105 | 121 |
|
106 | | - newORef, err := wshclient.CreateBlockCommand(RpcClient, *data, &wshrpc.RpcOpts{Timeout: 2000}) |
107 | | - if err != nil { |
108 | | - return fmt.Errorf("creating AI block: %w", err) |
109 | | - } |
110 | | - fullORef = &newORef |
111 | | - // Wait for the block's route to be available |
112 | | - gotRoute, err := wshclient.WaitForRouteCommand(RpcClient, wshrpc.CommandWaitForRouteData{ |
113 | | - RouteId: wshutil.MakeFeBlockRouteId(fullORef.OID), |
114 | | - WaitMs: 4000, |
115 | | - }, &wshrpc.RpcOpts{Timeout: 5000}) |
116 | | - if err != nil { |
117 | | - return fmt.Errorf("waiting for AI block: %w", err) |
| 122 | + isPDF := mimeType == "application/pdf" |
| 123 | + isImage := strings.HasPrefix(mimeType, "image/") |
| 124 | + |
| 125 | + if !isPDF && !isImage { |
| 126 | + mimeType = "text/plain" |
| 127 | + if utilfn.ContainsBinaryData(data) { |
| 128 | + return fmt.Errorf("file %s contains binary data and cannot be uploaded as text", fileName) |
| 129 | + } |
118 | 130 | } |
119 | | - if !gotRoute { |
120 | | - return fmt.Errorf("AI block route could not be established") |
| 131 | + |
| 132 | + maxSize, sizeStr := getMaxFileSize(mimeType) |
| 133 | + if len(data) > maxSize { |
| 134 | + return fmt.Errorf("file %s exceeds maximum size of %s for %s files", fileName, sizeStr, mimeType) |
121 | 135 | } |
122 | | - } else if err != nil { |
123 | | - return fmt.Errorf("resolving block: %w", err) |
| 136 | + |
| 137 | + allFiles = append(allFiles, wshrpc.AIAttachedFile{ |
| 138 | + Name: fileName, |
| 139 | + Type: mimeType, |
| 140 | + Size: len(data), |
| 141 | + Data64: base64.StdEncoding.EncodeToString(data), |
| 142 | + }) |
124 | 143 | } |
125 | 144 |
|
126 | | - // Create the route for this block |
127 | | - route := wshutil.MakeFeBlockRouteId(fullORef.OID) |
| 145 | + tabId := os.Getenv("WAVETERM_TABID") |
| 146 | + if tabId == "" { |
| 147 | + return fmt.Errorf("WAVETERM_TABID environment variable not set") |
| 148 | + } |
| 149 | + |
| 150 | + route := wshutil.MakeTabRouteId(tabId) |
128 | 151 |
|
129 | | - // Then handle main message |
130 | | - if args[0] == "-" { |
131 | | - if stdinUsed { |
132 | | - return fmt.Errorf("stdin (-) can only be used once") |
| 152 | + if aiNewBlockFlag { |
| 153 | + newChatData := wshrpc.CommandWaveAIAddContextData{ |
| 154 | + NewChat: true, |
133 | 155 | } |
134 | | - data, err := io.ReadAll(os.Stdin) |
| 156 | + err := wshclient.WaveAIAddContextCommand(RpcClient, newChatData, &wshrpc.RpcOpts{ |
| 157 | + Route: route, |
| 158 | + Timeout: rpcTimeout, |
| 159 | + }) |
135 | 160 | if err != nil { |
136 | | - return fmt.Errorf("reading from stdin: %w", err) |
137 | | - } |
138 | | - message.Write(data) |
139 | | - |
140 | | - // Also include any remaining arguments (excluding the "-" itself) |
141 | | - if len(args) > 1 { |
142 | | - if message.Len() > 0 { |
143 | | - message.WriteString(" ") |
144 | | - } |
145 | | - message.WriteString(strings.Join(args[1:], " ")) |
| 161 | + return fmt.Errorf("creating new chat: %w", err) |
146 | 162 | } |
147 | | - } else { |
148 | | - message.WriteString(strings.Join(args, " ")) |
149 | 163 | } |
150 | 164 |
|
151 | | - if message.Len() == 0 { |
152 | | - return fmt.Errorf("message is empty") |
153 | | - } |
154 | | - if message.Len() > 50*1024 { |
155 | | - return fmt.Errorf("current max message size is 50k") |
| 165 | + for _, file := range allFiles { |
| 166 | + contextData := wshrpc.CommandWaveAIAddContextData{ |
| 167 | + Files: []wshrpc.AIAttachedFile{file}, |
| 168 | + } |
| 169 | + err := wshclient.WaveAIAddContextCommand(RpcClient, contextData, &wshrpc.RpcOpts{ |
| 170 | + Route: route, |
| 171 | + Timeout: rpcTimeout, |
| 172 | + }) |
| 173 | + if err != nil { |
| 174 | + return fmt.Errorf("adding file %s: %w", file.Name, err) |
| 175 | + } |
156 | 176 | } |
157 | 177 |
|
158 | | - messageData := wshrpc.AiMessageData{ |
159 | | - Message: message.String(), |
160 | | - } |
161 | | - err = wshclient.AiSendMessageCommand(RpcClient, messageData, &wshrpc.RpcOpts{ |
162 | | - Route: route, |
163 | | - Timeout: 2000, |
164 | | - }) |
165 | | - if err != nil { |
166 | | - return fmt.Errorf("sending message: %w", err) |
| 178 | + if aiMessageFlag != "" || aiSubmitFlag { |
| 179 | + finalContextData := wshrpc.CommandWaveAIAddContextData{ |
| 180 | + Text: aiMessageFlag, |
| 181 | + Submit: aiSubmitFlag, |
| 182 | + } |
| 183 | + err := wshclient.WaveAIAddContextCommand(RpcClient, finalContextData, &wshrpc.RpcOpts{ |
| 184 | + Route: route, |
| 185 | + Timeout: rpcTimeout, |
| 186 | + }) |
| 187 | + if err != nil { |
| 188 | + return fmt.Errorf("adding context: %w", err) |
| 189 | + } |
167 | 190 | } |
168 | 191 |
|
169 | 192 | return nil |
|
0 commit comments