diff --git a/stanza/models/common/doc.py b/stanza/models/common/doc.py index 22c3fa7bf..a230a2282 100644 --- a/stanza/models/common/doc.py +++ b/stanza/models/common/doc.py @@ -483,18 +483,22 @@ def coref(self, chains): def _attach_coref_mentions(self, chains): for sentence in self.sentences: - for word in sentence.words: + for word in sentence.all_words: word.coref_chains = [] for chain in chains: for mention_idx, mention in enumerate(chain.mentions): sentence = self.sentences[mention.sentence] - for word_idx in range(mention.start_word, mention.end_word): - is_start = word_idx == mention.start_word - is_end = word_idx == mention.end_word - 1 - is_representative = mention_idx == chain.representative_index - attachment = CorefAttachment(chain, is_start, is_end, is_representative) - sentence.words[word_idx].coref_chains.append(attachment) + if isinstance(mention.start_word, tuple): + attachment = CorefAttachment(chain, True, True, False) + sentence._empty_words[mention.start_word[1]-1].coref_chains.append(attachment) + else: + for word_idx in range(mention.start_word, mention.end_word): + is_start = word_idx == mention.start_word + is_end = word_idx == mention.end_word - 1 + is_representative = mention_idx == chain.representative_index + attachment = CorefAttachment(chain, is_start, is_end, is_representative) + sentence.words[word_idx].coref_chains.append(attachment) def reindex_sentences(self, start_index): for sent_id, sentence in zip(range(start_index, start_index + len(self.sentences)), self.sentences): @@ -737,6 +741,17 @@ def empty_words(self, value): """ Set the list of words for this sentence. """ self._empty_words = value + @property + def all_words(self): + """ Access the list of words + empty words for this sentence. """ + words = self._words + empty_words = self._empty_words + + all_words = sorted(words + empty_words, + key=lambda x:(x.id,) if isinstance(x.id, int) else x.id) + + return all_words + @property def ents(self): """ Access the list of entities in this sentence. """ diff --git a/stanza/models/coref/const.py b/stanza/models/coref/const.py index 931eee122..479ee5fee 100644 --- a/stanza/models/coref/const.py +++ b/stanza/models/coref/const.py @@ -25,3 +25,5 @@ class CorefResult: rough_scores: torch.Tensor = None # [n_words, n_words] span_scores: torch.Tensor = None # [n_heads, n_words, 2] span_y: Tuple[torch.Tensor, torch.Tensor] = None # [n_heads] x2 + + zero_scores: torch.Tensor = None diff --git a/stanza/models/coref/dataset.py b/stanza/models/coref/dataset.py index fca7d4e50..3efe90379 100644 --- a/stanza/models/coref/dataset.py +++ b/stanza/models/coref/dataset.py @@ -50,9 +50,11 @@ def __init__(self, path, config, tokenizer): word2subword.append((len(subwords), len(subwords) + len(tokenized_word))) subwords.extend(tokenized_word) word_id.extend([i] * len(tokenized_word)) + doc["word2subword"] = word2subword doc["subwords"] = subwords doc["word_id"] = word_id + self.__out.append(doc) logger.info("Loaded %d docs from %s.", len(data_f), path) diff --git a/stanza/models/coref/model.py b/stanza/models/coref/model.py index 69d6a8e10..e59a5979b 100644 --- a/stanza/models/coref/model.py +++ b/stanza/models/coref/model.py @@ -33,6 +33,7 @@ from stanza.models.coref.rough_scorer import RoughScorer from stanza.models.coref.span_predictor import SpanPredictor from stanza.models.coref.utils import GraphNode +from stanza.models.coref.utils import sigmoid_focal_loss from stanza.models.coref.word_encoder import WordEncoder from stanza.models.coref.dataset import CorefDataset from stanza.models.coref.tokenizer_customization import * @@ -41,6 +42,8 @@ from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper +import torch.nn as nn + logger = logging.getLogger('stanza') class CorefModel: # pylint: disable=too-many-instance-attributes @@ -140,6 +143,8 @@ def evaluate(self, running_loss = 0.0 s_correct = 0 s_total = 0 + z_correct = 0 + z_total = 0 with conll.open_(self.config, self.epochs_trained, data_split) \ as (gold_f, pred_f): @@ -150,13 +155,21 @@ def evaluate(self, # want to test evaluation on one language continue - res = self.run(doc) + res = self.run(doc, True) + # measure zero prediction accuracy + zero_targets = torch.tensor(doc["is_zero"], device=res.zero_scores.device) + zero_preds = (res.zero_scores > 0).view(-1).to(zero_targets.dtype) + z_correct += (zero_preds == zero_targets).sum().item() + z_total += zero_targets.numel() if (res.coref_y.argmax(dim=1) == 1).all(): logger.warning(f"EVAL: skipping document with no corefs...") continue running_loss += self._coref_criterion(res.coref_scores, res.coref_y).item() + if res.word_clusters is None or res.span_clusters is None: + logger.warning(f"EVAL: skipping document with no clusters...") + continue if res.span_y: pred_starts = res.span_scores[:, :, 0].argmax(dim=1) @@ -191,8 +204,10 @@ def evaluate(self, f" f1: {s_lea[0]:.5f}," f" p: {s_lea[1]:.5f}," f" r: {s_lea[2]:<.5f}" + f" | ZA: {z_correct / z_total:<.5f}" ) logger.info(f"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}") + logger.info(f"Zero prediction accuracy: {z_correct / z_total:.5f}") return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc, w_checker.bakeoff, s_checker.bakeoff) @@ -332,6 +347,7 @@ def load_model(path: str, def run(self, # pylint: disable=too-many-locals doc: Doc, + use_gold_spans_for_zeros = False ) -> CorefResult: """ This is a massive method, but it made sense to me to not split it into @@ -380,16 +396,27 @@ def run(self, # pylint: disable=too-many-locals res.coref_y = self._get_ground_truth( cluster_ids, top_indices, (top_rough_scores > float("-inf")), self.config.clusters_starts_are_singletons, - self.config.singletons) + self.config.singletons + ) - res.word_clusters = self._clusterize(doc, res.coref_scores, top_indices, - self.config.singletons) + res.word_clusters = self._clusterize( + doc, res.coref_scores, top_indices, + self.config.singletons + ) res.span_scores, res.span_y = self.sp.get_training_data(doc, words) if not self.training: res.span_clusters = self.sp.predict(doc, words, res.word_clusters) + if not self.training and not use_gold_spans_for_zeros: + zero_words = words[[word_id + for cluster in res.word_clusters + for word_id in cluster]] + else: + zero_words = words[[i[0] for i in sorted(doc["head2span"])]] + res.zero_scores = self.zeros_predictor(zero_words) + return res def save_weights(self, save_path=None, save_optimizers=True): @@ -454,6 +481,7 @@ def train(self, log=False): self.log_norms() running_c_loss = 0.0 running_s_loss = 0.0 + running_z_loss = 0.0 random.shuffle(docs_ids) pbar = tqdm(docs_ids, unit="docs", ncols=0) for doc_indx, doc_id in enumerate(pbar): @@ -468,6 +496,14 @@ def train(self, log=False): res = self.run(doc) + if res.zero_scores.size(0) == 0: + z_loss = 0.0 # since there are no corefs + else: + z_loss = sigmoid_focal_loss(res.zero_scores.squeeze(-1), + (torch.tensor(doc["is_zero"]) + .to(res.zero_scores.device).float()), + reduction="mean") + c_loss = self._coref_criterion(res.coref_scores, res.coref_y) if res.span_y: @@ -476,20 +512,24 @@ def train(self, log=False): else: s_loss = torch.zeros_like(c_loss) - del res - - (c_loss + s_loss).backward() + (c_loss + s_loss + z_loss).backward() running_c_loss += c_loss.item() running_s_loss += s_loss.item() + if res.zero_scores.size(0) != 0: + running_z_loss += z_loss.item() # log every 100 docs if log and doc_indx % 100 == 0: - wandb.log({'train_c_loss': c_loss.item(), - 'train_s_loss': s_loss.item()}) + logged = { + 'train_c_loss': c_loss.item(), + 'train_s_loss': s_loss.item(), + } + if res.zero_scores.size(0) != 0: + logged['train_z_loss'] = z_loss.item() + wandb.log(logged) - - del c_loss, s_loss + del c_loss, s_loss, z_loss, res for optim in self.optimizers.values(): optim.step() @@ -501,6 +541,7 @@ def train(self, log=False): f" {doc['document_id']:26}" f" c_loss: {running_c_loss / (pbar.n + 1):<.5f}" f" s_loss: {running_s_loss / (pbar.n + 1):<.5f}" + f" z_loss: {running_z_loss / (pbar.n + 1):<.5f}" ) self.epochs_trained += 1 @@ -614,12 +655,17 @@ def _build_model(self, foundation_cache): self.we = WordEncoder(bert_emb, self.config).to(self.config.device) self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device) self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device) + self.zeros_predictor = nn.Sequential( + nn.Linear(bert_emb, bert_emb), + nn.ReLU(), + nn.Linear(bert_emb, 1) + ).to(self.config.device) self.trainable: Dict[str, torch.nn.Module] = { "bert": self.bert, "we": self.we, "rough_scorer": self.rough_scorer, "pw": self.pw, "a_scorer": self.a_scorer, - "sp": self.sp + "sp": self.sp, "zeros_predictor": self.zeros_predictor } def _build_optimizers(self): @@ -785,4 +831,3 @@ def _set_training(self, value: bool): self._training = value for module in self.trainable.values(): module.train(self._training) - diff --git a/stanza/models/coref/utils.py b/stanza/models/coref/utils.py index 027308a31..af6ad1963 100644 --- a/stanza/models/coref/utils.py +++ b/stanza/models/coref/utils.py @@ -3,6 +3,7 @@ from typing import List, Set import torch +import torch.nn.functional as F from stanza.models.coref.const import EPSILON @@ -33,3 +34,57 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False): else: dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore return torch.cat((dummy, tensor), dim=1) + +def sigmoid_focal_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2, + reduction: str = "none", +) -> torch.Tensor: + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + + Args: + inputs (Tensor): A float tensor of arbitrary shape. + The predictions for each example. + targets (Tensor): A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha (float): Weighting factor in range [0, 1] to balance + positive vs negative examples or -1 for ignore. Default: ``0.25``. + gamma (float): Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. Default: ``2``. + reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` + ``'none'``: No reduction will be applied to the output. + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'``. + Returns: + Loss tensor with the reduction option applied. + """ + # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py + + if not (0 <= alpha <= 1) and alpha != -1: + raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.") + + p = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + # Check reduction option and return loss accordingly + if reduction == "none": + pass + elif reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + else: + raise ValueError( + f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" + ) + return loss diff --git a/stanza/pipeline/coref_processor.py b/stanza/pipeline/coref_processor.py index 3c6bb2e01..9f43328e3 100644 --- a/stanza/pipeline/coref_processor.py +++ b/stanza/pipeline/coref_processor.py @@ -4,10 +4,13 @@ from stanza.models.common.utils import misc_to_space_after from stanza.models.coref.coref_chain import CorefMention, CorefChain +from stanza.models.common.doc import Word from stanza.pipeline._constants import * from stanza.pipeline.processor import UDProcessor, register_processor +import torch + def extract_text(document, sent_id, start_word, end_word): sentence = document.sentences[sent_id] tokens = [] @@ -99,8 +102,13 @@ def process(self, document): } coref_input = self._model.build_doc(coref_input) results = self._model.run(coref_input) + + + # Handle zero anaphora - zero_scores is always predicted + zero_nodes_created = self._handle_zero_anaphora(document, results, sent_ids, word_pos) + clusters = [] - for span_cluster in results.span_clusters: + for cluster_idx, span_cluster in enumerate(results.span_clusters): if len(span_cluster) == 0: continue span_cluster = sorted(span_cluster) @@ -122,6 +130,11 @@ def process(self, document): best_span = None max_propn = 0 for span_idx, span in enumerate(span_cluster): + word_idx = results.word_clusters[cluster_idx][span_idx] + is_zero = zero_nodes_created.get((cluster_idx, word_idx)) + if is_zero: + continue + sent_id = sent_ids[span[0]] sentence = sentences[sent_id] start_word = word_pos[span[0]] @@ -139,16 +152,91 @@ def process(self, document): max_propn = num_propn mentions = [] - for span in span_cluster: - sent_id = sent_ids[span[0]] - start_word = word_pos[span[0]] - end_word = word_pos[span[1]-1] + 1 - mentions.append(CorefMention(sent_id, start_word, end_word)) - representative = mentions[best_span] - representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word) + for span_idx, span in enumerate(span_cluster): + word_idx = results.word_clusters[cluster_idx][span_idx] + is_zero = zero_nodes_created.get((cluster_idx, word_idx)) + if is_zero: + (sent_id, zero_word_id) = is_zero + # if the word id is a tuple, it will be attached + # to the zero + mentions.append( + CorefMention( + sent_id, + zero_word_id, + zero_word_id + ) + ) + else: + sent_id = sent_ids[span[0]] + start_word = word_pos[span[0]] + end_word = word_pos[span[1]-1] + 1 + mentions.append(CorefMention(sent_id, start_word, end_word)) + + # if we ended up with no best span, then our "representative text" + # is just underscore + if best_span is not None: + representative = mentions[best_span] + representative_text = extract_text(document, representative.sentence, representative.start_word, representative.end_word) + else: + representative_text = "_" chain = CorefChain(len(clusters), mentions, representative_text, best_span) clusters.append(chain) document.coref = clusters return document + + def _handle_zero_anaphora(self, document, results, sent_ids, word_pos): + """Handle zero anaphora by creating zero nodes and updating coreference clusters.""" + if results.zero_scores is None or results.word_clusters is None: + return + + zero_scores = results.zero_scores.squeeze(-1) if results.zero_scores.dim() > 1 else results.zero_scores + is_zero = [] + + # Flatten word_clusters to get the word indices that correspond to zero_scores + cluster_word_ids = [] + cluster_mapping = {} + counter = 0 + for indx, cluster in enumerate(results.word_clusters): + for _ in range(len(cluster)): + cluster_mapping[counter] = indx + counter += 1 + cluster_word_ids.extend(cluster) + + # Find indices where zero_scores > 0 + zero_indices = (zero_scores > 0.0).nonzero() + + # this dict maps (cluster_id, word_id) to (cluster_id, start, end) + # which overrides span_clusters + zero_to_coref = {} + + for zero_idx in zero_indices: + zero_idx = zero_idx.item() + if zero_idx >= len(cluster_word_ids): + continue + + word_idx = cluster_word_ids[zero_idx] + sent_id = sent_ids[word_idx] + word_id = word_pos[word_idx] + + # Create zero node - attach BEFORE the current word + # This means the zero node comes after word_id-1 but before word_id + zero_word_id = ( + word_id, + len(document.sentences[sent_id]._empty_words)+1 + ) # attach after word_id-1, before word_id + zero_word = Word(document.sentences[sent_id], { + "text": "_", + "lemma": "_", + "id": zero_word_id + }) + document.sentences[sent_id]._empty_words.append(zero_word) + + # Track this zero node for adding to coreference clusters + cluster_idx = cluster_mapping[zero_idx] + zero_to_coref[(cluster_idx, word_idx)] = ( + sent_id, zero_word_id + ) + + return zero_to_coref diff --git a/stanza/utils/datasets/coref/convert_udcoref.py b/stanza/utils/datasets/coref/convert_udcoref.py index 72b0e8c18..d70c0946f 100644 --- a/stanza/utils/datasets/coref/convert_udcoref.py +++ b/stanza/utils/datasets/coref/convert_udcoref.py @@ -10,6 +10,7 @@ from stanza.utils.conll import CoNLL +import warnings from random import Random import argparse @@ -22,6 +23,7 @@ UDCOREF_ADDN = 0 if not IS_UDCOREF_FORMAT else 1 def process_documents(docs, augment=False): + # docs = sections processed_section = [] for idx, (doc, doc_id, lang) in enumerate(tqdm(docs)): @@ -35,9 +37,9 @@ def process_documents(docs, augment=False): # extract the entities # get sentence words and lengths - sentences = [[j.text for j in i.words] + sentences = [[j.text for j in i.all_words] for i in doc.sentences] - sentence_lens = [len(x.words) for x in doc.sentences] + sentence_lens = [len(x.all_words) for x in doc.sentences] cased_words = [] for x in sentences: @@ -56,26 +58,28 @@ def process_documents(docs, augment=False): # TODO: does SD vs UD matter? deprel = [] for sentence in doc.sentences: - for word in sentence.words: + for word in sentence.all_words: deprel.append(word.deprel) - if word.head == 0: + if not word.head or word.head == 0: heads.append("null") else: heads.append(word.head - 1 + word_total) - word_total += len(sentence.words) + word_total += len(sentence.all_words) span_clusters = defaultdict(list) word_clusters = defaultdict(list) head2span = [] + is_zero = [] word_total = 0 SPANS = re.compile(r"(\(\w+|[%\w]+\))") + do_ctn = False # if we broke in the loop for parsed_sentence in doc.sentences: # spans regex # parse the misc column, leaving on "Entity" entries misc = [[k.split("=") for k in j if k.split("=")[0] == "Entity"] - for i in parsed_sentence.words + for i in parsed_sentence.all_words for j in [i.misc.split("|") if i.misc else []]] # and extract the Entity entry values entities = [i[0][1] if len(i) > 0 else None for i in misc] @@ -112,23 +116,57 @@ def process_documents(docs, augment=False): for k, v in final_refs.items(): for i in v: coref_spans.append([int(k), i[0], i[1]]) - sentence_upos = [x.upos for x in parsed_sentence.words] - sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words] + sentence_upos = [x.upos for x in parsed_sentence.all_words] + sentence_heads = [x.head - 1 if x.head and x.head > 0 else None for x in parsed_sentence.all_words] + sentence_text = [x.text for x in parsed_sentence.all_words] + + # if "_" in sentence_text and sentence_text.index("_") in [j for i in coref_spans for j in i]: + # import ipdb + # ipdb.set_trace() + for span in coref_spans: + zero = False + if sentence_text[span[1]] == "_" and span[1] == span[2]: + is_zero.append([span[0], True]) + zero = True + # oo! that's a zero coref, we should merge it forwards + # i.e. we pick the next word as the head! + span = [span[0], span[1]+1, span[2]+1] + # crap! there's two zeros right next to each other + # we are sad and confused so we give up in this case + if len(sentence_text) > span[1] and sentence_text[span[1]] == "_": + warnings.warn("Found two zeros next to each other in sequence; we are confused and therefore giving up.") + do_ctn = True + break + else: + is_zero.append([span[0], False]) + # input is expected to be start word, end word + 1 # counting from 0 # whereas the OntoNotes coref_span is [start_word, end_word] inclusive span_start = span[1] + word_total span_end = span[2] + word_total + 1 - candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) + # if its a zero coref (i.e. coref, but the head in None), we call + # the beginning of the span (i.e. the zero itself) the head + + if zero: + candidate_head = span[1] + else: + try: + candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) + except RecursionError: + candidate_head = span[1] + if candidate_head is None: for candidate_head in range(span[1], span[2] + 1): # stanza uses 0 to mark the head, whereas OntoNotes is counting # words from 0, so we have to subtract 1 from the stanza heads #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1) # treat the head of the phrase as the first word that has a head outside the phrase - if (parsed_sentence.words[candidate_head].head - 1 < span[1] or - parsed_sentence.words[candidate_head].head - 1 > span[2]): + if (parsed_sentence.all_words[candidate_head].head is not None) and ( + parsed_sentence.all_words[candidate_head].head - 1 < span[1] or + parsed_sentence.all_words[candidate_head].head - 1 > span[2] + ): break else: # if none have a head outside the phrase (circular??) @@ -139,10 +177,45 @@ def process_documents(docs, augment=False): span_clusters[span[0]].append((span_start, span_end)) word_clusters[span[0]].append(candidate_head) head2span.append((candidate_head, span_start, span_end)) - word_total += len(parsed_sentence.words) + if do_ctn: + break + word_total += len(parsed_sentence.all_words) + if do_ctn: + continue span_clusters = sorted([sorted(values) for _, values in span_clusters.items()]) word_clusters = sorted([sorted(values) for _, values in word_clusters.items()]) head2span = sorted(head2span) + is_zero = [i for _,i in sorted(is_zero)] + + # remove zero tokens "_" from cased_words and adjust indices accordingly + zero_positions = [i for i, w in enumerate(cased_words) if w == "_"] + if zero_positions: + old_to_new = {} + new_idx = 0 + for old_idx, w in enumerate(cased_words): + if w != "_": + old_to_new[old_idx] = new_idx + new_idx += 1 + cased_words = [w for w in cased_words if w != "_"] + sent_id = [sent_id[i] for i in sorted(old_to_new.keys())] + deprel = [deprel[i] for i in sorted(old_to_new.keys())] + heads = [heads[i] for i in sorted(old_to_new.keys())] + try: + span_clusters = [ + [(old_to_new[start], old_to_new[end - 1] + 1) for start, end in cluster] + for cluster in span_clusters + ] + except (KeyError, TypeError) as _: # two errors, either end-1 = -1, or start/end is None + warnings.warn("Somehow, we are still coreffering to a zero. This is likely due to multiple zeros on top of each other. We are giving up.") + continue + word_clusters = [ + [old_to_new[h] for h in cluster] + for cluster in word_clusters + ] + head2span = [ + (old_to_new[h], old_to_new[s], old_to_new[e - 1] + 1) + for h, s, e in head2span + ] processed = { "document_id": doc_id, @@ -155,7 +228,8 @@ def process_documents(docs, augment=False): "span_clusters": span_clusters, "word_clusters": word_clusters, "head2span": head2span, - "lang": lang + "lang": lang, + "is_zero": is_zero } processed_section.append(processed) return processed_section @@ -172,7 +246,8 @@ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_ for load in filenames: lang = load.split("/")[-1].split("_")[0] print("Ingesting %s from %s of lang %s" % (section, load, lang)) - docs = CoNLL.conll2multi_docs(load) + docs = CoNLL.conll2multi_docs(load, ignore_gapping=False) + # sections = docs[:10] print(" Ingested %d documents" % len(docs)) if split_test and section == 'train': test_section = [] @@ -216,7 +291,7 @@ def process_dataset(short_name, coref_output_path, split_test, train_files, dev_ json.dump(converted_section, fout, indent=2) def get_dataset_by_language(coref_input_path, langs): - conll_path = os.path.join(coref_input_path, "CorefUD-1.2-public", "data") + conll_path = os.path.join(coref_input_path, "CorefUD-1.3-public", "data") train_filenames = [] dev_filenames = [] for lang in langs: @@ -242,9 +317,9 @@ def main(): coref_output_path = paths['COREF_DATA_DIR'] if args.project: - if args.project == 'slavic': - project = "slavic_udcoref" - langs = ('Polish', 'Russian', 'Czech') + if args.project == 'baltoslavic': + project = "baltoslavic_udcoref" + langs = ('Polish', 'Russian', 'Czech', 'Old_Church_Slavonic', 'Lithuanian') train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) elif args.project == 'hungarian': project = "hu_udcoref" @@ -262,6 +337,26 @@ def main(): project = "norwegian_udcoref" langs = ('Norwegian',) train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'turkish': + project = "turkish_udcoref" + langs = ('Turkish',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'korean': + project = "korean_udcoref" + langs = ('Korean',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'hindi': + project = "hindi_udcoref" + langs = ('Hindi',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'ancient_greek': + project = "ancient_greek_udcoref" + langs = ('Ancient_Greek',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) + elif args.project == 'ancient_hebrew': + project = "ancient_hebrew_udcoref" + langs = ('Ancient_Hebrew',) + train_filenames, dev_filenames = get_dataset_by_language(coref_input_path, langs) else: project = args.directory conll_path = os.path.join(coref_input_path, project) @@ -273,4 +368,3 @@ def main(): if __name__ == '__main__': main() -