@@ -6550,6 +6550,150 @@ def test_RelativePositionalEncodingLayer():
65506550 print (out ) # random...
65516551
65526552
6553+ def _build_self_attention_layer (d , input , output , inside_rec_layer , query_axis , num_heads = 8 , key_dim = 64 ,
6554+ value_dim = 64 , dropout = 0.0 ):
6555+ """
6556+ Essentially this does
6557+ d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads,
6558+ "total_key_dim": num_heads * key_dim,
6559+ "n_out": num_heads * value_dim, "from": [input],
6560+ "attention_left_only": inside_rec_layer,
6561+ "attention_dropout": dropout, "forward_weights_init": self.ff_init}
6562+ But using multiple layers.
6563+ """
6564+ # Create (non-accumulated) query, key and value
6565+ d [output + '_qkv0' ] = {
6566+ 'class' : 'linear' , 'activation' : None , 'with_bias' : False , 'from' : [input ],
6567+ 'n_out' : num_heads * (2 * key_dim + value_dim )} # [B,T?,F|n*(2d_k+d_v)]
6568+ d [output + '_qkv' ] = {
6569+ 'class' : 'split_dims' , 'axis' : 'F' , 'dims' : (num_heads , 2 * key_dim + value_dim ),
6570+ 'from' : [output + '_qkv0' ]} # [B,T?,n,F|2d_k+d_v]
6571+ d [output + '_qkv_split' ] = {
6572+ 'class' : 'split' , 'axis' : 'F' , 'size_splits' : (key_dim , key_dim , value_dim ),
6573+ 'from' : [output + '_qkv' ]}
6574+ d [output + '_query' ] = {
6575+ 'class' : 'copy' , 'from' : [output + '_qkv_split/0' ]} # [B,T?,n,F|d_k]
6576+ d [output + '_key' ] = {
6577+ 'class' : 'copy' , 'from' : [output + '_qkv_split/1' ]} # [B,T?,n,F|d_k]
6578+ d [output + '_value' ] = {
6579+ 'class' : 'copy' , 'from' : [output + '_qkv_split/2' ]} # [B,T?,n,F|d_v]
6580+
6581+ # Accumulate keys/values or rename the axis
6582+ key_dim_tag = DimensionTag (kind = DimensionTag .Types .Time , description = 'self-att-keys' )
6583+ key_axis = 'stag:' + key_dim_tag .description
6584+ if inside_rec_layer :
6585+ d [output + '_key_accum' ] = {
6586+ 'class' : 'cum_concat' , 'from' : [output + '_key' ], 'new_dim' : key_dim_tag } # [B,T|rec-history,n,F|d_k]
6587+ d [output + '_value_accum' ] = {
6588+ 'class' : 'cum_concat' , 'from' : [output + '_value' ], 'new_dim' : key_dim_tag } # [B,T|rec-history,n,F|d_v]
6589+ else :
6590+ d [output + '_key_accum' ] = {
6591+ 'class' : 'reinterpret_data' , 'set_dim_tags' : {query_axis : key_dim_tag },
6592+ 'from' : [output + '_key' ]} # [B,T|keys,n,F|d_k]
6593+ d [output + '_value_accum' ] = {
6594+ 'class' : 'reinterpret_data' , 'set_dim_tags' : {query_axis : key_dim_tag },
6595+ 'from' : [output + '_value' ]} # [B,T|keys,n,F|d_v]
6596+
6597+ # Calculate the energies
6598+ d [output + '_energy' ] = {
6599+ 'class' : 'dot' , 'from' : [output + '_query' , output + '_key_accum' ],
6600+ 'red1' : 'static:-1' , 'red2' : 'static:-1' , 'common' : ['B' , 'static:0' ]} # [B,n,T?,T|rec-history]
6601+
6602+ d [output + '_weights' ] = {
6603+ 'class' : 'softmax_over_spatial' , 'from' : [output + '_energy' ], 'axis' : key_axis ,
6604+ 'energy_factor' : key_dim ** - 0.5 } # [B,n,T?,T|rec-history]
6605+ d [output + '_weights_drop' ] = {
6606+ 'class' : 'dropout' , 'dropout_noise_shape' : {'*' : None }, 'from' : [output + '_weights' ],
6607+ 'dropout' : dropout } # [B,n,T?,T|rec-history]
6608+
6609+ d [output + '_output' ] = {
6610+ 'class' : 'dot' , 'from' : [output + '_weights_drop' , output + '_value_accum' ],
6611+ 'red1' : key_axis , 'red2' : key_axis , 'common' : ['B' , query_axis , 'static:0' ]} # [B,n,T?,F|d_v]
6612+ d [output + '_att' ] = {
6613+ 'class' : 'merge_dims' , 'axes' : 'static' , 'from' : [output + '_output' ]} # [B,T?,F|n*d_v]
6614+
6615+
6616+ def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer ():
6617+ n_time = 13
6618+ num_heads , key_dim , value_dim = 2 , 3 , 3
6619+ for inside_rec_layer in [False , True ]:
6620+ with make_scope () as session :
6621+ print ('Testing inside_rec_layer=%s' % inside_rec_layer )
6622+
6623+ # build net dict
6624+ single_layer_net_dict = {
6625+ "class" : "self_attention" , "from" : "data" , "num_heads" : num_heads , "total_key_dim" : num_heads * key_dim ,
6626+ "n_out" : num_heads * value_dim , "attention_left_only" : inside_rec_layer , 'is_output_layer' : True } # [B,T,F]
6627+ if inside_rec_layer :
6628+ net_dict = {
6629+ "output" : {
6630+ "class" : "rec" , "target" : "classes" ,
6631+ "unit" : {
6632+ "single_layer_att" : single_layer_net_dict , # [B,T,F]
6633+ "multi_layer_att" : None # [B,T,F], added below.
6634+ }}}
6635+ _build_self_attention_layer (
6636+ net_dict ["output" ], 'data' , 'multi_layer' , inside_rec_layer = False , query_axis = 'stag:extern_data:classes' ,
6637+ num_heads = num_heads , key_dim = key_dim , value_dim = value_dim )
6638+ net_dict ["output" ]["multi_layer_att" ]["is_output_layer" ] = True
6639+ else :
6640+ net_dict = {
6641+ "single_layer_att" : single_layer_net_dict , # [B,T,F]
6642+ "multi_layer_att" : None # [B,T,F], added below.
6643+ }
6644+ _build_self_attention_layer (
6645+ net_dict , 'data' , 'multi_layer' , inside_rec_layer = False , query_axis = 'stag:extern_data:data' ,
6646+ num_heads = num_heads , key_dim = key_dim , value_dim = value_dim )
6647+ net_dict ["multi_layer_att" ]["is_output_layer" ] = True
6648+
6649+ config = Config ({"debug_print_layer_output_template" : True , "debug_add_check_numerics_ops" : True })
6650+ config .update (dict (num_inputs = num_heads * key_dim , num_outputs = num_heads * value_dim ))
6651+ network = TFNetwork (config = config , train_flag = True )
6652+ network .construct_from_dict (net_dict )
6653+
6654+ if inside_rec_layer :
6655+ single_layer = network .get_layer ("output/single_layer_att" )
6656+ multi_layer = network .get_layer ("output/multi_layer_att" )
6657+ else :
6658+ single_layer = network .get_layer ("single_layer_att" )
6659+ multi_layer = network .get_layer ("multi_layer_att" )
6660+
6661+ assert_equal (single_layer .output .shape , (None , num_heads * value_dim ))
6662+ assert_equal (multi_layer .output .shape , (None , num_heads * value_dim ))
6663+
6664+ # set weights equal.
6665+ single_weights = single_layer .params ["QKV" ]
6666+ multi_weights = multi_layer .params ["W" ]
6667+ assert_equal (single_weights .shape , multi_weights .shape )
6668+ weights = numpy .random .rand (* single_weights .shape )
6669+ session .run (tf .assign (single_weights , weights ))
6670+ session .run (tf .assign (multi_weights , weights ))
6671+
6672+ # fetch/compare outputs
6673+ from tests .test_TFNetworkLayer import make_feed_dict
6674+ feed_dict = make_feed_dict (network .extern_data .data .values (), same_time = True , n_time = n_time )
6675+ single , multi = session .run (
6676+ [single_layer .output .placeholder , multi_layer .output .placeholder ], feed_dict = feed_dict )
6677+ print ('single layer output:' )
6678+ pprint (single )
6679+ print ('multi layer output:' )
6680+ pprint (multi )
6681+ numpy .testing .assert_almost_equal (single , multi , decimal = 5 )
6682+ print ('They are equal!' )
6683+
6684+
6685+ def test_self_attention_optimize_out ():
6686+ num_heads , key_dim , value_dim = 2 , 3 , 3
6687+ network = {}
6688+ _build_self_attention_layer (
6689+ network , 'data:source' , 'att' , inside_rec_layer = True , query_axis = 'stag:extern_data:data' ,
6690+ num_heads = num_heads , key_dim = key_dim , value_dim = value_dim )
6691+
6692+ check_reclayer_optimize_out (
6693+ {'class' : 'copy' , 'from' : 'att_att' , 'n_out' : value_dim * num_heads },
6694+ other_subnet_layers = network )
6695+
6696+
65536697if __name__ == "__main__" :
65546698 try :
65556699 better_exchook .install ()
0 commit comments