@@ -186,7 +186,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
186
186
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
187
187
"""
188
188
)
189
- split_precomputation += """
189
+ split_precomputation += """
190
190
// Define the rowwise adagrad optimizer state struct view
191
191
struct [[maybe_unused]] OptimizerState {
192
192
at::acc_type<cache_t, true> momentum;
@@ -197,17 +197,17 @@ def rowwise_adagrad() -> Dict[str, Any]:
197
197
198
198
at::acc_type<cache_t, true> multiplier = 0.0;
199
199
at::acc_type<cache_t, true> correction = 0.0;
200
- if (threadIdx.x == 0) {
200
+ if (threadIdx.x == 0) {
201
201
auto new_sum_square_grads = g_avg_square;
202
-
203
- // Update the optimizer state. Use optimizer state offloading only if
202
+
203
+ // Update the optimizer state. Use optimizer state offloading only if
204
204
// SSD and if enabled by the user
205
205
if (enable_optimizer_offloading) {
206
206
// Fetch the pointer to the optimizer state along the cache row
207
207
auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
208
208
new_sum_square_grads += optimizer->momentum;
209
209
optimizer->momentum = new_sum_square_grads;
210
-
210
+
211
211
} else {
212
212
new_sum_square_grads += momentum1[idx];
213
213
momentum1[idx] = new_sum_square_grads;
@@ -570,14 +570,17 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
570
570
if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3)
571
571
if (adjustment_enabled) {
572
572
if (weight_decay_mode == 3) { // AdagradW (weight_decay_mode=3)
573
- if (counter_halflife < 0 ) {
573
+ if (counter_halflife == -1 ) {
574
574
adjusted_multiplier = multiplier * sqrtf(row_counter[idx] * 1.0);
575
- exp_reg_correction = 1.0 - weight_decay * learning_rate;
576
- const auto lazy_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
577
- const auto lazy_multiplier = powf(exp_reg_correction, min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
578
- adjusted_multiplier *= lazy_multiplier;
579
- exp_reg_correction *= lazy_multiplier;
580
575
}
576
+ else if (counter_halflife == -2) {
577
+ adjusted_multiplier = min(learning_rate * powf(row_counter[idx] * 1.0, 1.0), adjustment_ub) / (sqrtf(new_sum_square_grads) + eps);
578
+ }
579
+ exp_reg_correction = 1.0 - weight_decay * learning_rate;
580
+ const auto lazy_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
581
+ const auto lazy_multiplier = powf(exp_reg_correction, min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
582
+ adjusted_multiplier *= lazy_multiplier;
583
+ exp_reg_correction *= lazy_multiplier;
581
584
} else if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2)
582
585
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
583
586
} else if (weight_decay_mode == 1) { // L2 regularization (coupled wd)
@@ -1040,8 +1043,8 @@ def adam() -> Dict[str, Any]:
1040
1043
DEVICE_INLINE momentum2_ph_t* momentum2_ptr(const int32_t D) {
1041
1044
// Cast to uintptr_t for pointer arithmetic
1042
1045
auto addr = reinterpret_cast<uintptr_t>(momentum1_ptr() + D);
1043
-
1044
- // Cast back to momentum2_ph_t* and return
1046
+
1047
+ // Cast back to momentum2_ph_t* and return
1045
1048
return reinterpret_cast<momentum2_ph_t *>(addr);
1046
1049
}
1047
1050
};
@@ -1179,16 +1182,16 @@ def partial_rowwise_adam() -> Dict[str, Any]:
1179
1182
struct OptimizerState {
1180
1183
// momentum2 is a single value placed at the beginning of the struct
1181
1184
momentum2_ph_t momentum2;
1182
-
1185
+
1183
1186
// momentum1 is an array of values placed after momentum2, aligned to 4-byte boundary
1184
1187
// to support mixed state precision (e.g. FP32 momentum1 and FP16 momentum2)
1185
1188
alignas(4) momentum1_ph_t momentum1[1];
1186
-
1189
+
1187
1190
// momentum2_ptr returns a pointer to the beginning of the struct
1188
1191
DEVICE_INLINE momentum2_ph_t* momentum2_ptr() {
1189
1192
return &momentum2;
1190
1193
}
1191
-
1194
+
1192
1195
// momentum1_ptr returns a pointer to the beginning of the momentum1 array
1193
1196
DEVICE_INLINE momentum1_ph_t* momentum1_ptr() {
1194
1197
return momentum1;
@@ -1231,11 +1234,11 @@ def partial_rowwise_adam() -> Dict[str, Any]:
1231
1234
// Create a Vec4T for momentum1 values - either directly from momentum1_start
1232
1235
// or from a temporary aligned buffer if optimizer offloading is enabled
1233
1236
Vec4T<momentum1_ph_t> m_t;
1234
-
1237
+
1235
1238
if (enable_optimizer_offloading) {
1236
- // When offloading is enabled, we need to ensure proper alignment, so
1239
+ // When offloading is enabled, we need to ensure proper alignment, so
1237
1240
// first copy to a temporary aligned array before loading to Vec4T
1238
- m_t = vec4_load_unaligned(momentum1_start + d);
1241
+ m_t = vec4_load_unaligned(momentum1_start + d);
1239
1242
m_t.mul_(beta1);
1240
1243
m_t.fma_(grad, 1.0 - beta1);
1241
1244
vec4_store_unaligned(m_t, momentum1_start + d);
@@ -1247,7 +1250,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
1247
1250
m_t.fma_(grad, 1.0 - beta1);
1248
1251
m_t.store(&momentum1_start[d]);
1249
1252
}
1250
-
1253
+
1251
1254
// Update weights using the momentum values
1252
1255
weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x);
1253
1256
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y);
0 commit comments