@@ -1001,3 +1001,317 @@ def _my_reshape(x):
10011001 else :
10021002 return x
10031003 return _my_reshape (q ), _my_reshape (k ), _my_reshape (v ), _my_reshape (bias )
1004+
1005+
1006+ def make_params_mtlprompt (task_num ,
1007+ num_heads ,
1008+ prefix_hidden_dim ,
1009+ kv_dim ,
1010+ context ,
1011+ mtlprompt_share ,
1012+ dropout_rate ,
1013+ prompt_length = None ):
1014+ """Returns the parameters for MTL-Prompt mode."""
1015+
1016+ tf .logging .info ("MTL-Prompt mode is ON!" )
1017+ # Get task ids for the batch from the context cache.
1018+ task_id = context .cache ["task-id" ]
1019+
1020+ # Remove length dim from shape [batch_size, length] -> [batch_size].
1021+ task_id = mtf .reshape (task_id , task_id .shape - task_id .shape .dims [- 1 ])
1022+ batch_size_dim = task_id .shape .get_dim_by_name ("batch" )
1023+
1024+ task_num_dim = mtf .Dimension ("task_num" , task_num )
1025+ prompt_length_dim = mtf .Dimension ("memory_length" , prompt_length )
1026+ heads_dim = mtf .Dimension ("heads" , num_heads )
1027+ prefix_hidden_dim = mtf .Dimension ("prefix_hidden" , prefix_hidden_dim )
1028+
1029+ with tf .variable_scope ("prompt" ):
1030+ # Initialize projection network matrices.
1031+
1032+ # MTL-Prompt-Share mode.
1033+ hidden_shape = [context .model .model_dim , prefix_hidden_dim ]
1034+ scratch_shape = [prefix_hidden_dim , heads_dim , kv_dim ]
1035+ if not mtlprompt_share :
1036+ # MTL-Prompt-Sep mode.
1037+ hidden_shape = [task_num_dim ] + hidden_shape
1038+ scratch_shape = [task_num_dim ] + scratch_shape
1039+
1040+ w_k_up = mtf .get_variable (
1041+ context .mesh ,
1042+ "w_k_up" ,
1043+ mtf .Shape (hidden_shape ),
1044+ dtype = context .variable_dtype )
1045+ w_k_down = mtf .get_variable (
1046+ context .mesh ,
1047+ "w_k_down" ,
1048+ mtf .Shape (scratch_shape ),
1049+ dtype = context .variable_dtype )
1050+ w_v_up = mtf .get_variable (
1051+ context .mesh ,
1052+ "w_v_up" ,
1053+ mtf .Shape (hidden_shape ),
1054+ dtype = context .variable_dtype )
1055+ w_v_down = mtf .get_variable (
1056+ context .mesh ,
1057+ "w_v_down" ,
1058+ mtf .Shape (scratch_shape ),
1059+ dtype = context .variable_dtype )
1060+
1061+ scope_name = tf .get_variable_scope ().name
1062+ scope = scope_name .split ("/" )[0 ]
1063+ prompts = context .shared_params [scope ]["prompts" ]
1064+ prompts = mtf .layers .layer_norm (
1065+ prompts , dim = context .model .model_dim , name = "prompts_layernorm" )
1066+ if context .train and dropout_rate != 0.0 :
1067+ prompts = mtf .dropout (prompts , context .train , 1.0 - dropout_rate )
1068+
1069+ prompts_k_hidden = mtf .matmul (
1070+ prompts ,
1071+ w_k_up ,
1072+ output_shape = [task_num_dim , prompt_length_dim , prefix_hidden_dim ],
1073+ reduced_dims = [context .model .model_dim ])
1074+ prompts_k_hidden = mtf .relu (prompts_k_hidden )
1075+ if context .train and dropout_rate != 0.0 :
1076+ prompts_k_hidden = mtf .dropout (
1077+ prompts_k_hidden , context .train , keep_prob = 1.0 - dropout_rate )
1078+ prompts_k = mtf .matmul (
1079+ prompts_k_hidden ,
1080+ w_k_down ,
1081+ output_shape = [task_num_dim , prompt_length_dim , heads_dim , kv_dim ],
1082+ reduced_dims = [prefix_hidden_dim ])
1083+
1084+ prompts_v_hidden = mtf .matmul (
1085+ prompts ,
1086+ w_v_up ,
1087+ output_shape = [task_num_dim , prompt_length_dim , prefix_hidden_dim ],
1088+ reduced_dims = [context .model .model_dim ])
1089+ prompts_v_hidden = mtf .relu (prompts_v_hidden )
1090+ if context .train and dropout_rate != 0.0 :
1091+ prompts_v_hidden = mtf .dropout (
1092+ prompts_v_hidden , context .train , keep_prob = 1.0 - dropout_rate )
1093+ prompts_v = mtf .matmul (
1094+ prompts_v_hidden ,
1095+ w_v_down ,
1096+ output_shape = [task_num_dim , prompt_length_dim , heads_dim , kv_dim ],
1097+ reduced_dims = [prefix_hidden_dim ])
1098+
1099+ prompts_batch_k = mtf .gather (
1100+ prompts_k ,
1101+ task_id ,
1102+ task_num_dim ,
1103+ output_shape = [batch_size_dim , prompt_length_dim , heads_dim , kv_dim ])
1104+ prompts_batch_v = mtf .gather (
1105+ prompts_v ,
1106+ task_id ,
1107+ task_num_dim ,
1108+ output_shape = [batch_size_dim , prompt_length_dim , heads_dim , kv_dim ])
1109+ return prompts_batch_k , prompts_batch_v
1110+
1111+
1112+ def make_params_hyperprompt (task_num ,
1113+ num_heads ,
1114+ prefix_hidden_dim ,
1115+ kv_dim ,
1116+ context ,
1117+ mtlprompt_share ,
1118+ dropout_rate ,
1119+ prompt_length ,
1120+ scope = None ):
1121+ """Returns the parameters for HyperPrompt mode."""
1122+ del mtlprompt_share
1123+ tf .logging .info ("HyperPrompt mode is ON!" )
1124+ # Get task ids for the batch from the context cache.
1125+ task_id = context .cache ["task-id" ]
1126+
1127+ # Remove length dim from shape [batch_size, length].
1128+ task_id = mtf .reshape (task_id , task_id .shape - task_id .shape .dims [- 1 ])
1129+
1130+ batch_size_dim = [dim for dim in task_id .shape .dims if dim .name == "batch" ][0 ]
1131+
1132+ task_num_dim = mtf .Dimension ("task_num" , task_num )
1133+ prompt_length_dim = mtf .Dimension ("memory_length" , prompt_length )
1134+ heads_dim = mtf .Dimension ("heads" , num_heads )
1135+ prefix_hidden_dim = mtf .Dimension ("prefix_hidden" , prefix_hidden_dim )
1136+
1137+ prompts = context .shared_params [scope ]["prompts" ]
1138+ prompts = mtf .layers .layer_norm (
1139+ prompts , dim = context .model .model_dim , name = "prompts_layernorm" )
1140+ if context .train and dropout_rate != 0.0 :
1141+ prompts = mtf .dropout (prompts , context .train , 1.0 - dropout_rate )
1142+
1143+ scope_name = tf .get_variable_scope ().name
1144+ scope = scope_name .split ("/" )[0 ]
1145+
1146+ # Get the layer id.
1147+ layer_id = int (scope_name .split ("/" )[1 ].split ("_" )[1 ])
1148+
1149+ task_raw_embeddings = context .shared_params [scope ]["task_raw_embedding" ]
1150+ task_raw_embedding_dim = task_raw_embeddings .shape .dims [1 ]
1151+
1152+ task_projector_layer_one = context .shared_params [scope ][
1153+ "task_projector_layer_one" ]
1154+ task_projector_layer_one_in_dim = task_projector_layer_one .shape .dims [0 ]
1155+ task_hidden_dim = task_projector_layer_one .shape .dims [1 ]
1156+
1157+ task_projector_layer_two = context .shared_params [scope ][
1158+ "task_projector_layer_two" ]
1159+ task_final_embedding_dim = task_projector_layer_two .shape .dims [1 ]
1160+
1161+ layer_id_embeddings = context .shared_params [scope ]["layer_embedding" ]
1162+ layer_num_dim = layer_id_embeddings .shape .dims [0 ]
1163+ layer_id_embedding_dim = layer_id_embeddings .shape .dims [1 ]
1164+
1165+ # Get the layer id embedding for the batch.
1166+ if layer_id not in range (0 , layer_num_dim .size ):
1167+ raise ValueError ("encounter errors in parsing scope get layer_id." )
1168+ layer_id_task_num = mtf .constant (
1169+ task_raw_embeddings .mesh ,
1170+ layer_id ,
1171+ shape = mtf .Shape ([task_num_dim ]),
1172+ dtype = tf .int32 )
1173+ layer_id_emb_task_num = mtf .gather (
1174+ layer_id_embeddings ,
1175+ layer_id_task_num ,
1176+ layer_num_dim ,
1177+ output_shape = [task_num_dim , layer_id_embedding_dim ])
1178+
1179+ task_embeddings_concat = mtf .concat (
1180+ [task_raw_embeddings , layer_id_emb_task_num ],
1181+ concat_dim_name = task_raw_embedding_dim .name )
1182+
1183+ # Feed raw task-embedding to MLP to obtain the layer-aware task embedding.
1184+ task_embeddings_concat_hidden = mtf .matmul (
1185+ task_embeddings_concat ,
1186+ task_projector_layer_one ,
1187+ output_shape = [task_num_dim , task_hidden_dim ],
1188+ reduced_dims = [task_projector_layer_one_in_dim ])
1189+ task_embeddings_concat_hidden_relu = mtf .relu (task_embeddings_concat_hidden )
1190+
1191+ if context .train and dropout_rate != 0.0 :
1192+ task_embeddings_concat_hidden_relu = mtf .dropout (
1193+ task_embeddings_concat_hidden_relu ,
1194+ context .train ,
1195+ keep_prob = 1.0 - dropout_rate )
1196+
1197+ task_embeddings_layer_awared = mtf .matmul (
1198+ task_embeddings_concat_hidden_relu ,
1199+ task_projector_layer_two ,
1200+ output_shape = [task_num_dim , task_final_embedding_dim ],
1201+ reduced_dims = [task_hidden_dim ])
1202+
1203+ task_embeddings_layer_awared = mtf .layers .layer_norm (
1204+ task_embeddings_layer_awared ,
1205+ dim = task_final_embedding_dim ,
1206+ name = "prompt_task_embed_layernorm" )
1207+
1208+ hypernet_w_k_up = context .shared_params [scope ]["hypernet_w_k_up" ]
1209+ hypernet_w_k_down = context .shared_params [scope ]["hypernet_w_k_down" ]
1210+ hypernet_w_v_up = context .shared_params [scope ]["hypernet_w_v_up" ]
1211+ hypernet_w_v_down = context .shared_params [scope ]["hypernet_w_v_down" ]
1212+
1213+ # Hypernetwork generates the prompts transformation
1214+ w_k_up = mtf .matmul (
1215+ task_embeddings_layer_awared ,
1216+ hypernet_w_k_up ,
1217+ output_shape = [task_num_dim , context .model .model_dim , prefix_hidden_dim ],
1218+ reduced_dims = [task_final_embedding_dim ])
1219+
1220+ w_k_down = mtf .matmul (
1221+ task_embeddings_layer_awared ,
1222+ hypernet_w_k_down ,
1223+ output_shape = [task_num_dim , prefix_hidden_dim , heads_dim , kv_dim ],
1224+ reduced_dims = [task_final_embedding_dim ])
1225+
1226+ w_v_up = mtf .matmul (
1227+ task_embeddings_layer_awared ,
1228+ hypernet_w_v_up ,
1229+ output_shape = [task_num_dim , context .model .model_dim , prefix_hidden_dim ],
1230+ reduced_dims = [task_final_embedding_dim ])
1231+
1232+ w_v_down = mtf .matmul (
1233+ task_embeddings_layer_awared ,
1234+ hypernet_w_v_down ,
1235+ output_shape = [task_num_dim , prefix_hidden_dim , heads_dim , kv_dim ],
1236+ reduced_dims = [task_final_embedding_dim ])
1237+
1238+ prompts_k_hidden = mtf .matmul (
1239+ prompts ,
1240+ w_k_up ,
1241+ output_shape = [task_num_dim , prompt_length_dim , prefix_hidden_dim ],
1242+ reduced_dims = [context .model .model_dim ])
1243+ prompts_k_hidden = mtf .relu (prompts_k_hidden )
1244+ if context .train and dropout_rate != 0.0 :
1245+ prompts_k_hidden = mtf .dropout (
1246+ prompts_k_hidden , context .train , keep_prob = 1.0 - dropout_rate )
1247+ prompts_k = mtf .matmul (
1248+ prompts_k_hidden ,
1249+ w_k_down ,
1250+ output_shape = [task_num_dim , prompt_length_dim , heads_dim , kv_dim ],
1251+ reduced_dims = [prefix_hidden_dim ])
1252+ prompts_batch_k = mtf .gather (
1253+ prompts_k ,
1254+ task_id ,
1255+ task_num_dim ,
1256+ output_shape = [batch_size_dim , prompt_length_dim , heads_dim , kv_dim ])
1257+
1258+ prompts_v_hidden = mtf .matmul (
1259+ prompts ,
1260+ w_v_up ,
1261+ output_shape = [task_num_dim , prompt_length_dim , prefix_hidden_dim ],
1262+ reduced_dims = [context .model .model_dim ])
1263+ prompts_v_hidden = mtf .relu (prompts_v_hidden )
1264+ if context .train and dropout_rate != 0.0 :
1265+ prompts_v_hidden = mtf .dropout (
1266+ prompts_v_hidden , context .train , keep_prob = 1.0 - dropout_rate )
1267+ prompts_v = mtf .matmul (
1268+ prompts_v_hidden ,
1269+ w_v_down ,
1270+ output_shape = [task_num_dim , prompt_length_dim , heads_dim , kv_dim ],
1271+ reduced_dims = [prefix_hidden_dim ])
1272+ prompts_batch_v = mtf .gather (
1273+ prompts_v ,
1274+ task_id ,
1275+ task_num_dim ,
1276+ output_shape = [batch_size_dim , prompt_length_dim , heads_dim , kv_dim ])
1277+
1278+ return prompts_batch_k , prompts_batch_v
1279+
1280+
1281+ def concat_hyper_prompts_kv (k , v , scope_encoder_or_decoder , use_hyperprompt ,
1282+ memory_length , task_num , num_heads ,
1283+ prefix_hidden_dim , kv_dim , context , mtlprompt_share ,
1284+ dropout_rate , prompt_length ):
1285+ """Performs the concatenation of hyper prompts to key and value."""
1286+ # Inject hyper-prompts into keys and values.
1287+ if use_hyperprompt :
1288+ # HyperPrompt mode.
1289+ prompts_batch_k , prompts_batch_v = make_params_hyperprompt (
1290+ task_num ,
1291+ num_heads ,
1292+ prefix_hidden_dim ,
1293+ kv_dim ,
1294+ context ,
1295+ mtlprompt_share ,
1296+ dropout_rate ,
1297+ prompt_length = prompt_length ,
1298+ scope = scope_encoder_or_decoder )
1299+ else :
1300+ # MTL-Prompt mode.
1301+ prompts_batch_k , prompts_batch_v = make_params_mtlprompt (
1302+ task_num ,
1303+ num_heads ,
1304+ prefix_hidden_dim ,
1305+ kv_dim ,
1306+ context ,
1307+ mtlprompt_share ,
1308+ dropout_rate ,
1309+ prompt_length = prompt_length )
1310+
1311+ k = mtf .concat ([prompts_batch_k , k ], concat_dim_name = "memory_length" )
1312+ v = mtf .concat ([prompts_batch_v , v ], concat_dim_name = "memory_length" )
1313+ memory_length = mtf .Dimension ("memory_length" ,
1314+ memory_length .size + prompt_length )
1315+ memory_position = mtf .range (context .mesh , memory_length , tf .int32 )
1316+
1317+ return k , v , memory_position , memory_length
0 commit comments