@@ -2,61 +2,156 @@ use std::any::TypeId;
22
33use cubecl:: prelude:: * ;
44use cubecl_core as cubecl;
5+ use cubecl_std:: {
6+ FastDivmodArgs ,
7+ tensor:: {
8+ View ,
9+ launch:: ViewArg ,
10+ layout:: {
11+ Coords3d ,
12+ chain:: { Chain , ChainLaunch } ,
13+ } ,
14+ } ,
15+ } ;
516
617use crate :: {
7- components:: ConvolutionProblem ,
18+ components:: {
19+ ConvGemmConfig , ConvolutionProblem ,
20+ global:: {
21+ layout:: {
22+ BiasLayout , BiasLayoutLaunch , Im2colLayout , Im2colLayoutLaunch , NhwcLayout ,
23+ NhwcLayoutLaunch , OutLayout , OutLayoutLaunch , WeightLayout , WeightLayoutLaunch ,
24+ } ,
25+ read:: layout:: {
26+ TmaDummyLayout , TmaDummyLayoutLaunch , TmaWeightLayout , TmaWeightLayoutLaunch ,
27+ } ,
28+ } ,
29+ } ,
830 kernels:: layered:: algorithm:: simple_tma:: { calculate_lower_corner, calculate_upper_corner} ,
931} ;
1032use cubecl_matmul:: {
1133 MatmulInputHandleRef ,
1234 components:: {
13- MatmulLineSizes , MatmulSelection ,
35+ MatmulIdent , MatmulLineSizes , MatmulSelection ,
1436 global:: args:: { TensorInputs , TensorInputsLaunch , TensorMapInputs , TensorMapInputsLaunch } ,
1537 } ,
1638} ;
1739
1840/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
1941/// output (not fused).
2042pub trait ConcreteInputsFactory : LaunchArg {
43+ #[ allow( clippy:: too_many_arguments) ]
2144 fn create < ' a , R : Runtime > (
45+ client : & ComputeClient < R :: Server , R :: Channel > ,
2246 lhs : & ' a MatmulInputHandleRef < ' a , R > ,
2347 rhs : & ' a MatmulInputHandleRef < ' a , R > ,
2448 bias : Option < & ' a TensorHandleRef < ' a , R > > ,
2549 selection : & MatmulSelection ,
2650 problem : & ConvolutionProblem ,
2751 line_sizes : & MatmulLineSizes ,
52+ config : impl ConvGemmConfig ,
53+ ) -> Self :: RuntimeArg < ' a , R > ;
54+ }
55+
56+ /// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
57+ /// output (not fused).
58+ pub trait ConcreteOutputFactory : LaunchArg {
59+ fn create < ' a , R : Runtime > (
60+ client : & ComputeClient < R :: Server , R :: Channel > ,
61+ out : & ' a TensorHandleRef < ' a , R > ,
62+ selection : & MatmulSelection ,
63+ problem : & ConvolutionProblem ,
64+ line_sizes : & MatmulLineSizes ,
65+ config : impl ConvGemmConfig ,
2866 ) -> Self :: RuntimeArg < ' a , R > ;
2967}
3068
3169impl < Lhs : Numeric , Rhs : Numeric , EO : Numeric > ConcreteInputsFactory for TensorInputs < Lhs , Rhs , EO > {
3270 fn create < ' a , R : Runtime > (
71+ client : & ComputeClient < R :: Server , R :: Channel > ,
3372 lhs : & ' a MatmulInputHandleRef < ' a , R > ,
3473 rhs : & ' a MatmulInputHandleRef < ' a , R > ,
3574 bias : Option < & ' a TensorHandleRef < ' a , R > > ,
3675 _selection : & MatmulSelection ,
37- _problem : & ConvolutionProblem ,
76+ problem : & ConvolutionProblem ,
3877 line_sizes : & MatmulLineSizes ,
78+ config : impl ConvGemmConfig ,
3979 ) -> Self :: RuntimeArg < ' a , R > {
80+ type LhsLayout = Chain < NhwcLayout , Im2colLayout > ;
81+ type RhsLayout = Chain < NhwcLayout , WeightLayout > ;
82+
83+ let layout_nhwc = |handle, line_size, check| {
84+ NhwcLayoutLaunch :: from_handle ( handle, line_size as u32 , check)
85+ } ;
86+ let layout_lhs = Im2colLayoutLaunch :: from_args (
87+ client,
88+ problem,
89+ config. convolution_params ( ) ,
90+ config. global_memory_config ( MatmulIdent :: Lhs ) ,
91+ ) ;
92+ let layout_rhs = WeightLayoutLaunch :: from_args (
93+ client,
94+ problem,
95+ config. convolution_params ( ) ,
96+ config. global_memory_config ( MatmulIdent :: Rhs ) ,
97+ ) ;
98+ let layout_bias =
99+ BiasLayoutLaunch :: new ( ScalarArg :: new ( problem. n as u32 ) , line_sizes. out as u32 ) ;
100+
101+ let layout_lhs = {
102+ let global = layout_nhwc ( lhs. data ( ) , line_sizes. lhs , config. check_spatial_bounds ( ) ) ;
103+ ChainLaunch :: new ( global, layout_lhs)
104+ } ;
105+ let layout_rhs = {
106+ let global = layout_nhwc ( rhs. data ( ) , line_sizes. rhs , false ) ;
107+ ChainLaunch :: new ( global, layout_rhs)
108+ } ;
109+
40110 TensorInputsLaunch :: new (
41- lhs. data ( ) . as_tensor_arg ( line_sizes. lhs ) ,
42- lhs. scale ( ) . map ( |it| it. as_tensor_arg ( 1 ) ) . into ( ) ,
43- rhs. data ( ) . as_tensor_arg ( line_sizes. rhs ) ,
44- rhs. scale ( ) . map ( |it| it. as_tensor_arg ( 1 ) ) . into ( ) ,
45- bias. map ( |it| it. as_tensor_arg ( line_sizes. out ) ) . into ( ) ,
111+ ViewArg :: new :: < LhsLayout > ( lhs. data ( ) . as_array_arg ( line_sizes. lhs ) , layout_lhs) ,
112+ ViewArg :: new :: < RhsLayout > ( rhs. data ( ) . as_array_arg ( line_sizes. rhs ) , layout_rhs) ,
113+ bias. map ( |bias| {
114+ ViewArg :: new :: < BiasLayout > ( bias. as_array_arg ( line_sizes. out ) , layout_bias)
115+ } )
116+ . into ( ) ,
46117 )
47118 }
48119}
49120
121+ impl < EG : Numeric > ConcreteOutputFactory for View < Line < EG > , Coords3d , ReadWrite > {
122+ fn create < ' a , R : Runtime > (
123+ client : & ComputeClient < R :: Server , R :: Channel > ,
124+ out : & ' a TensorHandleRef < ' a , R > ,
125+ _selection : & MatmulSelection ,
126+ problem : & ConvolutionProblem ,
127+ line_sizes : & MatmulLineSizes ,
128+ config : impl ConvGemmConfig ,
129+ ) -> Self :: RuntimeArg < ' a , R > {
130+ type Layout = Chain < NhwcLayout , OutLayout > ;
131+
132+ let global = NhwcLayoutLaunch :: from_handle ( out, line_sizes. out as u32 , false ) ;
133+ let layout = OutLayoutLaunch :: from_args (
134+ client,
135+ problem,
136+ config. global_memory_config ( MatmulIdent :: Out ) ,
137+ ) ;
138+ let layout = ChainLaunch :: new ( global, layout) ;
139+ ViewArg :: new :: < Layout > ( out. as_array_arg ( line_sizes. out ) , layout)
140+ }
141+ }
142+
50143impl < Lhs : Numeric , Rhs : Numeric , EO : Numeric > ConcreteInputsFactory
51144 for TensorMapInputs < Lhs , Rhs , EO >
52145{
53146 fn create < ' a , R : Runtime > (
147+ client : & ComputeClient < R :: Server , R :: Channel > ,
54148 lhs : & ' a MatmulInputHandleRef < ' a , R > ,
55149 rhs : & ' a MatmulInputHandleRef < ' a , R > ,
56150 bias : Option < & ' a TensorHandleRef < ' a , R > > ,
57151 selection : & MatmulSelection ,
58152 problem : & ConvolutionProblem ,
59153 line_sizes : & MatmulLineSizes ,
154+ config : impl ConvGemmConfig ,
60155 ) -> Self :: RuntimeArg < ' a , R > {
61156 let tiling_scheme = selection. tiling_scheme ;
62157 let stage_m = tiling_scheme. elements_in_stage_m ( ) ;
@@ -119,9 +214,23 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
119214 )
120215 . with_prefetch ( prefetch_rhs) ;
121216
122- let bias = bias. map ( |it| it. as_tensor_arg ( line_sizes. out ) ) ;
217+ let padded_channels =
218+ ( problem. channels as u32 ) . next_multiple_of ( config. tiling_scheme ( ) . elements_in_tile_k ( ) ) ;
219+
220+ // Dummy layout since we don't support im2col loading rn
221+ let lhs_layout = TmaDummyLayoutLaunch :: new ( ) ;
222+ let rhs_layout = TmaWeightLayoutLaunch :: new ( FastDivmodArgs :: new ( client, padded_channels) ) ;
123223
124- // TODO: Think about how to handle scales with TMA
125- TensorMapInputsLaunch :: new ( lhs, rhs, bias. into ( ) )
224+ let bias = bias. map ( |bias| {
225+ let layout =
226+ BiasLayoutLaunch :: new ( ScalarArg :: new ( problem. n as u32 ) , line_sizes. out as u32 ) ;
227+ ViewArg :: new :: < BiasLayout > ( bias. as_array_arg ( line_sizes. out ) , layout)
228+ } ) ;
229+
230+ TensorMapInputsLaunch :: new (
231+ ViewArg :: new_tensor_map :: < TmaDummyLayout > ( lhs, lhs_layout) ,
232+ ViewArg :: new_tensor_map :: < TmaWeightLayout > ( rhs, rhs_layout) ,
233+ bias. into ( ) ,
234+ )
126235 }
127236}
0 commit comments