diff --git a/tests/test_model_loader.py b/tests/test_model_loader.py new file mode 100644 index 0000000..bf7d4d4 --- /dev/null +++ b/tests/test_model_loader.py @@ -0,0 +1,131 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from tfkit.utility import model as model_utils +from tfkit.utility.model import load_pretrained_model, load_pretrained_tokenizer + + +def _make_config(**overrides): + defaults = { + "is_encoder_decoder": False, + "architectures": [], + "is_decoder": False, + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +def test_load_pretrained_model_prefers_seq2seq(monkeypatch): + config = _make_config(is_encoder_decoder=True) + + auto_config = MagicMock() + auto_config.from_pretrained.return_value = config + monkeypatch.setattr(model_utils, "AutoConfig", auto_config) + + seq2seq_loader = MagicMock() + seq2seq_instance = object() + seq2seq_loader.from_pretrained.return_value = seq2seq_instance + monkeypatch.setattr(model_utils, "AutoModelForSeq2SeqLM", seq2seq_loader) + + causal_loader = MagicMock() + monkeypatch.setattr(model_utils, "AutoModelForCausalLM", causal_loader) + + base_loader = MagicMock() + monkeypatch.setattr(model_utils, "AutoModel", base_loader) + + result = load_pretrained_model("mock-model", ["seq2seq"]) # type: ignore[arg-type] + + assert result is seq2seq_instance + seq2seq_loader.from_pretrained.assert_called_once() + causal_loader.from_pretrained.assert_not_called() + base_loader.from_pretrained.assert_not_called() + + +def test_load_pretrained_model_prefers_causal(monkeypatch): + config = _make_config(architectures=["CustomForCausalLM"]) + + auto_config = MagicMock() + auto_config.from_pretrained.return_value = config + monkeypatch.setattr(model_utils, "AutoConfig", auto_config) + + seq2seq_loader = MagicMock() + monkeypatch.setattr(model_utils, "AutoModelForSeq2SeqLM", seq2seq_loader) + + causal_loader = MagicMock() + causal_instance = object() + causal_loader.from_pretrained.return_value = causal_instance + monkeypatch.setattr(model_utils, "AutoModelForCausalLM", causal_loader) + + base_loader = MagicMock() + monkeypatch.setattr(model_utils, "AutoModel", base_loader) + + result = load_pretrained_model("mock-model", ["clm"]) # type: ignore[arg-type] + + assert result is causal_instance + causal_loader.from_pretrained.assert_called_once() + base_loader.from_pretrained.assert_not_called() + + +def test_load_pretrained_model_causal_fallback(monkeypatch): + config = _make_config(architectures=["CustomForCausalLM"]) + + auto_config = MagicMock() + auto_config.from_pretrained.return_value = config + monkeypatch.setattr(model_utils, "AutoConfig", auto_config) + + seq2seq_loader = MagicMock() + monkeypatch.setattr(model_utils, "AutoModelForSeq2SeqLM", seq2seq_loader) + + causal_loader = MagicMock() + causal_loader.from_pretrained.side_effect = ValueError("missing head") + monkeypatch.setattr(model_utils, "AutoModelForCausalLM", causal_loader) + + base_loader = MagicMock() + base_instance = object() + base_loader.from_pretrained.return_value = base_instance + monkeypatch.setattr(model_utils, "AutoModel", base_loader) + + result = load_pretrained_model("mock-model", ["clm"]) # type: ignore[arg-type] + + assert result is base_instance + base_loader.from_pretrained.assert_called_once() + assert config.is_decoder is True + + +def test_load_pretrained_model_trust_remote_code_env(monkeypatch): + monkeypatch.setenv("TFKIT_TRUST_REMOTE_CODE", "false") + + config = _make_config() + auto_config = MagicMock() + auto_config.from_pretrained.return_value = config + monkeypatch.setattr(model_utils, "AutoConfig", auto_config) + + base_loader = MagicMock() + base_instance = object() + base_loader.from_pretrained.return_value = base_instance + monkeypatch.setattr(model_utils, "AutoModel", base_loader) + + result = load_pretrained_model("mock-model", ["clas"]) # type: ignore[arg-type] + + assert result is base_instance + auto_config.from_pretrained.assert_called_once_with( + "mock-model", trust_remote_code=False + ) + base_loader.from_pretrained.assert_called_once() + _, kwargs = base_loader.from_pretrained.call_args + assert kwargs.get("trust_remote_code") is False + + +def test_load_pretrained_tokenizer_respects_env(monkeypatch): + monkeypatch.setenv("TFKIT_TRUST_REMOTE_CODE", "0") + + tokenizer_loader = MagicMock() + monkeypatch.setattr(model_utils, "AutoTokenizer", tokenizer_loader) + + load_pretrained_tokenizer("mock-tokenizer") + + tokenizer_loader.from_pretrained.assert_called_once_with( + "mock-tokenizer", trust_remote_code=False + ) diff --git a/tests/test_task_generation.py b/tests/test_task_generation.py new file mode 100644 index 0000000..af550e9 --- /dev/null +++ b/tests/test_task_generation.py @@ -0,0 +1,147 @@ +from types import SimpleNamespace + +import torch +from torch import nn + +from tfkit.task.clm.model import Model as CLMModel +from tfkit.task.seq2seq.model import Model as Seq2SeqModel + + +class DummyTokenizer: + def __init__(self, vocab_size): + self.vocab_size = vocab_size + + def __len__(self): + return self.vocab_size + + def convert_ids_to_tokens(self, idx): + return f"token-{idx}" + + +class DummyCausalPretrained(nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(vocab_size=5, hidden_size=4) + self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size) + self.last_kwargs = None + + def get_output_embeddings(self): + return self.output_layer + + def forward(self, input_ids, attention_mask=None, return_dict=True, **kwargs): + self.last_kwargs = kwargs + batch_size, seq_len = input_ids.shape + logits = torch.zeros(batch_size, seq_len, self.config.vocab_size) + outputs = { + "logits": logits, + "last_hidden_state": torch.zeros(batch_size, seq_len, self.config.hidden_size), + } + if "labels" in kwargs: + outputs["loss"] = torch.tensor(0.0) + return outputs + + +class DummyEncoderPretrained(nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(vocab_size=5, hidden_size=4) + self.last_kwargs = None + + def get_output_embeddings(self): + return None + + def forward(self, input_ids, attention_mask=None, return_dict=True, **kwargs): + self.last_kwargs = kwargs + batch_size, seq_len = input_ids.shape + hidden = torch.zeros(batch_size, seq_len, self.config.hidden_size) + return {"last_hidden_state": hidden} + + +class DummySeq2SeqPretrained(nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(vocab_size=3, hidden_size=4) + self.decoder = nn.Module() + self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size) + + def get_output_embeddings(self): + return self.output_layer + + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + output_hidden_states=False, + use_cache=False, + return_dict=True, + **kwargs, + ): + batch_size, seq_len = decoder_input_ids.shape + hidden = torch.zeros(batch_size, seq_len, self.config.hidden_size) + outputs = { + "last_hidden_state": hidden, + "decoder_hidden_states": (hidden,), + } + return outputs + + +def test_clm_model_uses_pretrained_head_for_loss(): + tokenizer = DummyTokenizer(vocab_size=5) + pretrained = DummyCausalPretrained() + model = CLMModel(tokenizer=tokenizer, pretrained=pretrained) + + batch = { + "input": torch.zeros((1, 2), dtype=torch.long), + "mask": torch.ones((1, 2), dtype=torch.long), + "target": torch.tensor([[0, -1]]), + } + + loss = model.forward(batch, eval=False) + assert torch.is_tensor(loss) + assert "labels" in pretrained.last_kwargs + assert pretrained.last_kwargs["labels"].tolist() == [[0, -100]] + + eval_batch = { + **batch, + "start": [0], + } + result = model.forward(eval_batch, eval=True) + assert isinstance(result, dict) + assert "max_item" in result + + +def test_clm_model_falls_back_to_linear_head(): + tokenizer = DummyTokenizer(vocab_size=5) + pretrained = DummyEncoderPretrained() + model = CLMModel(tokenizer=tokenizer, pretrained=pretrained) + + batch = { + "input": torch.zeros((1, 2), dtype=torch.long), + "mask": torch.ones((1, 2), dtype=torch.long), + "target": torch.tensor([[0, -1]]), + } + + loss = model.forward(batch, eval=False) + assert torch.is_tensor(loss) + assert pretrained.last_kwargs == {} + + +def test_seq2seq_model_uses_pretrained_output_head(): + tokenizer = DummyTokenizer(vocab_size=3) + pretrained = DummySeq2SeqPretrained() + model = Seq2SeqModel(tokenizer=tokenizer, pretrained=pretrained) + + batch = { + "input": torch.zeros((1, 1), dtype=torch.long), + "prev": torch.zeros((1, 1), dtype=torch.long), + "encoder_mask": torch.ones((1, 1), dtype=torch.long), + "decoder_mask": torch.ones((1, 1), dtype=torch.long), + "target": torch.zeros((1, 1), dtype=torch.long), + "ntarget": torch.full((1, 1), -1), + } + + loss = model.forward(batch, eval=False) + assert torch.is_tensor(loss) + assert model.model is pretrained.output_layer diff --git a/tfkit/task/clm/model.py b/tfkit/task/clm/model.py index fadbf70..130c4d2 100644 --- a/tfkit/task/clm/model.py +++ b/tfkit/task/clm/model.py @@ -12,17 +12,60 @@ class Model(BaseTFKitModel): def __init__(self, tokenizer, pretrained, maxlen=512, **kwargs): super().__init__(tokenizer, pretrained, maxlen, **kwargs) - self.model = nn.Linear(self.get_hidden_size(), self.get_vocab_size()) + self.model = self._resolve_output_head() + self.uses_pretrained_head = self.model is not None + if not self.uses_pretrained_head: + self.model = nn.Linear(self.get_hidden_size(), self.get_vocab_size()) + self._setup_predictor(AutoRegressivePredictor, Preprocessor) + def _resolve_output_head(self): + """Return the pretrained language modeling head if available.""" + + if hasattr(self.pretrained, "get_output_embeddings"): + output_embeddings = self.pretrained.get_output_embeddings() + if output_embeddings is not None: + return output_embeddings + if hasattr(self.pretrained, "lm_head"): + return self.pretrained.lm_head + if hasattr(self.pretrained, "cls"): + return self.pretrained.cls + return None + def forward(self, batch_data, eval=False, beamsearch=False, max_return=1, **kwargs): inputs = batch_data['input'] masks = batch_data['mask'] tokens_tensor = torch.as_tensor(inputs) mask_tensors = torch.as_tensor(masks) + model_kwargs = { + 'attention_mask': mask_tensors, + 'return_dict': True, + } + if eval: + model_kwargs['use_cache'] = False + + if eval: + outputs = self.pretrained(tokens_tensor, **model_kwargs) + prediction_scores = outputs['logits'] if 'logits' in outputs else outputs[0] + else: + targets = batch_data['target'] + loss_tensors = torch.as_tensor(targets) - outputs = self.pretrained(tokens_tensor, attention_mask=mask_tensors) - prediction_scores = self.model(outputs[0]) + if self.uses_pretrained_head: + labels = loss_tensors.clone().long() + labels[labels == -1] = -100 + model_kwargs['labels'] = labels + outputs = self.pretrained(tokens_tensor, **model_kwargs) + prediction_scores = outputs['logits'] if 'logits' in outputs else outputs[0] + masked_lm_loss = outputs['loss'] + else: + loss_tensors = loss_tensors.long() + outputs = self.pretrained(tokens_tensor, **model_kwargs) + hidden_states = outputs['last_hidden_state'] if 'last_hidden_state' in outputs else outputs[0] + prediction_scores = self.model(hidden_states) + loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # -1 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.vocab_size), + loss_tensors.view(-1)) if eval: result_dict = {} @@ -39,11 +82,5 @@ def forward(self, batch_data, eval=False, beamsearch=False, max_return=1, **kwar result_dict['label_prob'] = prob_result outputs = result_dict else: - targets = batch_data['target'] - loss_tensors = torch.as_tensor(targets) - loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # -1 index = padding token - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.vocab_size), - loss_tensors.view(-1)) - outputs = masked_lm_loss return outputs diff --git a/tfkit/task/seq2seq/model.py b/tfkit/task/seq2seq/model.py index 0705ed7..4f084ba 100644 --- a/tfkit/task/seq2seq/model.py +++ b/tfkit/task/seq2seq/model.py @@ -24,12 +24,26 @@ def __init__(self, tokenizer, pretrained, maxlen: int = DEFAULT_MAXLEN, self.selfkd = selfkd self.decoder_model, init_weight = self._initialize_decoder() - self.model = nn.Linear(self.decoder_hidden_size, self.get_vocab_size(), bias=False) - if init_weight is not None: - self.model.weight = init_weight - + self.model = self._resolve_output_projection() + + if self.model is None: + self.model = nn.Linear(self.decoder_hidden_size, self.get_vocab_size(), bias=False) + if init_weight is not None: + self.model.weight = init_weight + self._setup_predictor(AutoRegressivePredictor, Preprocessor) + def _resolve_output_projection(self): + """Return the pretrained output head when available.""" + + if hasattr(self.pretrained, "get_output_embeddings"): + output_embeddings = self.pretrained.get_output_embeddings() + if output_embeddings is not None: + return output_embeddings + if hasattr(self.pretrained, "lm_head"): + return self.pretrained.lm_head + return None + def _initialize_decoder(self) -> Tuple[Optional[nn.Module], Optional[torch.Tensor]]: """Initialize decoder model and return initial weights if available.""" init_weight = None @@ -55,7 +69,7 @@ def forward(self, batch_data, eval=False, beamsearch=False, max_return=1, **kwar else: prediction_output, prediction_all_hidden = self.encoder_forward(batch_data, eval, beamsearch) - prediction_scores = self.model(prediction_output) + prediction_scores = self._project_to_vocab(prediction_output) if eval: outputs = self.process_eval_output(prediction_scores, max_return) @@ -128,7 +142,7 @@ def calculate_loss(self, batch_data, prediction_scores, prediction_all_hidden): if self.selfkd: selfkdloss_fct = SelfKDLoss(ignore_index=-1) for decoder_hidden in prediction_all_hidden[:-1]: - student = self.model(decoder_hidden) + student = self._project_to_vocab(decoder_hidden) lm_loss += selfkdloss_fct(student.view(-1, self.vocab_size), prediction_scores.view(-1, self.vocab_size), loss_tensors.view(-1)) @@ -152,3 +166,6 @@ def calculate_loss(self, batch_data, prediction_scores, prediction_all_hidden): lm_loss += negative_loss return lm_loss + + def _project_to_vocab(self, hidden_states): + return self.model(hidden_states) diff --git a/tfkit/utility/base_model.py b/tfkit/utility/base_model.py index 6a7273e..a3b4880 100644 --- a/tfkit/utility/base_model.py +++ b/tfkit/utility/base_model.py @@ -85,4 +85,4 @@ def get_vocab_size(self) -> int: Returns: Vocabulary size """ - return self.vocab_size \ No newline at end of file + return self.vocab_size diff --git a/tfkit/utility/constants.py b/tfkit/utility/constants.py index 5d11c27..55571f1 100644 --- a/tfkit/utility/constants.py +++ b/tfkit/utility/constants.py @@ -22,6 +22,7 @@ # Environment variables ENV_TOKENIZERS_PARALLELISM = "TOKENIZERS_PARALLELISM" ENV_OMP_NUM_THREADS = "OMP_NUM_THREADS" +ENV_TRUST_REMOTE_CODE = "TFKIT_TRUST_REMOTE_CODE" # Special tokens BLANK_TOKEN = "" @@ -52,4 +53,4 @@ 'WARNING': 30, 'ERROR': 40, 'CRITICAL': 50 -} \ No newline at end of file +} diff --git a/tfkit/utility/model.py b/tfkit/utility/model.py index e379760..ef032b9 100644 --- a/tfkit/utility/model.py +++ b/tfkit/utility/model.py @@ -1,13 +1,21 @@ import copy import importlib import os -from typing import List +from typing import Iterable, List, Sequence import inquirer import nlp2 import torch from torch import nn -from transformers import AutoTokenizer, AutoModel +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoTokenizer, +) + +from tfkit.utility.constants import ENV_TRUST_REMOTE_CODE def list_all_model(ignore_list=[]): @@ -29,15 +37,83 @@ def load_model_class(model_name): return importlib.import_module('.' + model_name, 'tfkit.task') -def load_pretrained_model(pretrained_config, model_type): - pretrained = AutoModel.from_pretrained(pretrained_config) - if 'clm' in model_type: - pretrained.config.is_decoder = True - return pretrained +def _should_trust_remote_code(explicit: bool | None = None) -> bool: + """Determine whether remote code should be trusted when loading models.""" + + if explicit is not None: + return explicit + + env_value = os.getenv(ENV_TRUST_REMOTE_CODE, "1").strip().lower() + return env_value in {"1", "true", "yes", "on"} + + +def _normalize_model_types(model_type: Iterable[str] | str | None) -> Sequence[str]: + """Normalize model type input into a sequence for easier processing.""" + + if model_type is None: + return () + if isinstance(model_type, str): + return (model_type,) + return tuple(model_type) + + +def _requires_seq2seq_model(tasks: Sequence[str], config) -> bool: + """Determine if a sequence-to-sequence model should be used.""" + + if getattr(config, "is_encoder_decoder", False): + return True + return any(task in {"seq2seq"} for task in tasks) + +def _requires_causal_model(tasks: Sequence[str], config) -> bool: + """Determine if a causal language model should be used.""" -def load_pretrained_tokenizer(pretrained_config): - tokenizer = AutoTokenizer.from_pretrained(pretrained_config) + if any(task in {"clm"} for task in tasks): + return True + architectures = getattr(config, "architectures", []) or [] + return any("ForCausalLM" in arch for arch in architectures) + + +def load_pretrained_model(pretrained_config, model_type, trust_remote_code: bool | None = None): + """Load a pretrained model compatible with the requested tasks.""" + + trust_remote = _should_trust_remote_code(trust_remote_code) + tasks = _normalize_model_types(model_type) + config = AutoConfig.from_pretrained(pretrained_config, trust_remote_code=trust_remote) + + if _requires_seq2seq_model(tasks, config): + try: + return AutoModelForSeq2SeqLM.from_pretrained( + pretrained_config, + config=config, + trust_remote_code=trust_remote, + ) + except (ValueError, OSError): + pass + + if _requires_causal_model(tasks, config): + try: + return AutoModelForCausalLM.from_pretrained( + pretrained_config, + config=config, + trust_remote_code=trust_remote, + ) + except (ValueError, OSError): + # Fall back to encoder model when causal head is unavailable + config.is_decoder = True + + return AutoModel.from_pretrained( + pretrained_config, + config=config, + trust_remote_code=trust_remote, + ) + + +def load_pretrained_tokenizer(pretrained_config, trust_remote_code: bool | None = None): + tokenizer = AutoTokenizer.from_pretrained( + pretrained_config, + trust_remote_code=_should_trust_remote_code(trust_remote_code), + ) return tokenizer @@ -48,6 +124,9 @@ def resize_pretrain_tok(pretrained, tokenizer): def add_tokens_to_pretrain(pretrained, tokenizer, add_tokens, sample_init=False): + if not add_tokens: + return pretrained, tokenizer + origin_vocab_size = tokenizer.vocab_size print("===ADD TOKEN===") num_added_toks = tokenizer.add_tokens(add_tokens) @@ -63,7 +142,7 @@ def add_tokens_to_pretrain(pretrained, tokenizer, add_tokens, sample_init=False) return pretrained, tokenizer -def load_trained_model(model_path, pretrained_config=None, tag=None): +def load_trained_model(model_path, pretrained_config=None, tag=None, trust_remote_code: bool | None = None): """loading saved task""" device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -96,8 +175,9 @@ def load_trained_model(model_path, pretrained_config=None, tag=None): type = model_types[type_ind] add_tokens = torchpack['add_tokens'] if 'add_tokens' in torchpack else None # load task - tokenizer = AutoTokenizer.from_pretrained(config) - pretrained = AutoModel.from_pretrained(config) + trust_remote = _should_trust_remote_code(trust_remote_code) + tokenizer = load_pretrained_tokenizer(config, trust_remote_code=trust_remote) + pretrained = load_pretrained_model(config, type, trust_remote_code=trust_remote) pretrained, tokenizer = add_tokens_to_pretrain(pretrained, tokenizer, add_tokens)