Skip to content

Commit dfe11c6

Browse files
authored
Merge pull request #20 from voidful/codex/refactor-code-for-latest-model-support
Refactor model loading for modern transformers
2 parents f0680e6 + 2c42e3d commit dfe11c6

File tree

7 files changed

+442
-29
lines changed

7 files changed

+442
-29
lines changed

tests/test_model_loader.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from types import SimpleNamespace
2+
from unittest.mock import MagicMock
3+
4+
import pytest
5+
6+
from tfkit.utility import model as model_utils
7+
from tfkit.utility.model import load_pretrained_model, load_pretrained_tokenizer
8+
9+
10+
def _make_config(**overrides):
11+
defaults = {
12+
"is_encoder_decoder": False,
13+
"architectures": [],
14+
"is_decoder": False,
15+
}
16+
defaults.update(overrides)
17+
return SimpleNamespace(**defaults)
18+
19+
20+
def test_load_pretrained_model_prefers_seq2seq(monkeypatch):
21+
config = _make_config(is_encoder_decoder=True)
22+
23+
auto_config = MagicMock()
24+
auto_config.from_pretrained.return_value = config
25+
monkeypatch.setattr(model_utils, "AutoConfig", auto_config)
26+
27+
seq2seq_loader = MagicMock()
28+
seq2seq_instance = object()
29+
seq2seq_loader.from_pretrained.return_value = seq2seq_instance
30+
monkeypatch.setattr(model_utils, "AutoModelForSeq2SeqLM", seq2seq_loader)
31+
32+
causal_loader = MagicMock()
33+
monkeypatch.setattr(model_utils, "AutoModelForCausalLM", causal_loader)
34+
35+
base_loader = MagicMock()
36+
monkeypatch.setattr(model_utils, "AutoModel", base_loader)
37+
38+
result = load_pretrained_model("mock-model", ["seq2seq"]) # type: ignore[arg-type]
39+
40+
assert result is seq2seq_instance
41+
seq2seq_loader.from_pretrained.assert_called_once()
42+
causal_loader.from_pretrained.assert_not_called()
43+
base_loader.from_pretrained.assert_not_called()
44+
45+
46+
def test_load_pretrained_model_prefers_causal(monkeypatch):
47+
config = _make_config(architectures=["CustomForCausalLM"])
48+
49+
auto_config = MagicMock()
50+
auto_config.from_pretrained.return_value = config
51+
monkeypatch.setattr(model_utils, "AutoConfig", auto_config)
52+
53+
seq2seq_loader = MagicMock()
54+
monkeypatch.setattr(model_utils, "AutoModelForSeq2SeqLM", seq2seq_loader)
55+
56+
causal_loader = MagicMock()
57+
causal_instance = object()
58+
causal_loader.from_pretrained.return_value = causal_instance
59+
monkeypatch.setattr(model_utils, "AutoModelForCausalLM", causal_loader)
60+
61+
base_loader = MagicMock()
62+
monkeypatch.setattr(model_utils, "AutoModel", base_loader)
63+
64+
result = load_pretrained_model("mock-model", ["clm"]) # type: ignore[arg-type]
65+
66+
assert result is causal_instance
67+
causal_loader.from_pretrained.assert_called_once()
68+
base_loader.from_pretrained.assert_not_called()
69+
70+
71+
def test_load_pretrained_model_causal_fallback(monkeypatch):
72+
config = _make_config(architectures=["CustomForCausalLM"])
73+
74+
auto_config = MagicMock()
75+
auto_config.from_pretrained.return_value = config
76+
monkeypatch.setattr(model_utils, "AutoConfig", auto_config)
77+
78+
seq2seq_loader = MagicMock()
79+
monkeypatch.setattr(model_utils, "AutoModelForSeq2SeqLM", seq2seq_loader)
80+
81+
causal_loader = MagicMock()
82+
causal_loader.from_pretrained.side_effect = ValueError("missing head")
83+
monkeypatch.setattr(model_utils, "AutoModelForCausalLM", causal_loader)
84+
85+
base_loader = MagicMock()
86+
base_instance = object()
87+
base_loader.from_pretrained.return_value = base_instance
88+
monkeypatch.setattr(model_utils, "AutoModel", base_loader)
89+
90+
result = load_pretrained_model("mock-model", ["clm"]) # type: ignore[arg-type]
91+
92+
assert result is base_instance
93+
base_loader.from_pretrained.assert_called_once()
94+
assert config.is_decoder is True
95+
96+
97+
def test_load_pretrained_model_trust_remote_code_env(monkeypatch):
98+
monkeypatch.setenv("TFKIT_TRUST_REMOTE_CODE", "false")
99+
100+
config = _make_config()
101+
auto_config = MagicMock()
102+
auto_config.from_pretrained.return_value = config
103+
monkeypatch.setattr(model_utils, "AutoConfig", auto_config)
104+
105+
base_loader = MagicMock()
106+
base_instance = object()
107+
base_loader.from_pretrained.return_value = base_instance
108+
monkeypatch.setattr(model_utils, "AutoModel", base_loader)
109+
110+
result = load_pretrained_model("mock-model", ["clas"]) # type: ignore[arg-type]
111+
112+
assert result is base_instance
113+
auto_config.from_pretrained.assert_called_once_with(
114+
"mock-model", trust_remote_code=False
115+
)
116+
base_loader.from_pretrained.assert_called_once()
117+
_, kwargs = base_loader.from_pretrained.call_args
118+
assert kwargs.get("trust_remote_code") is False
119+
120+
121+
def test_load_pretrained_tokenizer_respects_env(monkeypatch):
122+
monkeypatch.setenv("TFKIT_TRUST_REMOTE_CODE", "0")
123+
124+
tokenizer_loader = MagicMock()
125+
monkeypatch.setattr(model_utils, "AutoTokenizer", tokenizer_loader)
126+
127+
load_pretrained_tokenizer("mock-tokenizer")
128+
129+
tokenizer_loader.from_pretrained.assert_called_once_with(
130+
"mock-tokenizer", trust_remote_code=False
131+
)

tests/test_task_generation.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from types import SimpleNamespace
2+
3+
import torch
4+
from torch import nn
5+
6+
from tfkit.task.clm.model import Model as CLMModel
7+
from tfkit.task.seq2seq.model import Model as Seq2SeqModel
8+
9+
10+
class DummyTokenizer:
11+
def __init__(self, vocab_size):
12+
self.vocab_size = vocab_size
13+
14+
def __len__(self):
15+
return self.vocab_size
16+
17+
def convert_ids_to_tokens(self, idx):
18+
return f"token-{idx}"
19+
20+
21+
class DummyCausalPretrained(nn.Module):
22+
def __init__(self):
23+
super().__init__()
24+
self.config = SimpleNamespace(vocab_size=5, hidden_size=4)
25+
self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size)
26+
self.last_kwargs = None
27+
28+
def get_output_embeddings(self):
29+
return self.output_layer
30+
31+
def forward(self, input_ids, attention_mask=None, return_dict=True, **kwargs):
32+
self.last_kwargs = kwargs
33+
batch_size, seq_len = input_ids.shape
34+
logits = torch.zeros(batch_size, seq_len, self.config.vocab_size)
35+
outputs = {
36+
"logits": logits,
37+
"last_hidden_state": torch.zeros(batch_size, seq_len, self.config.hidden_size),
38+
}
39+
if "labels" in kwargs:
40+
outputs["loss"] = torch.tensor(0.0)
41+
return outputs
42+
43+
44+
class DummyEncoderPretrained(nn.Module):
45+
def __init__(self):
46+
super().__init__()
47+
self.config = SimpleNamespace(vocab_size=5, hidden_size=4)
48+
self.last_kwargs = None
49+
50+
def get_output_embeddings(self):
51+
return None
52+
53+
def forward(self, input_ids, attention_mask=None, return_dict=True, **kwargs):
54+
self.last_kwargs = kwargs
55+
batch_size, seq_len = input_ids.shape
56+
hidden = torch.zeros(batch_size, seq_len, self.config.hidden_size)
57+
return {"last_hidden_state": hidden}
58+
59+
60+
class DummySeq2SeqPretrained(nn.Module):
61+
def __init__(self):
62+
super().__init__()
63+
self.config = SimpleNamespace(vocab_size=3, hidden_size=4)
64+
self.decoder = nn.Module()
65+
self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size)
66+
67+
def get_output_embeddings(self):
68+
return self.output_layer
69+
70+
def forward(
71+
self,
72+
input_ids=None,
73+
attention_mask=None,
74+
decoder_input_ids=None,
75+
decoder_attention_mask=None,
76+
output_hidden_states=False,
77+
use_cache=False,
78+
return_dict=True,
79+
**kwargs,
80+
):
81+
batch_size, seq_len = decoder_input_ids.shape
82+
hidden = torch.zeros(batch_size, seq_len, self.config.hidden_size)
83+
outputs = {
84+
"last_hidden_state": hidden,
85+
"decoder_hidden_states": (hidden,),
86+
}
87+
return outputs
88+
89+
90+
def test_clm_model_uses_pretrained_head_for_loss():
91+
tokenizer = DummyTokenizer(vocab_size=5)
92+
pretrained = DummyCausalPretrained()
93+
model = CLMModel(tokenizer=tokenizer, pretrained=pretrained)
94+
95+
batch = {
96+
"input": torch.zeros((1, 2), dtype=torch.long),
97+
"mask": torch.ones((1, 2), dtype=torch.long),
98+
"target": torch.tensor([[0, -1]]),
99+
}
100+
101+
loss = model.forward(batch, eval=False)
102+
assert torch.is_tensor(loss)
103+
assert "labels" in pretrained.last_kwargs
104+
assert pretrained.last_kwargs["labels"].tolist() == [[0, -100]]
105+
106+
eval_batch = {
107+
**batch,
108+
"start": [0],
109+
}
110+
result = model.forward(eval_batch, eval=True)
111+
assert isinstance(result, dict)
112+
assert "max_item" in result
113+
114+
115+
def test_clm_model_falls_back_to_linear_head():
116+
tokenizer = DummyTokenizer(vocab_size=5)
117+
pretrained = DummyEncoderPretrained()
118+
model = CLMModel(tokenizer=tokenizer, pretrained=pretrained)
119+
120+
batch = {
121+
"input": torch.zeros((1, 2), dtype=torch.long),
122+
"mask": torch.ones((1, 2), dtype=torch.long),
123+
"target": torch.tensor([[0, -1]]),
124+
}
125+
126+
loss = model.forward(batch, eval=False)
127+
assert torch.is_tensor(loss)
128+
assert pretrained.last_kwargs == {}
129+
130+
131+
def test_seq2seq_model_uses_pretrained_output_head():
132+
tokenizer = DummyTokenizer(vocab_size=3)
133+
pretrained = DummySeq2SeqPretrained()
134+
model = Seq2SeqModel(tokenizer=tokenizer, pretrained=pretrained)
135+
136+
batch = {
137+
"input": torch.zeros((1, 1), dtype=torch.long),
138+
"prev": torch.zeros((1, 1), dtype=torch.long),
139+
"encoder_mask": torch.ones((1, 1), dtype=torch.long),
140+
"decoder_mask": torch.ones((1, 1), dtype=torch.long),
141+
"target": torch.zeros((1, 1), dtype=torch.long),
142+
"ntarget": torch.full((1, 1), -1),
143+
}
144+
145+
loss = model.forward(batch, eval=False)
146+
assert torch.is_tensor(loss)
147+
assert model.model is pretrained.output_layer

tfkit/task/clm/model.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,60 @@ class Model(BaseTFKitModel):
1212

1313
def __init__(self, tokenizer, pretrained, maxlen=512, **kwargs):
1414
super().__init__(tokenizer, pretrained, maxlen, **kwargs)
15-
self.model = nn.Linear(self.get_hidden_size(), self.get_vocab_size())
15+
self.model = self._resolve_output_head()
16+
self.uses_pretrained_head = self.model is not None
17+
if not self.uses_pretrained_head:
18+
self.model = nn.Linear(self.get_hidden_size(), self.get_vocab_size())
19+
1620
self._setup_predictor(AutoRegressivePredictor, Preprocessor)
1721

22+
def _resolve_output_head(self):
23+
"""Return the pretrained language modeling head if available."""
24+
25+
if hasattr(self.pretrained, "get_output_embeddings"):
26+
output_embeddings = self.pretrained.get_output_embeddings()
27+
if output_embeddings is not None:
28+
return output_embeddings
29+
if hasattr(self.pretrained, "lm_head"):
30+
return self.pretrained.lm_head
31+
if hasattr(self.pretrained, "cls"):
32+
return self.pretrained.cls
33+
return None
34+
1835
def forward(self, batch_data, eval=False, beamsearch=False, max_return=1, **kwargs):
1936
inputs = batch_data['input']
2037
masks = batch_data['mask']
2138
tokens_tensor = torch.as_tensor(inputs)
2239
mask_tensors = torch.as_tensor(masks)
40+
model_kwargs = {
41+
'attention_mask': mask_tensors,
42+
'return_dict': True,
43+
}
44+
if eval:
45+
model_kwargs['use_cache'] = False
46+
47+
if eval:
48+
outputs = self.pretrained(tokens_tensor, **model_kwargs)
49+
prediction_scores = outputs['logits'] if 'logits' in outputs else outputs[0]
50+
else:
51+
targets = batch_data['target']
52+
loss_tensors = torch.as_tensor(targets)
2353

24-
outputs = self.pretrained(tokens_tensor, attention_mask=mask_tensors)
25-
prediction_scores = self.model(outputs[0])
54+
if self.uses_pretrained_head:
55+
labels = loss_tensors.clone().long()
56+
labels[labels == -1] = -100
57+
model_kwargs['labels'] = labels
58+
outputs = self.pretrained(tokens_tensor, **model_kwargs)
59+
prediction_scores = outputs['logits'] if 'logits' in outputs else outputs[0]
60+
masked_lm_loss = outputs['loss']
61+
else:
62+
loss_tensors = loss_tensors.long()
63+
outputs = self.pretrained(tokens_tensor, **model_kwargs)
64+
hidden_states = outputs['last_hidden_state'] if 'last_hidden_state' in outputs else outputs[0]
65+
prediction_scores = self.model(hidden_states)
66+
loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
67+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.vocab_size),
68+
loss_tensors.view(-1))
2669

2770
if eval:
2871
result_dict = {}
@@ -39,11 +82,5 @@ def forward(self, batch_data, eval=False, beamsearch=False, max_return=1, **kwar
3982
result_dict['label_prob'] = prob_result
4083
outputs = result_dict
4184
else:
42-
targets = batch_data['target']
43-
loss_tensors = torch.as_tensor(targets)
44-
loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
45-
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.vocab_size),
46-
loss_tensors.view(-1))
47-
4885
outputs = masked_lm_loss
4986
return outputs

0 commit comments

Comments
 (0)