Skip to content

Commit 405218b

Browse files
committed
Tests for self attention using CumConcatLayer
1 parent b2b5cb5 commit 405218b

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6837,6 +6837,150 @@ def test_RelativePositionalEncodingLayer():
68376837
print(out) # random...
68386838

68396839

6840+
def _build_self_attention_layer(d, input, output, inside_rec_layer, query_axis, num_heads=8, key_dim=64,
6841+
value_dim=64, dropout=0.0):
6842+
"""
6843+
Essentially this does
6844+
d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads,
6845+
"total_key_dim": num_heads * key_dim,
6846+
"n_out": num_heads * value_dim, "from": [input],
6847+
"attention_left_only": inside_rec_layer,
6848+
"attention_dropout": dropout, "forward_weights_init": self.ff_init}
6849+
But using multiple layers.
6850+
"""
6851+
# Create (non-accumulated) query, key and value
6852+
d[output + '_qkv0'] = {
6853+
'class': 'linear', 'activation': None, 'with_bias': False, 'from': [input],
6854+
'n_out': num_heads * (2 * key_dim + value_dim)} # [B,T?,F|n*(2d_k+d_v)]
6855+
d[output + '_qkv'] = {
6856+
'class': 'split_dims', 'axis': 'F', 'dims': (num_heads, 2 * key_dim + value_dim),
6857+
'from': [output + '_qkv0']} # [B,T?,n,F|2d_k+d_v]
6858+
d[output + '_qkv_split'] = {
6859+
'class': 'split', 'axis': 'F', 'size_splits': (key_dim, key_dim, value_dim),
6860+
'from': [output + '_qkv']}
6861+
d[output + '_query'] = {
6862+
'class': 'copy', 'from': [output + '_qkv_split/0']} # [B,T?,n,F|d_k]
6863+
d[output + '_key'] = {
6864+
'class': 'copy', 'from': [output + '_qkv_split/1']} # [B,T?,n,F|d_k]
6865+
d[output + '_value'] = {
6866+
'class': 'copy', 'from': [output + '_qkv_split/2']} # [B,T?,n,F|d_v]
6867+
6868+
# Accumulate keys/values or rename the axis
6869+
key_dim_tag = DimensionTag(kind=DimensionTag.Types.Time, description='self-att-keys')
6870+
key_axis = 'stag:' + key_dim_tag.description
6871+
if inside_rec_layer:
6872+
d[output + '_key_accum'] = {
6873+
'class': 'cum_concat', 'from': [output + '_key'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_k]
6874+
d[output + '_value_accum'] = {
6875+
'class': 'cum_concat', 'from': [output + '_value'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_v]
6876+
else:
6877+
d[output + '_key_accum'] = {
6878+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6879+
'from': [output + '_key']} # [B,T|keys,n,F|d_k]
6880+
d[output + '_value_accum'] = {
6881+
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
6882+
'from': [output + '_value']} # [B,T|keys,n,F|d_v]
6883+
6884+
# Calculate the energies
6885+
d[output + '_energy'] = {
6886+
'class': 'dot', 'from': [output + '_query', output + '_key_accum'],
6887+
'red1': 'static:-1', 'red2': 'static:-1', 'common': ['B', 'static:0']} # [B,n,T?,T|rec-history]
6888+
6889+
d[output + '_weights'] = {
6890+
'class': 'softmax_over_spatial', 'from': [output + '_energy'], 'axis': key_axis,
6891+
'energy_factor': key_dim ** -0.5} # [B,n,T?,T|rec-history]
6892+
d[output + '_weights_drop'] = {
6893+
'class': 'dropout', 'dropout_noise_shape': {'*': None}, 'from': [output + '_weights'],
6894+
'dropout': dropout} # [B,n,T?,T|rec-history]
6895+
6896+
d[output + '_output'] = {
6897+
'class': 'dot', 'from': [output + '_weights_drop', output + '_value_accum'],
6898+
'red1': key_axis, 'red2': key_axis, 'common': ['B', query_axis, 'static:0']} # [B,n,T?,F|d_v]
6899+
d[output + '_att'] = {
6900+
'class': 'merge_dims', 'axes': 'static', 'from': [output + '_output']} # [B,T?,F|n*d_v]
6901+
6902+
6903+
def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer():
6904+
n_time = 13
6905+
num_heads, key_dim, value_dim = 2, 3, 3
6906+
for inside_rec_layer in [False, True]:
6907+
with make_scope() as session:
6908+
print('Testing inside_rec_layer=%s' % inside_rec_layer)
6909+
6910+
# build net dict
6911+
single_layer_net_dict = {
6912+
"class": "self_attention", "from": "data", "num_heads": num_heads, "total_key_dim": num_heads * key_dim,
6913+
"n_out": num_heads * value_dim, "attention_left_only": inside_rec_layer, 'is_output_layer': True} # [B,T,F]
6914+
if inside_rec_layer:
6915+
net_dict = {
6916+
"output": {
6917+
"class": "rec", "target": "classes",
6918+
"unit": {
6919+
"single_layer_att": single_layer_net_dict, # [B,T,F]
6920+
"multi_layer_att": None # [B,T,F], added below.
6921+
}}}
6922+
_build_self_attention_layer(
6923+
net_dict["output"], 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:classes',
6924+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6925+
net_dict["output"]["multi_layer_att"]["is_output_layer"] = True
6926+
else:
6927+
net_dict = {
6928+
"single_layer_att": single_layer_net_dict, # [B,T,F]
6929+
"multi_layer_att": None # [B,T,F], added below.
6930+
}
6931+
_build_self_attention_layer(
6932+
net_dict, 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:data',
6933+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6934+
net_dict["multi_layer_att"]["is_output_layer"] = True
6935+
6936+
config = Config({"debug_print_layer_output_template": True, "debug_add_check_numerics_ops": True})
6937+
config.update(dict(num_inputs=num_heads*key_dim, num_outputs=num_heads*value_dim))
6938+
network = TFNetwork(config=config, train_flag=True)
6939+
network.construct_from_dict(net_dict)
6940+
6941+
if inside_rec_layer:
6942+
single_layer = network.get_layer("output/single_layer_att")
6943+
multi_layer = network.get_layer("output/multi_layer_att")
6944+
else:
6945+
single_layer = network.get_layer("single_layer_att")
6946+
multi_layer = network.get_layer("multi_layer_att")
6947+
6948+
assert_equal(single_layer.output.shape, (None, num_heads * value_dim))
6949+
assert_equal(multi_layer.output.shape, (None, num_heads * value_dim))
6950+
6951+
# set weights equal.
6952+
single_weights = single_layer.params["QKV"]
6953+
multi_weights = multi_layer.params["W"]
6954+
assert_equal(single_weights.shape, multi_weights.shape)
6955+
weights = numpy.random.rand(*single_weights.shape)
6956+
session.run(tf.assign(single_weights, weights))
6957+
session.run(tf.assign(multi_weights, weights))
6958+
6959+
# fetch/compare outputs
6960+
from tests.test_TFNetworkLayer import make_feed_dict
6961+
feed_dict = make_feed_dict(network.extern_data.data.values(), same_time=True, n_time=n_time)
6962+
single, multi = session.run(
6963+
[single_layer.output.placeholder, multi_layer.output.placeholder], feed_dict=feed_dict)
6964+
print('single layer output:')
6965+
pprint(single)
6966+
print('multi layer output:')
6967+
pprint(multi)
6968+
numpy.testing.assert_almost_equal(single, multi, decimal=5)
6969+
print('They are equal!')
6970+
6971+
6972+
def test_self_attention_optimize_out():
6973+
num_heads, key_dim, value_dim = 2, 3, 3
6974+
network = {}
6975+
_build_self_attention_layer(
6976+
network, 'data:source', 'att', inside_rec_layer=True, query_axis='stag:extern_data:data',
6977+
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
6978+
6979+
check_reclayer_optimize_out(
6980+
{'class': 'copy', 'from': 'att_att', 'n_out': value_dim * num_heads},
6981+
other_subnet_layers=network)
6982+
6983+
68406984
if __name__ == "__main__":
68416985
try:
68426986
better_exchook.install()

0 commit comments

Comments
 (0)