Skip to content

Commit fd49c0f

Browse files
authored
feat: updated AI Explain API interface for new API definition [IDE-954] (#78)
1 parent ec4acf4 commit fd49c0f

File tree

6 files changed

+59
-47
lines changed

6 files changed

+59
-47
lines changed

go.mod

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
module github.com/snyk/code-client-go
22

3-
go 1.23.0
4-
5-
toolchain go1.23.6
3+
go 1.23.6
64

75
require (
86
github.com/go-git/go-git/v5 v5.14.0

llm/api_client.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,27 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7+
"github.com/snyk/code-client-go/observability"
78
"io"
89
"net/http"
910
"net/url"
10-
11-
"github.com/snyk/code-client-go/observability"
1211
)
1312

14-
const (
13+
var (
1514
completeStatus = "COMPLETE"
1615
failedToObtainRequestIdString = "Failed to obtain request id. "
1716
defaultEndpointURL = "http://localhost:10000/explain"
1817
)
1918

20-
func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options ExplainOptions) (explainResponse, error) {
19+
func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options ExplainOptions) (Explanations, error) {
2120
span := d.instrumentor.StartSpan(ctx, "code.RunExplain")
2221
defer span.Finish()
2322

2423
requestId, err := observability.GetTraceId(ctx)
2524
logger := d.logger.With().Str("method", "code.RunExplain").Str("requestId", requestId).Logger()
2625
if err != nil {
2726
logger.Err(err).Msg(failedToObtainRequestIdString + err.Error())
28-
return explainResponse{}, err
27+
return Explanations{}, err
2928
}
3029

3130
logger.Debug().Msg("API: Retrieving explain for bundle")
@@ -34,7 +33,7 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain
3433
requestBody, err := d.explainRequestBody(&options)
3534
if err != nil {
3635
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error creating request body")
37-
return explainResponse{}, err
36+
return Explanations{}, err
3837
}
3938
logger.Debug().Str("payload body: %s\n", string(requestBody)).Msg("Marshaled payload")
4039

@@ -43,21 +42,21 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain
4342
u, err = url.Parse(defaultEndpointURL)
4443
if err != nil {
4544
logger.Err(err).Send()
46-
return explainResponse{}, err
45+
return Explanations{}, err
4746
}
4847
}
4948
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewBuffer(requestBody))
5049
if err != nil {
5150
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error creating request")
52-
return explainResponse{}, err
51+
return Explanations{}, err
5352
}
5453

5554
d.addDefaultHeaders(req, requestId)
5655

5756
resp, err := d.httpClientFunc().Do(req) //nolint:bodyclose // this seems to be a false positive
5857
if err != nil {
5958
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error getting response")
60-
return explainResponse{}, err
59+
return Explanations{}, err
6160
}
6261
defer func(Body io.ReadCloser) {
6362
bodyCloseErr := Body.Close()
@@ -70,25 +69,28 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain
7069
responseBody, err := io.ReadAll(resp.Body)
7170
if err != nil {
7271
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error reading all response")
73-
return explainResponse{}, err
72+
return Explanations{}, err
7473
}
7574
logger.Debug().Str("response body: %s\n", string(responseBody)).Msg("Got the response")
76-
7775
var response explainResponse
76+
var explains Explanations
7877
response.Status = completeStatus
7978
err = json.Unmarshal(responseBody, &response)
8079
if err != nil {
8180
logger.Err(err).Str("responseBody", string(responseBody)).Msg("error unmarshalling")
82-
return explainResponse{}, err
81+
return Explanations{}, err
8382
}
84-
return response, nil
83+
84+
explains = response.Explanation
85+
86+
return explains, nil
8587
}
8688

8789
func (d *DeepCodeLLMBindingImpl) explainRequestBody(options *ExplainOptions) ([]byte, error) {
8890
logger := d.logger.With().Str("method", "code.explainRequestBody").Logger()
8991

9092
var request explainRequest
91-
if options.Diff == "" {
93+
if len(options.Diffs) == 0 {
9294
request.VulnExplanation = &explainVulnerabilityRequest{
9395
RuleId: options.RuleKey,
9496
Derivation: options.Derivation,
@@ -99,7 +101,7 @@ func (d *DeepCodeLLMBindingImpl) explainRequestBody(options *ExplainOptions) ([]
99101
} else {
100102
request.FixExplanation = &explainFixRequest{
101103
RuleId: options.RuleKey,
102-
Diff: options.Diff,
104+
Diffs: options.Diffs,
103105
ExplanationLength: SHORT,
104106
}
105107
logger.Debug().Msg("payload for FixExplanation")

llm/api_client_test.go

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestDeepcodeLLMBinding_runExplain(t *testing.T) {
2222
options ExplainOptions
2323
serverResponse string
2424
serverStatusCode int
25-
expectedResponse explainResponse
25+
expectedResponse Explanations
2626
expectedError string
2727
expectedLogMessage string
2828
}{
@@ -33,25 +33,19 @@ func TestDeepcodeLLMBinding_runExplain(t *testing.T) {
3333
Derivation: "Derivation",
3434
RuleMessage: "rule-message",
3535
},
36-
serverResponse: `{"explanation": "This is a vulnerability explanation"}`,
36+
serverResponse: "{\n \"explanation\": \n {\n \"explanation1\": \"This is the first explanation\",\n \"explanation2\": \"this is the second explanation\"\n }\n}",
3737
serverStatusCode: http.StatusOK,
38-
expectedResponse: explainResponse{
39-
Status: completeStatus,
40-
Explanation: "This is a vulnerability explanation",
41-
},
38+
expectedResponse: map[string]string{"explanation1": "This is the first explanation", "explanation2": "this is the second explanation"},
4239
},
4340
{
4441
name: "successful fix explanation",
4542
options: ExplainOptions{
4643
RuleKey: "rule-key",
47-
Diff: "Diff",
44+
Diffs: []string{"Diffs"},
4845
},
49-
serverResponse: `{"explanation": "This is a fix explanation"}`,
46+
serverResponse: "{\n \"explanation\": \n {\n \"explanation1\": \"This is the first explanation\",\n \"explanation2\": \"this is the second explanation\"\n }\n}",
5047
serverStatusCode: http.StatusOK,
51-
expectedResponse: explainResponse{
52-
Status: completeStatus,
53-
Explanation: "This is a fix explanation",
54-
},
48+
expectedResponse: map[string]string{"explanation1": "This is the first explanation", "explanation2": "this is the second explanation"},
5549
},
5650
{
5751
name: "error creating request body",
@@ -146,7 +140,7 @@ func TestDeepcodeLLMBinding_explainRequestBody(t *testing.T) {
146140
t.Run("FixExplanation", func(t *testing.T) {
147141
options := &ExplainOptions{
148142
RuleKey: "test-rule-key",
149-
Diff: "test-Diff",
143+
Diffs: []string{"test-Diffs"},
150144
}
151145
requestBody, err := d.explainRequestBody(options)
152146
require.NoError(t, err)
@@ -158,7 +152,7 @@ func TestDeepcodeLLMBinding_explainRequestBody(t *testing.T) {
158152
assert.Nil(t, request.VulnExplanation)
159153
assert.NotNil(t, request.FixExplanation)
160154
assert.Equal(t, "test-rule-key", request.FixExplanation.RuleId)
161-
assert.Equal(t, "test-Diff", request.FixExplanation.Diff)
155+
assert.Equal(t, []string{"test-Diffs"}, request.FixExplanation.Diffs)
162156
assert.Equal(t, SHORT, request.FixExplanation.ExplanationLength)
163157
})
164158
}

llm/binding.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ type SnykLLMBindings interface {
4444
// output - a channel that can be used to stream the results
4545
Explain(ctx context.Context, input AIRequest, format OutputFormat, output chan<- string) error
4646
}
47+
type ExplainResult map[string]string
4748

4849
type DeepCodeLLMBinding interface {
4950
SnykLLMBindings
50-
ExplainWithOptions(ctx context.Context, options ExplainOptions) (string, error)
51+
ExplainWithOptions(ctx context.Context, options ExplainOptions) (ExplainResult, error)
5152
}
5253

5354
// DeepCodeLLMBindingImpl is an LLM binding for the Snyk Code LLM.
@@ -60,15 +61,23 @@ type DeepCodeLLMBindingImpl struct {
6061
endpoint *url.URL
6162
}
6263

63-
func (d *DeepCodeLLMBindingImpl) ExplainWithOptions(ctx context.Context, options ExplainOptions) (string, error) {
64+
func (d *DeepCodeLLMBindingImpl) ExplainWithOptions(ctx context.Context, options ExplainOptions) (ExplainResult, error) {
6465
s := d.instrumentor.StartSpan(ctx, "code.ExplainWithOptions")
6566
defer d.instrumentor.Finish(s)
6667
response, err := d.runExplain(s.Context(), options)
68+
explainResult := ExplainResult{}
6769
if err != nil {
68-
return "", err
70+
return explainResult, err
71+
}
72+
index := 0
73+
for _, explanation := range response {
74+
if index < len(options.Diffs) {
75+
explainResult[options.Diffs[index]] = explanation
76+
}
77+
index++
6978
}
7079

71-
return response.Explanation, nil
80+
return explainResult, nil
7281
}
7382

7483
func (d *DeepCodeLLMBindingImpl) PublishIssues(_ context.Context, _ []map[string]string) error {
@@ -85,7 +94,11 @@ func (d *DeepCodeLLMBindingImpl) Explain(ctx context.Context, input AIRequest, _
8594
if err != nil {
8695
return err
8796
}
88-
output <- response
97+
jsonBytes, err := json.Marshal(response)
98+
if err != nil {
99+
return err
100+
}
101+
output <- string(jsonBytes)
89102
return nil
90103
}
91104

llm/binding_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func TestExplainWithOptions(t *testing.T) {
3030

3131
explainResponseJSON := explainResponse{
3232
Status: completeStatus,
33-
Explanation: "mock explanation",
33+
Explanation: map[string]string{"explanation1": "This is the first explanation"},
3434
}
3535

3636
expectedResponseBody, err := json.Marshal(explainResponseJSON)
@@ -41,9 +41,14 @@ func TestExplainWithOptions(t *testing.T) {
4141
Body: io.NopCloser(strings.NewReader(string(expectedResponseBody))),
4242
}
4343
mockHTTPClient.EXPECT().Do(gomock.Any()).Return(&mockResponse, nil)
44-
explanation, err := d.ExplainWithOptions(context.Background(), ExplainOptions{})
44+
testDiff := "test diff"
45+
explanation, err := d.ExplainWithOptions(context.Background(), ExplainOptions{Diffs: []string{testDiff}})
4546
assert.NoError(t, err)
46-
assert.Equal(t, explainResponseJSON.Explanation, explanation)
47+
var exptectedExplanationsResponse explainResponse
48+
err = json.Unmarshal(expectedResponseBody, &exptectedExplanationsResponse)
49+
assert.NoError(t, err)
50+
expectedResExplanations := exptectedExplanationsResponse.Explanation
51+
assert.Equal(t, expectedResExplanations["explanation1"], explanation[testDiff])
4752
})
4853

4954
t.Run("runExplain error", func(t *testing.T) {

llm/types.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ const (
99
)
1010

1111
type explainVulnerabilityRequest struct {
12-
RuleId string `json:"rule_key"`
12+
RuleId string `json:"rule_id"`
1313
RuleMessage string `json:"rule_message"`
1414
Derivation string `json:"Derivation"`
1515
ExplanationLength explanationLength `json:"explanation_length"`
1616
}
1717

1818
type explainFixRequest struct {
19-
RuleId string `json:"rule_key"`
20-
Diff string `json:"diff"`
19+
RuleId string `json:"rule_id"`
20+
Diffs []string `json:"diffs"`
2121
ExplanationLength explanationLength `json:"explanation_length"`
2222
}
2323

@@ -27,10 +27,10 @@ type explainRequest struct {
2727
}
2828

2929
type explainResponse struct {
30-
Status string `json:"status"`
31-
Explanation string `json:"explanation"`
30+
Status string `json:"status"`
31+
Explanation Explanations `json:"explanation"`
3232
}
33-
33+
type Explanations map[string]string
3434
type ExplainOptions struct {
3535
// Derivation = Code Flow
3636
// const derivationLineNumbers: Set<number> = new Set<number>();
@@ -62,5 +62,5 @@ type ExplainOptions struct {
6262
RuleMessage string `json:"rule_message"`
6363

6464
// fix difference
65-
Diff string `json:"diff"`
65+
Diffs []string `json:"diffs"`
6666
}

0 commit comments

Comments
 (0)