diff --git a/run_pplm.py b/run_pplm.py index b6d18c9..3844c88 100644 --- a/run_pplm.py +++ b/run_pplm.py @@ -32,9 +32,24 @@ import torch.nn.functional as F from torch.autograd import Variable from tqdm import trange -from transformers import GPT2Tokenizer from transformers.file_utils import cached_path -from transformers.modeling_gpt2 import GPT2LMHeadModel + +from transformers import ( + AutoModelWithLMHead, + AutoTokenizer, + CTRLLMHeadModel, + CTRLTokenizer, + GPT2LMHeadModel, + GPT2Tokenizer, + OpenAIGPTLMHeadModel, + OpenAIGPTTokenizer, + TransfoXLLMHeadModel, + TransfoXLTokenizer, + XLMTokenizer, + XLMWithLMHeadModel, + XLNetLMHeadModel, + XLNetTokenizer, +) from pplm_classification_head import ClassificationHead @@ -86,7 +101,6 @@ }, } - def to_var(x, requires_grad=False, volatile=False, device='cuda'): if torch.cuda.is_available() and device == 'cuda': x = x.cuda() @@ -373,11 +387,18 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> filepath = id_or_path with open(filepath, "r") as f: words = f.read().strip().split("\n") - bow_indices.append( - [tokenizer.encode(word.strip(), - add_prefix_space=True, - add_special_tokens=False) - for word in words]) + + if isinstance(tokenizer, GPT2Tokenizer): + def tokenizer_encode(word): + return tokenizer.encode(word.strip(), + add_prefix_space=True, + add_special_tokens=False) + else: + def tokenizer_encode(word): + return tokenizer.encode(word.strip(), + add_special_tokens=False) + + bow_indices.append([tokenizer_encode(word) for word in words]) return bow_indices @@ -561,7 +582,7 @@ def generate_text_pplm( if past is None and output_so_far is not None: last = output_so_far[:, -1:] if output_so_far.shape[1] > 1: - _, past, _ = model(output_so_far[:, :-1]) + past = model(output_so_far[:, :-1])[1] unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) unpert_last_hidden = unpert_all_hidden[-1] @@ -727,7 +748,7 @@ def run_pplm_example( "to discriminator's = {}".format(discrim, pretrained_model)) # load pretrained model - model = GPT2LMHeadModel.from_pretrained( + model = AutoModelWithLMHead.from_pretrained( pretrained_model, output_hidden_states=True ) @@ -735,9 +756,9 @@ def run_pplm_example( model.eval() # load tokenizer - tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model) - # Freeze GPT-2 weights + # Freeze pretrained model's weights for param in model.parameters(): param.requires_grad = False @@ -844,8 +865,25 @@ def run_pplm_example( return +def test(): + run_pplm_example( + pretrained_model="gpt2-medium", + # pretrained_model="xlnet-large-cased", + cond_text="The potato", + num_samples=3, + bag_of_words='military', + length=50, + stepsize=0.03, + sample=True, + num_iterations=3, + window_length=5, + gamma=1.5, + gm_scale=0.95, + kl_scale=0.01, + verbosity='regular' + ) -if __name__ == '__main__': +def main(): parser = argparse.ArgumentParser() parser.add_argument( "--pretrained_model", @@ -934,3 +972,7 @@ def run_pplm_example( args = parser.parse_args() run_pplm_example(**vars(args)) + +if __name__ == '__main__': + # main() + test()