Skip to content

Commit eff2761

Browse files
feat!: explain endpoint is passed in the params and not upon client instantiation (#88)
2 parents a2a8f1a + 4b80523 commit eff2761

File tree

7 files changed

+69
-75
lines changed

7 files changed

+69
-75
lines changed

llm/api_client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain
3131
}
3232
logger.Debug().Str("payload body: %s\n", string(requestBody)).Msg("Marshaled payload")
3333

34-
u := d.endpoint
34+
u := options.Endpoint
3535
if u == nil {
3636
u, err = url.Parse(defaultEndpointURL)
3737
if err != nil {

llm/api_client_test.go

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ func TestDeepcodeLLMBinding_runExplain(t *testing.T) {
9292

9393
u, err := url.Parse(server.URL)
9494
assert.NoError(t, err)
95+
tt.options.Endpoint = u
9596

96-
d := NewDeepcodeLLMBinding(WithEndpoint(u))
97+
d := NewDeepcodeLLMBinding()
9798

9899
ctx := context.Background()
99100
ctx = observability.GetContextWithTraceId(ctx, "test-trace-id")
@@ -156,6 +157,55 @@ func TestDeepcodeLLMBinding_explainRequestBody(t *testing.T) {
156157
})
157158
}
158159

160+
func TestEndpoint(t *testing.T) {
161+
testCases := []struct {
162+
name string
163+
inputURL string
164+
expected url.URL
165+
}{
166+
{
167+
name: "Valid URL",
168+
inputURL: "http://localhost:8080",
169+
expected: url.URL{Scheme: "http", Host: "localhost:8080"},
170+
},
171+
{
172+
name: "URL with Path",
173+
inputURL: "https://example.com/path/to/resource",
174+
expected: url.URL{Scheme: "https", Host: "example.com", Path: "/path/to/resource"},
175+
},
176+
{
177+
name: "URL with Query Params",
178+
inputURL: "http://api.example.com?param1=value1&param2=value2",
179+
expected: url.URL{Scheme: "http", Host: "api.example.com", RawQuery: "param1=value1&param2=value2"},
180+
},
181+
}
182+
183+
for _, tc := range testCases {
184+
t.Run(tc.name, func(t *testing.T) {
185+
parsedURL, err := url.Parse(tc.inputURL)
186+
if err != nil {
187+
t.Fatalf("Failed to parse URL: %v", err)
188+
}
189+
190+
options := &ExplainOptions{}
191+
options.Endpoint = parsedURL
192+
193+
if options.Endpoint.Scheme != tc.expected.Scheme {
194+
t.Errorf("Expected Scheme: %s, Got: %s", tc.expected.Scheme, options.Endpoint.Scheme)
195+
}
196+
if options.Endpoint.Host != tc.expected.Host {
197+
t.Errorf("Expected Host: %s, Got: %s", tc.expected.Host, options.Endpoint.Host)
198+
}
199+
if options.Endpoint.Path != tc.expected.Path {
200+
t.Errorf("Expected Path: %s, Got: %s", tc.expected.Path, options.Endpoint.Path)
201+
}
202+
if options.Endpoint.RawQuery != tc.expected.RawQuery {
203+
t.Errorf("Expected RawQuery: %s, Got: %s", tc.expected.RawQuery, options.Endpoint.RawQuery)
204+
}
205+
})
206+
}
207+
}
208+
159209
// Helper function for testing
160210
func testLogger(t *testing.T) *zerolog.Logger {
161211
t.Helper()

llm/binding.go

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ const JSON OutputFormat = "json"
1919
const MarkDown OutputFormat = "md"
2020

2121
type AIRequest struct {
22-
Id string `json:"id"`
23-
Input string `json:"inputs"`
22+
Id string `json:"id"`
23+
Input string `json:"inputs"`
24+
Endpoint *url.URL `json:"endpoint"`
2425
}
2526

2627
var _ DeepCodeLLMBinding = (*DeepCodeLLMBindingImpl)(nil)
@@ -58,7 +59,6 @@ type DeepCodeLLMBindingImpl struct {
5859
logger *zerolog.Logger
5960
outputFormat OutputFormat
6061
instrumentor observability.Instrumentor
61-
endpoint *url.URL
6262
}
6363

6464
func (d *DeepCodeLLMBindingImpl) ExplainWithOptions(ctx context.Context, options ExplainOptions) (ExplainResult, error) {
@@ -103,12 +103,6 @@ func (d *DeepCodeLLMBindingImpl) Explain(ctx context.Context, input AIRequest, _
103103
}
104104

105105
func NewDeepcodeLLMBinding(opts ...Option) *DeepCodeLLMBindingImpl {
106-
endpoint, err := url.Parse(defaultEndpointURL)
107-
if err != nil {
108-
// time to panic, as our default should never be invalid
109-
panic(err)
110-
}
111-
112106
nopLogger := zerolog.Nop()
113107
binding := &DeepCodeLLMBindingImpl{
114108
logger: &nopLogger,
@@ -121,7 +115,6 @@ func NewDeepcodeLLMBinding(opts ...Option) *DeepCodeLLMBindingImpl {
121115
},
122116
outputFormat: MarkDown,
123117
instrumentor: observability.NewInstrumentor(),
124-
endpoint: endpoint,
125118
}
126119
for _, opt := range opts {
127120
opt(binding)

llm/binding_smoke_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package llm
22

33
import (
44
"context"
5+
"net/url"
56
"testing"
67

78
"github.com/google/uuid"
@@ -20,6 +21,9 @@ func TestDeepcodeLLMBinding_Explain_Smoke(t *testing.T) {
2021
WithLogger(&logger),
2122
)
2223
outputChain := make(chan string)
23-
err := binding.Explain(context.Background(), AIRequest{Id: uuid.New().String(), Input: "{}"}, HTML, outputChain)
24+
endpoint, errEndpoint := url.Parse(defaultEndpointURL)
25+
assert.NoError(t, errEndpoint)
26+
27+
err := binding.Explain(context.Background(), AIRequest{Id: uuid.New().String(), Input: "{}", Endpoint: endpoint}, HTML, outputChain)
2428
assert.NoError(t, err)
2529
}

llm/binding_test.go

Lines changed: 4 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestDeepcodeLLMBinding_PublishIssues(t *testing.T) {
2626

2727
func TestExplainWithOptions(t *testing.T) {
2828
t.Run("success", func(t *testing.T) {
29-
d, mockHTTPClient := getHTTPMockedBinding(t, &url.URL{Scheme: "http", Host: "test.com"})
29+
d, mockHTTPClient := getHTTPMockedBinding(t)
3030

3131
explainResponseJSON := explainResponse{
3232
Status: completeStatus,
@@ -42,7 +42,8 @@ func TestExplainWithOptions(t *testing.T) {
4242
}
4343
mockHTTPClient.EXPECT().Do(gomock.Any()).Return(&mockResponse, nil)
4444
testDiff := "test diff"
45-
explanation, err := d.ExplainWithOptions(context.Background(), ExplainOptions{Diffs: []string{testDiff}})
45+
endpoint := &url.URL{Scheme: "http", Host: "test.com"}
46+
explanation, err := d.ExplainWithOptions(context.Background(), ExplainOptions{Diffs: []string{testDiff}, Endpoint: endpoint})
4647
assert.NoError(t, err)
4748
var exptectedExplanationsResponse explainResponse
4849
err = json.Unmarshal(expectedResponseBody, &exptectedExplanationsResponse)
@@ -56,13 +57,12 @@ func TestExplainWithOptions(t *testing.T) {
5657
})
5758
}
5859

59-
func getHTTPMockedBinding(t *testing.T, endpoint *url.URL) (*DeepCodeLLMBindingImpl, *mocks.MockHTTPClient) {
60+
func getHTTPMockedBinding(t *testing.T) (*DeepCodeLLMBindingImpl, *mocks.MockHTTPClient) {
6061
t.Helper()
6162
ctrl := gomock.NewController(t)
6263
mockHTTPClient := mocks.NewMockHTTPClient(ctrl)
6364
d := NewDeepcodeLLMBinding(
6465
WithHTTPClient(func() http.HTTPClient { return mockHTTPClient }),
65-
WithEndpoint(endpoint),
6666
)
6767
return d, mockHTTPClient
6868
}
@@ -83,7 +83,6 @@ func TestNewDeepcodeLLMBinding(t *testing.T) {
8383
func TestNewDeepcodeLLMBinding_Defaults(t *testing.T) {
8484
binding := NewDeepcodeLLMBinding()
8585

86-
assert.NotNil(t, binding.endpoint)
8786
assert.NotNil(t, binding.logger)
8887
assert.NotNil(t, binding.httpClientFunc)
8988
assert.NotNil(t, binding.instrumentor)
@@ -128,55 +127,6 @@ func TestWithOutputFormat(t *testing.T) {
128127
assert.Equal(t, MarkDown, binding.outputFormat)
129128
}
130129

131-
func TestWithEndpoint(t *testing.T) {
132-
testCases := []struct {
133-
name string
134-
inputURL string
135-
expected url.URL
136-
}{
137-
{
138-
name: "Valid URL",
139-
inputURL: "http://localhost:8080",
140-
expected: url.URL{Scheme: "http", Host: "localhost:8080"},
141-
},
142-
{
143-
name: "URL with Path",
144-
inputURL: "https://example.com/path/to/resource",
145-
expected: url.URL{Scheme: "https", Host: "example.com", Path: "/path/to/resource"},
146-
},
147-
{
148-
name: "URL with Query Params",
149-
inputURL: "http://api.example.com?param1=value1&param2=value2",
150-
expected: url.URL{Scheme: "http", Host: "api.example.com", RawQuery: "param1=value1&param2=value2"},
151-
},
152-
}
153-
154-
for _, tc := range testCases {
155-
t.Run(tc.name, func(t *testing.T) {
156-
parsedURL, err := url.Parse(tc.inputURL)
157-
if err != nil {
158-
t.Fatalf("Failed to parse URL: %v", err)
159-
}
160-
161-
binding := &DeepCodeLLMBindingImpl{}
162-
WithEndpoint(parsedURL)(binding)
163-
164-
if binding.endpoint.Scheme != tc.expected.Scheme {
165-
t.Errorf("Expected Scheme: %s, Got: %s", tc.expected.Scheme, binding.endpoint.Scheme)
166-
}
167-
if binding.endpoint.Host != tc.expected.Host {
168-
t.Errorf("Expected Host: %s, Got: %s", tc.expected.Host, binding.endpoint.Host)
169-
}
170-
if binding.endpoint.Path != tc.expected.Path {
171-
t.Errorf("Expected Path: %s, Got: %s", tc.expected.Path, binding.endpoint.Path)
172-
}
173-
if binding.endpoint.RawQuery != tc.expected.RawQuery {
174-
t.Errorf("Expected RawQuery: %s, Got: %s", tc.expected.RawQuery, binding.endpoint.RawQuery)
175-
}
176-
})
177-
}
178-
}
179-
180130
func TestWithInstrumentor(t *testing.T) {
181131
// Test case 1: Provide a mock instrumentor
182132
binding := &DeepCodeLLMBindingImpl{}

llm/options.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package llm
22

33
import (
4-
"net/url"
5-
64
"github.com/rs/zerolog"
75
"github.com/snyk/code-client-go/http"
86
"github.com/snyk/code-client-go/observability"
@@ -16,12 +14,6 @@ func WithHTTPClient(httpClientFunc func() http.HTTPClient) func(*DeepCodeLLMBind
1614
}
1715
}
1816

19-
func WithEndpoint(endpoint *url.URL) func(*DeepCodeLLMBindingImpl) {
20-
return func(binding *DeepCodeLLMBindingImpl) {
21-
binding.endpoint = endpoint
22-
}
23-
}
24-
2517
func WithLogger(logger *zerolog.Logger) func(*DeepCodeLLMBindingImpl) {
2618
return func(binding *DeepCodeLLMBindingImpl) {
2719
binding.logger = logger

llm/types.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package llm
22

3+
import "net/url"
4+
35
type explanationLength string
46

57
const (
@@ -58,4 +60,7 @@ type ExplainOptions struct {
5860

5961
// fix difference
6062
Diffs []string `json:"diffs"`
63+
64+
// Endpoint to call
65+
Endpoint *url.URL `json:"endpoint"`
6166
}

0 commit comments

Comments
 (0)