Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions primaryHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ func NewPrimaryHandler(logger *zap.Logger, manager device.Manager, v *viper.Vipe
inboundTimeout = getInboundTimeout(v)
apiHandler = r.PathPrefix(fmt.Sprintf("%s/{version:%s|%s}", baseURI, v2, version)).Subrouter()

authConstructor = NoOpConstructor
authEnforcer = NoOpConstructor
jwtAuthEnforcer = NoOpConstructor
basicAuthEnforcer = NoOpConstructor

deviceAuthRules = bascule.Validators{} //auth rules for device registration endpoints
serviceAuthRules = bascule.Validators{} //auth rules for everything else
Expand All @@ -152,7 +152,11 @@ func NewPrimaryHandler(logger *zap.Logger, manager device.Manager, v *viper.Vipe
return nil, err
}

authConstructorOptions := []basculehttp.COption{
jwtAuthConstructorOptions := []basculehttp.COption{
basculehttp.WithCLogger(getLogger),
basculehttp.WithCErrorResponseFunc(listener.OnErrorResponse),
}
basicAuthConstructorOptions := []basculehttp.COption{
basculehttp.WithCLogger(getLogger),
basculehttp.WithCErrorResponseFunc(listener.OnErrorResponse),
}
Expand Down Expand Up @@ -200,7 +204,7 @@ func NewPrimaryHandler(logger *zap.Logger, manager device.Manager, v *viper.Vipe
resolver.AddListener(cml)
resolver.AddListener(czl)

authConstructorOptions = append(authConstructorOptions, basculehttp.WithTokenFactory("Bearer", basculehttp.BearerTokenFactory{
jwtAuthConstructorOptions = append(jwtAuthConstructorOptions, basculehttp.WithTokenFactory("Bearer", basculehttp.BearerTokenFactory{
DefaultKeyID: DefaultKeyID,
Resolver: resolver,
Parser: bascule.DefaultJWTParser,
Expand All @@ -219,7 +223,7 @@ func NewPrimaryHandler(logger *zap.Logger, manager device.Manager, v *viper.Vipe
userPassMap := buildUserPassMap(logger, v.GetStringSlice(ServiceBasicAuthConfigKey))

if len(userPassMap) > 0 {
authConstructorOptions = append(authConstructorOptions,
basicAuthConstructorOptions = append(basicAuthConstructorOptions,
basculehttp.WithTokenFactory("Basic", basculehttp.BasicTokenFactory(userPassMap)))

serviceAuthRules = append(serviceAuthRules, basculechecks.AllowAll())
Expand Down Expand Up @@ -257,31 +261,52 @@ func NewPrimaryHandler(logger *zap.Logger, manager device.Manager, v *viper.Vipe
wrpRouterHandler = withDeviceAccessCheck(logger, wrpRouterHandler, deviceAccessCheck)
}

authConstructor = basculehttp.NewConstructor(authConstructorOptions...)
authConstructorLegacy := basculehttp.NewConstructor(append([]basculehttp.COption{
jwtAuthConstructor := basculehttp.NewConstructor(jwtAuthConstructorOptions...)
basicAuthConstructor := basculehttp.NewConstructor(basicAuthConstructorOptions...)
jwtAuthConstructorLegacy := basculehttp.NewConstructor(append([]basculehttp.COption{
basculehttp.WithCErrorHTTPResponseFunc(basculehttp.LegacyOnErrorHTTPResponse),
}, authConstructorOptions...)...)
}, jwtAuthConstructorOptions...)...)
basicAuthConstructorLegacy := basculehttp.NewConstructor(append([]basculehttp.COption{
basculehttp.WithCErrorHTTPResponseFunc(basculehttp.LegacyOnErrorHTTPResponse),
}, basicAuthConstructorOptions...)...)

authEnforcer = basculehttp.NewEnforcer(
jwtAuthEnforcer = basculehttp.NewEnforcer(
basculehttp.WithELogger(getLogger),
basculehttp.WithRules("Basic", serviceAuthRules),
basculehttp.WithRules("Bearer", deviceAuthRules),
basculehttp.WithEErrorResponseFunc(listener.OnErrorResponse),
)
basicAuthEnforcer = basculehttp.NewEnforcer(
basculehttp.WithELogger(getLogger),
basculehttp.WithRules("Basic", serviceAuthRules),
basculehttp.WithEErrorResponseFunc(listener.OnErrorResponse),
)
jwtAuthChain := alice.New(setLogger(logger), jwtAuthConstructor, jwtAuthEnforcer, basculehttp.NewListenerDecorator(listener))
basicAuthChain := alice.New(setLogger(logger), basicAuthConstructor, basicAuthEnforcer, basculehttp.NewListenerDecorator(listener))
jwtAuthChainV2 := alice.New(setLogger(logger), jwtAuthConstructorLegacy, jwtAuthEnforcer, basculehttp.NewListenerDecorator(listener))
basicAuthChainV2 := alice.New(setLogger(logger), basicAuthConstructorLegacy, basicAuthEnforcer, basculehttp.NewListenerDecorator(listener))

authChain := alice.New(setLogger(logger), authConstructor, authEnforcer, basculehttp.NewListenerDecorator(listener))
authChainV2 := alice.New(setLogger(logger), authConstructorLegacy, authEnforcer, basculehttp.NewListenerDecorator(listener))

versionCompatibleAuth := alice.New(func(next http.Handler) http.Handler {
versionCompatibleJWTAuth := alice.New(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(r http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)
if vars != nil {
if vars["version"] == v2 {
jwtAuthChainV2.Then(next).ServeHTTP(r, req)
return
}
}
jwtAuthChain.Then(next).ServeHTTP(r, req)
})
})
versionCompatibleBasicAuth := alice.New(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(r http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)
if vars != nil {
if vars["version"] == v2 {
authChainV2.Then(next).ServeHTTP(r, req)
basicAuthChainV2.Then(next).ServeHTTP(r, req)
return
}
}
authChain.Then(next).ServeHTTP(r, req)
basicAuthChain.Then(next).ServeHTTP(r, req)
})
})

Expand All @@ -290,12 +315,12 @@ func NewPrimaryHandler(logger *zap.Logger, manager device.Manager, v *viper.Vipe
xtimeout.NewConstructor(xtimeout.Options{
Timeout: inboundTimeout,
})).
Extend(versionCompatibleAuth).
Extend(versionCompatibleBasicAuth).
Then(wrphttp.NewHTTPHandler(wrpRouterHandler)),
).Methods("POST", "PATCH")

apiHandler.Handle("/devices",
versionCompatibleAuth.Then(&device.ListHandler{
versionCompatibleBasicAuth.Then(&device.ListHandler{
Logger: logger,
Registry: manager,
})).Methods("GET")
Expand Down Expand Up @@ -342,11 +367,14 @@ func NewPrimaryHandler(logger *zap.Logger, manager device.Manager, v *viper.Vipe
r.Handle(
fmt.Sprintf("%s/{version:%s|%s}/device", baseURI, v2, version),
deviceConnectChain.
Extend(versionCompatibleAuth).
Extend(versionCompatibleJWTAuth).
Append(DeviceMetadataMiddleware(getLogger)).
Then(connectHandler),
).HeadersRegexp("Authorization", ".*")

// Openfail handler. Limited for devices to connect, nothing else.
// Services wanted to interact with either sets of devices (with a themis jwt or without)
// are ALWAYS required to basic auth.
r.Handle(
fmt.Sprintf("%s/{version:%s|%s}/device", baseURI, v2, version),
deviceConnectChain.
Expand All @@ -357,7 +385,7 @@ func NewPrimaryHandler(logger *zap.Logger, manager device.Manager, v *viper.Vipe
r.Handle(
fmt.Sprintf("%s/{version:%s|%s}/device", baseURI, v2, version),
deviceConnectChain.
Extend(versionCompatibleAuth).
Extend(versionCompatibleJWTAuth).
Append(DeviceMetadataMiddleware(getLogger)).
Then(connectHandler),
)
Expand All @@ -367,7 +395,7 @@ func NewPrimaryHandler(logger *zap.Logger, manager device.Manager, v *viper.Vipe
"/device/{deviceID}/stat",
alice.New(
device.UseID.FromPath("deviceID")).
Extend(versionCompatibleAuth).
Extend(versionCompatibleBasicAuth).
Then(&device.StatHandler{
Logger: logger,
Registry: manager,
Expand Down