Skip to content

Commit 4bcc8e4

Browse files
committed
fix unittest
1 parent c3528ec commit 4bcc8e4

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

scripts/train_eagle3.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -777,13 +777,20 @@ def main():
777777
attention_backend=args.attention_backend,
778778
)
779779
else:
780-
eagle3_model = OnlineEagle3Model(
781-
target_model=target_model,
782-
draft_model=draft_model,
783-
length=args.ttt_length,
784-
attention_backend=args.attention_backend,
785-
)
786-
780+
if is_online:
781+
eagle3_model = OnlineEagle3Model(
782+
target_model=target_model,
783+
draft_model=draft_model,
784+
length=args.ttt_length,
785+
attention_backend=args.attention_backend,
786+
)
787+
else:
788+
# offline: the target_model is TargetHead not a model
789+
eagle3_model = OnlineEagle3Model(
790+
draft_model=draft_model,
791+
length=args.ttt_length,
792+
attention_backend=args.attention_backend,
793+
)
787794
eagle3_model = FSDP(
788795
eagle3_model,
789796
use_orig_params=True,

0 commit comments

Comments
 (0)