Skip to content

Commit a4a3da8

Browse files
Add ability to load local and HF models (#212)
1 parent fa3e66a commit a4a3da8

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

src/trustyai/language/detoxify/tmarco.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
# pylint: disable = invalid-name, line-too-long, use-dict-literal, consider-using-f-string, too-many-nested-blocks, self-assigning-variable
33
"""TMaRCo detoxification."""
4+
import os
45
import math
56
import itertools
67
import numpy as np
@@ -17,6 +18,7 @@
1718
TrainingArguments,
1819
pipeline,
1920
AutoModelForCausalLM,
21+
AutoModelForMaskedLM,
2022
BartForConditionalGeneration,
2123
BartTokenizer,
2224
AutoModelForSeq2SeqLM,
@@ -102,18 +104,33 @@ def __init__(
102104
"cuda" if torch.cuda.is_available() else "cpu"
103105
)
104106

105-
def load_models(self, experts: list, expert_weights: list = None):
107+
def load_models(self, experts: list[str] = None, expert_weights: list = None):
106108
"""Load expert models."""
107109
if expert_weights is not None:
108110
self.expert_weights = expert_weights
109111
expert_models = []
110112
for expert in experts:
111-
if isinstance(expert, str):
113+
# Load TMaRCO models
114+
if (expert == "trustyai/gplus" or expert == "trustyai/gminus"):
112115
expert = BartForConditionalGeneration.from_pretrained(
113116
expert,
114117
forced_bos_token_id=self.tokenizer.bos_token_id,
115118
device_map="auto",
116119
)
120+
# Load local models
121+
elif os.path.exists(os.path.dirname(expert)):
122+
expert = AutoModelForMaskedLM.from_pretrained(
123+
expert,
124+
forced_bos_token_id=self.tokenizer.bos_token_id,
125+
device_map = "auto"
126+
)
127+
# Load HuggingFace models
128+
else:
129+
expert = AutoModelForCausalLM.from_pretrained(
130+
expert,
131+
forced_bos_token_id=self.tokenizer.bos_token_id,
132+
device_map = "auto"
133+
)
117134
expert_models.append(expert)
118135
self.experts = expert_models
119136

0 commit comments

Comments
 (0)