Skip to content

Commit 43a627a

Browse files
committed
Merge remote-tracking branch 'origin/main' into feature/more-visualizations-rag-pipeline
2 parents 9976ca5 + 343e007 commit 43a627a

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

huggingface-sagemaker/steps/deploying/sagemaker_deployment.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
# limitations under the License.
1616
#
1717

18+
import os
1819
from typing import Optional
1920

2021
from gradio.aws_helper import get_sagemaker_role, get_sagemaker_session
2122
from sagemaker.huggingface import HuggingFaceModel
2223
from typing_extensions import Annotated
23-
from zenml import get_step_context, step
24+
from zenml import get_step_context, log_artifact_metadata, step
2425
from zenml.logger import get_logger
2526

2627
# Initialize logger
@@ -35,7 +36,7 @@ def deploy_hf_to_sagemaker(
3536
pytorch_version: str = "1.13.1",
3637
py_version: str = "py39",
3738
hf_task: str = "text-classification",
38-
instance_type: str = "ml.g5.2xlarge",
39+
instance_type: str = "ml.t2.medium",
3940
container_startup_health_check_timeout: int = 300,
4041
) -> Annotated[str, "sagemaker_endpoint_name"]:
4142
"""
@@ -83,4 +84,18 @@ def deploy_hf_to_sagemaker(
8384
)
8485
endpoint_name = predictor.endpoint_name
8586
logger.info(f"Model deployed to SageMaker: {endpoint_name}")
87+
88+
# get region from env variable
89+
region = os.environ["AWS_REGION"] or "eu-central-1"
90+
invocation_url = f"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{endpoint_name}/invocations"
91+
92+
log_artifact_metadata(
93+
artifact_name="sagemaker_endpoint_name",
94+
metadata={
95+
"invocation_url": invocation_url,
96+
"endpoint_name": endpoint_name,
97+
},
98+
)
99+
100+
86101
return endpoint_name

huggingface-sagemaker/steps/promotion/promote_metric_compare_promoter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929
@step
3030
def promote_metric_compare_promoter(
31-
latest_metrics: Dict[str, str],
32-
current_metrics: Dict[str, str],
31+
latest_metrics: Dict[str, float],
32+
current_metrics: Dict[str, float],
3333
metric_to_compare: str = "accuracy",
3434
):
3535
"""Try to promote trained model.

huggingface-sagemaker/steps/training/model_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def model_trainer(
154154
eval_results = trainer.evaluate(metric_key_prefix="")
155155

156156
# Log the evaluation results in model control plane
157-
log_artifact_metadata(output_name="model", metrics=eval_results)
157+
log_artifact_metadata(
158+
artifact_name="model", metadata={"metrics": eval_results}
159+
)
158160

159161
return model, tokenizer

llm-complete-guide/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,4 +264,4 @@ def main(
264264
materializer_registry.register_materializer_type(
265265
Document, DocumentMaterializer
266266
)
267-
main()
267+
main()

0 commit comments

Comments
 (0)