@@ -1225,29 +1225,35 @@ def load_weights(
12251225 yield param_name
12261226
12271227 def get_expert_weights (self ) -> Iterable [torch .Tensor ]:
1228- def maybe_make_contiguous (name : str , p : torch .nn .Parameter ) -> torch .nn .Parameter :
1228+ def _maybe_make_contiguous (
1229+ name : str , p : torch .nn .Parameter
1230+ ) -> torch .nn .Parameter :
12291231 """
1230- Expert weight-scales are transposed and are represented
1231- in column-major. This function transposes the tensor back
1232- so the tensor is contiguous().
1232+ In some cases, the last 2 dimensions (the non-expert dimensions)
1233+ of the weight scale tensor are transposed. This function transposes
1234+ the tensor back so the tensor is contiguous().
1235+ Example: A scale tensor,
1236+ `x` of shape (E, 32, 16) and stride (512, 1, 32) is transposed to
1237+ `xt` of shape (E, 16, 32) and stride (512, 32, 1).
1238+ Note that we specifically use torch.transpose() so `xt` refers
1239+ to the same underlying memory. The tensors `x` and `xt`, pointing
1240+ to the same underlying memory make this transformation safe in the
1241+ context of EPLB. i.e. It is the same memory and just the view
1242+ is different.
1243+ Note: This function handles the "weight_scale" tensors specifically.
1244+ This could however be generalized to handle similar tensors.
12331245 """
1234- if p . is_contiguous ():
1235- return p
1236- if "weight_scale" not in name :
1246+ # Check if the last 2 dimensions are trasposed
1247+ is_transposed = p . stride ( 1 ) == 1 and p . stride ( 2 ) != 1
1248+ if p . is_contiguous () or not is_transposed or "weight_scale" not in name :
12371249 # do nothing.
12381250 return p
12391251 assert p .ndim == 3
1240- # Check if the tensor is tranposed
1241- is_colmajor = p .size (1 ) == 1 and p .size (2 ) != 1
1242- p = torch .transpose (p , 1 , 2 )
1243- assert p .is_contiguous ()
1252+ p .data = torch .transpose (p .data , 1 , 2 )
12441253 return p
12451254
12461255 weights = list (self .named_parameters ())
1247- weights = [ (name , maybe_make_contiguous (name , p )) for name , p in weights ]
1248-
1249- #for name, weight in weights:
1250- # print (f"{name} is_contiguous() ? {weight.is_contiguous()}")
1256+ weights = [(name , _maybe_make_contiguous (name , p )) for name , p in weights ]
12511257
12521258 assert all (
12531259 weight .is_contiguous ()
0 commit comments