|
1 | 1 | # -*- coding: utf-8 -*-
|
2 | 2 | # pylint: disable = invalid-name, line-too-long, use-dict-literal, consider-using-f-string, too-many-nested-blocks, self-assigning-variable
|
3 | 3 | """TMaRCo detoxification."""
|
| 4 | +import os |
4 | 5 | import math
|
5 | 6 | import itertools
|
6 | 7 | import numpy as np
|
|
17 | 18 | TrainingArguments,
|
18 | 19 | pipeline,
|
19 | 20 | AutoModelForCausalLM,
|
| 21 | + AutoModelForMaskedLM, |
20 | 22 | BartForConditionalGeneration,
|
21 | 23 | BartTokenizer,
|
22 | 24 | AutoModelForSeq2SeqLM,
|
@@ -102,18 +104,33 @@ def __init__(
|
102 | 104 | "cuda" if torch.cuda.is_available() else "cpu"
|
103 | 105 | )
|
104 | 106 |
|
105 |
| - def load_models(self, experts: list, expert_weights: list = None): |
| 107 | + def load_models(self, experts: list[str] = None, expert_weights: list = None): |
106 | 108 | """Load expert models."""
|
107 | 109 | if expert_weights is not None:
|
108 | 110 | self.expert_weights = expert_weights
|
109 | 111 | expert_models = []
|
110 | 112 | for expert in experts:
|
111 |
| - if isinstance(expert, str): |
| 113 | + # Load TMaRCO models |
| 114 | + if (expert == "trustyai/gplus" or expert == "trustyai/gminus"): |
112 | 115 | expert = BartForConditionalGeneration.from_pretrained(
|
113 | 116 | expert,
|
114 | 117 | forced_bos_token_id=self.tokenizer.bos_token_id,
|
115 | 118 | device_map="auto",
|
116 | 119 | )
|
| 120 | + # Load local models |
| 121 | + elif os.path.exists(os.path.dirname(expert)): |
| 122 | + expert = AutoModelForMaskedLM.from_pretrained( |
| 123 | + expert, |
| 124 | + forced_bos_token_id=self.tokenizer.bos_token_id, |
| 125 | + device_map = "auto" |
| 126 | + ) |
| 127 | + # Load HuggingFace models |
| 128 | + else: |
| 129 | + expert = AutoModelForCausalLM.from_pretrained( |
| 130 | + expert, |
| 131 | + forced_bos_token_id=self.tokenizer.bos_token_id, |
| 132 | + device_map = "auto" |
| 133 | + ) |
117 | 134 | expert_models.append(expert)
|
118 | 135 | self.experts = expert_models
|
119 | 136 |
|
|
0 commit comments