Skip to content

Commit df5cf42

Browse files
committed
Fix Cohere and remove lazy list initialization
1 parent 6fca813 commit df5cf42

File tree

3 files changed

+34
-30
lines changed

3 files changed

+34
-30
lines changed

src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public static class Builder implements ObjectBuilder<AnthropicGenerative> {
5858
private Integer maxTokens;
5959
private Float temperature;
6060
private String baseUrl;
61-
private List<String> stopSequences;
61+
private final List<String> stopSequences = new ArrayList<>();
6262

6363
/** Base URL of the generative provider. */
6464
public Builder baseUrl(String baseUrl) {
@@ -101,9 +101,6 @@ public Builder stopSequences(String... stopSequences) {
101101
* Set tokens which should signal the model to stop generating further output.
102102
*/
103103
public Builder stopSequences(List<String> stopSequences) {
104-
if (this.stopSequences == null) {
105-
this.stopSequences = new ArrayList<>();
106-
}
107104
this.stopSequences.addAll(stopSequences);
108105
return this;
109106
}
@@ -202,9 +199,9 @@ public static class Builder implements ObjectBuilder<AnthropicGenerative.Provide
202199
private String model;
203200
private Integer maxTokens;
204201
private Float temperature;
205-
private List<String> stopSequences;
206-
private List<String> images;
207-
private List<String> imageProperties;
202+
private final List<String> stopSequences = new ArrayList<>();
203+
private final List<String> images = new ArrayList<>();
204+
private final List<String> imageProperties = new ArrayList<>();
208205

209206
/** Base URL of the generative provider. */
210207
public Builder baseUrl(String baseUrl) {
@@ -247,9 +244,6 @@ public Builder stopSequences(String... stopSequences) {
247244
* Set tokens which should signal the model to stop generating further output.
248245
*/
249246
public Builder stopSequences(List<String> stopSequences) {
250-
if (this.stopSequences == null) {
251-
this.stopSequences = new ArrayList<>();
252-
}
253247
this.stopSequences.addAll(stopSequences);
254248
return this;
255249
}
@@ -259,9 +253,6 @@ public Builder images(String... images) {
259253
}
260254

261255
public Builder images(List<String> images) {
262-
if (this.images == null) {
263-
this.images = new ArrayList<>();
264-
}
265256
this.images.addAll(images);
266257
return this;
267258
}
@@ -271,9 +262,6 @@ public Builder imageProperties(String... imageProperties) {
271262
}
272263

273264
public Builder imageProperties(List<String> imageProperties) {
274-
if (this.imageProperties == null) {
275-
this.imageProperties = new ArrayList<>();
276-
}
277265
this.imageProperties.addAll(imageProperties);
278266
return this;
279267
}

src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public Builder(Service service, String region) {
9090
private Integer maxTokensToSample;
9191
private Float topP;
9292
private Integer topK;
93-
private List<String> stopSequences;
93+
private final List<String> stopSequences = new ArrayList<>();
9494

9595
/** Base URL of the generative provider. */
9696
protected Builder endpoint(String endpoint) {
@@ -153,9 +153,6 @@ public Builder stopSequences(String... stopSequences) {
153153

154154
/** Stop sequences for the model. */
155155
public Builder stopSequences(List<String> stopSequences) {
156-
if (this.stopSequences == null) {
157-
this.stopSequences = new ArrayList<>();
158-
}
159156
this.stopSequences.addAll(stopSequences);
160157
return this;
161158
}

src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ public record CohereGenerative(
2020
@SerializedName("maxTokens") Integer maxTokens,
2121
@SerializedName("temperature") Float temperature,
2222
@SerializedName("returnLikelihoods") String returnLikelihoodsProperty,
23-
@SerializedName("stopSequences") List<String> stopSequences) implements Generative {
23+
@SerializedName("stopSequences") List<String> stopSequences,
24+
@SerializedName("P") Float topP,
25+
@SerializedName("presencePenalty") Float presencePenalty,
26+
@SerializedName("frequencyPenalty") Float frequencyPenalty) implements Generative {
2427

2528
@Override
2629
public Kind _kind() {
@@ -48,7 +51,10 @@ public CohereGenerative(Builder builder) {
4851
builder.maxTokens,
4952
builder.temperature,
5053
builder.returnLikelihoodsProperty,
51-
builder.stopSequences);
54+
builder.stopSequences,
55+
builder.topP,
56+
builder.presencePenalty,
57+
builder.frequencyPenalty);
5258
}
5359

5460
public static class Builder implements ObjectBuilder<CohereGenerative> {
@@ -58,7 +64,10 @@ public static class Builder implements ObjectBuilder<CohereGenerative> {
5864
private Integer maxTokens;
5965
private Float temperature;
6066
private String returnLikelihoodsProperty;
61-
private List<String> stopSequences;
67+
private final List<String> stopSequences = new ArrayList<>();
68+
private Float topP;
69+
private Float presencePenalty;
70+
private Float frequencyPenalty;
6271

6372
/** Base URL of the generative provider. */
6473
public Builder baseUrl(String baseUrl) {
@@ -72,6 +81,12 @@ public Builder topK(int topK) {
7281
return this;
7382
}
7483

84+
/** Top P value for nucleus sampling. */
85+
public Builder topP(float topP) {
86+
this.topP = topP;
87+
return this;
88+
}
89+
7590
/** Select generative model. */
7691
public Builder model(String model) {
7792
this.model = model;
@@ -100,9 +115,6 @@ public Builder stopSequences(String... stopSequences) {
100115
* Set tokens which should signal the model to stop generating further output.
101116
*/
102117
public Builder stopSequences(List<String> stopSequences) {
103-
if (this.stopSequences == null) {
104-
this.stopSequences = new ArrayList<>();
105-
}
106118
this.stopSequences.addAll(stopSequences);
107119
return this;
108120
}
@@ -116,6 +128,16 @@ public Builder temperature(float temperature) {
116128
return this;
117129
}
118130

131+
public Builder presencePenalty(float presencePenalty) {
132+
this.presencePenalty = presencePenalty;
133+
return this;
134+
}
135+
136+
public Builder frequencyPenalty(float frequencyPenalty) {
137+
this.frequencyPenalty = frequencyPenalty;
138+
return this;
139+
}
140+
119141
@Override
120142
public CohereGenerative build() {
121143
return new CohereGenerative(this);
@@ -211,7 +233,7 @@ public static class Builder implements ObjectBuilder<CohereGenerative.Provider>
211233
private Float temperature;
212234
private Float frequencyPenalty;
213235
private Float presencePenalty;
214-
private List<String> stopSequences;
236+
private final List<String> stopSequences = new ArrayList<>();
215237

216238
/** Base URL of the generative provider. */
217239
public Builder baseUrl(String baseUrl) {
@@ -265,9 +287,6 @@ public Builder stopSequences(String... stopSequences) {
265287
* Set tokens which should signal the model to stop generating further output.
266288
*/
267289
public Builder stopSequences(List<String> stopSequences) {
268-
if (this.stopSequences == null) {
269-
this.stopSequences = new ArrayList<>();
270-
}
271290
this.stopSequences.addAll(stopSequences);
272291
return this;
273292
}

0 commit comments

Comments
 (0)