Skip to content

Commit 16dc02b

Browse files
authored
Merge branch 'aws:master' into rsareddy-dev
2 parents acc861a + 6945a04 commit 16dc02b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1518
-358
lines changed

CHANGELOG.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,43 @@
11
# Changelog
22

3+
## v2.241.0 (2025-03-06)
4+
5+
### Features
6+
7+
* Make DistributedConfig Extensible
8+
* support training for JumpStart model references as part of Curated Hub Phase 2
9+
* Allow ModelTrainer to accept hyperparameters file
10+
11+
### Bug Fixes and Other Changes
12+
13+
* Skip tests with deprecated instance type
14+
* Ensure Model.is_repack() returns a boolean
15+
* Fix error when there is no session to call _create_model_request()
16+
* Use sagemaker session's s3_resource in download_folder
17+
* Added check for the presence of model package group before creating one
18+
* Fix key error in _send_metrics()
19+
20+
## v2.240.0 (2025-02-25)
21+
22+
### Features
23+
24+
* Add support for TGI Neuronx 0.0.27 and HF PT 2.3.0 image in PySDK
25+
26+
### Bug Fixes and Other Changes
27+
28+
* Remove main function entrypoint in ModelBuilder dependency manager.
29+
* forbid extras in Configs
30+
* altconfig hubcontent and reenable integ test
31+
* Merge branch 'master-rba' into local_merge
32+
* py_version doc fixes
33+
* Add backward compatbility for RecordSerializer and RecordDeserializer
34+
* update image_uri_configs 02-21-2025 06:18:10 PST
35+
* update image_uri_configs 02-20-2025 06:18:08 PST
36+
37+
### Documentation Changes
38+
39+
* Removed a line about python version requirements of training script which can misguide users.
40+
341
## v2.239.3 (2025-02-19)
442

543
### Bug Fixes and Other Changes

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.239.4.dev0
1+
2.241.1.dev0

src/sagemaker/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2550,7 +2550,6 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25502550
raise ValueError(
25512551
"File URIs are supported in local mode only. Please use a S3 URI instead."
25522552
)
2553-
25542553
config = _Job._load_config(inputs, estimator)
25552554

25562555
current_hyperparameters = estimator.hyperparameters()

src/sagemaker/experiments/_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ def _send_metrics(self, metrics):
197197
response = self._metrics_client.batch_put_metrics(**request)
198198
errors = response["Errors"] if "Errors" in response else None
199199
if errors:
200-
message = errors[0]["Message"]
201-
raise Exception(f'{len(errors)} errors with message "{message}"')
200+
error_code = errors[0]["Code"]
201+
raise Exception(f'{len(errors)} errors with error code "{error_code}"')
202202

203203
def _construct_batch_put_metrics_request(self, batch):
204204
"""Creates dictionary object used as request to metrics service."""

src/sagemaker/inputs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(
4343
attribute_names: Optional[List[Union[str, PipelineVariable]]] = None,
4444
target_attribute_name: Optional[Union[str, PipelineVariable]] = None,
4545
shuffle_config: Optional["ShuffleConfig"] = None,
46+
hub_access_config: Optional[dict] = None,
47+
model_access_config: Optional[dict] = None,
4648
):
4749
r"""Create a definition for input data used by an SageMaker training job.
4850
@@ -102,6 +104,13 @@ def __init__(
102104
shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables
103105
shuffling on this channel. See the SageMaker API documentation for more info:
104106
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
107+
hub_access_config (dict): Specify the HubAccessConfig of a
108+
Model Reference for which a training job is being created for.
109+
model_access_config (dict): For models that require a Model Access Config, specify True
110+
or False for to indicate whether model terms of use have been accepted.
111+
The `accept_eula` value must be explicitly defined as `True` in order to
112+
accept the end-user license agreement (EULA) that some
113+
models require. (Default: None).
105114
"""
106115
self.config = {
107116
"DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}}
@@ -129,6 +138,27 @@ def __init__(
129138
self.config["TargetAttributeName"] = target_attribute_name
130139
if shuffle_config is not None:
131140
self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed}
141+
self.add_hub_access_config(hub_access_config)
142+
self.add_model_access_config(model_access_config)
143+
144+
def add_hub_access_config(self, hub_access_config=None):
145+
"""Add Hub Access Config to the channel's configuration.
146+
147+
Args:
148+
hub_access_config (dict): The HubAccessConfig to be added to the
149+
channel's configuration.
150+
"""
151+
if hub_access_config is not None:
152+
self.config["DataSource"]["S3DataSource"]["HubAccessConfig"] = hub_access_config
153+
154+
def add_model_access_config(self, model_access_config=None):
155+
"""Add Model Access Config to the channel's configuration.
156+
157+
Args:
158+
model_access_config (dict): Whether model terms of use have been accepted.
159+
"""
160+
if model_access_config is not None:
161+
self.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] = model_access_config
132162

133163

134164
class ShuffleConfig(object):

src/sagemaker/job.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def stop(self):
6565
@staticmethod
6666
def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
6767
"""Placeholder docstring"""
68+
model_access_config, hub_access_config = _Job._get_access_configs(estimator)
6869
input_config = _Job._format_inputs_to_input_config(inputs, validate_uri)
6970
role = (
7071
estimator.sagemaker_session.expand_role(estimator.role)
@@ -95,19 +96,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
9596
validate_uri,
9697
content_type="application/x-sagemaker-model",
9798
input_mode="File",
99+
model_access_config=model_access_config,
100+
hub_access_config=hub_access_config,
98101
)
99102
if model_channel:
100103
input_config = [] if input_config is None else input_config
101104
input_config.append(model_channel)
102105

103-
if estimator.enable_network_isolation():
104-
code_channel = _Job._prepare_channel(
105-
input_config, estimator.code_uri, estimator.code_channel_name, validate_uri
106-
)
106+
code_channel = _Job._prepare_channel(
107+
input_config,
108+
estimator.code_uri,
109+
estimator.code_channel_name,
110+
validate_uri,
111+
)
107112

108-
if code_channel:
109-
input_config = [] if input_config is None else input_config
110-
input_config.append(code_channel)
113+
if code_channel:
114+
input_config = [] if input_config is None else input_config
115+
input_config.append(code_channel)
111116

112117
return {
113118
"input_config": input_config,
@@ -118,6 +123,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
118123
"vpc_config": vpc_config,
119124
}
120125

126+
@staticmethod
127+
def _get_access_configs(estimator):
128+
"""Return access configs from estimator object.
129+
130+
JumpStartEstimator uses access configs which need to be added to the model channel,
131+
so they are passed down to the job level.
132+
133+
Args:
134+
estimator (EstimatorBase): estimator object with access config field if applicable
135+
"""
136+
model_access_config, hub_access_config = None, None
137+
if hasattr(estimator, "model_access_config"):
138+
model_access_config = estimator.model_access_config
139+
if hasattr(estimator, "hub_access_config"):
140+
hub_access_config = estimator.hub_access_config
141+
return model_access_config, hub_access_config
142+
121143
@staticmethod
122144
def _format_inputs_to_input_config(inputs, validate_uri=True):
123145
"""Placeholder docstring"""
@@ -173,6 +195,8 @@ def _format_string_uri_input(
173195
input_mode=None,
174196
compression=None,
175197
target_attribute_name=None,
198+
model_access_config=None,
199+
hub_access_config=None,
176200
):
177201
"""Placeholder docstring"""
178202
s3_input_result = TrainingInput(
@@ -181,6 +205,8 @@ def _format_string_uri_input(
181205
input_mode=input_mode,
182206
compression=compression,
183207
target_attribute_name=target_attribute_name,
208+
model_access_config=model_access_config,
209+
hub_access_config=hub_access_config,
184210
)
185211
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
186212
return s3_input_result
@@ -193,7 +219,11 @@ def _format_string_uri_input(
193219
)
194220
if isinstance(uri_input, str):
195221
return s3_input_result
196-
if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)):
222+
if isinstance(uri_input, (file_input, FileSystemInput)):
223+
return uri_input
224+
if isinstance(uri_input, TrainingInput):
225+
uri_input.add_hub_access_config(hub_access_config=hub_access_config)
226+
uri_input.add_model_access_config(model_access_config=model_access_config)
197227
return uri_input
198228
if is_pipeline_variable(uri_input):
199229
return s3_input_result
@@ -211,6 +241,8 @@ def _prepare_channel(
211241
validate_uri=True,
212242
content_type=None,
213243
input_mode=None,
244+
model_access_config=None,
245+
hub_access_config=None,
214246
):
215247
"""Placeholder docstring"""
216248
if not channel_uri:
@@ -226,7 +258,12 @@ def _prepare_channel(
226258
raise ValueError("Duplicate channel {} not allowed.".format(channel_name))
227259

228260
channel_input = _Job._format_string_uri_input(
229-
channel_uri, validate_uri, content_type, input_mode
261+
channel_uri,
262+
validate_uri,
263+
content_type,
264+
input_mode,
265+
model_access_config=model_access_config,
266+
hub_access_config=hub_access_config,
230267
)
231268
channel = _Job._convert_input_to_channel(channel_name, channel_input)
232269

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_region_fallback,
3030
verify_model_region_and_return_specs,
3131
)
32+
from sagemaker.s3_utils import is_s3_url
3233
from sagemaker.session import Session
3334
from sagemaker.jumpstart.types import JumpStartModelSpecs
3435

@@ -74,7 +75,7 @@ def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_ty
7475
def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
7576
"""Returns instance specific training artifact key or default one as fallback."""
7677
instance_specific_training_artifact_key: Optional[str] = (
77-
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
78+
model_specs.training_instance_type_variants.get_instance_specific_training_artifact_key(
7879
instance_type=instance_type
7980
)
8081
if instance_type
@@ -185,8 +186,8 @@ def _retrieve_model_uri(
185186
os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE)
186187
or default_jumpstart_bucket
187188
)
188-
189-
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
189+
if not is_s3_url(model_artifact_key):
190+
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
190191

191192
return model_s3_uri
192193

src/sagemaker/jumpstart/estimator.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
validate_model_id_and_get_type,
4242
resolve_model_sagemaker_config_field,
4343
verify_model_region_and_return_specs,
44+
remove_env_var_from_estimator_kwargs_if_accept_eula_present,
45+
get_model_access_config,
46+
get_hub_access_config,
4447
)
4548
from sagemaker.utils import stringify_object, format_tags, Tags
4649
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
@@ -619,6 +622,10 @@ def _validate_model_id_and_get_type_hook():
619622
self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation
620623
self.config_name = estimator_init_kwargs.config_name
621624
self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False)
625+
# Access configs initialized to None, would be given a value when .fit() is called
626+
# if applicable
627+
self.model_access_config = None
628+
self.hub_access_config = None
622629

623630
super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict())
624631

@@ -629,6 +636,7 @@ def fit(
629636
logs: Optional[str] = None,
630637
job_name: Optional[str] = None,
631638
experiment_config: Optional[Dict[str, str]] = None,
639+
accept_eula: Optional[bool] = None,
632640
) -> None:
633641
"""Start training job by calling base ``Estimator`` class ``fit`` method.
634642
@@ -679,8 +687,16 @@ def fit(
679687
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
680688
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
681689
(Default: None).
690+
accept_eula (bool): For models that require a Model Access Config, specify True or
691+
False to indicate whether model terms of use have been accepted.
692+
The `accept_eula` value must be explicitly defined as `True` in order to
693+
accept the end-user license agreement (EULA) that some
694+
models require. (Default: None).
682695
"""
683-
696+
self.model_access_config = get_model_access_config(accept_eula)
697+
self.hub_access_config = get_hub_access_config(
698+
hub_content_arn=self.init_kwargs.get("model_reference_arn", None)
699+
)
684700
estimator_fit_kwargs = get_fit_kwargs(
685701
model_id=self.model_id,
686702
model_version=self.model_version,
@@ -695,7 +711,9 @@ def fit(
695711
tolerate_deprecated_model=self.tolerate_deprecated_model,
696712
sagemaker_session=self.sagemaker_session,
697713
config_name=self.config_name,
714+
hub_access_config=self.hub_access_config,
698715
)
716+
remove_env_var_from_estimator_kwargs_if_accept_eula_present(self.init_kwargs, accept_eula)
699717

700718
return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
701719

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
from sagemaker.jumpstart.utils import (
7272
add_hub_content_arn_tags,
7373
add_jumpstart_model_info_tags,
74-
get_eula_message,
7574
get_default_jumpstart_session_with_user_agent_suffix,
7675
get_top_ranked_config_name,
7776
update_dict_if_key_not_present,
@@ -265,6 +264,7 @@ def get_fit_kwargs(
265264
tolerate_deprecated_model: Optional[bool] = None,
266265
sagemaker_session: Optional[Session] = None,
267266
config_name: Optional[str] = None,
267+
hub_access_config: Optional[Dict] = None,
268268
) -> JumpStartEstimatorFitKwargs:
269269
"""Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object."""
270270

@@ -301,10 +301,32 @@ def get_fit_kwargs(
301301
estimator_fit_kwargs = _add_region_to_kwargs(estimator_fit_kwargs)
302302
estimator_fit_kwargs = _add_training_job_name_to_kwargs(estimator_fit_kwargs)
303303
estimator_fit_kwargs = _add_fit_extra_kwargs(estimator_fit_kwargs)
304+
estimator_fit_kwargs = _add_hub_access_config_to_kwargs_inputs(
305+
estimator_fit_kwargs, hub_access_config
306+
)
304307

305308
return estimator_fit_kwargs
306309

307310

311+
def _add_hub_access_config_to_kwargs_inputs(
312+
kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None
313+
):
314+
"""Adds HubAccessConfig to kwargs inputs"""
315+
316+
if isinstance(kwargs.inputs, str):
317+
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
318+
elif isinstance(kwargs.inputs, TrainingInput):
319+
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
320+
elif isinstance(kwargs.inputs, dict):
321+
for k, v in kwargs.inputs.items():
322+
if isinstance(v, str):
323+
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
324+
elif isinstance(kwargs.inputs, TrainingInput):
325+
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
326+
327+
return kwargs
328+
329+
308330
def get_deploy_kwargs(
309331
model_id: str,
310332
model_version: Optional[str] = None,
@@ -668,18 +690,6 @@ def _add_env_to_kwargs(
668690
value,
669691
)
670692

671-
environment = getattr(kwargs, "environment", {}) or {}
672-
if (
673-
environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY)
674-
and str(environment.get("accept_eula", "")).lower() != "true"
675-
):
676-
model_specs = kwargs.specs
677-
if model_specs.is_gated_model():
678-
raise ValueError(
679-
"Need to define ‘accept_eula'='true' within Environment. "
680-
f"{get_eula_message(model_specs, kwargs.region)}"
681-
)
682-
683693
return kwargs
684694

685695

0 commit comments

Comments
 (0)