1515# limitations under the License.
1616#
1717
18+ import os
1819from typing import Optional
1920
2021from gradio .aws_helper import get_sagemaker_role , get_sagemaker_session
2122from sagemaker .huggingface import HuggingFaceModel
2223from typing_extensions import Annotated
23- from zenml import get_step_context , step
24+ from zenml import get_step_context , log_artifact_metadata , step
2425from zenml .logger import get_logger
2526
2627# Initialize logger
@@ -35,7 +36,7 @@ def deploy_hf_to_sagemaker(
3536 pytorch_version : str = "1.13.1" ,
3637 py_version : str = "py39" ,
3738 hf_task : str = "text-classification" ,
38- instance_type : str = "ml.g5.2xlarge " ,
39+ instance_type : str = "ml.t2.medium " ,
3940 container_startup_health_check_timeout : int = 300 ,
4041) -> Annotated [str , "sagemaker_endpoint_name" ]:
4142 """
@@ -83,4 +84,18 @@ def deploy_hf_to_sagemaker(
8384 )
8485 endpoint_name = predictor .endpoint_name
8586 logger .info (f"Model deployed to SageMaker: { endpoint_name } " )
87+
88+ # get region from env variable
89+ region = os .environ ["AWS_REGION" ] or "eu-central-1"
90+ invocation_url = f"https://runtime.sagemaker.{ region } .amazonaws.com/endpoints/{ endpoint_name } /invocations"
91+
92+ log_artifact_metadata (
93+ artifact_name = "sagemaker_endpoint_name" ,
94+ metadata = {
95+ "invocation_url" : invocation_url ,
96+ "endpoint_name" : endpoint_name ,
97+ },
98+ )
99+
100+
86101 return endpoint_name
0 commit comments