|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 | # pylint: disable = invalid-name, line-too-long, use-dict-literal, consider-using-f-string, too-many-nested-blocks, self-assigning-variable |
3 | 3 | """TMaRCo detoxification.""" |
| 4 | +import os |
4 | 5 | import math |
5 | 6 | import itertools |
6 | 7 | import numpy as np |
|
17 | 18 | TrainingArguments, |
18 | 19 | pipeline, |
19 | 20 | AutoModelForCausalLM, |
| 21 | + AutoModelForMaskedLM, |
20 | 22 | BartForConditionalGeneration, |
21 | 23 | BartTokenizer, |
22 | 24 | AutoModelForSeq2SeqLM, |
@@ -102,18 +104,33 @@ def __init__( |
102 | 104 | "cuda" if torch.cuda.is_available() else "cpu" |
103 | 105 | ) |
104 | 106 |
|
105 | | - def load_models(self, experts: list, expert_weights: list = None): |
| 107 | + def load_models(self, experts: list[str] = None, expert_weights: list = None): |
106 | 108 | """Load expert models.""" |
107 | 109 | if expert_weights is not None: |
108 | 110 | self.expert_weights = expert_weights |
109 | 111 | expert_models = [] |
110 | 112 | for expert in experts: |
111 | | - if isinstance(expert, str): |
| 113 | + # Load TMaRCO models |
| 114 | + if (expert == "trustyai/gplus" or expert == "trustyai/gminus"): |
112 | 115 | expert = BartForConditionalGeneration.from_pretrained( |
113 | 116 | expert, |
114 | 117 | forced_bos_token_id=self.tokenizer.bos_token_id, |
115 | 118 | device_map="auto", |
116 | 119 | ) |
| 120 | + # Load local models |
| 121 | + elif os.path.exists(os.path.dirname(expert)): |
| 122 | + expert = AutoModelForMaskedLM.from_pretrained( |
| 123 | + expert, |
| 124 | + forced_bos_token_id=self.tokenizer.bos_token_id, |
| 125 | + device_map = "auto" |
| 126 | + ) |
| 127 | + # Load HuggingFace models |
| 128 | + else: |
| 129 | + expert = AutoModelForCausalLM.from_pretrained( |
| 130 | + expert, |
| 131 | + forced_bos_token_id=self.tokenizer.bos_token_id, |
| 132 | + device_map = "auto" |
| 133 | + ) |
117 | 134 | expert_models.append(expert) |
118 | 135 | self.experts = expert_models |
119 | 136 |
|
|
0 commit comments