Skip to content

Commit e7088cf

Browse files
author
mirkobrombin
committed
feat: Add authentication plugin for protected routes
1 parent e7e4c92 commit e7088cf

File tree

3 files changed

+264
-0
lines changed

3 files changed

+264
-0
lines changed

cmd/goup/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ func main() {
1212
// Register your plugins here:
1313
pluginManager.Register(&plugins.CustomHeaderPlugin{})
1414
pluginManager.Register(&plugins.PHPPlugin{})
15+
pluginManager.Register(&plugins.AuthPlugin{})
1516

1617
cli.Execute()
1718
}

plugins/auth.go

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
package plugins
2+
3+
import (
4+
"encoding/base64"
5+
"errors"
6+
"net/http"
7+
"strings"
8+
"sync"
9+
"time"
10+
11+
"github.com/mirkobrombin/goup/internal/config"
12+
"github.com/mirkobrombin/goup/internal/server/middleware"
13+
log "github.com/sirupsen/logrus"
14+
)
15+
16+
// AuthPlugin provides HTTP Basic Authentication for protected paths.
17+
type AuthPlugin struct{}
18+
19+
// Name returns the name of the plugin.
20+
func (p *AuthPlugin) Name() string {
21+
return "AuthPlugin"
22+
}
23+
24+
// AuthPluginConfig represents the configuration for the AuthPlugin.
25+
type AuthPluginConfig struct {
26+
// URL paths to protect with authentication.
27+
ProtectedPaths []string `json:"protected_paths"`
28+
// username:password pairs for authentication.
29+
Credentials map[string]string `json:"credentials"`
30+
// Session expiration in seconds.
31+
// -1 means sessions never expire. Maximum allowed is 86400 seconds (24 hours).
32+
SessionExpiration int `json:"session_expiration"`
33+
}
34+
35+
// session represents an authenticated session.
36+
type session struct {
37+
Username string
38+
Expiry time.Time
39+
}
40+
41+
// AuthPluginState internal state.
42+
type AuthPluginState struct {
43+
sessions map[string]session
44+
mu sync.RWMutex
45+
}
46+
47+
// Init registers the authentication middleware globally.
48+
func (p *AuthPlugin) Init(mwManager *middleware.MiddlewareManager) error {
49+
return nil
50+
}
51+
52+
// InitForSite initializes the plugin for a specific site.
53+
func (p *AuthPlugin) InitForSite(mwManager *middleware.MiddlewareManager, logger *log.Logger, conf config.SiteConfig) error {
54+
pluginConfigRaw, ok := conf.PluginConfigs[p.Name()]
55+
if !ok {
56+
logger.Warnf("AuthPlugin config not found for host: %s", conf.Domain)
57+
return nil
58+
}
59+
60+
// Parse plugin configuration.
61+
authConfig := AuthPluginConfig{}
62+
if rawMap, ok := pluginConfigRaw.(map[string]interface{}); ok {
63+
// ProtectedPaths
64+
if paths, ok := rawMap["protected_paths"].([]interface{}); ok {
65+
for _, path := range paths {
66+
if pStr, ok := path.(string); ok {
67+
authConfig.ProtectedPaths = append(authConfig.ProtectedPaths, pStr)
68+
}
69+
}
70+
}
71+
72+
// Credentials
73+
if creds, ok := rawMap["credentials"].(map[string]interface{}); ok {
74+
authConfig.Credentials = make(map[string]string)
75+
for user, pass := range creds {
76+
if passStr, ok := pass.(string); ok {
77+
authConfig.Credentials[user] = passStr
78+
}
79+
}
80+
}
81+
82+
// SessionExpiration
83+
if se, ok := rawMap["session_expiration"].(float64); ok {
84+
authConfig.SessionExpiration = int(se)
85+
}
86+
}
87+
88+
// Validate session expiration
89+
if authConfig.SessionExpiration > 86400 {
90+
return errors.New("session_expiration cannot exceed 86400 seconds (24 hours)")
91+
}
92+
if authConfig.SessionExpiration < -1 {
93+
return errors.New("session_expiration cannot be less than -1")
94+
}
95+
96+
logger.Infof("Initializing AuthPlugin for domain: %s with session_expiration: %d", conf.Domain, authConfig.SessionExpiration)
97+
98+
// Initialization of the plugin state
99+
state := &AuthPluginState{
100+
sessions: make(map[string]session),
101+
}
102+
103+
// Cleanup expired sessions
104+
if authConfig.SessionExpiration != -1 {
105+
go state.cleanupExpiredSessions(time.Minute, logger)
106+
}
107+
108+
mwManager.Use(p.authMiddleware(logger, authConfig, state))
109+
110+
return nil
111+
}
112+
113+
// authMiddleware returns the middleware function for authentication.
114+
func (p *AuthPlugin) authMiddleware(logger *log.Logger, config AuthPluginConfig, state *AuthPluginState) middleware.MiddlewareFunc {
115+
return func(next http.Handler) http.Handler {
116+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
117+
118+
// Check if the path is protected
119+
protected := false
120+
for _, path := range config.ProtectedPaths {
121+
if strings.HasPrefix(r.URL.Path, path) {
122+
protected = true
123+
break
124+
}
125+
}
126+
if !protected {
127+
next.ServeHTTP(w, r)
128+
return
129+
}
130+
131+
ip := getClientIP(r)
132+
133+
// Check if session exists and is valid
134+
if sess, exists := state.getSession(ip); exists {
135+
logger.Infof("Session valid for IP: %s, user: %s", ip, sess.Username)
136+
next.ServeHTTP(w, r)
137+
return
138+
}
139+
140+
// Check for Authorization header
141+
authHeader := r.Header.Get("Authorization")
142+
if authHeader == "" {
143+
unauthorized(w)
144+
return
145+
}
146+
147+
// Parse Basic Auth
148+
username, password, ok := parseBasicAuth(authHeader)
149+
if !ok {
150+
unauthorized(w)
151+
return
152+
}
153+
154+
// Validate credentials
155+
expectedPassword, userExists := config.Credentials[username]
156+
if !userExists || expectedPassword != password {
157+
unauthorized(w)
158+
return
159+
}
160+
161+
// Create session
162+
state.createSession(ip, username, config.SessionExpiration, logger)
163+
164+
logger.Infof("Authenticated IP: %s, user: %s", ip, username)
165+
next.ServeHTTP(w, r)
166+
})
167+
}
168+
}
169+
170+
// getClientIP extracts the client's IP address from the request.
171+
func getClientIP(r *http.Request) string {
172+
if ip := r.Header.Get("X-Real-IP"); ip != "" {
173+
return ip
174+
}
175+
if ips := r.Header.Get("X-Forwarded-For"); ips != "" {
176+
// X-Forwarded-For may contain multiple IPs, take the first one
177+
return strings.Split(ips, ",")[0]
178+
}
179+
// Fallback to RemoteAddr
180+
ip := r.RemoteAddr
181+
if colonIndex := strings.LastIndex(ip, ":"); colonIndex != -1 {
182+
ip = ip[:colonIndex]
183+
}
184+
return ip
185+
}
186+
187+
// parseBasicAuth parses the Basic Authentication header.
188+
func parseBasicAuth(authHeader string) (username, password string, ok bool) {
189+
const prefix = "Basic "
190+
if !strings.HasPrefix(authHeader, prefix) {
191+
return
192+
}
193+
decoded, err := base64.StdEncoding.DecodeString(authHeader[len(prefix):])
194+
if err != nil {
195+
return
196+
}
197+
parts := strings.SplitN(string(decoded), ":", 2)
198+
if len(parts) != 2 {
199+
return
200+
}
201+
username = parts[0]
202+
password = parts[1]
203+
ok = true
204+
return
205+
}
206+
207+
// unauthorized sends a 401 Unauthorized response with the appropriate header.
208+
func unauthorized(w http.ResponseWriter) {
209+
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
210+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
211+
}
212+
213+
// getSession retrieves a session for the given IP, if it exists and is valid.
214+
func (s *AuthPluginState) getSession(ip string) (session, bool) {
215+
s.mu.RLock()
216+
defer s.mu.RUnlock()
217+
sess, exists := s.sessions[ip]
218+
if !exists {
219+
return session{}, false
220+
}
221+
222+
// Check expiration
223+
if !sess.Expiry.IsZero() && sess.Expiry.Before(time.Now()) {
224+
return session{}, false
225+
}
226+
return sess, true
227+
}
228+
229+
// createSession creates a new session for the given IP and username.
230+
func (s *AuthPluginState) createSession(ip, username string, expiration int, logger *log.Logger) {
231+
s.mu.Lock()
232+
defer s.mu.Unlock()
233+
var expiry time.Time
234+
if expiration != -1 {
235+
expiry = time.Now().Add(time.Duration(expiration) * time.Second)
236+
}
237+
s.sessions[ip] = session{
238+
Username: username,
239+
Expiry: expiry,
240+
}
241+
if expiration != -1 {
242+
logger.Infof("Session created for IP: %s, user: %s, expires at: %v", ip, username, expiry)
243+
} else {
244+
logger.Infof("Session created for IP: %s, user: %s, never expires", ip, username)
245+
}
246+
}
247+
248+
// cleanupExpiredSessions periodically removes expired sessions.
249+
func (s *AuthPluginState) cleanupExpiredSessions(interval time.Duration, logger *log.Logger) {
250+
ticker := time.NewTicker(interval)
251+
defer ticker.Stop()
252+
for range ticker.C {
253+
s.mu.Lock()
254+
for ip, sess := range s.sessions {
255+
if !sess.Expiry.IsZero() && sess.Expiry.Before(time.Now()) {
256+
delete(s.sessions, ip)
257+
logger.Infof("Session expired and removed for IP: %s, user: %s", ip, sess.Username)
258+
}
259+
}
260+
s.mu.Unlock()
261+
}
262+
}

public/protected.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
<h1>If you see me, you are logged in</h1>

0 commit comments

Comments
 (0)