Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit b1b5364

Browse files
William FedusMesh TensorFlow Team
authored andcommitted
Clarify names within Mesh Tensorflow
PiperOrigin-RevId: 352603360
1 parent 914bb1f commit b1b5364

File tree

3 files changed

+55
-56
lines changed

3 files changed

+55
-56
lines changed

mesh_tensorflow/transformer/attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -656,11 +656,11 @@ def __init__(self,
656656
min_expert_capacity=experts_hparams.min_expert_capacity,
657657
capacity_factor_train=experts_hparams.capacity_factor_train,
658658
capacity_factor_eval=experts_hparams.capacity_factor_eval,
659-
rand_1_policy_train=experts_hparams.rand_1_policy_train,
660-
rand_1_policy_eval=experts_hparams.rand_1_policy_eval,
661-
rand_1_dropout=experts_hparams.rand_1_dropout,
662-
rand_1_temperature=experts_hparams.rand_1_temperature,
663-
rand_1_jitter=experts_hparams.rand_1_jitter,
659+
switch_policy_train=experts_hparams.switch_policy_train,
660+
switch_policy_eval=experts_hparams.switch_policy_eval,
661+
switch_dropout=experts_hparams.switch_dropout,
662+
switch_temperature=experts_hparams.switch_temperature,
663+
switch_jitter=experts_hparams.switch_jitter,
664664
switch_top_k=experts_hparams.switch_top_k,
665665
hidden_size=experts_hparams.hidden_size,
666666
output_dim=moe_output_dims,

mesh_tensorflow/transformer/moe.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ def __init__(self,
5353
activation="relu",
5454
moe_gating="top_2",
5555
min_expert_capacity=4,
56-
rand_1_policy_train="input_jitter",
57-
rand_1_policy_eval="input_jitter",
58-
rand_1_dropout=0.1,
59-
rand_1_temperature=1.0,
60-
rand_1_jitter=1e-2,
61-
switch_top_k=4,
56+
switch_policy_train="input_jitter",
57+
switch_policy_eval="input_jitter",
58+
switch_dropout=0.1,
59+
switch_temperature=1.0,
60+
switch_jitter=1e-2,
61+
ntlb_top_k=4,
6262
output_dim=None,
6363
use_experts_attention=False):
6464
self._hparams = HParams(
@@ -76,13 +76,13 @@ def __init__(self,
7676
moe_second_threshold_train=second_threshold_train,
7777
moe_second_threshold_eval=second_threshold_eval,
7878
moe_dropout_rate=dropout_rate,
79-
moe_rand_1_policy_train=rand_1_policy_train,
80-
moe_rand_1_policy_eval=rand_1_policy_eval,
81-
moe_rand_1_dropout=rand_1_dropout,
82-
moe_rand_1_temperature=rand_1_temperature,
83-
moe_rand_1_jitter=rand_1_jitter,
79+
moe_switch_policy_train=switch_policy_train,
80+
moe_switch_policy_eval=switch_policy_eval,
81+
moe_switch_dropout=switch_dropout,
82+
moe_switch_temperature=switch_temperature,
83+
moe_switch_jitter=switch_jitter,
8484
moe_output_dim=output_dim,
85-
moe_switch_top_k=switch_top_k,
85+
moe_ntlb_top_k=ntlb_top_k,
8686
moe_use_experts_attention=use_experts_attention)
8787
self._activation = activation
8888

@@ -389,8 +389,8 @@ def transformer_moe_layer_v1(
389389
variable_dtype=variable_dtype,
390390
importance=nonpadding,
391391
num_microbatches=num_microbatches)
392-
elif hparams.moe_gating == "rand_1":
393-
dispatch_tensor, combine_tensor, loss = _rand_1_gating(
392+
elif hparams.moe_gating == "switch":
393+
dispatch_tensor, combine_tensor, loss = _switch_gating(
394394
inputs=inputs,
395395
outer_expert_dims=None,
396396
experts_dim=experts_dim_unsplit,
@@ -400,8 +400,8 @@ def transformer_moe_layer_v1(
400400
variable_dtype=variable_dtype,
401401
importance=nonpadding,
402402
num_microbatches=num_microbatches)
403-
elif hparams.moe_gating == "switch":
404-
dispatch_tensor, combine_tensor, loss = _switch_gating(
403+
elif hparams.moe_gating == "ntlb":
404+
dispatch_tensor, combine_tensor, loss = _ntlb_gating(
405405
inputs=inputs,
406406
outer_expert_dims=None,
407407
experts_dim=experts_dim_unsplit,
@@ -774,26 +774,26 @@ def transformer_moe_layer_v2(
774774
return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
775775

776776

777-
def _switch_gating(inputs,
778-
outer_expert_dims,
779-
experts_dim,
780-
expert_capacity_dim,
781-
hparams,
782-
train,
783-
variable_dtype,
784-
importance=None,
785-
name="switch_gating",
786-
num_microbatches=None):
787-
"""Compute a switch top-1 gating with no-token-left behind behavior."""
777+
def _ntlb_gating(inputs,
778+
outer_expert_dims,
779+
experts_dim,
780+
expert_capacity_dim,
781+
hparams,
782+
train,
783+
variable_dtype,
784+
importance=None,
785+
name="ntlb_gating",
786+
num_microbatches=None):
787+
"""Compute Switch gating with no-token-left behind (NTLB) behavior."""
788788
# SELECT EXPERT
789789
if train:
790-
policy = hparams.moe_rand_1_policy_train
790+
policy = hparams.moe_switch_policy_train
791791
else:
792-
policy = hparams.moe_rand_1_policy_eval
792+
policy = hparams.moe_switch_policy_eval
793793

794794
# Input perturbations
795795
if train and policy == "input_jitter":
796-
inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter)
796+
inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_switch_jitter)
797797

798798
gate_logits = mtf.layers.dense(
799799
inputs,
@@ -809,7 +809,7 @@ def _switch_gating(inputs,
809809
raw_gates = mtf.to_float(raw_gates)
810810

811811
# Top-k operation
812-
k_dim = mtf.Dimension("k", hparams.moe_switch_top_k)
812+
k_dim = mtf.Dimension("k", hparams.moe_ntlb_top_k)
813813
expert_gate, expert_index = mtf.top_k(
814814
raw_gates, reduced_dim=experts_dim, k_dim=k_dim)
815815
expert_mask = mtf.one_hot(expert_index, experts_dim)
@@ -913,27 +913,27 @@ def _switch_gating(inputs,
913913
return dispatch_tensor, combine_tensor, loss
914914

915915

916-
def _rand_1_gating(
916+
def _switch_gating(
917917
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
918-
hparams, train, variable_dtype, importance=None, name="rand_1_gating",
918+
hparams, train, variable_dtype, importance=None, name="switch_gating",
919919
num_microbatches=None):
920-
"""Compute a random top-1 gating."""
920+
"""Compute Switch gating."""
921921
# SELECT EXPERT
922922
if train:
923-
policy = hparams.moe_rand_1_policy_train
923+
policy = hparams.moe_switch_policy_train
924924
else:
925-
policy = hparams.moe_rand_1_policy_eval
925+
policy = hparams.moe_switch_policy_eval
926926

927927
# The internals of this function run in float32.
928928
# bfloat16 seems to reduce quality.
929929
gate_inputs = mtf.to_float(inputs)
930930

931931
# Input perturbations
932932
if train and policy == "input_dropout":
933-
gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_rand_1_dropout)
933+
gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_switch_dropout)
934934
elif train and policy == "input_jitter":
935935
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
936-
hparams.moe_rand_1_jitter)
936+
hparams.moe_switch_jitter)
937937

938938
gate_logits = mtf.layers.dense(
939939
gate_inputs,
@@ -950,15 +950,14 @@ def _rand_1_gating(
950950
mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate))
951951
elif policy == "sample":
952952
expert_index = mtf.sample_with_temperature(
953-
gate_logits, experts_dim, temperature=hparams.moe_rand_1_temperature)
953+
gate_logits, experts_dim, temperature=hparams.moe_switch_temperature)
954954
expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim)
955955
else:
956-
raise ValueError("Unknown rand_1 policy %s" % policy)
956+
raise ValueError("Unknown Switch gating policy %s" % policy)
957957

958958
expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype)
959959

960960
# LOAD BALANCING LOSS
961-
# TODO(liamfedus): Check entropy loss.
962961
group_size_dim = inputs.shape[-2]
963962
density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
964963
density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)

mesh_tensorflow/transformer/transformer_layers.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,11 @@ def __init__(self,
390390
capacity_factor_eval=2.0,
391391
moe_gating="switch",
392392
min_expert_capacity=4,
393-
rand_1_policy_train="input_jitter",
394-
rand_1_policy_eval="input_jitter",
395-
rand_1_dropout=0.0,
396-
rand_1_temperature=1.0,
397-
rand_1_jitter=1e-2,
393+
switch_policy_train="input_jitter",
394+
switch_policy_eval="input_jitter",
395+
switch_dropout=0.0,
396+
switch_temperature=1.0,
397+
switch_jitter=1e-2,
398398
switch_top_k=4,
399399
hidden_size=3072,
400400
use_experts_attention=True,
@@ -408,11 +408,11 @@ def __init__(self,
408408
min_expert_capacity=min_expert_capacity,
409409
capacity_factor_train=capacity_factor_train,
410410
capacity_factor_eval=capacity_factor_eval,
411-
rand_1_policy_train=rand_1_policy_train,
412-
rand_1_policy_eval=rand_1_policy_eval,
413-
rand_1_dropout=rand_1_dropout,
414-
rand_1_temperature=rand_1_temperature,
415-
rand_1_jitter=rand_1_jitter,
411+
switch_policy_train=switch_policy_train,
412+
switch_policy_eval=switch_policy_eval,
413+
switch_dropout=switch_dropout,
414+
switch_temperature=switch_temperature,
415+
switch_jitter=switch_jitter,
416416
switch_top_k=switch_top_k,
417417
hidden_size=hidden_size,
418418
use_experts_attention=use_experts_attention)

0 commit comments

Comments
 (0)