Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions tests/test_model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from types import SimpleNamespace
from unittest.mock import MagicMock

import pytest

from tfkit.utility import model as model_utils
from tfkit.utility.model import load_pretrained_model, load_pretrained_tokenizer


def _make_config(**overrides):
defaults = {
"is_encoder_decoder": False,
"architectures": [],
"is_decoder": False,
}
defaults.update(overrides)
return SimpleNamespace(**defaults)


def test_load_pretrained_model_prefers_seq2seq(monkeypatch):
config = _make_config(is_encoder_decoder=True)

auto_config = MagicMock()
auto_config.from_pretrained.return_value = config
monkeypatch.setattr(model_utils, "AutoConfig", auto_config)

seq2seq_loader = MagicMock()
seq2seq_instance = object()
seq2seq_loader.from_pretrained.return_value = seq2seq_instance
monkeypatch.setattr(model_utils, "AutoModelForSeq2SeqLM", seq2seq_loader)

causal_loader = MagicMock()
monkeypatch.setattr(model_utils, "AutoModelForCausalLM", causal_loader)

base_loader = MagicMock()
monkeypatch.setattr(model_utils, "AutoModel", base_loader)

result = load_pretrained_model("mock-model", ["seq2seq"]) # type: ignore[arg-type]

assert result is seq2seq_instance
seq2seq_loader.from_pretrained.assert_called_once()
causal_loader.from_pretrained.assert_not_called()
base_loader.from_pretrained.assert_not_called()


def test_load_pretrained_model_prefers_causal(monkeypatch):
config = _make_config(architectures=["CustomForCausalLM"])

auto_config = MagicMock()
auto_config.from_pretrained.return_value = config
monkeypatch.setattr(model_utils, "AutoConfig", auto_config)

seq2seq_loader = MagicMock()
monkeypatch.setattr(model_utils, "AutoModelForSeq2SeqLM", seq2seq_loader)

causal_loader = MagicMock()
causal_instance = object()
causal_loader.from_pretrained.return_value = causal_instance
monkeypatch.setattr(model_utils, "AutoModelForCausalLM", causal_loader)

base_loader = MagicMock()
monkeypatch.setattr(model_utils, "AutoModel", base_loader)

result = load_pretrained_model("mock-model", ["clm"]) # type: ignore[arg-type]

assert result is causal_instance
causal_loader.from_pretrained.assert_called_once()
base_loader.from_pretrained.assert_not_called()


def test_load_pretrained_model_causal_fallback(monkeypatch):
config = _make_config(architectures=["CustomForCausalLM"])

auto_config = MagicMock()
auto_config.from_pretrained.return_value = config
monkeypatch.setattr(model_utils, "AutoConfig", auto_config)

seq2seq_loader = MagicMock()
monkeypatch.setattr(model_utils, "AutoModelForSeq2SeqLM", seq2seq_loader)

causal_loader = MagicMock()
causal_loader.from_pretrained.side_effect = ValueError("missing head")
monkeypatch.setattr(model_utils, "AutoModelForCausalLM", causal_loader)

base_loader = MagicMock()
base_instance = object()
base_loader.from_pretrained.return_value = base_instance
monkeypatch.setattr(model_utils, "AutoModel", base_loader)

result = load_pretrained_model("mock-model", ["clm"]) # type: ignore[arg-type]

assert result is base_instance
base_loader.from_pretrained.assert_called_once()
assert config.is_decoder is True


def test_load_pretrained_model_trust_remote_code_env(monkeypatch):
monkeypatch.setenv("TFKIT_TRUST_REMOTE_CODE", "false")

config = _make_config()
auto_config = MagicMock()
auto_config.from_pretrained.return_value = config
monkeypatch.setattr(model_utils, "AutoConfig", auto_config)

base_loader = MagicMock()
base_instance = object()
base_loader.from_pretrained.return_value = base_instance
monkeypatch.setattr(model_utils, "AutoModel", base_loader)

result = load_pretrained_model("mock-model", ["clas"]) # type: ignore[arg-type]

assert result is base_instance
auto_config.from_pretrained.assert_called_once_with(
"mock-model", trust_remote_code=False
)
base_loader.from_pretrained.assert_called_once()
_, kwargs = base_loader.from_pretrained.call_args
assert kwargs.get("trust_remote_code") is False


def test_load_pretrained_tokenizer_respects_env(monkeypatch):
monkeypatch.setenv("TFKIT_TRUST_REMOTE_CODE", "0")

tokenizer_loader = MagicMock()
monkeypatch.setattr(model_utils, "AutoTokenizer", tokenizer_loader)

load_pretrained_tokenizer("mock-tokenizer")

tokenizer_loader.from_pretrained.assert_called_once_with(
"mock-tokenizer", trust_remote_code=False
)
147 changes: 147 additions & 0 deletions tests/test_task_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from types import SimpleNamespace

import torch
from torch import nn

from tfkit.task.clm.model import Model as CLMModel
from tfkit.task.seq2seq.model import Model as Seq2SeqModel


class DummyTokenizer:
def __init__(self, vocab_size):
self.vocab_size = vocab_size

def __len__(self):
return self.vocab_size

def convert_ids_to_tokens(self, idx):
return f"token-{idx}"


class DummyCausalPretrained(nn.Module):
def __init__(self):
super().__init__()
self.config = SimpleNamespace(vocab_size=5, hidden_size=4)
self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size)
self.last_kwargs = None

def get_output_embeddings(self):
return self.output_layer

def forward(self, input_ids, attention_mask=None, return_dict=True, **kwargs):
self.last_kwargs = kwargs
batch_size, seq_len = input_ids.shape
logits = torch.zeros(batch_size, seq_len, self.config.vocab_size)
outputs = {
"logits": logits,
"last_hidden_state": torch.zeros(batch_size, seq_len, self.config.hidden_size),
}
if "labels" in kwargs:
outputs["loss"] = torch.tensor(0.0)
return outputs


class DummyEncoderPretrained(nn.Module):
def __init__(self):
super().__init__()
self.config = SimpleNamespace(vocab_size=5, hidden_size=4)
self.last_kwargs = None

def get_output_embeddings(self):
return None

def forward(self, input_ids, attention_mask=None, return_dict=True, **kwargs):
self.last_kwargs = kwargs
batch_size, seq_len = input_ids.shape
hidden = torch.zeros(batch_size, seq_len, self.config.hidden_size)
return {"last_hidden_state": hidden}


class DummySeq2SeqPretrained(nn.Module):
def __init__(self):
super().__init__()
self.config = SimpleNamespace(vocab_size=3, hidden_size=4)
self.decoder = nn.Module()
self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size)

def get_output_embeddings(self):
return self.output_layer

def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
output_hidden_states=False,
use_cache=False,
return_dict=True,
**kwargs,
):
batch_size, seq_len = decoder_input_ids.shape
hidden = torch.zeros(batch_size, seq_len, self.config.hidden_size)
outputs = {
"last_hidden_state": hidden,
"decoder_hidden_states": (hidden,),
}
return outputs


def test_clm_model_uses_pretrained_head_for_loss():
tokenizer = DummyTokenizer(vocab_size=5)
pretrained = DummyCausalPretrained()
model = CLMModel(tokenizer=tokenizer, pretrained=pretrained)

batch = {
"input": torch.zeros((1, 2), dtype=torch.long),
"mask": torch.ones((1, 2), dtype=torch.long),
"target": torch.tensor([[0, -1]]),
}

loss = model.forward(batch, eval=False)
assert torch.is_tensor(loss)
assert "labels" in pretrained.last_kwargs
assert pretrained.last_kwargs["labels"].tolist() == [[0, -100]]

eval_batch = {
**batch,
"start": [0],
}
result = model.forward(eval_batch, eval=True)
assert isinstance(result, dict)
assert "max_item" in result


def test_clm_model_falls_back_to_linear_head():
tokenizer = DummyTokenizer(vocab_size=5)
pretrained = DummyEncoderPretrained()
model = CLMModel(tokenizer=tokenizer, pretrained=pretrained)

batch = {
"input": torch.zeros((1, 2), dtype=torch.long),
"mask": torch.ones((1, 2), dtype=torch.long),
"target": torch.tensor([[0, -1]]),
}

loss = model.forward(batch, eval=False)
assert torch.is_tensor(loss)
assert pretrained.last_kwargs == {}


def test_seq2seq_model_uses_pretrained_output_head():
tokenizer = DummyTokenizer(vocab_size=3)
pretrained = DummySeq2SeqPretrained()
model = Seq2SeqModel(tokenizer=tokenizer, pretrained=pretrained)

batch = {
"input": torch.zeros((1, 1), dtype=torch.long),
"prev": torch.zeros((1, 1), dtype=torch.long),
"encoder_mask": torch.ones((1, 1), dtype=torch.long),
"decoder_mask": torch.ones((1, 1), dtype=torch.long),
"target": torch.zeros((1, 1), dtype=torch.long),
"ntarget": torch.full((1, 1), -1),
}

loss = model.forward(batch, eval=False)
assert torch.is_tensor(loss)
assert model.model is pretrained.output_layer
55 changes: 46 additions & 9 deletions tfkit/task/clm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,60 @@ class Model(BaseTFKitModel):

def __init__(self, tokenizer, pretrained, maxlen=512, **kwargs):
super().__init__(tokenizer, pretrained, maxlen, **kwargs)
self.model = nn.Linear(self.get_hidden_size(), self.get_vocab_size())
self.model = self._resolve_output_head()
self.uses_pretrained_head = self.model is not None
if not self.uses_pretrained_head:
self.model = nn.Linear(self.get_hidden_size(), self.get_vocab_size())

self._setup_predictor(AutoRegressivePredictor, Preprocessor)

def _resolve_output_head(self):
"""Return the pretrained language modeling head if available."""

if hasattr(self.pretrained, "get_output_embeddings"):
output_embeddings = self.pretrained.get_output_embeddings()
if output_embeddings is not None:
return output_embeddings
if hasattr(self.pretrained, "lm_head"):
return self.pretrained.lm_head
if hasattr(self.pretrained, "cls"):
return self.pretrained.cls
return None

def forward(self, batch_data, eval=False, beamsearch=False, max_return=1, **kwargs):
inputs = batch_data['input']
masks = batch_data['mask']
tokens_tensor = torch.as_tensor(inputs)
mask_tensors = torch.as_tensor(masks)
model_kwargs = {
'attention_mask': mask_tensors,
'return_dict': True,
}
if eval:
model_kwargs['use_cache'] = False

if eval:
outputs = self.pretrained(tokens_tensor, **model_kwargs)
prediction_scores = outputs['logits'] if 'logits' in outputs else outputs[0]
else:
targets = batch_data['target']
loss_tensors = torch.as_tensor(targets)

outputs = self.pretrained(tokens_tensor, attention_mask=mask_tensors)
prediction_scores = self.model(outputs[0])
if self.uses_pretrained_head:
labels = loss_tensors.clone().long()
labels[labels == -1] = -100
model_kwargs['labels'] = labels
outputs = self.pretrained(tokens_tensor, **model_kwargs)
prediction_scores = outputs['logits'] if 'logits' in outputs else outputs[0]
masked_lm_loss = outputs['loss']
else:
loss_tensors = loss_tensors.long()
outputs = self.pretrained(tokens_tensor, **model_kwargs)
hidden_states = outputs['last_hidden_state'] if 'last_hidden_state' in outputs else outputs[0]
prediction_scores = self.model(hidden_states)
loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.vocab_size),
loss_tensors.view(-1))

if eval:
result_dict = {}
Expand All @@ -39,11 +82,5 @@ def forward(self, batch_data, eval=False, beamsearch=False, max_return=1, **kwar
result_dict['label_prob'] = prob_result
outputs = result_dict
else:
targets = batch_data['target']
loss_tensors = torch.as_tensor(targets)
loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.vocab_size),
loss_tensors.view(-1))

outputs = masked_lm_loss
return outputs
Loading
Loading