diff --git a/.gitignore b/.gitignore index b27ae77..280ddf7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,11 @@ raft.db .idea .vscode .claude/ +docs/ +/bin/ +/data/ +/cmd/single/data/ +/cmd/cluster/data/ +/single +/comqtt +/comqtt-cluster diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e098406 --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +TAILWIND_VERSION ?= v3.4.7 +TAILWIND_BIN := bin/tailwindcss + +.PHONY: dashboard +dashboard: $(TAILWIND_BIN) + $(TAILWIND_BIN) -c dashboard/web/tailwind.config.js \ + -i dashboard/web/input.css \ + -o dashboard/static/tailwind.css --minify + +$(TAILWIND_BIN): + @mkdir -p bin + @case "$$(uname -s)-$$(uname -m)" in \ + Darwin-arm64) URL="https://github.com/tailwindlabs/tailwindcss/releases/download/$(TAILWIND_VERSION)/tailwindcss-macos-arm64";; \ + Darwin-x86_64) URL="https://github.com/tailwindlabs/tailwindcss/releases/download/$(TAILWIND_VERSION)/tailwindcss-macos-x64";; \ + Linux-x86_64) URL="https://github.com/tailwindlabs/tailwindcss/releases/download/$(TAILWIND_VERSION)/tailwindcss-linux-x64";; \ + Linux-aarch64) URL="https://github.com/tailwindlabs/tailwindcss/releases/download/$(TAILWIND_VERSION)/tailwindcss-linux-arm64";; \ + *) echo "unsupported platform: $$(uname -s)-$$(uname -m)" >&2; exit 1;; \ + esac; \ + curl -L -o $@ $$URL && chmod +x $@ diff --git a/cluster/agent.go b/cluster/agent.go index 09c5938..2c694e4 100644 --- a/cluster/agent.go +++ b/cluster/agent.go @@ -201,6 +201,39 @@ func (a *Agent) GetMemberList() []discovery.Member { return a.membership.Members() } +// Leader returns the node name of the current Raft leader, or "" if no +// leader is known. Used by integrations (REST handlers, dashboard) that +// need to display cluster topology. +// +// Tries the raft-reported ID first; falls back to matching the leader's +// host address against the membership list. This handles setups where +// raft.LocalID and the discovery node name differ. +func (a *Agent) Leader() string { + if a.raftPeer == nil { + return "" + } + addr, id := a.raftPeer.GetLeader() + if id != "" { + return id + } + if addr == "" { + return "" + } + host := addr + for i := len(addr) - 1; i >= 0; i-- { + if addr[i] == ':' { + host = addr[:i] + break + } + } + for _, m := range a.membership.Members() { + if m.Addr == host { + return m.Name + } + } + return "" +} + func getRaftPeerAddr(member *discovery.Member) string { // using serf if raftPort, ok := member.Tags[discovery.TagRaftPort]; ok { diff --git a/cluster/rest/rest.go b/cluster/rest/rest.go index f138ddb..bd10200 100644 --- a/cluster/rest/rest.go +++ b/cluster/rest/rest.go @@ -23,16 +23,24 @@ func New(agent *cs.Agent) *rest { func (s *rest) GenHandlers() map[string]rt.Handler { return map[string]rt.Handler{ - "GET /api/v1/node/config": s.viewConfig, - "DELETE /api/v1/node/{name}": s.leave, - "GET /api/v1/cluster/nodes": s.getNodes, - "POST /api/v1/cluster/nodes": s.join, - "POST /api/v1/cluster/peers": s.addRaftPeer, - "DELETE /api/v1/cluster/peers/{name}": s.removeRaftPeer, - "GET /api/v1/cluster/stat/online": s.getOnlineCount, - "GET /api/v1/cluster/clients/{id}": s.getClient, - "POST /api/v1/cluster/blacklist/{id}": s.kickClient, - "DELETE /api/v1/cluster/blacklist/{id}": s.blanchClient, + "GET /api/v1/node/config": s.viewConfig, + "DELETE /api/v1/node/{name}": s.leave, + "GET /api/v1/cluster/nodes": s.getNodes, + "POST /api/v1/cluster/nodes": s.join, + "POST /api/v1/cluster/peers": s.addRaftPeer, + "DELETE /api/v1/cluster/peers/{name}": s.removeRaftPeer, + "GET /api/v1/cluster/stat/online": s.getOnlineCount, + "GET /api/v1/cluster/clients/{id}": s.getClient, + "POST /api/v1/cluster/blacklist/{id}": s.kickClient, + "DELETE /api/v1/cluster/blacklist/{id}": s.blanchClient, + "GET /api/v1/cluster/clients": s.listClients, + "GET /api/v1/cluster/subscriptions": s.listSubscriptions, + "GET /api/v1/cluster/topics": s.topicsTree, + "DELETE /api/v1/cluster/clients/{id}/subscriptions/{topic}": s.unsubscribeClient, + "GET /api/v1/cluster/retained": s.listRetained, + "DELETE /api/v1/cluster/retained/{topic}": s.clearRetained, + "GET /api/v1/cluster/sessions": s.listSessions, + "DELETE /api/v1/cluster/sessions/{id}": s.clearSession, } } @@ -167,6 +175,94 @@ func (s *rest) blanchClient(w http.ResponseWriter, r *http.Request) { rt.Ok(w, rs) } +// listClients fan out client listing to every node in the cluster +// GET api/v1/cluster/clients +func (s *rest) listClients(w http.ResponseWriter, r *http.Request) { + path := rt.MqttListClientsPath + if r.URL.RawQuery != "" { + path += "?" + r.URL.RawQuery + } + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpGet, urls, nil) + rt.Ok(w, rs) +} + +// listSubscriptions fan out subscription listing to every node in the cluster +// GET api/v1/cluster/subscriptions +func (s *rest) listSubscriptions(w http.ResponseWriter, r *http.Request) { + path := rt.MqttListSubscriptionsPath + if r.URL.RawQuery != "" { + path += "?" + r.URL.RawQuery + } + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpGet, urls, nil) + rt.Ok(w, rs) +} + +// topicsTree fan out topic tree retrieval to every node in the cluster +// GET api/v1/cluster/topics +func (s *rest) topicsTree(w http.ResponseWriter, r *http.Request) { + urls := genUrls(s.agent.GetMemberList(), rt.MqttTopicsTreePath) + rs := fetchM(HttpGet, urls, nil) + rt.Ok(w, rs) +} + +// unsubscribeClient fan out unsubscribe to every node in the cluster +// DELETE api/v1/cluster/clients/{id}/subscriptions/{topic} +func (s *rest) unsubscribeClient(w http.ResponseWriter, r *http.Request) { + cid := r.PathValue("id") + topic := r.PathValue("topic") + path := strings.Replace(rt.MqttUnsubscribeClientPath, "{id}", cid, 1) + path = strings.Replace(path, "{topic}", topic, 1) + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpDelete, urls, nil) + rt.Ok(w, rs) +} + +// listRetained fan out retained message listing to every node in the cluster +// GET api/v1/cluster/retained +func (s *rest) listRetained(w http.ResponseWriter, r *http.Request) { + path := rt.MqttListRetainedPath + if r.URL.RawQuery != "" { + path += "?" + r.URL.RawQuery + } + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpGet, urls, nil) + rt.Ok(w, rs) +} + +// clearRetained fan out retained clear to every node in the cluster +// DELETE api/v1/cluster/retained/{topic} +func (s *rest) clearRetained(w http.ResponseWriter, r *http.Request) { + topic := r.PathValue("topic") + path := strings.Replace(rt.MqttClearRetainedPath, "{topic}", topic, 1) + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpDelete, urls, nil) + rt.Ok(w, rs) +} + +// listSessions fan out session listing to every node in the cluster +// GET api/v1/cluster/sessions +func (s *rest) listSessions(w http.ResponseWriter, r *http.Request) { + path := rt.MqttListSessionsPath + if r.URL.RawQuery != "" { + path += "?" + r.URL.RawQuery + } + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpGet, urls, nil) + rt.Ok(w, rs) +} + +// clearSession fan out session clear to every node in the cluster +// DELETE api/v1/cluster/sessions/{id} +func (s *rest) clearSession(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + path := strings.Replace(rt.MqttClearSessionPath, "{id}", id, 1) + urls := genUrls(s.agent.GetMemberList(), path) + rs := fetchM(HttpDelete, urls, nil) + rt.Ok(w, rs) +} + // genUrls generate urls func genUrls(ms []discovery.Member, path string) []string { urls := make([]string, len(ms)) diff --git a/cmd/cluster/main.go b/cmd/cluster/main.go index 415c7bb..1f413ac 100644 --- a/cmd/cluster/main.go +++ b/cmd/cluster/main.go @@ -26,6 +26,7 @@ import ( coredis "github.com/wind-c/comqtt/v2/cluster/storage/redis" "github.com/wind-c/comqtt/v2/config" mqtt "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/dashboard" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/listeners" mqttRt "github.com/wind-c/comqtt/v2/mqtt/rest" @@ -144,6 +145,29 @@ func realMain(ctx context.Context) error { csHls := csRt.New(agent).GenHandlers() mqHls := mqttRt.New(server).GenHandlers() maps.Copy(csHls, mqHls) + + dashCleanup := func() {} + if cfg.Dashboard.Enabled { + dashRedis := redis.NewClient(&redis.Options{ + Addr: cfg.Redis.Options.Addr, + Password: cfg.Redis.Options.Password, + DB: cfg.Redis.Options.DB, + }) + dashRoutes, cleanup, err := dashboard.Routes(dashboard.Options{ + Server: server, + Cluster: true, + ClusterAgent: agent, + Redis: dashRedis, + Secret: cfg.Dashboard.DecodeSecret(), + PasswordExpiryDays: cfg.Dashboard.PasswordExpiryDays, + }) + if err != nil { + return fmt.Errorf("dashboard routes: %w", err) + } + dashCleanup = cleanup + maps.Copy(csHls, dashRoutes) + } + http := listeners.NewHTTP("stats", cfg.Mqtt.HTTP, nil, csHls) onError(server.AddListener(http), "add http listener") @@ -165,6 +189,7 @@ func realMain(ctx context.Context) error { case <-ctx.Done(): server.Log.Warn("caught signal, stopping...") } + dashCleanup() agent.Stop() server.Close() return nil diff --git a/cmd/single/main.go b/cmd/single/main.go index 665fa0b..e5df364 100644 --- a/cmd/single/main.go +++ b/cmd/single/main.go @@ -18,6 +18,7 @@ import ( "github.com/wind-c/comqtt/v2/cluster/log" "github.com/wind-c/comqtt/v2/config" "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/dashboard" "github.com/wind-c/comqtt/v2/mqtt/hooks/auth" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage/badger" "github.com/wind-c/comqtt/v2/mqtt/hooks/storage/bolt" @@ -135,7 +136,26 @@ func realMain(ctx context.Context) error { onError(server.AddListener(ws), "add websocket listener") // add http listener - http := listeners.NewHTTP("stats", cfg.Mqtt.HTTP, nil, rest.New(server).GenHandlers()) + restHandlers := rest.New(server).GenHandlers() + + dashCleanup := func() {} + if cfg.Dashboard.Enabled { + dashRoutes, cleanup, err := dashboard.Routes(dashboard.Options{ + Server: server, + Cluster: false, + Secret: cfg.Dashboard.DecodeSecret(), + PasswordExpiryDays: cfg.Dashboard.PasswordExpiryDays, + }) + if err != nil { + return fmt.Errorf("dashboard routes: %w", err) + } + dashCleanup = cleanup + for path, h := range dashRoutes { + restHandlers[path] = h + } + } + + http := listeners.NewHTTP("stats", cfg.Mqtt.HTTP, nil, restHandlers) onError(server.AddListener(http), "add http listener") errCh := make(chan error, 1) @@ -155,6 +175,7 @@ func realMain(ctx context.Context) error { case <-ctx.Done(): log.Warn("caught signal, stopping...") } + dashCleanup() server.Close() log.Info("main.go finished") return nil diff --git a/config/config.go b/config/config.go index 238f453..f4404ac 100644 --- a/config/config.go +++ b/config/config.go @@ -10,6 +10,7 @@ import ( tls2 "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "encoding/pem" "errors" "math/big" @@ -67,7 +68,9 @@ var ( ) func New() *Config { - return &Config{} + return &Config{ + Dashboard: Dashboard{Enabled: true}, + } } func Load(yamlFile string) (*Config, error) { @@ -79,7 +82,9 @@ func Load(yamlFile string) (*Config, error) { } func parse(buf []byte) (*Config, error) { - conf := &Config{} + // Pre-populate defaults so omitted YAML keys keep them. yaml.v3 + // overwrites only the fields that are present in the document. + conf := New() err := yaml.Unmarshal(buf, conf) if err != nil { return nil, err @@ -97,9 +102,33 @@ type Config struct { Cluster Cluster `yaml:"cluster"` Redis redis `yaml:"redis"` Log log.Options `yaml:"log"` + Dashboard Dashboard `yaml:"dashboard"` PprofEnable bool `yaml:"pprof-enable"` } +// Dashboard holds the v1 web-dashboard wiring choices. Defaults are applied +// in dashboard.Options.applyDefaults; these YAML fields let operators +// opt out (Enabled=false), pin a session secret across restarts, or change +// the password expiry policy without code changes. +type Dashboard struct { + Enabled bool `yaml:"enabled"` + SessionSecret string `yaml:"session-secret"` + PasswordExpiryDays int `yaml:"password-expiry-days"` +} + +// DecodeSecret returns the SessionSecret as raw bytes. It accepts either a +// base64-encoded string or a raw secret. Returns nil if SessionSecret is +// empty (in which case the dashboard auto-generates and persists one). +func (d *Dashboard) DecodeSecret() []byte { + if d.SessionSecret == "" { + return nil + } + if b, err := base64.StdEncoding.DecodeString(d.SessionSecret); err == nil && len(b) >= 16 { + return b + } + return []byte(d.SessionSecret) +} + type auth struct { Way uint `yaml:"way"` Datasource uint `yaml:"datasource"` diff --git a/dashboard/auth/credstore.go b/dashboard/auth/credstore.go new file mode 100644 index 0000000..407adb0 --- /dev/null +++ b/dashboard/auth/credstore.go @@ -0,0 +1,51 @@ +// dashboard/auth/credstore.go +package auth + +import ( + "context" + "errors" +) + +type Role string + +const ( + RoleAdmin Role = "admin" + RoleViewer Role = "viewer" +) + +func (r Role) Valid() bool { + return r == RoleAdmin || r == RoleViewer +} + +type User struct { + Username string `json:"username"` + Hash string `json:"hash"` + Role Role `json:"role"` + MustChange bool `json:"must_change"` + PasswordSetAt int64 `json:"password_set_at"` + FailedAttempts int `json:"failed_attempts"` + LockedUntil int64 `json:"locked_until"` + CreatedAt int64 `json:"created_at"` +} + +var ( + ErrBadCredentials = errors.New("auth: bad credentials") + ErrUserNotFound = errors.New("auth: user not found") + ErrUserExists = errors.New("auth: user exists") + ErrCannotDeleteLastAdmin = errors.New("auth: cannot delete the last admin") + ErrLocked = errors.New("auth: account locked") +) + +type CredStore interface { + Seed(ctx context.Context, username string) (plaintext string, err error) + Authenticate(ctx context.Context, username, password string) (User, error) + SetPassword(ctx context.Context, username, password string) error + CreateUser(ctx context.Context, username, password string, role Role) error + DeleteUser(ctx context.Context, username string) error + SetRole(ctx context.Context, username string, role Role) error + GetUser(ctx context.Context, username string) (User, error) + ListUsers(ctx context.Context) ([]User, error) + SetLockedUntil(ctx context.Context, username string, until int64) error + IncrementFailures(ctx context.Context, username string) (int, error) + ResetFailures(ctx context.Context, username string) error +} diff --git a/dashboard/auth/credstore_file.go b/dashboard/auth/credstore_file.go new file mode 100644 index 0000000..2bbd19b --- /dev/null +++ b/dashboard/auth/credstore_file.go @@ -0,0 +1,253 @@ +// dashboard/auth/credstore_file.go +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "os" + "path/filepath" + "sync" + "time" + + "golang.org/x/crypto/bcrypt" +) + +type FileStore struct { + path string + mu sync.Mutex +} + +func NewFileStore(path string) (*FileStore, error) { + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return nil, err + } + return &FileStore{path: path}, nil +} + +func (s *FileStore) load() ([]User, error) { + b, err := os.ReadFile(s.path) + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + if err != nil { + return nil, err + } + var users []User + if err := json.Unmarshal(b, &users); err != nil { + return nil, err + } + return users, nil +} + +func (s *FileStore) save(users []User) error { + b, err := json.MarshalIndent(users, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, b, 0o600) +} + +func (s *FileStore) Seed(ctx context.Context, username string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + users, err := s.load() + if err != nil { + return "", err + } + if len(users) > 0 { + return "", nil + } + pw, err := randomPassword(16) + if err != nil { + return "", err + } + if env := os.Getenv("DASHBOARD_INITIAL_PASSWORD"); env != "" { + pw = env + } + hash, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost) + if err != nil { + return "", err + } + now := time.Now().Unix() + users = append(users, User{ + Username: username, + Hash: string(hash), + Role: RoleAdmin, + MustChange: true, + PasswordSetAt: now, + CreatedAt: now, + }) + if err := s.save(users); err != nil { + return "", err + } + return pw, nil +} + +func (s *FileStore) Authenticate(ctx context.Context, username, password string) (User, error) { + s.mu.Lock() + defer s.mu.Unlock() + users, err := s.load() + if err != nil { + return User{}, err + } + for _, u := range users { + if u.Username != username { + continue + } + if u.LockedUntil > time.Now().Unix() { + return User{}, ErrLocked + } + if err := bcrypt.CompareHashAndPassword([]byte(u.Hash), []byte(password)); err != nil { + return User{}, ErrBadCredentials + } + return u, nil + } + return User{}, ErrBadCredentials +} + +func (s *FileStore) SetPassword(ctx context.Context, username, password string) error { + s.mu.Lock() + defer s.mu.Unlock() + users, err := s.load() + if err != nil { + return err + } + for i, u := range users { + if u.Username != username { + continue + } + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + users[i].Hash = string(hash) + users[i].MustChange = false + users[i].PasswordSetAt = time.Now().Unix() + users[i].FailedAttempts = 0 + users[i].LockedUntil = 0 + return s.save(users) + } + return ErrUserNotFound +} + +func (s *FileStore) CreateUser(ctx context.Context, username, password string, role Role) error { + s.mu.Lock() + defer s.mu.Unlock() + users, err := s.load() + if err != nil { + return err + } + for _, u := range users { + if u.Username == username { + return ErrUserExists + } + } + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + now := time.Now().Unix() + users = append(users, User{ + Username: username, Hash: string(hash), Role: role, + MustChange: true, PasswordSetAt: now, CreatedAt: now, + }) + return s.save(users) +} + +func (s *FileStore) DeleteUser(ctx context.Context, username string) error { + s.mu.Lock() + defer s.mu.Unlock() + users, err := s.load() + if err != nil { + return err + } + admins := 0 + for _, u := range users { + if u.Role == RoleAdmin { + admins++ + } + } + out := users[:0] + for _, u := range users { + if u.Username == username { + if u.Role == RoleAdmin && admins == 1 { + return ErrCannotDeleteLastAdmin + } + continue + } + out = append(out, u) + } + if len(out) == len(users) { + return ErrUserNotFound + } + return s.save(out) +} + +func (s *FileStore) GetUser(ctx context.Context, username string) (User, error) { + s.mu.Lock() + defer s.mu.Unlock() + users, err := s.load() + if err != nil { + return User{}, err + } + for _, u := range users { + if u.Username == username { + return u, nil + } + } + return User{}, ErrUserNotFound +} + +func (s *FileStore) ListUsers(ctx context.Context) ([]User, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.load() +} + +func (s *FileStore) SetRole(ctx context.Context, username string, role Role) error { + if !role.Valid() { + return errors.New("auth: invalid role") + } + return s.mutate(username, func(u *User) { u.Role = role }) +} + +func (s *FileStore) SetLockedUntil(ctx context.Context, username string, until int64) error { + return s.mutate(username, func(u *User) { u.LockedUntil = until }) +} + +func (s *FileStore) IncrementFailures(ctx context.Context, username string) (int, error) { + var n int + err := s.mutate(username, func(u *User) { u.FailedAttempts++; n = u.FailedAttempts }) + return n, err +} + +func (s *FileStore) ResetFailures(ctx context.Context, username string) error { + return s.mutate(username, func(u *User) { u.FailedAttempts = 0; u.LockedUntil = 0 }) +} + +func (s *FileStore) mutate(username string, fn func(*User)) error { + s.mu.Lock() + defer s.mu.Unlock() + users, err := s.load() + if err != nil { + return err + } + for i, u := range users { + if u.Username == username { + fn(&users[i]) + return s.save(users) + } + } + return ErrUserNotFound +} + +func randomPassword(n int) (string, error) { + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf)[:n], nil +} diff --git a/dashboard/auth/credstore_file_test.go b/dashboard/auth/credstore_file_test.go new file mode 100644 index 0000000..b29badc --- /dev/null +++ b/dashboard/auth/credstore_file_test.go @@ -0,0 +1,101 @@ +// dashboard/auth/credstore_file_test.go +package auth + +import ( + "context" + "path/filepath" + "testing" +) + +func TestFileStoreSeedAndAuthenticate(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "users.json") + store, err := NewFileStore(path) + if err != nil { + t.Fatalf("NewFileStore: %v", err) + } + ctx := context.Background() + pw, err := store.Seed(ctx, "admin") + if err != nil { + t.Fatalf("Seed: %v", err) + } + if pw == "" { + t.Fatal("seeded password should be non-empty") + } + u, err := store.Authenticate(ctx, "admin", pw) + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if u.Role != RoleAdmin { + t.Fatalf("role: got %s want %s", u.Role, RoleAdmin) + } + if !u.MustChange { + t.Fatal("seeded user should have must_change=true") + } +} + +func TestFileStoreAuthenticateRejectsBadPassword(t *testing.T) { + store, _ := NewFileStore(filepath.Join(t.TempDir(), "users.json")) + _, _ = store.Seed(context.Background(), "admin") + if _, err := store.Authenticate(context.Background(), "admin", "wrong"); err != ErrBadCredentials { + t.Fatalf("err: %v", err) + } +} + +func TestFileStoreSetPassword(t *testing.T) { + path := filepath.Join(t.TempDir(), "users.json") + store, _ := NewFileStore(path) + _, _ = store.Seed(context.Background(), "admin") + if err := store.SetPassword(context.Background(), "admin", "newpass1234"); err != nil { + t.Fatalf("SetPassword: %v", err) + } + u, err := store.Authenticate(context.Background(), "admin", "newpass1234") + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if u.MustChange { + t.Fatal("must_change should be cleared after rotation") + } + store2, _ := NewFileStore(path) + if _, err := store2.Authenticate(context.Background(), "admin", "newpass1234"); err != nil { + t.Fatalf("re-open Authenticate: %v", err) + } +} + +func TestFileStoreCreateAndDeleteUser(t *testing.T) { + store, _ := NewFileStore(filepath.Join(t.TempDir(), "users.json")) + _, _ = store.Seed(context.Background(), "admin") + if err := store.CreateUser(context.Background(), "bob", "bobpass1234", RoleViewer); err != nil { + t.Fatalf("CreateUser: %v", err) + } + users, err := store.ListUsers(context.Background()) + if err != nil { + t.Fatalf("ListUsers: %v", err) + } + if len(users) != 2 { + t.Fatalf("expected 2 users got %d", len(users)) + } + if err := store.DeleteUser(context.Background(), "bob"); err != nil { + t.Fatalf("DeleteUser: %v", err) + } + if _, err := store.Authenticate(context.Background(), "bob", "bobpass1234"); err == nil { + t.Fatal("expected user to be gone") + } + if err := store.DeleteUser(context.Background(), "admin"); err != ErrCannotDeleteLastAdmin { + t.Fatalf("expected ErrCannotDeleteLastAdmin, got %v", err) + } +} + +func TestFileStoreLockoutFields(t *testing.T) { + path := filepath.Join(t.TempDir(), "users.json") + store, _ := NewFileStore(path) + _, _ = store.Seed(context.Background(), "admin") + if err := store.SetLockedUntil(context.Background(), "admin", 99999999999); err != nil { + t.Fatalf("SetLockedUntil: %v", err) + } + store2, _ := NewFileStore(path) + u, _ := store2.GetUser(context.Background(), "admin") + if u.LockedUntil != 99999999999 { + t.Fatalf("locked_until persisted as %d", u.LockedUntil) + } +} diff --git a/dashboard/auth/credstore_redis.go b/dashboard/auth/credstore_redis.go new file mode 100644 index 0000000..9abc53b --- /dev/null +++ b/dashboard/auth/credstore_redis.go @@ -0,0 +1,231 @@ +// dashboard/auth/credstore_redis.go +package auth + +import ( + "context" + "encoding/json" + "errors" + "os" + "time" + + "github.com/redis/go-redis/v9" + "golang.org/x/crypto/bcrypt" +) + +// RedisStore implements CredStore using a single Redis hash. All cluster +// nodes read/write the same hash so password/role changes propagate. +// +// Layout: +// key = ":users" (default "comqtt:dashboard:users") +// field = username +// value = JSON-marshaled User struct +type RedisStore struct { + Client *redis.Client + KeyPrefix string +} + +// NewRedisStore constructs a RedisStore with sensible defaults. KeyPrefix +// defaults to "comqtt:dashboard" if empty. The caller owns the underlying +// redis.Client lifecycle. +func NewRedisStore(client *redis.Client, keyPrefix string) *RedisStore { + if keyPrefix == "" { + keyPrefix = "comqtt:dashboard" + } + return &RedisStore{Client: client, KeyPrefix: keyPrefix} +} + +func (s *RedisStore) usersKey() string { return s.KeyPrefix + ":users" } + +// Seed creates the admin user once. Idempotent: returns "" with no error +// if a user already exists in the hash. +func (s *RedisStore) Seed(ctx context.Context, username string) (string, error) { + exists, err := s.Client.HLen(ctx, s.usersKey()).Result() + if err != nil { + return "", err + } + if exists > 0 { + return "", nil + } + pw, err := randomPassword(16) + if err != nil { + return "", err + } + if env := os.Getenv("DASHBOARD_INITIAL_PASSWORD"); env != "" { + pw = env + } + hash, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost) + if err != nil { + return "", err + } + now := time.Now().Unix() + u := User{ + Username: username, + Hash: string(hash), + Role: RoleAdmin, + MustChange: true, + PasswordSetAt: now, + CreatedAt: now, + } + body, _ := json.Marshal(u) + if err := s.Client.HSet(ctx, s.usersKey(), username, body).Err(); err != nil { + return "", err + } + return pw, nil +} + +func (s *RedisStore) Authenticate(ctx context.Context, username, password string) (User, error) { + u, err := s.GetUser(ctx, username) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return User{}, ErrBadCredentials + } + return User{}, err + } + if u.LockedUntil > time.Now().Unix() { + return User{}, ErrLocked + } + if err := bcrypt.CompareHashAndPassword([]byte(u.Hash), []byte(password)); err != nil { + return User{}, ErrBadCredentials + } + return u, nil +} + +func (s *RedisStore) SetPassword(ctx context.Context, username, password string) error { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + return s.mutate(ctx, username, func(u *User) { + u.Hash = string(hash) + u.MustChange = false + u.PasswordSetAt = time.Now().Unix() + u.FailedAttempts = 0 + u.LockedUntil = 0 + }) +} + +func (s *RedisStore) CreateUser(ctx context.Context, username, password string, role Role) error { + exists, err := s.Client.HExists(ctx, s.usersKey(), username).Result() + if err != nil { + return err + } + if exists { + return ErrUserExists + } + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + now := time.Now().Unix() + u := User{ + Username: username, + Hash: string(hash), + Role: role, + MustChange: true, + PasswordSetAt: now, + CreatedAt: now, + } + body, _ := json.Marshal(u) + return s.Client.HSet(ctx, s.usersKey(), username, body).Err() +} + +func (s *RedisStore) DeleteUser(ctx context.Context, username string) error { + users, err := s.ListUsers(ctx) + if err != nil { + return err + } + admins := 0 + var target *User + for i, u := range users { + if u.Role == RoleAdmin { + admins++ + } + if u.Username == username { + target = &users[i] + } + } + if target == nil { + return ErrUserNotFound + } + if target.Role == RoleAdmin && admins == 1 { + return ErrCannotDeleteLastAdmin + } + return s.Client.HDel(ctx, s.usersKey(), username).Err() +} + +func (s *RedisStore) GetUser(ctx context.Context, username string) (User, error) { + body, err := s.Client.HGet(ctx, s.usersKey(), username).Bytes() + if err != nil { + if errors.Is(err, redis.Nil) { + return User{}, ErrUserNotFound + } + return User{}, err + } + var u User + if err := json.Unmarshal(body, &u); err != nil { + return User{}, err + } + return u, nil +} + +func (s *RedisStore) ListUsers(ctx context.Context) ([]User, error) { + all, err := s.Client.HGetAll(ctx, s.usersKey()).Result() + if err != nil { + return nil, err + } + out := make([]User, 0, len(all)) + for _, raw := range all { + var u User + if err := json.Unmarshal([]byte(raw), &u); err != nil { + continue + } + out = append(out, u) + } + return out, nil +} + +func (s *RedisStore) SetLockedUntil(ctx context.Context, username string, until int64) error { + return s.mutate(ctx, username, func(u *User) { u.LockedUntil = until }) +} + +func (s *RedisStore) IncrementFailures(ctx context.Context, username string) (int, error) { + var n int + err := s.mutate(ctx, username, func(u *User) { u.FailedAttempts++; n = u.FailedAttempts }) + return n, err +} + +func (s *RedisStore) ResetFailures(ctx context.Context, username string) error { + return s.mutate(ctx, username, func(u *User) { u.FailedAttempts = 0; u.LockedUntil = 0 }) +} + +func (s *RedisStore) SetRole(ctx context.Context, username string, role Role) error { + if !role.Valid() { + return errors.New("auth: invalid role") + } + return s.mutate(ctx, username, func(u *User) { u.Role = role }) +} + +// mutate atomically updates one user via WATCH/MULTI/EXEC. +func (s *RedisStore) mutate(ctx context.Context, username string, fn func(*User)) error { + key := s.usersKey() + return s.Client.Watch(ctx, func(tx *redis.Tx) error { + body, err := tx.HGet(ctx, key, username).Bytes() + if errors.Is(err, redis.Nil) { + return ErrUserNotFound + } + if err != nil { + return err + } + var u User + if err := json.Unmarshal(body, &u); err != nil { + return err + } + fn(&u) + updated, _ := json.Marshal(&u) + _, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error { + p.HSet(ctx, key, username, updated) + return nil + }) + return err + }, key) +} diff --git a/dashboard/auth/credstore_redis_test.go b/dashboard/auth/credstore_redis_test.go new file mode 100644 index 0000000..ee9cacf --- /dev/null +++ b/dashboard/auth/credstore_redis_test.go @@ -0,0 +1,140 @@ +// dashboard/auth/credstore_redis_test.go +package auth + +import ( + "context" + "testing" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +func newRedisStore(t *testing.T) (*RedisStore, *miniredis.Miniredis) { + t.Helper() + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { _ = client.Close() }) + return NewRedisStore(client, "test:dashboard"), mr +} + +func TestRedisStoreSeedAndAuthenticate(t *testing.T) { + store, _ := newRedisStore(t) + ctx := context.Background() + pw, err := store.Seed(ctx, "admin") + if err != nil { + t.Fatalf("Seed: %v", err) + } + if pw == "" { + t.Fatal("seeded password should be non-empty") + } + u, err := store.Authenticate(ctx, "admin", pw) + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if u.Role != RoleAdmin { + t.Fatalf("role: %s", u.Role) + } + if !u.MustChange { + t.Fatal("MustChange should be true on seed") + } +} + +func TestRedisStoreSeedIdempotent(t *testing.T) { + store, _ := newRedisStore(t) + ctx := context.Background() + pw1, _ := store.Seed(ctx, "admin") + pw2, err := store.Seed(ctx, "admin") + if err != nil { + t.Fatalf("Seed: %v", err) + } + if pw2 != "" { + t.Fatalf("second Seed should return empty password, got %q", pw2) + } + if pw1 == "" { + t.Fatal("first Seed should return password") + } +} + +func TestRedisStoreAuthenticateBadPassword(t *testing.T) { + store, _ := newRedisStore(t) + _, _ = store.Seed(context.Background(), "admin") + if _, err := store.Authenticate(context.Background(), "admin", "wrong"); err != ErrBadCredentials { + t.Fatalf("err: %v", err) + } +} + +func TestRedisStoreSetPassword(t *testing.T) { + store, _ := newRedisStore(t) + ctx := context.Background() + _, _ = store.Seed(ctx, "admin") + if err := store.SetPassword(ctx, "admin", "newpass1234"); err != nil { + t.Fatalf("SetPassword: %v", err) + } + u, err := store.Authenticate(ctx, "admin", "newpass1234") + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if u.MustChange { + t.Fatal("must_change should be cleared") + } +} + +func TestRedisStoreCreateAndDeleteUser(t *testing.T) { + store, _ := newRedisStore(t) + ctx := context.Background() + _, _ = store.Seed(ctx, "admin") + if err := store.CreateUser(ctx, "bob", "bobpass1234", RoleViewer); err != nil { + t.Fatalf("CreateUser: %v", err) + } + users, _ := store.ListUsers(ctx) + if len(users) != 2 { + t.Fatalf("expected 2 users, got %d", len(users)) + } + if err := store.DeleteUser(ctx, "bob"); err != nil { + t.Fatalf("DeleteUser: %v", err) + } + if err := store.DeleteUser(ctx, "admin"); err != ErrCannotDeleteLastAdmin { + t.Fatalf("expected ErrCannotDeleteLastAdmin: %v", err) + } +} + +func TestRedisStoreSetRole(t *testing.T) { + store, _ := newRedisStore(t) + ctx := context.Background() + _, _ = store.Seed(ctx, "admin") + _ = store.CreateUser(ctx, "bob", "bobpass1234", RoleViewer) + if err := store.SetRole(ctx, "bob", RoleAdmin); err != nil { + t.Fatalf("SetRole: %v", err) + } + u, _ := store.GetUser(ctx, "bob") + if u.Role != RoleAdmin { + t.Fatalf("role: %s", u.Role) + } +} + +func TestRedisStoreLockoutFields(t *testing.T) { + store, _ := newRedisStore(t) + ctx := context.Background() + _, _ = store.Seed(ctx, "admin") + if err := store.SetLockedUntil(ctx, "admin", 99999999999); err != nil { + t.Fatalf("SetLockedUntil: %v", err) + } + u, _ := store.GetUser(ctx, "admin") + if u.LockedUntil != 99999999999 { + t.Fatalf("locked_until: %d", u.LockedUntil) + } + if _, err := store.IncrementFailures(ctx, "admin"); err != nil { + t.Fatalf("IncrementFailures: %v", err) + } + u, _ = store.GetUser(ctx, "admin") + if u.FailedAttempts != 1 { + t.Fatalf("failed_attempts: %d", u.FailedAttempts) + } + if err := store.ResetFailures(ctx, "admin"); err != nil { + t.Fatalf("ResetFailures: %v", err) + } + u, _ = store.GetUser(ctx, "admin") + if u.FailedAttempts != 0 || u.LockedUntil != 0 { + t.Fatalf("after reset: %+v", u) + } +} diff --git a/dashboard/auth/csrf.go b/dashboard/auth/csrf.go new file mode 100644 index 0000000..1562441 --- /dev/null +++ b/dashboard/auth/csrf.go @@ -0,0 +1,23 @@ +// dashboard/auth/csrf.go +package auth + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" +) + +func NewCSRFToken() string { + buf := make([]byte, 24) + _, _ = rand.Read(buf) + return base64.RawURLEncoding.EncodeToString(buf) +} + +// ValidateCSRFToken constant-time compares the request token against the +// per-render token stamped into the form. Empty tokens never validate. +func ValidateCSRFToken(form, expected string) bool { + if form == "" || expected == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(form), []byte(expected)) == 1 +} diff --git a/dashboard/auth/csrf_test.go b/dashboard/auth/csrf_test.go new file mode 100644 index 0000000..c19ecdd --- /dev/null +++ b/dashboard/auth/csrf_test.go @@ -0,0 +1,23 @@ +// dashboard/auth/csrf_test.go +package auth + +import "testing" + +func TestCSRFRoundTrip(t *testing.T) { + tok := NewCSRFToken() + if !ValidateCSRFToken(tok, tok) { + t.Fatal("token should match itself") + } +} + +func TestCSRFRejectsMismatch(t *testing.T) { + if ValidateCSRFToken(NewCSRFToken(), NewCSRFToken()) { + t.Fatal("two random tokens should not match") + } +} + +func TestCSRFRejectsEmpty(t *testing.T) { + if ValidateCSRFToken("", "") { + t.Fatal("empty tokens must not validate") + } +} diff --git a/dashboard/auth/lockout.go b/dashboard/auth/lockout.go new file mode 100644 index 0000000..5c7a13e --- /dev/null +++ b/dashboard/auth/lockout.go @@ -0,0 +1,70 @@ +// dashboard/auth/lockout.go +package auth + +import ( + "sync" + "time" +) + +type LockoutConfig struct { + Threshold int + Window time.Duration + Duration time.Duration +} + +type Lockout struct { + cfg LockoutConfig + mu sync.Mutex + st map[string]*lockoutState +} + +type lockoutState struct { + failures []time.Time + lockedUntil time.Time +} + +func NewLockout(cfg LockoutConfig) *Lockout { + return &Lockout{cfg: cfg, st: map[string]*lockoutState{}} +} + +func (l *Lockout) Record(username string) { + l.mu.Lock() + defer l.mu.Unlock() + s, ok := l.st[username] + if !ok { + s = &lockoutState{} + l.st[username] = s + } + now := time.Now() + cutoff := now.Add(-l.cfg.Window) + out := s.failures[:0] + for _, t := range s.failures { + if t.After(cutoff) { + out = append(out, t) + } + } + s.failures = append(out, now) + if len(s.failures) >= l.cfg.Threshold { + s.lockedUntil = now.Add(l.cfg.Duration) + s.failures = nil + } +} + +func (l *Lockout) IsLocked(username string) (bool, time.Time) { + l.mu.Lock() + defer l.mu.Unlock() + s, ok := l.st[username] + if !ok { + return false, time.Time{} + } + if time.Now().Before(s.lockedUntil) { + return true, s.lockedUntil + } + return false, time.Time{} +} + +func (l *Lockout) Reset(username string) { + l.mu.Lock() + defer l.mu.Unlock() + delete(l.st, username) +} diff --git a/dashboard/auth/lockout_test.go b/dashboard/auth/lockout_test.go new file mode 100644 index 0000000..8061313 --- /dev/null +++ b/dashboard/auth/lockout_test.go @@ -0,0 +1,42 @@ +// dashboard/auth/lockout_test.go +package auth + +import ( + "testing" + "time" +) + +func TestLockoutBelowThreshold(t *testing.T) { + l := NewLockout(LockoutConfig{Threshold: 5, Window: time.Minute, Duration: 10 * time.Minute}) + for i := 0; i < 4; i++ { + l.Record("alice") + } + if locked, _ := l.IsLocked("alice"); locked { + t.Fatal("should not be locked yet") + } +} + +func TestLockoutAtThreshold(t *testing.T) { + l := NewLockout(LockoutConfig{Threshold: 5, Window: time.Minute, Duration: 10 * time.Minute}) + for i := 0; i < 5; i++ { + l.Record("alice") + } + locked, until := l.IsLocked("alice") + if !locked { + t.Fatal("should be locked") + } + if until.Before(time.Now()) { + t.Fatalf("locked_until in past: %v", until) + } +} + +func TestLockoutResetsClearsState(t *testing.T) { + l := NewLockout(LockoutConfig{Threshold: 5, Window: time.Minute, Duration: 10 * time.Minute}) + for i := 0; i < 5; i++ { + l.Record("alice") + } + l.Reset("alice") + if locked, _ := l.IsLocked("alice"); locked { + t.Fatal("should be cleared after Reset") + } +} diff --git a/dashboard/auth/middleware.go b/dashboard/auth/middleware.go new file mode 100644 index 0000000..5bfb68e --- /dev/null +++ b/dashboard/auth/middleware.go @@ -0,0 +1,98 @@ +// dashboard/auth/middleware.go +package auth + +import ( + "context" + "net/http" + "strings" + "time" +) + +type ctxKey string + +const userCtxKey ctxKey = "comqtt-user" + +func WithUser(ctx context.Context, u User) context.Context { + return context.WithValue(ctx, userCtxKey, u) +} + +func UserFromContext(ctx context.Context) User { + if u, ok := ctx.Value(userCtxKey).(User); ok { + return u + } + return User{} +} + +// RequireAuth verifies the cookie, loads the User, then enforces force-rotate +// and password-expiry policies. It exempts the login/logout/static paths. +func RequireAuth(secret []byte, store CredStore, expiryDays int) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isExempt(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + c, err := r.Cookie("comqtt_session") + if err != nil { + redirectLogin(w, r) + return + } + payload, err := Verify(secret, c.Value) + if err != nil { + redirectLogin(w, r) + return + } + u, err := store.GetUser(r.Context(), payload.Username) + if err != nil { + redirectLogin(w, r) + return + } + if u.MustChange && !strings.HasPrefix(r.URL.Path, "/dashboard/account/password") { + http.Redirect(w, r, "/dashboard/account/password?reason=must_change", http.StatusFound) + return + } + if expiryDays > 0 && time.Now().Unix()-u.PasswordSetAt > int64(expiryDays)*86400 { + if !strings.HasPrefix(r.URL.Path, "/dashboard/account/password") { + http.Redirect(w, r, "/dashboard/account/password?reason=expired", http.StatusFound) + return + } + } + next.ServeHTTP(w, r.WithContext(WithUser(r.Context(), u))) + }) + } +} + +// RequireRole rejects users whose role is below the required level. +// It assumes RequireAuth has already populated the context. +func RequireRole(min Role) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u := UserFromContext(r.Context()) + if u.Role == "" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if min == RoleAdmin && u.Role != RoleAdmin { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) + } +} + +func isExempt(path string) bool { + switch { + case strings.HasPrefix(path, "/dashboard/login"): + return true + case strings.HasPrefix(path, "/dashboard/logout"): + return true + case strings.HasPrefix(path, "/dashboard/static/"): + return true + } + return false +} + +func redirectLogin(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/dashboard/login?next="+r.URL.Path, http.StatusFound) +} diff --git a/dashboard/auth/middleware_test.go b/dashboard/auth/middleware_test.go new file mode 100644 index 0000000..7364c9e --- /dev/null +++ b/dashboard/auth/middleware_test.go @@ -0,0 +1,95 @@ +// dashboard/auth/middleware_test.go +package auth + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "time" +) + +// setupForMiddleware seeds an admin in a fresh FileStore. The user is +// returned in the must_change=true state. Callers that want a "settled" +// admin should call store.SetPassword to clear the flag. +func setupForMiddleware(t *testing.T) (*FileStore, []byte, string) { + t.Helper() + store, _ := NewFileStore(filepath.Join(t.TempDir(), "users.json")) + _, _ = store.Seed(context.Background(), "admin") + secret := []byte("0123456789abcdef0123456789abcdef") + cookie, _ := Sign(secret, SessionPayload{Username: "admin", Role: string(RoleAdmin), Exp: time.Now().Add(time.Hour).Unix()}) + return store, secret, cookie +} + +func TestRequireAuthRedirectsAnon(t *testing.T) { + store, secret, _ := setupForMiddleware(t) + mw := RequireAuth(secret, store, 90) + h := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, httptest.NewRequest("GET", "/dashboard/", nil)) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d", rr.Code) + } + if !strings.HasPrefix(rr.Header().Get("Location"), "/dashboard/login") { + t.Fatalf("redirect: %q", rr.Header().Get("Location")) + } +} + +func TestRequireAuthAcceptsValidCookie(t *testing.T) { + store, secret, cookie := setupForMiddleware(t) + // Clear must_change so the handler is allowed to run. + if err := store.SetPassword(context.Background(), "admin", "newpass1234"); err != nil { + t.Fatalf("SetPassword: %v", err) + } + mw := RequireAuth(secret, store, 90) + h := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u := UserFromContext(r.Context()) + w.Write([]byte("hello " + u.Username)) + })) + req := httptest.NewRequest("GET", "/dashboard/", nil) + req.AddCookie(&http.Cookie{Name: "comqtt_session", Value: cookie}) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + body, _ := io.ReadAll(rr.Body) + if string(body) != "hello admin" { + t.Fatalf("body: %q", body) + } +} + +func TestRequireAuthForceRotateRedirects(t *testing.T) { + store, secret, cookie := setupForMiddleware(t) + // Seed leaves must_change=true; we want this state. + mw := RequireAuth(secret, store, 90) + h := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) })) + req := httptest.NewRequest("GET", "/dashboard/", nil) + req.AddCookie(&http.Cookie{Name: "comqtt_session", Value: cookie}) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d", rr.Code) + } + if !strings.HasPrefix(rr.Header().Get("Location"), "/dashboard/account/password") { + t.Fatalf("redirect: %q", rr.Header().Get("Location")) + } +} + +func TestRequireRoleRejectsViewerOnAdminRoute(t *testing.T) { + mw := RequireRole(RoleAdmin) + h := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) })) + req := httptest.NewRequest("DELETE", "/api/v1/whatever", nil) + ctx := WithUser(req.Context(), User{Username: "bob", Role: RoleViewer}) + req = req.WithContext(ctx) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusForbidden { + t.Fatalf("status: %d", rr.Code) + } +} diff --git a/dashboard/auth/secret_redis.go b/dashboard/auth/secret_redis.go new file mode 100644 index 0000000..cf7699b --- /dev/null +++ b/dashboard/auth/secret_redis.go @@ -0,0 +1,46 @@ +// dashboard/auth/secret_redis.go +package auth + +import ( + "context" + "crypto/rand" + "errors" + + "github.com/redis/go-redis/v9" +) + +// EnsureSecret returns the cluster-wide HMAC secret for dashboard session +// cookies. The secret is stored at the given Redis key. If absent, a fresh +// 32-byte secret is generated and SETNX-ed; on race loss, the winning +// value is read back. Returns at most 32 bytes. +func EnsureSecret(ctx context.Context, client *redis.Client, key string) ([]byte, error) { + if client == nil { + return nil, errors.New("auth: nil redis client") + } + if key == "" { + return nil, errors.New("auth: empty key") + } + + // First, try to read. + if b, err := client.Get(ctx, key).Bytes(); err == nil { + return b, nil + } else if !errors.Is(err, redis.Nil) { + return nil, err + } + + // Key absent. Generate and SETNX. + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return nil, err + } + ok, err := client.SetNX(ctx, key, buf, 0).Result() + if err != nil { + return nil, err + } + if ok { + return buf, nil + } + + // Lost the race; re-read the winning value. + return client.Get(ctx, key).Bytes() +} diff --git a/dashboard/auth/secret_redis_test.go b/dashboard/auth/secret_redis_test.go new file mode 100644 index 0000000..e1dfc91 --- /dev/null +++ b/dashboard/auth/secret_redis_test.go @@ -0,0 +1,82 @@ +// dashboard/auth/secret_redis_test.go +package auth + +import ( + "bytes" + "context" + "sync" + "testing" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +func newRedisClient(t *testing.T) (*redis.Client, *miniredis.Miniredis) { + t.Helper() + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { _ = client.Close() }) + return client, mr +} + +func TestEnsureSecretCreatesOnFirstCall(t *testing.T) { + client, _ := newRedisClient(t) + got, err := EnsureSecret(context.Background(), client, "test:secret") + if err != nil { + t.Fatalf("EnsureSecret: %v", err) + } + if len(got) != 32 { + t.Fatalf("expected 32 bytes, got %d", len(got)) + } +} + +func TestEnsureSecretReturnsExistingOnSecondCall(t *testing.T) { + client, _ := newRedisClient(t) + first, err := EnsureSecret(context.Background(), client, "test:secret") + if err != nil { + t.Fatalf("first EnsureSecret: %v", err) + } + second, err := EnsureSecret(context.Background(), client, "test:secret") + if err != nil { + t.Fatalf("second EnsureSecret: %v", err) + } + if !bytes.Equal(first, second) { + t.Fatalf("subsequent call should return same secret: %x vs %x", first, second) + } +} + +func TestEnsureSecretConcurrent(t *testing.T) { + client, _ := newRedisClient(t) + results := make([][]byte, 16) + var wg sync.WaitGroup + for i := range results { + i := i + wg.Add(1) + go func() { + defer wg.Done() + b, err := EnsureSecret(context.Background(), client, "test:secret") + if err != nil { + t.Errorf("goroutine %d: %v", i, err) + return + } + results[i] = b + }() + } + wg.Wait() + first := results[0] + for i, b := range results[1:] { + if !bytes.Equal(first, b) { + t.Fatalf("concurrent results diverged at %d: %x vs %x", i+1, first, b) + } + } +} + +func TestEnsureSecretRejectsBadInputs(t *testing.T) { + client, _ := newRedisClient(t) + if _, err := EnsureSecret(context.Background(), nil, "k"); err == nil { + t.Fatal("expected error for nil client") + } + if _, err := EnsureSecret(context.Background(), client, ""); err == nil { + t.Fatal("expected error for empty key") + } +} diff --git a/dashboard/auth/session.go b/dashboard/auth/session.go new file mode 100644 index 0000000..6b602ce --- /dev/null +++ b/dashboard/auth/session.go @@ -0,0 +1,65 @@ +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "strings" + "time" +) + +// SessionPayload is what gets HMAC-signed into the comqtt_session cookie. +// Stateless: server does no lookup beyond MAC verification + expiry check. +type SessionPayload struct { + Username string `json:"u"` + Role string `json:"r"` + Exp int64 `json:"e"` // unix seconds + Nonce string `json:"n"` // 8 random bytes b64 +} + +var ( + ErrBadMAC = errors.New("session: bad mac") + ErrExpired = errors.New("session: expired") + ErrMalformed = errors.New("session: malformed") +) + +func Sign(secret []byte, p SessionPayload) (string, error) { + body, err := json.Marshal(p) + if err != nil { + return "", err + } + b := base64.RawURLEncoding.EncodeToString(body) + mac := hmac.New(sha256.New, secret) + mac.Write([]byte(b)) + return b + "." + base64.RawURLEncoding.EncodeToString(mac.Sum(nil)), nil +} + +func Verify(secret []byte, cookie string) (SessionPayload, error) { + parts := strings.SplitN(cookie, ".", 2) + if len(parts) != 2 { + return SessionPayload{}, ErrMalformed + } + expected := hmac.New(sha256.New, secret) + expected.Write([]byte(parts[0])) + got, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return SessionPayload{}, ErrMalformed + } + if !hmac.Equal(expected.Sum(nil), got) { + return SessionPayload{}, ErrBadMAC + } + body, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return SessionPayload{}, ErrMalformed + } + var p SessionPayload + if err := json.Unmarshal(body, &p); err != nil { + return SessionPayload{}, ErrMalformed + } + if time.Now().Unix() > p.Exp { + return SessionPayload{}, ErrExpired + } + return p, nil +} diff --git a/dashboard/auth/session_test.go b/dashboard/auth/session_test.go new file mode 100644 index 0000000..28fb438 --- /dev/null +++ b/dashboard/auth/session_test.go @@ -0,0 +1,54 @@ +package auth + +import ( + "strings" + "testing" + "time" +) + +func TestSignAndVerifyRoundTrip(t *testing.T) { + secret := []byte("0123456789abcdef0123456789abcdef") + want := SessionPayload{Username: "alice", Role: "admin", Exp: time.Now().Add(time.Hour).Unix(), Nonce: "n1"} + cookie, err := Sign(secret, want) + if err != nil { + t.Fatalf("Sign: %v", err) + } + got, err := Verify(secret, cookie) + if err != nil { + t.Fatalf("Verify: %v", err) + } + if got != want { + t.Fatalf("payload mismatch: got %+v want %+v", got, want) + } +} + +func TestVerifyRejectsTampered(t *testing.T) { + secret := []byte("0123456789abcdef0123456789abcdef") + cookie, _ := Sign(secret, SessionPayload{Username: "alice", Role: "admin", Exp: time.Now().Add(time.Hour).Unix(), Nonce: "n1"}) + parts := strings.SplitN(cookie, ".", 2) + if len(parts) != 2 { + t.Fatalf("malformed cookie: %q", cookie) + } + // Flip a byte in the encoded payload so the HMAC no longer matches. + flipped := []byte(parts[0]) + flipped[3] ^= 0x01 + tampered := string(flipped) + "." + parts[1] + if _, err := Verify(secret, tampered); err == nil { + t.Fatal("expected error on tampered cookie") + } +} + +func TestVerifyRejectsExpired(t *testing.T) { + secret := []byte("0123456789abcdef0123456789abcdef") + cookie, _ := Sign(secret, SessionPayload{Username: "alice", Role: "admin", Exp: time.Now().Add(-time.Hour).Unix(), Nonce: "n1"}) + if _, err := Verify(secret, cookie); err == nil { + t.Fatal("expected expired error") + } +} + +func TestVerifyRejectsBadSecret(t *testing.T) { + cookie, _ := Sign([]byte("0123456789abcdef0123456789abcdef"), SessionPayload{Username: "alice", Role: "admin", Exp: time.Now().Add(time.Hour).Unix()}) + if _, err := Verify([]byte("ffffffffffffffffffffffffffffffff"), cookie); err == nil { + t.Fatal("expected bad-mac error") + } +} diff --git a/dashboard/cluster_integration_test.go b/dashboard/cluster_integration_test.go new file mode 100644 index 0000000..f5408c4 --- /dev/null +++ b/dashboard/cluster_integration_test.go @@ -0,0 +1,220 @@ +//go:build integration +// +build integration + +// dashboard/cluster_integration_test.go +package dashboard_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/dashboard/handlers" + "github.com/wind-c/comqtt/v2/dashboard/sse" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +// node bundles a fake "broker A" or "broker B" - one mqtt.Server, one +// hub, one bridge, one events handler exposed via httptest. +type node struct { + name string + server *mqtt.Server + hub *sse.Hub + bridge *sse.Bridge + srv *httptest.Server +} + +func newNode(t *testing.T, name string, redisAddr string) *node { + t.Helper() + server := mqtt.New(nil) + hub := sse.NewHub(64) + + if err := server.AddHook(&sse.HubHook{Hub: hub, Node: name}, nil); err != nil { + t.Fatalf("AddHook: %v", err) + } + + client := redis.NewClient(&redis.Options{Addr: redisAddr}) + t.Cleanup(func() { _ = client.Close() }) + + br := sse.NewBridge(client, hub, name) + br.Start(context.Background()) + + srv := httptest.NewServer(handlers.Events(hub)) + + t.Cleanup(func() { + srv.Close() + br.Stop() + hub.Close() + _ = server.Close() + }) + + return &node{name: name, server: server, hub: hub, bridge: br, srv: srv} +} + +// readSSEUntil reads `event:` lines from the SSE stream until one matches +// the wanted type with the wanted node, or the deadline elapses. Returns +// the matched parsed event or an error. +func readSSEUntil(t *testing.T, url string, wantType, wantNode string, timeout time.Duration) (sse.Event, error) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url+"?as=json", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return sse.Event{}, err + } + defer resp.Body.Close() + + buf := make([]byte, 1024) + var sb strings.Builder + for { + n, err := resp.Body.Read(buf) + if n > 0 { + sb.Write(buf[:n]) + body := sb.String() + records := strings.Split(body, "\n\n") + sb.Reset() + sb.WriteString(records[len(records)-1]) + for _, rec := range records[:len(records)-1] { + rec = strings.TrimSpace(rec) + if rec == "" || strings.HasPrefix(rec, ":") { + continue + } + var typeLine, dataLine string + for _, line := range strings.Split(rec, "\n") { + if strings.HasPrefix(line, "event: ") { + typeLine = strings.TrimPrefix(line, "event: ") + } else if strings.HasPrefix(line, "data: ") { + dataLine = strings.TrimPrefix(line, "data: ") + } + } + if typeLine != wantType { + continue + } + var ev sse.Event + if err := json.Unmarshal([]byte(dataLine), &ev); err != nil { + continue + } + if ev.Type == wantType && ev.Node == wantNode { + return ev, nil + } + } + } + if err != nil { + return sse.Event{}, err + } + } +} + +// TestClusterEventAggregation simulates two cluster nodes sharing events +// via miniredis. A connect event triggered on node-A is observed on +// node-B's SSE stream. +func TestClusterEventAggregation(t *testing.T) { + mr := miniredis.RunT(t) + addr := mr.Addr() + + a := newNode(t, "node-A", addr) + b := newNode(t, "node-B", addr) + + // Allow both bridges' subscribers to be registered with redis before we + // publish. miniredis is fast but PSubscribe still needs a moment to round + // trip. 150ms is plenty. + time.Sleep(150 * time.Millisecond) + + cl := &mqtt.Client{ID: "alice"} + cl.Properties.Username = []byte("u") + cl.Net.Remote = "127.0.0.1:0" + + hubHook := &sse.HubHook{Hub: a.hub, Node: "node-A"} + if err := hubHook.OnConnect(cl, packets.Packet{}); err != nil { + t.Fatalf("OnConnect: %v", err) + } + + ev, err := readSSEUntil(t, b.srv.URL, "client.connected", "node-A", 5*time.Second) + if err != nil { + t.Fatalf("readSSEUntil: %v", err) + } + if ev.Node != "node-A" { + t.Fatalf("expected node=node-A, got %q", ev.Node) + } + if ev.Type != "client.connected" { + t.Fatalf("expected type=client.connected, got %q", ev.Type) + } +} + +// TestClusterEventAggregationOwnNodeFiltered verifies that a tab connected +// to node-A doesn't see its own events DUPLICATED via the redis round trip. +func TestClusterEventAggregationOwnNodeFiltered(t *testing.T) { + mr := miniredis.RunT(t) + addr := mr.Addr() + a := newNode(t, "node-A", addr) + + time.Sleep(100 * time.Millisecond) + + hubHook := &sse.HubHook{Hub: a.hub, Node: "node-A"} + cl := &mqtt.Client{ID: "alice"} + cl.Properties.Username = []byte("u") + cl.Net.Remote = "127.0.0.1:0" + + // The hub is fan-out and does not buffer for not-yet-connected subscribers, + // so we must open the SSE stream before triggering events. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, a.srv.URL+"?as=json", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + defer resp.Body.Close() + + // Allow the SSE handler to subscribe to the hub before publishing. + time.Sleep(150 * time.Millisecond) + + for i := 0; i < 5; i++ { + if err := hubHook.OnConnect(cl, packets.Packet{}); err != nil { + t.Fatalf("OnConnect: %v", err) + } + } + + count := 0 + deadline := time.After(2 * time.Second) + buf := make([]byte, 1024) + var sb strings.Builder +DrainLoop: + for { + select { + case <-deadline: + break DrainLoop + default: + } + n, err := resp.Body.Read(buf) + if n > 0 { + sb.Write(buf[:n]) + count = strings.Count(sb.String(), "event: client.connected") + if count >= 5 { + time.Sleep(500 * time.Millisecond) + n, _ = resp.Body.Read(buf) + if n > 0 { + sb.Write(buf[:n]) + count = strings.Count(sb.String(), "event: client.connected") + } + break DrainLoop + } + } + if err != nil { + break + } + } + if count != 5 { + t.Fatalf("expected exactly 5 client.connected events, got %d", count) + } +} diff --git a/dashboard/dashboard.go b/dashboard/dashboard.go new file mode 100644 index 0000000..0a4212a --- /dev/null +++ b/dashboard/dashboard.go @@ -0,0 +1,301 @@ +// dashboard/dashboard.go +package dashboard + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "net/http" + "os" + "path/filepath" + "time" + + redis "github.com/redis/go-redis/v9" + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/dashboard/auth" + "github.com/wind-c/comqtt/v2/dashboard/handlers" + "github.com/wind-c/comqtt/v2/dashboard/sse" + "github.com/wind-c/comqtt/v2/mqtt/rest" +) + +// Options bundles wiring choices for the dashboard. +type Options struct { + // Cluster is true when running inside cmd/cluster. The nav menu shows + // the Cluster link only when this is set. + Cluster bool + + // ClusterAgent provides cluster topology to the Cluster page. nil in + // single-mode; required when Cluster=true to mount /dashboard/cluster. + ClusterAgent handlers.ClusterAgent + + // Server is the broker. Required. + Server *mqtt.Server + + // Store is the credential store. If nil, a FileStore is created at + // CredStorePath (defaulting to ./data/dashboard-users.json) and seeded + // with username "admin" and a random password printed to stdout. + Store auth.CredStore + + // CredStorePath is the file path for the FileStore. Used only when + // Store is nil. Default: ./data/dashboard-users.json. + CredStorePath string + + // Secret is the HMAC secret for the session cookie. If nil, one is + // loaded from SecretPath (auto-generated 32-byte file) or from the + // COMQTT_DASHBOARD_SESSION_SECRET env var (base64). + Secret []byte + + // SecretPath is the file path for the auto-generated secret. Default: + // ./data/dashboard-secret. + SecretPath string + + // PasswordExpiryDays. 0 = never expire. Default: 90. + PasswordExpiryDays int + + // Lockout configuration. Defaults: 5 fails / 5 min window / 10 min lock. + LockoutThreshold int + LockoutWindow time.Duration + LockoutDuration time.Duration + + // SessionTTL. Default: 12h. + SessionTTL time.Duration + + // Redis is the cluster-mode redis client. When set, the cred store and + // HMAC secret are read from redis (not the file paths), and a pub/sub + // bridge fans events between cluster nodes. nil in single-mode. + Redis *redis.Client +} + +// Routes returns the full set of HTTP routes for the dashboard, ready to be +// merged into another route map (e.g. rest.New(server).GenHandlers()) and +// registered onto the existing :8080 listener. +// +// Routes are returned as map[string]rest.Handler keyed on the Go 1.22+ +// pattern syntax "METHOD /path/{params}". The static asset path uses the +// {path...} catch-all. +// +// The returned cleanup func stops background goroutines (rate sampler, and +// the redis pub/sub bridge in cluster mode). Callers must invoke it on +// shutdown to avoid leaking goroutines. +func Routes(opts Options) (map[string]rest.Handler, func(), error) { + if opts.Server == nil { + return nil, nil, errors.New("dashboard: Options.Server is required") + } + if err := opts.applyDefaults(); err != nil { + return nil, nil, err + } + + var store auth.CredStore = opts.Store + if store == nil { + if opts.Redis != nil { + rs := auth.NewRedisStore(opts.Redis, "comqtt:dashboard") + pw, err := rs.Seed(context.Background(), "admin") + if err != nil { + return nil, nil, err + } + if pw != "" { + println("[dashboard] seeded admin password:", pw, "(rotate via /dashboard/account/password)") + } + store = rs + } else { + fs, err := auth.NewFileStore(opts.CredStorePath) + if err != nil { + return nil, nil, err + } + pw, err := fs.Seed(context.Background(), "admin") + if err != nil { + return nil, nil, err + } + if pw != "" { + // Printed exactly once: first boot. Operators must rotate. + println("[dashboard] seeded admin password:", pw, "(rotate via /dashboard/account/password)") + } + store = fs + } + } + + lockout := auth.NewLockout(auth.LockoutConfig{ + Threshold: opts.LockoutThreshold, + Window: opts.LockoutWindow, + Duration: opts.LockoutDuration, + }) + + hub := sse.NewHub(1024) + // Register the broker hook so connect/publish/disconnect events reach + // the hub. The hook id is unique per dashboard instance. + if err := opts.Server.AddHook(&sse.HubHook{Hub: hub, Node: hostname()}, nil); err != nil { + return nil, nil, err + } + + var bridge *sse.Bridge + if opts.Redis != nil { + bridge = sse.NewBridge(opts.Redis, hub, hostname()) + bridge.Start(context.Background()) + } + + rdr := handlers.NewRenderer(assetsFS) + sampler := handlers.NewRateSampler(opts.Server) + + accountDeps := handlers.AccountDeps{ + Store: store, + Lockout: lockout, + Secret: opts.Secret, + Renderer: rdr, + SessionTTL: opts.SessionTTL, + } + overviewDeps := handlers.OverviewDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster, Sampler: sampler, Agent: opts.ClusterAgent} + clientsDeps := handlers.ClientsDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster} + clientDetailDeps := handlers.ClientDetailDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster} + subscriptionsDeps := handlers.SubscriptionsDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster} + topicsDeps := handlers.TopicsDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster} + retainedDeps := handlers.RetainedDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster} + sessionsDeps := handlers.SessionsDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster} + blacklistDeps := handlers.BlacklistDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster} + toolsDeps := handlers.ToolsDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster} + settingsDeps := handlers.SettingsDeps{Server: opts.Server, Renderer: rdr, Cluster: opts.Cluster} + usersDeps := handlers.UsersDeps{Store: store, Renderer: rdr, Cluster: opts.Cluster} + + // Auth wrappers. + requireAuth := auth.RequireAuth(opts.Secret, store, opts.PasswordExpiryDays) + requireAdmin := auth.RequireRole(auth.RoleAdmin) + + // Helper to wrap a HandlerFunc through requireAuth (and optionally requireAdmin). + wrap := func(h http.HandlerFunc) rest.Handler { + wrapped := requireAuth(h) + return wrapped.ServeHTTP + } + wrapAdmin := func(h http.HandlerFunc) rest.Handler { + wrapped := requireAuth(requireAdmin(h)) + return wrapped.ServeHTTP + } + + staticHandler := http.StripPrefix("/dashboard/", http.FileServerFS(assetsFS)) + + routes := map[string]rest.Handler{ + // Public. + "GET /{$}": rootRedirect, + "GET /dashboard/login": handlers.LoginGet(accountDeps), + "POST /dashboard/login": handlers.LoginPost(accountDeps), + "POST /dashboard/logout": handlers.LogoutPost(), + "GET /dashboard/static/": staticHandler.ServeHTTP, + + // Authenticated pages. + "GET /dashboard/{$}": wrap(handlers.OverviewGet(overviewDeps)), + "GET /dashboard/fragments/overview-cards": wrap(handlers.OverviewCards(overviewDeps)), + "GET /dashboard/clients": wrap(handlers.ClientsList(clientsDeps)), + "GET /dashboard/clients/{id}": wrap(handlers.ClientDetail(clientDetailDeps)), + "POST /dashboard/clients/{id}/subscriptions/{topic}/delete": wrapAdmin(handlers.ClientUnsubscribe(clientDetailDeps)), + "GET /dashboard/subscriptions": wrap(handlers.SubscriptionsList(subscriptionsDeps)), + "GET /dashboard/topics": wrap(handlers.TopicsTree(topicsDeps)), + "GET /dashboard/retained": wrap(handlers.RetainedList(retainedDeps)), + "POST /dashboard/retained/{topic}/delete": wrapAdmin(handlers.RetainedClear(retainedDeps)), + "GET /dashboard/sessions": wrap(handlers.SessionsList(sessionsDeps)), + "POST /dashboard/sessions/{id}/delete": wrapAdmin(handlers.SessionsClear(sessionsDeps)), + "GET /dashboard/blacklist": wrap(handlers.BlacklistGet(blacklistDeps)), + "POST /dashboard/blacklist": wrapAdmin(handlers.BlacklistAdd(blacklistDeps)), + "POST /dashboard/blacklist/{id}/delete": wrapAdmin(handlers.BlacklistRemove(blacklistDeps)), + "GET /dashboard/tools": wrap(handlers.ToolsGet(toolsDeps)), + "POST /dashboard/tools/publish": wrapAdmin(handlers.ToolsPublish(toolsDeps)), + "GET /dashboard/settings": wrap(handlers.Settings(settingsDeps)), + "GET /dashboard/account": wrap(handlers.AccountGet(accountDeps)), + "GET /dashboard/account/password": wrap(handlers.ChangePasswordGet(accountDeps)), + "POST /dashboard/account/password": wrap(handlers.ChangePasswordPost(accountDeps)), + + // Admin-only pages. + "GET /dashboard/users": wrapAdmin(handlers.UsersList(usersDeps)), + "POST /dashboard/users": wrapAdmin(handlers.UsersCreate(usersDeps)), + "POST /dashboard/users/{username}/delete": wrapAdmin(handlers.UsersDelete(usersDeps)), + "POST /dashboard/users/{username}/role": wrapAdmin(handlers.UsersToggleRole(usersDeps)), + + // SSE. + "GET /dashboard/events": wrap(handlers.Events(hub)), + } + + if opts.Cluster && opts.ClusterAgent != nil { + routes["GET /dashboard/cluster"] = wrap(handlers.ClusterPage(handlers.ClusterDeps{ + Agent: opts.ClusterAgent, Renderer: rdr, Cluster: true, + })) + } + + cleanup := func() { + sampler.Stop() + if bridge != nil { + bridge.Stop() + } + } + + return routes, cleanup, nil +} + +func rootRedirect(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/dashboard/", http.StatusFound) +} + +func hostname() string { + if h, err := os.Hostname(); err == nil { + return h + } + return "unknown" +} + +func (o *Options) applyDefaults() error { + if o.CredStorePath == "" { + o.CredStorePath = "./data/dashboard-users.json" + } + if o.SecretPath == "" { + o.SecretPath = "./data/dashboard-secret" + } + if o.PasswordExpiryDays == 0 { + o.PasswordExpiryDays = 90 + } + if o.LockoutThreshold == 0 { + o.LockoutThreshold = 5 + } + if o.LockoutWindow == 0 { + o.LockoutWindow = 5 * time.Minute + } + if o.LockoutDuration == 0 { + o.LockoutDuration = 10 * time.Minute + } + if o.SessionTTL == 0 { + o.SessionTTL = 12 * time.Hour + } + if o.Secret == nil { + // In cluster mode, prefer the redis-backed secret so all nodes + // share the same HMAC key. Fall through to env/file on error. + if o.Redis != nil { + if b, err := auth.EnsureSecret(context.Background(), o.Redis, "comqtt:dashboard:secret"); err == nil && len(b) >= 16 { + o.Secret = b + return nil + } + } + // Try env first. + if env := os.Getenv("COMQTT_DASHBOARD_SESSION_SECRET"); env != "" { + b, err := base64.StdEncoding.DecodeString(env) + if err == nil && len(b) >= 16 { + o.Secret = b + return nil + } + } + // Then file. + b, err := os.ReadFile(o.SecretPath) + if err == nil && len(b) >= 16 { + o.Secret = b + return nil + } + // Generate. + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(o.SecretPath), 0o700); err != nil { + return err + } + if err := os.WriteFile(o.SecretPath, buf, 0o600); err != nil { + return err + } + o.Secret = buf + } + return nil +} diff --git a/dashboard/dashboard_test.go b/dashboard/dashboard_test.go new file mode 100644 index 0000000..4a15593 --- /dev/null +++ b/dashboard/dashboard_test.go @@ -0,0 +1,121 @@ +// dashboard/dashboard_test.go +package dashboard + +import ( + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + + "github.com/wind-c/comqtt/v2/mqtt" +) + +func TestRoutesRequiresServer(t *testing.T) { + if _, _, err := Routes(Options{}); err == nil { + t.Fatal("expected error for missing Server") + } +} + +func TestRoutesIncludesExpectedKeys(t *testing.T) { + server := mqtt.New(nil) + dir := t.TempDir() + rs, cleanup, err := Routes(Options{ + Server: server, + CredStorePath: filepath.Join(dir, "users.json"), + SecretPath: filepath.Join(dir, "secret"), + }) + if err != nil { + t.Fatalf("Routes: %v", err) + } + defer cleanup() + for _, want := range []string{ + "GET /{$}", + "GET /dashboard/login", + "POST /dashboard/login", + "POST /dashboard/logout", + "GET /dashboard/static/", + "GET /dashboard/{$}", + "GET /dashboard/clients", + "GET /dashboard/users", + "GET /dashboard/events", + } { + if _, ok := rs[want]; !ok { + t.Errorf("missing route: %q", want) + } + } +} + +func TestRoutesRedirectFromRoot(t *testing.T) { + server := mqtt.New(nil) + dir := t.TempDir() + rs, cleanup, _ := Routes(Options{ + Server: server, + CredStorePath: filepath.Join(dir, "users.json"), + SecretPath: filepath.Join(dir, "secret"), + }) + defer cleanup() + mux := http.NewServeMux() + for path, handler := range rs { + mux.HandleFunc(path, handler) + } + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d", rr.Code) + } + if loc := rr.Header().Get("Location"); !strings.HasPrefix(loc, "/dashboard/") { + t.Fatalf("location: %q", loc) + } +} + +func TestRoutesAnonRedirectsToLogin(t *testing.T) { + server := mqtt.New(nil) + dir := t.TempDir() + rs, cleanup, _ := Routes(Options{ + Server: server, + CredStorePath: filepath.Join(dir, "users.json"), + SecretPath: filepath.Join(dir, "secret"), + }) + defer cleanup() + mux := http.NewServeMux() + for path, handler := range rs { + mux.HandleFunc(path, handler) + } + // Anon GET /dashboard/clients should redirect to /dashboard/login. + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/dashboard/clients", nil)) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + if loc := rr.Header().Get("Location"); !strings.HasPrefix(loc, "/dashboard/login") { + t.Fatalf("location: %q", loc) + } +} + +func TestRoutesStaticServesAsset(t *testing.T) { + server := mqtt.New(nil) + dir := t.TempDir() + rs, cleanup, _ := Routes(Options{ + Server: server, + CredStorePath: filepath.Join(dir, "users.json"), + SecretPath: filepath.Join(dir, "secret"), + }) + defer cleanup() + mux := http.NewServeMux() + for path, handler := range rs { + mux.HandleFunc(path, handler) + } + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/dashboard/static/tailwind.css", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "tailwindcss") { + body := rr.Body.String() + if len(body) > 200 { + body = body[:200] + } + t.Fatalf("expected CSS body: %q", body) + } +} diff --git a/dashboard/embed.go b/dashboard/embed.go new file mode 100644 index 0000000..d21a010 --- /dev/null +++ b/dashboard/embed.go @@ -0,0 +1,6 @@ +package dashboard + +import "embed" + +//go:embed all:templates all:static +var assetsFS embed.FS diff --git a/dashboard/handlers/account.go b/dashboard/handlers/account.go new file mode 100644 index 0000000..aa06fae --- /dev/null +++ b/dashboard/handlers/account.go @@ -0,0 +1,239 @@ +// dashboard/handlers/account.go +package handlers + +import ( + "context" + "errors" + "net/http" + "net/url" + "strings" + "time" + + "github.com/wind-c/comqtt/v2/dashboard/auth" +) + +// AccountDeps bundles the dependencies the account handlers need. +type AccountDeps struct { + Store auth.CredStore + Lockout *auth.Lockout + Secret []byte + Renderer *Renderer + // SessionTTL is the cookie Max-Age; defaults to 12h if zero. + SessionTTL time.Duration +} + +func (d AccountDeps) ttl() time.Duration { + if d.SessionTTL == 0 { + return 12 * time.Hour + } + return d.SessionTTL +} + +// LoginGet renders the login form. +func LoginGet(d AccountDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + d.Renderer.Render(w, "login", map[string]any{ + "CSRF": auth.NewCSRFToken(), + "Next": sanitizeNext(r.URL.Query().Get("next")), + "Error": "", + }) + } +} + +// LoginPost validates credentials, sets the session cookie on success, or +// re-renders the form with an error. +func LoginPost(d AccountDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + username := strings.TrimSpace(r.PostFormValue("username")) + password := r.PostFormValue("password") + next := sanitizeNext(r.PostFormValue("next")) + + if locked, until := d.Lockout.IsLocked(username); locked { + renderLoginError(d, w, r, next, "Account temporarily locked. Try again at "+until.Format("15:04")+".") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + defer cancel() + u, err := d.Store.Authenticate(ctx, username, password) + if err != nil { + if errors.Is(err, auth.ErrLocked) { + renderLoginError(d, w, r, next, "Account locked.") + return + } + d.Lockout.Record(username) + renderLoginError(d, w, r, next, "Invalid credentials.") + return + } + d.Lockout.Reset(username) + + payload := auth.SessionPayload{ + Username: u.Username, + Role: string(u.Role), + Exp: time.Now().Add(d.ttl()).Unix(), + Nonce: auth.NewCSRFToken(), + } + cookie, err := auth.Sign(d.Secret, payload) + if err != nil { + http.Error(w, "session: "+err.Error(), http.StatusInternalServerError) + return + } + http.SetCookie(w, &http.Cookie{ + Name: "comqtt_session", + Value: cookie, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: r.TLS != nil, + MaxAge: int(d.ttl().Seconds()), + }) + http.Redirect(w, r, next, http.StatusFound) + } +} + +// LogoutPost clears the session cookie and redirects to the login page. +// Cookie attributes match LoginPost so the browser recognises this as the +// same cookie and actually expires it (Chrome/Safari are strict about that +// match; clearing without HttpOnly/SameSite leaves a stale cookie behind). +func LogoutPost() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{ + Name: "comqtt_session", + Value: "", + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: r.TLS != nil, + MaxAge: -1, + Expires: time.Unix(0, 0), + }) + http.Redirect(w, r, "/dashboard/login", http.StatusFound) + } +} + +func renderLoginError(d AccountDeps, w http.ResponseWriter, r *http.Request, next, msg string) { + w.WriteHeader(http.StatusUnauthorized) + d.Renderer.Render(w, "login", map[string]any{ + "CSRF": auth.NewCSRFToken(), + "Next": next, + "Error": msg, + }) +} + +const minPasswordLen = 8 + +// ChangePasswordGet renders the password-change form. Used both for forced +// rotation (?reason=must_change|expired) and personal password change. +func ChangePasswordGet(d AccountDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + reason := r.URL.Query().Get("reason") + d.Renderer.Render(w, "account/password", map[string]any{ + "CSRF": auth.NewCSRFToken(), + "Reason": reason, + "Error": "", + }) + } +} + +// ChangePasswordPost validates the current password, enforces a new-password +// minimum length, persists the new bcrypt hash, and redirects to the +// dashboard root. The cred store's SetPassword clears must_change. +func ChangePasswordPost(d AccountDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + current := r.PostFormValue("current") + next := r.PostFormValue("new") + confirm := r.PostFormValue("confirm") + reason := r.URL.Query().Get("reason") + + render := func(status int, msg string) { + w.WriteHeader(status) + d.Renderer.Render(w, "account/password", map[string]any{ + "CSRF": auth.NewCSRFToken(), + "Reason": reason, + "Error": msg, + }) + } + + u := auth.UserFromContext(r.Context()) + if u.Username == "" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + if next != confirm { + render(http.StatusBadRequest, "New password and confirmation do not match.") + return + } + if len(next) < minPasswordLen { + render(http.StatusBadRequest, "Password must be at least 8 characters.") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + defer cancel() + if _, err := d.Store.Authenticate(ctx, u.Username, current); err != nil { + render(http.StatusUnauthorized, "Current password is incorrect.") + return + } + if err := d.Store.SetPassword(ctx, u.Username, next); err != nil { + render(http.StatusInternalServerError, "Failed to update password: "+err.Error()) + return + } + http.Redirect(w, r, "/dashboard/", http.StatusFound) + } +} + +// AccountGet renders the personal account details page. +// All authenticated users can view; the page itself is read-only. +func AccountGet(d AccountDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + u := auth.UserFromContext(r.Context()) + if u.Username == "" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + defer cancel() + stored, err := d.Store.GetUser(ctx, u.Username) + if err != nil { + http.Error(w, "user lookup: "+err.Error(), http.StatusInternalServerError) + return + } + passwordSet := "never" + if stored.PasswordSetAt > 0 { + passwordSet = time.Unix(stored.PasswordSetAt, 0).UTC().Format(time.RFC3339) + } + d.Renderer.Render(w, "account_personal", map[string]any{ + "Title": "Account", + "User": u, + "CSRF": auth.NewCSRFToken(), + "Cluster": false, + "Account": map[string]any{ + "Username": stored.Username, + "Role": string(stored.Role), + "PasswordSetAt": passwordSet, + "MustChange": stored.MustChange, + "Locked": stored.LockedUntil > time.Now().Unix(), + }, + }) + } +} + +// sanitizeNext only accepts paths under /dashboard/ to prevent open-redirect. +// External targets, scheme-relative URLs, and missing leading slash all fall +// back to the default landing page. +func sanitizeNext(raw string) string { + if raw == "" { + return "/dashboard/" + } + parsed, err := url.Parse(raw) + if err != nil || parsed.Host != "" || parsed.Scheme != "" { + return "/dashboard/" + } + if !strings.HasPrefix(parsed.Path, "/dashboard/") { + return "/dashboard/" + } + return parsed.Path +} diff --git a/dashboard/handlers/account_test.go b/dashboard/handlers/account_test.go new file mode 100644 index 0000000..7858b24 --- /dev/null +++ b/dashboard/handlers/account_test.go @@ -0,0 +1,269 @@ +// dashboard/handlers/account_test.go +package handlers + +import ( + "context" + "io/fs" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "strings" + "testing" + "testing/fstest" + "time" + + "github.com/wind-c/comqtt/v2/dashboard/auth" +) + +const loginTemplate = `{{define "login"}}
{{if .Error}}

{{.Error}}

{{end}}
{{end}}` + +const passwordTemplate = `{{define "account/password"}}
{{if .Error}}

{{.Error}}

{{end}}

{{.Reason}}

{{end}}` + +const personalTemplate = `{{define "layout"}}{{template "content" .}}{{end}}{{define "account_personal"}}{{template "layout" .}}{{end}}{{define "content"}}
user={{.Account.Username}} role={{.Account.Role}} mc={{.Account.MustChange}} locked={{.Account.Locked}} pset={{.Account.PasswordSetAt}}
{{end}}` + +func newRenderer(t *testing.T) *Renderer { + t.Helper() + return NewRenderer(fakeTplFS(t)) +} + +func fakeTplFS(t *testing.T) fs.FS { + t.Helper() + return fstest.MapFS{ + "templates/login.html": &fstest.MapFile{Data: []byte(loginTemplate)}, + "templates/account/password.html": &fstest.MapFile{Data: []byte(passwordTemplate)}, + "templates/account/personal.html": &fstest.MapFile{Data: []byte(personalTemplate)}, + } +} + +func newAccountDeps(t *testing.T) (AccountDeps, *auth.FileStore) { + t.Helper() + store, err := auth.NewFileStore(filepath.Join(t.TempDir(), "users.json")) + if err != nil { + t.Fatalf("NewFileStore: %v", err) + } + return AccountDeps{ + Store: store, + Lockout: auth.NewLockout(auth.LockoutConfig{Threshold: 3, Window: time.Minute, Duration: 10 * time.Minute}), + Secret: []byte("0123456789abcdef0123456789abcdef"), + Renderer: newRenderer(t), + }, store +} + +func TestLoginGetRendersForm(t *testing.T) { + deps, _ := newAccountDeps(t) + rr := httptest.NewRecorder() + LoginGet(deps)(rr, httptest.NewRequest(http.MethodGet, "/dashboard/login", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "
") { + t.Fatalf("body: %q", rr.Body.String()) + } +} + +func TestLoginPostValidCredentialsSetsCookie(t *testing.T) { + deps, store := newAccountDeps(t) + pw, _ := store.Seed(context.Background(), "admin") + form := url.Values{"username": {"admin"}, "password": {pw}} + req := httptest.NewRequest(http.MethodPost, "/dashboard/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + LoginPost(deps)(rr, req) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + cookies := rr.Result().Cookies() + if len(cookies) == 0 || cookies[0].Name != "comqtt_session" || cookies[0].Value == "" { + t.Fatalf("cookie not set: %+v", cookies) + } + if !cookies[0].HttpOnly { + t.Fatal("cookie should be HttpOnly") + } +} + +func TestLoginPostBadCredsRecordsLockout(t *testing.T) { + deps, store := newAccountDeps(t) + _, _ = store.Seed(context.Background(), "admin") + form := url.Values{"username": {"admin"}, "password": {"wrong"}} + for i := 0; i < 3; i++ { + req := httptest.NewRequest(http.MethodPost, "/dashboard/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + LoginPost(deps)(rr, req) + } + locked, _ := deps.Lockout.IsLocked("admin") + if !locked { + t.Fatal("expected lockout after threshold failures") + } +} + +func TestLoginPostLockedShowsError(t *testing.T) { + deps, store := newAccountDeps(t) + pw, _ := store.Seed(context.Background(), "admin") + for i := 0; i < 3; i++ { + deps.Lockout.Record("admin") + } + form := url.Values{"username": {"admin"}, "password": {pw}} + req := httptest.NewRequest(http.MethodPost, "/dashboard/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + LoginPost(deps)(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatalf("status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "locked") { + t.Fatalf("expected locked message: %q", rr.Body.String()) + } +} + +func TestLogoutClearsCookie(t *testing.T) { + rr := httptest.NewRecorder() + LogoutPost()(rr, httptest.NewRequest(http.MethodPost, "/dashboard/logout", nil)) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d", rr.Code) + } + cookies := rr.Result().Cookies() + if len(cookies) == 0 || cookies[0].MaxAge >= 0 { + t.Fatalf("cookie should be expired: %+v", cookies) + } +} + +func TestSanitizeNext(t *testing.T) { + cases := []struct { + in, want string + }{ + {"", "/dashboard/"}, + {"/dashboard/clients", "/dashboard/clients"}, + {"/etc/passwd", "/dashboard/"}, + {"//evil.com/x", "/dashboard/"}, + {"https://evil.com/x", "/dashboard/"}, + {"javascript:alert(1)", "/dashboard/"}, + } + for _, c := range cases { + if got := sanitizeNext(c.in); got != c.want { + t.Errorf("sanitizeNext(%q): got %q want %q", c.in, got, c.want) + } + } +} + +func TestChangePasswordGetRenders(t *testing.T) { + deps, _ := newAccountDeps(t) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/dashboard/account/password?reason=must_change", nil) + ChangePasswordGet(deps)(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "must_change") { + t.Fatalf("expected reason in body: %q", rr.Body.String()) + } +} + +func TestChangePasswordPostHappyPath(t *testing.T) { + deps, store := newAccountDeps(t) + pw, _ := store.Seed(context.Background(), "admin") + form := url.Values{"current": {pw}, "new": {"newpass1234"}, "confirm": {"newpass1234"}} + req := httptest.NewRequest(http.MethodPost, "/dashboard/account/password", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "admin", Role: auth.RoleAdmin})) + rr := httptest.NewRecorder() + ChangePasswordPost(deps)(rr, req) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + if _, err := store.Authenticate(context.Background(), "admin", "newpass1234"); err != nil { + t.Fatalf("new password should work: %v", err) + } + u, _ := store.GetUser(context.Background(), "admin") + if u.MustChange { + t.Fatal("must_change should be cleared") + } +} + +func TestChangePasswordPostMismatchedConfirm(t *testing.T) { + deps, store := newAccountDeps(t) + pw, _ := store.Seed(context.Background(), "admin") + form := url.Values{"current": {pw}, "new": {"newpass1234"}, "confirm": {"different5678"}} + req := httptest.NewRequest(http.MethodPost, "/dashboard/account/password", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "admin", Role: auth.RoleAdmin})) + rr := httptest.NewRecorder() + ChangePasswordPost(deps)(rr, req) + if rr.Code != http.StatusBadRequest { + t.Fatalf("status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "do not match") { + t.Fatalf("expected mismatch message: %q", rr.Body.String()) + } +} + +func TestChangePasswordPostShortPassword(t *testing.T) { + deps, store := newAccountDeps(t) + pw, _ := store.Seed(context.Background(), "admin") + form := url.Values{"current": {pw}, "new": {"short"}, "confirm": {"short"}} + req := httptest.NewRequest(http.MethodPost, "/dashboard/account/password", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "admin", Role: auth.RoleAdmin})) + rr := httptest.NewRecorder() + ChangePasswordPost(deps)(rr, req) + if rr.Code != http.StatusBadRequest { + t.Fatalf("status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "8 characters") { + t.Fatalf("expected length error: %q", rr.Body.String()) + } +} + +func TestChangePasswordPostWrongCurrent(t *testing.T) { + deps, store := newAccountDeps(t) + _, _ = store.Seed(context.Background(), "admin") + form := url.Values{"current": {"wrong"}, "new": {"newpass1234"}, "confirm": {"newpass1234"}} + req := httptest.NewRequest(http.MethodPost, "/dashboard/account/password", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "admin", Role: auth.RoleAdmin})) + rr := httptest.NewRecorder() + ChangePasswordPost(deps)(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } +} + +func TestAccountGetRendersDetails(t *testing.T) { + deps, store := newAccountDeps(t) + pw, _ := store.Seed(context.Background(), "admin") + if err := store.SetPassword(context.Background(), "admin", "newpass1234"); err != nil { + t.Fatalf("SetPassword: %v", err) + } + _ = pw + + req := httptest.NewRequest(http.MethodGet, "/dashboard/account", nil) + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "admin", Role: auth.RoleAdmin})) + rr := httptest.NewRecorder() + AccountGet(deps)(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + body := rr.Body.String() + if !strings.Contains(body, "user=admin") { + t.Fatalf("expected username: %q", body) + } + if !strings.Contains(body, "role=admin") { + t.Fatalf("expected role: %q", body) + } + if !strings.Contains(body, "mc=false") { + t.Fatalf("expected must_change cleared: %q", body) + } + if strings.Contains(body, "pset=never") { + t.Fatalf("PasswordSetAt should be a timestamp not 'never': %q", body) + } +} + +func TestAccountGetUnauthorizedWithoutContext(t *testing.T) { + deps, _ := newAccountDeps(t) + rr := httptest.NewRecorder() + AccountGet(deps)(rr, httptest.NewRequest(http.MethodGet, "/dashboard/account", nil)) + if rr.Code != http.StatusUnauthorized { + t.Fatalf("status: %d", rr.Code) + } +} diff --git a/dashboard/handlers/blacklist.go b/dashboard/handlers/blacklist.go new file mode 100644 index 0000000..167f0cf --- /dev/null +++ b/dashboard/handlers/blacklist.go @@ -0,0 +1,96 @@ +// dashboard/handlers/blacklist.go +package handlers + +import ( + "net/http" + "slices" + "strings" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/dashboard/auth" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +// BlacklistDeps bundles dependencies for the Blacklist page. +type BlacklistDeps struct { + Server *mqtt.Server + Renderer *Renderer + Cluster bool +} + +type blacklistPageData struct { + Title string + User auth.User + CSRF string + Cluster bool + Flash string + Error string + Items []string + Readonly bool +} + +// BlacklistGet renders the blacklist page (GET /dashboard/blacklist). +func BlacklistGet(d BlacklistDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + u := auth.UserFromContext(r.Context()) + d.Renderer.Render(w, "blacklist", blacklistPageData{ + Title: "Blacklist", + User: u, + CSRF: auth.NewCSRFToken(), + Cluster: d.Cluster, + Items: snapshotBlacklist(d.Server), + Readonly: u.Role != auth.RoleAdmin, + }) + } +} + +// BlacklistAdd handles POST /dashboard/blacklist. +// Admin-only. Adds the client_id to the blacklist and disconnects them if +// currently connected. +func BlacklistAdd(d BlacklistDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if auth.UserFromContext(r.Context()).Role != auth.RoleAdmin { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + _ = r.ParseForm() + id := strings.TrimSpace(r.PostFormValue("client_id")) + if id == "" { + http.Error(w, "client_id required", http.StatusBadRequest) + return + } + if !slices.Contains(d.Server.Blacklist, id) { + d.Server.Blacklist = append(d.Server.Blacklist, id) + } + if cl, ok := d.Server.Clients.Get(id); ok { + d.Server.DisconnectClient(cl, packets.ErrNotAuthorized) + } + http.Redirect(w, r, "/dashboard/blacklist", http.StatusFound) + } +} + +// BlacklistRemove handles POST /dashboard/blacklist/{id}/delete. +// Admin-only. Removes the entry from the blacklist. +func BlacklistRemove(d BlacklistDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if auth.UserFromContext(r.Context()).Role != auth.RoleAdmin { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + id := r.PathValue("id") + if id == "" { + http.Error(w, "id required", http.StatusBadRequest) + return + } + d.Server.Blacklist = slices.DeleteFunc(d.Server.Blacklist, func(s string) bool { return s == id }) + http.Redirect(w, r, "/dashboard/blacklist", http.StatusFound) + } +} + +// snapshotBlacklist copies the slice so concurrent broker mutation can't +// race the template execution. Cheap relative to bcrypt et al. +func snapshotBlacklist(s *mqtt.Server) []string { + out := make([]string, len(s.Blacklist)) + copy(out, s.Blacklist) + return out +} diff --git a/dashboard/handlers/blacklist_test.go b/dashboard/handlers/blacklist_test.go new file mode 100644 index 0000000..1981ab0 --- /dev/null +++ b/dashboard/handlers/blacklist_test.go @@ -0,0 +1,123 @@ +// dashboard/handlers/blacklist_test.go +package handlers + +import ( + "io/fs" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "testing/fstest" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/dashboard/auth" +) + +func newBlacklistRenderer(t *testing.T) *Renderer { + t.Helper() + return NewRenderer(fs.FS(fstest.MapFS{ + "templates/blacklist.html": &fstest.MapFile{Data: []byte(`{{define "layout"}}{{template "content" .}}{{end}}{{define "blacklist"}}{{template "layout" .}}{{end}}{{define "content"}}
readonly={{.Readonly}}
{{range .Items}}
{{.}}
{{else}}
empty
{{end}}{{end}}`)}, + })) +} + +func newBlacklistDeps(t *testing.T) BlacklistDeps { + t.Helper() + return BlacklistDeps{ + Server: mqtt.New(nil), + Renderer: newBlacklistRenderer(t), + } +} + +func adminCtx(r *http.Request) *http.Request { + return r.WithContext(auth.WithUser(r.Context(), auth.User{Username: "admin", Role: auth.RoleAdmin})) +} + +func viewerCtx(r *http.Request) *http.Request { + return r.WithContext(auth.WithUser(r.Context(), auth.User{Username: "v", Role: auth.RoleViewer})) +} + +func TestBlacklistGetEmpty(t *testing.T) { + deps := newBlacklistDeps(t) + rr := httptest.NewRecorder() + BlacklistGet(deps)(rr, adminCtx(httptest.NewRequest(http.MethodGet, "/dashboard/blacklist", nil))) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "empty") { + t.Fatalf("expected empty: %q", rr.Body.String()) + } + if !strings.Contains(rr.Body.String(), "readonly=false") { + t.Fatalf("expected admin readonly=false: %q", rr.Body.String()) + } +} + +func TestBlacklistGetViewerReadonly(t *testing.T) { + deps := newBlacklistDeps(t) + rr := httptest.NewRecorder() + BlacklistGet(deps)(rr, viewerCtx(httptest.NewRequest(http.MethodGet, "/dashboard/blacklist", nil))) + if !strings.Contains(rr.Body.String(), "readonly=true") { + t.Fatalf("expected viewer readonly=true: %q", rr.Body.String()) + } +} + +func TestBlacklistAddAdmin(t *testing.T) { + deps := newBlacklistDeps(t) + form := url.Values{"client_id": {"bad-bot"}} + req := httptest.NewRequest(http.MethodPost, "/dashboard/blacklist", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + BlacklistAdd(deps)(rr, adminCtx(req)) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + if len(deps.Server.Blacklist) != 1 || deps.Server.Blacklist[0] != "bad-bot" { + t.Fatalf("blacklist: %v", deps.Server.Blacklist) + } +} + +func TestBlacklistAddViewerForbidden(t *testing.T) { + deps := newBlacklistDeps(t) + form := url.Values{"client_id": {"bad-bot"}} + req := httptest.NewRequest(http.MethodPost, "/dashboard/blacklist", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + BlacklistAdd(deps)(rr, viewerCtx(req)) + if rr.Code != http.StatusForbidden { + t.Fatalf("status: %d", rr.Code) + } + if len(deps.Server.Blacklist) != 0 { + t.Fatalf("viewer should not mutate blacklist: %v", deps.Server.Blacklist) + } +} + +func TestBlacklistAddIsIdempotent(t *testing.T) { + deps := newBlacklistDeps(t) + deps.Server.Blacklist = []string{"existing"} + form := url.Values{"client_id": {"existing"}} + req := httptest.NewRequest(http.MethodPost, "/dashboard/blacklist", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + BlacklistAdd(deps)(rr, adminCtx(req)) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d", rr.Code) + } + if len(deps.Server.Blacklist) != 1 { + t.Fatalf("expected idempotent add, got %v", deps.Server.Blacklist) + } +} + +func TestBlacklistRemoveAdmin(t *testing.T) { + deps := newBlacklistDeps(t) + deps.Server.Blacklist = []string{"a", "b", "c"} + req := httptest.NewRequest(http.MethodPost, "/dashboard/blacklist/b/delete", nil) + req.SetPathValue("id", "b") + rr := httptest.NewRecorder() + BlacklistRemove(deps)(rr, adminCtx(req)) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d", rr.Code) + } + if len(deps.Server.Blacklist) != 2 || deps.Server.Blacklist[0] != "a" || deps.Server.Blacklist[1] != "c" { + t.Fatalf("expected [a c], got %v", deps.Server.Blacklist) + } +} diff --git a/dashboard/handlers/client_detail.go b/dashboard/handlers/client_detail.go new file mode 100644 index 0000000..eae6d38 --- /dev/null +++ b/dashboard/handlers/client_detail.go @@ -0,0 +1,138 @@ +// dashboard/handlers/client_detail.go +package handlers + +import ( + "net/http" + "net/url" + "sort" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/dashboard/auth" +) + +// ClientDetailDeps bundles dependencies for the Client detail page. +type ClientDetailDeps struct { + Server *mqtt.Server + Renderer *Renderer + Cluster bool +} + +type clientDetailPageData struct { + Title string + User auth.User + CSRF string + Cluster bool + Flash string + Error string + Readonly bool + + ClientID string + Online bool + Tab string + SubCount int + Info clientInfoSection + Subs []subscriptionRow +} + +type clientInfoSection struct { + Username string + Remote string + Listener string + ProtocolVersion byte + Keepalive uint16 + Inflight int + Clean bool +} + +type subscriptionRow struct { + Topic string + TopicEncoded string + QoS byte +} + +// ClientDetail handles GET /dashboard/clients/{id}. +func ClientDetail(d ClientDetailDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + http.Error(w, "client id required", http.StatusBadRequest) + return + } + tab := r.URL.Query().Get("tab") + if tab == "" { + tab = "info" + } + + u := auth.UserFromContext(r.Context()) + page := clientDetailPageData{ + Title: "Client " + id, + User: u, + CSRF: auth.NewCSRFToken(), + Cluster: d.Cluster, + Readonly: u.Role != auth.RoleAdmin, + ClientID: id, + Tab: tab, + } + + cl, online := d.Server.Clients.Get(id) + if !online { + http.NotFound(w, r) + return + } + page.Online = true + page.Info = clientInfoSection{ + Username: string(cl.Properties.Username), + Remote: cl.Net.Remote, + Listener: cl.Net.Listener, + ProtocolVersion: cl.Properties.ProtocolVersion, + Keepalive: cl.State.Keepalive, + Inflight: cl.State.Inflight.Len(), + Clean: cl.Properties.Clean, + } + + subsMap := cl.State.Subscriptions.GetAll() + page.SubCount = len(subsMap) + page.Subs = make([]subscriptionRow, 0, len(subsMap)) + for filter, sub := range subsMap { + page.Subs = append(page.Subs, subscriptionRow{ + Topic: filter, + TopicEncoded: url.PathEscape(filter), + QoS: sub.Qos, + }) + } + sort.Slice(page.Subs, func(i, j int) bool { return page.Subs[i].Topic < page.Subs[j].Topic }) + + d.Renderer.Render(w, "client_detail", page) + } +} + +// ClientUnsubscribe handles POST /dashboard/clients/{id}/subscriptions/{topic}/delete. +// Admin-only. Removes the subscription from both the broker's topic trie +// (Topics.Unsubscribe) and the per-client map (Subscriptions.Delete) - +// matches the REST handler in mqtt/rest/client_unsubscribe.go. +func ClientUnsubscribe(d ClientDetailDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if auth.UserFromContext(r.Context()).Role != auth.RoleAdmin { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + id := r.PathValue("id") + topic, err := url.PathUnescape(r.PathValue("topic")) + if err != nil || topic == "" || id == "" { + http.Error(w, "missing or malformed id/topic", http.StatusBadRequest) + return + } + cl, ok := d.Server.Clients.Get(id) + if !ok { + http.NotFound(w, r) + return + } + if _, has := cl.State.Subscriptions.Get(topic); !has { + http.NotFound(w, r) + return + } + d.Server.Topics.Unsubscribe(topic, id) + cl.State.Subscriptions.Delete(topic) + http.Redirect(w, r, "/dashboard/clients/"+id+"?tab=subs", http.StatusFound) + } +} diff --git a/dashboard/handlers/client_detail_test.go b/dashboard/handlers/client_detail_test.go new file mode 100644 index 0000000..c9a4150 --- /dev/null +++ b/dashboard/handlers/client_detail_test.go @@ -0,0 +1,152 @@ +// dashboard/handlers/client_detail_test.go +package handlers + +import ( + "io/fs" + "net/http" + "net/http/httptest" + "strings" + "testing" + "testing/fstest" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/dashboard/auth" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +func newClientDetailRenderer(t *testing.T) *Renderer { + t.Helper() + return NewRenderer(fs.FS(fstest.MapFS{ + "templates/clients/detail.html": &fstest.MapFile{Data: []byte(`{{define "layout"}}{{template "content" .}}{{end}}{{define "client_detail"}}{{template "layout" .}}{{end}}{{define "content"}}
id={{.ClientID}} online={{.Online}} tab={{.Tab}} subcount={{.SubCount}} readonly={{.Readonly}}
{{if eq .Tab "info"}}
user={{.Info.Username}} remote={{.Info.Remote}} kal={{.Info.Keepalive}}
{{end}}{{if eq .Tab "subs"}}{{range .Subs}}
{{.Topic}}|{{.TopicEncoded}}|qos{{.QoS}}
{{else}}
no subs
{{end}}{{end}}{{end}}`)}, + })) +} + +func newClientDetailDeps(t *testing.T) ClientDetailDeps { + t.Helper() + return ClientDetailDeps{ + Server: mqtt.New(nil), + Renderer: newClientDetailRenderer(t), + } +} + +func addClientForDetail(t *testing.T, server *mqtt.Server, id string, filters ...string) { + t.Helper() + cl := &mqtt.Client{ID: id} + cl.Properties.Username = []byte("u") + cl.Properties.ProtocolVersion = 5 + cl.Properties.Clean = false + cl.Net.Remote = "127.0.0.1:0" + cl.Net.Listener = "tcp" + cl.State.Keepalive = 60 + cl.State.Subscriptions = mqtt.NewSubscriptions() + cl.State.Inflight = mqtt.NewInflights() + for _, f := range filters { + cl.State.Subscriptions.Add(f, packets.Subscription{Filter: f, Qos: 1}) + server.Topics.Subscribe(id, packets.Subscription{Filter: f, Qos: 1}) + } + server.Clients.Add(cl) +} + +func TestClientDetailNotFound(t *testing.T) { + deps := newClientDetailDeps(t) + req := httptest.NewRequest(http.MethodGet, "/dashboard/clients/ghost", nil) + req.SetPathValue("id", "ghost") + rr := httptest.NewRecorder() + ClientDetail(deps)(rr, req) + if rr.Code != http.StatusNotFound { + t.Fatalf("status: %d", rr.Code) + } +} + +func TestClientDetailInfoTab(t *testing.T) { + deps := newClientDetailDeps(t) + addClientForDetail(t, deps.Server, "alice") + req := httptest.NewRequest(http.MethodGet, "/dashboard/clients/alice", nil) + req.SetPathValue("id", "alice") + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "admin", Role: auth.RoleAdmin})) + rr := httptest.NewRecorder() + ClientDetail(deps)(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + body := rr.Body.String() + if !strings.Contains(body, "id=alice") { + t.Fatalf("missing id: %q", body) + } + if !strings.Contains(body, "online=true") { + t.Fatalf("expected online=true: %q", body) + } + if !strings.Contains(body, "tab=info") { + t.Fatalf("expected default info tab: %q", body) + } + if !strings.Contains(body, "user=u") { + t.Fatalf("expected username: %q", body) + } +} + +func TestClientDetailSubsTab(t *testing.T) { + deps := newClientDetailDeps(t) + addClientForDetail(t, deps.Server, "alice", "sensors/temp", "sensors/+/room1") + req := httptest.NewRequest(http.MethodGet, "/dashboard/clients/alice?tab=subs", nil) + req.SetPathValue("id", "alice") + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "admin", Role: auth.RoleAdmin})) + rr := httptest.NewRecorder() + ClientDetail(deps)(rr, req) + body := rr.Body.String() + if !strings.Contains(body, "subcount=2") { + t.Fatalf("expected 2 subs: %q", body) + } + if !strings.Contains(body, "sensors/temp|sensors%2Ftemp") { + t.Fatalf("expected escaped topic: %q", body) + } + // html/template escapes '+' as + in element content; the encoded + // path segment is what the browser receives. + if !strings.Contains(body, "sensors/+/room1|sensors%2F+%2Froom1") { + t.Fatalf("expected escaped + topic: %q", body) + } +} + +func TestClientDetailReadonlyForViewer(t *testing.T) { + deps := newClientDetailDeps(t) + addClientForDetail(t, deps.Server, "alice") + req := httptest.NewRequest(http.MethodGet, "/dashboard/clients/alice?tab=subs", nil) + req.SetPathValue("id", "alice") + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "v", Role: auth.RoleViewer})) + rr := httptest.NewRecorder() + ClientDetail(deps)(rr, req) + if !strings.Contains(rr.Body.String(), "readonly=true") { + t.Fatalf("expected readonly: %q", rr.Body.String()) + } +} + +func TestClientUnsubscribeAdminHappy(t *testing.T) { + deps := newClientDetailDeps(t) + addClientForDetail(t, deps.Server, "alice", "sensors/temp") + req := httptest.NewRequest(http.MethodPost, "/dashboard/clients/alice/subscriptions/sensors%2Ftemp/delete", nil) + req.SetPathValue("id", "alice") + req.SetPathValue("topic", "sensors%2Ftemp") + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "admin", Role: auth.RoleAdmin})) + rr := httptest.NewRecorder() + ClientUnsubscribe(deps)(rr, req) + if rr.Code != http.StatusFound { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + cl, _ := deps.Server.Clients.Get("alice") + if _, ok := cl.State.Subscriptions.Get("sensors/temp"); ok { + t.Fatal("subscription should be cleared") + } +} + +func TestClientUnsubscribeViewerForbidden(t *testing.T) { + deps := newClientDetailDeps(t) + addClientForDetail(t, deps.Server, "alice", "sensors/temp") + req := httptest.NewRequest(http.MethodPost, "/dashboard/clients/alice/subscriptions/sensors%2Ftemp/delete", nil) + req.SetPathValue("id", "alice") + req.SetPathValue("topic", "sensors%2Ftemp") + req = req.WithContext(auth.WithUser(req.Context(), auth.User{Username: "v", Role: auth.RoleViewer})) + rr := httptest.NewRecorder() + ClientUnsubscribe(deps)(rr, req) + if rr.Code != http.StatusForbidden { + t.Fatalf("status: %d", rr.Code) + } +} diff --git a/dashboard/handlers/clients.go b/dashboard/handlers/clients.go new file mode 100644 index 0000000..0ea7750 --- /dev/null +++ b/dashboard/handlers/clients.go @@ -0,0 +1,110 @@ +// dashboard/handlers/clients.go +package handlers + +import ( + "net/http" + "net/url" + "sort" + "strings" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/dashboard/auth" + "github.com/wind-c/comqtt/v2/mqtt/rest" +) + +// ClientsDeps bundles the dependencies for the Clients page. +type ClientsDeps struct { + Server *mqtt.Server + Renderer *Renderer + Cluster bool +} + +type clientsPageData struct { + Title string + User auth.User + CSRF string + Cluster bool + Flash string + Q string + Page rest.Page[clientRow] + TotalPages int + PrevQuery string + NextQuery string +} + +// clientRow mirrors the JSON fields of rest/clients_list.go::clientSummary so +// the same template can be driven by either source in future SSE upgrades. +type clientRow struct { + ClientID string + Username string + Remote string + Keepalive uint16 + Subs int + Pending int +} + +// ClientsList handles GET /dashboard/clients. +func ClientsList(d ClientsDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + q := strings.ToLower(r.URL.Query().Get("q")) + params := rest.ParsePage(r.URL.Query()) + + all := make([]clientRow, 0, 64) + for _, cl := range d.Server.Clients.GetAll() { + if q != "" && !strings.Contains(strings.ToLower(cl.ID), q) { + continue + } + all = append(all, clientRow{ + ClientID: cl.ID, + Username: string(cl.Properties.Username), + Remote: cl.Net.Remote, + Keepalive: cl.State.Keepalive, + Subs: cl.State.Subscriptions.Len(), + Pending: cl.State.Inflight.Len(), + }) + } + sort.Slice(all, func(i, j int) bool { return all[i].ClientID < all[j].ClientID }) + + page := rest.Page[clientRow]{ + Page: params.Page, + Size: params.Size, + Total: len(all), + Items: rest.ApplyPagination(all, params), + } + totalPages := page.Total / page.Size + if page.Total%page.Size > 0 { + totalPages++ + } + if totalPages < 1 { + totalPages = 1 + } + + d.Renderer.Render(w, "clients_list", clientsPageData{ + Title: "Clients", + User: auth.UserFromContext(r.Context()), + CSRF: auth.NewCSRFToken(), + Cluster: d.Cluster, + Q: r.URL.Query().Get("q"), + Page: page, + TotalPages: totalPages, + PrevQuery: pageQuery(r.URL.Query(), params.Page-1), + NextQuery: pageQuery(r.URL.Query(), params.Page+1), + }) + } +} + +// pageQuery returns the URL-encoded query string with `page=N` substituted. +// All other params (size, q) are preserved. +func pageQuery(q url.Values, page int) string { + out := url.Values{} + for k, v := range q { + if k == "page" { + continue + } + for _, vv := range v { + out.Add(k, vv) + } + } + out.Set("page", itoa(int64(page))) + return out.Encode() +} diff --git a/dashboard/handlers/clients_test.go b/dashboard/handlers/clients_test.go new file mode 100644 index 0000000..98a5533 --- /dev/null +++ b/dashboard/handlers/clients_test.go @@ -0,0 +1,91 @@ +// dashboard/handlers/clients_test.go +package handlers + +import ( + "io/fs" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "testing/fstest" + + "github.com/wind-c/comqtt/v2/mqtt" +) + +func newClientsRenderer(t *testing.T) *Renderer { + t.Helper() + return NewRenderer(fs.FS(fstest.MapFS{ + "templates/clients/list.html": &fstest.MapFile{Data: []byte(`{{define "layout"}}{{template "content" .}}{{end}}{{define "clients_list"}}{{template "layout" .}}{{end}}{{define "content"}}{{range .Page.Items}}{{else}}{{end}}
{{.ClientID}}{{.Username}}
empty
total={{.Page.Total}} q={{.Q}} page={{.Page.Page}}/{{.TotalPages}}
{{end}}`)}, + })) +} + +func addFakeClientForDashboard(t *testing.T, server *mqtt.Server, id string) { + t.Helper() + cl := &mqtt.Client{ID: id} + cl.Properties.Username = []byte("u") + cl.Net.Remote = "127.0.0.1:0" + cl.State.Subscriptions = mqtt.NewSubscriptions() + cl.State.Inflight = mqtt.NewInflights() + server.Clients.Add(cl) +} + +func TestClientsListEmpty(t *testing.T) { + server := mqtt.New(nil) + deps := ClientsDeps{Server: server, Renderer: newClientsRenderer(t)} + rr := httptest.NewRecorder() + ClientsList(deps)(rr, httptest.NewRequest(http.MethodGet, "/dashboard/clients", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "empty") { + t.Fatalf("expected empty row: %q", rr.Body.String()) + } + if !strings.Contains(rr.Body.String(), "total=0") { + t.Fatalf("expected total=0: %q", rr.Body.String()) + } +} + +func TestClientsListWithRowsAndPagination(t *testing.T) { + server := mqtt.New(nil) + for i := 0; i < 60; i++ { + addFakeClientForDashboard(t, server, "client-"+itoa(int64(i))) + } + deps := ClientsDeps{Server: server, Renderer: newClientsRenderer(t)} + rr := httptest.NewRecorder() + ClientsList(deps)(rr, httptest.NewRequest(http.MethodGet, "/dashboard/clients?page=2&size=20", nil)) + body := rr.Body.String() + if !strings.Contains(body, "total=60") { + t.Fatalf("expected total=60: %q", body) + } + if !strings.Contains(body, "page=2/3") { + t.Fatalf("expected page=2/3: %q", body) + } +} + +func TestClientsListSearch(t *testing.T) { + server := mqtt.New(nil) + addFakeClientForDashboard(t, server, "alpha-1") + addFakeClientForDashboard(t, server, "alpha-2") + addFakeClientForDashboard(t, server, "bravo-1") + deps := ClientsDeps{Server: server, Renderer: newClientsRenderer(t)} + rr := httptest.NewRecorder() + ClientsList(deps)(rr, httptest.NewRequest(http.MethodGet, "/dashboard/clients?q=alpha", nil)) + body := rr.Body.String() + if !strings.Contains(body, "total=2") { + t.Fatalf("expected total=2: %q", body) + } + if !strings.Contains(body, "q=alpha") { + t.Fatalf("expected q=alpha echoed: %q", body) + } +} + +func TestPageQueryPreservesOtherParams(t *testing.T) { + got := pageQuery(url.Values{"q": {"alpha"}, "size": {"10"}, "page": {"1"}}, 3) + if !strings.Contains(got, "page=3") { + t.Fatalf("got %q", got) + } + if !strings.Contains(got, "q=alpha") { + t.Fatalf("expected q to be preserved: %q", got) + } +} diff --git a/dashboard/handlers/cluster.go b/dashboard/handlers/cluster.go new file mode 100644 index 0000000..4909968 --- /dev/null +++ b/dashboard/handlers/cluster.go @@ -0,0 +1,71 @@ +// dashboard/handlers/cluster.go +package handlers + +import ( + "net/http" + "sort" + + "github.com/wind-c/comqtt/v2/cluster/discovery" + "github.com/wind-c/comqtt/v2/dashboard/auth" +) + +// ClusterAgent is the small surface from *cluster.Agent that the cluster +// page needs. *cluster.Agent satisfies it. Defined here to avoid an +// import cycle (the cluster package imports handlers? It doesn't, but this +// keeps the handler package free of any cluster-specific imports beyond +// the discovery types.) +type ClusterAgent interface { + GetMemberList() []discovery.Member + Leader() string +} + +type ClusterDeps struct { + Agent ClusterAgent + Renderer *Renderer + Cluster bool +} + +type clusterMemberRow struct { + Name string + Addr string + Port int + IsLeader bool +} + +type clusterPageData struct { + Title string + User auth.User + CSRF string + Cluster bool + Flash string + Error string + Members []clusterMemberRow + Leader string +} + +func ClusterPage(d ClusterDeps) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + members := d.Agent.GetMemberList() + leader := d.Agent.Leader() + + rows := make([]clusterMemberRow, 0, len(members)) + for _, m := range members { + rows = append(rows, clusterMemberRow{ + Name: m.Name, + Addr: m.Addr, + Port: m.Port, + IsLeader: m.Name == leader && leader != "", + }) + } + sort.Slice(rows, func(i, j int) bool { return rows[i].Name < rows[j].Name }) + + d.Renderer.Render(w, "cluster", clusterPageData{ + Title: "Cluster", + User: auth.UserFromContext(r.Context()), + CSRF: auth.NewCSRFToken(), + Cluster: d.Cluster, + Members: rows, + Leader: leader, + }) + } +} diff --git a/dashboard/handlers/cluster_test.go b/dashboard/handlers/cluster_test.go new file mode 100644 index 0000000..7c5f748 --- /dev/null +++ b/dashboard/handlers/cluster_test.go @@ -0,0 +1,92 @@ +// dashboard/handlers/cluster_test.go +package handlers + +import ( + "io/fs" + "net/http" + "net/http/httptest" + "strings" + "testing" + "testing/fstest" + + "github.com/wind-c/comqtt/v2/cluster/discovery" +) + +type fakeAgent struct { + members []discovery.Member + leader string +} + +func (f *fakeAgent) GetMemberList() []discovery.Member { return f.members } +func (f *fakeAgent) Leader() string { return f.leader } + +func newClusterRenderer(t *testing.T) *Renderer { + t.Helper() + return NewRenderer(fs.FS(fstest.MapFS{ + "templates/cluster.html": &fstest.MapFile{Data: []byte(`{{define "layout"}}{{template "content" .}}{{end}}{{define "cluster"}}{{template "layout" .}}{{end}}{{define "content"}}
leader={{.Leader}} count={{len .Members}}
{{range .Members}}
{{.Name}}@{{.Addr}}:{{.Port}} leader={{.IsLeader}}
{{else}}
none
{{end}}{{end}}`)}, + })) +} + +func TestClusterPageEmpty(t *testing.T) { + deps := ClusterDeps{Agent: &fakeAgent{}, Renderer: newClusterRenderer(t)} + rr := httptest.NewRecorder() + ClusterPage(deps)(rr, httptest.NewRequest(http.MethodGet, "/dashboard/cluster", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "count=0") { + t.Fatalf("expected empty: %q", rr.Body.String()) + } + if !strings.Contains(rr.Body.String(), "
") { + t.Fatalf("expected empty row: %q", rr.Body.String()) + } +} + +func TestClusterPageWithMembersHighlightsLeader(t *testing.T) { + deps := ClusterDeps{ + Agent: &fakeAgent{ + members: []discovery.Member{ + {Name: "n2", Addr: "10.0.0.2", Port: 7946}, + {Name: "n1", Addr: "10.0.0.1", Port: 7946}, + {Name: "n3", Addr: "10.0.0.3", Port: 7946}, + }, + leader: "n1", + }, + Renderer: newClusterRenderer(t), + } + rr := httptest.NewRecorder() + ClusterPage(deps)(rr, httptest.NewRequest(http.MethodGet, "/dashboard/cluster", nil)) + body := rr.Body.String() + if !strings.Contains(body, "leader=n1 count=3") { + t.Fatalf("expected leader+count line: %q", body) + } + if !strings.Contains(body, "n1@10.0.0.1:7946 leader=true") { + t.Fatalf("expected n1 marked as leader: %q", body) + } + if !strings.Contains(body, "n2@10.0.0.2:7946 leader=false") { + t.Fatalf("expected n2 not leader: %q", body) + } + // Sort check: n1 should appear before n2 in body + if strings.Index(body, "n1@") > strings.Index(body, "n2@") { + t.Fatalf("expected sorted order: %q", body) + } +} + +func TestClusterPageHandlesUnknownLeader(t *testing.T) { + deps := ClusterDeps{ + Agent: &fakeAgent{ + members: []discovery.Member{{Name: "n1", Addr: "10.0.0.1", Port: 7946}}, + leader: "", + }, + Renderer: newClusterRenderer(t), + } + rr := httptest.NewRecorder() + ClusterPage(deps)(rr, httptest.NewRequest(http.MethodGet, "/dashboard/cluster", nil)) + body := rr.Body.String() + if !strings.Contains(body, "leader= count=1") { + t.Fatalf("expected empty leader: %q", body) + } + if !strings.Contains(body, "n1@10.0.0.1:7946 leader=false") { + t.Fatalf("expected n1 not leader: %q", body) + } +} diff --git a/dashboard/handlers/events.go b/dashboard/handlers/events.go new file mode 100644 index 0000000..0cfbdfb --- /dev/null +++ b/dashboard/handlers/events.go @@ -0,0 +1,151 @@ +// dashboard/handlers/events.go +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "html" + "net/http" + + "github.com/wind-c/comqtt/v2/dashboard/sse" +) + +// Events returns an http.HandlerFunc that streams events from the given Hub +// as Server-Sent Events. Default payload is HTML for direct htmx-sse swap; +// pass ?as=json for the raw JSON payload. +func Events(hub *sse.Hub) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming unsupported", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + asJSON := r.URL.Query().Get("as") == "json" + + ch := hub.Subscribe() + defer hub.Unsubscribe(ch) + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + fmt.Fprintf(w, ": comqtt dashboard events\n\n") + flusher.Flush() + + for { + select { + case <-ctx.Done(): + return + case ev, ok := <-ch: + if !ok { + return + } + var payload string + if asJSON { + b, _ := json.Marshal(ev) + payload = string(b) + } else { + payload = renderEventHTML(ev) + } + fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", ev.Seq, ev.Type, payload) + flusher.Flush() + } + } + } +} + +// renderEventHTML produces a single-line
  • fragment per event so it can +// flow through htmx-sse's hx-swap="afterbegin" without wrecking newlines +// (SSE data: lines are newline-delimited; the fragment must be one line). +func renderEventHTML(ev sse.Event) string { + var summary map[string]any + _ = json.Unmarshal(ev.Payload, &summary) + + icon := eventIcon(ev.Type) + tone := eventTone(ev.Type) + + headline := html.EscapeString(eventHeadline(ev.Type, summary)) + detail := html.EscapeString(eventDetail(ev.Type, summary)) + node := html.EscapeString(ev.Node) + ts := fmt.Sprintf("%d", ev.TS) + + return fmt.Sprintf( + `
  • `+ + `%s`+ + `%s %s`+ + `%s`+ + `
  • `, + tone, icon, headline, detail, ts, node, node, + ) +} + +func eventIcon(t string) string { + switch t { + case "client.connected": + return "+" + case "client.disconnected": + return "-" + case "message.published": + return ">" + case "subscription.added": + return "*" + case "subscription.removed": + return "~" + } + return "." +} + +func eventTone(t string) string { + switch t { + case "client.connected", "subscription.added": + return "text-emerald-700 dark:text-emerald-300" + case "client.disconnected", "subscription.removed": + return "text-rose-700 dark:text-rose-300" + } + return "" +} + +func eventHeadline(t string, p map[string]any) string { + switch t { + case "client.connected": + return strOf(p, "client_id") + " connected" + case "client.disconnected": + return strOf(p, "client_id") + " disconnected" + case "message.published": + return strOf(p, "client_id") + " -> " + strOf(p, "topic") + case "subscription.added": + return strOf(p, "client_id") + " sub " + strOf(p, "topic") + case "subscription.removed": + return strOf(p, "client_id") + " unsub " + strOf(p, "topic") + } + return t +} + +func eventDetail(t string, p map[string]any) string { + switch t { + case "client.connected": + return strOf(p, "remote") + case "client.disconnected": + if reason := strOf(p, "reason"); reason != "" { + return reason + } + return "" + case "message.published": + size, _ := p["payload_size"].(float64) + qos, _ := p["qos"].(float64) + return fmt.Sprintf("qos=%d size=%dB", int(qos), int(size)) + } + return "" +} + +func strOf(m map[string]any, k string) string { + if v, ok := m[k].(string); ok { + return v + } + return "" +} diff --git a/dashboard/handlers/events_test.go b/dashboard/handlers/events_test.go new file mode 100644 index 0000000..c833562 --- /dev/null +++ b/dashboard/handlers/events_test.go @@ -0,0 +1,138 @@ +// dashboard/handlers/events_test.go +package handlers + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/wind-c/comqtt/v2/dashboard/sse" +) + +func TestEventsStreamsJSONWithAsParam(t *testing.T) { + hub := sse.NewHub(8) + defer hub.Close() + srv := httptest.NewServer(Events(hub)) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"?as=json", nil) + resp, err := srv.Client().Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + defer resp.Body.Close() + + time.Sleep(75 * time.Millisecond) + hub.Publish(sse.Event{Type: "client.connected", Node: "n1", TS: 1, Payload: json.RawMessage(`{"client_id":"alice"}`)}) + + got := readUntil(t, resp.Body, "event: client.connected", 2*time.Second) + if !strings.Contains(got, "data: {") { + t.Fatalf("expected JSON data: %q", got) + } +} + +func TestEventsStreamsHTMLByDefault(t *testing.T) { + hub := sse.NewHub(8) + defer hub.Close() + srv := httptest.NewServer(Events(hub)) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := srv.Client().Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + defer resp.Body.Close() + + time.Sleep(75 * time.Millisecond) + hub.Publish(sse.Event{Type: "client.connected", Node: "n1", TS: 1, Payload: json.RawMessage(`{"client_id":"alice","remote":"127.0.0.1:0"}`)}) + + got := readUntil(t, resp.Body, " fragment: %q", got) + } + if !strings.Contains(got, "alice connected") { + t.Fatalf("expected headline: %q", got) + } + if strings.Count(got, "\n\n") < 1 { + t.Fatalf("expected SSE record terminator: %q", got) + } +} + +func TestEventsClosesOnContextCancel(t *testing.T) { + hub := sse.NewHub(8) + defer hub.Close() + srv := httptest.NewServer(Events(hub)) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := srv.Client().Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() +} + +func TestRenderEventHTMLEscapes(t *testing.T) { + ev := sse.Event{ + Type: "client.connected", + Payload: json.RawMessage(`{"client_id":" + + + + + +
    +
    + +
    + comqtt + / {{.Title}} +
    +
    +
    + + {{template "nav" .}} +
    + {{template "flash" .}} + {{template "content" .}} +
    +
    + + +{{end}} diff --git a/dashboard/templates/_nav.html b/dashboard/templates/_nav.html new file mode 100644 index 0000000..35654a6 --- /dev/null +++ b/dashboard/templates/_nav.html @@ -0,0 +1,53 @@ +{{define "nav"}} + +{{end}} diff --git a/dashboard/templates/account/password.html b/dashboard/templates/account/password.html new file mode 100644 index 0000000..71080ec --- /dev/null +++ b/dashboard/templates/account/password.html @@ -0,0 +1,44 @@ +{{define "account/password"}} + + + + + + Change password - comqtt + + + + +
    +
    +

    Change password

    +

    + {{if eq .Reason "must_change"}}You are signed in with the seeded password. Choose a new one to continue. + {{else if eq .Reason "expired"}}Your password has expired. Choose a new one. + {{else}}Update your dashboard password.{{end}} +

    +
    + {{with .Error}} +
    {{.}}
    + {{end}} + + + + + +
    + + +{{end}} diff --git a/dashboard/templates/account/personal.html b/dashboard/templates/account/personal.html new file mode 100644 index 0000000..afaebe3 --- /dev/null +++ b/dashboard/templates/account/personal.html @@ -0,0 +1,38 @@ +{{define "account_personal"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +

    Account

    +

    Your dashboard profile.

    +
    + +
    +
    +
    Username
    +
    {{.Account.Username}}
    + +
    Role
    +
    + {{if eq .Account.Role "admin"}} + admin + {{else}} + viewer + {{end}} +
    + +
    Password set
    +
    {{.Account.PasswordSetAt}}
    + +
    Status
    +
    + {{if .Account.MustChange}}must change password{{else}}active{{end}} + {{if .Account.Locked}}locked{{end}} +
    +
    + +
    +
    +{{end}} diff --git a/dashboard/templates/blacklist.html b/dashboard/templates/blacklist.html new file mode 100644 index 0000000..9ea29bd --- /dev/null +++ b/dashboard/templates/blacklist.html @@ -0,0 +1,39 @@ +{{define "blacklist"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +

    Blacklist

    +

    {{len .Items}} blocked client{{if ne (len .Items) 1}}s{{end}}

    +
    + + {{with .Error}} +
    {{.}}
    + {{end}} + +
    + + + +
    + +
    + {{range .Items}} +
    + {{.}} +
    + + +
    +
    + {{else}} +
    Blacklist is empty.
    + {{end}} +
    +
    +{{end}} diff --git a/dashboard/templates/clients/detail.html b/dashboard/templates/clients/detail.html new file mode 100644 index 0000000..bafc787 --- /dev/null +++ b/dashboard/templates/clients/detail.html @@ -0,0 +1,79 @@ +{{define "client_detail"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +
    +

    {{.ClientID}}

    +

    + {{if .Online}}online{{else}}offline{{end}} +

    +
    + ← Back to clients +
    + + {{with .Flash}}
    {{.}}
    {{end}} + {{with .Error}}
    {{.}}
    {{end}} + + + + {{if eq .Tab "info"}} +
    +
    +
    Username
    +
    {{.Info.Username}}
    + +
    Remote
    +
    {{.Info.Remote}}
    + +
    Listener
    +
    {{.Info.Listener}}
    + +
    Protocol
    +
    MQTT v{{.Info.ProtocolVersion}}
    + +
    Keepalive
    +
    {{.Info.Keepalive}}s
    + +
    Inflight
    +
    {{.Info.Inflight}}
    + +
    Clean session
    +
    {{.Info.Clean}}
    +
    +
    + {{end}} + + {{if eq .Tab "subs"}} +
    + {{range .Subs}} +
    + {{.Topic}} + + QoS {{.QoS}} +
    + + +
    +
    +
    + {{else}} +
    No subscriptions.
    + {{end}} +
    + {{end}} + + {{if eq .Tab "events"}} +
    + Live event filtering for a single client is not yet wired up. The Overview events feed shows all events; visit /dashboard/ for now. +
    + {{end}} +
    +{{end}} diff --git a/dashboard/templates/clients/list.html b/dashboard/templates/clients/list.html new file mode 100644 index 0000000..dcafebb --- /dev/null +++ b/dashboard/templates/clients/list.html @@ -0,0 +1,62 @@ +{{define "clients_list"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +
    +

    Clients

    +

    {{.Page.Total}} client{{if ne .Page.Total 1}}s{{end}} connected

    +
    +
    + + +
    +
    + +
    + + + + + + + + + + + + + {{range .Page.Items}} + + + + + + + + + {{else}} + + {{end}} + +
    Client IDUsernameRemoteSubsPendingKeepalive
    + {{.ClientID}} + {{.Username}}{{.Remote}}{{.Subs}}{{.Pending}}{{.Keepalive}}s
    No clients connected.
    +
    + + {{if gt .TotalPages 1}} +
    +
    Page {{.Page.Page}} of {{.TotalPages}}
    +
    + {{if gt .Page.Page 1}} + Prev + {{end}} + {{if lt .Page.Page .TotalPages}} + Next + {{end}} +
    +
    + {{end}} +
    +{{end}} diff --git a/dashboard/templates/cluster.html b/dashboard/templates/cluster.html new file mode 100644 index 0000000..64e8ec8 --- /dev/null +++ b/dashboard/templates/cluster.html @@ -0,0 +1,52 @@ +{{define "cluster"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +
    +

    Cluster

    +

    {{len .Members}} node{{if ne (len .Members) 1}}s{{end}}{{with .Leader}}, leader: {{.}}{{end}}

    +
    +
    + + {{with .Flash}}
    {{.}}
    {{end}} + {{with .Error}}
    {{.}}
    {{end}} + +
    + + + + + + + + + + + {{range .Members}} + + + + + + + {{else}} + + {{end}} + +
    NodeAddressPortRole
    {{.Name}}{{.Addr}}{{.Port}} + {{if .IsLeader}} + leader + {{else}} + follower + {{end}} +
    No cluster members.
    +
    + +

    + Add and Remove peer actions are not yet exposed on the dashboard. Use + POST /api/v1/cluster/peers directly via curl + (admin-only). +

    +
    +{{end}} diff --git a/dashboard/templates/fragments/overview_cards.html b/dashboard/templates/fragments/overview_cards.html new file mode 100644 index 0000000..09a4c93 --- /dev/null +++ b/dashboard/templates/fragments/overview_cards.html @@ -0,0 +1,12 @@ +{{define "overview_cards"}} +
    + {{range .Cards}} +
    +
    {{.Label}}
    +
    {{.Value}}{{with .Unit}} {{.}}{{end}}
    + {{with .Sub}}
    {{.}}
    {{end}} + {{with .Spark}}
    {{.}}
    {{end}} +
    + {{end}} +
    +{{end}} diff --git a/dashboard/templates/login.html b/dashboard/templates/login.html new file mode 100644 index 0000000..e4452ae --- /dev/null +++ b/dashboard/templates/login.html @@ -0,0 +1,39 @@ +{{define "login"}} + + + + + + Sign in - comqtt + + + + +
    +
    +

    + comqtt +

    +

    Sign in to the dashboard

    +
    + {{with .Error}} +
    {{.}}
    + {{end}} + + + + + + +
    + + +{{end}} diff --git a/dashboard/templates/overview.html b/dashboard/templates/overview.html new file mode 100644 index 0000000..6b1d0b7 --- /dev/null +++ b/dashboard/templates/overview.html @@ -0,0 +1,96 @@ +{{define "overview"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +
    +

    Overview

    +

    Broker health at a glance

    +
    +
    +
    + Mode: + {{if eq .NodeInfo.Mode "Cluster"}} + {{.NodeInfo.Mode}} + {{else}} + {{.NodeInfo.Mode}} + {{end}} +
    +
    This node: {{.NodeInfo.Self}}
    +
    +
    + +
    +
    +
    + {{template "overview_cards" .}} +
    + +
    +
    + Cluster topology + + {{.NodeInfo.Total}} node{{if ne .NodeInfo.Total 1}}s{{end}}{{with .NodeInfo.Leader}} · leader {{.}}{{end}} + +
    +
    + {{.NodeInfo.Topology}} +
    +
    + + + leader + + + + follower + + + + this node + +
    +
      + {{range .NodeInfo.Members}} +
    • +
      + {{.Name}} + {{if .IsSelf}}(this node){{end}} + {{with .Addr}}{{.}}{{end}} +
      +
      + {{if .IsLeader}} + leader + {{else if eq $.NodeInfo.Mode "Cluster"}} + follower + {{else}} + running + {{end}} +
      +
    • + {{end}} +
    +
    +
    + +
    +
    +
    + Recent events +
    +
    +
      +
    • Waiting for events...
    • +
    +
    +
    +
    +
    +
    +{{end}} diff --git a/dashboard/templates/retained.html b/dashboard/templates/retained.html new file mode 100644 index 0000000..9a717c6 --- /dev/null +++ b/dashboard/templates/retained.html @@ -0,0 +1,43 @@ +{{define "retained"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +
    +

    Retained Messages

    +

    {{.Page.Total}} retained

    +
    +
    + + + +
    +
    + +
    + {{range .Page.Items}} +
    + + {{.Topic}} + + QoS {{.QoS}} + {{.Size}}B +
    + + +
    +
    +
    +
    {{.PayloadHex}}
    +
    + {{else}} +
    No retained messages.
    + {{end}} +
    +
    +{{end}} diff --git a/dashboard/templates/sessions.html b/dashboard/templates/sessions.html new file mode 100644 index 0000000..542880f --- /dev/null +++ b/dashboard/templates/sessions.html @@ -0,0 +1,55 @@ +{{define "sessions"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +
    +

    Sessions

    +

    {{.Page.Total}} session{{if ne .Page.Total 1}}s{{end}}

    +
    +
    + + +
    +
    + +
    + + + + + + + + + + + + + {{range .Page.Items}} + + + + + + + + + {{else}} + + {{end}} + +
    Client IDOnlineRemoteSubsInflightAction
    {{.ClientID}}{{if .Online}}yes{{else}}no{{end}}{{.Remote}}{{.Subs}}{{.Inflight}} +
    + + +
    +
    No sessions.
    +
    +
    +{{end}} diff --git a/dashboard/templates/settings.html b/dashboard/templates/settings.html new file mode 100644 index 0000000..182568d --- /dev/null +++ b/dashboard/templates/settings.html @@ -0,0 +1,17 @@ +{{define "settings"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +

    Settings

    +

    Read-only snapshot of the running configuration.

    +
    + +
    + +
    + {{.HighlightedYAML}} +
    +
    +
    +{{end}} diff --git a/dashboard/templates/subscriptions.html b/dashboard/templates/subscriptions.html new file mode 100644 index 0000000..c52929c --- /dev/null +++ b/dashboard/templates/subscriptions.html @@ -0,0 +1,46 @@ +{{define "subscriptions"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +
    +

    Subscriptions

    +

    {{.Page.Total}} subscription{{if ne .Page.Total 1}}s{{end}}

    +
    +
    + + + +
    +
    + +
    + + + + + + + + + + + + {{range .Page.Items}} + + + + + + + + {{else}} + + {{end}} + +
    Client IDTopicQoSNoLocalRetainAsPub
    {{.ClientID}}{{.Topic}}{{.QoS}}{{.NoLocal}}{{.RetainAsPublished}}
    No subscriptions match.
    +
    +
    +{{end}} diff --git a/dashboard/templates/tools.html b/dashboard/templates/tools.html new file mode 100644 index 0000000..d4ac159 --- /dev/null +++ b/dashboard/templates/tools.html @@ -0,0 +1,54 @@ +{{define "tools"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +

    Tools

    +

    Publish an MQTT message from the dashboard.

    +
    + + {{with .Flash}} +
    {{.}}
    + {{end}} + {{with .Error}} +
    {{.}}
    + {{end}} + +
    + + + +
    + + +
    +
    + +
    +
    +
    +{{end}} diff --git a/dashboard/templates/topics.html b/dashboard/templates/topics.html new file mode 100644 index 0000000..38f1064 --- /dev/null +++ b/dashboard/templates/topics.html @@ -0,0 +1,28 @@ +{{define "topics"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +

    Topics

    +

    {{.UniqueFilters}} unique filter{{if ne .UniqueFilters 1}}s{{end}}

    +
    + +
    + {{template "topic_node" .Root}} +
    +
    +{{end}} + +{{define "topic_node"}} +{{if .Topic}} +
    + + {{.Topic}} + {{if gt .Subscribers 0}}({{.Subscribers}} sub{{if ne .Subscribers 1}}s{{end}}){{end}} + + {{range .Children}}{{template "topic_node" .}}{{end}} +
    +{{else}} +{{range .Children}}{{template "topic_node" .}}{{end}} +{{end}} +{{end}} diff --git a/dashboard/templates/users.html b/dashboard/templates/users.html new file mode 100644 index 0000000..ccadded --- /dev/null +++ b/dashboard/templates/users.html @@ -0,0 +1,78 @@ +{{define "users"}}{{template "layout" .}}{{end}} + +{{define "content"}} +
    +
    +

    Users

    +

    {{len .Items}} dashboard user{{if ne (len .Items) 1}}s{{end}}

    +
    + + {{with .Flash}}
    {{.}}
    {{end}} + {{with .Error}}
    {{.}}
    {{end}} + +
    + + + + + +
    + +
    + + + + + + + + + + + {{range .Items}} + + + + + + + {{else}} + + {{end}} + +
    UsernameRoleStatusActions
    {{.Username}} + {{if eq .Role "admin"}} + admin + {{else}} + viewer + {{end}} + + {{if .MustChange}}must change{{else}}active{{end}} + {{if .Locked}}locked{{end}} + +
    + + +
    +
    + + +
    +
    No users.
    +
    +
    +{{end}} diff --git a/dashboard/web/input.css b/dashboard/web/input.css new file mode 100644 index 0000000..e40cdff --- /dev/null +++ b/dashboard/web/input.css @@ -0,0 +1,88 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +/* IoT-Blue palette - HSL tokens. + Foreground/grounding: Slate. Brand accent: Sky 500 (IoT/connectivity feel). + Status colors: Emerald (success), Amber (warning), Red (error). */ +@layer base { + :root { + --background: 210 40% 98%; /* Slate 50 */ + --foreground: 222 47% 11%; /* Slate 900 */ + --card: 0 0% 100%; + --card-foreground: 222 47% 11%; + --popover: 0 0% 100%; + --popover-foreground: 222 47% 11%; + --primary: 199 89% 48%; /* Sky 500 */ + --primary-foreground: 0 0% 100%; + --secondary: 210 40% 96%; /* Slate 100 */ + --secondary-foreground: 222 47% 11%; + --muted: 210 40% 96%; + --muted-foreground: 215 16% 47%; /* Slate 500 */ + --accent: 210 40% 96%; + --accent-foreground: 222 47% 11%; + --destructive: 0 84% 60%; /* Red 500 */ + --destructive-foreground: 0 0% 100%; + --success: 160 84% 39%; /* Emerald 500 */ + --success-foreground: 0 0% 100%; + --warning: 38 92% 50%; /* Amber 500 */ + --warning-foreground: 222 47% 11%; + --border: 214 32% 91%; /* Slate 200 */ + --input: 214 32% 91%; + --ring: 199 89% 48%; /* Sky 500 */ + --radius: 0.5rem; + } + .dark { + --background: 222 47% 11%; /* Slate 900 */ + --foreground: 210 40% 98%; /* Slate 50 */ + --card: 217 33% 17%; /* Slate 800 */ + --card-foreground: 210 40% 98%; + --popover: 217 33% 17%; + --popover-foreground: 210 40% 98%; + --primary: 199 93% 60%; /* Sky 400 */ + --primary-foreground: 222 47% 11%; + --secondary: 217 33% 17%; + --secondary-foreground: 210 40% 98%; + --muted: 217 33% 17%; + --muted-foreground: 215 20% 65%; /* Slate 400 */ + --accent: 217 33% 17%; + --accent-foreground: 210 40% 98%; + --destructive: 0 84% 60%; + --destructive-foreground: 210 40% 98%; + --success: 160 84% 39%; + --success-foreground: 210 40% 98%; + --warning: 38 92% 50%; + --warning-foreground: 222 47% 11%; + --border: 215 19% 27%; /* Slate 700 */ + --input: 215 19% 27%; + --ring: 199 93% 60%; + } + body { font-feature-settings: "cv02","cv03","cv04","cv11"; } + + /* Native even with + appearance:none; replace it with the theme ring. */ + select:focus-visible { + outline: 2px solid hsl(var(--ring)); + outline-offset: 1px; + } +} + +@layer components { + .sparkline { width: 100%; height: 100%; } +} diff --git a/dashboard/web/tailwind.config.js b/dashboard/web/tailwind.config.js new file mode 100644 index 0000000..c7619a5 --- /dev/null +++ b/dashboard/web/tailwind.config.js @@ -0,0 +1,48 @@ +// Tailwind v3 config - Vercel-themed (tweakcn) tokens via CSS variables. +module.exports = { + darkMode: 'class', + content: ['./dashboard/templates/**/*.html'], + // .dark must be safelisted - shadcn-style theming defines it in input.css + // (the rule that flips the CSS variables) but Tailwind would otherwise + // purge it as "not used by any utility class". + safelist: ['sparkline', 'dark'], + theme: { + extend: { + colors: { + background: 'hsl(var(--background))', + foreground: 'hsl(var(--foreground))', + card: 'hsl(var(--card))', + 'card-foreground': 'hsl(var(--card-foreground))', + popover: 'hsl(var(--popover))', + 'popover-foreground': 'hsl(var(--popover-foreground))', + primary: 'hsl(var(--primary))', + 'primary-foreground': 'hsl(var(--primary-foreground))', + secondary: 'hsl(var(--secondary))', + 'secondary-foreground': 'hsl(var(--secondary-foreground))', + muted: 'hsl(var(--muted))', + 'muted-foreground': 'hsl(var(--muted-foreground))', + accent: 'hsl(var(--accent))', + 'accent-foreground': 'hsl(var(--accent-foreground))', + destructive: 'hsl(var(--destructive))', + 'destructive-foreground': 'hsl(var(--destructive-foreground))', + success: 'hsl(var(--success))', + 'success-foreground': 'hsl(var(--success-foreground))', + warning: 'hsl(var(--warning))', + 'warning-foreground': 'hsl(var(--warning-foreground))', + border: 'hsl(var(--border))', + input: 'hsl(var(--input))', + ring: 'hsl(var(--ring))', + }, + borderRadius: { + lg: 'var(--radius)', + md: 'calc(var(--radius) - 2px)', + sm: 'calc(var(--radius) - 4px)', + }, + fontFamily: { + sans: ['ui-sans-serif', 'system-ui', '-apple-system', 'BlinkMacSystemFont', '"Segoe UI"', 'Roboto', '"Helvetica Neue"', 'Arial', 'sans-serif'], + mono: ['ui-monospace', 'SFMono-Regular', 'Menlo', 'Monaco', 'Consolas', '"Liberation Mono"', '"Courier New"', 'monospace'], + }, + }, + }, + plugins: [], +}; diff --git a/go.mod b/go.mod index f15844c..53e8ab2 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24 toolchain go1.24.3 require ( + github.com/alecthomas/chroma/v2 v2.24.1 github.com/alicebob/miniredis/v2 v2.34.0 github.com/asdine/storm v2.1.2+incompatible github.com/asdine/storm/v3 v3.2.1 @@ -58,6 +59,7 @@ require ( github.com/dgraph-io/badger v1.6.0 // indirect github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.12.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.13.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/go.sum b/go.sum index 8f63231..76166ad 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,12 @@ github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM= github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM= +github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= +github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/chroma/v2 v2.24.1 h1:m5ffpfZbIb++k8AqFEKy9uVgY12xIQtBsQlc6DfZJQM= +github.com/alecthomas/chroma/v2 v2.24.1/go.mod h1:l+ohZ9xRXIbGe7cIW+YZgOGbvuVLjMps/FYN/CwuabI= +github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs= +github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -64,6 +70,8 @@ github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczC github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.12.0 h1:0j4c5qQmnC6XOWNjP3PIXURXN2gWx76rd3KvgdPkCz8= +github.com/dlclark/regexp2 v1.12.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= @@ -169,6 +177,8 @@ github.com/hashicorp/raft-boltdb/v2 v2.3.1 h1:ackhdCNPKblmOhjEU9+4lHSJYFkJd6Jqyv github.com/hashicorp/raft-boltdb/v2 v2.3.1/go.mod h1:n4S+g43dXF1tqDT+yzcXHhXM6y7MrlUd3TTwGRcUvQE= github.com/hashicorp/serf v0.10.2 h1:m5IORhuNSjaxeljg5DeQVDlQyVkhRIjJDimbkCa8aAc= github.com/hashicorp/serf v0.10.2/go.mod h1:T1CmSGfSeGfnfNy/w0odXQUR1rfECGd2Qdsp84DjOiY= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= diff --git a/mqtt/rest/client_unsubscribe.go b/mqtt/rest/client_unsubscribe.go new file mode 100644 index 0000000..6585705 --- /dev/null +++ b/mqtt/rest/client_unsubscribe.go @@ -0,0 +1,31 @@ +// mqtt/rest/client_unsubscribe.go +package rest + +import ( + "net/http" +) + +// unsubscribeClient handles DELETE /api/v1/mqtt/clients/{id}/subscriptions/{topic}. +// The {topic} path segment must be URL-encoded by the caller because MQTT +// topic filters can contain '/', '+', '#'. The Go 1.22 ServeMux applies +// path-segment-aware matching so the encoded form is decoded automatically. +func (s *Rest) unsubscribeClient(w http.ResponseWriter, r *http.Request) { + clientID := r.PathValue("id") + filter := r.PathValue("topic") + if clientID == "" || filter == "" { + http.Error(w, "missing client id or topic", http.StatusBadRequest) + return + } + cl, ok := s.server.Clients.Get(clientID) + if !ok { + http.NotFound(w, r) + return + } + if _, has := cl.State.Subscriptions.Get(filter); !has { + http.NotFound(w, r) + return + } + s.server.Topics.Unsubscribe(filter, clientID) + cl.State.Subscriptions.Delete(filter) + w.WriteHeader(http.StatusNoContent) +} diff --git a/mqtt/rest/client_unsubscribe_test.go b/mqtt/rest/client_unsubscribe_test.go new file mode 100644 index 0000000..ca36aba --- /dev/null +++ b/mqtt/rest/client_unsubscribe_test.go @@ -0,0 +1,68 @@ +// mqtt/rest/client_unsubscribe_test.go +package rest + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +func TestUnsubscribeClientHappyPath(t *testing.T) { + server := mqtt.New(nil) + cl := &mqtt.Client{ID: "alice"} + cl.Properties.Username = []byte("u") + cl.Net.Remote = "127.0.0.1:0" + cl.State.Subscriptions = mqtt.NewSubscriptions() + cl.State.Inflight = mqtt.NewInflights() + cl.State.Subscriptions.Add("sensors/temp", packets.Subscription{Filter: "sensors/temp", Qos: 1}) + server.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "sensors/temp", Qos: 1}) + server.Clients.Add(cl) + rest := New(server) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/api/v1/mqtt/clients/alice/subscriptions/sensors%2Ftemp", nil) + req.SetPathValue("id", "alice") + req.SetPathValue("topic", "sensors/temp") + rest.unsubscribeClient(rr, req) + if rr.Code != http.StatusNoContent { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + if _, has := cl.State.Subscriptions.Get("sensors/temp"); has { + t.Fatal("expected per-client subscription to be cleared") + } +} + +func TestUnsubscribeClientNotFound(t *testing.T) { + server := mqtt.New(nil) + rest := New(server) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/x", nil) + req.SetPathValue("id", "ghost") + req.SetPathValue("topic", "anything") + rest.unsubscribeClient(rr, req) + if rr.Code != http.StatusNotFound { + t.Fatalf("status: %d", rr.Code) + } +} + +func TestUnsubscribeClientUnknownTopic(t *testing.T) { + server := mqtt.New(nil) + cl := &mqtt.Client{ID: "alice"} + cl.State.Subscriptions = mqtt.NewSubscriptions() + cl.State.Inflight = mqtt.NewInflights() + server.Clients.Add(cl) + rest := New(server) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/x", nil) + req.SetPathValue("id", "alice") + req.SetPathValue("topic", "never/subscribed") + rest.unsubscribeClient(rr, req) + if rr.Code != http.StatusNotFound { + t.Fatalf("status: %d", rr.Code) + } +} diff --git a/mqtt/rest/clients_list.go b/mqtt/rest/clients_list.go new file mode 100644 index 0000000..3cf523a --- /dev/null +++ b/mqtt/rest/clients_list.go @@ -0,0 +1,49 @@ +// mqtt/rest/clients_list.go +package rest + +import ( + "net/http" + "sort" + "strings" +) + +type clientSummary struct { + ClientID string `json:"client_id"` + Username string `json:"username"` + Remote string `json:"remote"` + ConnectedAt int64 `json:"connected_at"` + Keepalive uint16 `json:"keepalive"` + Subs int `json:"subs"` + Pending int `json:"pending"` +} + +// listClients handles GET /api/v1/mqtt/clients. +// Returns a paginated list filtered by ?q= prefix on ClientID. +func (s *Rest) listClients(w http.ResponseWriter, r *http.Request) { + page := ParsePage(r.URL.Query()) + q := strings.ToLower(r.URL.Query().Get("q")) + + all := make([]clientSummary, 0, 256) + for _, cl := range s.server.Clients.GetAll() { + if q != "" && !strings.Contains(strings.ToLower(cl.ID), q) { + continue + } + all = append(all, clientSummary{ + ClientID: cl.ID, + Username: string(cl.Properties.Username), + Remote: cl.Net.Remote, + Keepalive: cl.State.Keepalive, + Subs: cl.State.Subscriptions.Len(), + Pending: cl.State.Inflight.Len(), + }) + } + sort.Slice(all, func(i, j int) bool { return all[i].ClientID < all[j].ClientID }) + + resp := Page[clientSummary]{ + Page: page.Page, + Size: page.Size, + Total: len(all), + Items: ApplyPagination(all, page), + } + Ok(w, resp) +} diff --git a/mqtt/rest/clients_list_test.go b/mqtt/rest/clients_list_test.go new file mode 100644 index 0000000..a1aadba --- /dev/null +++ b/mqtt/rest/clients_list_test.go @@ -0,0 +1,112 @@ +// mqtt/rest/clients_list_test.go +package rest + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/wind-c/comqtt/v2/mqtt" +) + +// addFakeClient inserts a minimally-populated *mqtt.Client into the broker's +// Clients map. The fields populated are exactly the ones listClients reads. +func addFakeClient(t *testing.T, server *mqtt.Server, id, username string) { + t.Helper() + cl := &mqtt.Client{ID: id} + cl.Properties.Username = []byte(username) + cl.Net.Remote = "127.0.0.1:0" + cl.State.Subscriptions = mqtt.NewSubscriptions() + cl.State.Inflight = mqtt.NewInflights() + server.Clients.Add(cl) +} + +func TestListClientsEmpty(t *testing.T) { + server := mqtt.New(nil) + rest := New(server) + + rr := httptest.NewRecorder() + rest.listClients(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/clients", nil)) + + if rr.Code != http.StatusOK { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + var got Page[clientSummary] + if err := json.NewDecoder(rr.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.Total != 0 || len(got.Items) != 0 { + t.Fatalf("expected empty: %+v", got) + } + if got.Page != 1 || got.Size != 25 { + t.Fatalf("defaults: %+v", got) + } +} + +func TestListClientsPagination(t *testing.T) { + server := mqtt.New(nil) + for i := 0; i < 60; i++ { + addFakeClient(t, server, fmtID(i), "u") + } + rest := New(server) + + rr := httptest.NewRecorder() + rest.listClients(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/clients?page=2&size=20", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + var got Page[clientSummary] + if err := json.NewDecoder(rr.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.Total != 60 { + t.Fatalf("total: got %d want 60", got.Total) + } + if len(got.Items) != 20 { + t.Fatalf("items: got %d want 20", len(got.Items)) + } +} + +func TestListClientsPrefixSearch(t *testing.T) { + server := mqtt.New(nil) + addFakeClient(t, server, "alpha-1", "u") + addFakeClient(t, server, "alpha-2", "u") + addFakeClient(t, server, "bravo-1", "u") + rest := New(server) + + rr := httptest.NewRecorder() + rest.listClients(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/clients?q=alpha", nil)) + var got Page[clientSummary] + _ = json.NewDecoder(rr.Body).Decode(&got) + if got.Total != 2 { + t.Fatalf("expected 2 alpha clients, got %d", got.Total) + } +} + +// fmtID formats a deterministic ID like "client-007". +func fmtID(i int) string { + return "client-" + leftPad(i) +} + +func leftPad(i int) string { + s := "" + if i < 10 { + s = "00" + } else if i < 100 { + s = "0" + } + return s + itoa(i) +} + +func itoa(i int) string { + if i == 0 { + return "0" + } + var b []byte + for i > 0 { + b = append([]byte{byte('0' + i%10)}, b...) + i /= 10 + } + return string(b) +} diff --git a/mqtt/rest/pagination.go b/mqtt/rest/pagination.go new file mode 100644 index 0000000..543fe45 --- /dev/null +++ b/mqtt/rest/pagination.go @@ -0,0 +1,51 @@ +// mqtt/rest/pagination.go +package rest + +import ( + "net/url" + "strconv" +) + +const ( + DefaultPageSize = 25 + MaxPageSize = 500 +) + +type PageParams struct { + Page int + Size int +} + +func ParsePage(q url.Values) PageParams { + page, _ := strconv.Atoi(q.Get("page")) + if page < 1 { + page = 1 + } + size, _ := strconv.Atoi(q.Get("size")) + if size < 1 { + size = DefaultPageSize + } + if size > MaxPageSize { + size = MaxPageSize + } + return PageParams{Page: page, Size: size} +} + +func ApplyPagination[T any](items []T, p PageParams) []T { + start := (p.Page - 1) * p.Size + end := start + p.Size + if start >= len(items) { + return nil + } + if end > len(items) { + end = len(items) + } + return items[start:end] +} + +type Page[T any] struct { + Page int `json:"page"` + Size int `json:"size"` + Total int `json:"total"` + Items []T `json:"items"` +} diff --git a/mqtt/rest/pagination_test.go b/mqtt/rest/pagination_test.go new file mode 100644 index 0000000..4ae1b41 --- /dev/null +++ b/mqtt/rest/pagination_test.go @@ -0,0 +1,45 @@ +// mqtt/rest/pagination_test.go +package rest + +import ( + "net/url" + "testing" +) + +func TestParsePageDefaults(t *testing.T) { + p := ParsePage(url.Values{}) + if p.Page != 1 || p.Size != 25 { + t.Fatalf("defaults: %+v", p) + } +} + +func TestParsePageClamps(t *testing.T) { + p := ParsePage(url.Values{"page": {"0"}, "size": {"5000"}}) + if p.Page != 1 { + t.Fatalf("page should be clamped to 1, got %d", p.Page) + } + if p.Size != MaxPageSize { + t.Fatalf("size should be clamped to %d, got %d", MaxPageSize, p.Size) + } +} + +func TestApplyPaginationSlices(t *testing.T) { + items := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + got := ApplyPagination(items, PageParams{Page: 2, Size: 3}) + want := []int{3, 4, 5} + if !equalInts(got, want) { + t.Fatalf("got %v want %v", got, want) + } +} + +func equalInts(a, b []int) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/mqtt/rest/rest.go b/mqtt/rest/rest.go index 1c18f1b..2eabb2a 100644 --- a/mqtt/rest/rest.go +++ b/mqtt/rest/rest.go @@ -11,15 +11,23 @@ import ( ) const ( - MqttGetOverallPath = "/api/v1/mqtt/stat/overall" - MqttGetOnlinePath = "/api/v1/mqtt/stat/online" - MqttGetClientPath = "/api/v1/mqtt/clients/{id}" - MqttGetBlacklistPath = "/api/v1/mqtt/blacklist" - MqttAddBlacklistPath = "/api/v1/mqtt/blacklist/{id}" - MqttDelBlacklistPath = "/api/v1/mqtt/blacklist/{id}" - MqttPublishMessagePath = "/api/v1/mqtt/message" - MqttGetConfigPath = "/api/v1/mqtt/config" - PrometheusMetrics = "/metrics" + MqttGetOverallPath = "/api/v1/mqtt/stat/overall" + MqttGetOnlinePath = "/api/v1/mqtt/stat/online" + MqttGetClientPath = "/api/v1/mqtt/clients/{id}" + MqttListClientsPath = "/api/v1/mqtt/clients" + MqttUnsubscribeClientPath = "/api/v1/mqtt/clients/{id}/subscriptions/{topic}" + MqttListSubscriptionsPath = "/api/v1/mqtt/subscriptions" + MqttTopicsTreePath = "/api/v1/mqtt/topics" + MqttListRetainedPath = "/api/v1/mqtt/retained" + MqttClearRetainedPath = "/api/v1/mqtt/retained/{topic}" + MqttListSessionsPath = "/api/v1/mqtt/sessions" + MqttClearSessionPath = "/api/v1/mqtt/sessions/{id}" + MqttGetBlacklistPath = "/api/v1/mqtt/blacklist" + MqttAddBlacklistPath = "/api/v1/mqtt/blacklist/{id}" + MqttDelBlacklistPath = "/api/v1/mqtt/blacklist/{id}" + MqttPublishMessagePath = "/api/v1/mqtt/message" + MqttGetConfigPath = "/api/v1/mqtt/config" + PrometheusMetrics = "/metrics" ) type Handler = func(http.ResponseWriter, *http.Request) @@ -36,14 +44,22 @@ func New(server *mqtt.Server) *Rest { func (s *Rest) GenHandlers() map[string]Handler { return map[string]Handler{ - "GET " + MqttGetConfigPath: s.viewConfig, - "GET " + MqttGetOverallPath: s.getOverallInfo, - "GET " + MqttGetOnlinePath: s.getOnlineCount, - "GET " + MqttGetClientPath: s.getClient, - "GET " + MqttGetBlacklistPath: s.blacklist, - "POST " + MqttAddBlacklistPath: s.kickClient, - "DELETE " + MqttDelBlacklistPath: s.blanchClient, - "POST " + MqttPublishMessagePath: s.publishMessage, + "GET " + MqttGetConfigPath: s.viewConfig, + "GET " + MqttGetOverallPath: s.getOverallInfo, + "GET " + MqttGetOnlinePath: s.getOnlineCount, + "GET " + MqttGetClientPath: s.getClient, + "GET " + MqttListClientsPath: s.listClients, + "GET " + MqttListSubscriptionsPath: s.listSubscriptions, + "GET " + MqttTopicsTreePath: s.topicsTree, + "GET " + MqttListRetainedPath: s.listRetained, + "DELETE " + MqttClearRetainedPath: s.clearRetained, + "GET " + MqttListSessionsPath: s.listSessions, + "DELETE " + MqttClearSessionPath: s.clearSession, + "GET " + MqttGetBlacklistPath: s.blacklist, + "POST " + MqttAddBlacklistPath: s.kickClient, + "DELETE " + MqttDelBlacklistPath: s.blanchClient, + "DELETE " + MqttUnsubscribeClientPath: s.unsubscribeClient, + "POST " + MqttPublishMessagePath: s.publishMessage, "GET " + PrometheusMetrics: promhttp.HandlerFor( s.server.Options.PrometheusRegistry, promhttp.HandlerOpts{ diff --git a/mqtt/rest/retained.go b/mqtt/rest/retained.go new file mode 100644 index 0000000..1431455 --- /dev/null +++ b/mqtt/rest/retained.go @@ -0,0 +1,74 @@ +// mqtt/rest/retained.go +package rest + +import ( + "net/http" + "sort" + "strings" +) + +const retainedPayloadCap = 4096 + +type retainedSummary struct { + Topic string `json:"topic"` + QoS byte `json:"qos"` + Size int `json:"size"` + StoredAt int64 `json:"stored_at"` + Payload []byte `json:"payload,omitempty"` +} + +// listRetained handles GET /api/v1/mqtt/retained. +// Optional ?topic= substring filter on the topic name. ?payload=true includes +// the raw payload (truncated to 4096 bytes). Default omits payloads to keep +// list responses small. +func (s *Rest) listRetained(w http.ResponseWriter, r *http.Request) { + page := ParsePage(r.URL.Query()) + topicFilter := strings.ToLower(r.URL.Query().Get("topic")) + includePayload := r.URL.Query().Get("payload") == "true" + + all := make([]retainedSummary, 0, 64) + for topic, pk := range s.server.Topics.Retained.GetAll() { + if topicFilter != "" && !strings.Contains(strings.ToLower(topic), topicFilter) { + continue + } + row := retainedSummary{ + Topic: topic, + QoS: pk.FixedHeader.Qos, + Size: len(pk.Payload), + StoredAt: pk.Created, + } + if includePayload { + if len(pk.Payload) > retainedPayloadCap { + row.Payload = pk.Payload[:retainedPayloadCap] + } else { + row.Payload = pk.Payload + } + } + all = append(all, row) + } + sort.Slice(all, func(i, j int) bool { return all[i].Topic < all[j].Topic }) + + resp := Page[retainedSummary]{ + Page: page.Page, + Size: page.Size, + Total: len(all), + Items: ApplyPagination(all, page), + } + Ok(w, resp) +} + +// clearRetained handles DELETE /api/v1/mqtt/retained/{topic}. +// 404 if no retained message exists at the given topic. +func (s *Rest) clearRetained(w http.ResponseWriter, r *http.Request) { + topic := r.PathValue("topic") + if topic == "" { + http.Error(w, "missing topic", http.StatusBadRequest) + return + } + if _, ok := s.server.Topics.Retained.Get(topic); !ok { + http.NotFound(w, r) + return + } + s.server.Topics.Retained.Delete(topic) + w.WriteHeader(http.StatusNoContent) +} diff --git a/mqtt/rest/retained_test.go b/mqtt/rest/retained_test.go new file mode 100644 index 0000000..1b91ad0 --- /dev/null +++ b/mqtt/rest/retained_test.go @@ -0,0 +1,134 @@ +// mqtt/rest/retained_test.go +package rest + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +func addRetained(t *testing.T, server *mqtt.Server, topic string, payload []byte, qos byte) { + t.Helper() + pk := packets.Packet{ + TopicName: topic, + Payload: payload, + FixedHeader: packets.FixedHeader{Qos: qos, Retain: true}, + Created: 12345, + } + server.Topics.Retained.Add(topic, pk) +} + +func TestListRetainedEmpty(t *testing.T) { + server := mqtt.New(nil) + rest := New(server) + rr := httptest.NewRecorder() + rest.listRetained(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/retained", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + var got Page[retainedSummary] + _ = json.NewDecoder(rr.Body).Decode(&got) + if got.Total != 0 || len(got.Items) != 0 { + t.Fatalf("expected empty: %+v", got) + } +} + +func TestListRetainedBasic(t *testing.T) { + server := mqtt.New(nil) + addRetained(t, server, "sensors/temp/room1", []byte("21.5"), 0) + addRetained(t, server, "sensors/temp/room2", []byte("22.0"), 1) + addRetained(t, server, "metrics/heap", []byte("100"), 0) + rest := New(server) + + rr := httptest.NewRecorder() + rest.listRetained(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/retained", nil)) + var got Page[retainedSummary] + _ = json.NewDecoder(rr.Body).Decode(&got) + if got.Total != 3 { + t.Fatalf("total: %d", got.Total) + } + for _, item := range got.Items { + if len(item.Payload) != 0 { + t.Fatalf("payload should be omitted by default: %s -> %q", item.Topic, item.Payload) + } + if item.Size <= 0 { + t.Fatalf("size should be set: %+v", item) + } + if item.StoredAt != 12345 { + t.Fatalf("stored_at not propagated: %+v", item) + } + } +} + +func TestListRetainedTopicFilter(t *testing.T) { + server := mqtt.New(nil) + addRetained(t, server, "sensors/temp/room1", []byte("21.5"), 0) + addRetained(t, server, "metrics/heap", []byte("100"), 0) + rest := New(server) + + rr := httptest.NewRecorder() + rest.listRetained(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/retained?topic=sensors", nil)) + var got Page[retainedSummary] + _ = json.NewDecoder(rr.Body).Decode(&got) + if got.Total != 1 || got.Items[0].Topic != "sensors/temp/room1" { + t.Fatalf("filter mismatch: %+v", got) + } +} + +func TestListRetainedPayloadIncludedAndCapped(t *testing.T) { + server := mqtt.New(nil) + big := make([]byte, 8192) + for i := range big { + big[i] = byte(i) + } + addRetained(t, server, "big/topic", big, 0) + rest := New(server) + + rr := httptest.NewRecorder() + rest.listRetained(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/retained?payload=true", nil)) + var got Page[retainedSummary] + _ = json.NewDecoder(rr.Body).Decode(&got) + if got.Total != 1 { + t.Fatalf("total: %d", got.Total) + } + if len(got.Items[0].Payload) != 4096 { + t.Fatalf("payload should be capped at 4096, got %d", len(got.Items[0].Payload)) + } + if got.Items[0].Size != 8192 { + t.Fatalf("size should be the original length: %d", got.Items[0].Size) + } +} + +func TestClearRetainedHappy(t *testing.T) { + server := mqtt.New(nil) + addRetained(t, server, "sensors/temp", []byte("21.5"), 0) + rest := New(server) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/x", nil) + req.SetPathValue("topic", "sensors/temp") + rest.clearRetained(rr, req) + if rr.Code != http.StatusNoContent { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + if _, ok := server.Topics.Retained.Get("sensors/temp"); ok { + t.Fatal("retained should be cleared") + } +} + +func TestClearRetainedNotFound(t *testing.T) { + server := mqtt.New(nil) + rest := New(server) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/x", nil) + req.SetPathValue("topic", "ghost") + rest.clearRetained(rr, req) + if rr.Code != http.StatusNotFound { + t.Fatalf("status: %d", rr.Code) + } +} diff --git a/mqtt/rest/sessions.go b/mqtt/rest/sessions.go new file mode 100644 index 0000000..4cc61bb --- /dev/null +++ b/mqtt/rest/sessions.go @@ -0,0 +1,117 @@ +// mqtt/rest/sessions.go +package rest + +import ( + "net/http" + "sort" + "strings" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +type sessionSummary struct { + ClientID string `json:"client_id"` + Username string `json:"username"` + Online bool `json:"online"` + Remote string `json:"remote"` + Listener string `json:"listener,omitempty"` + ProtocolVersion byte `json:"protocol_version"` + Clean bool `json:"clean"` + Subs int `json:"subs"` + Inflight int `json:"inflight"` + SessionExpiryInterval uint32 `json:"session_expiry_interval,omitempty"` +} + +// listSessions handles GET /api/v1/mqtt/sessions. +// Optional ?online=true|false filters to online-only or offline-only. +// Online sessions are pulled from the in-memory Clients map; offline sessions +// are pulled from the storage backend via s.server.Hooks().StoredClients(). +func (s *Rest) listSessions(w http.ResponseWriter, r *http.Request) { + page := ParsePage(r.URL.Query()) + onlineParam := strings.ToLower(r.URL.Query().Get("online")) + + all := make([]sessionSummary, 0, 128) + online := map[string]bool{} + + for _, cl := range s.server.Clients.GetAll() { + online[cl.ID] = true + if onlineParam == "false" { + continue + } + all = append(all, sessionSummary{ + ClientID: cl.ID, + Username: string(cl.Properties.Username), + Online: true, + Remote: cl.Net.Remote, + Listener: cl.Net.Listener, + ProtocolVersion: cl.Properties.ProtocolVersion, + Clean: cl.Properties.Clean, + Subs: cl.State.Subscriptions.Len(), + Inflight: cl.State.Inflight.Len(), + }) + } + + if onlineParam != "true" && s.server.Hooks() != nil && s.server.Hooks().Provides(mqtt.StoredClients) { + stored, err := s.server.Hooks().StoredClients() + if err == nil { + for _, sc := range stored { + if online[sc.ID] { + continue + } + all = append(all, sessionSummary{ + ClientID: sc.ID, + Username: string(sc.Username), + Online: false, + Remote: sc.Remote, + Listener: sc.Listener, + ProtocolVersion: sc.ProtocolVersion, + Clean: sc.Clean, + SessionExpiryInterval: sc.Properties.SessionExpiryInterval, + }) + } + } + } + + sort.Slice(all, func(i, j int) bool { return all[i].ClientID < all[j].ClientID }) + + resp := Page[sessionSummary]{ + Page: page.Page, + Size: page.Size, + Total: len(all), + Items: ApplyPagination(all, page), + } + Ok(w, resp) +} + +// clearSession handles DELETE /api/v1/mqtt/sessions/{id}. +// Disconnects the client if online. Storage-backed session cleanup happens +// automatically when the broker processes the disconnect with clean=true, +// or via the existing storage hook eviction path on session expiry. +func (s *Rest) clearSession(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + http.Error(w, "missing client id", http.StatusBadRequest) + return + } + cl, online := s.server.Clients.Get(id) + storedExists := false + if !online && s.server.Hooks() != nil && s.server.Hooks().Provides(mqtt.StoredClients) { + if stored, err := s.server.Hooks().StoredClients(); err == nil { + for _, sc := range stored { + if sc.ID == id { + storedExists = true + break + } + } + } + } + if !online && !storedExists { + http.NotFound(w, r) + return + } + if online { + s.server.DisconnectClient(cl, packets.CodeDisconnect) + } + w.WriteHeader(http.StatusNoContent) +} diff --git a/mqtt/rest/sessions_test.go b/mqtt/rest/sessions_test.go new file mode 100644 index 0000000..76cdd92 --- /dev/null +++ b/mqtt/rest/sessions_test.go @@ -0,0 +1,91 @@ +// mqtt/rest/sessions_test.go +package rest + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/wind-c/comqtt/v2/mqtt" +) + +func addOnlineSession(t *testing.T, server *mqtt.Server, id, username string) *mqtt.Client { + t.Helper() + cl := &mqtt.Client{ID: id} + cl.Properties.Username = []byte(username) + cl.Net.Remote = "127.0.0.1:0" + cl.Net.Listener = "test-listener" + cl.Properties.ProtocolVersion = 5 + cl.Properties.Clean = false + cl.State.Subscriptions = mqtt.NewSubscriptions() + cl.State.Inflight = mqtt.NewInflights() + server.Clients.Add(cl) + return cl +} + +func TestListSessionsEmpty(t *testing.T) { + server := mqtt.New(nil) + rest := New(server) + rr := httptest.NewRecorder() + rest.listSessions(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/sessions", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + var got Page[sessionSummary] + _ = json.NewDecoder(rr.Body).Decode(&got) + if got.Total != 0 || len(got.Items) != 0 { + t.Fatalf("expected empty: %+v", got) + } +} + +func TestListSessionsOnlineOnly(t *testing.T) { + server := mqtt.New(nil) + addOnlineSession(t, server, "alice", "u") + addOnlineSession(t, server, "bob", "u") + rest := New(server) + rr := httptest.NewRecorder() + rest.listSessions(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/sessions", nil)) + var got Page[sessionSummary] + _ = json.NewDecoder(rr.Body).Decode(&got) + if got.Total != 2 { + t.Fatalf("total: %d", got.Total) + } + for _, item := range got.Items { + if !item.Online { + t.Fatalf("expected Online=true: %+v", item) + } + } +} + +func TestListSessionsOnlineFilter(t *testing.T) { + server := mqtt.New(nil) + addOnlineSession(t, server, "alice", "u") + rest := New(server) + rr := httptest.NewRecorder() + rest.listSessions(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/sessions?online=true", nil)) + var got Page[sessionSummary] + _ = json.NewDecoder(rr.Body).Decode(&got) + if got.Total != 1 { + t.Fatalf("total: %d", got.Total) + } + rr2 := httptest.NewRecorder() + rest.listSessions(rr2, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/sessions?online=false", nil)) + var got2 Page[sessionSummary] + _ = json.NewDecoder(rr2.Body).Decode(&got2) + if got2.Total != 0 { + t.Fatalf("offline-only with no storage hook should return 0, got %d", got2.Total) + } +} + +func TestClearSessionNotFound(t *testing.T) { + server := mqtt.New(nil) + rest := New(server) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/x", nil) + req.SetPathValue("id", "ghost") + rest.clearSession(rr, req) + if rr.Code != http.StatusNotFound { + t.Fatalf("status: %d", rr.Code) + } +} diff --git a/mqtt/rest/subscriptions_list.go b/mqtt/rest/subscriptions_list.go new file mode 100644 index 0000000..d5df719 --- /dev/null +++ b/mqtt/rest/subscriptions_list.go @@ -0,0 +1,60 @@ +// mqtt/rest/subscriptions_list.go +package rest + +import ( + "net/http" + "sort" + "strings" +) + +type subscriptionSummary struct { + ClientID string `json:"client_id"` + Topic string `json:"topic"` + QoS byte `json:"qos"` + NoLocal bool `json:"no_local"` + RetainAsPublished bool `json:"retain_as_published"` + Identifier int `json:"identifier,omitempty"` +} + +// listSubscriptions handles GET /api/v1/mqtt/subscriptions. +// Returns a paginated list of subscriptions filtered by ?topic= substring on +// the topic filter and ?clientid= substring on the client ID. +func (s *Rest) listSubscriptions(w http.ResponseWriter, r *http.Request) { + page := ParsePage(r.URL.Query()) + topicFilter := strings.ToLower(r.URL.Query().Get("topic")) + clientFilter := strings.ToLower(r.URL.Query().Get("clientid")) + + all := make([]subscriptionSummary, 0, 256) + for _, cl := range s.server.Clients.GetAll() { + if clientFilter != "" && !strings.Contains(strings.ToLower(cl.ID), clientFilter) { + continue + } + for filter, sub := range cl.State.Subscriptions.GetAll() { + if topicFilter != "" && !strings.Contains(strings.ToLower(filter), topicFilter) { + continue + } + all = append(all, subscriptionSummary{ + ClientID: cl.ID, + Topic: filter, + QoS: sub.Qos, + NoLocal: sub.NoLocal, + RetainAsPublished: sub.RetainAsPublished, + Identifier: sub.Identifier, + }) + } + } + sort.Slice(all, func(i, j int) bool { + if all[i].ClientID != all[j].ClientID { + return all[i].ClientID < all[j].ClientID + } + return all[i].Topic < all[j].Topic + }) + + resp := Page[subscriptionSummary]{ + Page: page.Page, + Size: page.Size, + Total: len(all), + Items: ApplyPagination(all, page), + } + Ok(w, resp) +} diff --git a/mqtt/rest/subscriptions_list_test.go b/mqtt/rest/subscriptions_list_test.go new file mode 100644 index 0000000..7781167 --- /dev/null +++ b/mqtt/rest/subscriptions_list_test.go @@ -0,0 +1,104 @@ +// mqtt/rest/subscriptions_list_test.go +package rest + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +// addFakeClientWithSubs inserts a minimal *mqtt.Client into the broker and +// attaches the given topic filters as QoS 1 subscriptions. +func addFakeClientWithSubs(t *testing.T, server *mqtt.Server, id string, filters ...string) { + t.Helper() + cl := &mqtt.Client{ID: id} + cl.Properties.Username = []byte("u") + cl.Net.Remote = "127.0.0.1:0" + cl.State.Subscriptions = mqtt.NewSubscriptions() + cl.State.Inflight = mqtt.NewInflights() + for _, f := range filters { + cl.State.Subscriptions.Add(f, packets.Subscription{Filter: f, Qos: 1}) + } + server.Clients.Add(cl) +} + +func TestListSubscriptionsEmpty(t *testing.T) { + server := mqtt.New(nil) + rest := New(server) + + rr := httptest.NewRecorder() + rest.listSubscriptions(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/subscriptions", nil)) + + if rr.Code != http.StatusOK { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + var got Page[subscriptionSummary] + if err := json.NewDecoder(rr.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.Total != 0 || len(got.Items) != 0 { + t.Fatalf("expected empty: %+v", got) + } + if got.Page != 1 || got.Size != 25 { + t.Fatalf("defaults: %+v", got) + } +} + +func TestListSubscriptionsBasic(t *testing.T) { + server := mqtt.New(nil) + addFakeClientWithSubs(t, server, "client-a", "a/1", "a/2") + addFakeClientWithSubs(t, server, "client-b", "b/1", "b/2") + addFakeClientWithSubs(t, server, "client-c", "c/1", "c/2") + rest := New(server) + + rr := httptest.NewRecorder() + rest.listSubscriptions(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/subscriptions", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + var got Page[subscriptionSummary] + if err := json.NewDecoder(rr.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.Total != 6 { + t.Fatalf("total: got %d want 6", got.Total) + } + if len(got.Items) != 6 { + t.Fatalf("items: got %d want 6", len(got.Items)) + } + for _, it := range got.Items { + if it.QoS != 1 { + t.Fatalf("qos: got %d want 1 for %+v", it.QoS, it) + } + } +} + +func TestListSubscriptionsTopicFilter(t *testing.T) { + server := mqtt.New(nil) + addFakeClientWithSubs(t, server, "client-a", "sensors/temp", "sensors/humidity") + addFakeClientWithSubs(t, server, "client-b", "sensors/temp/room1", "alerts/fire") + addFakeClientWithSubs(t, server, "client-c", "alerts/flood") + rest := New(server) + + rr := httptest.NewRecorder() + rest.listSubscriptions(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/subscriptions?topic=temp", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + var got Page[subscriptionSummary] + if err := json.NewDecoder(rr.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.Total != 2 { + t.Fatalf("total: got %d want 2", got.Total) + } + for _, it := range got.Items { + if it.Topic != "sensors/temp" && it.Topic != "sensors/temp/room1" { + t.Fatalf("unexpected topic in filtered result: %s", it.Topic) + } + } +} diff --git a/mqtt/rest/topics_tree.go b/mqtt/rest/topics_tree.go new file mode 100644 index 0000000..2917146 --- /dev/null +++ b/mqtt/rest/topics_tree.go @@ -0,0 +1,54 @@ +// mqtt/rest/topics_tree.go +package rest + +import ( + "net/http" + "sort" + "strings" +) + +type topicNode struct { + Topic string `json:"topic"` + Subscribers int `json:"subscribers"` + Children []*topicNode `json:"children,omitempty"` +} + +// topicsTree handles GET /api/v1/mqtt/topics. +// Walks every connected client's subscription set, builds a slash-separated +// trie, and returns the root. +func (s *Rest) topicsTree(w http.ResponseWriter, r *http.Request) { + root := &topicNode{} + for _, cl := range s.server.Clients.GetAll() { + for filter := range cl.State.Subscriptions.GetAll() { + insertTopic(root, filter) + } + } + sortTree(root) + Ok(w, root) +} + +func insertTopic(root *topicNode, filter string) { + cur := root + for _, seg := range strings.Split(filter, "/") { + var found *topicNode + for _, c := range cur.Children { + if c.Topic == seg { + found = c + break + } + } + if found == nil { + found = &topicNode{Topic: seg} + cur.Children = append(cur.Children, found) + } + cur = found + } + cur.Subscribers++ +} + +func sortTree(n *topicNode) { + sort.Slice(n.Children, func(i, j int) bool { return n.Children[i].Topic < n.Children[j].Topic }) + for _, c := range n.Children { + sortTree(c) + } +} diff --git a/mqtt/rest/topics_tree_test.go b/mqtt/rest/topics_tree_test.go new file mode 100644 index 0000000..6d659fe --- /dev/null +++ b/mqtt/rest/topics_tree_test.go @@ -0,0 +1,69 @@ +// mqtt/rest/topics_tree_test.go +package rest + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/wind-c/comqtt/v2/mqtt" + "github.com/wind-c/comqtt/v2/mqtt/packets" +) + +func TestTopicsTreeShape(t *testing.T) { + server := mqtt.New(nil) + + subscribe := func(id, filter string) { + cl := &mqtt.Client{ID: id} + cl.Properties.Username = []byte("u") + cl.Net.Remote = "127.0.0.1:0" + cl.State.Subscriptions = mqtt.NewSubscriptions() + cl.State.Inflight = mqtt.NewInflights() + cl.State.Subscriptions.Add(filter, packets.Subscription{Filter: filter, Qos: 0}) + server.Clients.Add(cl) + } + subscribe("alice", "sensors/temp/room1") + subscribe("bob", "sensors/temp/room2") + subscribe("alice2", "sensors/humidity/+") + + rest := New(server) + rr := httptest.NewRecorder() + rest.topicsTree(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/topics", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + var root topicNode + if err := json.NewDecoder(rr.Body).Decode(&root); err != nil { + t.Fatalf("decode: %v", err) + } + if root.Topic != "" { + t.Fatalf("root.Topic should be empty: %q", root.Topic) + } + if len(root.Children) != 1 || root.Children[0].Topic != "sensors" { + t.Fatalf("expected sensors child, got %+v", root.Children) + } + sensors := root.Children[0] + if len(sensors.Children) != 2 { + t.Fatalf("expected 2 children under sensors, got %d", len(sensors.Children)) + } + // children should be sorted: humidity before temp + if sensors.Children[0].Topic != "humidity" || sensors.Children[1].Topic != "temp" { + t.Fatalf("not sorted: %v", []string{sensors.Children[0].Topic, sensors.Children[1].Topic}) + } +} + +func TestTopicsTreeEmpty(t *testing.T) { + server := mqtt.New(nil) + rest := New(server) + rr := httptest.NewRecorder() + rest.topicsTree(rr, httptest.NewRequest(http.MethodGet, "/api/v1/mqtt/topics", nil)) + if rr.Code != http.StatusOK { + t.Fatalf("status: %d", rr.Code) + } + var root topicNode + _ = json.NewDecoder(rr.Body).Decode(&root) + if len(root.Children) != 0 { + t.Fatalf("expected empty tree, got %+v", root) + } +} diff --git a/mqtt/server.go b/mqtt/server.go index 4593f21..0553ba1 100644 --- a/mqtt/server.go +++ b/mqtt/server.go @@ -249,6 +249,13 @@ func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) return cl } +// Hooks returns the registered hook bus. Used by external integrations +// (REST handlers, dashboards) that need to inspect stored state without +// the broker reaching back into them. +func (s *Server) Hooks() *Hooks { + return s.hooks +} + // AddHook attaches a new Hook to the server. Ideally, this should be called // before the server is started with s.Serve(). func (s *Server) AddHook(hook Hook, config any) error {