Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6837,6 +6837,148 @@ def test_RelativePositionalEncodingLayer():
print(out) # random...


def _build_self_attention_layer(d, input, output, inside_rec_layer, query_axis, num_heads=8, key_dim=64,
value_dim=64, dropout=0.0):
"""
Essentially this does
d[output + '_att'] = {"class": "self_attention", "num_heads": num_heads,
"total_key_dim": num_heads * key_dim,
"n_out": num_heads * value_dim, "from": [input],
"attention_left_only": inside_rec_layer,
"attention_dropout": dropout, "forward_weights_init": self.ff_init}
But using multiple layers.
"""
# Create (non-accumulated) query, key and value
d[output + '_qkv0'] = {
'class': 'linear', 'activation': None, 'with_bias': False, 'from': [input],
'n_out': num_heads * (2 * key_dim + value_dim)} # [B,T?,F|n*(2d_k+d_v)]
d[output + '_qkv'] = {
'class': 'split_dims', 'axis': 'F', 'dims': (num_heads, 2 * key_dim + value_dim),
'from': [output + '_qkv0']} # [B,T?,n,F|2d_k+d_v]
d[output + '_qkv_split'] = {
'class': 'split', 'axis': 'F', 'size_splits': (key_dim, key_dim, value_dim), 'from': [output + '_qkv']}
d[output + '_query'] = {'class': 'copy', 'from': [output + '_qkv_split/0']} # [B,T?,n,F|d_k]
d[output + '_key'] = {'class': 'copy', 'from': [output + '_qkv_split/1']} # [B,T?,n,F|d_k]
d[output + '_value'] = {'class': 'copy', 'from': [output + '_qkv_split/2']} # [B,T?,n,F|d_v]

# Accumulate keys/values or rename the axis
key_dim_tag = DimensionTag(kind=DimensionTag.Types.Time, description='self-att-keys')
key_axis = 'stag:' + key_dim_tag.description
if inside_rec_layer:
d[output + '_key_accum'] = {
'class': 'cum_concat', 'from': [output + '_key'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_k]
d[output + '_value_accum'] = {
'class': 'cum_concat', 'from': [output + '_value'], 'new_dim': key_dim_tag} # [B,T|rec-history,n,F|d_v]
else:
d[output + '_key_accum'] = {
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
'from': [output + '_key']} # [B,T|keys,n,F|d_k]
d[output + '_value_accum'] = {
'class': 'reinterpret_data', 'set_dim_tags': {query_axis: key_dim_tag},
'from': [output + '_value']} # [B,T|keys,n,F|d_v]

# Calculate the energies
d[output + '_energy'] = {
'class': 'dot', 'from': [output + '_query', output + '_key_accum'],
'red1': 'static:-1', 'red2': 'static:-1',
'var1': None if inside_rec_layer else query_axis, 'var2': key_dim_tag} # [B,n,T?,T|rec-history]

d[output + '_weights'] = {
'class': 'softmax_over_spatial', 'from': [output + '_energy'], 'axis': key_axis,
'energy_factor': key_dim ** -0.5} # [B,n,T?,T|rec-history]
d[output + '_weights_drop'] = {
'class': 'dropout', 'dropout_noise_shape': {'*': None}, 'from': [output + '_weights'],
'dropout': dropout} # [B,n,T?,T|rec-history]

d[output + '_output'] = {
'class': 'dot', 'from': [output + '_weights_drop', output + '_value_accum'],
'red1': key_axis, 'red2': key_axis,
"var1": None if inside_rec_layer else query_axis, "var2": "static:-1"} # [B,n,T?,F|d_v]
d[output + '_att'] = {'class': 'merge_dims', 'axes': 'static', 'from': [output + '_output']} # [B,T?,F|n*d_v]


def test_CumConcatLayer_self_attention_equal_to_SelfAttentionLayer():
n_time = 13
num_heads, key_dim, value_dim = 2, 3, 3
for inside_rec_layer in [False, True]:
with make_scope() as session:
print('Testing inside_rec_layer=%s' % inside_rec_layer)

# build net dict
if inside_rec_layer:
net_dict = {
"output": {
"class": "rec", "target": "classes", "from": [],
"unit": {
"single_layer_att": {
"class": "self_attention", "from": "prev:single_layer_att", "num_heads": num_heads,
"total_key_dim": num_heads * key_dim, "n_out": num_heads * value_dim,
"attention_left_only": inside_rec_layer, 'is_output_layer': True}, # [B,T,F]
"multi_layer_att": None, # [B,T,F], added below.
"output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]}}}}
_build_self_attention_layer(
net_dict["output"]["unit"], 'prev:multi_layer_att', 'multi_layer', inside_rec_layer=True,
query_axis='stag:extern_data:classes', num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
net_dict["output"]["unit"]["multi_layer_att"]["is_output_layer"] = True
net_dict["output"]["unit"]["multi_layer_qkv0"]["is_output_layer"] = True # we need to set the matrix here
else:
net_dict = {
"single_layer_att": {
"class": "self_attention", "from": "data", "num_heads": num_heads, "total_key_dim": num_heads * key_dim,
"n_out": num_heads * value_dim, "attention_left_only": inside_rec_layer,
'is_output_layer': True}, # [B,T,F]
"multi_layer_att": None, # [B,T,F], added below.
"output": {"class": "compare", "from": ["single_layer_att", "multi_layer_att"]}
}
_build_self_attention_layer(
net_dict, 'data', 'multi_layer', inside_rec_layer=False, query_axis='stag:extern_data:data',
num_heads=num_heads, key_dim=key_dim, value_dim=value_dim)
net_dict["multi_layer_att"]["is_output_layer"] = True

config = Config({
"debug_print_layer_output_template": True, "optimize_move_layers_out": True})
config.update(dict(num_inputs=num_heads*key_dim, num_outputs=num_heads*value_dim))
network = TFNetwork(config=config, train_flag=True)
from pprint import pprint
pprint(net_dict)
network.construct_from_dict(net_dict)

if inside_rec_layer:
single_layer = network.get_layer("output/single_layer_att")
multi_layer = network.get_layer("output/multi_layer_att")

# Note: single_layer.params etc. do not contain the params, need to access rec cell directly
rec_layer = network.get_layer("output")
single_weights = rec_layer.cell.net.get_layer("single_layer_att").params["QKV"]
multi_weights = rec_layer.cell.net.get_layer("multi_layer_qkv0").params["W"]
else:
single_layer = network.get_layer("single_layer_att")
multi_layer = network.get_layer("multi_layer_att")
single_weights = single_layer.params["QKV"]
multi_weights = network.get_layer("multi_layer_qkv0").params["W"]

assert_equal(single_layer.output.batch_shape, (None, None, num_heads * value_dim))
assert_equal(multi_layer.output.batch_shape, (None, None, num_heads * value_dim))

# set weights equal.
assert_equal(single_weights.shape, multi_weights.shape)
weights = numpy.random.rand(*single_weights.shape)
session.run(tf.compat.v1.assign(single_weights, weights))
session.run(tf.compat.v1.assign(multi_weights, weights))

# fetch/compare outputs
from tests.test_TFNetworkLayer import make_feed_dict
feed_dict = make_feed_dict(network.extern_data.data.values(), same_time=True, n_time=n_time)
single, multi = session.run(
[single_layer.output.placeholder, multi_layer.output.placeholder], feed_dict=feed_dict)
print('single layer output:')
pprint(single)
print('multi layer output:')
pprint(multi)
numpy.testing.assert_almost_equal(single, multi, decimal=5)
print('They are equal!')


if __name__ == "__main__":
try:
better_exchook.install()
Expand Down