Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 39f4bd6

Browse files
author
Mesh TensorFlow Team
committed
#HyperPrompt Part 2 of HyperPrompt implementation: the actual computation of HyperPrompt inside self-attention layer.
PiperOrigin-RevId: 429613966
1 parent bbb6ce7 commit 39f4bd6

File tree

2 files changed

+440
-12
lines changed

2 files changed

+440
-12
lines changed

mesh_tensorflow/transformer/attention.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,3 +1001,317 @@ def _my_reshape(x):
10011001
else:
10021002
return x
10031003
return _my_reshape(q), _my_reshape(k), _my_reshape(v), _my_reshape(bias)
1004+
1005+
1006+
def make_params_mtlprompt(task_num,
1007+
num_heads,
1008+
prefix_hidden_dim,
1009+
kv_dim,
1010+
context,
1011+
mtlprompt_share,
1012+
dropout_rate,
1013+
prompt_length=None):
1014+
"""Returns the parameters for MTL-Prompt mode."""
1015+
1016+
tf.logging.info("MTL-Prompt mode is ON!")
1017+
# Get task ids for the batch from the context cache.
1018+
task_id = context.cache["task-id"]
1019+
1020+
# Remove length dim from shape [batch_size, length] -> [batch_size].
1021+
task_id = mtf.reshape(task_id, task_id.shape - task_id.shape.dims[-1])
1022+
batch_size_dim = task_id.shape.get_dim_by_name("batch")
1023+
1024+
task_num_dim = mtf.Dimension("task_num", task_num)
1025+
prompt_length_dim = mtf.Dimension("memory_length", prompt_length)
1026+
heads_dim = mtf.Dimension("heads", num_heads)
1027+
prefix_hidden_dim = mtf.Dimension("prefix_hidden", prefix_hidden_dim)
1028+
1029+
with tf.variable_scope("prompt"):
1030+
# Initialize projection network matrices.
1031+
1032+
# MTL-Prompt-Share mode.
1033+
hidden_shape = [context.model.model_dim, prefix_hidden_dim]
1034+
scratch_shape = [prefix_hidden_dim, heads_dim, kv_dim]
1035+
if not mtlprompt_share:
1036+
# MTL-Prompt-Sep mode.
1037+
hidden_shape = [task_num_dim] + hidden_shape
1038+
scratch_shape = [task_num_dim] + scratch_shape
1039+
1040+
w_k_up = mtf.get_variable(
1041+
context.mesh,
1042+
"w_k_up",
1043+
mtf.Shape(hidden_shape),
1044+
dtype=context.variable_dtype)
1045+
w_k_down = mtf.get_variable(
1046+
context.mesh,
1047+
"w_k_down",
1048+
mtf.Shape(scratch_shape),
1049+
dtype=context.variable_dtype)
1050+
w_v_up = mtf.get_variable(
1051+
context.mesh,
1052+
"w_v_up",
1053+
mtf.Shape(hidden_shape),
1054+
dtype=context.variable_dtype)
1055+
w_v_down = mtf.get_variable(
1056+
context.mesh,
1057+
"w_v_down",
1058+
mtf.Shape(scratch_shape),
1059+
dtype=context.variable_dtype)
1060+
1061+
scope_name = tf.get_variable_scope().name
1062+
scope = scope_name.split("/")[0]
1063+
prompts = context.shared_params[scope]["prompts"]
1064+
prompts = mtf.layers.layer_norm(
1065+
prompts, dim=context.model.model_dim, name="prompts_layernorm")
1066+
if context.train and dropout_rate != 0.0:
1067+
prompts = mtf.dropout(prompts, context.train, 1.0 - dropout_rate)
1068+
1069+
prompts_k_hidden = mtf.matmul(
1070+
prompts,
1071+
w_k_up,
1072+
output_shape=[task_num_dim, prompt_length_dim, prefix_hidden_dim],
1073+
reduced_dims=[context.model.model_dim])
1074+
prompts_k_hidden = mtf.relu(prompts_k_hidden)
1075+
if context.train and dropout_rate != 0.0:
1076+
prompts_k_hidden = mtf.dropout(
1077+
prompts_k_hidden, context.train, keep_prob=1.0 - dropout_rate)
1078+
prompts_k = mtf.matmul(
1079+
prompts_k_hidden,
1080+
w_k_down,
1081+
output_shape=[task_num_dim, prompt_length_dim, heads_dim, kv_dim],
1082+
reduced_dims=[prefix_hidden_dim])
1083+
1084+
prompts_v_hidden = mtf.matmul(
1085+
prompts,
1086+
w_v_up,
1087+
output_shape=[task_num_dim, prompt_length_dim, prefix_hidden_dim],
1088+
reduced_dims=[context.model.model_dim])
1089+
prompts_v_hidden = mtf.relu(prompts_v_hidden)
1090+
if context.train and dropout_rate != 0.0:
1091+
prompts_v_hidden = mtf.dropout(
1092+
prompts_v_hidden, context.train, keep_prob=1.0 - dropout_rate)
1093+
prompts_v = mtf.matmul(
1094+
prompts_v_hidden,
1095+
w_v_down,
1096+
output_shape=[task_num_dim, prompt_length_dim, heads_dim, kv_dim],
1097+
reduced_dims=[prefix_hidden_dim])
1098+
1099+
prompts_batch_k = mtf.gather(
1100+
prompts_k,
1101+
task_id,
1102+
task_num_dim,
1103+
output_shape=[batch_size_dim, prompt_length_dim, heads_dim, kv_dim])
1104+
prompts_batch_v = mtf.gather(
1105+
prompts_v,
1106+
task_id,
1107+
task_num_dim,
1108+
output_shape=[batch_size_dim, prompt_length_dim, heads_dim, kv_dim])
1109+
return prompts_batch_k, prompts_batch_v
1110+
1111+
1112+
def make_params_hyperprompt(task_num,
1113+
num_heads,
1114+
prefix_hidden_dim,
1115+
kv_dim,
1116+
context,
1117+
mtlprompt_share,
1118+
dropout_rate,
1119+
prompt_length,
1120+
scope=None):
1121+
"""Returns the parameters for HyperPrompt mode."""
1122+
del mtlprompt_share
1123+
tf.logging.info("HyperPrompt mode is ON!")
1124+
# Get task ids for the batch from the context cache.
1125+
task_id = context.cache["task-id"]
1126+
1127+
# Remove length dim from shape [batch_size, length].
1128+
task_id = mtf.reshape(task_id, task_id.shape - task_id.shape.dims[-1])
1129+
1130+
batch_size_dim = [dim for dim in task_id.shape.dims if dim.name == "batch"][0]
1131+
1132+
task_num_dim = mtf.Dimension("task_num", task_num)
1133+
prompt_length_dim = mtf.Dimension("memory_length", prompt_length)
1134+
heads_dim = mtf.Dimension("heads", num_heads)
1135+
prefix_hidden_dim = mtf.Dimension("prefix_hidden", prefix_hidden_dim)
1136+
1137+
prompts = context.shared_params[scope]["prompts"]
1138+
prompts = mtf.layers.layer_norm(
1139+
prompts, dim=context.model.model_dim, name="prompts_layernorm")
1140+
if context.train and dropout_rate != 0.0:
1141+
prompts = mtf.dropout(prompts, context.train, 1.0 - dropout_rate)
1142+
1143+
scope_name = tf.get_variable_scope().name
1144+
scope = scope_name.split("/")[0]
1145+
1146+
# Get the layer id.
1147+
layer_id = int(scope_name.split("/")[1].split("_")[1])
1148+
1149+
task_raw_embeddings = context.shared_params[scope]["task_raw_embedding"]
1150+
task_raw_embedding_dim = task_raw_embeddings.shape.dims[1]
1151+
1152+
task_projector_layer_one = context.shared_params[scope][
1153+
"task_projector_layer_one"]
1154+
task_projector_layer_one_in_dim = task_projector_layer_one.shape.dims[0]
1155+
task_hidden_dim = task_projector_layer_one.shape.dims[1]
1156+
1157+
task_projector_layer_two = context.shared_params[scope][
1158+
"task_projector_layer_two"]
1159+
task_final_embedding_dim = task_projector_layer_two.shape.dims[1]
1160+
1161+
layer_id_embeddings = context.shared_params[scope]["layer_embedding"]
1162+
layer_num_dim = layer_id_embeddings.shape.dims[0]
1163+
layer_id_embedding_dim = layer_id_embeddings.shape.dims[1]
1164+
1165+
# Get the layer id embedding for the batch.
1166+
if layer_id not in range(0, layer_num_dim.size):
1167+
raise ValueError("encounter errors in parsing scope get layer_id.")
1168+
layer_id_task_num = mtf.constant(
1169+
task_raw_embeddings.mesh,
1170+
layer_id,
1171+
shape=mtf.Shape([task_num_dim]),
1172+
dtype=tf.int32)
1173+
layer_id_emb_task_num = mtf.gather(
1174+
layer_id_embeddings,
1175+
layer_id_task_num,
1176+
layer_num_dim,
1177+
output_shape=[task_num_dim, layer_id_embedding_dim])
1178+
1179+
task_embeddings_concat = mtf.concat(
1180+
[task_raw_embeddings, layer_id_emb_task_num],
1181+
concat_dim_name=task_raw_embedding_dim.name)
1182+
1183+
# Feed raw task-embedding to MLP to obtain the layer-aware task embedding.
1184+
task_embeddings_concat_hidden = mtf.matmul(
1185+
task_embeddings_concat,
1186+
task_projector_layer_one,
1187+
output_shape=[task_num_dim, task_hidden_dim],
1188+
reduced_dims=[task_projector_layer_one_in_dim])
1189+
task_embeddings_concat_hidden_relu = mtf.relu(task_embeddings_concat_hidden)
1190+
1191+
if context.train and dropout_rate != 0.0:
1192+
task_embeddings_concat_hidden_relu = mtf.dropout(
1193+
task_embeddings_concat_hidden_relu,
1194+
context.train,
1195+
keep_prob=1.0 - dropout_rate)
1196+
1197+
task_embeddings_layer_awared = mtf.matmul(
1198+
task_embeddings_concat_hidden_relu,
1199+
task_projector_layer_two,
1200+
output_shape=[task_num_dim, task_final_embedding_dim],
1201+
reduced_dims=[task_hidden_dim])
1202+
1203+
task_embeddings_layer_awared = mtf.layers.layer_norm(
1204+
task_embeddings_layer_awared,
1205+
dim=task_final_embedding_dim,
1206+
name="prompt_task_embed_layernorm")
1207+
1208+
hypernet_w_k_up = context.shared_params[scope]["hypernet_w_k_up"]
1209+
hypernet_w_k_down = context.shared_params[scope]["hypernet_w_k_down"]
1210+
hypernet_w_v_up = context.shared_params[scope]["hypernet_w_v_up"]
1211+
hypernet_w_v_down = context.shared_params[scope]["hypernet_w_v_down"]
1212+
1213+
# Hypernetwork generates the prompts transformation
1214+
w_k_up = mtf.matmul(
1215+
task_embeddings_layer_awared,
1216+
hypernet_w_k_up,
1217+
output_shape=[task_num_dim, context.model.model_dim, prefix_hidden_dim],
1218+
reduced_dims=[task_final_embedding_dim])
1219+
1220+
w_k_down = mtf.matmul(
1221+
task_embeddings_layer_awared,
1222+
hypernet_w_k_down,
1223+
output_shape=[task_num_dim, prefix_hidden_dim, heads_dim, kv_dim],
1224+
reduced_dims=[task_final_embedding_dim])
1225+
1226+
w_v_up = mtf.matmul(
1227+
task_embeddings_layer_awared,
1228+
hypernet_w_v_up,
1229+
output_shape=[task_num_dim, context.model.model_dim, prefix_hidden_dim],
1230+
reduced_dims=[task_final_embedding_dim])
1231+
1232+
w_v_down = mtf.matmul(
1233+
task_embeddings_layer_awared,
1234+
hypernet_w_v_down,
1235+
output_shape=[task_num_dim, prefix_hidden_dim, heads_dim, kv_dim],
1236+
reduced_dims=[task_final_embedding_dim])
1237+
1238+
prompts_k_hidden = mtf.matmul(
1239+
prompts,
1240+
w_k_up,
1241+
output_shape=[task_num_dim, prompt_length_dim, prefix_hidden_dim],
1242+
reduced_dims=[context.model.model_dim])
1243+
prompts_k_hidden = mtf.relu(prompts_k_hidden)
1244+
if context.train and dropout_rate != 0.0:
1245+
prompts_k_hidden = mtf.dropout(
1246+
prompts_k_hidden, context.train, keep_prob=1.0 - dropout_rate)
1247+
prompts_k = mtf.matmul(
1248+
prompts_k_hidden,
1249+
w_k_down,
1250+
output_shape=[task_num_dim, prompt_length_dim, heads_dim, kv_dim],
1251+
reduced_dims=[prefix_hidden_dim])
1252+
prompts_batch_k = mtf.gather(
1253+
prompts_k,
1254+
task_id,
1255+
task_num_dim,
1256+
output_shape=[batch_size_dim, prompt_length_dim, heads_dim, kv_dim])
1257+
1258+
prompts_v_hidden = mtf.matmul(
1259+
prompts,
1260+
w_v_up,
1261+
output_shape=[task_num_dim, prompt_length_dim, prefix_hidden_dim],
1262+
reduced_dims=[context.model.model_dim])
1263+
prompts_v_hidden = mtf.relu(prompts_v_hidden)
1264+
if context.train and dropout_rate != 0.0:
1265+
prompts_v_hidden = mtf.dropout(
1266+
prompts_v_hidden, context.train, keep_prob=1.0 - dropout_rate)
1267+
prompts_v = mtf.matmul(
1268+
prompts_v_hidden,
1269+
w_v_down,
1270+
output_shape=[task_num_dim, prompt_length_dim, heads_dim, kv_dim],
1271+
reduced_dims=[prefix_hidden_dim])
1272+
prompts_batch_v = mtf.gather(
1273+
prompts_v,
1274+
task_id,
1275+
task_num_dim,
1276+
output_shape=[batch_size_dim, prompt_length_dim, heads_dim, kv_dim])
1277+
1278+
return prompts_batch_k, prompts_batch_v
1279+
1280+
1281+
def concat_hyper_prompts_kv(k, v, scope_encoder_or_decoder, use_hyperprompt,
1282+
memory_length, task_num, num_heads,
1283+
prefix_hidden_dim, kv_dim, context, mtlprompt_share,
1284+
dropout_rate, prompt_length):
1285+
"""Performs the concatenation of hyper prompts to key and value."""
1286+
# Inject hyper-prompts into keys and values.
1287+
if use_hyperprompt:
1288+
# HyperPrompt mode.
1289+
prompts_batch_k, prompts_batch_v = make_params_hyperprompt(
1290+
task_num,
1291+
num_heads,
1292+
prefix_hidden_dim,
1293+
kv_dim,
1294+
context,
1295+
mtlprompt_share,
1296+
dropout_rate,
1297+
prompt_length=prompt_length,
1298+
scope=scope_encoder_or_decoder)
1299+
else:
1300+
# MTL-Prompt mode.
1301+
prompts_batch_k, prompts_batch_v = make_params_mtlprompt(
1302+
task_num,
1303+
num_heads,
1304+
prefix_hidden_dim,
1305+
kv_dim,
1306+
context,
1307+
mtlprompt_share,
1308+
dropout_rate,
1309+
prompt_length=prompt_length)
1310+
1311+
k = mtf.concat([prompts_batch_k, k], concat_dim_name="memory_length")
1312+
v = mtf.concat([prompts_batch_v, v], concat_dim_name="memory_length")
1313+
memory_length = mtf.Dimension("memory_length",
1314+
memory_length.size + prompt_length)
1315+
memory_position = mtf.range(context.mesh, memory_length, tf.int32)
1316+
1317+
return k, v, memory_position, memory_length

0 commit comments

Comments
 (0)