Skip to content

Commit 781cd02

Browse files
authored
Merge pull request #204 from stanfordnlp/zen/fsdp
[P0] Initiate the support of FSDP training (#205)
2 parents 4be6f6e + 262ed44 commit 781cd02

16 files changed

+107
-85
lines changed

pyvene/models/intervenable_base.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def __init__(self, config, model, backend, **kwargs):
9090
# mapping between supported abstract type and module name.
9191
###
9292
self.representations = {}
93-
self.interventions = {}
93+
self.interventions = torch.nn.ModuleDict({})
94+
self.intervention_hooks = {}
9495
self._key_collision_counter = {}
9596
self.return_collect_activations = False
9697
# Flags and counters below are for interventions in the model.generate
@@ -116,6 +117,7 @@ def __init__(self, config, model, backend, **kwargs):
116117
config.representations
117118
):
118119
_key = self._get_representation_key(representation)
120+
print(f"Intervention key: {_key}")
119121

120122
if representation.intervention is not None:
121123
intervention = representation.intervention
@@ -164,7 +166,11 @@ def __init__(self, config, model, backend, **kwargs):
164166
model, representation, backend
165167
)
166168
self.representations[_key] = representation
167-
self.interventions[_key] = (intervention, module_hook)
169+
if isinstance(intervention, types.FunctionType):
170+
self.interventions[_key] = LambdaIntervention(intervention)
171+
else:
172+
self.interventions[_key] = intervention
173+
self.intervention_hooks[_key] = module_hook
168174
self._key_getter_call_counter[
169175
_key
170176
] = 0 # we memo how many the hook is called,
@@ -266,11 +272,13 @@ def _get_representation_key(self, representation):
266272
c = representation.component
267273
u = representation.unit
268274
n = representation.max_number_of_units
275+
_u = u.replace(".", "_") # this will need internal functions to be changed as well.
269276
if "." in c:
277+
_c = c.replace(".", "_")
270278
# string access for sure
271-
key_proposal = f"comp.{c}.unit.{u}.nunit.{n}"
279+
key_proposal = f"comp_{_c}_unit_{_u}_nunit_{n}"
272280
else:
273-
key_proposal = f"layer.{l}.comp.{c}.unit.{u}.nunit.{n}"
281+
key_proposal = f"layer_{l}_comp_{c}_unit_{_u}_nunit_{n}"
274282
if key_proposal not in self._key_collision_counter:
275283
self._key_collision_counter[key_proposal] = 0
276284
else:
@@ -283,8 +291,8 @@ def get_trainable_parameters(self):
283291
"""
284292
ret_params = []
285293
for k, v in self.interventions.items():
286-
if isinstance(v[0], TrainableIntervention):
287-
ret_params += [p for p in v[0].parameters()]
294+
if isinstance(v, TrainableIntervention):
295+
ret_params += [p for p in v.parameters()]
288296
for p in self.model.parameters():
289297
if p.requires_grad:
290298
ret_params += [p]
@@ -296,8 +304,8 @@ def named_parameters(self, recurse=True):
296304
"""
297305
ret_params = []
298306
for k, v in self.interventions.items():
299-
if isinstance(v[0], TrainableIntervention):
300-
ret_params += [(k + '.' + n, p) for n, p in v[0].named_parameters()]
307+
if isinstance(v, TrainableIntervention):
308+
ret_params += [(k + '.' + n, p) for n, p in v.named_parameters()]
301309
for n, p in self.model.named_parameters():
302310
if p.requires_grad:
303311
ret_params += [('model.' + n, p)]
@@ -320,9 +328,9 @@ def set_temperature(self, temp: torch.Tensor):
320328
Set temperature if needed
321329
"""
322330
for k, v in self.interventions.items():
323-
if isinstance(v[0], BoundlessRotatedSpaceIntervention) or \
324-
isinstance(v[0], SigmoidMaskIntervention):
325-
v[0].set_temperature(temp)
331+
if isinstance(v, BoundlessRotatedSpaceIntervention) or \
332+
isinstance(v, SigmoidMaskIntervention):
333+
v.set_temperature(temp)
326334

327335
def enable_model_gradients(self):
328336
"""
@@ -356,7 +364,7 @@ def set_device(self, device, set_model=True):
356364
Set device of interventions and the model
357365
"""
358366
for k, v in self.interventions.items():
359-
v[0].to(device)
367+
v.to(device)
360368
if set_model:
361369
self.model.to(device)
362370

@@ -373,13 +381,13 @@ def count_parameters(self, include_model=False):
373381
_linked_key_set = set([])
374382
total_parameters = 0
375383
for k, v in self.interventions.items():
376-
if isinstance(v[0], TrainableIntervention):
384+
if isinstance(v, TrainableIntervention):
377385
if k in self._intervention_reverse_link:
378386
if not self._intervention_reverse_link[k] in _linked_key_set:
379387
_linked_key_set.add(self._intervention_reverse_link[k])
380-
total_parameters += count_parameters(v[0])
388+
total_parameters += count_parameters(v)
381389
else:
382-
total_parameters += count_parameters(v[0])
390+
total_parameters += count_parameters(v)
383391
if include_model:
384392
total_parameters += sum(
385393
p.numel() for p in self.model.parameters() if p.requires_grad)
@@ -390,16 +398,16 @@ def set_zero_grad(self):
390398
Set device of interventions and the model
391399
"""
392400
for k, v in self.interventions.items():
393-
if isinstance(v[0], TrainableIntervention):
394-
v[0].zero_grad()
401+
if isinstance(v, TrainableIntervention):
402+
v.zero_grad()
395403

396404
def zero_grad(self):
397405
"""
398406
The above, but for HuggingFace.
399407
"""
400408
for k, v in self.interventions.items():
401-
if isinstance(v[0], TrainableIntervention):
402-
v[0].zero_grad()
409+
if isinstance(v, TrainableIntervention):
410+
v.zero_grad()
403411

404412
def _input_validation(
405413
self,
@@ -758,7 +766,8 @@ def _intervention_getter(
758766
"""
759767
handlers = []
760768
for key_i, key in enumerate(keys):
761-
intervention, (module_hook, hook_type) = self.interventions[key]
769+
intervention = self.interventions[key]
770+
(module_hook, hook_type) = self.intervention_hooks[key]
762771
if self._is_generation:
763772
raise NotImplementedError("Generation is not implemented for ndif backend")
764773

@@ -803,7 +812,8 @@ def _intervention_setter(
803812
self._tidy_stateful_activations()
804813

805814
for key_i, key in enumerate(keys):
806-
intervention, (module_hook, hook_type) = self.interventions[key]
815+
intervention = self.interventions[key]
816+
(module_hook, hook_type) = self.intervention_hooks[key]
807817
if unit_locations_base[0] is not None:
808818
self._batched_setter_activation_select[key] = [
809819
0 for _ in range(len(unit_locations_base[0]))
@@ -846,7 +856,7 @@ def _intervention_setter(
846856
# no-op to the output
847857

848858
else:
849-
if not isinstance(self.interventions[key][0], types.FunctionType):
859+
if not isinstance(self.interventions[key], LambdaIntervention):
850860
if intervention.is_source_constant:
851861
intervened_representation = do_intervention(
852862
selected_output,
@@ -944,8 +954,8 @@ def _sync_forward_with_parallel_intervention(
944954
for key in keys:
945955
# skip in case smart jump
946956
if key in self.activations or \
947-
isinstance(self.interventions[key][0], types.FunctionType) or \
948-
self.interventions[key][0].is_source_constant:
957+
isinstance(self.interventions[key], LambdaIntervention) or \
958+
self.interventions[key].is_source_constant:
949959
self._intervention_setter(
950960
[key],
951961
[
@@ -1056,7 +1066,7 @@ def forward(
10561066
if self.return_collect_activations:
10571067
for key in self.sorted_keys:
10581068
if isinstance(
1059-
self.interventions[key][0],
1069+
self.interventions[key],
10601070
CollectIntervention
10611071
):
10621072
collected_activations += self.activations[key].clone()
@@ -1191,7 +1201,7 @@ def save(
11911201
serialized_representations
11921202

11931203
for k, v in self.interventions.items():
1194-
intervention = v[0]
1204+
intervention = v
11951205
saving_config.intervention_types += [str(type(intervention))]
11961206
binary_filename = f"intkey_{k}.bin"
11971207
# save intervention binary file
@@ -1288,7 +1298,7 @@ def load(
12881298

12891299
# load binary files
12901300
for i, (k, v) in enumerate(intervenable.interventions.items()):
1291-
intervention = v[0]
1301+
intervention = v
12921302
binary_filename = f"intkey_{k}.bin"
12931303
intervention.is_source_constant = \
12941304
saving_config.intervention_constant_sources[i]
@@ -1334,7 +1344,7 @@ def save_intervention(self, save_directory, include_model=True):
13341344

13351345
# save binary files
13361346
for k, v in self.interventions.items():
1337-
intervention = v[0]
1347+
intervention = v
13381348
binary_filename = f"intkey_{k}.bin"
13391349
# save intervention binary file
13401350
if isinstance(intervention, TrainableIntervention):
@@ -1357,7 +1367,7 @@ def load_intervention(self, load_directory, include_model=True):
13571367
"""
13581368
# load binary files
13591369
for i, (k, v) in enumerate(self.interventions.items()):
1360-
intervention = v[0]
1370+
intervention = v
13611371
binary_filename = f"intkey_{k}.bin"
13621372
if isinstance(intervention, TrainableIntervention):
13631373
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
@@ -1379,7 +1389,8 @@ def _intervention_getter(
13791389
"""
13801390
handlers = []
13811391
for key_i, key in enumerate(keys):
1382-
intervention, module_hook = self.interventions[key]
1392+
intervention = self.interventions[key]
1393+
module_hook = self.intervention_hooks[key]
13831394

13841395
def hook_callback(model, args, kwargs, output=None):
13851396
if self._is_generation:
@@ -1524,7 +1535,8 @@ def _intervention_setter(
15241535

15251536
handlers = []
15261537
for key_i, key in enumerate(keys):
1527-
intervention, module_hook = self.interventions[key]
1538+
intervention = self.interventions[key]
1539+
module_hook = self.intervention_hooks[key]
15281540
if unit_locations_base[0] is not None:
15291541
self._batched_setter_activation_select[key] = [
15301542
0 for _ in range(len(unit_locations_base[0]))
@@ -1570,7 +1582,7 @@ def hook_callback(model, args, kwargs, output=None):
15701582
# no-op to the output
15711583

15721584
else:
1573-
if not isinstance(self.interventions[key][0], types.FunctionType):
1585+
if not isinstance(self.interventions[key], LambdaIntervention):
15741586
if intervention.is_source_constant:
15751587
raw_intervened_representation = do_intervention(
15761588
selected_output,
@@ -1710,8 +1722,8 @@ def _wait_for_forward_with_parallel_intervention(
17101722
for key in keys:
17111723
# skip in case smart jump
17121724
if key in self.activations or \
1713-
isinstance(self.interventions[key][0], types.FunctionType) or \
1714-
self.interventions[key][0].is_source_constant:
1725+
isinstance(self.interventions[key], LambdaIntervention) or \
1726+
self.interventions[key].is_source_constant:
17151727
set_handlers = self._intervention_setter(
17161728
[key],
17171729
[
@@ -1780,8 +1792,8 @@ def _wait_for_forward_with_serial_intervention(
17801792
for key in keys:
17811793
# skip in case smart jump
17821794
if key in self.activations or \
1783-
isinstance(self.interventions[key][0], types.FunctionType) or \
1784-
self.interventions[key][0].is_source_constant:
1795+
isinstance(self.interventions[key], LambdaIntervention) or \
1796+
self.interventions[key].is_source_constant:
17851797
# set with intervened activation to source_i+1
17861798
set_handlers = self._intervention_setter(
17871799
[key],
@@ -1947,7 +1959,7 @@ def forward(
19471959
if self.return_collect_activations:
19481960
for key in self.sorted_keys:
19491961
if isinstance(
1950-
self.interventions[key][0],
1962+
self.interventions[key],
19511963
CollectIntervention
19521964
):
19531965
collected_activations += self.activations[key]
@@ -2081,7 +2093,7 @@ def generate(
20812093
if self.return_collect_activations:
20822094
for key in self.sorted_keys:
20832095
if isinstance(
2084-
self.interventions[key][0],
2096+
self.interventions[key],
20852097
CollectIntervention
20862098
):
20872099
collected_activations += self.activations[key]

pyvene/models/modeling_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,21 @@
66
from .constants import *
77

88

9+
class LambdaIntervention(torch.nn.Module):
10+
"""
11+
A generic wrapper to turn any Python callable (e.g. a lambda)
12+
into an nn.Module. This does *not* automatically turn external
13+
Tensors into parameters or buffers—it's just a functional wrapper.
14+
"""
15+
def __init__(self, func):
16+
super().__init__()
17+
self.func = func # store the lambda or any callable
18+
19+
def forward(self, *args, **kwargs):
20+
# Simply call the stored function
21+
return self.func(*args, **kwargs)
22+
23+
924
def get_internal_model_type(model):
1025
"""Return the model type."""
1126
return type(model)
@@ -435,7 +450,7 @@ def do_intervention(
435450
):
436451
"""Do the actual intervention."""
437452

438-
if isinstance(intervention, types.FunctionType):
453+
if isinstance(intervention, LambdaIntervention):
439454
if subspaces is None:
440455
return intervention(base_representation, source_representation)
441456
else:

pyvene_101.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2715,7 +2715,7 @@
27152715
"# zero-out grads\n",
27162716
"_ = pv_gpt2.model.eval()\n",
27172717
"for k, v in pv_gpt2.interventions.items():\n",
2718-
" v[0].zero_grad()\n",
2718+
" v.zero_grad()\n",
27192719
"\n",
27202720
"original_outputs, counterfactual_outputs = pv_gpt2(\n",
27212721
" base, \n",

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name="pyvene",
13-
version="0.1.6",
13+
version="0.1.7dev",
1414
description="Use Activation Intervention to Interpret Causal Mechanism of Model",
1515
long_description=long_description,
1616
long_description_content_type='text/markdown',

tests/integration_tests/ComplexInterventionWithGPT2TestCase.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,8 @@ def _test_subspace_partition_in_forward(self, intervention_type):
119119
RotatedSpaceIntervention,
120120
LowRankRotatedSpaceIntervention,
121121
}:
122-
list(fast.interventions.values())[0][
123-
0
124-
].rotate_layer.weight = list(intervenable.interventions.values())[0][
125-
0
126-
].rotate_layer.weight
122+
list(fast.interventions.values())[0].rotate_layer.weight = \
123+
list(intervenable.interventions.values())[0].rotate_layer.weight
127124

128125
_, without_partition_our_output = fast(
129126
base,

tests/integration_tests/InterventionWithGPT2TestCase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _test_with_head_position_intervention(
8989
intervention_types=intervention_type,
9090
)
9191
intervenable = IntervenableModel(config, self.gpt2)
92-
intervention = list(intervenable.interventions.values())[0][0]
92+
intervention = list(intervenable.interventions.values())[0]
9393

9494
base_activations = {}
9595
source_activations = {}

tests/integration_tests/InterventionWithLlamaTestCase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _test_with_head_position_intervention(
8888
intervention_types=intervention_type,
8989
)
9090
intervenable = IntervenableModel(config, self.llama)
91-
intervention = list(intervenable.interventions.values())[0][0]
91+
intervention = list(intervenable.interventions.values())[0]
9292

9393
base_activations = {}
9494
source_activations = {}

tutorials/advanced_tutorials/Boundless_DAS.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,8 @@
422422
"warm_up_steps = 0.1 * t_total\n",
423423
"optimizer_params = []\n",
424424
"for k, v in intervenable.interventions.items():\n",
425-
" optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n",
426-
" optimizer_params += [{\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n",
425+
" optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
426+
" optimizer_params += [{\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n",
427427
"optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n",
428428
"scheduler = get_linear_schedule_with_warmup(\n",
429429
" optimizer, num_warmup_steps=warm_up_steps, num_training_steps=t_total\n",
@@ -470,7 +470,7 @@
470470
" loss = loss_fct(shift_logits, shift_labels)\n",
471471
"\n",
472472
" for k, v in intervenable.interventions.items():\n",
473-
" boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n",
473+
" boundary_loss = 1.0 * v.intervention_boundaries.sum()\n",
474474
" loss += boundary_loss\n",
475475
"\n",
476476
" return loss"

tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1438,7 +1438,7 @@
14381438
"t_total = int(len(dataset) * epochs)\n",
14391439
"optimizer_params = []\n",
14401440
"for k, v in intervenable.interventions.items():\n",
1441-
" optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n",
1441+
" optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
14421442
" break\n",
14431443
"optimizer = torch.optim.Adam(optimizer_params, lr=0.001)\n",
14441444
"\n",

0 commit comments

Comments
 (0)