@@ -291,6 +291,46 @@ class MaxTextModel:
291291 ),
292292)
293293
294+
295+ llama2_70b_4096_real_data_pw_long_run = MaxTextModel (
296+ model_name = "llama2-70b-4096-rd-pw-lr" ,
297+ model_type = "llama2-70b" ,
298+ tuning_params = {
299+ "per_device_batch_size" : 4 ,
300+ "ici_fsdp_parallelism" : - 1 ,
301+ "remat_policy" : "full" ,
302+ "max_target_length" : 4096 ,
303+ "attention" : "flash" ,
304+ "gcs_metrics" : True ,
305+ "use_iota_embed" : True ,
306+ "reuse_example_batch" : 0 ,
307+ "profiler" : "xplane" ,
308+ "dataset_path" : "gs://max-datasets-rogue" ,
309+ "dataset_type" : "tfds" ,
310+ "tokenizer_path" : "assets/tokenizer.llama2" ,
311+ "sa_block_q" : 1024 ,
312+ "sa_block_q_dkv" : 2048 ,
313+ "sa_block_q_dq" : 2048 ,
314+ "steps" : 1000000 ,
315+
316+ # Additional tuning params for pathways long running test.
317+ "enable_checkpointing" : True ,
318+ "async_checkpointing" : True ,
319+ "checkpoint_period" : 100 ,
320+ "checkpoint_storage_use_ocdbt" : False ,
321+ "checkpoint_storage_use_zarr3" : False ,
322+ "metrics_file" : "metrics.txt" ,
323+ "goodput_upload_interval_seconds" : 30 ,
324+ "enable_pathways_goodput" : True ,
325+ "enable_checkpoint_cloud_logger" : True ,
326+ "enable_single_controller" : True ,
327+ },
328+ xla_flags = (
329+ xla_flags_library .DENSE_VMEM_LIMIT_FLAG
330+ + xla_flags_library .CF_FOR_ALL_GATHER
331+ ),
332+ )
333+
294334# ici_fsdp_transpose_parallelism gives one TFLOP better performance.
295335llama2_70b_4096 = MaxTextModel (
296336 model_name = "llama2-70b-4096" ,
@@ -319,6 +359,151 @@ class MaxTextModel:
319359 + xla_flags_library .CF_FOR_ALL_GATHER
320360 ),
321361)
362+ llama2_70b_4096_synthetic = MaxTextModel (
363+ model_name = "llama2_70b_4096_synthetic" ,
364+ model_type = "llama2-70b" ,
365+ tuning_params = {
366+ "per_device_batch_size" : 2 ,
367+ "ici_fsdp_parallelism" : 1 ,
368+ "ici_fsdp_transpose_parallelism" : - 1 ,
369+ "ici_tensor_parallelism" : 1 ,
370+ "remat_policy" : "qkv_proj_offloaded" ,
371+ "max_target_length" : 4096 ,
372+ "attention" : "flash" ,
373+ "gcs_metrics" : True ,
374+ "use_iota_embed" : True ,
375+ "dataset_path" : "gs://max-datasets-rogue" ,
376+ "dataset_type" : "synthetic" ,
377+ "enable_checkpointing" : False ,
378+ "profiler" : "xplane" ,
379+ "sa_block_q" : 1024 ,
380+ "sa_block_q_dkv" : 2048 ,
381+ "sa_block_q_dq" : 2048 ,
382+ },
383+ xla_flags = (
384+ xla_flags_library .DENSE_VMEM_LIMIT_FLAG
385+ + xla_flags_library .CF_FOR_ALL_GATHER
386+ ),
387+ )
388+
389+ llama2_70b_4096_synthetic_pw_lr = MaxTextModel (
390+ model_name = "llama2_70b_4096_synthetic_pw_lr" ,
391+ model_type = "llama2-70b" ,
392+ tuning_params = {
393+ "per_device_batch_size" : 2 ,
394+ "ici_fsdp_parallelism" : 1 ,
395+ "ici_fsdp_transpose_parallelism" : - 1 ,
396+ "ici_tensor_parallelism" : 1 ,
397+ "remat_policy" : "qkv_proj_offloaded" ,
398+ "max_target_length" : 4096 ,
399+ "attention" : "flash" ,
400+ "gcs_metrics" : True ,
401+ "use_iota_embed" : True ,
402+ "dataset_path" : "gs://max-datasets-rogue" ,
403+ "dataset_type" : "synthetic" ,
404+ # "enable_checkpointing": False,
405+ "profiler" : "xplane" ,
406+ "sa_block_q" : 1024 ,
407+ "sa_block_q_dkv" : 2048 ,
408+ "sa_block_q_dq" : 2048 ,
409+ "steps" : 1000000 ,
410+
411+ # Additional tuning params for pathways long running test.
412+ "enable_checkpointing" : True ,
413+ "async_checkpointing" : True ,
414+ "checkpoint_period" : 100 ,
415+ "checkpoint_storage_use_ocdbt" : False ,
416+ "checkpoint_storage_use_zarr3" : False ,
417+ "metrics_file" : "metrics.txt" ,
418+ "goodput_upload_interval_seconds" : 30 ,
419+ "enable_pathways_goodput" : True ,
420+ "enable_checkpoint_cloud_logger" : True ,
421+ "enable_single_controller" : True ,
422+ },
423+ xla_flags = (
424+ xla_flags_library .DENSE_VMEM_LIMIT_FLAG
425+ + xla_flags_library .CF_FOR_ALL_GATHER
426+ ),
427+ )
428+
429+ llama2_70b_4096_pw_long_run = MaxTextModel (
430+ model_name = "llama2-70b-4096-pw-lr" ,
431+ model_type = "llama2-70b" ,
432+ tuning_params = {
433+ "per_device_batch_size" : 4 ,
434+ "ici_fsdp_parallelism" : 1 ,
435+ "ici_fsdp_transpose_parallelism" : - 1 ,
436+ "ici_tensor_parallelism" : 1 ,
437+ "remat_policy" : "full" ,
438+ "max_target_length" : 4096 ,
439+ "attention" : "flash" ,
440+ "gcs_metrics" : True ,
441+ "use_iota_embed" : True ,
442+ "dataset_path" : "gs://max-datasets-rogue" ,
443+ "dataset_type" : "synthetic" ,
444+ "reuse_example_batch" : 1 ,
445+ "profiler" : "xplane" ,
446+ "sa_block_q" : 1024 ,
447+ "sa_block_q_dkv" : 2048 ,
448+ "sa_block_q_dq" : 2048 ,
449+ "steps" : 1000000 ,
450+
451+ # Additional tuning params for pathways long running test.
452+ "enable_checkpointing" : True ,
453+ "async_checkpointing" : True ,
454+ "checkpoint_period" : 100 ,
455+ "checkpoint_storage_use_ocdbt" : False ,
456+ "checkpoint_storage_use_zarr3" : False ,
457+ "metrics_file" : "metrics.txt" ,
458+ "goodput_upload_interval_seconds" : 30 ,
459+ "enable_pathways_goodput" : True ,
460+ "enable_checkpoint_cloud_logger" : True ,
461+ "enable_single_controller" : True ,
462+ },
463+ xla_flags = (
464+ xla_flags_library .DENSE_VMEM_LIMIT_FLAG
465+ + xla_flags_library .CF_FOR_ALL_GATHER
466+ ),
467+ )
468+
469+ llama2_70b_4096_pw_rd_tfds = MaxTextModel (
470+ model_name = "llama2_70b_4096_pw_rd_tfds" ,
471+ model_type = "llama2-70b" ,
472+ tuning_params = {
473+ "per_device_batch_size" : 2 ,
474+ "ici_fsdp_parallelism" : 1 ,
475+ "ici_fsdp_transpose_parallelism" : - 1 ,
476+ "ici_tensor_parallelism" : 1 ,
477+ "remat_policy" : "qkv_proj_offloaded" ,
478+ "max_target_length" : 4096 ,
479+ "attention" : "flash" ,
480+ "gcs_metrics" : True ,
481+ "use_iota_embed" : True ,
482+ "dataset_path" : "gs://trillium-storage-datasets-sr" ,
483+ "enable_checkpointing" : False ,
484+ "profiler" : "xplane" ,
485+ "sa_block_q" : 1024 ,
486+ "sa_block_q_dkv" : 2048 ,
487+ "sa_block_q_dq" : 2048 ,
488+
489+ # Additional tuning params for pathways long running test.
490+ "enable_checkpointing" : True ,
491+ "async_checkpointing" : True ,
492+ "checkpoint_period" : 100 ,
493+ "checkpoint_storage_use_ocdbt" : False ,
494+ "checkpoint_storage_use_zarr3" : False ,
495+ "metrics_file" : "metrics.txt" ,
496+ "goodput_upload_interval_seconds" : 30 ,
497+ "enable_pathways_goodput" : True ,
498+ "enable_checkpoint_cloud_logger" : True ,
499+ "enable_single_controller" : True ,
500+ },
501+ xla_flags = (
502+ xla_flags_library .DENSE_VMEM_LIMIT_FLAG
503+ + xla_flags_library .CF_FOR_ALL_GATHER
504+ ),
505+ )
506+
322507
323508llama3_8b_8192 = MaxTextModel (
324509 model_name = "llama3-8b-8192" ,
@@ -695,9 +880,14 @@ class MaxTextModel:
695880 gpt_3_175b ,
696881 llama2_7b_4096 ,
697882 llama2_70b_4096 ,
883+ llama2_70b_4096_pw_long_run ,
698884 llama2_70b_4096_real_data ,
885+ llama2_70b_4096_real_data_pw_long_run ,
886+ llama2_70b_4096_pw_rd_tfds ,
699887 llama3_8b_8192 , # Not Optimizied yet
700888 llama3_70b_8192 , # Not Optimizied yet
889+ llama2_70b_4096_synthetic_pw_lr ,
890+ llama2_70b_4096_synthetic ,
701891 llama3_1_405b_8192_fsdp_dcn ,
702892 llama3_1_8b_8192 ,
703893 llama3_1_70b_8192 ,
0 commit comments