diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 71bb733e78..6351625244 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -6837,6 +6837,148 @@ def test_RelativePositionalEncodingLayer(): print(out) # random... +def _build_self_attention_layer(d, input, output, inside_rec_layer, query_axis, num_heads=8, key_dim=64, + value_dim=64, dropout=0.0): + """ + Essentially this does + d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads, + "total_key_dim": num_heads * key_dim, + "n_out": num_heads * value_dim, "from": [input], + "attention_left_only": inside_rec_layer, + "attention_dropout": dropout, "forward_weights_init": self.ff_init} + But using multiple layers. + """ + # Create (non-accumulated) query, key and value + d[output + '_qkv0'] = { + 'class': 'linear', 'activation': None, 'with_bias': False, 'from': [input], + 'n_out': num_heads * (2 * key_dim + value_dim)} # [B,T?,F|n*(2d_k+d_v)] + d[output + '_qkv'] = { + 'class': 'split_dims', 'axis': 'F', 'dims': (num_heads, 2 * key_dim + value_dim), + 'from': [output + '_qkv0']} # [B,T?,n,F|2d_k+d_v] + d[output + '_qkv_split'] = { + 'class': 'split', 'axis': 'F', 'size_splits': (key_dim, key_dim, value_dim), 'from': [output + '_qkv']} + d[output + '_query'] = {'class': 'copy', 'from': [output + '_qkv_split/0']} # [B,T?,n,F|d_k] + d[output + '_key'] = {'class': 'copy', 'from': [output + '_qkv_split/1']} # [B,T?,n,F|d_k] + d[output + '_value'] = {'class': 'copy', 'from': [output + '_qkv_split/2']} # [B,T?,n,F|d_v] + + # Accumulate keys/values or rename the axis + key_dim_tag = DimensionTag(kind=DimensionTag.Types.Time, description='self-att-keys') + key_axis = 'stag:' + key_dim_tag.description + if inside_rec_layer: + d[output + '_key_accum'] = { + 'class': 'cum_concat', 'from': [output + '_key'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_k] + d[output + '_value_accum'] = { + 'class': 'cum_concat', 'from': [output + '_value'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_v] + else: + d[output + '_key_accum'] = { + 'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag}, + 'from': [output + '_key']} # [B,T|keys,n,F|d_k] + d[output + '_value_accum'] = { + 'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag}, + 'from': [output + '_value']} # [B,T|keys,n,F|d_v] + + # Calculate the energies + d[output + '_energy'] = { + 'class': 'dot', 'from': [output + '_query', output + '_key_accum'], + 'red1': 'static:-1', 'red2': 'static:-1', + 'var1': None if inside_rec_layer else query_axis, 'var2': key_dim_tag} # [B,n,T?,T|rec-history] + + d[output + '_weights'] = { + 'class': 'softmax_over_spatial', 'from': [output + '_energy'], 'axis': key_axis, + 'energy_factor': key_dim ** -0.5} # [B,n,T?,T|rec-history] + d[output + '_weights_drop'] = { + 'class': 'dropout', 'dropout_noise_shape': {'*': None}, 'from': [output + '_weights'], + 'dropout': dropout} # [B,n,T?,T|rec-history] + + d[output + '_output'] = { + 'class': 'dot', 'from': [output + '_weights_drop', output + '_value_accum'], + 'red1': key_axis, 'red2': key_axis, + "var1": None if inside_rec_layer else query_axis, "var2": "static:-1"} # [B,n,T?,F|d_v] + d[output + '_att'] = {'class': 'merge_dims', 'axes': 'static', 'from': [output + '_output']} # [B,T?,F|n*d_v] + + +def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer(): + n_time = 13 + num_heads, key_dim, value_dim = 2, 3, 3 + for inside_rec_layer in [False, True]: + with make_scope() as session: + print('Testing inside_rec_layer=%s' % inside_rec_layer) + + # build net dict + if inside_rec_layer: + net_dict = { + "output": { + "class": "rec", "target": "classes", "from": [], + "unit": { + "single_layer_att": { + "class": "self_attention", "from": "prev:single_layer_att", "num_heads": num_heads, + "total_key_dim": num_heads * key_dim, "n_out": num_heads * value_dim, + "attention_left_only": inside_rec_layer, 'is_output_layer': True}, # [B,T,F] + "multi_layer_att": None, # [B,T,F], added below. + "output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]}}}} + _build_self_attention_layer( + net_dict["output"]["unit"], 'prev:multi_layer_att', 'multi_layer', inside_rec_layer=True, + query_axis='stag:extern_data:classes', num_heads=num_heads, key_dim=key_dim, value_dim=value_dim) + net_dict["output"]["unit"]["multi_layer_att"]["is_output_layer"] = True + net_dict["output"]["unit"]["multi_layer_qkv0"]["is_output_layer"] = True # we need to set the matrix here + else: + net_dict = { + "single_layer_att": { + "class": "self_attention", "from": "data", "num_heads": num_heads, "total_key_dim": num_heads * key_dim, + "n_out": num_heads * value_dim, "attention_left_only": inside_rec_layer, + 'is_output_layer': True}, # [B,T,F] + "multi_layer_att": None, # [B,T,F], added below. + "output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]} + } + _build_self_attention_layer( + net_dict, 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:data', + num_heads=num_heads, key_dim=key_dim, value_dim=value_dim) + net_dict["multi_layer_att"]["is_output_layer"] = True + + config = Config({ + "debug_print_layer_output_template": True, "optimize_move_layers_out": True}) + config.update(dict(num_inputs=num_heads*key_dim, num_outputs=num_heads*value_dim)) + network = TFNetwork(config=config, train_flag=True) + from pprint import pprint + pprint(net_dict) + network.construct_from_dict(net_dict) + + if inside_rec_layer: + single_layer = network.get_layer("output/single_layer_att") + multi_layer = network.get_layer("output/multi_layer_att") + + # Note: single_layer.params etc. do not contain the params, need to access rec cell directly + rec_layer = network.get_layer("output") + single_weights = rec_layer.cell.net.get_layer("single_layer_att").params["QKV"] + multi_weights = rec_layer.cell.net.get_layer("multi_layer_qkv0").params["W"] + else: + single_layer = network.get_layer("single_layer_att") + multi_layer = network.get_layer("multi_layer_att") + single_weights = single_layer.params["QKV"] + multi_weights = network.get_layer("multi_layer_qkv0").params["W"] + + assert_equal(single_layer.output.batch_shape, (None, None, num_heads * value_dim)) + assert_equal(multi_layer.output.batch_shape, (None, None, num_heads * value_dim)) + + # set weights equal. + assert_equal(single_weights.shape, multi_weights.shape) + weights = numpy.random.rand(*single_weights.shape) + session.run(tf.compat.v1.assign(single_weights, weights)) + session.run(tf.compat.v1.assign(multi_weights, weights)) + + # fetch/compare outputs + from tests.test_TFNetworkLayer import make_feed_dict + feed_dict = make_feed_dict(network.extern_data.data.values(), same_time=True, n_time=n_time) + single, multi = session.run( + [single_layer.output.placeholder, multi_layer.output.placeholder], feed_dict=feed_dict) + print('single layer output:') + pprint(single) + print('multi layer output:') + pprint(multi) + numpy.testing.assert_almost_equal(single, multi, decimal=5) + print('They are equal!') + + if __name__ == "__main__": try: better_exchook.install()