Skip to content

Commit b3050a2

Browse files
committed
feat: support jwtoken and preset endpoint
* feat: restore test file location * feat: update error message * Revert "feat: fix typo in comment" * feat: fix typo in comment * feat: set resource type for model id * Revert "feat: support model id" * feat: support model id * feat: support preset ep for jwtoken See merge request: !771
1 parent 9ddd58a commit b3050a2

File tree

2 files changed

+40
-13
lines changed

2 files changed

+40
-13
lines changed

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/Const.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ public class Const {
88
public static final String SERVER_REQUEST_HEADER = "X-Request-Id";
99
public static final String REQUEST_MODEL = "X-Request-Model";
1010
public static final String REQUEST_BOT = "X-Request-Bot";
11+
public static final String REQUEST_PROJECT_NAME = "X-Project-Name";
1112
public static final String RETRY_AFTER = "Retry-After";
1213
public static final String REQUEST_BOT_ID = "botId";
1314
public static final Integer DEFAULT_MANDATORY_REFRESH_TIMEOUT = 10 * 60; // 10 min
@@ -16,6 +17,7 @@ public class Const {
1617

1718
public static final String RESOURCE_TYPE_BOT = "bot";
1819
public static final String RESOURCE_TYPE_ENDPOINT = "endpoint";
20+
public static final String RESOURCE_TYPE_PRESETENDPOINT = "presetendpoint";
1921

2022
public static final String CONTEXT_MODE_SESSION = "session";
2123
public static final String CONTEXT_MODE_COMMON_PREFIX = "common_prefix";

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/interceptor/ArkResourceStsAuthenticationInterceptor.java

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,40 @@ public ArkResourceStsAuthenticationInterceptor(String ak, String sk, String regi
5151
@Override
5252
public Response intercept(Chain chain) throws IOException {
5353
Request request = chain.request();
54-
String requestResourceType = getRequestResourceType(request);
5554
String requestResourceId = getRequestResourceId(request);
55+
String requestResourceType = getRequestResourceType(request, requestResourceId);
56+
String projectName = getProjectName(request);
57+
58+
if (requestResourceType.equalsIgnoreCase(Const.RESOURCE_TYPE_PRESETENDPOINT) && StringUtils.isBlank(projectName)) {
59+
throw new ArkException("project name is required for preset endpoint");
60+
}
5661

5762
if (request.url().url().getPath().contains("contents/generations") || request.url().url().getPath().contains("images/generations")) {
5863
throw new ArkException("content generation currently does not support ak&sk authentication, use api_key instead.");
5964
}
6065

6166
Request newRequest = chain.request()
6267
.newBuilder()
63-
.header("Authorization", "Bearer " + getResourceStsToken(requestResourceType, requestResourceId))
68+
.header("Authorization", "Bearer " + getResourceStsToken(requestResourceType, requestResourceId, projectName))
6469
.build();
6570
return chain.proceed(newRequest);
6671
}
6772

68-
private String getRequestResourceType(Request request) {
73+
private String getRequestResourceType(Request request, String requestResourceId) {
6974
if (StringUtils.isNotBlank(request.header(Const.REQUEST_BOT))) {
7075
return Const.RESOURCE_TYPE_BOT;
7176
}
72-
return Const.RESOURCE_TYPE_ENDPOINT;
77+
78+
if (StringUtils.isNotBlank(requestResourceId) && requestResourceId.startsWith("ep-m-")) {
79+
return Const.RESOURCE_TYPE_PRESETENDPOINT;
80+
}
81+
82+
if (StringUtils.isNotBlank(requestResourceId) && requestResourceId.startsWith("ep-")) {
83+
return Const.RESOURCE_TYPE_ENDPOINT;
84+
}
85+
86+
// for model id
87+
return Const.RESOURCE_TYPE_PRESETENDPOINT;
7388
}
7489

7590
private String getRequestResourceId(Request request) {
@@ -79,8 +94,15 @@ private String getRequestResourceId(Request request) {
7994
return request.header(Const.REQUEST_MODEL);
8095
}
8196

82-
private String getResourceStsToken(String resourceType, String resourceId) {
83-
refresh(resourceType, resourceId);
97+
private String getProjectName(Request request) {
98+
if (StringUtils.isNotBlank(request.header(Const.REQUEST_PROJECT_NAME))) {
99+
return request.header(Const.REQUEST_PROJECT_NAME);
100+
}
101+
return "";
102+
}
103+
104+
private String getResourceStsToken(String resourceType, String resourceId, String projectName) {
105+
refresh(resourceType, resourceId, projectName);
84106

85107
ArkResourceStsTokenInfo tokenInfo = this.resourceStsTokens.get(getResourceKey(resourceType, resourceId));
86108
if (tokenInfo == null) {
@@ -89,7 +111,7 @@ private String getResourceStsToken(String resourceType, String resourceId) {
89111
return tokenInfo.getToken();
90112
}
91113

92-
private void refresh(String resourceType, String resourceId) {
114+
private void refresh(String resourceType, String resourceId, String projectName) {
93115
if (!need_refresh(resourceType, resourceId, this.advisoryRefreshTimeout)) {
94116
return;
95117
}
@@ -101,7 +123,7 @@ private void refresh(String resourceType, String resourceId) {
101123

102124
try {
103125
boolean isMandatoryRefresh = need_refresh(resourceType, resourceId, this.mandatoryRefreshTimeout);
104-
protectedRefresh(resourceType, resourceId, isMandatoryRefresh);
126+
protectedRefresh(resourceType, resourceId, isMandatoryRefresh, projectName);
105127
} finally {
106128
lock.writeLock().unlock();
107129
}
@@ -111,7 +133,7 @@ private void refresh(String resourceType, String resourceId) {
111133
if (!need_refresh(resourceType, resourceId, this.mandatoryRefreshTimeout)) {
112134
return;
113135
}
114-
protectedRefresh(resourceType, resourceId, true);
136+
protectedRefresh(resourceType, resourceId, true, projectName);
115137
} finally {
116138
lock.writeLock().unlock();
117139
}
@@ -127,12 +149,12 @@ private boolean need_refresh(String resourceType, String resourceId, Integer ref
127149
return tokenInfo.getExpiredTime() - System.currentTimeMillis() / 1000 < refresh_in;
128150
}
129151

130-
private void protectedRefresh(String resourceType, String resourceId, boolean isMandatory) {
152+
private void protectedRefresh(String resourceType, String resourceId, boolean isMandatory, String projectName) {
131153
this.resourceStsTokens.compute(getResourceKey(resourceType, resourceId), new BiFunction<String, ArkResourceStsTokenInfo, ArkResourceStsTokenInfo>() {
132154
@Override
133155
public ArkResourceStsTokenInfo apply(String s, ArkResourceStsTokenInfo stringIntegerPair) {
134156
try {
135-
ArkResourceStsTokenInfo tokenInfo = getToken(resourceType, resourceId, Const.DEFAULT_STS_TIMEOUT);
157+
ArkResourceStsTokenInfo tokenInfo = getToken(resourceType, resourceId, Const.DEFAULT_STS_TIMEOUT, projectName);
136158
return tokenInfo;
137159
} catch (ApiException e) {
138160
if (isMandatory) {
@@ -145,17 +167,20 @@ public ArkResourceStsTokenInfo apply(String s, ArkResourceStsTokenInfo stringInt
145167
}
146168

147169
private ArkResourceStsTokenInfo getEndpointToken(String endpointId, Integer ttl) throws ApiException {
148-
return getToken("endpoint", endpointId, ttl);
170+
return getToken("endpoint", endpointId, ttl, "");
149171
}
150172

151-
private ArkResourceStsTokenInfo getToken(String resourceType, String resourceId, Integer ttl) throws ApiException {
173+
private ArkResourceStsTokenInfo getToken(String resourceType, String resourceId, Integer ttl, String projectName) throws ApiException {
152174
if (ttl < this.advisoryRefreshTimeout * 2) {
153175
throw new ArkException("ttl should not be under " + this.advisoryRefreshTimeout * 2 + " seconds.");
154176
}
155177

156178
GetApiKeyRequest r = new GetApiKeyRequest();
157179
r.durationSeconds(ttl);
158180
r.resourceType(resourceType);
181+
if (StringUtils.isNotBlank(projectName)) {
182+
r.projectName(projectName);
183+
}
159184
List<String> list = new ArrayList<>();
160185
list.add(resourceId);
161186
r.resourceIds(list);

0 commit comments

Comments
 (0)