Skip to content

Commit 147917c

Browse files
feat: new deepcode LLM binding [IDE-877] (#74)
Co-authored-by: Abdelrahman Shawki Hassan <[email protected]>
1 parent d9cd412 commit 147917c

File tree

9 files changed

+763
-0
lines changed

9 files changed

+763
-0
lines changed

http/http.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,8 @@ func (s *httpClient) httpCall(req *http.Request) (*http.Response, error) {
153153

154154
return response, nil
155155
}
156+
157+
func NewDefaultClientFactory() HTTPClientFactory {
158+
clientFunc := func() *http.Client { return http.DefaultClient }
159+
return clientFunc
160+
}

llm/api_client.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package llm
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
11+
"github.com/snyk/code-client-go/observability"
12+
)
13+
14+
const (
15+
completeStatus = "COMPLETE"
16+
failedToObtainRequestIdString = "Failed to obtain request id. "
17+
defaultEndpointURL = "http://localhost:10000/explain"
18+
)
19+
20+
func (d *DeepcodeLLMBinding) runExplain(ctx context.Context, options ExplainOptions) (explainResponse, error) {
21+
span := d.instrumentor.StartSpan(ctx, "code.RunExplain")
22+
defer span.Finish()
23+
24+
requestId, err := observability.GetTraceId(ctx)
25+
logger := d.logger.With().Str("method", "code.RunExplain").Str("requestId", requestId).Logger()
26+
if err != nil {
27+
logger.Err(err).Msg(failedToObtainRequestIdString + err.Error())
28+
return explainResponse{}, err
29+
}
30+
31+
logger.Debug().Msg("API: Retrieving explain for bundle")
32+
defer logger.Debug().Msg("API: Retrieving explain done")
33+
34+
requestBody, err := d.explainRequestBody(&options)
35+
if err != nil {
36+
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error creating request body")
37+
return explainResponse{}, err
38+
}
39+
logger.Debug().Str("payload body: %s\n", string(requestBody)).Msg("Marshaled payload")
40+
41+
u := d.endpoint
42+
if u == nil {
43+
u, err = url.Parse(defaultEndpointURL)
44+
if err != nil {
45+
logger.Err(err).Send()
46+
return explainResponse{}, err
47+
}
48+
}
49+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), bytes.NewBuffer(requestBody))
50+
if err != nil {
51+
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error creating request")
52+
return explainResponse{}, err
53+
}
54+
55+
d.addDefaultHeaders(req, requestId)
56+
57+
resp, err := d.httpClientFunc().Do(req) //nolint:bodyclose // this seems to be a false positive
58+
if err != nil {
59+
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error getting response")
60+
return explainResponse{}, err
61+
}
62+
defer func(Body io.ReadCloser) {
63+
bodyCloseErr := Body.Close()
64+
if bodyCloseErr != nil {
65+
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error closing response")
66+
}
67+
}(resp.Body)
68+
69+
// Read the response body
70+
responseBody, err := io.ReadAll(resp.Body)
71+
if err != nil {
72+
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error reading all response")
73+
return explainResponse{}, err
74+
}
75+
logger.Debug().Str("response body: %s\n", string(responseBody)).Msg("Got the response")
76+
77+
var response explainResponse
78+
response.Status = completeStatus
79+
err = json.Unmarshal(responseBody, &response)
80+
if err != nil {
81+
logger.Err(err).Str("responseBody", string(responseBody)).Msg("error unmarshalling")
82+
return explainResponse{}, err
83+
}
84+
return response, nil
85+
}
86+
87+
func (d *DeepcodeLLMBinding) explainRequestBody(options *ExplainOptions) ([]byte, error) {
88+
logger := d.logger.With().Str("method", "code.explainRequestBody").Logger()
89+
90+
var request explainRequest
91+
if options.Diff == "" {
92+
request.VulnExplanation = &explainVulnerabilityRequest{
93+
RuleId: options.RuleKey,
94+
Derivation: options.Derivation,
95+
RuleMessage: options.RuleMessage,
96+
ExplanationLength: SHORT,
97+
}
98+
logger.Debug().Msg("payload for VulnExplanation")
99+
} else {
100+
request.FixExplanation = &explainFixRequest{
101+
RuleId: options.RuleKey,
102+
Diff: options.Diff,
103+
ExplanationLength: SHORT,
104+
}
105+
logger.Debug().Msg("payload for FixExplanation")
106+
}
107+
requestBody, err := json.Marshal(request)
108+
return requestBody, err
109+
}
110+
111+
func (d *DeepcodeLLMBinding) addDefaultHeaders(req *http.Request, requestId string) {
112+
req.Header.Set("snyk-request-id", requestId)
113+
req.Header.Set("Cache-Control", "private, max-age=0, no-cache")
114+
req.Header.Set("Content-Type", "application/json")
115+
}

llm/api_client_test.go

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
package llm
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"io"
7+
"net/http"
8+
"net/http/httptest"
9+
"net/url"
10+
"testing"
11+
12+
"github.com/rs/zerolog"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
16+
"github.com/snyk/code-client-go/observability"
17+
)
18+
19+
func TestDeepcodeLLMBinding_runExplain(t *testing.T) {
20+
tests := []struct {
21+
name string
22+
options ExplainOptions
23+
serverResponse string
24+
serverStatusCode int
25+
expectedResponse explainResponse
26+
expectedError string
27+
expectedLogMessage string
28+
}{
29+
{
30+
name: "successful vuln explanation",
31+
options: ExplainOptions{
32+
RuleKey: "rule-key",
33+
Derivation: "Derivation",
34+
RuleMessage: "rule-message",
35+
},
36+
serverResponse: `{"explanation": "This is a vulnerability explanation"}`,
37+
serverStatusCode: http.StatusOK,
38+
expectedResponse: explainResponse{
39+
Status: completeStatus,
40+
Explanation: "This is a vulnerability explanation",
41+
},
42+
},
43+
{
44+
name: "successful fix explanation",
45+
options: ExplainOptions{
46+
RuleKey: "rule-key",
47+
Diff: "Diff",
48+
},
49+
serverResponse: `{"explanation": "This is a fix explanation"}`,
50+
serverStatusCode: http.StatusOK,
51+
expectedResponse: explainResponse{
52+
Status: completeStatus,
53+
Explanation: "This is a fix explanation",
54+
},
55+
},
56+
{
57+
name: "error creating request body",
58+
options: ExplainOptions{}, // Missing required fields will cause an error
59+
serverStatusCode: http.StatusUnprocessableEntity,
60+
expectedError: "unexpected end of JSON input",
61+
expectedLogMessage: "error creating request body",
62+
},
63+
{
64+
name: "error getting response",
65+
options: ExplainOptions{
66+
RuleKey: "rule-key",
67+
Derivation: "Derivation",
68+
RuleMessage: "rule-message",
69+
},
70+
serverStatusCode: http.StatusInternalServerError,
71+
expectedError: "unexpected end of JSON input",
72+
expectedLogMessage: "error getting response",
73+
},
74+
{
75+
name: "error unmarshalling response",
76+
options: ExplainOptions{
77+
RuleKey: "rule-key",
78+
Derivation: "Derivation",
79+
RuleMessage: "rule-message",
80+
},
81+
serverResponse: `invalid json`,
82+
serverStatusCode: http.StatusOK,
83+
expectedError: "invalid character 'i' looking for beginning of value",
84+
expectedLogMessage: "error unmarshalling",
85+
},
86+
}
87+
88+
for _, tt := range tests {
89+
t.Run(tt.name, func(t *testing.T) {
90+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91+
w.WriteHeader(tt.serverStatusCode)
92+
_, _ = w.Write([]byte(tt.serverResponse))
93+
if tt.expectedError == "unexpected EOF" {
94+
_ = r.Body.Close() // Close the request body early to simulate a read error
95+
}
96+
}))
97+
defer server.Close()
98+
99+
u, err := url.Parse(server.URL)
100+
assert.NoError(t, err)
101+
102+
d := NewDeepcodeLLMBinding(WithEndpoint(u))
103+
104+
ctx := context.Background()
105+
ctx = observability.GetContextWithTraceId(ctx, "test-trace-id")
106+
107+
response, err := d.runExplain(ctx, tt.options)
108+
109+
if tt.expectedError != "" {
110+
require.Error(t, err)
111+
assert.Contains(t, err.Error(), tt.expectedError)
112+
} else {
113+
require.NoError(t, err)
114+
assert.Equal(t, tt.expectedResponse, response)
115+
}
116+
})
117+
}
118+
}
119+
120+
func TestDeepcodeLLMBinding_explainRequestBody(t *testing.T) {
121+
d := &DeepcodeLLMBinding{
122+
logger: testLogger(t),
123+
}
124+
125+
t.Run("VulnExplanation", func(t *testing.T) {
126+
options := &ExplainOptions{
127+
RuleKey: "test-rule-key",
128+
Derivation: "test-Derivation",
129+
RuleMessage: "test-rule-message",
130+
}
131+
requestBody, err := d.explainRequestBody(options)
132+
require.NoError(t, err)
133+
134+
var request explainRequest
135+
err = json.Unmarshal(requestBody, &request)
136+
require.NoError(t, err)
137+
138+
assert.Nil(t, request.FixExplanation)
139+
assert.NotNil(t, request.VulnExplanation)
140+
assert.Equal(t, "test-rule-key", request.VulnExplanation.RuleId)
141+
assert.Equal(t, "test-Derivation", request.VulnExplanation.Derivation)
142+
assert.Equal(t, "test-rule-message", request.VulnExplanation.RuleMessage)
143+
assert.Equal(t, SHORT, request.VulnExplanation.ExplanationLength)
144+
})
145+
146+
t.Run("FixExplanation", func(t *testing.T) {
147+
options := &ExplainOptions{
148+
RuleKey: "test-rule-key",
149+
Diff: "test-Diff",
150+
}
151+
requestBody, err := d.explainRequestBody(options)
152+
require.NoError(t, err)
153+
154+
var request explainRequest
155+
err = json.Unmarshal(requestBody, &request)
156+
require.NoError(t, err)
157+
158+
assert.Nil(t, request.VulnExplanation)
159+
assert.NotNil(t, request.FixExplanation)
160+
assert.Equal(t, "test-rule-key", request.FixExplanation.RuleId)
161+
assert.Equal(t, "test-Diff", request.FixExplanation.Diff)
162+
assert.Equal(t, SHORT, request.FixExplanation.ExplanationLength)
163+
})
164+
}
165+
166+
// Helper function for testing
167+
func testLogger(t *testing.T) *zerolog.Logger {
168+
t.Helper()
169+
logger := zerolog.New(io.Discard)
170+
return &logger
171+
}
172+
173+
// Test with existing headers
174+
func TestAddDefaultHeadersWithExistingHeaders(t *testing.T) {
175+
d := &DeepcodeLLMBinding{} // Initialize your struct if needed
176+
req := &http.Request{Header: http.Header{"Existing-Header": {"existing-value"}}}
177+
requestId := "test-request-id"
178+
179+
d.addDefaultHeaders(req, requestId)
180+
181+
snykRequestId := req.Header.Get("snyk-request-id")
182+
cacheControl := req.Header.Get("Cache-Control")
183+
contentType := req.Header.Get("Content-Type")
184+
existingHeader := req.Header.Get("Existing-Header")
185+
186+
if snykRequestId != requestId {
187+
t.Errorf("Expected snyk-request-id header to be %s, got %s", requestId, snykRequestId)
188+
}
189+
190+
if cacheControl != "private, max-age=0, no-cache" {
191+
t.Errorf("Expected Cache-Control header to be 'private, max-age=0, no-cache', got %s", cacheControl)
192+
}
193+
194+
if contentType != "application/json" {
195+
t.Errorf("Expected Content-Type header to be 'application/json', got %s", contentType)
196+
}
197+
198+
if existingHeader != "existing-value" {
199+
t.Errorf("Expected Existing-Header to be 'existing-value', got %s", existingHeader)
200+
}
201+
}

0 commit comments

Comments
 (0)