Skip to content

Trigger more meaningful validation error messages when trainer registration failed#40

Open
sfc-gh-lmerrick wants to merge 2 commits intomainfrom
lmerrick-validation-error-handling
Open

Trigger more meaningful validation error messages when trainer registration failed#40
sfc-gh-lmerrick wants to merge 2 commits intomainfrom
lmerrick-validation-error-handling

Conversation

@sfc-gh-lmerrick
Copy link
Contributor

If a user fails to register their custom Trainer class, or they accidentally supply the incorrect trainer type in the trainer config, validation can proceed in a difficult-to-debug manner. This PR introduces more explicit error messages that trigger in this case and can help the user understand the root cause of the issue faster.

Copy link
Collaborator

@sfc-gh-caxu sfc-gh-caxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not obvious to me whether this change is needed because the error has been raised from get_registered_trainer

Comment on lines -133 to +125
raise ValueError(f"{trainer_name} is not a registered Trainer.")
raise KeyError(f"{trainer_name} is not a registered Trainer.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be ValueError. ValueError suggests the function (get_registered_trainer) receives an invalid value. https://docs.python.org/3/library/exceptions.html#ValueError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the registry not a key-value mapping? If it's a key-value mapping, I believe KeyError is the more specific, and thus more useful, exception to raise.

Comment on lines +195 to +196
except KeyError as e:
raise KeyError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ValueError, see explanations. I think KeyError is rarely used and only for a local data type (e.g., a dict, set etc.)

@sfc-gh-lmerrick
Copy link
Contributor Author

Not obvious to me whether this change is needed because the error has been raised from get_registered_trainer

In my tests, the rust-implemented SchemaValidator.validate_python function driving the valiation suppressed this error and passed an info: ValidationInfo object to the parse_sub_config field validator method that was simply missing the extra fields that should have been added by the tried-but-failed previous calls to this function (i.e. no data object).

@sfc-gh-lmerrick
Copy link
Contributor Author

Added a test. The test fails before the changes, but succeeds after.

Previous test failure output:

========================================================================================= test session starts ==========================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
rootdir: /scratch/tests
configfile: pytest.ini
plugins: devtools-0.12.2, anyio-4.3.0, hypothesis-5.35.1, flakefinder-1.1.0, rerunfailures-14.0, shard-0.1.2, xdist-3.6.1, xdoctest-1.0.2
collected 1 item                                                                                                                                                                                       
Running 1 items in this shard

tests/trainer/test_trainer_validation.py F                                                                                                                                                       [100%]

=============================================================================================== FAILURES ===============================================================================================
______________________________________________________________________________________ test_unregistered_trainer _______________________________________________________________________________________

tmp_path = PosixPath('/tmp/pytest-of-lmerrick/pytest-4/test_unregistered_trainer0')

    @pytest.mark.cpu
    def test_unregistered_trainer(tmp_path):
        config_dict = {
            "type": "unregistered_or_nonexistent",
            "exit_iteration": 2,
            "micro_batch_size": 1,
            "model": {
                "type": "random-weight-hf",
                "name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct",
                "attn_implementation": "eager",
                "dtype": "float32",
            },
            "data": {
                "max_length": 2048,
                "sources": ["HuggingFaceH4/ultrachat_200k-truncated"],
            },
            "deepspeed": {"zero_optimization": {"stage": 0}},
            "optimizer": {"type": "cpu-adam"},
        }
        # Fails in previous implementation of `TrainerConfig.parse_sub_config`, despite
        # the implementation intending for this to succeed.
        with pytest.raises(ValueError) as ctx:
>           config = TrainerConfig(**config_dict)

tests/trainer/test_trainer_validation.py:43: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
arctic_training/config/base.py:25: in __init__
    super().__init__(**data)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

cls = <class 'arctic_training.config.trainer.TrainerConfig'>, v = 0
info = ValidationInfo(config={'title': 'TrainerConfig', 'extra_fields_behavior': 'forbid', 'validate_default': True, 'populat...1, 'gradient_accumulation_steps': 1, 'micro_batch_size': 1, 'seed': 42, 'train_iters': 0}, field_name='eval_frequency')

    @field_validator("eval_frequency", mode="after")
    def validate_eval_frequency(cls, v: int, info: ValidationInfo) -> int:
        if (
>           info.data["data"].eval_sources
            or info.data["data"].train_eval_split[1] > 0.0
        ):
E       KeyError: 'data'

arctic_training/config/trainer.py:158: KeyError
======================================================================================= short test summary info ========================================================================================
FAILED tests/trainer/test_trainer_validation.py::test_unregistered_trainer - KeyError: 'data'
========================================================================================== 1 failed in 0.72s ===========================================================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants