@@ -193,7 +193,7 @@ def _create_r1_scheme(self) -> TransformScheme:
193193 randomize = self .randomize ,
194194 requires_grad = self .learnable ,
195195 precision = self .precision ,
196- block_size = self .transform_block_size ,
196+ head_dim = self .transform_block_size ,
197197 apply = [
198198 TransformArgs (
199199 targets = [
@@ -240,7 +240,7 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
240240 randomize = self .randomize ,
241241 requires_grad = self .learnable ,
242242 precision = self .precision ,
243- block_size = head_dim ,
243+ head_dim = head_dim ,
244244 apply = [
245245 TransformArgs (targets = [self .mappings .attn_v ], location = "weight_output" ),
246246 TransformArgs (
@@ -262,7 +262,7 @@ def _create_r4_scheme(self) -> TransformScheme:
262262 randomize = self .randomize ,
263263 requires_grad = self .learnable ,
264264 precision = self .precision ,
265- block_size = self .transform_block_size ,
265+ head_dim = self .transform_block_size ,
266266 apply = [
267267 TransformArgs (
268268 targets = [* self .mappings .mlp_out ],
0 commit comments