-
Notifications
You must be signed in to change notification settings - Fork 205
Description
In Section 4.3, the paper says, "... we use the distirubtion ~p_(t+1) (instead of a hard sample x_(t+1)), and feed it forward to obtain (a biased) estimate of the next token’s embedding and then update delta_H_t." In the code, I found hard sample x_(t+1) (i.e., model(last, ...)) is feeded into the model and got the probs in the first time,
all_logits, _, all_hidden = model(last, past=perturbed_past)
hidden = all_hidden[-1]
new_accumulated_hidden = accumulated_hidden + torch.sum(
hidden,
dim=1
).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
and then, the code put the soft distribution ~p_(t+1) (i.e., inputs_embeds) in the model in the second time,
if loss_type == PPLM_DISCRIM or loss_type == PPLM_BOW_DISCRIM:
ce_loss = torch.nn.CrossEntropyLoss()
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past = unpert_past
curr_probs = torch.unsqueeze(probs, dim=1)
wte = model.resize_token_embeddings()
for _ in range(horizon_length):
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
_, curr_unpert_past, curr_all_hidden = model(
past=curr_unpert_past,
inputs_embeds=inputs_embeds
)
curr_hidden = curr_all_hidden[-1]
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
curr_hidden, dim=1)
My questions are
(1) Why it uses past=curr_unpert_past, instead of past=past, to predict next token in the second time? Because if you predict next token, we need to input GPT2 with current token_id (or embedding) and the past_key_values before current token.
(2) In the second time, the code didn't update logits (i.e., _, curr_unpert_past, curr_all_hidden = model(...), so it cann't update probs, thus it use the probs in the first time (i.e., probs = F.softmax(logits, dim=-1). Why not to update probs at the second time?
Thank you so much. Please correct me if I'm wrong.