diff --git a/.golangci.yml b/.golangci.yml index 33cba94..0077400 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -41,6 +41,7 @@ linters: rules: - linters: - errcheck + - goconst - unparam path: _test.go paths: diff --git a/pkg/server/metrics.go b/pkg/server/metrics.go new file mode 100644 index 0000000..43e2c11 --- /dev/null +++ b/pkg/server/metrics.go @@ -0,0 +1,106 @@ +package server + +import ( + "fmt" + "io" + "sort" + "strings" + "time" +) + +func prometheusEscapeLabelValue(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, "\n", `\n`) + s = strings.ReplaceAll(s, `"`, `\"`) + return s +} + +func prometheusLabelValueOrUnknown(s string) string { + if s == "" { + return "unknown" + } + return s +} + +func prometheusLabels(index uint32, module, upstream string) string { + return fmt.Sprintf( + `index="%d",module="%s",upstream="%s"`, + index, + prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(module)), + prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(upstream)), + ) +} + +type prometheusConnectionGroup struct { + module string + upstream string +} + +func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) { + connections := s.ListConnectionInfo() + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections gauge") + _, _ = fmt.Fprintf(w, "rsync_proxy_active_connections %d\n", s.GetActiveConnectionCount()) + + connectionCounts := make(map[prometheusConnectionGroup]int) + for _, conn := range connections { + snapshot := conn.snapshot() + key := prometheusConnectionGroup{ + module: prometheusLabelValueOrUnknown(snapshot.Module), + upstream: prometheusLabelValueOrUnknown(snapshot.UpstreamAddr), + } + connectionCounts[key]++ + } + + keys := make([]prometheusConnectionGroup, 0, len(connectionCounts)) + for key := range connectionCounts { + keys = append(keys, key) + } + sort.Slice(keys, func(i, j int) bool { + if keys[i].module != keys[j].module { + return keys[i].module < keys[j].module + } + return keys[i].upstream < keys[j].upstream + }) + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections_by_module Current active rsync proxy connections by module and upstream.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections_by_module gauge") + for _, key := range keys { + module := prometheusEscapeLabelValue(key.module) + upstream := prometheusEscapeLabelValue(key.upstream) + _, _ = fmt.Fprintf(w, "rsync_proxy_active_connections_by_module{module=\"%s\",upstream=\"%s\"} %d\n", module, upstream, connectionCounts[key]) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_sent_bytes Bytes sent to clients for active connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_sent_bytes gauge") + for _, conn := range connections { + snapshot := conn.snapshot() + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_sent_bytes{%s} %d\n", prometheusLabels(snapshot.Index, snapshot.Module, snapshot.UpstreamAddr), snapshot.SentBytes) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_received_bytes Bytes received from clients for active connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_received_bytes gauge") + for _, conn := range connections { + snapshot := conn.snapshot() + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_received_bytes{%s} %d\n", prometheusLabels(snapshot.Index, snapshot.Module, snapshot.UpstreamAddr), snapshot.ReceivedBytes) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_connected_timestamp_seconds Unix timestamp when active connections were established.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_connected_timestamp_seconds gauge") + for _, conn := range connections { + snapshot := conn.snapshot() + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_connected_timestamp_seconds{%s} %d\n", prometheusLabels(snapshot.Index, snapshot.Module, snapshot.UpstreamAddr), snapshot.ConnectedAt.Unix()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_duration_seconds Current duration of active connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_duration_seconds gauge") + for _, conn := range connections { + snapshot := conn.snapshot() + duration := now.Sub(snapshot.ConnectedAt).Seconds() + if duration < 0 { + duration = 0 + } + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_duration_seconds{%s} %.3f\n", prometheusLabels(snapshot.Index, snapshot.Module, snapshot.UpstreamAddr), duration) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index a2302e6..67d28b3 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -50,6 +50,7 @@ var ( const lineFeed = '\n' type ConnInfo struct { + mu sync.RWMutex Index uint32 LocalAddr string RemoteAddr string @@ -60,18 +61,33 @@ type ConnInfo struct { ReceivedBytes atomic.Int64 } -func (c *ConnInfo) MarshalJSON() ([]byte, error) { - // Handle atomic value (cannot marshal directly) - return json.Marshal(struct { - Index uint32 `json:"index"` - LocalAddr string `json:"local"` - RemoteAddr string `json:"remote"` - ConnectedAt time.Time `json:"connected"` - Module string `json:"module"` - UpstreamAddr string `json:"upstream"` - SentBytes int64 `json:"sentBytes"` - ReceivedBytes int64 `json:"receivedBytes"` - }{ +type connInfoSnapshot struct { + Index uint32 `json:"index"` + LocalAddr string `json:"local"` + RemoteAddr string `json:"remote"` + ConnectedAt time.Time `json:"connected"` + Module string `json:"module"` + UpstreamAddr string `json:"upstream"` + SentBytes int64 `json:"sentBytes"` + ReceivedBytes int64 `json:"receivedBytes"` +} + +func (c *ConnInfo) SetModule(module string) { + c.mu.Lock() + defer c.mu.Unlock() + c.Module = module +} + +func (c *ConnInfo) SetUpstreamAddr(upstreamAddr string) { + c.mu.Lock() + defer c.mu.Unlock() + c.UpstreamAddr = upstreamAddr +} + +func (c *ConnInfo) snapshot() connInfoSnapshot { + c.mu.RLock() + defer c.mu.RUnlock() + return connInfoSnapshot{ Index: c.Index, LocalAddr: c.LocalAddr, RemoteAddr: c.RemoteAddr, @@ -80,7 +96,11 @@ func (c *ConnInfo) MarshalJSON() ([]byte, error) { UpstreamAddr: c.UpstreamAddr, SentBytes: c.SentBytes.Load(), ReceivedBytes: c.ReceivedBytes.Load(), - }) + } +} + +func (c *ConnInfo) MarshalJSON() ([]byte, error) { + return json.Marshal(c.snapshot()) } type Target struct { @@ -537,8 +557,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err } moduleName := string(buf[:n-1]) // trim trailing \n - info.Module = moduleName - s.connInfo.Store(index, &info) + info.SetModule(moduleName) targets, ok := s.getTargetsForModule(moduleName) if !ok { @@ -551,8 +570,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err target := targets[chooseTargetByClientIP(net.ParseIP(ip), len(targets))] upstreamAddr := target.Addr useProxyProtocol := target.UseProxyProtocol - info.UpstreamAddr = upstreamAddr - s.connInfo.Store(index, &info) + info.SetUpstreamAddr(upstreamAddr) upstreamQueue, ok := s.getQueueForUpstream(target.Upstream) if !ok { @@ -804,6 +822,16 @@ func (s *Server) runHTTPServer() error { _, _ = fmt.Fprintf(w, "rsync-proxy,host=%s count=%d %d\n", hostname, count, timestamp) }) + mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + s.writePrometheusMetrics(w, time.Now()) + }) + return http.Serve(s.HTTPListener, &mux) } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 291c453..bc25978 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -1,11 +1,13 @@ package server import ( + "bytes" "crypto/tls" "crypto/x509" "fmt" "io" "net" + "net/http" "os" "path/filepath" "strings" @@ -57,6 +59,10 @@ func startServer(t *testing.T) *Server { return srv } +func testHTTPClient() *http.Client { + return &http.Client{Timeout: time.Second} +} + func doClientHandshake(conn *rsync.Conn, version []byte, module string) (svrVersion string, err error) { _, err = conn.Write(version) if err != nil { @@ -334,12 +340,147 @@ func TestStatusIncludesSelectedUpstream(t *testing.T) { require.Eventually(t, func() bool { infos := srv.ListConnectionInfo() - return len(infos) == 1 && infos[0].UpstreamAddr == upstreamAddr + if len(infos) != 1 { + return false + } + return infos[0].snapshot().UpstreamAddr == upstreamAddr }, time.Second, 10*time.Millisecond) wg.Done() } +func TestMetricsEndpointNoConnections(t *testing.T) { + srv := startServer(t) + defer srv.Close() + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + text := string(body) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "text/plain; version=0.0.4; charset=utf-8", resp.Header.Get("Content-Type")) + assert.Contains(t, text, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.") + assert.Contains(t, text, "# TYPE rsync_proxy_active_connections gauge") + assert.Contains(t, text, "rsync_proxy_active_connections 0\n") +} + +func TestMetricsEndpointRejectsNonGET(t *testing.T) { + srv := startServer(t) + defer srv.Close() + + resp, err := testHTTPClient().Post("http://"+srv.HTTPListener.Addr().String()+"/metrics", "text/plain", nil) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} + +func TestMetricsIncludesActiveConnections(t *testing.T) { + srv := startServer(t) + defer srv.Close() + + var wg sync.WaitGroup + wg.Add(1) + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + require.NoError(t, err) + wg.Wait() + }) + fakeRsync.Start() + defer fakeRsync.Close() + + upstreamAddr := fakeRsync.Listener.Addr().String() + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: upstreamAddr}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + require.NoError(t, err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + require.NoError(t, err) + + require.Eventually(t, func() bool { + infos := srv.ListConnectionInfo() + if len(infos) != 1 { + return false + } + return infos[0].snapshot().UpstreamAddr == upstreamAddr + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + text := string(body) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, text, "rsync_proxy_active_connections 1\n") + assert.Contains(t, text, fmt.Sprintf("rsync_proxy_active_connections_by_module{module=\"fake\",upstream=%q} 1\n", upstreamAddr)) + assert.Contains(t, text, "rsync_proxy_connection_sent_bytes{index=\"") + assert.Contains(t, text, "module=\"fake\"") + assert.Contains(t, text, fmt.Sprintf("upstream=%q", upstreamAddr)) + assert.Contains(t, text, "rsync_proxy_connection_received_bytes{index=\"") + assert.Contains(t, text, "rsync_proxy_connection_connected_timestamp_seconds{index=\"") + assert.Contains(t, text, "rsync_proxy_connection_duration_seconds{index=\"") + assert.NotContains(t, text, rawConn.LocalAddr().String()) + + wg.Done() +} + +func TestPrometheusConnectionGroupingUsesStructuredKey(t *testing.T) { + srv := New() + + first := &ConnInfo{Index: 1, ConnectedAt: time.Unix(100, 0)} + first.Module = "a\xffb" + first.UpstreamAddr = "c" + srv.connInfo.Store(first.Index, first) + + second := &ConnInfo{Index: 2, ConnectedAt: time.Unix(100, 0)} + second.Module = "a" + second.UpstreamAddr = "b\xffc" + srv.connInfo.Store(second.Index, second) + + var buf bytes.Buffer + srv.writePrometheusMetrics(&buf, time.Unix(101, 0)) + text := buf.String() + + assert.Contains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\xffb\",upstream=\"c\"} 1\n") + assert.Contains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\",upstream=\"b\xffc\"} 1\n") + assert.NotContains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\",upstream=\"b\xffc\"} 2\n") +} + +func TestPrometheusDurationIncludesFractionalSeconds(t *testing.T) { + srv := New() + conn := &ConnInfo{Index: 1, ConnectedAt: time.Unix(100, 0)} + conn.Module = "fake" + conn.UpstreamAddr = "127.0.0.1:873" + srv.connInfo.Store(conn.Index, conn) + + var buf bytes.Buffer + srv.writePrometheusMetrics(&buf, time.Unix(100, 250_000_000)) + + assert.Contains(t, buf.String(), "rsync_proxy_connection_duration_seconds{index=\"1\",module=\"fake\",upstream=\"127.0.0.1:873\"} 0.250\n") +} + +func TestPrometheusLabelValueEscaping(t *testing.T) { + assert.Equal(t, `plain`, prometheusEscapeLabelValue("plain")) + assert.Equal(t, `quote\"value`, prometheusEscapeLabelValue(`quote"value`)) + assert.Equal(t, `slash\\value`, prometheusEscapeLabelValue(`slash\value`)) + assert.Equal(t, `line\nbreak`, prometheusEscapeLabelValue("line\nbreak")) + assert.Equal(t, `unknown`, prometheusLabelValueOrUnknown("")) +} + func TestPerUpstreamQueueIsolation(t *testing.T) { srv := startServer(t) defer srv.Close()