Skip to content

Commit 4be6f6e

Browse files
authored
Merge pull request #203 from stanfordnlp/peterwz-versions
[P0] Fix test failures due to transformers version change
2 parents f5119c1 + 8695a09 commit 4be6f6e

File tree

3 files changed

+82
-19
lines changed

3 files changed

+82
-19
lines changed

pyvene/models/intervenable_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1292,7 +1292,16 @@ def load(
12921292
binary_filename = f"intkey_{k}.bin"
12931293
intervention.is_source_constant = \
12941294
saving_config.intervention_constant_sources[i]
1295-
intervention.set_interchange_dim(saving_config.intervention_dimensions[i])
1295+
dim = saving_config.intervention_dimensions[i]
1296+
if dim is None:
1297+
# Infer interchange dimension from component name to be compatible with old versions
1298+
component_name = saving_config.representations[i].component
1299+
if component_name.startswith("head_"):
1300+
dim = model.config.hidden_size // model.config.num_attention_heads
1301+
else:
1302+
dim = model.config.hidden_size
1303+
1304+
intervention.set_interchange_dim(dim)
12961305
if saving_config.intervention_constant_sources[i] and \
12971306
not isinstance(intervention, ZeroIntervention) and \
12981307
not isinstance(intervention, SourcelessIntervention):

pyvene_101.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,7 +1877,7 @@
18771877
" tokenizer(\"The capital of Italy is\", return_tensors=\"pt\"),\n",
18781878
"]\n",
18791879
"base_outputs, counterfactual_outputs = pv_gpt2(\n",
1880-
" base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}\n",
1880+
" base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}, output_original_output=True\n",
18811881
")\n",
18821882
"print(counterfactual_outputs.last_hidden_state - base_outputs.last_hidden_state)\n",
18831883
"# call backward will put gradients on model's weights\n",
@@ -2785,7 +2785,7 @@
27852785
" model=resnet\n",
27862786
")\n",
27872787
"intervened_outputs = pv_resnet(\n",
2788-
" base_inputs, [source_inputs], return_dict=True\n",
2788+
" base_inputs, [source_inputs], return_dict=True, output_original_output=True\n",
27892789
")\n",
27902790
"(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()"
27912791
]
@@ -2842,7 +2842,7 @@
28422842
")\n",
28432843
"\n",
28442844
"intervened_outputs = pv_resnet(\n",
2845-
" base_inputs, [source_inputs], return_dict=True\n",
2845+
" base_inputs, [source_inputs], return_dict=True, output_original_output=True\n",
28462846
")\n",
28472847
"(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()"
28482848
]

tests/utils.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88
import numpy as np
99
from transformers import GPT2Config, LlamaConfig
10+
from transformers.models.gpt2.modeling_gpt2 import eager_attention_forward
1011
import math
1112
from torch import nn
1213

@@ -79,6 +80,22 @@ def is_package_installed(package_name):
7980
forward calls to fetch activations or run with cached activations
8081
"""
8182

83+
def split_heads(tensor, num_heads, attn_head_size):
84+
"""
85+
Splits hidden_size dim into attn_head_size and num_heads
86+
"""
87+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
88+
tensor = tensor.view(new_shape)
89+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
90+
91+
def merge_heads(tensor, num_heads, attn_head_size):
92+
"""
93+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
94+
"""
95+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
96+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
97+
return tensor.view(new_shape)
98+
8299

83100
def DO_INTERVENTION(name, orig_hidden_states, INTERVENTION_ACTIVATIONS):
84101
if name in INTERVENTION_ACTIVATIONS:
@@ -100,9 +117,9 @@ def GPT2_SELF_ATTENTION_RUN(
100117
value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS)
101118
CACHE_ACTIVATIONS[f"{i}.value_output"] = value
102119

103-
head_query = self_attn._split_heads(query, self_attn.num_heads, self_attn.head_dim)
104-
head_key = self_attn._split_heads(key, self_attn.num_heads, self_attn.head_dim)
105-
head_value = self_attn._split_heads(value, self_attn.num_heads, self_attn.head_dim)
120+
head_query = split_heads(query, self_attn.num_heads, self_attn.head_dim)
121+
head_key = split_heads(key, self_attn.num_heads, self_attn.head_dim)
122+
head_value = split_heads(value, self_attn.num_heads, self_attn.head_dim)
106123

107124
head_query = DO_INTERVENTION(
108125
f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS
@@ -117,18 +134,24 @@ def GPT2_SELF_ATTENTION_RUN(
117134
)
118135
CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value
119136

120-
head_attention_value_output, attn_weights = self_attn._attn(
121-
head_query, head_key, head_value
137+
head_attention_value_output, _ = eager_attention_forward(
138+
module=self_attn,
139+
query=head_query,
140+
key=head_key,
141+
value=head_value,
142+
attention_mask=None,
122143
)
123144

145+
head_attention_value_output = head_attention_value_output.permute(0, 2, 1, 3)
146+
124147
head_attention_value_output = DO_INTERVENTION(
125148
f"{i}.head_attention_value_output",
126149
head_attention_value_output,
127150
INTERVENTION_ACTIVATIONS,
128151
)
129152
CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output
130153

131-
attn_value_output = self_attn._merge_heads(
154+
attn_value_output = merge_heads(
132155
head_attention_value_output, self_attn.num_heads, self_attn.head_dim
133156
)
134157
attn_value_output = DO_INTERVENTION(
@@ -287,7 +310,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
287310
return q_embed, k_embed
288311

289312
def Llama_SELF_ATTENTION_RUN(
290-
self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
313+
self_attn,
314+
hidden_states,
315+
i,
316+
CACHE_ACTIVATIONS,
317+
INTERVENTION_ACTIVATIONS,
318+
num_heads,
319+
num_key_value_heads,
320+
rotary_emb
291321
):
292322
bsz, q_len, _ = hidden_states.size()
293323

@@ -302,9 +332,9 @@ def Llama_SELF_ATTENTION_RUN(
302332
value = DO_INTERVENTION(f"{i}.value_output", value, INTERVENTION_ACTIVATIONS)
303333
CACHE_ACTIVATIONS[f"{i}.value_output"] = value
304334

305-
head_query = query.view(bsz, q_len, self_attn.num_heads, self_attn.head_dim).transpose(1, 2)
306-
head_key = key.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2)
307-
head_value = value.view(bsz, q_len, self_attn.num_key_value_heads, self_attn.head_dim).transpose(1, 2)
335+
head_query = query.view(bsz, q_len, num_heads, self_attn.head_dim).transpose(1, 2)
336+
head_key = key.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2)
337+
head_value = value.view(bsz, q_len, num_key_value_heads, self_attn.head_dim).transpose(1, 2)
308338

309339
head_query = DO_INTERVENTION(
310340
f"{i}.head_query_output", head_query, INTERVENTION_ACTIVATIONS
@@ -320,7 +350,7 @@ def Llama_SELF_ATTENTION_RUN(
320350
CACHE_ACTIVATIONS[f"{i}.head_value_output"] = head_value
321351

322352
position_ids = torch.arange(q_len, device=hidden_states.device).repeat(bsz, 1)
323-
cos, sin = self_attn.rotary_emb(head_value, position_ids)
353+
cos, sin = rotary_emb(head_value, position_ids)
324354
head_query, head_key = apply_rotary_pos_emb(head_query, head_key, cos, sin)
325355

326356
head_key = repeat_kv(head_key, self_attn.num_key_value_groups)
@@ -340,7 +370,7 @@ def Llama_SELF_ATTENTION_RUN(
340370
INTERVENTION_ACTIVATIONS,
341371
)
342372
CACHE_ACTIVATIONS[f"{i}.head_attention_value_output"] = head_attention_value_output
343-
attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self_attn.hidden_size)
373+
attn_value_output = head_attention_value_output.transpose(1, 2).contiguous().reshape(bsz, q_len, num_heads * self_attn.head_dim)
344374
attn_value_output = DO_INTERVENTION(
345375
f"{i}.attention_value_output", attn_value_output, INTERVENTION_ACTIVATIONS
346376
)
@@ -364,7 +394,14 @@ def Llama_MLP_RUN(mlp, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVAT
364394
return hidden_states_down_proj
365395

366396
def Llama_BLOCK_RUN(
367-
block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
397+
block,
398+
hidden_states,
399+
i,
400+
CACHE_ACTIVATIONS,
401+
INTERVENTION_ACTIVATIONS,
402+
num_heads,
403+
num_key_value_heads,
404+
rotary_emb
368405
):
369406
# self attention + residual
370407
residual = hidden_states
@@ -376,7 +413,14 @@ def Llama_BLOCK_RUN(
376413
CACHE_ACTIVATIONS[f"{i}.attention_input"] = hidden_states
377414

378415
attn_outputs = Llama_SELF_ATTENTION_RUN(
379-
block.self_attn, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
416+
block.self_attn,
417+
hidden_states,
418+
i,
419+
CACHE_ACTIVATIONS,
420+
INTERVENTION_ACTIVATIONS,
421+
num_heads,
422+
num_key_value_heads,
423+
rotary_emb
380424
)
381425

382426
attn_outputs = DO_INTERVENTION(
@@ -417,6 +461,9 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
417461
"""
418462
# embed
419463
inputs_embeds = llama.model.embed_tokens(input_ids)
464+
num_heads = llama.model.config.num_attention_heads
465+
num_key_value_heads = llama.model.config.num_key_value_heads
466+
rotary_emb = llama.model.rotary_emb
420467
hidden_states = inputs_embeds
421468
for i, block in enumerate(llama.model.layers):
422469
hidden_states = DO_INTERVENTION(
@@ -425,7 +472,14 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
425472
CACHE_ACTIVATIONS[f"{i}.block_input"] = hidden_states
426473

427474
hidden_states = Llama_BLOCK_RUN(
428-
block, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS
475+
block,
476+
hidden_states,
477+
i,
478+
CACHE_ACTIVATIONS,
479+
INTERVENTION_ACTIVATIONS,
480+
num_heads,
481+
num_key_value_heads,
482+
rotary_emb
429483
)
430484

431485
hidden_states = DO_INTERVENTION(

0 commit comments

Comments
 (0)