Skip to content

Commit 1b06020

Browse files
committed
remove qkv fusion tets
1 parent 2158f00 commit 1b06020

File tree

1 file changed

+0
-40
lines changed

1 file changed

+0
-40
lines changed

tests/pipelines/cogview3/test_cogview3plus.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3333
from ..test_pipelines_common import (
3434
PipelineTesterMixin,
35-
check_qkv_fusion_matches_attn_procs_length,
36-
check_qkv_fusion_processors_exist,
3735
to_np,
3836
)
3937

@@ -233,44 +231,6 @@ def test_attention_slicing_forward_pass(
233231
"Attention slicing should not affect the inference results",
234232
)
235233

236-
def test_fused_qkv_projections(self):
237-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
238-
components = self.get_dummy_components()
239-
pipe = self.pipeline_class(**components)
240-
pipe = pipe.to(device)
241-
pipe.set_progress_bar_config(disable=None)
242-
243-
inputs = self.get_dummy_inputs(device)
244-
images = pipe(**inputs)[0] # [B, C, H, W]
245-
original_image_slice = images[0, -1, -3:, -3:]
246-
247-
pipe.fuse_qkv_projections()
248-
assert check_qkv_fusion_processors_exist(
249-
pipe.transformer
250-
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
251-
assert check_qkv_fusion_matches_attn_procs_length(
252-
pipe.transformer, pipe.transformer.original_attn_processors
253-
), "Something wrong with the attention processors concerning the fused QKV projections."
254-
255-
inputs = self.get_dummy_inputs(device)
256-
images = pipe(**inputs)[0]
257-
image_slice_fused = images[0, -1, -3:, -3:]
258-
259-
pipe.transformer.unfuse_qkv_projections()
260-
inputs = self.get_dummy_inputs(device)
261-
images = pipe(**inputs)[0]
262-
image_slice_disabled = images[0, -1, -3:, -3:]
263-
264-
assert np.allclose(
265-
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
266-
), "Fusion of QKV projections shouldn't affect the outputs."
267-
assert np.allclose(
268-
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
269-
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
270-
assert np.allclose(
271-
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
272-
), "Original outputs should match when fused QKV projections are disabled."
273-
274234

275235
@slow
276236
@require_torch_gpu

0 commit comments

Comments
 (0)