|
2 | 2 | package main |
3 | 3 |
|
4 | 4 | import ( |
5 | | - "flag" |
6 | | - "strings" |
7 | 5 | "fmt" |
8 | 6 | "os" |
9 | | - "path/filepath" |
10 | 7 | "runtime" |
| 8 | + "strings" |
| 9 | + |
| 10 | + "github.com/spf13/cobra" |
| 11 | + "github.com/spf13/viper" |
11 | 12 | "github.com/yanjiulab/packetforge/pkg/engine" |
12 | 13 | "github.com/yanjiulab/packetforge/pkg/packet" |
13 | 14 | "github.com/yanjiulab/packetforge/pkg/pdl" |
@@ -92,68 +93,112 @@ func VersionString() string { |
92 | 93 | ) |
93 | 94 | } |
94 | 95 |
|
95 | | - |
96 | 96 | func main() { |
97 | | - pdlDir := flag.String("proto", "proto", "Protocol definition directory (.pdl files)") |
98 | | - pslFile := flag.String("stream", "", "Packet stream language file (required)") |
99 | | - iface := flag.String("iface", "lo", "Network interface to send packets (e.g. eth0, lo)") |
100 | | - dryRun := flag.Bool("dry-run", false, "Parse and build packets only, do not actually send") |
101 | | - printVersion := flag.Bool("version", false, "print version and exit") |
102 | | - flag.Parse() |
103 | | - |
104 | | - if *printVersion { |
105 | | - fmt.Println(VersionString()) |
106 | | - return |
| 97 | + rootCmd := newRootCmd() |
| 98 | + if err := rootCmd.Execute(); err != nil { |
| 99 | + fmt.Fprintln(os.Stderr, err) |
| 100 | + os.Exit(1) |
107 | 101 | } |
| 102 | +} |
108 | 103 |
|
109 | | - if *pslFile == "" { |
110 | | - fmt.Fprintf(os.Stderr, "Usage: %s -stream <script.psl> [-proto dir] [-iface interface] [-dry-run]\n", filepath.Base(os.Args[0])) |
111 | | - flag.Usage() |
112 | | - os.Exit(1) |
| 104 | +func newRootCmd() *cobra.Command { |
| 105 | + rootCmd := &cobra.Command{ |
| 106 | + Use: "pf", |
| 107 | + Short: "PacketForge protocol registry and packet sender", |
| 108 | + RunE: func(cmd *cobra.Command, args []string) error { |
| 109 | + pslFile := viper.GetString("stream") |
| 110 | + if pslFile == "" { |
| 111 | + return fmt.Errorf("required flag \"stream\" not set") |
| 112 | + } |
| 113 | + |
| 114 | + return run(pslFile, viper.GetString("proto"), viper.GetString("iface"), viper.GetBool("dry-run"), viper.GetBool("builtin-proto")) |
| 115 | + }, |
113 | 116 | } |
114 | 117 |
|
| 118 | + rootCmd.Version = VersionString() |
| 119 | + rootCmd.SetVersionTemplate("{{.Version}}\n") |
| 120 | + |
| 121 | + flags := rootCmd.Flags() |
| 122 | + flags.StringP("proto", "p", "proto", "Protocol definition directory (.pdl files), optional") |
| 123 | + flags.StringP("stream", "s", "", "Packet stream language file (required)") |
| 124 | + flags.StringP("iface", "i", "lo", "Network interface to send packets (e.g. eth0, lo)") |
| 125 | + flags.BoolP("dry-run", "d", false, "Parse and build packets only, do not actually send") |
| 126 | + flags.BoolP("builtin-proto", "b", true, "Load built-in common protocols first (eth/vlan/arp/arp_request/arp_reply/ip/ipv6/icmp/icmp6/ndp_ns/ndp_na/udp/tcp)") |
| 127 | + |
| 128 | + _ = viper.BindPFlag("proto", flags.Lookup("proto")) |
| 129 | + _ = viper.BindPFlag("stream", flags.Lookup("stream")) |
| 130 | + _ = viper.BindPFlag("iface", flags.Lookup("iface")) |
| 131 | + _ = viper.BindPFlag("dry-run", flags.Lookup("dry-run")) |
| 132 | + _ = viper.BindPFlag("builtin-proto", flags.Lookup("builtin-proto")) |
| 133 | + viper.SetEnvPrefix("PF") |
| 134 | + viper.AutomaticEnv() |
| 135 | + rootCmd.AddCommand(newBuiltinCmd()) |
| 136 | + |
| 137 | + return rootCmd |
| 138 | +} |
| 139 | + |
| 140 | +func newBuiltinCmd() *cobra.Command { |
| 141 | + return &cobra.Command{ |
| 142 | + Use: "builtin", |
| 143 | + Short: "Show builtin protocol list", |
| 144 | + Run: func(cmd *cobra.Command, args []string) { |
| 145 | + for _, name := range pdl.BuiltinCommonProtocolNames() { |
| 146 | + fmt.Println(name) |
| 147 | + } |
| 148 | + }, |
| 149 | + } |
| 150 | +} |
| 151 | + |
| 152 | +func run(pslFile, pdlDir, iface string, dryRun bool, builtinProto bool) error { |
115 | 153 | // 1. Load PDL protocols |
116 | 154 | reg := pdl.NewRegistry() |
117 | | - if err := reg.LoadPDLDir(*pdlDir); err != nil { |
118 | | - fmt.Fprintf(os.Stderr, "Load PDL: %v\n", err) |
119 | | - os.Exit(1) |
| 155 | + if builtinProto { |
| 156 | + if err := reg.LoadBuiltinCommonProtocols(); err != nil { |
| 157 | + return fmt.Errorf("load builtin protocols: %w", err) |
| 158 | + } |
| 159 | + } |
| 160 | + if pdlDir != "" { |
| 161 | + if _, err := os.Stat(pdlDir); err == nil { |
| 162 | + if err := reg.LoadPDLDir(pdlDir); err != nil { |
| 163 | + return fmt.Errorf("load PDL dir: %w", err) |
| 164 | + } |
| 165 | + } else if !os.IsNotExist(err) || pdlDir != "proto" { |
| 166 | + return fmt.Errorf("read proto dir %q: %w", pdlDir, err) |
| 167 | + } |
120 | 168 | } |
121 | 169 |
|
122 | 170 | // 2. Parse PSL script |
123 | | - pslData, err := os.ReadFile(*pslFile) |
| 171 | + pslData, err := os.ReadFile(pslFile) |
124 | 172 | if err != nil { |
125 | | - fmt.Fprintf(os.Stderr, "Read PSL: %v\n", err) |
126 | | - os.Exit(1) |
| 173 | + return fmt.Errorf("read PSL: %w", err) |
127 | 174 | } |
128 | 175 | parser := psl.NewParser(string(pslData)) |
129 | 176 | script, err := parser.ParseScript() |
130 | 177 | if err != nil { |
131 | | - fmt.Fprintf(os.Stderr, "Parse PSL: %v\n", err) |
132 | | - os.Exit(1) |
| 178 | + return fmt.Errorf("parse PSL: %w", err) |
133 | 179 | } |
134 | 180 |
|
135 | 181 | builder := packet.NewBuilder(reg) |
136 | 182 |
|
137 | 183 | sendFn := func(data []byte) error { |
138 | | - if *dryRun { |
| 184 | + if dryRun { |
139 | 185 | fmt.Printf("[dry-run] Send %d bytes:\n%s\n", len(data), FormatTCPDump(data, 0)) |
140 | 186 | return nil |
141 | 187 | } |
142 | 188 | return nil |
143 | 189 | } |
144 | 190 |
|
145 | | - if !*dryRun { |
146 | | - sender, err := packet.NewSender(*iface) |
| 191 | + if !dryRun { |
| 192 | + sender, err := packet.NewSender(iface) |
147 | 193 | if err != nil { |
148 | | - fmt.Fprintf(os.Stderr, "Create sender: %v\n", err) |
149 | | - os.Exit(1) |
| 194 | + return fmt.Errorf("create sender: %w", err) |
150 | 195 | } |
151 | 196 | defer sender.Close() |
152 | 197 | sendFn = sender.Send |
153 | 198 | } |
154 | 199 |
|
155 | 200 | if err := engine.Run(script, builder, sendFn); err != nil { |
156 | | - fmt.Fprintf(os.Stderr, "Run: %v\n", err) |
157 | | - os.Exit(1) |
| 201 | + return fmt.Errorf("run: %w", err) |
158 | 202 | } |
| 203 | + return nil |
159 | 204 | } |
0 commit comments