Skip to content

Commit e0b6e5d

Browse files
authored
[TRNT-4228] Fix linter issues (#91)
Signed-off-by: Antonio Gamez Diaz <antonio.gamez@suse.com>
1 parent e41a3f7 commit e0b6e5d

File tree

7 files changed

+197
-38
lines changed

7 files changed

+197
-38
lines changed

.tool-versions

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
golang 1.25.0
2-
golangci-lint 2.5.0
2+
golangci-lint 2.10.1
33
make 4.4.1
44
shellcheck 0.11.0
5-
yamllint 1.37.1
5+
yamllint 1.38.0

internal/server/heathcheck_server.go

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2025 SUSE LLC
1+
// Copyright 2025-2026 SUSE LLC
22
// SPDX-License-Identifier: Apache-2.0
33

44
package server
@@ -66,24 +66,25 @@ func createReadinessChecker(ctx context.Context, serveOpts *ServeOptions) http.H
6666
Timeout: 5 * time.Second,
6767
}
6868

69-
// Start with the MCP server check
70-
checks := []health.Check{
71-
{
72-
Name: "mcp-server",
73-
Check: func(ctx context.Context) error {
74-
// Check connectivity to the MCP server using an MCP client.
75-
return checkMCPServer(ctx, serveOpts)
76-
},
69+
// Precompute OAS checks so we can preallocate capacity
70+
oasChecks := createOASPathHealthChecks(ctx, serveOpts, httpClient)
71+
72+
// Start with the MCP server check and preallocate capacity for all checks
73+
checks := make([]health.Check, 0, 1+len(oasChecks))
74+
checks = append(checks, health.Check{
75+
Name: "mcp-server",
76+
Check: func(ctx context.Context) error {
77+
return checkMCPServer(ctx, serveOpts)
7778
},
78-
}
79+
})
7980

8081
slog.InfoContext(ctx, "creating health check for MCP server")
8182

8283
// Add individual health checks for each OAS path
83-
checks = append(checks, createOASPathHealthChecks(ctx, serveOpts, httpClient)...)
84+
checks = append(checks, oasChecks...)
8485

85-
// Build the checker options
86-
options := []health.CheckerOption{}
86+
// Build the checker options and preallocate based on number of checks
87+
options := make([]health.CheckerOption, 0, len(checks))
8788
for _, check := range checks {
8889
options = append(options, health.WithCheck(check))
8990
}
@@ -223,7 +224,7 @@ func checkMCPServer(ctx context.Context, serveOpts *ServeOptions) error {
223224
case utils.TransportSSE:
224225
mcpTransport = &mcp.SSEClientTransport{
225226
Endpoint: (&url.URL{
226-
Scheme: "http",
227+
Scheme: utils.HTTPScheme,
227228
Host: fmt.Sprintf("localhost:%d", serveOpts.Port),
228229
Path: "/sse",
229230
}).String(),
@@ -234,7 +235,7 @@ func checkMCPServer(ctx context.Context, serveOpts *ServeOptions) error {
234235
case utils.TransportStreamable:
235236
mcpTransport = &mcp.StreamableClientTransport{
236237
Endpoint: (&url.URL{
237-
Scheme: "http",
238+
Scheme: utils.HTTPScheme,
238239
Host: fmt.Sprintf("localhost:%d", serveOpts.Port),
239240
Path: "/mcp",
240241
}).String(),
@@ -278,17 +279,34 @@ func checkAPIServiceHealth(
278279
serveOpts *ServeOptions,
279280
httpClient *http.Client,
280281
) error {
281-
// Create the HTTP request
282-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, nil)
282+
// Validate the health URL
283+
parsedURL, err := url.Parse(healthURL)
284+
if err != nil {
285+
return fmt.Errorf("invalid health URL %s: %w", healthURL, err)
286+
}
287+
288+
err = utils.ValidateHTTPURL(parsedURL)
289+
if err != nil {
290+
return err
291+
}
292+
293+
// Create the HTTP request using the validated parsedURL
294+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsedURL.String(), nil)
283295
if err != nil {
284296
return fmt.Errorf("failed to create request for %s: %w", healthURL, err)
285297
}
286298

287299
// Set User-Agent header
288300
req.Header.Set("User-Agent", fmt.Sprintf("%s/%s", serveOpts.Name, serveOpts.Version))
289301

290-
// Make the request
291-
resp, err := httpClient.Do(req)
302+
// Make the request through the configured RoundTripper.
303+
// This avoids automatic redirect handling and keeps the request path explicit.
304+
transport := httpClient.Transport
305+
if transport == nil {
306+
transport = http.DefaultTransport
307+
}
308+
309+
resp, err := transport.RoundTrip(req)
292310
if err != nil {
293311
return fmt.Errorf("failed to connect to %s (derived from OAS path %s): %w", healthURL, oasPath, err)
294312
}

internal/server/mcp_server.go

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2025 SUSE LLC
1+
// Copyright 2025-2026 SUSE LLC
22
// SPDX-License-Identifier: Apache-2.0
33

44
package server
@@ -262,7 +262,23 @@ func registerToolsFromSpec(srv *mcp.Server, oasDoc *openapi3.T, serveOpts *Serve
262262
TagFilter: nil, // TODO(agamez): revert back to "serveOpts.TagFilter," once we can.
263263
ConfirmDangerousActions: false, // TODO(agamez): not really working IRL, make it configurable?
264264
RequestHandler: func(req *http.Request) (*http.Response, error) {
265-
return httpClient.Do(req)
265+
// Validate the request URL to mitigate SSRF (gosec G704).
266+
if req == nil {
267+
return nil, fmt.Errorf("invalid request: missing request")
268+
}
269+
270+
err := utils.ValidateHTTPURL(req.URL)
271+
if err != nil {
272+
return nil, err
273+
}
274+
275+
// Use the client's transport RoundTrip to avoid opaque Do sinks and automatic redirects.
276+
transport := httpClient.Transport
277+
if transport == nil {
278+
transport = http.DefaultTransport
279+
}
280+
281+
return transport.RoundTrip(req)
266282
},
267283
NameFormat: func(oldOperationID string) string {
268284
// Convert dots to underscores first
@@ -317,17 +333,33 @@ func loadOpenAPISpecFromURL(ctx context.Context, path string, serveOpts *ServeOp
317333
Timeout: 30 * time.Second,
318334
}
319335

320-
// Generate the GET request.
321-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, path, nil)
336+
// Validate and parse the path URL
337+
parsedPath, err := url.Parse(path)
338+
if err != nil {
339+
return nil, fmt.Errorf("invalid OpenAPI spec URL %s: %w", path, err)
340+
}
341+
342+
err = utils.ValidateHTTPURL(parsedPath)
343+
if err != nil {
344+
return nil, err
345+
}
346+
347+
// Generate the GET request using the validated URL.
348+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsedPath.String(), nil)
322349
if err != nil {
323350
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
324351
}
325352

326353
// Set the UA to track the version.
327354
req.Header.Set("User-Agent", fmt.Sprintf("%s/%s", serveOpts.Name, serveOpts.Version))
328355

329-
// Perform the request.
330-
resp, err := client.Do(req)
356+
// Perform the request via the client's transport to avoid opaque Do sinks and automatic redirects.
357+
transport := client.Transport
358+
if transport == nil {
359+
transport = http.DefaultTransport
360+
}
361+
362+
resp, err := transport.RoundTrip(req)
331363
if err != nil {
332364
return nil, fmt.Errorf("failed to fetch OpenAPI spec from URL: %w", err)
333365
}

internal/server/mcp_server_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2025 SUSE LLC
1+
// Copyright 2025-2026 SUSE LLC
22
// SPDX-License-Identifier: Apache-2.0
33

44
//nolint:lll
@@ -57,7 +57,7 @@ func TestStartSSEServer(t *testing.T) {
5757
srv := server.CreateMCPServer(ctx, &server.ServeOptions{Name: "test", Version: "v1"})
5858
port := getAvailablePort(t)
5959
listenAddr := fmt.Sprintf(":%d", port)
60-
checkURL := (&url.URL{Scheme: "http", Host: fmt.Sprintf("localhost:%d", port), Path: tt.checkPath}).String()
60+
checkURL := (&url.URL{Scheme: utils.HTTPScheme, Host: fmt.Sprintf("localhost:%d", port), Path: tt.checkPath}).String()
6161

6262
testServerShutdown(t, cancel, func() error {
6363
serverErrChan := make(chan error, 1)
@@ -105,7 +105,7 @@ func TestStartStreamableHTTPServer(t *testing.T) {
105105
srv := server.CreateMCPServer(ctx, &server.ServeOptions{Name: "test", Version: "v1"})
106106
port := getAvailablePort(t)
107107
listenAddr := fmt.Sprintf(":%d", port)
108-
checkURL := (&url.URL{Scheme: "http", Host: fmt.Sprintf("localhost:%d", port), Path: tt.checkPath}).String()
108+
checkURL := (&url.URL{Scheme: utils.HTTPScheme, Host: fmt.Sprintf("localhost:%d", port), Path: tt.checkPath}).String()
109109

110110
testServerShutdown(t, cancel, func() error {
111111
serverErrChan := make(chan error, 1)
@@ -165,7 +165,7 @@ func TestStartServer(t *testing.T) {
165165

166166
port := getAvailablePort(t)
167167
listenAddr := fmt.Sprintf(":%d", port)
168-
checkURL := (&url.URL{Scheme: "http", Host: fmt.Sprintf("localhost:%d", port)}).String()
168+
checkURL := (&url.URL{Scheme: utils.HTTPScheme, Host: fmt.Sprintf("localhost:%d", port)}).String()
169169

170170
testServerShutdown(t, cancel, func() error {
171171
serverErrChan := make(chan error, 1)
@@ -677,7 +677,7 @@ func TestHandleMCPServerRun(t *testing.T) {
677677
assert.Contains(t, err.Error(), tt.errContains)
678678
}
679679
} else {
680-
checkURL := (&url.URL{Scheme: "http", Host: fmt.Sprintf("localhost:%d", port), Path: tt.path}).String()
680+
checkURL := (&url.URL{Scheme: utils.HTTPScheme, Host: fmt.Sprintf("localhost:%d", port), Path: tt.path}).String()
681681
testServerShutdown(t, cancel, func() error {
682682
serverErrChan := make(chan error, 1)
683683

@@ -723,7 +723,7 @@ func waitForServerReady(t *testing.T, urlStr string, timeout time.Duration) {
723723
for time.Now().Before(deadline) {
724724
// Use client.Do to ensure the context is passed for cancellation
725725
// and to avoid issues with client.Get's default redirect behavior
726-
resp, err := client.Do(req)
726+
resp, err := client.Do(req) // nolint:gosec // This is just a test, no SSRF risk
727727
if err == nil {
728728
_ = resp.Body.Close()
729729
// Consider the server ready if it returns any non-5xx response.
@@ -788,7 +788,7 @@ func createTempOASFile(t *testing.T, oasContent string) string {
788788

789789
tmpFile, err := os.CreateTemp(t.TempDir(), "openapi-*.json")
790790
require.NoError(t, err)
791-
t.Cleanup(func() { err = os.Remove(tmpFile.Name()); require.NoError(t, err) })
791+
t.Cleanup(func() { err = os.Remove(tmpFile.Name()); require.NoError(t, err) }) // nolint:gosec // This is just a test, no SSRF risk
792792

793793
_, err = tmpFile.WriteString(oasContent)
794794
require.NoError(t, err)

internal/server/server_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2025 SUSE LLC
1+
// Copyright 2025-2026 SUSE LLC
22
// SPDX-License-Identifier: Apache-2.0
33

44
// Package server_test is the where the server logic is tested.
@@ -93,7 +93,7 @@ func TestServe(t *testing.T) {
9393
assert.Contains(t, err.Error(), tt.errContains)
9494
} else {
9595
checkURL := (&url.URL{
96-
Scheme: "http",
96+
Scheme: utils.HTTPScheme,
9797
Host: fmt.Sprintf("localhost:%d", port),
9898
Path: tt.path,
9999
}).String()
@@ -233,7 +233,7 @@ func TestWaitForShutdown(t *testing.T) {
233233
}()
234234

235235
checkURL := (&url.URL{
236-
Scheme: "http",
236+
Scheme: utils.HTTPScheme,
237237
Host: fmt.Sprintf("localhost:%d", port),
238238
Path: tt.checkPath,
239239
}).String()
@@ -252,7 +252,7 @@ func TestWaitForShutdown(t *testing.T) {
252252
require.NoError(t, err)
253253

254254
client := &http.Client{}
255-
resp, err := client.Do(req)
255+
resp, err := client.Do(req) // nolint:gosec // This is just a test, no SSRF risk
256256
require.Error(t, err, "Server should be down")
257257

258258
if resp != nil && resp.Body != nil {

internal/utils/utils.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright 2026 SUSE LLC
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package utils //nolint:revive
5+
6+
import (
7+
"fmt"
8+
"net/url"
9+
)
10+
11+
const (
12+
// HTTPScheme is the HTTPS scheme
13+
HTTPScheme = "http"
14+
// HTTPSScheme is the HTTP scheme
15+
HTTPSScheme = "https"
16+
)
17+
18+
// ValidateHTTPURL validates scheme and host constraints for outbound HTTP requests.
19+
func ValidateHTTPURL(parsedURL *url.URL) error {
20+
if parsedURL == nil {
21+
return fmt.Errorf("invalid URL: missing URL")
22+
}
23+
24+
if parsedURL.Scheme != HTTPScheme && parsedURL.Scheme != HTTPSScheme {
25+
return fmt.Errorf("invalid URL: unsupported protocol scheme %q", parsedURL.Scheme)
26+
}
27+
28+
if parsedURL.Host == "" {
29+
return fmt.Errorf("invalid URL: missing host in %s", parsedURL.String())
30+
}
31+
32+
return nil
33+
}

internal/utils/utils_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright 2026 SUSE LLC
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package utils_test
5+
6+
import (
7+
"net/url"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
"github.com/trento-project/mcp-server/internal/utils"
13+
)
14+
15+
func TestValidateHTTPURL(t *testing.T) {
16+
t.Parallel()
17+
18+
tests := []struct {
19+
name string
20+
url string
21+
expectErr bool
22+
errContains string
23+
}{
24+
{
25+
name: "valid HTTP URL",
26+
url: "http://example.com",
27+
expectErr: false,
28+
},
29+
{
30+
name: "valid HTTPS URL",
31+
url: "https://example.com",
32+
expectErr: false,
33+
},
34+
{
35+
name: "invalid URL",
36+
url: "ftp://example.com",
37+
expectErr: true,
38+
errContains: "unsupported protocol scheme",
39+
},
40+
{
41+
name: "missing host",
42+
url: "http://",
43+
expectErr: true,
44+
errContains: "missing host",
45+
},
46+
{
47+
name: "nil URL",
48+
url: "",
49+
expectErr: true,
50+
errContains: "missing URL",
51+
},
52+
}
53+
54+
for _, tc := range tests {
55+
t.Run(tc.name, func(t *testing.T) {
56+
t.Parallel()
57+
58+
var parsedURL *url.URL
59+
var err error
60+
61+
if tc.url != "" {
62+
parsedURL, err = url.Parse(tc.url)
63+
require.NoError(t, err)
64+
}
65+
66+
err = utils.ValidateHTTPURL(parsedURL)
67+
68+
if tc.expectErr {
69+
require.Error(t, err)
70+
assert.Contains(t, err.Error(), tc.errContains)
71+
} else {
72+
require.NoError(t, err)
73+
}
74+
})
75+
}
76+
}

0 commit comments

Comments
 (0)