Skip to content

Commit 5722cf2

Browse files
committed
golink: address PR feedback
* use `srv.ListenTLS` API instead of DIY'ing it. * DRY up http & https listener code. * use type safe URL generation for redirect handler * use status API to determine HTTPS capabilities directly. * handle http only case gracefully. Signed-off-by: Patrick O'Doherty <[email protected]>
1 parent 23f9f96 commit 5722cf2

File tree

1 file changed

+41
-55
lines changed

1 file changed

+41
-55
lines changed

golink.go

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"bytes"
1010
"context"
1111
"crypto/rand"
12-
"crypto/tls"
1312
"embed"
1413
"encoding/base64"
1514
"encoding/json"
@@ -159,6 +158,7 @@ func Run() error {
159158
return errors.New("--hostname, if specified, cannot be empty")
160159
}
161160

161+
// create tsNet server and wait for it to be ready & connected.
162162
srv := &tsnet.Server{
163163
ControlURL: *controlURL,
164164
Hostname: *hostname,
@@ -171,77 +171,65 @@ func Run() error {
171171
return err
172172
}
173173

174-
// create tsNet server and wait for it to be ready & connected.
175174
localClient, _ = srv.LocalClient()
175+
out:
176+
for {
177+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
178+
defer cancel()
179+
status, err := srv.Up(ctx)
180+
if err == nil && status != nil {
181+
break out
182+
}
183+
}
184+
176185
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
177186
defer cancel()
178-
_, err = srv.Up(ctx)
187+
status, err := localClient.Status(ctx)
179188
if err != nil {
180189
return err
181190
}
191+
enableTLS := status.Self.HasCap(tailcfg.CapabilityHTTPS)
192+
dnsName := status.Self.DNSName
182193

183-
enableTLS := len(srv.CertDomains()) > 0
194+
var httpHandler http.Handler
195+
var httpsHandler http.Handler
184196
if enableTLS {
185-
// warm the certificate cache for all cert domains to prevent users waiting
186-
// on ACME challenges in-line on their first request.
187-
for _, d := range srv.CertDomains() {
188-
log.Printf("Provisioning TLS certificate for %s ...", d)
189-
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
190-
defer cancel()
191-
192-
_, _, err := localClient.CertPair(ctx, d)
193-
if err != nil {
194-
return err
195-
}
196-
}
197+
redirectFqdn := strings.TrimSuffix(dnsName, ".")
198+
httpHandler = redirectHandler(redirectFqdn)
199+
httpsHandler = HSTS(serveHandler())
200+
} else {
201+
httpHandler = serveHandler()
202+
httpsHandler = nil
203+
}
197204

198-
redirectFqdn := srv.CertDomains()[0]
199-
// HTTP listener that redirects to our HTTPS listener.
200-
log.Println("Listening on :80")
201-
httpListener, err := srv.Listen("tcp", ":80")
205+
if httpsHandler != nil {
206+
log.Println("Listening on :443")
207+
httpsListener, err := srv.ListenTLS("tcp", ":443")
202208
if err != nil {
203209
return err
204210
}
205211
go func() error {
206-
log.Printf("Serving http://%s/ ...", *hostname)
207-
if err := http.Serve(httpListener, redirectHandler(redirectFqdn)); err != nil {
212+
log.Printf("Serving https://%s/ ...", strings.TrimSuffix(dnsName, "."))
213+
if err := http.Serve(httpsListener, httpsHandler); err != nil {
208214
return err
209215
}
210216
return nil
211217
}()
218+
}
212219

213-
log.Println("Listening on :443")
214-
httpsListener, err := srv.Listen("tcp", ":443")
215-
if err != nil {
216-
return err
217-
}
218-
s := http.Server{
219-
Addr: ":443",
220-
Handler: HSTS(serveHandler()),
221-
TLSConfig: &tls.Config{
222-
GetCertificate: localClient.GetCertificate,
223-
},
224-
}
225-
226-
log.Printf("Serving https://%s/\n", redirectFqdn)
227-
if err := s.ServeTLS(httpsListener, "", ""); err != nil {
228-
return err
229-
}
230-
return nil
231-
} else {
232-
// no TLS, just serve on :80
233-
log.Println("Listening on :80")
234-
httpListener, err := srv.Listen("tcp", ":80")
235-
if err != nil {
236-
return err
237-
}
238-
log.Printf("Serving http://%s/ ...", *hostname)
239-
if err := http.Serve(httpListener, serveHandler()); err != nil {
240-
return err
241-
}
242-
return nil
220+
// HTTP handler that either serves primary handler or redirects to HTTPS
221+
// depending on availability of TLS.
222+
log.Println("Listening on :80")
223+
httpListener, err := srv.Listen("tcp", ":80")
224+
if err != nil {
225+
return err
226+
}
227+
log.Printf("Serving http://%s/ ...", *hostname)
228+
if err := http.Serve(httpListener, httpHandler); err != nil {
229+
return err
243230
}
244231

232+
return nil
245233
}
246234

247235
var (
@@ -351,9 +339,7 @@ func deleteLinkStats(link *Link) {
351339
// requests. It redirects all requests to the HTTPs version of the same URL.
352340
func redirectHandler(hostname string) http.Handler {
353341
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
354-
path := r.URL.Path
355-
newUrl := fmt.Sprintf("https://%s%s", hostname, path)
356-
http.Redirect(w, r, newUrl, http.StatusMovedPermanently)
342+
http.Redirect(w, r, (&url.URL{Scheme: "https", Host: hostname, Path: r.URL.Path}).String(), http.StatusMovedPermanently)
357343
})
358344
}
359345

0 commit comments

Comments
 (0)