|
| 1 | +package mailer |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "errors" |
| 6 | + "fmt" |
| 7 | + "net/smtp" |
| 8 | + "slices" |
| 9 | +) |
| 10 | + |
| 11 | +func PlainOrLoginAuth(username, password, host string) smtp.Auth { |
| 12 | + return &plainOrLoginAuth{username: username, password: password, host: host} |
| 13 | +} |
| 14 | + |
| 15 | +func isLocalhost(name string) bool { |
| 16 | + return name == "localhost" || name == "127.0.0.1" || name == "::1" |
| 17 | +} |
| 18 | + |
| 19 | +type plainOrLoginAuth struct { |
| 20 | + username string |
| 21 | + password string |
| 22 | + host string |
| 23 | + authMethod string |
| 24 | +} |
| 25 | + |
| 26 | +func (a *plainOrLoginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { |
| 27 | + // Must have TLS, or else localhost server. |
| 28 | + // Note: If TLS is not true, then we can't trust ANYTHING in ServerInfo. |
| 29 | + // In particular, it doesn't matter if the server advertises PLAIN auth. |
| 30 | + // That might just be the attacker saying |
| 31 | + // "it's ok, you can trust me with your password." |
| 32 | + if !server.TLS && !isLocalhost(server.Name) { |
| 33 | + return "", nil, errors.New("unencrypted connection") |
| 34 | + } |
| 35 | + if server.Name != a.host { |
| 36 | + return "", nil, errors.New("wrong host name") |
| 37 | + } |
| 38 | + if !slices.Contains(server.Auth, "PLAIN") { |
| 39 | + a.authMethod = "LOGIN" |
| 40 | + return a.authMethod, nil, nil |
| 41 | + } else { |
| 42 | + a.authMethod = "PLAIN" |
| 43 | + resp := []byte("\x00" + a.username + "\x00" + a.password) |
| 44 | + return a.authMethod, resp, nil |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +func (a *plainOrLoginAuth) Next(fromServer []byte, more bool) ([]byte, error) { |
| 49 | + if !more { |
| 50 | + return nil, nil |
| 51 | + } |
| 52 | + |
| 53 | + if a.authMethod == "PLAIN" { |
| 54 | + // We've already sent everything. |
| 55 | + return nil, errors.New("unexpected server challenge") |
| 56 | + } |
| 57 | + |
| 58 | + switch { |
| 59 | + case bytes.Equal(fromServer, []byte("Username:")): |
| 60 | + return []byte(a.username), nil |
| 61 | + case bytes.Equal(fromServer, []byte("Password:")): |
| 62 | + return []byte(a.password), nil |
| 63 | + default: |
| 64 | + return nil, fmt.Errorf("unexpected server challenge: %s", fromServer) |
| 65 | + } |
| 66 | +} |
0 commit comments