Skip to content

Commit 9e088cd

Browse files
authored
Tests for self attention using CumConcatLayer (#590)
1 parent ac3a9e2 commit 9e088cd

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6883,6 +6883,148 @@ def test_RelativePositionalEncodingLayer():
68836883
print(out) # random...
68846884

68856885

6886+
def _build_self_attention_layer(d, input, output, inside_rec_layer, query_axis, num_heads=8, key_dim=64,
6887+
value_dim=64, dropout=0.0):
6888+
"""
6889+
Essentially this does
6890+
d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads,
6891+
"total_key_dim": num_heads * key_dim,
6892+
"n_out": num_heads * value_dim, "from": [input],
6893+
"attention_left_only": inside_rec_layer,
6894+
"attention_dropout": dropout, "forward_weights_init": self.ff_init}
6895+
But using multiple layers.
6896+
"""
6897+
# Create (non-accumulated) query, key and value
6898+
d[output + '_qkv0'] = {
6899+
'class': 'linear', 'activation': None, 'with_bias': False, 'from': [input],
6900+
'n_out': num_heads * (2 * key_dim + value_dim)} # [B,T?,F|n*(2d_k+d_v)]
6901+
d[output + '_qkv'] = {
6902+
'class': 'split_dims', 'axis': 'F', 'dims': (num_heads, 2 * key_dim + value_dim),
6903+
'from': [output + '_qkv0']} # [B,T?,n,F|2d_k+d_v]
6904+
d[output + '_qkv_split'] = {
6905+
'class': 'split', 'axis': 'F', 'size_splits': (key_dim, key_dim, value_dim), 'from': [output + '_qkv']}
6906+
d[output + '_query'] = {'class': 'copy', 'from': [output + '_qkv_split/0']} # [B,T?,n,F|d_k]
6907+
d[output + '_key'] = {'class': 'copy', 'from': [output + '_qkv_split/1']} # [B,T?,n,F|d_k]
6908+
d[output + '_value'] = {'class': 'copy', 'from': [output + '_qkv_split/2']} # [B,T?,n,F|d_v]
6909+
6910+
# Accumulate keys/values or rename the axis
6911+
key_dim_tag = DimensionTag(kind=DimensionTag.Types.Time, description='self-att-keys')
6912+
key_axis = 'stag:' + key_dim_tag.description
6913+
if inside_rec_layer:
6914+
d[output + '_key_accum'] = {
6915+
'class': 'cum_concat', 'from': [output + '_key'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_k]
6916+
d[output + '_value_accum'] = {
6917+
'class': 'cum_concat', 'from': [output + '_value'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_v]
6918+
else:
6919+
d[output + '_key_accum'] = {
6920+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6921+
'from': [output + '_key']} # [B,T|keys,n,F|d_k]
6922+
d[output + '_value_accum'] = {
6923+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6924+
'from': [output + '_value']} # [B,T|keys,n,F|d_v]
6925+
6926+
# Calculate the energies
6927+
d[output + '_energy'] = {
6928+
'class': 'dot', 'from': [output + '_query', output + '_key_accum'],
6929+
'red1': 'static:-1', 'red2': 'static:-1',
6930+
'var1': None if inside_rec_layer else query_axis, 'var2': key_dim_tag} # [B,n,T?,T|rec-history]
6931+
6932+
d[output + '_weights'] = {
6933+
'class': 'softmax_over_spatial', 'from': [output + '_energy'], 'axis': key_axis,
6934+
'energy_factor': key_dim ** -0.5} # [B,n,T?,T|rec-history]
6935+
d[output + '_weights_drop'] = {
6936+
'class': 'dropout', 'dropout_noise_shape': {'*': None}, 'from': [output + '_weights'],
6937+
'dropout': dropout} # [B,n,T?,T|rec-history]
6938+
6939+
d[output + '_output'] = {
6940+
'class': 'dot', 'from': [output + '_weights_drop', output + '_value_accum'],
6941+
'red1': key_axis, 'red2': key_axis,
6942+
"var1": None if inside_rec_layer else query_axis, "var2": "static:-1"} # [B,n,T?,F|d_v]
6943+
d[output + '_att'] = {'class': 'merge_dims', 'axes': 'static', 'from': [output + '_output']} # [B,T?,F|n*d_v]
6944+
6945+
6946+
def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer():
6947+
n_time = 13
6948+
num_heads, key_dim, value_dim = 2, 3, 3
6949+
for inside_rec_layer in [False, True]:
6950+
with make_scope() as session:
6951+
print('Testing inside_rec_layer=%s' % inside_rec_layer)
6952+
6953+
# build net dict
6954+
if inside_rec_layer:
6955+
net_dict = {
6956+
"output": {
6957+
"class": "rec", "target": "classes", "from": [],
6958+
"unit": {
6959+
"single_layer_att": {
6960+
"class": "self_attention", "from": "prev:single_layer_att", "num_heads": num_heads,
6961+
"total_key_dim": num_heads * key_dim, "n_out": num_heads * value_dim,
6962+
"attention_left_only": inside_rec_layer, 'is_output_layer': True}, # [B,T,F]
6963+
"multi_layer_att": None, # [B,T,F], added below.
6964+
"output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]}}}}
6965+
_build_self_attention_layer(
6966+
net_dict["output"]["unit"], 'prev:multi_layer_att', 'multi_layer', inside_rec_layer=True,
6967+
query_axis='stag:extern_data:classes', num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6968+
net_dict["output"]["unit"]["multi_layer_att"]["is_output_layer"] = True
6969+
net_dict["output"]["unit"]["multi_layer_qkv0"]["is_output_layer"] = True # we need to set the matrix here
6970+
else:
6971+
net_dict = {
6972+
"single_layer_att": {
6973+
"class": "self_attention", "from": "data", "num_heads": num_heads, "total_key_dim": num_heads * key_dim,
6974+
"n_out": num_heads * value_dim, "attention_left_only": inside_rec_layer,
6975+
'is_output_layer': True}, # [B,T,F]
6976+
"multi_layer_att": None, # [B,T,F], added below.
6977+
"output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]}
6978+
}
6979+
_build_self_attention_layer(
6980+
net_dict, 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:data',
6981+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6982+
net_dict["multi_layer_att"]["is_output_layer"] = True
6983+
6984+
config = Config({
6985+
"debug_print_layer_output_template": True, "optimize_move_layers_out": True})
6986+
config.update(dict(num_inputs=num_heads*key_dim, num_outputs=num_heads*value_dim))
6987+
network = TFNetwork(config=config, train_flag=True)
6988+
from pprint import pprint
6989+
pprint(net_dict)
6990+
network.construct_from_dict(net_dict)
6991+
6992+
if inside_rec_layer:
6993+
single_layer = network.get_layer("output/single_layer_att")
6994+
multi_layer = network.get_layer("output/multi_layer_att")
6995+
6996+
# Note: single_layer.params etc. do not contain the params, need to access rec cell directly
6997+
rec_layer = network.get_layer("output")
6998+
single_weights = rec_layer.cell.net.get_layer("single_layer_att").params["QKV"]
6999+
multi_weights = rec_layer.cell.net.get_layer("multi_layer_qkv0").params["W"]
7000+
else:
7001+
single_layer = network.get_layer("single_layer_att")
7002+
multi_layer = network.get_layer("multi_layer_att")
7003+
single_weights = single_layer.params["QKV"]
7004+
multi_weights = network.get_layer("multi_layer_qkv0").params["W"]
7005+
7006+
assert_equal(single_layer.output.batch_shape, (None, None, num_heads * value_dim))
7007+
assert_equal(multi_layer.output.batch_shape, (None, None, num_heads * value_dim))
7008+
7009+
# set weights equal.
7010+
assert_equal(single_weights.shape, multi_weights.shape)
7011+
weights = numpy.random.rand(*single_weights.shape)
7012+
session.run(tf.compat.v1.assign(single_weights, weights))
7013+
session.run(tf.compat.v1.assign(multi_weights, weights))
7014+
7015+
# fetch/compare outputs
7016+
from tests.test_TFNetworkLayer import make_feed_dict
7017+
feed_dict = make_feed_dict(network.extern_data.data.values(), same_time=True, n_time=n_time)
7018+
single, multi = session.run(
7019+
[single_layer.output.placeholder, multi_layer.output.placeholder], feed_dict=feed_dict)
7020+
print('single layer output:')
7021+
pprint(single)
7022+
print('multi layer output:')
7023+
pprint(multi)
7024+
numpy.testing.assert_almost_equal(single, multi, decimal=5)
7025+
print('They are equal!')
7026+
7027+
68867028
if __name__ == "__main__":
68877029
try:
68887030
better_exchook.install()

0 commit comments

Comments
 (0)