@@ -28,6 +28,7 @@ pub fn launch<R: Runtime, AP: AttentionPrecision>(
2828 query : TensorHandle < R , QG < AP > > ,
2929 key : TensorHandle < R , KG < AP > > ,
3030 value : TensorHandle < R , VG < AP > > ,
31+ mask : Option < TensorHandle < R , MSK < AP > > > ,
3132 out : TensorHandle < R , OG < AP > > ,
3233) -> Result < ( ) , AttentionSetupError > {
3334 launch_ref :: < R , AP > (
@@ -36,6 +37,7 @@ pub fn launch<R: Runtime, AP: AttentionPrecision>(
3637 & query. as_ref ( ) ,
3738 & key. as_ref ( ) ,
3839 & value. as_ref ( ) ,
40+ & mask. as_ref ( ) . map ( |m| m. as_ref ( ) ) ,
3941 & out. as_ref ( ) ,
4042 )
4143}
@@ -47,10 +49,11 @@ pub fn launch_ref<R: Runtime, AP: AttentionPrecision>(
4749 query : & TensorHandleRef < R > ,
4850 key : & TensorHandleRef < R > ,
4951 value : & TensorHandleRef < R > ,
52+ mask : & Option < TensorHandleRef < R > > ,
5053 out : & TensorHandleRef < R > ,
5154) -> Result < ( ) , AttentionSetupError > {
5255 match strategy {
53- Strategy :: Tmp => launch_tmp :: < R , AP > ( client, query, key, value, out) ,
56+ Strategy :: Tmp => launch_tmp :: < R , AP > ( client, query, key, value, mask , out) ,
5457 }
5558}
5659
@@ -59,6 +62,7 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
5962 query : & TensorHandleRef < R > ,
6063 key : & TensorHandleRef < R > ,
6164 value : & TensorHandleRef < R > ,
65+ mask : & Option < TensorHandleRef < R > > ,
6266 out : & TensorHandleRef < R > ,
6367) -> Result < ( ) , AttentionSetupError > {
6468 let line_sizes = AvailableLineSizes :: from_elem_types :: < R > (
@@ -81,7 +85,8 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
8185 num_heads : query. shape [ 2 ] ,
8286 head_dim : query. shape [ 3 ] ,
8387 val_dim : value. shape [ 3 ] ,
84- masked : false ,
88+ masked : mask. is_some ( ) ,
89+ causal : false ,
8590 } ;
8691
8792 let tile_size = AttentionTileSize {
@@ -123,6 +128,9 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
123128 query. as_tensor_arg ( line_sizes. query ) ,
124129 key. as_tensor_arg ( line_sizes. key ) ,
125130 value. as_tensor_arg ( line_sizes. value ) ,
131+ mask. as_ref ( )
132+ . map ( |it| it. as_tensor_arg ( line_sizes. out ) )
133+ . into ( ) ,
126134 ) ,
127135 out. as_tensor_arg ( line_sizes. out ) ,
128136 cube_count_plan. as_args ( ) ,
0 commit comments