Skip to content

Commit a85222c

Browse files
authored
fix: Fix packed fp4 casting from int (#890)
* Fix packed fp4 casting from int * Fix matmul and conv2d test launch
1 parent 018341e commit a85222c

File tree

4 files changed

+33
-15
lines changed

4 files changed

+33
-15
lines changed

crates/cubecl-convolution/src/tests/convolution_test_launcher.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ pub fn test_convolution_algorithm<A, Args, P, R>(
6161
.pick_max()
6262
.unwrap();
6363

64-
let config = match A::setup::<R, (P::EG, P::EG, P::ES, P::ES, f32, P::EG)>(
64+
let config = match A::setup::<R, (P::EG, P::EG, P::EG, P::ES, P::ES, f32)>(
6565
&client,
6666
&problem,
6767
&selection,

crates/cubecl-cpp/src/cuda/convert.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ pub(crate) fn special_cast<D: Dialect>(
5757
current_in = out_var;
5858
}
5959

60+
// Broadcast scalars to packing factor
61+
if out.item().packing_factor() > 1 && input.item().vectorization == 1 {
62+
let tmp = Variable::tmp(Item {
63+
elem: input.item().elem,
64+
vectorization: out.item().packing_factor(),
65+
native: input.item().native,
66+
});
67+
let assign = Instruction::Assign(UnaryInstruction {
68+
input: current_in,
69+
out: tmp,
70+
});
71+
writeln!(f, "{assign}")?;
72+
current_in = tmp;
73+
}
74+
6075
if matches!(
6176
current_in.elem(),
6277
Elem::U8
@@ -72,8 +87,8 @@ pub(crate) fn special_cast<D: Dialect>(
7287
// Precision is irrelevant for int, so use bf16 for the range
7388
let tmp = Variable::tmp(Item {
7489
elem: Elem::BF16,
75-
vectorization: input.item().vectorization,
76-
native: input.item().native,
90+
vectorization: current_in.item().vectorization,
91+
native: current_in.item().native,
7792
});
7893
let assign = Instruction::Assign(UnaryInstruction {
7994
input: current_in,
@@ -137,7 +152,7 @@ fn cast_to_fp4_fp6<D: Dialect>(
137152
_ => unreachable!("Must be fp4 or fp6"),
138153
};
139154

140-
let in_ty = match input.elem() {
155+
let in_ty = match input.elem().unpacked() {
141156
Elem::F64 => format!("double{pack_suffix}"),
142157
Elem::TF32 | Elem::F32 => format!("float{pack_suffix}"),
143158
Elem::F16 => format!("halfraw{pack_suffix}"),

crates/cubecl-cpp/src/shared/base.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,18 +1136,26 @@ impl<D: Dialect> CppCompiler<D> {
11361136
gpu::Operator::Cast(op)
11371137
if is_fp4_fp6_fp8(op.input.elem_type()) || is_fp4_fp6_fp8(out.elem_type()) =>
11381138
{
1139-
let inst = self.compile_unary(op, out);
1140-
11411139
// We may need these for intermediates
11421140
self.flags.elem_f16 = true;
11431141
self.flags.elem_bf16 = true;
1144-
let vec = inst.input.item().vectorization as u32;
1142+
let vec_in = op.input.ty.line_size();
1143+
let packing = out.storage_type().packing_factor();
1144+
self.compile_type(op.input.ty.line(packing));
1145+
self.compile_type(
1146+
gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec_in),
1147+
);
11451148
self.compile_type(
1146-
gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(vec),
1149+
gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec_in),
11471150
);
11481151
self.compile_type(
1149-
gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(vec),
1152+
gpu::Type::scalar(gpu::ElemType::Float(FloatKind::F16)).line(packing),
11501153
);
1154+
self.compile_type(
1155+
gpu::Type::scalar(gpu::ElemType::Float(FloatKind::BF16)).line(packing),
1156+
);
1157+
1158+
let inst = self.compile_unary(op, out);
11511159

11521160
instructions.push(Instruction::SpecialCast(inst));
11531161
}

crates/cubecl-matmul/src/tests/layered/matmul_test_launcher.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,7 @@ pub fn test_matmul_algorithm<A, P, R>(
6060
.pick_max()
6161
.unwrap();
6262

63-
let config = match A::setup::<(P::EG, P::EG, P::ES, P::ES, P::EA, P::EG), R>(
64-
&client,
65-
&problem,
66-
&selection,
67-
&line_sizes,
68-
) {
63+
let config = match A::setup::<P::MP, R>(&client, &problem, &selection, &line_sizes) {
6964
Ok(config) => config,
7065
Err(err) => {
7166
let msg = format!("Can't launch the test: {err}");

0 commit comments

Comments
 (0)