Skip to content

Commit 767d463

Browse files
committed
A couple experiments on dataset mixing experiments that didn't have any positive effect
lang_lr_attenuation, decays the learning rate per epoch Do scaling as well: can set the learning rate for an individual language to be lower than default Document how to use the flags in the --help Add some comments on an experiment that seemed like it might have helped, but so far didn't
1 parent 101fdb5 commit 767d463

File tree

4 files changed

+48
-0
lines changed

4 files changed

+48
-0
lines changed

stanza/models/coref/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,5 @@ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-me
6767
max_train_len: int
6868
use_zeros: bool
6969

70+
lang_lr_attenuation: str
71+
lang_lr_weights: str

stanza/models/coref/coref_config.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,22 @@ max_train_len = 5000
122122
# if this is set to false, the model will set its zero_predictor to, well, 0
123123
use_zeros = true
124124

125+
# two different methods for specifying how to weaken the LR for certain languages
126+
# however, in their current forms, on an HE experiment, neither worked
127+
# better than just mixing the two datasets together unweighted
128+
# Starting from the HE IAHLT dataset, and possibly mixing in the ger/rom ud coref,
129+
# averaging over 5 different seeds, we got the following results:
130+
# HE only: 0.497
131+
# Attenuated: 0.508
132+
# Scaled: 0.517
133+
# Mixed: 0.517
134+
# the attenuation scheme for that experiment was 1/epoch
135+
# These were the settings
136+
# --lang_lr_weights es=0.2,en=0.2,de=0.2,ca=0.2,fr=0.2,no=0.2
137+
# --lang_lr_attenuation es,en,de,ca,fr,no
138+
lang_lr_attenuation = ""
139+
lang_lr_weights = ""
140+
125141
# =============================================================================
126142
# Extra keyword arguments to be passed to bert tokenizers of specified models
127143
[DEFAULT.tokenizer_kwargs]

stanza/models/coref/model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,19 @@ def train(self, log=False):
486486
# new model, set it to always predict not-zero
487487
self.disable_zeros_predictor()
488488

489+
attenuated_languages = set()
490+
if self.config.lang_lr_attenuation:
491+
attenuated_languages = self.config.lang_lr_attenuation.split(",")
492+
logger.info("Attenuating LR for the following languages: %s", attenuated_languages)
493+
494+
lr_scaled_languages = dict()
495+
if self.config.lang_lr_weights:
496+
scaled_languages = self.config.lang_lr_weights.split(",")
497+
for piece in scaled_languages:
498+
pieces = piece.split("=")
499+
lr_scaled_languages[pieces[0]] = float(pieces[1])
500+
logger.info("Scaling LR for the following languages: %s", lr_scaled_languages)
501+
489502
best_f1 = None
490503
for epoch in range(self.epochs_trained, self.config.train_epochs):
491504
self.training = True
@@ -526,6 +539,13 @@ def train(self, log=False):
526539
else:
527540
s_loss = torch.zeros_like(c_loss)
528541

542+
lr_scale = lr_scaled_languages.get(doc.get("lang"), 1.0)
543+
if doc.get("lang") in attenuated_languages:
544+
lr_scale = lr_scale / max(epoch, 1.0)
545+
c_loss = c_loss * lr_scale
546+
s_loss = s_loss * lr_scale
547+
z_loss = z_loss * lr_scale
548+
529549
(c_loss + s_loss + z_loss).backward()
530550

531551
running_c_loss += c_loss.item()

stanza/models/wl_coref.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ def deterministic() -> None:
134134
argparser.add_argument("--seed", type=int, default=2020,
135135
help="Random seed to set")
136136

137+
argparser.add_argument("--lang_lr_attenuation", type=str, default=None,
138+
help="A comma-separated list of languages where the LR will be scaled by 1/epoch, such as --lang_lr_attenuation=es,en,de,...")
139+
argparser.add_argument("--lang_lr_weights", type=str, default=None,
140+
help="A comma-separated list of languages and their weights of LR scaling for different languages, such as es=0.5,en=1.0,...")
141+
137142
argparser.add_argument("--max_train_len", type=int, default=5000,
138143
help="Skip any documents longer than this maximum length")
139144
argparser.add_argument("--no_max_train_len", action="store_const", const=float("inf"), dest="max_train_len",
@@ -196,6 +201,11 @@ def deterministic() -> None:
196201
if args.max_train_len:
197202
config.max_train_len = args.max_train_len
198203

204+
if args.lang_lr_attenuation:
205+
config.lang_lr_attenuation = args.lang_lr_attenuation
206+
if args.lang_lr_weights:
207+
config.lang_lr_weights = args.lang_lr_weights
208+
199209
# if wandb, generate wandb configuration
200210
if args.mode == "train":
201211
if args.wandb:

0 commit comments

Comments
 (0)