@@ -6883,6 +6883,148 @@ def test_RelativePositionalEncodingLayer():
68836883 print (out ) # random...
68846884
68856885
6886+ def _build_self_attention_layer (d , input , output , inside_rec_layer , query_axis , num_heads = 8 , key_dim = 64 ,
6887+ value_dim = 64 , dropout = 0.0 ):
6888+ """
6889+ Essentially this does
6890+ d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads,
6891+ "total_key_dim": num_heads * key_dim,
6892+ "n_out": num_heads * value_dim, "from": [input],
6893+ "attention_left_only": inside_rec_layer,
6894+ "attention_dropout": dropout, "forward_weights_init": self.ff_init}
6895+ But using multiple layers.
6896+ """
6897+ # Create (non-accumulated) query, key and value
6898+ d [output + '_qkv0' ] = {
6899+ 'class' : 'linear' , 'activation' : None , 'with_bias' : False , 'from' : [input ],
6900+ 'n_out' : num_heads * (2 * key_dim + value_dim )} # [B,T?,F|n*(2d_k+d_v)]
6901+ d [output + '_qkv' ] = {
6902+ 'class' : 'split_dims' , 'axis' : 'F' , 'dims' : (num_heads , 2 * key_dim + value_dim ),
6903+ 'from' : [output + '_qkv0' ]} # [B,T?,n,F|2d_k+d_v]
6904+ d [output + '_qkv_split' ] = {
6905+ 'class' : 'split' , 'axis' : 'F' , 'size_splits' : (key_dim , key_dim , value_dim ), 'from' : [output + '_qkv' ]}
6906+ d [output + '_query' ] = {'class' : 'copy' , 'from' : [output + '_qkv_split/0' ]} # [B,T?,n,F|d_k]
6907+ d [output + '_key' ] = {'class' : 'copy' , 'from' : [output + '_qkv_split/1' ]} # [B,T?,n,F|d_k]
6908+ d [output + '_value' ] = {'class' : 'copy' , 'from' : [output + '_qkv_split/2' ]} # [B,T?,n,F|d_v]
6909+
6910+ # Accumulate keys/values or rename the axis
6911+ key_dim_tag = DimensionTag (kind = DimensionTag .Types .Time , description = 'self-att-keys' )
6912+ key_axis = 'stag:' + key_dim_tag .description
6913+ if inside_rec_layer :
6914+ d [output + '_key_accum' ] = {
6915+ 'class' : 'cum_concat' , 'from' : [output + '_key' ], 'new_dim' : key_dim_tag } # [B,T|rec-history,n,F|d_k]
6916+ d [output + '_value_accum' ] = {
6917+ 'class' : 'cum_concat' , 'from' : [output + '_value' ], 'new_dim' : key_dim_tag } # [B,T|rec-history,n,F|d_v]
6918+ else :
6919+ d [output + '_key_accum' ] = {
6920+ 'class' : 'reinterpret_data' , 'set_dim_tags' : {query_axis : key_dim_tag },
6921+ 'from' : [output + '_key' ]} # [B,T|keys,n,F|d_k]
6922+ d [output + '_value_accum' ] = {
6923+ 'class' : 'reinterpret_data' , 'set_dim_tags' : {query_axis : key_dim_tag },
6924+ 'from' : [output + '_value' ]} # [B,T|keys,n,F|d_v]
6925+
6926+ # Calculate the energies
6927+ d [output + '_energy' ] = {
6928+ 'class' : 'dot' , 'from' : [output + '_query' , output + '_key_accum' ],
6929+ 'red1' : 'static:-1' , 'red2' : 'static:-1' ,
6930+ 'var1' : None if inside_rec_layer else query_axis , 'var2' : key_dim_tag } # [B,n,T?,T|rec-history]
6931+
6932+ d [output + '_weights' ] = {
6933+ 'class' : 'softmax_over_spatial' , 'from' : [output + '_energy' ], 'axis' : key_axis ,
6934+ 'energy_factor' : key_dim ** - 0.5 } # [B,n,T?,T|rec-history]
6935+ d [output + '_weights_drop' ] = {
6936+ 'class' : 'dropout' , 'dropout_noise_shape' : {'*' : None }, 'from' : [output + '_weights' ],
6937+ 'dropout' : dropout } # [B,n,T?,T|rec-history]
6938+
6939+ d [output + '_output' ] = {
6940+ 'class' : 'dot' , 'from' : [output + '_weights_drop' , output + '_value_accum' ],
6941+ 'red1' : key_axis , 'red2' : key_axis ,
6942+ "var1" : None if inside_rec_layer else query_axis , "var2" : "static:-1" } # [B,n,T?,F|d_v]
6943+ d [output + '_att' ] = {'class' : 'merge_dims' , 'axes' : 'static' , 'from' : [output + '_output' ]} # [B,T?,F|n*d_v]
6944+
6945+
6946+ def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer ():
6947+ n_time = 13
6948+ num_heads , key_dim , value_dim = 2 , 3 , 3
6949+ for inside_rec_layer in [False , True ]:
6950+ with make_scope () as session :
6951+ print ('Testing inside_rec_layer=%s' % inside_rec_layer )
6952+
6953+ # build net dict
6954+ if inside_rec_layer :
6955+ net_dict = {
6956+ "output" : {
6957+ "class" : "rec" , "target" : "classes" , "from" : [],
6958+ "unit" : {
6959+ "single_layer_att" : {
6960+ "class" : "self_attention" , "from" : "prev:single_layer_att" , "num_heads" : num_heads ,
6961+ "total_key_dim" : num_heads * key_dim , "n_out" : num_heads * value_dim ,
6962+ "attention_left_only" : inside_rec_layer , 'is_output_layer' : True }, # [B,T,F]
6963+ "multi_layer_att" : None , # [B,T,F], added below.
6964+ "output" : {"class" : "compare" , "from" : ["single_layer_att" , "multi_layer_att" ]}}}}
6965+ _build_self_attention_layer (
6966+ net_dict ["output" ]["unit" ], 'prev:multi_layer_att' , 'multi_layer' , inside_rec_layer = True ,
6967+ query_axis = 'stag:extern_data:classes' , num_heads = num_heads , key_dim = key_dim , value_dim = value_dim )
6968+ net_dict ["output" ]["unit" ]["multi_layer_att" ]["is_output_layer" ] = True
6969+ net_dict ["output" ]["unit" ]["multi_layer_qkv0" ]["is_output_layer" ] = True # we need to set the matrix here
6970+ else :
6971+ net_dict = {
6972+ "single_layer_att" : {
6973+ "class" : "self_attention" , "from" : "data" , "num_heads" : num_heads , "total_key_dim" : num_heads * key_dim ,
6974+ "n_out" : num_heads * value_dim , "attention_left_only" : inside_rec_layer ,
6975+ 'is_output_layer' : True }, # [B,T,F]
6976+ "multi_layer_att" : None , # [B,T,F], added below.
6977+ "output" : {"class" : "compare" , "from" : ["single_layer_att" , "multi_layer_att" ]}
6978+ }
6979+ _build_self_attention_layer (
6980+ net_dict , 'data' , 'multi_layer' , inside_rec_layer = False , query_axis = 'stag:extern_data:data' ,
6981+ num_heads = num_heads , key_dim = key_dim , value_dim = value_dim )
6982+ net_dict ["multi_layer_att" ]["is_output_layer" ] = True
6983+
6984+ config = Config ({
6985+ "debug_print_layer_output_template" : True , "optimize_move_layers_out" : True })
6986+ config .update (dict (num_inputs = num_heads * key_dim , num_outputs = num_heads * value_dim ))
6987+ network = TFNetwork (config = config , train_flag = True )
6988+ from pprint import pprint
6989+ pprint (net_dict )
6990+ network .construct_from_dict (net_dict )
6991+
6992+ if inside_rec_layer :
6993+ single_layer = network .get_layer ("output/single_layer_att" )
6994+ multi_layer = network .get_layer ("output/multi_layer_att" )
6995+
6996+ # Note: single_layer.params etc. do not contain the params, need to access rec cell directly
6997+ rec_layer = network .get_layer ("output" )
6998+ single_weights = rec_layer .cell .net .get_layer ("single_layer_att" ).params ["QKV" ]
6999+ multi_weights = rec_layer .cell .net .get_layer ("multi_layer_qkv0" ).params ["W" ]
7000+ else :
7001+ single_layer = network .get_layer ("single_layer_att" )
7002+ multi_layer = network .get_layer ("multi_layer_att" )
7003+ single_weights = single_layer .params ["QKV" ]
7004+ multi_weights = network .get_layer ("multi_layer_qkv0" ).params ["W" ]
7005+
7006+ assert_equal (single_layer .output .batch_shape , (None , None , num_heads * value_dim ))
7007+ assert_equal (multi_layer .output .batch_shape , (None , None , num_heads * value_dim ))
7008+
7009+ # set weights equal.
7010+ assert_equal (single_weights .shape , multi_weights .shape )
7011+ weights = numpy .random .rand (* single_weights .shape )
7012+ session .run (tf .compat .v1 .assign (single_weights , weights ))
7013+ session .run (tf .compat .v1 .assign (multi_weights , weights ))
7014+
7015+ # fetch/compare outputs
7016+ from tests .test_TFNetworkLayer import make_feed_dict
7017+ feed_dict = make_feed_dict (network .extern_data .data .values (), same_time = True , n_time = n_time )
7018+ single , multi = session .run (
7019+ [single_layer .output .placeholder , multi_layer .output .placeholder ], feed_dict = feed_dict )
7020+ print ('single layer output:' )
7021+ pprint (single )
7022+ print ('multi layer output:' )
7023+ pprint (multi )
7024+ numpy .testing .assert_almost_equal (single , multi , decimal = 5 )
7025+ print ('They are equal!' )
7026+
7027+
68867028if __name__ == "__main__" :
68877029 try :
68887030 better_exchook .install ()
0 commit comments