|
1 | 1 | use cubecl::prelude::*; |
2 | 2 | use cubecl_core::{self as cubecl}; |
3 | | -use cubecl_std::tensor::r#virtual::{VirtualTensorOperations, VirtualTensorOperationsExpand}; |
| 3 | +use cubecl_std::{ |
| 4 | + CubeOption, CubeOptionExpand, |
| 5 | + tensor::r#virtual::{VirtualTensorOperations, VirtualTensorOperationsExpand}, |
| 6 | +}; |
4 | 7 |
|
5 | 8 | use crate::components::{ |
6 | 9 | line_size::AttentionLineSizes, problem::AttentionProblem, selection::AttentionSelection, |
@@ -88,17 +91,17 @@ pub trait AttentionArgs: Send + Sync + 'static + Clone { |
88 | 91 | /// Reinterpret query as tensor map |
89 | 92 | fn as_tensor_map_query<Q: Float, K: Float, V: Float, O: Float>( |
90 | 93 | state: &Self::State<Q, K, V, O>, |
91 | | - ) -> TensorMap<Q>; |
| 94 | + ) -> CubeOption<TensorMap<Q>>; |
92 | 95 |
|
93 | 96 | /// Reinterpret key as tensor map |
94 | 97 | fn as_tensor_map_key<Q: Float, K: Float, V: Float, O: Float>( |
95 | 98 | state: &Self::State<Q, K, V, O>, |
96 | | - ) -> TensorMap<K>; |
| 99 | + ) -> CubeOption<TensorMap<K>>; |
97 | 100 |
|
98 | 101 | /// Reinterpret value as tensor map |
99 | 102 | fn as_tensor_map_value<Q: Float, K: Float, V: Float, O: Float>( |
100 | 103 | state: &Self::State<Q, K, V, O>, |
101 | | - ) -> TensorMap<V>; |
| 104 | + ) -> CubeOption<TensorMap<V>>; |
102 | 105 |
|
103 | 106 | /// Write the line to the output at the given coordinate using the state. |
104 | 107 | fn write_out<Q: Float, K: Float, V: Float, O: Float>( |
@@ -290,11 +293,8 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> VirtualTensorOpe |
290 | 293 | TensorOutputExpand::__expand_buffer_len_method(self.clone(), scope) |
291 | 294 | } |
292 | 295 |
|
293 | | - fn __expand_as_tensor_map_method( |
294 | | - &self, |
295 | | - _scope: &mut Scope, |
296 | | - ) -> ExpandElementTyped<TensorMap<O>> { |
297 | | - unimplemented!("TensorOutputExpand can't be turned into a tensor map"); |
| 296 | + fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<O>> { |
| 297 | + CubeOption::__expand_new_None(scope) |
298 | 298 | } |
299 | 299 | } |
300 | 300 |
|
@@ -367,7 +367,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> VirtualTensorOpe |
367 | 367 | TensorQueryExpand::__expand_buffer_len_method(self.clone(), scope) |
368 | 368 | } |
369 | 369 |
|
370 | | - fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<Q>> { |
| 370 | + fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<Q>> { |
371 | 371 | TensorQueryExpand::__expand_as_tensor_map_method(self.clone(), scope) |
372 | 372 | } |
373 | 373 | } |
@@ -441,7 +441,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> VirtualTensorOpe |
441 | 441 | TensorKeyExpand::__expand_buffer_len_method(self.clone(), scope) |
442 | 442 | } |
443 | 443 |
|
444 | | - fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<K>> { |
| 444 | + fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<K>> { |
445 | 445 | TensorKeyExpand::__expand_as_tensor_map_method(self.clone(), scope) |
446 | 446 | } |
447 | 447 | } |
@@ -515,7 +515,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> VirtualTensorOpe |
515 | 515 | TensorValueExpand::__expand_buffer_len_method(self.clone(), scope) |
516 | 516 | } |
517 | 517 |
|
518 | | - fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<V>> { |
| 518 | + fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<V>> { |
519 | 519 | TensorValueExpand::__expand_as_tensor_map_method(self.clone(), scope) |
520 | 520 | } |
521 | 521 | } |
@@ -606,7 +606,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> TensorQuery<Q, K |
606 | 606 | } |
607 | 607 |
|
608 | 608 | /// Get the buffer length of the tensor. |
609 | | - pub fn as_tensor_map(&self) -> TensorMap<Q> { |
| 609 | + pub fn as_tensor_map(&self) -> CubeOption<TensorMap<Q>> { |
610 | 610 | unsafe { MA::as_tensor_map_query(&(*self.state)) } |
611 | 611 | } |
612 | 612 |
|
@@ -660,7 +660,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> TensorKey<Q, K, |
660 | 660 | } |
661 | 661 |
|
662 | 662 | /// Get the buffer length of the tensor. |
663 | | - pub fn as_tensor_map(&self) -> TensorMap<K> { |
| 663 | + pub fn as_tensor_map(&self) -> CubeOption<TensorMap<K>> { |
664 | 664 | unsafe { MA::as_tensor_map_key(&(*self.state)) } |
665 | 665 | } |
666 | 666 |
|
@@ -714,7 +714,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> TensorValue<Q, K |
714 | 714 | } |
715 | 715 |
|
716 | 716 | /// Get the buffer length of the tensor. |
717 | | - pub fn as_tensor_map(&self) -> TensorMap<V> { |
| 717 | + pub fn as_tensor_map(&self) -> CubeOption<TensorMap<V>> { |
718 | 718 | unsafe { MA::as_tensor_map_value(&(*self.state)) } |
719 | 719 | } |
720 | 720 |
|
@@ -885,26 +885,20 @@ impl AttentionArgs for TensorArgs { |
885 | 885 |
|
886 | 886 | fn as_tensor_map_query<Q: Float, K: Float, V: Float, O: Float>( |
887 | 887 | _state: &Self::State<Q, K, V, O>, |
888 | | - ) -> TensorMap<Q> { |
889 | | - comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`")); |
890 | | - #[allow(unreachable_code)] |
891 | | - TensorMap::dummy() |
| 888 | + ) -> CubeOption<TensorMap<Q>> { |
| 889 | + CubeOption::new_None() |
892 | 890 | } |
893 | 891 |
|
894 | 892 | fn as_tensor_map_key<Q: Float, K: Float, V: Float, O: Float>( |
895 | 893 | _state: &Self::State<Q, K, V, O>, |
896 | | - ) -> TensorMap<K> { |
897 | | - comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`")); |
898 | | - #[allow(unreachable_code)] |
899 | | - TensorMap::dummy() |
| 894 | + ) -> CubeOption<TensorMap<K>> { |
| 895 | + CubeOption::new_None() |
900 | 896 | } |
901 | 897 |
|
902 | 898 | fn as_tensor_map_value<Q: Float, K: Float, V: Float, O: Float>( |
903 | 899 | _state: &Self::State<Q, K, V, O>, |
904 | | - ) -> TensorMap<V> { |
905 | | - comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`")); |
906 | | - #[allow(unreachable_code)] |
907 | | - TensorMap::dummy() |
| 900 | + ) -> CubeOption<TensorMap<V>> { |
| 901 | + CubeOption::new_None() |
908 | 902 | } |
909 | 903 |
|
910 | 904 | fn shape_query<Q: Float, K: Float, V: Float, O: Float>( |
|
0 commit comments