Skip to content

Commit 452474c

Browse files
authored
Merge pull request #177 from zenml-io/bugfix/fixgame
Fixed import errors
2 parents e403d06 + e95132d commit 452474c

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

gamesense/pipelines/train_accelerated.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818

1919
from steps import (
2020
evaluate_model,
21-
finetune_accelerated,
21+
finetune,
2222
log_metadata_from_step_artifact,
2323
prepare_data,
2424
promote,
2525
)
2626
from zenml import pipeline
27+
from zenml.integrations.huggingface.steps import run_with_accelerate
2728

2829

2930
@pipeline
@@ -75,6 +76,9 @@ def llm_peft_full_finetune(
7576
id="log_metadata_evaluation_base",
7677
)
7778

79+
finetune_accelerated = run_with_accelerate(
80+
finetune, num_processes=2, multi_gpu=True, mixed_precision="bf16"
81+
)
7882
ft_model_dir = finetune_accelerated(
7983
base_model_id=base_model_id,
8084
dataset_dir=datasets_dir,

gamesense/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@
3535
3636
\b
3737
# Run the pipeline with custom config
38-
python run.py --config custom_finetune.yaml
38+
python run.py --config phi3.5_finetune_local.yaml
3939
"""
4040
)
4141
@click.option(
4242
"--config",
4343
type=str,
44-
default="default_finetune.yaml",
44+
default="phi3.5_finetune_local.yaml",
4545
help="Path to the YAML config file.",
4646
)
4747
@click.option(

gamesense/steps/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
from .evaluate_model import evaluate_model
19-
from .finetune import finetune, finetune_accelerated
19+
from .finetune import finetune
2020
from .log_metadata import log_metadata_from_step_artifact
2121
from .prepare_datasets import prepare_data
2222
from .promote import promote

gamesense/steps/finetune.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from zenml.logger import get_logger
3333
from zenml.materializers import BuiltInMaterializer
3434
from zenml.utils.cuda_utils import cleanup_gpu_memory
35+
from zenml.client import Client
36+
3537

3638
logger = get_logger(__name__)
3739

@@ -196,8 +198,3 @@ def finetune(
196198
)
197199

198200
return ft_model_dir
199-
200-
201-
finetune_accelerated = run_with_accelerate(
202-
finetune, num_processes=2, multi_gpu=True, mixed_precision="bf16"
203-
)

0 commit comments

Comments
 (0)