Skip to content

Commit e723980

Browse files
authored
fix: interface signature for Explain API (#75)
1 parent 147917c commit e723980

File tree

6 files changed

+52
-42
lines changed

6 files changed

+52
-42
lines changed

llm/api_client.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ const (
1717
defaultEndpointURL = "http://localhost:10000/explain"
1818
)
1919

20-
func (d *DeepcodeLLMBinding) runExplain(ctx context.Context, options ExplainOptions) (explainResponse, error) {
20+
func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options ExplainOptions) (explainResponse, error) {
2121
span := d.instrumentor.StartSpan(ctx, "code.RunExplain")
2222
defer span.Finish()
2323

@@ -46,7 +46,7 @@ func (d *DeepcodeLLMBinding) runExplain(ctx context.Context, options ExplainOpti
4646
return explainResponse{}, err
4747
}
4848
}
49-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), bytes.NewBuffer(requestBody))
49+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewBuffer(requestBody))
5050
if err != nil {
5151
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error creating request")
5252
return explainResponse{}, err
@@ -84,7 +84,7 @@ func (d *DeepcodeLLMBinding) runExplain(ctx context.Context, options ExplainOpti
8484
return response, nil
8585
}
8686

87-
func (d *DeepcodeLLMBinding) explainRequestBody(options *ExplainOptions) ([]byte, error) {
87+
func (d *DeepCodeLLMBindingImpl) explainRequestBody(options *ExplainOptions) ([]byte, error) {
8888
logger := d.logger.With().Str("method", "code.explainRequestBody").Logger()
8989

9090
var request explainRequest
@@ -108,7 +108,7 @@ func (d *DeepcodeLLMBinding) explainRequestBody(options *ExplainOptions) ([]byte
108108
return requestBody, err
109109
}
110110

111-
func (d *DeepcodeLLMBinding) addDefaultHeaders(req *http.Request, requestId string) {
111+
func (d *DeepCodeLLMBindingImpl) addDefaultHeaders(req *http.Request, requestId string) {
112112
req.Header.Set("snyk-request-id", requestId)
113113
req.Header.Set("Cache-Control", "private, max-age=0, no-cache")
114114
req.Header.Set("Content-Type", "application/json")

llm/api_client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func TestDeepcodeLLMBinding_runExplain(t *testing.T) {
118118
}
119119

120120
func TestDeepcodeLLMBinding_explainRequestBody(t *testing.T) {
121-
d := &DeepcodeLLMBinding{
121+
d := &DeepCodeLLMBindingImpl{
122122
logger: testLogger(t),
123123
}
124124

@@ -172,7 +172,7 @@ func testLogger(t *testing.T) *zerolog.Logger {
172172

173173
// Test with existing headers
174174
func TestAddDefaultHeadersWithExistingHeaders(t *testing.T) {
175-
d := &DeepcodeLLMBinding{} // Initialize your struct if needed
175+
d := &DeepCodeLLMBindingImpl{} // Initialize your struct if needed
176176
req := &http.Request{Header: http.Header{"Existing-Header": {"existing-value"}}}
177177
requestId := "test-request-id"
178178

llm/binding.go

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net/url"
77

88
"github.com/rs/zerolog"
9+
910
"github.com/snyk/code-client-go/http"
1011

1112
"github.com/snyk/code-client-go/observability"
@@ -22,41 +23,45 @@ type AIRequest struct {
2223
Input string `json:"inputs"`
2324
}
2425

26+
var _ DeepCodeLLMBinding = (*DeepCodeLLMBindingImpl)(nil)
27+
var _ SnykLLMBindings = (*DeepCodeLLMBindingImpl)(nil)
28+
2529
type SnykLLMBindings interface {
2630
// PublishIssues sends issues to an LLM for further processing.
2731
// the map in the slice of issues map is a json representation of json key : value
2832
// In case of errors, they are returned
29-
PublishIssues(issues []map[string]string) error
33+
PublishIssues(ctx context.Context, issues []map[string]string) error
3034

3135
// Explain forwards an input and desired output format to an LLM to
3236
// receive an explanation. The implementation should alter the LLM
3337
// prompt to honor the output format, but is not required to enforce
3438
// the format. The results should be streamed into the given channel
3539
//
3640
// Parameters:
41+
// ctx - request context
3742
// input - the thing to be explained as a string
3843
// format - the requested outputFormat
3944
// output - a channel that can be used to stream the results
40-
Explain(input AIRequest, format OutputFormat, output chan<- string) error
45+
Explain(ctx context.Context, input AIRequest, format OutputFormat, output chan<- string) error
4146
}
4247

4348
type DeepCodeLLMBinding interface {
4449
SnykLLMBindings
45-
ExplainWithOptions(options ExplainOptions) (string, error)
50+
ExplainWithOptions(ctx context.Context, options ExplainOptions) (string, error)
4651
}
4752

48-
// DeepcodeLLMBinding is an LLM binding for the Snyk Code LLM.
53+
// DeepCodeLLMBindingImpl is an LLM binding for the Snyk Code LLM.
4954
// Currently, it only supports explain.
50-
type DeepcodeLLMBinding struct {
55+
type DeepCodeLLMBindingImpl struct {
5156
httpClientFunc func() http.HTTPClient
5257
logger *zerolog.Logger
5358
outputFormat OutputFormat
5459
instrumentor observability.Instrumentor
5560
endpoint *url.URL
5661
}
5762

58-
func (d *DeepcodeLLMBinding) ExplainWithOptions(options ExplainOptions) (string, error) {
59-
s := d.instrumentor.StartSpan(context.Background(), "code.ExplainWithOptions")
63+
func (d *DeepCodeLLMBindingImpl) ExplainWithOptions(ctx context.Context, options ExplainOptions) (string, error) {
64+
s := d.instrumentor.StartSpan(ctx, "code.ExplainWithOptions")
6065
defer d.instrumentor.Finish(s)
6166
response, err := d.runExplain(s.Context(), options)
6267
if err != nil {
@@ -66,33 +71,33 @@ func (d *DeepcodeLLMBinding) ExplainWithOptions(options ExplainOptions) (string,
6671
return response.Explanation, nil
6772
}
6873

69-
func (d *DeepcodeLLMBinding) PublishIssues(issues []map[string]string) error {
74+
func (d *DeepCodeLLMBindingImpl) PublishIssues(_ context.Context, _ []map[string]string) error {
7075
panic("implement me")
7176
}
7277

73-
func (d *DeepcodeLLMBinding) Explain(input string, format OutputFormat, output chan<- string) error {
78+
func (d *DeepCodeLLMBindingImpl) Explain(ctx context.Context, input AIRequest, _ OutputFormat, output chan<- string) error {
7479
var options ExplainOptions
75-
err := json.Unmarshal([]byte(input), &options)
80+
err := json.Unmarshal([]byte(input.Input), &options)
7681
if err != nil {
7782
return err
7883
}
79-
response, err := d.ExplainWithOptions(options)
84+
response, err := d.ExplainWithOptions(ctx, options)
8085
if err != nil {
8186
return err
8287
}
8388
output <- response
8489
return nil
8590
}
8691

87-
func NewDeepcodeLLMBinding(opts ...Option) *DeepcodeLLMBinding {
92+
func NewDeepcodeLLMBinding(opts ...Option) *DeepCodeLLMBindingImpl {
8893
endpoint, err := url.Parse(defaultEndpointURL)
8994
if err != nil {
9095
// time to panic, as our default should never be invalid
9196
panic(err)
9297
}
9398

9499
nopLogger := zerolog.Nop()
95-
binding := &DeepcodeLLMBinding{
100+
binding := &DeepCodeLLMBindingImpl{
96101
logger: &nopLogger,
97102
httpClientFunc: func() http.HTTPClient {
98103
return http.NewHTTPClient(

llm/binding_smoke_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package llm
22

33
import (
4+
"context"
45
"testing"
56

7+
"github.com/google/uuid"
68
"github.com/rs/zerolog"
7-
"github.com/snyk/code-client-go/http"
89
"github.com/stretchr/testify/assert"
10+
11+
"github.com/snyk/code-client-go/http"
912
)
1013

1114
func TestDeepcodeLLMBinding_Explain_Smoke(t *testing.T) {
@@ -17,6 +20,6 @@ func TestDeepcodeLLMBinding_Explain_Smoke(t *testing.T) {
1720
WithLogger(&logger),
1821
)
1922
outputChain := make(chan string)
20-
err := binding.Explain("{}", HTML, outputChain)
23+
err := binding.Explain(context.Background(), AIRequest{Id: uuid.New().String(), Input: "{}"}, HTML, outputChain)
2124
assert.NoError(t, err)
2225
}

llm/binding_test.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package llm
22

33
import (
4+
"context"
45
"encoding/json"
56
"io"
67
http2 "net/http"
@@ -10,16 +11,17 @@ import (
1011

1112
"github.com/golang/mock/gomock"
1213
"github.com/rs/zerolog"
14+
"github.com/stretchr/testify/assert"
15+
1316
"github.com/snyk/code-client-go/http"
1417
"github.com/snyk/code-client-go/http/mocks"
15-
"github.com/stretchr/testify/assert"
1618

1719
"github.com/snyk/code-client-go/observability"
1820
)
1921

2022
func TestDeepcodeLLMBinding_PublishIssues(t *testing.T) {
2123
binding := NewDeepcodeLLMBinding()
22-
assert.PanicsWithValue(t, "implement me", func() { _ = binding.PublishIssues([]map[string]string{}) })
24+
assert.PanicsWithValue(t, "implement me", func() { _ = binding.PublishIssues(context.Background(), []map[string]string{}) })
2325
}
2426

2527
func TestExplainWithOptions(t *testing.T) {
@@ -39,7 +41,7 @@ func TestExplainWithOptions(t *testing.T) {
3941
Body: io.NopCloser(strings.NewReader(string(expectedResponseBody))),
4042
}
4143
mockHTTPClient.EXPECT().Do(gomock.Any()).Return(&mockResponse, nil)
42-
explanation, err := d.ExplainWithOptions(ExplainOptions{})
44+
explanation, err := d.ExplainWithOptions(context.Background(), ExplainOptions{})
4345
assert.NoError(t, err)
4446
assert.Equal(t, explainResponseJSON.Explanation, explanation)
4547
})
@@ -49,7 +51,7 @@ func TestExplainWithOptions(t *testing.T) {
4951
})
5052
}
5153

52-
func getHTTPMockedBinding(t *testing.T, endpoint *url.URL) (*DeepcodeLLMBinding, *mocks.MockHTTPClient) {
54+
func getHTTPMockedBinding(t *testing.T, endpoint *url.URL) (*DeepCodeLLMBindingImpl, *mocks.MockHTTPClient) {
5355
t.Helper()
5456
ctrl := gomock.NewController(t)
5557
mockHTTPClient := mocks.NewMockHTTPClient(ctrl)
@@ -84,14 +86,14 @@ func TestNewDeepcodeLLMBinding_Defaults(t *testing.T) {
8486

8587
func TestWithHTTPClient(t *testing.T) {
8688
client := http.NewHTTPClient(http.NewDefaultClientFactory())
87-
binding := &DeepcodeLLMBinding{}
89+
binding := &DeepCodeLLMBindingImpl{}
8890
WithHTTPClient(func() http.HTTPClient { return client })(binding)
8991
assert.Equal(t, client, binding.httpClientFunc())
9092
}
9193

9294
func TestWithLogger(t *testing.T) {
9395
logger := zerolog.Nop()
94-
binding := &DeepcodeLLMBinding{}
96+
binding := &DeepCodeLLMBindingImpl{}
9597
WithLogger(&logger)(binding)
9698
assert.Equal(t, &logger, binding.logger)
9799
}
@@ -104,7 +106,7 @@ func TestOutputFormatConstants(t *testing.T) {
104106
}
105107

106108
func TestWithOutputFormat(t *testing.T) {
107-
binding := &DeepcodeLLMBinding{}
109+
binding := &DeepCodeLLMBindingImpl{}
108110

109111
// Test setting valid output formats
110112
WithOutputFormat(JSON)(binding)
@@ -151,7 +153,7 @@ func TestWithEndpoint(t *testing.T) {
151153
t.Fatalf("Failed to parse URL: %v", err)
152154
}
153155

154-
binding := &DeepcodeLLMBinding{}
156+
binding := &DeepCodeLLMBindingImpl{}
155157
WithEndpoint(parsedURL)(binding)
156158

157159
if binding.endpoint.Scheme != tc.expected.Scheme {
@@ -172,15 +174,15 @@ func TestWithEndpoint(t *testing.T) {
172174

173175
func TestWithInstrumentor(t *testing.T) {
174176
// Test case 1: Provide a mock instrumentor
175-
binding := &DeepcodeLLMBinding{}
177+
binding := &DeepCodeLLMBindingImpl{}
176178

177179
instrumentor := observability.NewInstrumentor()
178180
WithInstrumentor(instrumentor)(binding)
179181

180182
assert.Equal(t, instrumentor, binding.instrumentor)
181183

182184
// Test case 2: Provide a nil instrumentor (should still set it)
183-
binding = &DeepcodeLLMBinding{} // Reset binding for the next test
185+
binding = &DeepCodeLLMBindingImpl{} // Reset binding for the next test
184186

185187
WithInstrumentor(nil)(binding)
186188

llm/options.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,37 @@ import (
88
"github.com/snyk/code-client-go/observability"
99
)
1010

11-
type Option func(*DeepcodeLLMBinding)
11+
type Option func(*DeepCodeLLMBindingImpl)
1212

13-
func WithHTTPClient(httpClientFunc func() http.HTTPClient) func(*DeepcodeLLMBinding) {
14-
return func(binding *DeepcodeLLMBinding) {
13+
func WithHTTPClient(httpClientFunc func() http.HTTPClient) func(*DeepCodeLLMBindingImpl) {
14+
return func(binding *DeepCodeLLMBindingImpl) {
1515
binding.httpClientFunc = httpClientFunc
1616
}
1717
}
1818

19-
func WithEndpoint(endpoint *url.URL) func(*DeepcodeLLMBinding) {
20-
return func(binding *DeepcodeLLMBinding) {
19+
func WithEndpoint(endpoint *url.URL) func(*DeepCodeLLMBindingImpl) {
20+
return func(binding *DeepCodeLLMBindingImpl) {
2121
binding.endpoint = endpoint
2222
}
2323
}
2424

25-
func WithLogger(logger *zerolog.Logger) func(*DeepcodeLLMBinding) {
26-
return func(binding *DeepcodeLLMBinding) {
25+
func WithLogger(logger *zerolog.Logger) func(*DeepCodeLLMBindingImpl) {
26+
return func(binding *DeepCodeLLMBindingImpl) {
2727
binding.logger = logger
2828
}
2929
}
3030

31-
func WithOutputFormat(outputFormat OutputFormat) func(*DeepcodeLLMBinding) {
32-
return func(binding *DeepcodeLLMBinding) {
31+
func WithOutputFormat(outputFormat OutputFormat) func(*DeepCodeLLMBindingImpl) {
32+
return func(binding *DeepCodeLLMBindingImpl) {
3333
if outputFormat != HTML && outputFormat != JSON && outputFormat != MarkDown {
3434
return
3535
}
3636
binding.outputFormat = outputFormat
3737
}
3838
}
3939

40-
func WithInstrumentor(instrumentor observability.Instrumentor) func(*DeepcodeLLMBinding) {
41-
return func(binding *DeepcodeLLMBinding) {
40+
func WithInstrumentor(instrumentor observability.Instrumentor) func(*DeepCodeLLMBindingImpl) {
41+
return func(binding *DeepCodeLLMBindingImpl) {
4242
binding.instrumentor = instrumentor
4343
}
4444
}

0 commit comments

Comments
 (0)