@@ -27,20 +27,15 @@ def lazy_initialization(self, key_states: torch.Tensor):
2727 assert (
2828 hasattr (key_states , "shape" ) and key_states is not None
2929 ), f"Attribute 'shape' is wrong for type { type (key_states )} "
30- assert isinstance (key_states .shape , tuple ), (
31- f"Unxpected type { type (key_states .shape )} for key_states.shape, "
32- f"__dict__={ key_states .shape .__dict__ } "
33- )
34- new_shape = list (key_states .shape )
35- new_shape [- 2 ] = 0
30+ like = key_states [:, :0 ]
3631 # PATCHED: used a tensor with an empty shape and not en empty list to initialize
3732 if isinstance (key_states , torch ._subclasses .fake_tensor .FakeTensor ):
3833 with key_states .fake_mode :
39- self .keys = torch .empty ( new_shape , dtype = self .dtype , device = self .device )
40- self .values = torch .empty ( new_shape , dtype = self .dtype , device = self .device )
34+ self .keys = torch .empty_like ( like , dtype = self .dtype , device = self .device )
35+ self .values = torch .empty_like ( like , dtype = self .dtype , device = self .device )
4136 else :
42- self .keys = torch .empty ( new_shape , dtype = self .dtype , device = self .device )
43- self .values = torch .empty ( new_shape , dtype = self .dtype , device = self .device )
37+ self .keys = torch .empty_like ( like , dtype = self .dtype , device = self .device )
38+ self .values = torch .empty_like ( like , dtype = self .dtype , device = self .device )
4439 if patch_is_initialized :
4540 self .is_initialized = True
4641
0 commit comments