@@ -1396,13 +1396,13 @@ def _maybe_make_contiguous(
13961396 ) -> torch .nn .Parameter :
13971397 """
13981398 In some cases, the last 2 dimensions (the non-expert dimensions)
1399- of the weight scale tensor are transposed. This function transposes
1400- the tensor back so the tensor is contiguous().
1401- Example: A scale tensor,
1402- `x` of shape (E, 32, 16) and stride (512, 1, 32) is transposed to
1403- `xt ` of shape (E, 16, 32) and stride (512, 32, 1).
1404- Note that we specifically use torch.transpose() so `xt ` refers
1405- to the same underlying memory. The tensors `x` and `xt `, pointing
1399+ of the weight scale tensor are transposed. This function
1400+ transforms the tensor (view update) so the tensor is contiguous().
1401+ Example: A non-contiguous scale tensor,
1402+ `x` of shape (E, 32, 16) and stride (512, 1, 32) is transformed to
1403+ `x_ ` of shape (E, 16, 32) and stride (512, 32, 1).
1404+ Note that we specifically use torch.transpose() so `x_ ` refers
1405+ to the same underlying memory. The tensors `x` and `x_ `, pointing
14061406 to the same underlying memory make this transformation safe in the
14071407 context of EPLB. i.e. It is the same memory and just the view
14081408 is different.
0 commit comments