|
9 | 9 | "bytes"
|
10 | 10 | "context"
|
11 | 11 | "crypto/rand"
|
| 12 | + "crypto/tls" |
12 | 13 | "embed"
|
13 | 14 | "encoding/base64"
|
14 | 15 | "encoding/json"
|
@@ -169,18 +170,78 @@ func Run() error {
|
169 | 170 | if err := srv.Start(); err != nil {
|
170 | 171 | return err
|
171 | 172 | }
|
172 |
| - localClient, _ = srv.LocalClient() |
173 | 173 |
|
174 |
| - l80, err := srv.Listen("tcp", ":80") |
| 174 | + // create tsNet server and wait for it to be ready & connected. |
| 175 | + localClient, _ = srv.LocalClient() |
| 176 | + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
| 177 | + defer cancel() |
| 178 | + _, err = srv.Up(ctx) |
175 | 179 | if err != nil {
|
176 | 180 | return err
|
177 | 181 | }
|
178 | 182 |
|
179 |
| - log.Printf("Serving http://%s/ ...", *hostname) |
180 |
| - if err := http.Serve(l80, serveHandler()); err != nil { |
181 |
| - return err |
| 183 | + enableTLS := len(srv.CertDomains()) > 0 |
| 184 | + if enableTLS { |
| 185 | + // warm the certificate cache for all cert domains to prevent users waiting |
| 186 | + // on ACME challenges in-line on their first request. |
| 187 | + for _, d := range srv.CertDomains() { |
| 188 | + log.Printf("Provisioning TLS certificate for %s ...", d) |
| 189 | + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) |
| 190 | + defer cancel() |
| 191 | + |
| 192 | + _, _, err := localClient.CertPair(ctx, d) |
| 193 | + if err != nil { |
| 194 | + return err |
| 195 | + } |
| 196 | + } |
| 197 | + |
| 198 | + redirectFqdn := srv.CertDomains()[0] |
| 199 | + // HTTP listener that redirects to our HTTPS listener. |
| 200 | + log.Println("Listening on :80") |
| 201 | + httpListener, err := srv.Listen("tcp", ":80") |
| 202 | + if err != nil { |
| 203 | + return err |
| 204 | + } |
| 205 | + go func() error { |
| 206 | + log.Printf("Serving http://%s/ ...", *hostname) |
| 207 | + if err := http.Serve(httpListener, redirectHandler(redirectFqdn)); err != nil { |
| 208 | + return err |
| 209 | + } |
| 210 | + return nil |
| 211 | + }() |
| 212 | + |
| 213 | + log.Println("Listening on :443") |
| 214 | + httpsListener, err := srv.Listen("tcp", ":443") |
| 215 | + if err != nil { |
| 216 | + return err |
| 217 | + } |
| 218 | + s := http.Server{ |
| 219 | + Addr: ":443", |
| 220 | + Handler: serveHandler(), |
| 221 | + TLSConfig: &tls.Config{ |
| 222 | + GetCertificate: localClient.GetCertificate, |
| 223 | + }, |
| 224 | + } |
| 225 | + |
| 226 | + log.Printf("Serving https://%s/\n", redirectFqdn) |
| 227 | + if err := s.ServeTLS(httpsListener, "", ""); err != nil { |
| 228 | + return err |
| 229 | + } |
| 230 | + return nil |
| 231 | + } else { |
| 232 | + // no TLS, just serve on :80 |
| 233 | + log.Println("Listening on :80") |
| 234 | + httpListener, err := srv.Listen("tcp", ":80") |
| 235 | + if err != nil { |
| 236 | + return err |
| 237 | + } |
| 238 | + log.Printf("Serving http://%s/ ...", *hostname) |
| 239 | + if err := http.Serve(httpListener, serveHandler()); err != nil { |
| 240 | + return err |
| 241 | + } |
| 242 | + return nil |
182 | 243 | }
|
183 |
| - return nil |
| 244 | + |
184 | 245 | }
|
185 | 246 |
|
186 | 247 | var (
|
@@ -286,6 +347,16 @@ func deleteLinkStats(link *Link) {
|
286 | 347 | db.DeleteStats(link.Short)
|
287 | 348 | }
|
288 | 349 |
|
| 350 | +// redirectHandler returns the http.Handler for serving all plaintext HTTP |
| 351 | +// requests. It redirects all requests to the HTTPs version of the same URL. |
| 352 | +func redirectHandler(hostname string) http.Handler { |
| 353 | + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 354 | + path := r.URL.Path |
| 355 | + newUrl := fmt.Sprintf("https://%s%s", hostname, path) |
| 356 | + http.Redirect(w, r, newUrl, http.StatusMovedPermanently) |
| 357 | + }) |
| 358 | +} |
| 359 | + |
289 | 360 | // serverHandler returns the main http.Handler for serving all requests.
|
290 | 361 | func serveHandler() http.Handler {
|
291 | 362 | mux := http.NewServeMux()
|
|
0 commit comments