Skip to content

Commit 21fa9cb

Browse files
committed
Fix styling
1 parent 5d49d74 commit 21fa9cb

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/trustyai/language/detoxify/tmarco.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def __init__(
104104
"cuda" if torch.cuda.is_available() else "cpu"
105105
)
106106

107-
def load_models(self, experts: list[str] = None, expert_weights: list = None): # pylint: disable=unsubscriptable-object
107+
def load_models(
108+
self, experts: list[str] = None, expert_weights: list = None
109+
): # pylint: disable=unsubscriptable-object
108110
"""Load expert models."""
109111
if expert_weights is not None:
110112
self.expert_weights = expert_weights
@@ -122,14 +124,14 @@ def load_models(self, experts: list[str] = None, expert_weights: list = None): #
122124
expert = AutoModelForMaskedLM.from_pretrained(
123125
expert,
124126
forced_bos_token_id=self.tokenizer.bos_token_id,
125-
device_map = "auto"
127+
device_map="auto",
126128
)
127129
# Load HuggingFace models
128130
else:
129131
expert = AutoModelForCausalLM.from_pretrained(
130132
expert,
131133
forced_bos_token_id=self.tokenizer.bos_token_id,
132-
device_map = "auto"
134+
device_map="auto",
133135
)
134136
expert_models.append(expert)
135137
self.experts = expert_models

0 commit comments

Comments
 (0)