Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/trustyai/language/detoxify/tmarco.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# pylint: disable = invalid-name, line-too-long, use-dict-literal, consider-using-f-string, too-many-nested-blocks, self-assigning-variable
"""TMaRCo detoxification."""
import os
import math
import itertools
import numpy as np
Expand All @@ -17,6 +18,7 @@
TrainingArguments,
pipeline,
AutoModelForCausalLM,
AutoModelForMaskedLM,
BartForConditionalGeneration,
BartTokenizer,
AutoModelForSeq2SeqLM,
Expand Down Expand Up @@ -102,18 +104,33 @@ def __init__(
"cuda" if torch.cuda.is_available() else "cpu"
)

def load_models(self, experts: list, expert_weights: list = None):
def load_models(self, experts: list[str] = None, expert_weights: list = None):
"""Load expert models."""
if expert_weights is not None:
self.expert_weights = expert_weights
expert_models = []
for expert in experts:
if isinstance(expert, str):
# Load TMaRCO models
if (expert == "trustyai/gplus" or expert == "trustyai/gminus"):
expert = BartForConditionalGeneration.from_pretrained(
expert,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map="auto",
)
# Load local models
elif os.path.exists(os.path.dirname(expert)):
expert = AutoModelForMaskedLM.from_pretrained(
expert,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map = "auto"
)
# Load HuggingFace models
else:
expert = AutoModelForCausalLM.from_pretrained(
expert,
forced_bos_token_id=self.tokenizer.bos_token_id,
device_map = "auto"
)
expert_models.append(expert)
self.experts = expert_models

Expand Down
Loading