|
32 | 32 | from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS |
33 | 33 | from ..test_pipelines_common import ( |
34 | 34 | PipelineTesterMixin, |
35 | | - check_qkv_fusion_matches_attn_procs_length, |
36 | | - check_qkv_fusion_processors_exist, |
37 | 35 | to_np, |
38 | 36 | ) |
39 | 37 |
|
@@ -233,44 +231,6 @@ def test_attention_slicing_forward_pass( |
233 | 231 | "Attention slicing should not affect the inference results", |
234 | 232 | ) |
235 | 233 |
|
236 | | - def test_fused_qkv_projections(self): |
237 | | - device = "cpu" # ensure determinism for the device-dependent torch.Generator |
238 | | - components = self.get_dummy_components() |
239 | | - pipe = self.pipeline_class(**components) |
240 | | - pipe = pipe.to(device) |
241 | | - pipe.set_progress_bar_config(disable=None) |
242 | | - |
243 | | - inputs = self.get_dummy_inputs(device) |
244 | | - images = pipe(**inputs)[0] # [B, C, H, W] |
245 | | - original_image_slice = images[0, -1, -3:, -3:] |
246 | | - |
247 | | - pipe.fuse_qkv_projections() |
248 | | - assert check_qkv_fusion_processors_exist( |
249 | | - pipe.transformer |
250 | | - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." |
251 | | - assert check_qkv_fusion_matches_attn_procs_length( |
252 | | - pipe.transformer, pipe.transformer.original_attn_processors |
253 | | - ), "Something wrong with the attention processors concerning the fused QKV projections." |
254 | | - |
255 | | - inputs = self.get_dummy_inputs(device) |
256 | | - images = pipe(**inputs)[0] |
257 | | - image_slice_fused = images[0, -1, -3:, -3:] |
258 | | - |
259 | | - pipe.transformer.unfuse_qkv_projections() |
260 | | - inputs = self.get_dummy_inputs(device) |
261 | | - images = pipe(**inputs)[0] |
262 | | - image_slice_disabled = images[0, -1, -3:, -3:] |
263 | | - |
264 | | - assert np.allclose( |
265 | | - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 |
266 | | - ), "Fusion of QKV projections shouldn't affect the outputs." |
267 | | - assert np.allclose( |
268 | | - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 |
269 | | - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." |
270 | | - assert np.allclose( |
271 | | - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 |
272 | | - ), "Original outputs should match when fused QKV projections are disabled." |
273 | | - |
274 | 234 |
|
275 | 235 | @slow |
276 | 236 | @require_torch_gpu |
|
0 commit comments