@@ -2445,6 +2445,29 @@ def kernel():
24452445""" )
24462446
24472447
2448+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA4 ])
2449+ def test_amd_mfma_scaled_none (target ):
2450+
2451+ @gluon .jit
2452+ def kernel ():
2453+ mfma_layout : ttgl .constexpr = ttgl .amd .AMDMFMALayout (4 , [16 , 16 , 128 ], True , [1 , 1 ])
2454+ scale_layout : ttgl .constexpr = ttgl .DistributedLinearLayout ([],
2455+ [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 1 ], [0 , 2 ]],
2456+ [], [], [16 , 4 ])
2457+
2458+ a = ttgl .full ([16 , 64 ], 0x11 , ttgl .uint8 , ttgl .DotOperandLayout (0 , mfma_layout , 16 ))
2459+ b = ttgl .full ([64 , 16 ], 0x22 , ttgl .uint8 , ttgl .DotOperandLayout (1 , mfma_layout , 16 ))
2460+
2461+ b_scale = ttgl .full ([16 , 4 ], 0x01 , ttgl .uint8 , scale_layout )
2462+ acc = ttgl .full ([16 , 16 ], 0 , ttgl .float32 , mfma_layout )
2463+ ttgl .amd .cdna4 .mfma_scaled (a , None , 'e2m1' , b , b_scale , 'e2m1' , acc )
2464+
2465+ with pytest .raises (CompilationError ) as e :
2466+ run_parser (kernel , target = target )
2467+
2468+ assert "Scales must not be None" in str (e .value )
2469+
2470+
24482471@pytest .mark .parametrize ("target" , [HIP_TARGET_GFX1250 ])
24492472def test_amd_wmma_scaled (target ):
24502473
@@ -2497,6 +2520,32 @@ def kernel():
24972520""" )
24982521
24992522
2523+ @pytest .mark .parametrize ("target" , [HIP_TARGET_GFX1250 ])
2524+ def test_amd_wmma_scaled_none (target ):
2525+
2526+ @gluon .jit
2527+ def kernel ():
2528+ wmma_layout : ttgl .constexpr = ttgl .amd .AMDWMMALayout (3 , True , [1 , 1 ], [16 , 16 , 128 ])
2529+ wmma_layout_packed : ttgl .constexpr = ttgl .amd .AMDWMMALayout (3 , True , [1 , 1 ], [16 , 16 , 64 ])
2530+ scale_layout : ttgl .constexpr = ttgl .DistributedLinearLayout ([[0 , 1 ], [0 , 2 ]],
2531+ [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 0 ]], [], [],
2532+ [16 , 4 ])
2533+ a_layout : ttgl .constexpr = ttgl .DotOperandLayout (0 , wmma_layout_packed , 16 )
2534+ b_layout : ttgl .constexpr = ttgl .DotOperandLayout (1 , wmma_layout_packed , 16 )
2535+
2536+ a = ttgl .full ([16 , 64 ], 0x11 , ttgl .uint8 , a_layout )
2537+ b = ttgl .full ([64 , 16 ], 0x22 , ttgl .uint8 , b_layout )
2538+ b_scale = ttgl .full ([16 , 4 ], 0x01 , ttgl .uint8 , scale_layout )
2539+ acc = ttgl .full ([16 , 16 ], 0 , ttgl .float32 , wmma_layout )
2540+
2541+ ttgl .amd .gfx1250 .wmma_scaled (a , None , 'e2m1' , b , b_scale , 'e2m1' , acc )
2542+
2543+ with pytest .raises (CompilationError ) as e :
2544+ run_parser (kernel , target = target )
2545+
2546+ assert "Scales must not be None" in str (e .value )
2547+
2548+
25002549@gluon .jit
25012550def padded_shared_layout_kernel ():
25022551 shape : ttgl .constexpr = [64 , 64 ]
0 commit comments