Skip to content

Commit b52c9ee

Browse files
committed
internal: use direct as_tensor_arg in internal launch paths; reserve try_* for FFI/tests
1 parent f72ad6b commit b52c9ee

File tree

7 files changed

+32
-76
lines changed

7 files changed

+32
-76
lines changed

crates/cubecl-attention/src/base.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,11 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
114114
config.cube_dim(),
115115
cube_count_plan.resolve(),
116116
TensorInputsLaunch::new(
117-
query
118-
.try_as_tensor_arg(line_sizes.query)
119-
.expect("valid vectorisation for query"),
120-
key.try_as_tensor_arg(line_sizes.key)
121-
.expect("valid vectorisation for key"),
122-
value
123-
.try_as_tensor_arg(line_sizes.value)
124-
.expect("valid vectorisation for value"),
117+
query.as_tensor_arg(line_sizes.query),
118+
key.as_tensor_arg(line_sizes.key),
119+
value.as_tensor_arg(line_sizes.value),
125120
),
126-
out.try_as_tensor_arg(line_sizes.out)
127-
.expect("valid vectorisation for out"),
121+
out.as_tensor_arg(line_sizes.out),
128122
cube_count_plan.as_args(),
129123
config,
130124
);

crates/cubecl-attention/src/components/args.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -527,14 +527,9 @@ impl<EG: Numeric> ConcreteInputsFactory for TensorInputs<EG> {
527527
line_sizes: &AttentionLineSizes,
528528
) -> Self::RuntimeArg<'a, R> {
529529
TensorInputsLaunch::new(
530-
query
531-
.try_as_tensor_arg(line_sizes.query)
532-
.expect("valid vectorisation for query"),
533-
key.try_as_tensor_arg(line_sizes.key)
534-
.expect("valid vectorisation for key"),
535-
value
536-
.try_as_tensor_arg(line_sizes.value)
537-
.expect("valid vectorisation for value"),
530+
query.as_tensor_arg(line_sizes.query),
531+
key.as_tensor_arg(line_sizes.key),
532+
value.as_tensor_arg(line_sizes.value),
538533
// mask.as_tensor_arg(line_sizes.value),
539534
)
540535
}
@@ -547,8 +542,7 @@ impl<EG: Numeric> ConcreteOutputFactory for Tensor<Line<EG>> {
547542
_problem: &AttentionProblem,
548543
line_sizes: &AttentionLineSizes,
549544
) -> Self::RuntimeArg<'a, R> {
550-
out.try_as_tensor_arg(line_sizes.out)
551-
.expect("valid vectorisation for out")
545+
out.as_tensor_arg(line_sizes.out)
552546
}
553547
}
554548

crates/cubecl-convolution/src/components/global/args.rs

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,11 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorIn
3838
line_sizes: &MatmulLineSizes,
3939
) -> Self::RuntimeArg<'a, R> {
4040
TensorInputsLaunch::new(
41-
lhs.data()
42-
.try_as_tensor_arg(line_sizes.lhs)
43-
.expect("valid vec lhs"),
44-
lhs.scale()
45-
.map(|it| it.try_as_tensor_arg(1).expect("vec=1"))
46-
.into(),
47-
rhs.data()
48-
.try_as_tensor_arg(line_sizes.rhs)
49-
.expect("valid vec rhs"),
50-
rhs.scale()
51-
.map(|it| it.try_as_tensor_arg(1).expect("vec=1"))
52-
.into(),
53-
bias.map(|it| {
54-
it.try_as_tensor_arg(line_sizes.out)
55-
.expect("valid vec out")
56-
})
57-
.into(),
41+
lhs.data().as_tensor_arg(line_sizes.lhs),
42+
lhs.scale().map(|it| it.as_tensor_arg(1)).into(),
43+
rhs.data().as_tensor_arg(line_sizes.rhs),
44+
rhs.scale().map(|it| it.as_tensor_arg(1)).into(),
45+
bias.map(|it| it.as_tensor_arg(line_sizes.out)).into(),
5846
)
5947
}
6048
}
@@ -116,9 +104,7 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
116104
channels_per_pixel: tile_size_k,
117105
pixels_per_column: stage_m,
118106
},
119-
lhs.data()
120-
.try_as_tensor_arg(line_sizes.lhs)
121-
.expect("valid vec lhs"),
107+
lhs.data().as_tensor_arg(line_sizes.lhs),
122108
lhs_elem,
123109
)
124110
.with_elem_stride(elem_stride)
@@ -128,15 +114,12 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
128114
TensorMapFormat::Tiled {
129115
tile_size: stage_size_rhs,
130116
},
131-
rhs.data().try_as_tensor_arg(1).expect("vec=1"),
117+
rhs.data().as_tensor_arg(1),
132118
Rhs::as_type_native_unchecked(),
133119
)
134120
.with_prefetch(prefetch_rhs);
135121

136-
let bias = bias.map(|it| {
137-
it.try_as_tensor_arg(line_sizes.out)
138-
.expect("valid vec out")
139-
});
122+
let bias = bias.map(|it| it.as_tensor_arg(line_sizes.out));
140123

141124
// TODO: Think about how to handle scales with TMA
142125
TensorMapInputsLaunch::new(lhs, rhs, bias.into())

crates/cubecl-matmul/src/components/global/args.rs

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -836,17 +836,15 @@ impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> ConcreteInputsFactory
836836
line_sizes: &MatmulLineSizes,
837837
) -> Self::RuntimeArg<'a, R> {
838838
TensorInputsLaunch::new(
839-
lhs.data()
840-
.try_as_tensor_arg(line_sizes.lhs)
841-
.expect("valid vec for lhs"),
842-
lhs.scale()
843-
.map(|it| it.try_as_tensor_arg(1).expect("vec=1"))
839+
lhs.data().as_tensor_arg(line_sizes.lhs),
840+
lhs
841+
.scale()
842+
.map(|it| it.as_tensor_arg(1))
844843
.into(),
845-
rhs.data()
846-
.try_as_tensor_arg(line_sizes.rhs)
847-
.expect("valid vec for rhs"),
848-
rhs.scale()
849-
.map(|it| it.try_as_tensor_arg(1).expect("vec=1"))
844+
rhs.data().as_tensor_arg(line_sizes.rhs),
845+
rhs
846+
.scale()
847+
.map(|it| it.as_tensor_arg(1))
850848
.into(),
851849
CubeOptionArgs::None,
852850
)
@@ -860,8 +858,7 @@ impl<EG: Numeric> ConcreteOutputFactory for Tensor<Line<EG>> {
860858
_problem: &MatmulProblem,
861859
line_sizes: &MatmulLineSizes,
862860
) -> Self::RuntimeArg<'a, R> {
863-
out.try_as_tensor_arg(line_sizes.out)
864-
.expect("valid vec for out")
861+
out.as_tensor_arg(line_sizes.out)
865862
}
866863
}
867864

@@ -1274,15 +1271,11 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
12741271
};
12751272

12761273
let lhs = TensorMapArg {
1277-
tensor: lhs
1278-
.try_as_tensor_arg(line_sizes.lhs)
1279-
.expect("valid vec for lhs"),
1274+
tensor: lhs.as_tensor_arg(line_sizes.lhs),
12801275
metadata: meta_lhs,
12811276
};
12821277
let rhs = TensorMapArg {
1283-
tensor: rhs
1284-
.try_as_tensor_arg(line_sizes.rhs)
1285-
.expect("valid vec for rhs"),
1278+
tensor: rhs.as_tensor_arg(line_sizes.rhs),
12861279
metadata: meta_rhs,
12871280
};
12881281

crates/cubecl-reduce/src/launch.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,8 @@ pub(crate) fn launch_reduce<Run: Runtime, P: ReducePrecision, Out: Numeric, Rd:
4343
client,
4444
config.cube_count,
4545
config.cube_dim,
46-
input
47-
.try_as_tensor_arg(config.line_size_input as u8)
48-
.expect("valid reduce input vec"),
49-
output
50-
.try_as_tensor_arg(config.line_size_output as u8)
51-
.expect("valid reduce output vec"),
46+
input.as_tensor_arg(config.line_size_input as u8),
47+
output.as_tensor_arg(config.line_size_output as u8),
5248
ScalarArg::new(axis),
5349
settings,
5450
inst,

crates/cubecl-reduce/src/shared_sum.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,8 @@ pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
8989
client,
9090
cube_count,
9191
cube_dim,
92-
input
93-
.try_as_tensor_arg(line_size as u8)
94-
.expect("valid vec"),
95-
output.try_as_tensor_arg(1).expect("vec=1"),
92+
input.as_tensor_arg(line_size as u8),
93+
output.as_tensor_arg(1),
9694
cube_dim.num_elems(),
9795
line_size,
9896
num_lines_per_unit,

crates/cubecl-std/src/tensor/contiguous.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,7 @@ pub fn into_contiguous_ref<R: Runtime, E: CubePrimitive>(
250250
cube_count,
251251
cube_dim,
252252
input,
253-
output
254-
.try_as_tensor_arg(out_vec)
255-
.expect("valid vec for output"),
253+
output.as_tensor_arg(out_vec),
256254
out_layout,
257255
elems_per_unit,
258256
);

0 commit comments

Comments
 (0)