Skip to content

Commit 0377053

Browse files
authored
feat: optimize config valid for runtime auth (#28)
2 parents a0e7052 + 74162c5 commit 0377053

File tree

5 files changed

+192
-12
lines changed

5 files changed

+192
-12
lines changed

agentkit/toolkit/cli/cli_config.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ def config_command(
119119
cr_repo_name: Optional[str] = typer.Option(
120120
None, "--cr_repo_name", "--ve_cr_repo_name", help="CR repository name"
121121
),
122+
cr_auto_create_instance_type: Optional[str] = typer.Option(
123+
None,
124+
"--cr_auto_create_instance_type",
125+
help="CR instance type when auto-creating: Micro/Enterprise",
126+
),
122127
# Runtime configuration parameters
123128
runtime_name: Optional[str] = typer.Option(
124129
None, "--runtime_name", "--ve_runtime_name", help="Runtime instance name"
@@ -135,6 +140,21 @@ def config_command(
135140
"--ve_runtime_apikey_name",
136141
help="Runtime API key secret name",
137142
),
143+
runtime_auth_type: Optional[str] = typer.Option(
144+
None,
145+
"--runtime_auth_type",
146+
help="Runtime authentication type: key_auth/custom_jwt",
147+
),
148+
runtime_jwt_discovery_url: Optional[str] = typer.Option(
149+
None,
150+
"--runtime_jwt_discovery_url",
151+
help="OIDC Discovery URL when runtime_auth_type is custom_jwt",
152+
),
153+
runtime_jwt_allowed_clients: Optional[List[str]] = typer.Option(
154+
None,
155+
"--runtime_jwt_allowed_clients",
156+
help="Allowed OAuth2 client IDs when runtime_auth_type is custom_jwt (can be used multiple times)",
157+
),
138158
):
139159
"""Configure AgentKit (supports interactive and non-interactive modes).
140160
@@ -227,9 +247,13 @@ def config_command(
227247
cr_instance_name=cr_instance_name,
228248
cr_namespace_name=cr_namespace_name,
229249
cr_repo_name=cr_repo_name,
250+
cr_auto_create_instance_type=cr_auto_create_instance_type,
230251
runtime_name=runtime_name,
231252
runtime_role_name=runtime_role_name,
232253
runtime_apikey_name=runtime_apikey_name,
254+
runtime_auth_type=runtime_auth_type,
255+
runtime_jwt_discovery_url=runtime_jwt_discovery_url,
256+
runtime_jwt_allowed_clients=runtime_jwt_allowed_clients,
233257
)
234258

235259
has_cli_params = ConfigParamHandler.has_cli_params(cli_params)

agentkit/toolkit/cli/interactive_config.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,15 @@ def _prompt_for_field(
307307
if len(args) == 2 and type(None) in args:
308308
field_type = args[0]
309309

310+
prompt_condition = metadata.get("prompt_condition")
311+
if prompt_condition:
312+
depends_on = prompt_condition.get("depends_on")
313+
expected_values = prompt_condition.get("values", [])
314+
if depends_on and expected_values:
315+
depend_value = current_config.get(depends_on)
316+
if depend_value not in expected_values:
317+
return current_config.get(name, default)
318+
310319
if get_origin(field_type) is list or field_type is List:
311320
return self._handle_list(description, default, metadata, current, total)
312321

@@ -327,7 +336,6 @@ def _prompt_for_field(
327336
description, metadata, current_config
328337
)
329338

330-
# Conditional validation loop
331339
validation = metadata.get("validation", {})
332340
while True:
333341
# Call specific input handler (basic validation)
@@ -339,8 +347,7 @@ def _prompt_for_field(
339347
enhanced_description, default, metadata, current, total
340348
)
341349

342-
# If conditional validation type, perform conditional validation
343-
if validation.get("type") == "conditional" and value:
350+
if validation.get("type") == "conditional":
344351
errors = self._validate_conditional_value(
345352
name, value, validation, current_config
346353
)
@@ -428,7 +435,11 @@ def _validate_conditional_value(
428435
if depend_value and depend_value in rules:
429436
rule = rules[depend_value]
430437

431-
# choices validation
438+
if rule.get("required") and (
439+
not value or (isinstance(value, str) and not value.strip())
440+
):
441+
errors.append("This field cannot be empty")
442+
432443
if "choices" in rule and value not in rule["choices"]:
433444
msg = rule.get(
434445
"message", f"Must be one of: {', '.join(rule['choices'])}"

agentkit/toolkit/config/config_handler.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,13 @@ def collect_cli_params(
8181
cr_instance_name: Optional[str],
8282
cr_namespace_name: Optional[str],
8383
cr_repo_name: Optional[str],
84+
cr_auto_create_instance_type: Optional[str],
8485
runtime_name: Optional[str],
8586
runtime_role_name: Optional[str],
8687
runtime_apikey_name: Optional[str],
88+
runtime_auth_type: Optional[str],
89+
runtime_jwt_discovery_url: Optional[str],
90+
runtime_jwt_allowed_clients: Optional[List[str]],
8791
) -> Dict[str, Any]:
8892
"""Collect all CLI parameters.
8993
@@ -144,12 +148,22 @@ def collect_cli_params(
144148
strategy_params["cr_namespace_name"] = cr_namespace_name
145149
if cr_repo_name is not None:
146150
strategy_params["cr_repo_name"] = cr_repo_name
151+
if cr_auto_create_instance_type is not None:
152+
strategy_params["cr_auto_create_instance_type"] = (
153+
cr_auto_create_instance_type
154+
)
147155
if runtime_name is not None:
148156
strategy_params["runtime_name"] = runtime_name
149157
if runtime_role_name is not None:
150158
strategy_params["runtime_role_name"] = runtime_role_name
151159
if runtime_apikey_name is not None:
152160
strategy_params["runtime_apikey_name"] = runtime_apikey_name
161+
if runtime_auth_type is not None:
162+
strategy_params["runtime_auth_type"] = runtime_auth_type
163+
if runtime_jwt_discovery_url is not None:
164+
strategy_params["runtime_jwt_discovery_url"] = runtime_jwt_discovery_url
165+
if runtime_jwt_allowed_clients is not None:
166+
strategy_params["runtime_jwt_allowed_clients"] = runtime_jwt_allowed_clients
153167

154168
return {"common": common_params, "strategy": strategy_params}
155169

@@ -230,6 +244,34 @@ def update_config(
230244
else:
231245
new_strategy_config[key] = value
232246

247+
strategy_obj = None
248+
if strategy_name == "local":
249+
from agentkit.toolkit.config import LocalStrategyConfig
250+
251+
strategy_obj = LocalStrategyConfig.from_dict(
252+
new_strategy_config, skip_render=True
253+
)
254+
elif strategy_name == "cloud":
255+
from agentkit.toolkit.config import CloudStrategyConfig
256+
257+
strategy_obj = CloudStrategyConfig.from_dict(
258+
new_strategy_config, skip_render=True
259+
)
260+
elif strategy_name == "hybrid":
261+
from agentkit.toolkit.config import HybridStrategyConfig
262+
263+
strategy_obj = HybridStrategyConfig.from_dict(
264+
new_strategy_config, skip_render=True
265+
)
266+
267+
if strategy_obj is not None:
268+
strategy_errors = self.validator.validate_dataclass(strategy_obj)
269+
if strategy_errors:
270+
console.print("[red]Configuration validation failed:[/red]")
271+
for error in strategy_errors:
272+
console.print(f" [red]✗[/red] {error}")
273+
return False
274+
233275
self._show_config_changes(
234276
old_strategy_config,
235277
new_strategy_config,

agentkit/toolkit/config/config_validator.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import re
1818
from typing import List, Any
19-
from dataclasses import fields
19+
from dataclasses import fields, is_dataclass
2020

2121
from agentkit.toolkit.config.config import CommonConfig
2222

@@ -82,7 +82,59 @@ def validate_common_config(config: CommonConfig) -> List[str]:
8282
return errors
8383

8484
@staticmethod
85-
def _validate_conditional_fields(config: CommonConfig) -> List[str]:
85+
def validate_dataclass(config: Any) -> List[str]:
86+
if not is_dataclass(config):
87+
return []
88+
89+
errors: List[str] = []
90+
91+
for field in fields(config):
92+
if field.name.startswith("_"):
93+
continue
94+
95+
validation = field.metadata.get("validation", {})
96+
97+
if validation.get("type") == "conditional":
98+
continue
99+
100+
value = getattr(config, field.name)
101+
102+
if validation.get("required") and (
103+
not value or (isinstance(value, str) and not value.strip())
104+
):
105+
desc = field.metadata.get("description", field.name)
106+
errors.append(f"{desc} is required")
107+
continue
108+
109+
pattern = validation.get("pattern")
110+
if pattern and value and isinstance(value, str):
111+
if not re.match(pattern, value):
112+
desc = field.metadata.get("description", field.name)
113+
msg = validation.get("message", "Invalid format")
114+
errors.append(f"{desc}: {msg}")
115+
116+
choices = field.metadata.get("choices")
117+
if choices and value:
118+
valid_values = []
119+
if isinstance(choices, list):
120+
if choices and isinstance(choices[0], dict):
121+
valid_values = [c["value"] for c in choices]
122+
else:
123+
valid_values = choices
124+
125+
if valid_values and value not in valid_values:
126+
desc = field.metadata.get("description", field.name)
127+
errors.append(
128+
f"{desc} must be one of: {', '.join(map(str, valid_values))}"
129+
)
130+
131+
conditional_errors = ConfigValidator._validate_conditional_fields(config)
132+
errors.extend(conditional_errors)
133+
134+
return errors
135+
136+
@staticmethod
137+
def _validate_conditional_fields(config: Any) -> List[str]:
86138
"""Execute conditional validation (cross-field dependencies).
87139
88140
Args:
@@ -93,7 +145,7 @@ def _validate_conditional_fields(config: CommonConfig) -> List[str]:
93145
"""
94146
errors = []
95147

96-
for field in fields(CommonConfig):
148+
for field in fields(config):
97149
if field.name.startswith("_"):
98150
continue
99151

@@ -111,11 +163,6 @@ def _validate_conditional_fields(config: CommonConfig) -> List[str]:
111163
depend_value = getattr(config, depends_on, None)
112164
current_value = getattr(config, field.name, None)
113165

114-
if not current_value or (
115-
isinstance(current_value, str) and not current_value.strip()
116-
):
117-
continue
118-
119166
if depend_value in rules:
120167
rule = rules[depend_value]
121168
field_errors = ConfigValidator._apply_conditional_rule(
@@ -143,6 +190,12 @@ def _apply_conditional_rule(
143190
errors = []
144191
desc = metadata.get("description", field_name)
145192

193+
if rule.get("required") and (
194+
value is None or (isinstance(value, str) and not value.strip())
195+
):
196+
errors.append(f"{desc} is required")
197+
return errors
198+
146199
if "choices" in rule:
147200
if value not in rule["choices"]:
148201
msg = rule.get(

agentkit/toolkit/config/strategy_configs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ class HybridStrategyConfig(AutoSerializableMixin):
171171
metadata={
172172
"description": "CR instance type when auto-creating (Micro or Enterprise)",
173173
"icon": "⚙️",
174+
"choices": [
175+
{"value": "Micro", "description": "Micro"},
176+
{"value": "Enterprise", "description": "Enterprise"},
177+
],
178+
"hidden": True,
174179
},
175180
)
176181
cr_image_full_url: str = field(
@@ -244,13 +249,33 @@ class HybridStrategyConfig(AutoSerializableMixin):
244249
metadata={
245250
"description": "OIDC Discovery URL for JWT validation (required when auth_type is custom_jwt)",
246251
"examples": "https://userpool-xxx.userpool.auth.id.cn-beijing.volces.com/.well-known/openid-configuration",
252+
"prompt_condition": {
253+
"depends_on": "runtime_auth_type",
254+
"values": [AUTH_TYPE_CUSTOM_JWT],
255+
},
256+
"validation": {
257+
"type": "conditional",
258+
"depends_on": "runtime_auth_type",
259+
"rules": {
260+
AUTH_TYPE_CUSTOM_JWT: {
261+
"required": True,
262+
"pattern": r"^https://.+",
263+
"hint": "(must be a valid https URL)",
264+
"message": "must be a valid https URL",
265+
}
266+
},
267+
},
247268
},
248269
)
249270
runtime_jwt_allowed_clients: List[str] = field(
250271
default_factory=list,
251272
metadata={
252273
"description": "Allowed OAuth2 client IDs (required when auth_type is custom_jwt)",
253274
"examples": "['fa99ec54-8a1c-49b2-9a9e-3f3ba31d9a33']",
275+
"prompt_condition": {
276+
"depends_on": "runtime_auth_type",
277+
"values": [AUTH_TYPE_CUSTOM_JWT],
278+
},
254279
},
255280
)
256281
runtime_endpoint: str = field(
@@ -371,6 +396,11 @@ class CloudStrategyConfig(AutoSerializableMixin):
371396
metadata={
372397
"description": "CR instance type when auto-creating (Micro or Enterprise)",
373398
"icon": "⚙️",
399+
"choices": [
400+
{"value": "Micro", "description": "Micro"},
401+
{"value": "Enterprise", "description": "Enterprise"},
402+
],
403+
"hidden": True,
374404
},
375405
)
376406
cr_region: str = field(
@@ -468,13 +498,33 @@ class CloudStrategyConfig(AutoSerializableMixin):
468498
metadata={
469499
"description": "OIDC Discovery URL for JWT validation (required when auth_type is custom_jwt)",
470500
"examples": "https://userpool-xxx.userpool.auth.id.cn-beijing.volces.com/.well-known/openid-configuration",
501+
"prompt_condition": {
502+
"depends_on": "runtime_auth_type",
503+
"values": [AUTH_TYPE_CUSTOM_JWT],
504+
},
505+
"validation": {
506+
"type": "conditional",
507+
"depends_on": "runtime_auth_type",
508+
"rules": {
509+
AUTH_TYPE_CUSTOM_JWT: {
510+
"required": True,
511+
"pattern": r"^https://.+",
512+
"hint": "(must be a valid https URL)",
513+
"message": "must be a valid https URL",
514+
}
515+
},
516+
},
471517
},
472518
)
473519
runtime_jwt_allowed_clients: List[str] = field(
474520
default_factory=list,
475521
metadata={
476522
"description": "Allowed OAuth2 client IDs (required when auth_type is custom_jwt)",
477523
"examples": "['fa99ec54-8a1c-49b2-9a9e-3f3ba31d9a33']",
524+
"prompt_condition": {
525+
"depends_on": "runtime_auth_type",
526+
"values": [AUTH_TYPE_CUSTOM_JWT],
527+
},
478528
},
479529
)
480530
runtime_endpoint: str = field(

0 commit comments

Comments
 (0)