|
1 |
| -from typing import Optional |
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import List, Optional |
2 | 3 |
|
3 | 4 | import torch
|
4 | 5 | import torch.nn.functional as F
|
5 | 6 | from torch import nn
|
| 7 | +from flashlight.lib.text.decoder import LexiconFreeSeq2SeqDecoder, LexiconFreeSeq2SeqDecoderOptions, ZeroLM, create_emitting_model_state, get_obj_from_emitting_model_state |
6 | 8 |
|
7 | 9 | import logging
|
8 | 10 |
|
9 | 11 | logger = logging.getLogger(__name__)
|
10 | 12 |
|
11 | 13 |
|
| 14 | +@dataclass |
| 15 | +class Seq2SeqModelState(object): |
| 16 | + |
| 17 | + timestep: int |
| 18 | + hidden_states: List[torch.Tensor] |
| 19 | + sequence: torch.Tensor |
| 20 | + lm_scores: torch.Tensor |
| 21 | + |
| 22 | + |
12 | 23 | class GenerationUtil:
|
13 | 24 | """Wrapper to provide generation utils for encoder/decoder models and decoder models.
|
14 | 25 |
|
@@ -92,8 +103,90 @@ def greedy_search(
|
92 | 103 |
|
93 | 104 | return input_ids
|
94 | 105 |
|
95 |
| - def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_len: Optional[int]) -> torch.Tensor: |
96 |
| - raise NotImplementedError() |
| 106 | + def beam_search( |
| 107 | + self, |
| 108 | + input_ids: torch.Tensor, |
| 109 | + num_beams: int, |
| 110 | + max_len: int, |
| 111 | + vocab_size: int, |
| 112 | + eos_idx: int = 1, |
| 113 | + **model_kwargs |
| 114 | + ) -> torch.Tensor: |
| 115 | + |
| 116 | + # Is this right? |
| 117 | + T = max_len |
| 118 | + N = vocab_size |
| 119 | + |
| 120 | + def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_states, timestep): |
| 121 | + # `emissions_ptr` should always be the same (from encoder output) |
| 122 | + # N is not needed |
| 123 | + # T is not needed |
| 124 | + |
| 125 | + if timestep == 0: |
| 126 | + prev_step_token_idxs = input_ids |
| 127 | + prev_step_model_states = [ |
| 128 | + create_emitting_model_state( |
| 129 | + Seq2SeqModelState( |
| 130 | + timestep=0, |
| 131 | + hidden_states=None, |
| 132 | + sequence=input_ids, |
| 133 | + lm_scores=None |
| 134 | + ) |
| 135 | + ) |
| 136 | + ] |
| 137 | + |
| 138 | + model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| 139 | + if self.is_huggingface_model: |
| 140 | + model_inputs["return_dict"] = True |
| 141 | + model_inputs["output_hidden_states"] = True |
| 142 | + |
| 143 | + outputs = self.model(**model_inputs) |
| 144 | + output_key = "logits" if self.is_huggingface_model else "decoder_output" |
| 145 | + lm_scores = outputs[output_key] |
| 146 | + |
| 147 | + model_states = [] |
| 148 | + for idx, model_state_ptr in zip(prev_step_token_idxs, prev_step_model_states): |
| 149 | + model_state = get_obj_from_emitting_model_state(model_state_ptr) |
| 150 | + model_states.append( |
| 151 | + create_emitting_model_state( |
| 152 | + Seq2SeqModelState( |
| 153 | + timestep=timestep, |
| 154 | + hidden_states=outputs["decoder_hidden_states"], |
| 155 | + sequence=torch.cat([model_state.sequence[:, -1], idx], dim=-1), |
| 156 | + lm_scores=lm_scores |
| 157 | + ) |
| 158 | + ) |
| 159 | + ) |
| 160 | + |
| 161 | + import pdb |
| 162 | + pdb.set_trace() |
| 163 | + |
| 164 | + out_probs = lm_scores[0][0].tolist() * len(prev_step_token_idxs) |
| 165 | + return out_probs, model_states |
| 166 | + |
| 167 | + options = LexiconFreeSeq2SeqDecoderOptions( |
| 168 | + beam_size=num_beams, |
| 169 | + beam_size_token=self.model.config.vocab_size, |
| 170 | + beam_threshold=1000, |
| 171 | + lm_weight=0.0, |
| 172 | + eos_score=0.0, |
| 173 | + log_add=True, |
| 174 | + ) |
| 175 | + |
| 176 | + decoder = LexiconFreeSeq2SeqDecoder( |
| 177 | + options=options, |
| 178 | + lm=ZeroLM(), |
| 179 | + eos_idx=eos_idx, |
| 180 | + update_func=update_func, |
| 181 | + max_output_length=max_len |
| 182 | + ) |
| 183 | + |
| 184 | + emissions = model_kwargs["encoder_outputs"].get("encoder_output") |
| 185 | + |
| 186 | + decoder.decode_step(emissions.data_ptr(), T, N) |
| 187 | + hyps = decoder.get_all_final_hypothesis() |
| 188 | + |
| 189 | + return hyps |
97 | 190 |
|
98 | 191 | def generate(
|
99 | 192 | self,
|
@@ -135,6 +228,9 @@ def generate(
|
135 | 228 | if num_beams == 1 or num_beams is None:
|
136 | 229 | return self.greedy_search(inputs, max_len, eos_idx, pad_idx=pad_idx, **model_kwargs)
|
137 | 230 | elif num_beams > 1:
|
138 |
| - return self.beam_search(inputs, num_beams, max_len) |
| 231 | + if torch.has_cuda: |
| 232 | + logger.warning("No CUDA parellelization available through CUDA yet.") |
| 233 | + # Implement some sort of multiprocessing here |
| 234 | + return self.beam_search(inputs, num_beams, vocab_size=self.model.config.vocab_size, max_len=max_len, **model_kwargs) |
139 | 235 | else:
|
140 | 236 | raise ValueError("`num_beams` must be >= 1.")
|
0 commit comments