-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathact_samples.py
More file actions
38 lines (28 loc) · 1.49 KB
/
act_samples.py
File metadata and controls
38 lines (28 loc) · 1.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import pandas as pd
from tango import Step
@Step.register('dir_samples')
class DirectionActivationSamples(Step):
VERSION = "002"
CACHEABLE = False
def run(self, dataset, hidden_states, probe, csv_name, nr_samples=-1) -> None:
hidden_states = hidden_states['eval'][0]
direction = probe.direction
activations = hidden_states @ direction
activations = torch.cat([activations[0], activations[1]]) # neg, pos
N = activations.shape[0]
indices = activations.argsort()
ranks = indices.argsort()
relative = ranks / N
sample_texts = [s[2] for s in dataset][N // 2:] + [s[3] for s in dataset][N // 2:]
token_counts = ([s[0]['attention_mask'].sum().item() for s in dataset][N // 2:]
+ [s[0]['attention_mask'].sum().item() for s in dataset][N // 2:])
polarities = ['negative' for _ in range(N//2)] + ['positive' for _ in range(N//2)]
labels = [s[4] == 0 for s in dataset][N // 2:] + [s[4] == 1 for s in dataset][N // 2:]
df = pd.DataFrame.from_records([
{'rel_rank': rel_rank.item(), 'activation': activation.item(), 'sample_text': sample_text,
'token_count': token_count, 'polarity': polarity, 'label': label}
for rel_rank, activation, sample_text, token_count, polarity, label
in zip(relative, activations, sample_texts, token_counts, polarities, labels)
])
df.to_csv('local_outputs/' + csv_name + '.csv')