Skip to content

Commit 8187e9c

Browse files
fix: explain result order (#98)
2 parents 732091c + 610e718 commit 8187e9c

File tree

4 files changed

+94
-15
lines changed

4 files changed

+94
-15
lines changed

llm/api_client.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io"
99
"net/http"
1010
"net/url"
11+
"strings"
1112
)
1213

1314
var (
@@ -96,17 +97,28 @@ func (d *DeepCodeLLMBindingImpl) explainRequestBody(options *ExplainOptions) ([]
9697
} else {
9798
requestBody, marshalErr = json.Marshal(explainFixRequest{
9899
RuleId: options.RuleKey,
99-
Diffs: encodeDiffs(options.Diffs),
100+
Diffs: prepareDiffs(options.Diffs),
100101
ExplanationLength: SHORT,
101102
})
102103
logger.Debug().Msg("payload for FixExplanation")
103104
}
104105
return requestBody, marshalErr
105106
}
106107

107-
func encodeDiffs(diffs []string) []string {
108-
var encodedDiffs []string
108+
func prepareDiffs(diffs []string) []string {
109+
cleanedDiffs := make([]string, 0, len(diffs))
109110
for _, diff := range diffs {
111+
diffLines := strings.Split(diff, "\n")
112+
cleanedLines := ""
113+
for _, line := range diffLines {
114+
if !strings.HasPrefix(line, "---") && !strings.HasPrefix(line, "+++") {
115+
cleanedLines += line + "\n"
116+
}
117+
}
118+
cleanedDiffs = append(cleanedDiffs, cleanedLines)
119+
}
120+
var encodedDiffs []string
121+
for _, diff := range cleanedDiffs {
110122
encodedDiffs = append(encodedDiffs, base64.StdEncoding.EncodeToString([]byte(diff)))
111123
}
112124
return encodedDiffs

llm/api_client_test.go

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package llm
22

33
import (
44
"context"
5+
"encoding/base64"
56
"encoding/json"
67
"io"
78
"net/http"
@@ -151,7 +152,7 @@ func TestDeepcodeLLMBinding_explainRequestBody(t *testing.T) {
151152

152153
assert.NotNil(t, request)
153154
assert.Equal(t, "test-rule-key", request.RuleId)
154-
expectedEncodedDiffs := encodeDiffs([]string{"test-Diffs"})
155+
expectedEncodedDiffs := prepareDiffs([]string{"test-Diffs"})
155156
assert.Equal(t, expectedEncodedDiffs, request.Diffs)
156157
assert.Equal(t, SHORT, request.ExplanationLength)
157158
})
@@ -206,6 +207,52 @@ func TestEndpoint(t *testing.T) {
206207
}
207208
}
208209

210+
func TestPrepareDiffs(t *testing.T) {
211+
testCases := []struct {
212+
name string
213+
input []string
214+
expected []string
215+
}{
216+
{
217+
name: "Single diff with headers and content",
218+
input: []string{
219+
"--- a/file.txt\n" +
220+
"+++ b/file.txt\n" +
221+
"@@ -1,1 +1,1 @@\n" +
222+
"-old line\n" +
223+
"+new line\n",
224+
},
225+
expected: []string{
226+
base64.StdEncoding.EncodeToString([]byte("@@ -1,1 +1,1 @@\n-old line\n+new line\n\n")),
227+
},
228+
},
229+
{
230+
name: "Multiple diffs",
231+
input: []string{
232+
"--- a/file1.txt\n" +
233+
"+++ b/file1.txt\n" +
234+
"-line 1\n" +
235+
"+line 2\n",
236+
"--- a/file2.txt\n" +
237+
"+++ b/file2.txt\n" +
238+
"content2a\n" +
239+
"+content2b\n",
240+
},
241+
expected: []string{
242+
base64.StdEncoding.EncodeToString([]byte("-line 1\n+line 2\n\n")),
243+
base64.StdEncoding.EncodeToString([]byte("content2a\n+content2b\n\n")),
244+
},
245+
},
246+
}
247+
248+
for _, tc := range testCases {
249+
t.Run(tc.name, func(t *testing.T) {
250+
actual := prepareDiffs(tc.input)
251+
assert.Equal(t, tc.expected, actual)
252+
})
253+
}
254+
}
255+
209256
// Helper function for testing
210257
func testLogger(t *testing.T) *zerolog.Logger {
211258
t.Helper()

llm/binding.go

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"net/url"
7+
"slices"
78

89
"github.com/rs/zerolog"
910

@@ -45,7 +46,7 @@ type SnykLLMBindings interface {
4546
// output - a channel that can be used to stream the results
4647
Explain(ctx context.Context, input AIRequest, format OutputFormat, output chan<- string) error
4748
}
48-
type ExplainResult map[string]string
49+
type ExplainResult []string
4950

5051
type DeepCodeLLMBinding interface {
5152
SnykLLMBindings
@@ -69,15 +70,24 @@ func (d *DeepCodeLLMBindingImpl) ExplainWithOptions(ctx context.Context, options
6970
if err != nil {
7071
return explainResult, err
7172
}
72-
index := 0
73-
for _, explanation := range response {
74-
if index < len(options.Diffs) {
75-
explainResult[options.Diffs[index]] = explanation
76-
}
77-
index++
73+
74+
orderedExplainResults := getOrderedResponse(response)
75+
76+
return orderedExplainResults, nil
77+
}
78+
79+
func getOrderedResponse(explainResponse Explanations) []string {
80+
explainMapKeys := make([]string, 0, len(explainResponse))
81+
for k := range explainResponse {
82+
explainMapKeys = append(explainMapKeys, k)
7883
}
84+
slices.Sort(explainMapKeys)
7985

80-
return explainResult, nil
86+
orderedValues := make([]string, 0, len(explainResponse))
87+
for _, key := range explainMapKeys {
88+
orderedValues = append(orderedValues, explainResponse[key])
89+
}
90+
return orderedValues
8191
}
8292

8393
func (d *DeepCodeLLMBindingImpl) PublishIssues(_ context.Context, _ []map[string]string) error {

llm/binding_test.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,14 @@ func TestExplainWithOptions(t *testing.T) {
2929
d, mockHTTPClient := getHTTPMockedBinding(t)
3030

3131
explainResponseJSON := explainResponse{
32-
Status: completeStatus,
33-
Explanation: map[string]string{"explanation1": "This is the first explanation"},
32+
Status: completeStatus,
33+
Explanation: map[string]string{
34+
"explanation1": "This is the first explanation",
35+
"explanation2": "this is the second explanation",
36+
"explanation3": "this is the third explanation",
37+
"explanation4": "this is the fourth explanation",
38+
"explanation5": "this is the fifth explanation",
39+
},
3440
}
3541

3642
expectedResponseBody, err := json.Marshal(explainResponseJSON)
@@ -49,7 +55,11 @@ func TestExplainWithOptions(t *testing.T) {
4955
err = json.Unmarshal(expectedResponseBody, &exptectedExplanationsResponse)
5056
assert.NoError(t, err)
5157
expectedResExplanations := exptectedExplanationsResponse.Explanation
52-
assert.Equal(t, expectedResExplanations["explanation1"], explanation[testDiff])
58+
assert.Equal(t, expectedResExplanations["explanation1"], explanation[0])
59+
assert.Equal(t, expectedResExplanations["explanation2"], explanation[1])
60+
assert.Equal(t, expectedResExplanations["explanation3"], explanation[2])
61+
assert.Equal(t, expectedResExplanations["explanation4"], explanation[3])
62+
assert.Equal(t, expectedResExplanations["explanation5"], explanation[4])
5363
})
5464

5565
t.Run("runExplain error", func(t *testing.T) {

0 commit comments

Comments
 (0)