77import pandas as pd
88import numpy as np
99from transformers import GPT2Config , LlamaConfig
10+ from transformers .models .gpt2 .modeling_gpt2 import eager_attention_forward
1011import math
1112from torch import nn
1213
@@ -79,6 +80,22 @@ def is_package_installed(package_name):
7980forward calls to fetch activations or run with cached activations
8081"""
8182
83+ def split_heads (tensor , num_heads , attn_head_size ):
84+ """
85+ Splits hidden_size dim into attn_head_size and num_heads
86+ """
87+ new_shape = tensor .size ()[:- 1 ] + (num_heads , attn_head_size )
88+ tensor = tensor .view (new_shape )
89+ return tensor .permute (0 , 2 , 1 , 3 ) # (batch, head, seq_length, head_features)
90+
91+ def merge_heads (tensor , num_heads , attn_head_size ):
92+ """
93+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
94+ """
95+ tensor = tensor .permute (0 , 2 , 1 , 3 ).contiguous ()
96+ new_shape = tensor .size ()[:- 2 ] + (num_heads * attn_head_size ,)
97+ return tensor .view (new_shape )
98+
8299
83100def DO_INTERVENTION (name , orig_hidden_states , INTERVENTION_ACTIVATIONS ):
84101 if name in INTERVENTION_ACTIVATIONS :
@@ -100,9 +117,9 @@ def GPT2_SELF_ATTENTION_RUN(
100117 value = DO_INTERVENTION (f"{ i } .value_output" , value , INTERVENTION_ACTIVATIONS )
101118 CACHE_ACTIVATIONS [f"{ i } .value_output" ] = value
102119
103- head_query = self_attn . _split_heads (query , self_attn .num_heads , self_attn .head_dim )
104- head_key = self_attn . _split_heads (key , self_attn .num_heads , self_attn .head_dim )
105- head_value = self_attn . _split_heads (value , self_attn .num_heads , self_attn .head_dim )
120+ head_query = split_heads (query , self_attn .num_heads , self_attn .head_dim )
121+ head_key = split_heads (key , self_attn .num_heads , self_attn .head_dim )
122+ head_value = split_heads (value , self_attn .num_heads , self_attn .head_dim )
106123
107124 head_query = DO_INTERVENTION (
108125 f"{ i } .head_query_output" , head_query , INTERVENTION_ACTIVATIONS
@@ -117,18 +134,24 @@ def GPT2_SELF_ATTENTION_RUN(
117134 )
118135 CACHE_ACTIVATIONS [f"{ i } .head_value_output" ] = head_value
119136
120- head_attention_value_output , attn_weights = self_attn ._attn (
121- head_query , head_key , head_value
137+ head_attention_value_output , _ = eager_attention_forward (
138+ module = self_attn ,
139+ query = head_query ,
140+ key = head_key ,
141+ value = head_value ,
142+ attention_mask = None ,
122143 )
123144
145+ head_attention_value_output = head_attention_value_output .permute (0 , 2 , 1 , 3 )
146+
124147 head_attention_value_output = DO_INTERVENTION (
125148 f"{ i } .head_attention_value_output" ,
126149 head_attention_value_output ,
127150 INTERVENTION_ACTIVATIONS ,
128151 )
129152 CACHE_ACTIVATIONS [f"{ i } .head_attention_value_output" ] = head_attention_value_output
130153
131- attn_value_output = self_attn . _merge_heads (
154+ attn_value_output = merge_heads (
132155 head_attention_value_output , self_attn .num_heads , self_attn .head_dim
133156 )
134157 attn_value_output = DO_INTERVENTION (
@@ -287,7 +310,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
287310 return q_embed , k_embed
288311
289312def Llama_SELF_ATTENTION_RUN (
290- self_attn , hidden_states , i , CACHE_ACTIVATIONS , INTERVENTION_ACTIVATIONS
313+ self_attn ,
314+ hidden_states ,
315+ i ,
316+ CACHE_ACTIVATIONS ,
317+ INTERVENTION_ACTIVATIONS ,
318+ num_heads ,
319+ num_key_value_heads ,
320+ rotary_emb
291321):
292322 bsz , q_len , _ = hidden_states .size ()
293323
@@ -302,9 +332,9 @@ def Llama_SELF_ATTENTION_RUN(
302332 value = DO_INTERVENTION (f"{ i } .value_output" , value , INTERVENTION_ACTIVATIONS )
303333 CACHE_ACTIVATIONS [f"{ i } .value_output" ] = value
304334
305- head_query = query .view (bsz , q_len , self_attn . num_heads , self_attn .head_dim ).transpose (1 , 2 )
306- head_key = key .view (bsz , q_len , self_attn . num_key_value_heads , self_attn .head_dim ).transpose (1 , 2 )
307- head_value = value .view (bsz , q_len , self_attn . num_key_value_heads , self_attn .head_dim ).transpose (1 , 2 )
335+ head_query = query .view (bsz , q_len , num_heads , self_attn .head_dim ).transpose (1 , 2 )
336+ head_key = key .view (bsz , q_len , num_key_value_heads , self_attn .head_dim ).transpose (1 , 2 )
337+ head_value = value .view (bsz , q_len , num_key_value_heads , self_attn .head_dim ).transpose (1 , 2 )
308338
309339 head_query = DO_INTERVENTION (
310340 f"{ i } .head_query_output" , head_query , INTERVENTION_ACTIVATIONS
@@ -320,7 +350,7 @@ def Llama_SELF_ATTENTION_RUN(
320350 CACHE_ACTIVATIONS [f"{ i } .head_value_output" ] = head_value
321351
322352 position_ids = torch .arange (q_len , device = hidden_states .device ).repeat (bsz , 1 )
323- cos , sin = self_attn . rotary_emb (head_value , position_ids )
353+ cos , sin = rotary_emb (head_value , position_ids )
324354 head_query , head_key = apply_rotary_pos_emb (head_query , head_key , cos , sin )
325355
326356 head_key = repeat_kv (head_key , self_attn .num_key_value_groups )
@@ -340,7 +370,7 @@ def Llama_SELF_ATTENTION_RUN(
340370 INTERVENTION_ACTIVATIONS ,
341371 )
342372 CACHE_ACTIVATIONS [f"{ i } .head_attention_value_output" ] = head_attention_value_output
343- attn_value_output = head_attention_value_output .transpose (1 , 2 ).contiguous ().reshape (bsz , q_len , self_attn .hidden_size )
373+ attn_value_output = head_attention_value_output .transpose (1 , 2 ).contiguous ().reshape (bsz , q_len , num_heads * self_attn .head_dim )
344374 attn_value_output = DO_INTERVENTION (
345375 f"{ i } .attention_value_output" , attn_value_output , INTERVENTION_ACTIVATIONS
346376 )
@@ -364,7 +394,14 @@ def Llama_MLP_RUN(mlp, hidden_states, i, CACHE_ACTIVATIONS, INTERVENTION_ACTIVAT
364394 return hidden_states_down_proj
365395
366396def Llama_BLOCK_RUN (
367- block , hidden_states , i , CACHE_ACTIVATIONS , INTERVENTION_ACTIVATIONS
397+ block ,
398+ hidden_states ,
399+ i ,
400+ CACHE_ACTIVATIONS ,
401+ INTERVENTION_ACTIVATIONS ,
402+ num_heads ,
403+ num_key_value_heads ,
404+ rotary_emb
368405):
369406 # self attention + residual
370407 residual = hidden_states
@@ -376,7 +413,14 @@ def Llama_BLOCK_RUN(
376413 CACHE_ACTIVATIONS [f"{ i } .attention_input" ] = hidden_states
377414
378415 attn_outputs = Llama_SELF_ATTENTION_RUN (
379- block .self_attn , hidden_states , i , CACHE_ACTIVATIONS , INTERVENTION_ACTIVATIONS
416+ block .self_attn ,
417+ hidden_states ,
418+ i ,
419+ CACHE_ACTIVATIONS ,
420+ INTERVENTION_ACTIVATIONS ,
421+ num_heads ,
422+ num_key_value_heads ,
423+ rotary_emb
380424 )
381425
382426 attn_outputs = DO_INTERVENTION (
@@ -417,6 +461,9 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
417461 """
418462 # embed
419463 inputs_embeds = llama .model .embed_tokens (input_ids )
464+ num_heads = llama .model .config .num_attention_heads
465+ num_key_value_heads = llama .model .config .num_key_value_heads
466+ rotary_emb = llama .model .rotary_emb
420467 hidden_states = inputs_embeds
421468 for i , block in enumerate (llama .model .layers ):
422469 hidden_states = DO_INTERVENTION (
@@ -425,7 +472,14 @@ def Llama_RUN(llama, input_ids, CACHE_ACTIVATIONS, INTERVENTION_ACTIVATIONS):
425472 CACHE_ACTIVATIONS [f"{ i } .block_input" ] = hidden_states
426473
427474 hidden_states = Llama_BLOCK_RUN (
428- block , hidden_states , i , CACHE_ACTIVATIONS , INTERVENTION_ACTIVATIONS
475+ block ,
476+ hidden_states ,
477+ i ,
478+ CACHE_ACTIVATIONS ,
479+ INTERVENTION_ACTIVATIONS ,
480+ num_heads ,
481+ num_key_value_heads ,
482+ rotary_emb
429483 )
430484
431485 hidden_states = DO_INTERVENTION (
0 commit comments