diff --git a/configs/all_tasks.json b/configs/all_tasks.json new file mode 100644 index 0000000..67d8f58 --- /dev/null +++ b/configs/all_tasks.json @@ -0,0 +1,35 @@ +{ + "infer": { + "posttrain": ["arc_ar", "arc_bn", "arc_ca", "arc_da", "arc_de", "arc_es", "arc_eu", "arc_fr", "arc_gu", "arc_hi", "arc_hr", "arc_hu", "arc_hy", "arc_id", "arc_it", "arc_kn", "arc_ml", "arc_mr", "arc_ne", "arc_nl", "arc_pt", "arc_ro", "arc_ru", "arc_sk", "arc_sr", "arc_sv", "arc_ta", "arc_te", "arc_uk", "arc_vi", "arc_zh", "global_mmlu_ar", "global_mmlu_bn", "global_mmlu_de", "global_mmlu_es", "global_mmlu_fr", "global_mmlu_hi", "global_mmlu_id", "global_mmlu_it", "global_mmlu_ja", "global_mmlu_ko", "global_mmlu_pt", "global_mmlu_sw", "global_mmlu_yo", "global_mmlu_zh", "hellaswag", "hellaswag_ar", "hellaswag_bn", "hellaswag_ca", "hellaswag_da", "hellaswag_de", "hellaswag_es", "hellaswag_eu", "hellaswag_fr", "hellaswag_gu", "hellaswag_hi", "hellaswag_hr", "hellaswag_hu", "hellaswag_hy", "hellaswag_id", "hellaswag_it", "hellaswag_kn", "hellaswag_ml", "hellaswag_mr", "hellaswag_ne", "hellaswag_nl", "hellaswag_pt", "hellaswag_ro", "hellaswag_ru", "hellaswag_sk", "hellaswag_sr", "hellaswag_sv", "hellaswag_ta", "hellaswag_te", "hellaswag_uk", "hellaswag_vi", "include_base_44_albanian", "include_base_44_arabic", "include_base_44_armenian", "include_base_44_azerbaijani", "include_base_44_basque", "include_base_44_belarusian", "include_base_44_bengali", "include_base_44_bulgarian", "include_base_44_chinese", "include_base_44_croatian", "include_base_44_finnish", "include_base_44_french", "include_base_44_georgian", "include_base_44_german", "include_base_44_hebrew", "include_base_44_hindi", "include_base_44_hungarian", "include_base_44_indonesian", "include_base_44_italian", "include_base_44_japanese", "include_base_44_kazakh", "include_base_44_korean", "include_base_44_lithuanian", "include_base_44_malay", "include_base_44_malayalam", "include_base_44_nepali", "include_base_44_persian", "include_base_44_polish", "include_base_44_portuguese", "include_base_44_russian", "include_base_44_serbian", "include_base_44_spanish", "include_base_44_tagalog", "include_base_44_tamil", "include_base_44_telugu", "include_base_44_turkish", "include_base_44_ukrainian", "include_base_44_uzbek", "include_base_44_vietnamese", "mmlu", "winogrande", "xcopa_et", "xcopa_ht", "xcopa_qu", "xnli_ar", "xnli_bg", "xnli_de", "xnli_el", "xnli_en", "xnli_es", "xnli_fr", "xnli_hi", "xnli_ru", "xnli_sw", "xnli_th", "xnli_tr", "xnli_ur", "xnli_vi", "xnli_zh", "xwinograd_en", "xwinograd_fr", "xwinograd_pt", "xwinograd_ru", "xwinograd_zh"] + }, + "other": [ + {"name": "ai2_arc", "kinds": ["pretrain"], "size": null, "language": "en", "dimension": null, "alias": ["arc_easy", "arc_challenge"]}, + {"name": "include_base_44_greek", "kinds": ["pretrain"], "size": null, "language": "el", "dimension": null}, + {"name": "include_base_44_north macedonian", "kinds": ["pretrain"], "size": null, "language": "mk", "dimension": null}, + {"name": "xwinograd_jp", "kinds": ["pretrain"], "size": null, "language": "ja", "dimension": null}, + {"name": "include_base_44_dutch", "kinds": ["pretrain"], "size": 576, "language": null, "dimension": null}, + {"name": "piqa", "kinds": ["pretrain"], "size": 21000, "language": null, "dimension": null}, + {"name": "switzerland_qa_de", "kinds": ["pretrain"], "size": 9160, "language": null, "dimension": null}, + {"name": "switzerland_qa_fr", "kinds": ["pretrain"], "size": 9160, "language": null, "dimension": null}, + {"name": "switzerland_qa_it", "kinds": ["pretrain"], "size": 9160, "language": null, "dimension": null}, + {"name": "switzerland_qa_rm", "kinds": ["pretrain"], "size": 9160, "language": null, "dimension": null}, + {"name": "switzerland_qa_en", "kinds": ["pretrain"], "size": 9160, "language": null, "dimension": null}, + {"name": "cultural_bench", "kinds": ["pretrain"], "size": null, "language": "en", "dimension": null}, + {"name": "blend_algeria", "kinds": ["pretrain"], "size": 20364, "language": "ar", "kind": null}, + {"name": "blend_assam", "kinds": ["pretrain"], "size": 21293, "language": "as", "kind": null}, + {"name": "blend_azerbaijan", "kinds": ["pretrain"], "size": 19932, "language": "az", "kind": null}, + {"name": "blend_china", "kinds": ["pretrain"], "size": 20410, "language": "zh", "kind": null}, + {"name": "blend_ethiopia", "kinds": ["pretrain"], "size": 22712, "language": "am", "kind": null}, + {"name": "blend_greece", "kinds": ["pretrain"], "size": 20383, "language": "el", "kind": null}, + {"name": "blend_indonesia", "kinds": ["pretrain"], "size": 18417, "language": "id", "kind": null}, + {"name": "blend_iran", "kinds": ["pretrain"], "size": 19371, "language": "fa", "kind": null}, + {"name": "blend_mexico", "kinds": ["pretrain"], "size": 20513, "language": "es", "kind": null}, + {"name": "blend_north_korea", "kinds": ["pretrain"], "size": 17005, "language": "ko", "kind": null}, + {"name": "blend_northern_nigeria", "kinds": ["pretrain"], "size": 16317, "language": "ha", "kind": null}, + {"name": "blend_south_korea", "kinds": ["pretrain"], "size": 21439, "language": "ko", "kind": null}, + {"name": "blend_spain", "kinds": ["pretrain"], "size": 19280, "language": "es", "kind": null}, + {"name": "blend_uk", "kinds": ["pretrain"], "size": 16723, "language": "en", "kind": null}, + {"name": "blend_us", "kinds": ["pretrain"], "size": 16491, "language": "en", "kind": null}, + {"name": "blend_west_java", "kinds": ["pretrain"], "size": 15289, "language": "su", "kind": null} + ] +} diff --git a/configs/automation.json b/configs/automation.json index e507bb6..a258ee9 100644 --- a/configs/automation.json +++ b/configs/automation.json @@ -12,10 +12,11 @@ "/capstor/scratch/cscs/asolergi/main_run_70B_megatron/Megatron-LM/logs/Meg-Runs/main-runs-v1/apertus3-70b-512-nodes-1e-5lr/checkpoints", "/capstor/scratch/cscs/asolergi/main_run_70B_megatron/Megatron-LM/logs/Meg-Runs/main-runs-v1/apertus3-70b-512-nodes-1e-5lr/checkpoints-512-noOverlap" ], + "max_samples": 500000, "size": 70, "tokens_per_iter": "8388608:523519,16777216:", - "frequency": 30000, - "start_eval_from": 830000 + "frequency": 15000, + "start_eval_from": 1070000 } } } diff --git a/configs/tasks.json b/configs/tasks.json index 409b8c3..51cd9c2 100644 --- a/configs/tasks.json +++ b/configs/tasks.json @@ -1,9 +1,49 @@ { - "show_in_table": ["mmlu/acc", "gsm8k/exact_match", "arc_challenge/acc", "hellaswag/acc", "m_hellaswag/acc", "m_arc/acc", "include_base_44/acc"], - "root": "swissai_eval", - "groups": { - "swissai_eval": ["mmlu", "hellaswag", "mmlu_continuation", "winogrande", "piqa", "openbookqa", "arc_challenge", "arc_easy", "commonsense_qa", "lambada_openai", "lambada_standard", "wikitext", "gsm8k", "squadv2", "include_base_44", "xcopa", "xnli", "xwinograd", "pawsx", "m_arc", "global_mmlu", "m_hellaswag"], - "english": ["mmlu", "hellaswag", "mmlu_continuation", "winogrande", "piqa", "openbookqa", "arc_challenge", "arc_easy", "commonsense_qa", "lambada_openai", "lambada_standard", "wikitext", "gsm8k", "squadv2"], - "multilingual": ["include_base_44", "xcopa", "xnli", "xwinograd", "pawsx", "m_arc", "global_mmlu", "m_hellaswag"] + "show_in_table": [ + "mmlu/acc", + "hellaswag/acc" + ], + "language_groups": { + "english": [ + "en" + ], + "swiss": [ + "de", + "fr", + "it", + "rm" + ], + "eu": [ + "sq", + "hy", + "eu", + "be", + "bg", + "ca", + "hr", + "da", + "nl", + "en", + "et", + "fi", + "fr", + "ka", + "de", + "el", + "hu", + "it", + "lt", + "mk", + "pl", + "pt", + "ro", + "rm", + "ru", + "sr", + "sk", + "es", + "sv", + "uk" + ] } } diff --git a/pyproject.toml b/pyproject.toml index 08548d2..d190201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,19 @@ [project] -name = "swissai-evals" +name = "evals" version = "0.1.0" description = "Swissai evaluation scripts" readme = "README.md" requires-python = ">=3.11" dependencies = [ + "iso639-lang>=2.6.1", "pandas>=2.3.0", + "prtpy>=0.8.3", + "pyyaml>=6.0.2", + "requests>=2.32.4", "wandb>=0.20.1", ] + + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/scripts/automate.py b/scripts/automate.py index 1df681e..2638da9 100644 --- a/scripts/automate.py +++ b/scripts/automate.py @@ -13,20 +13,13 @@ import collections import re import os +import math import json import subprocess import shutil from pathlib import Path - - -def unify(completed: list[str]) -> list[str]: - completed_set = set(completed) - unified = [] - for groupname, tasks in filter(lambda t: t[0] != ROOT_EVAL, TASKS["groups"].items()): - if set(tasks) <= completed_set: - unified.append(groupname) - return unified +from evals.tasks import Task, get_all_tasks, get_partition def get_running(as_jobname: bool = False) -> dict[str, dict[int, list[str]] | list[str]]: @@ -47,10 +40,7 @@ def get_running(as_jobname: bool = False) -> dict[str, dict[int, list[str]] | li running.append(jobname) else: name, group, it = rmatch.groups() - if group == ROOT_EVAL: - running[name][int(it)] += ALL_EVALS - else: - running[name][int(it)].append(group) + running[name][int(it)] += group return running @@ -60,9 +50,9 @@ def get_evaluated(model: str) -> dict[int, list[str]]: it = int(re.match("^iter_([0-9]+)$", path.parent.parent.parent.parent.name).group(1)) with open(path) as f: info = json.load(f) - for task in info["results"]: - status[it].append(task) - return {it: unify(tasks) for it, tasks in status.items()} + for taskname in info["results"]: + status[it].append(taskname) + return status def get_available(model_dirs: list[Path]) -> list[int]: @@ -73,49 +63,81 @@ def get_available(model_dirs: list[Path]) -> list[int]: return available -def submit(name: str, model: dict, it: int, tasks: list[str]): - task_alias = ROOT_EVAL if tasks == ALL_EVALS else " ".join(tasks) - tasks = " ".join(tasks) +def submit(name: str, model: dict, it: int, tasks: list[Task]): + # Get partition of tasks. + total_size = sum(task.size for task in ALL_TASKS) + n_shards = math.ceil(total_size/model["max_samples"]) + partition = get_partition(tasks=tasks, shards=n_shards) + default_partition = get_partition(tasks=ALL_TASKS, shards=n_shards) + + # Schedule all tasks requested. path, = (model_dir for model_dir in model["model_dirs"] if Path(f"{model_dir}/iter_{it:07d}").exists()) - cmd = ["sbatch", - f"--job-name=eval_{name}_{task_alias}_{it}", - "scripts/evaluate.sbatch", - str(path), - str(it), - model["tokens_per_iter"], - name] - env = {**os.environ, - "LOGS_ROOT": CFG["logs_root"], - "TOKENIZER": "alehc/swissai-tokenizer", - "BOS": "true", - "SIZE": str(model["size"]), - "HF_TEMP_DIR": CFG["hf_temp_dir"], - "TASKS": tasks} - print("Launching", name, it, tasks, path) - subprocess.run(cmd, env=env, stdout=subprocess.PIPE) + for part in partition: + # Get special jobname depending on the tasks requested. + matches = [(i, default_part) for i, default_part in enumerate(default_partition) + if part == default_part] + if len(matches) == 0: + jobname = "mixed" + else: + (shard_i, _), = matches + jobname = f"shard{shard_i}of{n_shards}" + jobname = f"eval_{name}_{jobname}_{it}" + + cmd = ["sbatch", f"--job-name={jobname}", "scripts/evaluate.sbatch", str(path), + str(it), model["tokens_per_iter"], name] + env = {**os.environ, + "LOGS_ROOT": CFG["logs_root"], + "TOKENIZER": "alehc/swissai-tokenizer", + "BOS": "true", + "SIZE": str(model["size"]), + "HF_TEMP_DIR": CFG["hf_temp_dir"], + "TASKS": ",".join(task.name for task in part)} + print("Launching", jobname) + subprocess.run(cmd, env=env, stdout=subprocess.PIPE) def submit_needed(): running = get_running() for name, model in CFG["models"].items(): + total_size = sum(task.size for task in ALL_TASKS) + n_shards = math.ceil(total_size/model["max_samples"]) + default_partition = get_partition(tasks=ALL_TASKS, shards=n_shards) + + # Get tasks alredy evaluated (reading them from the `results.json`). status = get_evaluated(name) - for it, tasks in running[name].items(): - if it in status: - status[it] += tasks - else: - status[it] = tasks + default_partition = get_partition(tasks=ALL_TASKS, shards=n_shards) + + # Handle already evaluated: if a "mixed" group is running, assume it will + # contain all missing tasks because we don't know which one does it contain in reality, + # otherwise obtain the correct shard. + for it, groups in running[name].items(): + for group in groups: + if groups == "mixed": + actual_tasks = ALL_TASKS + else: + shard_i, total_shards = re.match("^shard([0-9]+)of([0-9]+)$", group).groups() + assert total_shards == n_shards + actual_tasks = default_partition[int(shard_i)] + + if it in status: + status[it] += [task.name for task in actual_tasks] + else: + status[it] = [task.name for task in actual_tasks] available = get_available(model["model_dirs"]) for it in available: if (it - model["start_eval_from"]) % model["frequency"] == 0 and it >= model["start_eval_from"]: - missing = sorted(set(ALL_EVALS) - set(status.get(it, []))) + # Determine missing set. + missing = [] + handled = status.get(it, []) + for task in ALL_TASKS: + if len(task.alias) > 0 and any(actual_name not in handled for actual_name in task.alias): + missing.append(task) + elif len(task.alias) == 0 and task.name not in handled: + missing.append(task) if len(missing) > 0: - if model["size"] < 70: - submit(name, model, it, missing) - else: - for task in missing: - submit(name, model, it, [task]) + submit(name, model, it, missing) def update_hf_checkpoints(): @@ -172,10 +194,7 @@ def main(): if __name__ == "__main__": + ALL_TASKS = get_all_tasks() with open("configs/automation.json") as f: CFG = json.load(f) - with open("configs/tasks.json") as f: - TASKS = json.load(f) - ROOT_EVAL = TASKS["root"] - ALL_EVALS = sorted([task for task in TASKS["groups"] if task != ROOT_EVAL]) main() diff --git a/scripts/update_wandb.py b/scripts/update_wandb.py index f55bc84..29b9772 100644 --- a/scripts/update_wandb.py +++ b/scripts/update_wandb.py @@ -1,6 +1,5 @@ import collections import statistics -import functools import re import json import os @@ -10,9 +9,24 @@ import pandas as pd import wandb +import iso639 +from evals.tasks import get_all_tasks, Task + + +def get_log(infos: List[dict], tasks_cfg: dict, all_tasks: list[Task]) -> dict[str, float]: + def agg(log: dict[str, dict[str, float]], prefix: str, tasks_to_agg: list[str], warn: bool = True): + missing = set(tasks_to_agg) - set(log) + if len(missing) > 0: + if warn: + print("WARNING! Macro aggregation for", prefix, "not available. Missing:", sorted(missing)) + return + + for metric in filter(lambda metric: "stderr" not in metric, all_metrics): + values = [log[taskname][metric] for taskname in tasks_to_agg if metric in log[taskname]] + if len(values) > 0: + log[f"{prefix}.macro"][metric] = statistics.mean(values) -def get_log(infos: List[dict], tasks_cfg: dict) -> Dict[str, float]: # Aggregate raw info. groups = {} results = {} @@ -25,26 +39,44 @@ def get_log(infos: List[dict], tasks_cfg: dict) -> Dict[str, float]: log = collections.defaultdict(dict) for dataname, details in results.items(): for metricname, val in details.items(): - if metricname == "alias" or val == "N/A": + if metricname == "alias" or val in ["N/A", " "]: continue - assert isinstance(val, float), val + assert isinstance(val, float), f"{dataname}.{metricname} = val" metricname, _ = metricname.split(",") # for some reason it is always acc,none so we remove the none. all_metrics.add(metricname) log[dataname][metricname] = val - # Do macro aggregations. - for groupname, subgroups in tasks_cfg["groups"].items(): - missing = set(subgroups) - set(results) - if len(missing) > 0: - print("WARNING! Macro aggregation for", groupname, "not available. Missing:", sorted(missing)) - continue - metrics = sorted(functools.reduce(set.union, (set(log[dataname]) for dataname in subgroups))) - for metric in filter(lambda metric: "stderr" not in metric, all_metrics): - values = [log[dataname][metric] for dataname in subgroups if metric in log[dataname]] - if len(values) > 0: - log[f"{groupname}_macro"][metric] = statistics.mean(values) - - # Finally, push to wandb. + # Now that we have all the "leaf task groups" we can do four aggregations: + # Let's start with the {language_group} agg. + for lang_group_name, langs in tasks_cfg["language_groups"].items(): + tasks_to_agg = [task.name for task in all_tasks + if task.language.pt1 in langs] + agg(log, f"language_group/{lang_group_name}", list(tasks_to_agg)) + + # Aggregate start with the {dimension} agg. + all_dims = sorted({task.dimension for task in all_tasks}) + for dim in all_dims: + tasks_to_agg = [task.name for task in all_tasks + if task.dimension == dim] + agg(log, f"dimension/{dim}", list(tasks_to_agg)) + + + # Agregate {dimension}.{language_group}.macro. + for lang_group_name, langs in tasks_cfg["language_groups"].items(): + for dim in all_dims: + tasks_to_agg = [task.name for task in all_tasks + if task.language.pt1 in langs and task.dimension == dim] + agg(log, f"dimension_group/{dim}.{lang_group_name}", list(tasks_to_agg)) + + # Finally, {dimension}.{language} + all_langs = sorted({task.language.pt1 for task in all_tasks}) + for lang in all_langs: + for dim in all_dims: + tasks_to_agg = [task.name for task in all_tasks + if task.language.pt1 == lang and task.dimension == dim] + agg(log, f"dimension_lang/{dim}.{lang}", list(tasks_to_agg)) + + # Finally, prepare wandb format. wandb_log = {} for dataname, details in log.items(): for metric, value in details.items(): @@ -65,12 +97,31 @@ def get_history(name: str) -> Dict[int, Dict[str, float]]: return history -def main(logs_root: Path, name: Optional[str], it: Optional[int], - tasks: Path): +def repair(all_tasks: list[Task]) -> list[Task]: + repaired = [] + for task in all_tasks: + if task.name == "ai2_arc": + repaired += [Task("arc_easy", (), 0, iso639.Lang("en"), task.dimension), + Task("arc_challenge", (), 0, iso639.Lang("en"), task.dimension)] + else: + repaired.append(task) + return repaired + - # model => {metric => value} - with open(tasks) as f: +def main(logs_root: Path, name: Optional[str], it: Optional[int], cfg: Path): + + all_tasks = get_all_tasks(all_tasks_json=cfg/"all_tasks.json") + all_tasks = repair(all_tasks) + with open(cfg/"tasks.json") as f: tasks_cfg = json.load(f) + all_languages = {task.language.pt1 for task in all_tasks} + + for lang_group in tasks_cfg["language_groups"].values(): + for lang in lang_group: + assert lang in all_languages or lang == "rm", lang + + tasks_cfg["language_groups"]["global"] = list(all_languages) + tasks_cfg["language_groups"]["multilingual"] = list(all_languages - {"en"}) # Grab each possible log and update wandb run. # First, iterate model names. @@ -99,9 +150,8 @@ def main(logs_root: Path, name: Optional[str], it: Optional[int], results.append(json.load(f)) if len(results) > 0: - log = get_log(results, tasks_cfg) + log = get_log(results, tasks_cfg, all_tasks) log.update({"ConsumedTokens": consumed_tokens, "OptStep": current_it}) - sublog = {k: v for k, v in log.items() if "macro/acc" in k} # Update log if needed. if consumed_tokens in history: if "eval_table" in history[consumed_tokens]: @@ -112,10 +162,8 @@ def main(logs_root: Path, name: Optional[str], it: Optional[int], print(sorted(set(history[consumed_tokens]) - set(log))) print("Important! wandb log at current iteration already found, but differs. Updating") run.log(log) - print("Logged sucessful:", sublog) else: run.log(log) - print("Logged sucessful:", sublog) # Update all_logs so we can build the table after this big loop. if p1.name not in latest_logs or latest_logs[p1.name]["ConsumedTokens"] < consumed_tokens: @@ -137,6 +185,7 @@ def main(logs_root: Path, name: Optional[str], it: Optional[int], df = pd.DataFrame([sublog]) with wandb.init(id=name, name=name) as run: run.log({"eval_table": wandb.Table(dataframe=df), "ConsumedTokens": log["ConsumedTokens"]}) + print("Goodbye") if __name__ == "__main__": @@ -144,6 +193,6 @@ def main(logs_root: Path, name: Optional[str], it: Optional[int], parser.add_argument("logs_root", type=Path) parser.add_argument("--name") parser.add_argument("--it", type=int) - parser.add_argument("--tasks", type=Path, default=Path("configs/tasks.json")) + parser.add_argument("--cfg", type=Path, default=Path("configs")) args = parser.parse_args() main(**vars(args)) diff --git a/src/evals/__init__.py b/src/evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/evals/tasks.py b/src/evals/tasks.py new file mode 100644 index 0000000..483c802 --- /dev/null +++ b/src/evals/tasks.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import re +import dataclasses +import enum +import json +import os +from pathlib import Path +from typing import Optional + +import requests +import iso639 +import prtpy + + +REQUEST_CACHE = {} + + +class Dimension(enum.StrEnum): + general_abilities = enum.auto() + factual_agnostic = enum.auto() + factual_regional = enum.auto() + + @classmethod + def get(cls, name: str) -> Dimension: + general = ["hellaswag", "piqa", "arc", "ai2_arc", "winogrande", "xwinograd", "xnli", "copa", "xcopa"] + agnostic = ["mmlu", "global_mmlu"] + regional = ["include", "switzerland_qa", "cultural_bench", "blend"] + + if any(name.startswith(group) for group in general): + return Dimension.general_abilities + if any(name.startswith(group) for group in agnostic): + return Dimension.factual_agnostic + if any(name.startswith(group) for group in regional): + return Dimension.factual_regional + raise ValueError(f"Could not infer dimension for task {task}") + + +class TaskKind(enum.StrEnum): + pretrain = enum.auto() + posttrain = enum.auto() + + +@dataclasses.dataclass +class Task: + name: str + kinds: tuple[TaskKind] + size: int = None + language: iso639.Lang = None + dimension: Dimension = None + alias: tuple[str] = () + + def __hash__(self) -> int: + return hash((self.name, self.kinds, self.size, self.language.pt1, self.dimension, self.alias)) + + def __post_init__(self): + # Infer size, language and dimension if not specified. + if self.size is None: + self.size = _infer_size(self.name) + if self.language is None: + chunks = self.name.split("_") + if len(chunks) == 1: # no language code, assume English. + self.language = iso639.Lang("en") + elif len(chunks[-1]) == 2: # iso639 already expected. + self.language = iso639.Lang(chunks[-1]) + else: # full name of the language given. + self.language = iso639.Lang(chunks[-1].title()) + if self.dimension is None: + self.dimension = Dimension.get(self.name) + +def _infer_size(name: str) -> int: + def query(source: str) -> dict: + headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"} + url = f"https://datasets-server.huggingface.co/info?dataset={source}" + response = requests.get(url, headers=headers) + return response.json() + + def get_split(names: list[str]) -> str: + maybe_specific = [pattern for pattern in specific_splits if re.match(pattern, name) is not None] + if len(maybe_specific): + pattern, = maybe_specific + return specific_splits[pattern] + if len(names) == 1: + return names[0] + if "test" in names: + return "test" + + print("Unknown", name, names) + + # Get `source, config` when exact names match. + exact_sources = { + "hellaswag": ("Rowan/hellaswag", "default"), + "mmlu": ("cais/mmlu", "all"), + "winogrande": ("allenai/winogrande", "winogrande_xl"), + "ai2_arc": ("allenai/ai2_arc", None), + "cultural_bench": ("kellycyy/CulturalBench", None), + } + + # Get `source` and infer language when task names have underscores. + underscore_sources = { + "arc": "alexandrainst/m_arc", + "global_mmlu": "CohereLabs/Global-MMLU-Lite", + "hellaswag": "alexandrainst/m_hellaswag", + "include_base_44": "CohereLabs/include-base-44", + "xcopa": "cambridgeltl/xcopa", + "xnli": "facebook/xnli", + "xwinograd": "Muennighoff/xwinograd", + } + + # These tasks use a very specific split: + specific_splits = { + r"xnli_.*": "validation", + "winogrande": "validation", + "^hellaswag$": "validation", + } + + # Get `source ` and `config` based on the above dicts. + if name in exact_sources: + source, config = exact_sources[name] + elif "_" in name: + root = "_".join(name.split("_")[:-1]) + if root not in underscore_sources: + raise ValueError(f"Could not infer size for task {name}") + source = underscore_sources[root] + config = name.split("_")[-1] + else: + raise ValueError(f"Could not infer size for task {name}") + if name.startswith("include_base_44"): # Include languages need to start with Capital letters. + config = config.title() + if source not in REQUEST_CACHE: # HTTPS request if not already queried. + REQUEST_CACHE[source] = query(source) + req = REQUEST_CACHE[source] + + # Get the `split_name` available. + if config is None: + all_splits = [subset["splits"] for subset in req["dataset_info"].values()] + for split in all_splits: + assert {split_name for split_name in split} == {split_name for split_name in all_splits[0]}, name + else: + all_splits = [req["dataset_info"][config]["splits"]] + split_name = get_split([split_name for split_name in all_splits[0]]) + + return sum(splits[split_name]["num_examples"] for splits in all_splits) + + +def get_all_tasks(all_tasks_json: Path = Path("configs/all_tasks.json")) -> list[Task]: + with open(all_tasks_json) as f: + raw_tasks = json.load(f) + + tasks = [] + for kind, names in raw_tasks["infer"].items(): + for name in names: + tasks.append(Task(name, (kind,))) + for row in raw_tasks["other"]: + tasks.append(Task( + name=row["name"], + kinds=tuple(row["kinds"]), + size=row.get("size"), + language=None if row["language"] is None else iso639.Lang(row["language"]), + dimension=row.get("dimension"), + alias=tuple(row.get("alias", ())), + )) + return tasks + + +def get_partition(tasks: Optional[list[Task]] = None, shards: int = 1, + all_tasks_json: Path = Path("configs/all_tasks.json")) -> tuple[list[Task], ...]: + + if tasks is None: + tasks = get_all_tasks(all_tasks_json=all_tasks_json) + if shards == 1: + return tasks, + return tuple(prtpy.partition(prtpy.partitioning.greedy, shards, tasks, + valueof=lambda task: task.size)) diff --git a/src/evals/write_task_yaml.py b/src/evals/write_task_yaml.py new file mode 100644 index 0000000..ccc3063 --- /dev/null +++ b/src/evals/write_task_yaml.py @@ -0,0 +1,30 @@ +import argparse +from pathlib import Path + +import yaml + +from evals.tasks import get_all_tasks + + +def main(config: Path, out: Path): + info = { + "group": "swissai_eval", + "tasks": sorted([task.name for task in get_all_tasks(all_tasks_json=config)]), + "aggregate_metric_list": [ + {"metric": "acc", "aggregation": "mean", "weight_by_size": False}, + {"metric": "acc_norm", "aggregation": "mean", "weight_by_size": False}, + {"metric": "perplexity", "aggregation": "mean", "weight_by_size": False}, + {"metric": "f1", "aggregation": "mean", "weight_by_size": False}, + {"metric": "exact_match", "aggregation": "mean", "weight_by_size": False}, + ], + "metadata": {"version": 1.1}, + } + with open(out, "w+") as f: + yaml.dump(info, f, sort_keys=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=Path, default=Path("configs/all_tasks.json")) + parser.add_argument("--out", type=Path, default=Path("default.yaml")) + main(**vars(parser.parse_args()))