@@ -65,21 +65,29 @@ def test_load_dataset(self):
6565 str (256 ),
6666 "--dataloader.dataset" ,
6767 dataset_name ,
68- "--dataloader.classifier_free_guidance_prob " ,
68+ "--dataloader.prompt_dropout_prob " ,
6969 "0.447" ,
70- "--dataloader.encoder.test_mode" ,
71- "--encoder.test_mode" ,
70+ "--tokenizer.test_mode" ,
71+ "--tokenizer.t5_tokenizer_path" ,
72+ "tests/assets/tokenizer" ,
73+ "--tokenizer.clip_tokenizer_path" ,
74+ "tests/assets/tokenizer" ,
75+ "--encoder.random_init" ,
7276 "--encoder.t5_encoder" ,
7377 "tests/assets/flux_test_encoders/t5-v1_1-xxl" ,
7478 "--encoder.clip_encoder" ,
7579 "tests/assets/flux_test_encoders/clip-vit-large-patch14" ,
7680 ]
7781 )
7882
83+ # Build the tokenizer container from config
84+ tokenizer = config .tokenizer .build (tokenizer_path = config .hf_assets_path )
85+
7986 dl = config .dataloader .build (
8087 dp_world_size = world_size ,
8188 dp_rank = rank ,
8289 local_batch_size = batch_size ,
90+ tokenizer = tokenizer ,
8391 )
8492
8593 it = iter (dl )
@@ -91,11 +99,11 @@ def test_load_dataset(self):
9199 len (input_data ) == 3
92100 ) # (clip_encodings, t5_encodings, prompt)
93101 assert labels .shape == (batch_size , 3 , 256 , 256 )
94- assert input_data ["clip_tokens " ].shape == (
102+ assert input_data ["clip " ].shape == (
95103 batch_size ,
96104 77 ,
97105 )
98- assert input_data ["t5_tokens " ].shape == (
106+ assert input_data ["t5 " ].shape == (
99107 batch_size ,
100108 256 ,
101109 )
@@ -107,6 +115,7 @@ def test_load_dataset(self):
107115 dp_world_size = world_size ,
108116 dp_rank = rank ,
109117 local_batch_size = batch_size ,
118+ tokenizer = tokenizer ,
110119 )
111120 dl_resumed .load_state_dict (state )
112121 it_resumed = iter (dl_resumed )
@@ -119,10 +128,6 @@ def test_load_dataset(self):
119128 torch .manual_seed (i )
120129 input_ids , labels = next (it_resumed )
121130
122- assert torch .equal (
123- input_ids ["clip_tokens" ], expected_input_ids ["clip_tokens" ]
124- )
125- assert torch .equal (
126- input_ids ["t5_tokens" ], expected_input_ids ["t5_tokens" ]
127- )
131+ assert torch .equal (input_ids ["clip" ], expected_input_ids ["clip" ])
132+ assert torch .equal (input_ids ["t5" ], expected_input_ids ["t5" ])
128133 assert torch .equal (labels , expected_labels )
0 commit comments