Skip to content

Commit 0f70937

Browse files
dudongsufacebook-github-bot
authored andcommitted
add one mode in Rowwise_adagrad_with_counter (pytorch#4920)
Summary: X-link: facebookresearch/FBGEMM#1944 Pull Request resolved: pytorch#4920 We have tested a new method which instead of adaptive the learning rate with sqrt(t), we use t to schedule the learning rate with a predefined maximum learning rate. We will use counter_halflife = -2 for this new mode under adagradW Reviewed By: spcyppt Differential Revision: D82576034 fbshipit-source-id: 6e01c3489909ecf3382613e822caf63d90e8f1f4
1 parent 1de2434 commit 0f70937

File tree

2 files changed

+44
-26
lines changed

2 files changed

+44
-26
lines changed

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
186186
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
187187
"""
188188
)
189-
split_precomputation += """
189+
split_precomputation += """
190190
// Define the rowwise adagrad optimizer state struct view
191191
struct [[maybe_unused]] OptimizerState {
192192
at::acc_type<cache_t, true> momentum;
@@ -197,17 +197,17 @@ def rowwise_adagrad() -> Dict[str, Any]:
197197
198198
at::acc_type<cache_t, true> multiplier = 0.0;
199199
at::acc_type<cache_t, true> correction = 0.0;
200-
if (threadIdx.x == 0) {
200+
if (threadIdx.x == 0) {
201201
auto new_sum_square_grads = g_avg_square;
202-
203-
// Update the optimizer state. Use optimizer state offloading only if
202+
203+
// Update the optimizer state. Use optimizer state offloading only if
204204
// SSD and if enabled by the user
205205
if (enable_optimizer_offloading) {
206206
// Fetch the pointer to the optimizer state along the cache row
207207
auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
208208
new_sum_square_grads += optimizer->momentum;
209209
optimizer->momentum = new_sum_square_grads;
210-
210+
211211
} else {
212212
new_sum_square_grads += momentum1[idx];
213213
momentum1[idx] = new_sum_square_grads;
@@ -570,14 +570,17 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
570570
if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3)
571571
if (adjustment_enabled) {
572572
if (weight_decay_mode == 3) { // AdagradW (weight_decay_mode=3)
573-
if (counter_halflife < 0) {
573+
if (counter_halflife == -1) {
574574
adjusted_multiplier = multiplier * sqrtf(row_counter[idx] * 1.0);
575-
exp_reg_correction = 1.0 - weight_decay * learning_rate;
576-
const auto lazy_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
577-
const auto lazy_multiplier = powf(exp_reg_correction, min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
578-
adjusted_multiplier *= lazy_multiplier;
579-
exp_reg_correction *= lazy_multiplier;
580575
}
576+
else if (counter_halflife == -2) {
577+
adjusted_multiplier = min(learning_rate * powf(row_counter[idx] * 1.0, 1.0), adjustment_ub) / (sqrtf(new_sum_square_grads) + eps);
578+
}
579+
exp_reg_correction = 1.0 - weight_decay * learning_rate;
580+
const auto lazy_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
581+
const auto lazy_multiplier = powf(exp_reg_correction, min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
582+
adjusted_multiplier *= lazy_multiplier;
583+
exp_reg_correction *= lazy_multiplier;
581584
} else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
582585
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
583586
} else if (weight_decay_mode == 1) { // L2 regularization (coupled wd)
@@ -1040,8 +1043,8 @@ def adam() -> Dict[str, Any]:
10401043
DEVICE_INLINE momentum2_ph_t* momentum2_ptr(const int32_t D) {
10411044
// Cast to uintptr_t for pointer arithmetic
10421045
auto addr = reinterpret_cast<uintptr_t>(momentum1_ptr() + D);
1043-
1044-
// Cast back to momentum2_ph_t* and return
1046+
1047+
// Cast back to momentum2_ph_t* and return
10451048
return reinterpret_cast<momentum2_ph_t *>(addr);
10461049
}
10471050
};
@@ -1179,16 +1182,16 @@ def partial_rowwise_adam() -> Dict[str, Any]:
11791182
struct OptimizerState {
11801183
// momentum2 is a single value placed at the beginning of the struct
11811184
momentum2_ph_t momentum2;
1182-
1185+
11831186
// momentum1 is an array of values placed after momentum2, aligned to 4-byte boundary
11841187
// to support mixed state precision (e.g. FP32 momentum1 and FP16 momentum2)
11851188
alignas(4) momentum1_ph_t momentum1[1];
1186-
1189+
11871190
// momentum2_ptr returns a pointer to the beginning of the struct
11881191
DEVICE_INLINE momentum2_ph_t* momentum2_ptr() {
11891192
return &momentum2;
11901193
}
1191-
1194+
11921195
// momentum1_ptr returns a pointer to the beginning of the momentum1 array
11931196
DEVICE_INLINE momentum1_ph_t* momentum1_ptr() {
11941197
return momentum1;
@@ -1231,11 +1234,11 @@ def partial_rowwise_adam() -> Dict[str, Any]:
12311234
// Create a Vec4T for momentum1 values - either directly from momentum1_start
12321235
// or from a temporary aligned buffer if optimizer offloading is enabled
12331236
Vec4T<momentum1_ph_t> m_t;
1234-
1237+
12351238
if (enable_optimizer_offloading) {
1236-
// When offloading is enabled, we need to ensure proper alignment, so
1239+
// When offloading is enabled, we need to ensure proper alignment, so
12371240
// first copy to a temporary aligned array before loading to Vec4T
1238-
m_t = vec4_load_unaligned(momentum1_start + d);
1241+
m_t = vec4_load_unaligned(momentum1_start + d);
12391242
m_t.mul_(beta1);
12401243
m_t.fma_(grad, 1.0 - beta1);
12411244
vec4_store_unaligned(m_t, momentum1_start + d);
@@ -1247,7 +1250,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
12471250
m_t.fma_(grad, 1.0 - beta1);
12481251
m_t.store(&momentum1_start[d]);
12491252
}
1250-
1253+
12511254
// Update weights using the momentum values
12521255
weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x);
12531256
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y);

fbgemm_gpu/test/tbe/training/backward_optimizers_test.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def execute_backward_optimizers_( # noqa C901
108108
optimizer_state_dtypes: Optional[dict[str, SparseType]] = None,
109109
use_rowwise_bias_correction: bool = False,
110110
counter_weight_decay_mode: Optional[CounterWeightDecayMode] = None,
111+
counter_halflife: int = -1,
111112
) -> None:
112113
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
113114

@@ -297,7 +298,7 @@ def execute_backward_optimizers_( # noqa C901
297298
else:
298299
counter_based_regularization = CounterBasedRegularizationDefinition(
299300
counter_weight_decay_mode=CounterWeightDecayMode.ADAGRADW,
300-
counter_halflife=-1,
301+
counter_halflife=counter_halflife,
301302
adjustment_iter=-1,
302303
adjustment_ub=0.1,
303304
learning_rate_mode=LearningRateMode.EQUAL,
@@ -893,11 +894,22 @@ def _get_wts_from_counter_adagrad_using_counter(
893894
adjustment_iter > 0 and iter_ > adjustment_iter
894895
):
895896
if counter_weight_decay_mode == CounterWeightDecayMode.ADAGRADW:
896-
adjusted_multiplier = torch.where(
897-
row_counter > 0,
898-
multiplier * torch.sqrt(row_counter),
899-
torch.Tensor([0.0]),
900-
)
897+
if counter_halflife == -1:
898+
adjusted_multiplier = torch.where(
899+
row_counter > 0,
900+
multiplier * torch.sqrt(row_counter),
901+
torch.Tensor([0.0]),
902+
)
903+
elif counter_halflife == -2:
904+
adjusted_multiplier = torch.where(
905+
row_counter > 0,
906+
torch.minimum(
907+
torch.tensor([learning_rate]) * row_counter,
908+
torch.tensor([adjustment_ub]),
909+
)
910+
/ denom,
911+
torch.tensor([0.0]),
912+
)
901913
exp_reg_correction = torch.where(
902914
row_counter > 0,
903915
1.0 - weight_decay * learning_rate,
@@ -1177,6 +1189,7 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901
11771189
CounterWeightDecayMode.ADAGRADW,
11781190
]
11791191
),
1192+
counter_halflife=st.sampled_from([-1, -2]),
11801193
)
11811194
@settings(
11821195
verbosity=VERBOSITY,
@@ -1201,6 +1214,7 @@ def test_backward_optimizers_adagrad( # noqa C901
12011214
use_cpu: bool,
12021215
weight_decay_mode: WeightDecayMode,
12031216
counter_weight_decay_mode: CounterWeightDecayMode,
1217+
counter_halflife: int,
12041218
) -> None:
12051219
if (
12061220
pooling_mode == PoolingMode.NONE
@@ -1222,6 +1236,7 @@ def test_backward_optimizers_adagrad( # noqa C901
12221236
use_cpu,
12231237
weight_decay_mode,
12241238
counter_weight_decay_mode=counter_weight_decay_mode,
1239+
counter_halflife=counter_halflife,
12251240
)
12261241

12271242
@given(

0 commit comments

Comments
 (0)