@@ -90,7 +90,8 @@ def __init__(self, config, model, backend, **kwargs):
9090 # mapping between supported abstract type and module name.
9191 ###
9292 self .representations = {}
93- self .interventions = {}
93+ self .interventions = torch .nn .ModuleDict ({})
94+ self .intervention_hooks = {}
9495 self ._key_collision_counter = {}
9596 self .return_collect_activations = False
9697 # Flags and counters below are for interventions in the model.generate
@@ -116,6 +117,7 @@ def __init__(self, config, model, backend, **kwargs):
116117 config .representations
117118 ):
118119 _key = self ._get_representation_key (representation )
120+ print (f"Intervention key: { _key } " )
119121
120122 if representation .intervention is not None :
121123 intervention = representation .intervention
@@ -164,7 +166,11 @@ def __init__(self, config, model, backend, **kwargs):
164166 model , representation , backend
165167 )
166168 self .representations [_key ] = representation
167- self .interventions [_key ] = (intervention , module_hook )
169+ if isinstance (intervention , types .FunctionType ):
170+ self .interventions [_key ] = LambdaIntervention (intervention )
171+ else :
172+ self .interventions [_key ] = intervention
173+ self .intervention_hooks [_key ] = module_hook
168174 self ._key_getter_call_counter [
169175 _key
170176 ] = 0 # we memo how many the hook is called,
@@ -266,11 +272,13 @@ def _get_representation_key(self, representation):
266272 c = representation .component
267273 u = representation .unit
268274 n = representation .max_number_of_units
275+ _u = u .replace ("." , "_" ) # this will need internal functions to be changed as well.
269276 if "." in c :
277+ _c = c .replace ("." , "_" )
270278 # string access for sure
271- key_proposal = f"comp. { c } .unit. { u } .nunit. { n } "
279+ key_proposal = f"comp_ { _c } _unit_ { _u } _nunit_ { n } "
272280 else :
273- key_proposal = f"layer. { l } .comp. { c } .unit. { u } .nunit. { n } "
281+ key_proposal = f"layer_ { l } _comp_ { c } _unit_ { _u } _nunit_ { n } "
274282 if key_proposal not in self ._key_collision_counter :
275283 self ._key_collision_counter [key_proposal ] = 0
276284 else :
@@ -283,8 +291,8 @@ def get_trainable_parameters(self):
283291 """
284292 ret_params = []
285293 for k , v in self .interventions .items ():
286- if isinstance (v [ 0 ] , TrainableIntervention ):
287- ret_params += [p for p in v [ 0 ] .parameters ()]
294+ if isinstance (v , TrainableIntervention ):
295+ ret_params += [p for p in v .parameters ()]
288296 for p in self .model .parameters ():
289297 if p .requires_grad :
290298 ret_params += [p ]
@@ -296,8 +304,8 @@ def named_parameters(self, recurse=True):
296304 """
297305 ret_params = []
298306 for k , v in self .interventions .items ():
299- if isinstance (v [ 0 ] , TrainableIntervention ):
300- ret_params += [(k + '.' + n , p ) for n , p in v [ 0 ] .named_parameters ()]
307+ if isinstance (v , TrainableIntervention ):
308+ ret_params += [(k + '.' + n , p ) for n , p in v .named_parameters ()]
301309 for n , p in self .model .named_parameters ():
302310 if p .requires_grad :
303311 ret_params += [('model.' + n , p )]
@@ -320,9 +328,9 @@ def set_temperature(self, temp: torch.Tensor):
320328 Set temperature if needed
321329 """
322330 for k , v in self .interventions .items ():
323- if isinstance (v [ 0 ] , BoundlessRotatedSpaceIntervention ) or \
324- isinstance (v [ 0 ] , SigmoidMaskIntervention ):
325- v [ 0 ] .set_temperature (temp )
331+ if isinstance (v , BoundlessRotatedSpaceIntervention ) or \
332+ isinstance (v , SigmoidMaskIntervention ):
333+ v .set_temperature (temp )
326334
327335 def enable_model_gradients (self ):
328336 """
@@ -356,7 +364,7 @@ def set_device(self, device, set_model=True):
356364 Set device of interventions and the model
357365 """
358366 for k , v in self .interventions .items ():
359- v [ 0 ] .to (device )
367+ v .to (device )
360368 if set_model :
361369 self .model .to (device )
362370
@@ -373,13 +381,13 @@ def count_parameters(self, include_model=False):
373381 _linked_key_set = set ([])
374382 total_parameters = 0
375383 for k , v in self .interventions .items ():
376- if isinstance (v [ 0 ] , TrainableIntervention ):
384+ if isinstance (v , TrainableIntervention ):
377385 if k in self ._intervention_reverse_link :
378386 if not self ._intervention_reverse_link [k ] in _linked_key_set :
379387 _linked_key_set .add (self ._intervention_reverse_link [k ])
380- total_parameters += count_parameters (v [ 0 ] )
388+ total_parameters += count_parameters (v )
381389 else :
382- total_parameters += count_parameters (v [ 0 ] )
390+ total_parameters += count_parameters (v )
383391 if include_model :
384392 total_parameters += sum (
385393 p .numel () for p in self .model .parameters () if p .requires_grad )
@@ -390,16 +398,16 @@ def set_zero_grad(self):
390398 Set device of interventions and the model
391399 """
392400 for k , v in self .interventions .items ():
393- if isinstance (v [ 0 ] , TrainableIntervention ):
394- v [ 0 ] .zero_grad ()
401+ if isinstance (v , TrainableIntervention ):
402+ v .zero_grad ()
395403
396404 def zero_grad (self ):
397405 """
398406 The above, but for HuggingFace.
399407 """
400408 for k , v in self .interventions .items ():
401- if isinstance (v [ 0 ] , TrainableIntervention ):
402- v [ 0 ] .zero_grad ()
409+ if isinstance (v , TrainableIntervention ):
410+ v .zero_grad ()
403411
404412 def _input_validation (
405413 self ,
@@ -758,7 +766,8 @@ def _intervention_getter(
758766 """
759767 handlers = []
760768 for key_i , key in enumerate (keys ):
761- intervention , (module_hook , hook_type ) = self .interventions [key ]
769+ intervention = self .interventions [key ]
770+ (module_hook , hook_type ) = self .intervention_hooks [key ]
762771 if self ._is_generation :
763772 raise NotImplementedError ("Generation is not implemented for ndif backend" )
764773
@@ -803,7 +812,8 @@ def _intervention_setter(
803812 self ._tidy_stateful_activations ()
804813
805814 for key_i , key in enumerate (keys ):
806- intervention , (module_hook , hook_type ) = self .interventions [key ]
815+ intervention = self .interventions [key ]
816+ (module_hook , hook_type ) = self .intervention_hooks [key ]
807817 if unit_locations_base [0 ] is not None :
808818 self ._batched_setter_activation_select [key ] = [
809819 0 for _ in range (len (unit_locations_base [0 ]))
@@ -846,7 +856,7 @@ def _intervention_setter(
846856 # no-op to the output
847857
848858 else :
849- if not isinstance (self .interventions [key ][ 0 ], types . FunctionType ):
859+ if not isinstance (self .interventions [key ], LambdaIntervention ):
850860 if intervention .is_source_constant :
851861 intervened_representation = do_intervention (
852862 selected_output ,
@@ -944,8 +954,8 @@ def _sync_forward_with_parallel_intervention(
944954 for key in keys :
945955 # skip in case smart jump
946956 if key in self .activations or \
947- isinstance (self .interventions [key ][ 0 ], types . FunctionType ) or \
948- self .interventions [key ][ 0 ] .is_source_constant :
957+ isinstance (self .interventions [key ], LambdaIntervention ) or \
958+ self .interventions [key ].is_source_constant :
949959 self ._intervention_setter (
950960 [key ],
951961 [
@@ -1056,7 +1066,7 @@ def forward(
10561066 if self .return_collect_activations :
10571067 for key in self .sorted_keys :
10581068 if isinstance (
1059- self .interventions [key ][ 0 ] ,
1069+ self .interventions [key ],
10601070 CollectIntervention
10611071 ):
10621072 collected_activations += self .activations [key ].clone ()
@@ -1191,7 +1201,7 @@ def save(
11911201 serialized_representations
11921202
11931203 for k , v in self .interventions .items ():
1194- intervention = v [ 0 ]
1204+ intervention = v
11951205 saving_config .intervention_types += [str (type (intervention ))]
11961206 binary_filename = f"intkey_{ k } .bin"
11971207 # save intervention binary file
@@ -1288,7 +1298,7 @@ def load(
12881298
12891299 # load binary files
12901300 for i , (k , v ) in enumerate (intervenable .interventions .items ()):
1291- intervention = v [ 0 ]
1301+ intervention = v
12921302 binary_filename = f"intkey_{ k } .bin"
12931303 intervention .is_source_constant = \
12941304 saving_config .intervention_constant_sources [i ]
@@ -1334,7 +1344,7 @@ def save_intervention(self, save_directory, include_model=True):
13341344
13351345 # save binary files
13361346 for k , v in self .interventions .items ():
1337- intervention = v [ 0 ]
1347+ intervention = v
13381348 binary_filename = f"intkey_{ k } .bin"
13391349 # save intervention binary file
13401350 if isinstance (intervention , TrainableIntervention ):
@@ -1357,7 +1367,7 @@ def load_intervention(self, load_directory, include_model=True):
13571367 """
13581368 # load binary files
13591369 for i , (k , v ) in enumerate (self .interventions .items ()):
1360- intervention = v [ 0 ]
1370+ intervention = v
13611371 binary_filename = f"intkey_{ k } .bin"
13621372 if isinstance (intervention , TrainableIntervention ):
13631373 saved_state_dict = torch .load (os .path .join (load_directory , binary_filename ))
@@ -1379,7 +1389,8 @@ def _intervention_getter(
13791389 """
13801390 handlers = []
13811391 for key_i , key in enumerate (keys ):
1382- intervention , module_hook = self .interventions [key ]
1392+ intervention = self .interventions [key ]
1393+ module_hook = self .intervention_hooks [key ]
13831394
13841395 def hook_callback (model , args , kwargs , output = None ):
13851396 if self ._is_generation :
@@ -1524,7 +1535,8 @@ def _intervention_setter(
15241535
15251536 handlers = []
15261537 for key_i , key in enumerate (keys ):
1527- intervention , module_hook = self .interventions [key ]
1538+ intervention = self .interventions [key ]
1539+ module_hook = self .intervention_hooks [key ]
15281540 if unit_locations_base [0 ] is not None :
15291541 self ._batched_setter_activation_select [key ] = [
15301542 0 for _ in range (len (unit_locations_base [0 ]))
@@ -1570,7 +1582,7 @@ def hook_callback(model, args, kwargs, output=None):
15701582 # no-op to the output
15711583
15721584 else :
1573- if not isinstance (self .interventions [key ][ 0 ], types . FunctionType ):
1585+ if not isinstance (self .interventions [key ], LambdaIntervention ):
15741586 if intervention .is_source_constant :
15751587 raw_intervened_representation = do_intervention (
15761588 selected_output ,
@@ -1710,8 +1722,8 @@ def _wait_for_forward_with_parallel_intervention(
17101722 for key in keys :
17111723 # skip in case smart jump
17121724 if key in self .activations or \
1713- isinstance (self .interventions [key ][ 0 ], types . FunctionType ) or \
1714- self .interventions [key ][ 0 ] .is_source_constant :
1725+ isinstance (self .interventions [key ], LambdaIntervention ) or \
1726+ self .interventions [key ].is_source_constant :
17151727 set_handlers = self ._intervention_setter (
17161728 [key ],
17171729 [
@@ -1780,8 +1792,8 @@ def _wait_for_forward_with_serial_intervention(
17801792 for key in keys :
17811793 # skip in case smart jump
17821794 if key in self .activations or \
1783- isinstance (self .interventions [key ][ 0 ], types . FunctionType ) or \
1784- self .interventions [key ][ 0 ] .is_source_constant :
1795+ isinstance (self .interventions [key ], LambdaIntervention ) or \
1796+ self .interventions [key ].is_source_constant :
17851797 # set with intervened activation to source_i+1
17861798 set_handlers = self ._intervention_setter (
17871799 [key ],
@@ -1947,7 +1959,7 @@ def forward(
19471959 if self .return_collect_activations :
19481960 for key in self .sorted_keys :
19491961 if isinstance (
1950- self .interventions [key ][ 0 ] ,
1962+ self .interventions [key ],
19511963 CollectIntervention
19521964 ):
19531965 collected_activations += self .activations [key ]
@@ -2081,7 +2093,7 @@ def generate(
20812093 if self .return_collect_activations :
20822094 for key in self .sorted_keys :
20832095 if isinstance (
2084- self .interventions [key ][ 0 ] ,
2096+ self .interventions [key ],
20852097 CollectIntervention
20862098 ):
20872099 collected_activations += self .activations [key ]
0 commit comments