Skip to content

Commit 68c2469

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents c5d8dfa + cd4fdde commit 68c2469

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tuning/sft_trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,12 @@ def train(
361361
# from our object directly. In the future, we should consider renaming this class and / or
362362
# not adding things that are not directly used by the trainer instance to it.
363363

364-
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
364+
# To filter out fields that are not defined as init (eg. _n_gpu)
365+
transformer_train_arg_fields = [
366+
x.name for x in dataclasses.fields(SFTConfig) if x.init
367+
]
365368
transformer_kwargs = {
366-
k: v
367-
for k, v in train_args.to_dict().items()
368-
if k in transformer_train_arg_fields
369+
k: v for k, v in vars(train_args).items() if k in transformer_train_arg_fields
369370
}
370371

371372
additional_args = {

0 commit comments

Comments
 (0)