Skip to content

Commit 3889010

Browse files
author
maxtext authors
committed
Merge pull request AI-Hypercomputer#1094 from AI-Hypercomputer:sujinesh/llama2_v6e_pw_long_running_test
PiperOrigin-RevId: 714177448
2 parents 776da9f + de7e105 commit 3889010

File tree

3 files changed

+338
-37
lines changed

3 files changed

+338
-37
lines changed

benchmarks/benchmark_runner.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext_xpk_runner import BenchmarkRunner
2929
from maxtext_xpk_runner import HWConfig
3030
from maxtext_xpk_runner import SWconfig
31+
from maxtext_xpk_runner import PathwaysConfig
3132
from maxtext_xpk_runner import xpk_benchmark_runner
3233
from maxtext_xpk_runner import XpkConfig
3334

@@ -86,6 +87,11 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
8687
'llama2_7b_4096',
8788
'llama2_70b_4096',
8889
'llama2_70b_4096_real_data',
90+
'llama2_70b_4096_pw_long_run',
91+
'llama2_70b_4096_real_data_pw_long_run',
92+
'llama2_70b_4096_pw_rd_tfds',
93+
'llama2_70b_4096_synthetic_pw_lr',
94+
'llama2_70b_4096_synthetic',
8995
'llama3_70b_8192',
9096
'llama3_1_405b_8192_fsdp_dcn',
9197
'mixtral_8x7b_dropped',
@@ -103,6 +109,11 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
103109
'llama2_7b_4096 '
104110
'llama2_70b_4096 '
105111
'llama2_70b_4096_real_data '
112+
'llama2_70b_4096_pw_long_run '
113+
'llama2_70b_4096_real_data_pw_long_run '
114+
'llama2_70b_4096_pw_rd_tfds '
115+
'llama2_70b_4096_synthetic_pw_lr '
116+
'llama2_70b_4096_synthetic '
106117
'llama3_1_405b_8192_fsdp_dcn '
107118
'mixtral_8x7b_dropped '
108119
'mixtral_8x7b_dropped_int8 '
@@ -124,6 +135,51 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
124135
default='maxtext_base_image',
125136
help='version of base docker image to be benchmarked command.',
126137
)
138+
custom_parser.add_argument(
139+
'--pathways_server_image',
140+
type=str,
141+
default=(
142+
'us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest'
143+
),
144+
help='version of pathways server image to be benchmarked command.',
145+
)
146+
custom_parser.add_argument(
147+
'--pathways_proxy_image',
148+
type=str,
149+
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest',
150+
help='version of pathways proxy image to be benchmarked command.',
151+
)
152+
custom_parser.add_argument(
153+
'--pathways_runner_image',
154+
type=str,
155+
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest',
156+
help='version of pathways runner image to be benchmarked command.',
157+
)
158+
custom_parser.add_argument(
159+
'--use_pathways',
160+
type=bool,
161+
default=False,
162+
help='whether to use pathways or not.',
163+
)
164+
custom_parser.add_argument(
165+
'--xpk_path',
166+
type=str,
167+
default='~/xpk',
168+
help='path to xpk dir.',
169+
)
170+
custom_parser.add_argument(
171+
'--priority',
172+
type=str,
173+
default='medium',
174+
help='Priority the XPK workload should run with.',
175+
)
176+
custom_parser.add_argument(
177+
'--max_restarts',
178+
type=int,
179+
default=0,
180+
help='Number of restarts to attempt.',
181+
)
182+
127183

128184
def main() -> None:
129185
parser = argparse.ArgumentParser(
@@ -139,11 +195,19 @@ def main() -> None:
139195
num_slices=options.num_slices,
140196
device_type=options.device_type,
141197
base_output_directory=options.base_output_directory,
198+
priority=options.priority,
199+
max_restarts=options.max_restarts,
142200
)
143201

144202
v6e_env_configs = SWconfig(
145203
base_docker_image=options.base_docker_image,
146204
libtpu_version=options.libtpu_version,
205+
pathways_config=PathwaysConfig(
206+
use_pathways=options.use_pathways,
207+
server_image=options.pathways_server_image,
208+
proxy_image=options.pathways_proxy_image,
209+
runner_image=options.pathways_runner_image,
210+
),
147211
)
148212

149213
v6e_256_configs = HWConfig(
@@ -159,7 +223,7 @@ def main() -> None:
159223
hardware_config=v6e_256_configs,
160224
)
161225

162-
xpk_benchmark_runner(cluster_config, [model_runner])
226+
xpk_benchmark_runner(cluster_config, [model_runner], options.xpk_path)
163227

164228

165229
if __name__ == '__main__':

benchmarks/maxtext_trillium_model_configs.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,46 @@ class MaxTextModel:
291291
),
292292
)
293293

294+
295+
llama2_70b_4096_real_data_pw_long_run = MaxTextModel(
296+
model_name="llama2-70b-4096-rd-pw-lr",
297+
model_type="llama2-70b",
298+
tuning_params={
299+
"per_device_batch_size": 4,
300+
"ici_fsdp_parallelism": -1,
301+
"remat_policy": "full",
302+
"max_target_length": 4096,
303+
"attention": "flash",
304+
"gcs_metrics": True,
305+
"use_iota_embed": True,
306+
"reuse_example_batch": 0,
307+
"profiler": "xplane",
308+
"dataset_path": "gs://max-datasets-rogue",
309+
"dataset_type": "tfds",
310+
"tokenizer_path": "assets/tokenizer.llama2",
311+
"sa_block_q": 1024,
312+
"sa_block_q_dkv": 2048,
313+
"sa_block_q_dq": 2048,
314+
"steps": 1000000,
315+
316+
# Additional tuning params for pathways long running test.
317+
"enable_checkpointing": True,
318+
"async_checkpointing": True,
319+
"checkpoint_period": 100,
320+
"checkpoint_storage_use_ocdbt": False,
321+
"checkpoint_storage_use_zarr3": False,
322+
"metrics_file": "metrics.txt",
323+
"goodput_upload_interval_seconds": 30,
324+
"enable_pathways_goodput": True,
325+
"enable_checkpoint_cloud_logger": True,
326+
"enable_single_controller": True,
327+
},
328+
xla_flags=(
329+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
330+
+ xla_flags_library.CF_FOR_ALL_GATHER
331+
),
332+
)
333+
294334
# ici_fsdp_transpose_parallelism gives one TFLOP better performance.
295335
llama2_70b_4096 = MaxTextModel(
296336
model_name="llama2-70b-4096",
@@ -319,6 +359,151 @@ class MaxTextModel:
319359
+ xla_flags_library.CF_FOR_ALL_GATHER
320360
),
321361
)
362+
llama2_70b_4096_synthetic = MaxTextModel(
363+
model_name="llama2_70b_4096_synthetic",
364+
model_type="llama2-70b",
365+
tuning_params={
366+
"per_device_batch_size": 2,
367+
"ici_fsdp_parallelism": 1,
368+
"ici_fsdp_transpose_parallelism": -1,
369+
"ici_tensor_parallelism": 1,
370+
"remat_policy": "qkv_proj_offloaded",
371+
"max_target_length": 4096,
372+
"attention": "flash",
373+
"gcs_metrics": True,
374+
"use_iota_embed": True,
375+
"dataset_path": "gs://max-datasets-rogue",
376+
"dataset_type": "synthetic",
377+
"enable_checkpointing": False,
378+
"profiler": "xplane",
379+
"sa_block_q": 1024,
380+
"sa_block_q_dkv": 2048,
381+
"sa_block_q_dq": 2048,
382+
},
383+
xla_flags=(
384+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
385+
+ xla_flags_library.CF_FOR_ALL_GATHER
386+
),
387+
)
388+
389+
llama2_70b_4096_synthetic_pw_lr = MaxTextModel(
390+
model_name="llama2_70b_4096_synthetic_pw_lr",
391+
model_type="llama2-70b",
392+
tuning_params={
393+
"per_device_batch_size": 2,
394+
"ici_fsdp_parallelism": 1,
395+
"ici_fsdp_transpose_parallelism": -1,
396+
"ici_tensor_parallelism": 1,
397+
"remat_policy": "qkv_proj_offloaded",
398+
"max_target_length": 4096,
399+
"attention": "flash",
400+
"gcs_metrics": True,
401+
"use_iota_embed": True,
402+
"dataset_path": "gs://max-datasets-rogue",
403+
"dataset_type": "synthetic",
404+
# "enable_checkpointing": False,
405+
"profiler": "xplane",
406+
"sa_block_q": 1024,
407+
"sa_block_q_dkv": 2048,
408+
"sa_block_q_dq": 2048,
409+
"steps": 1000000,
410+
411+
# Additional tuning params for pathways long running test.
412+
"enable_checkpointing": True,
413+
"async_checkpointing": True,
414+
"checkpoint_period": 100,
415+
"checkpoint_storage_use_ocdbt": False,
416+
"checkpoint_storage_use_zarr3": False,
417+
"metrics_file": "metrics.txt",
418+
"goodput_upload_interval_seconds": 30,
419+
"enable_pathways_goodput": True,
420+
"enable_checkpoint_cloud_logger": True,
421+
"enable_single_controller": True,
422+
},
423+
xla_flags=(
424+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
425+
+ xla_flags_library.CF_FOR_ALL_GATHER
426+
),
427+
)
428+
429+
llama2_70b_4096_pw_long_run = MaxTextModel(
430+
model_name="llama2-70b-4096-pw-lr",
431+
model_type="llama2-70b",
432+
tuning_params={
433+
"per_device_batch_size": 4,
434+
"ici_fsdp_parallelism": 1,
435+
"ici_fsdp_transpose_parallelism": -1,
436+
"ici_tensor_parallelism": 1,
437+
"remat_policy": "full",
438+
"max_target_length": 4096,
439+
"attention": "flash",
440+
"gcs_metrics": True,
441+
"use_iota_embed": True,
442+
"dataset_path": "gs://max-datasets-rogue",
443+
"dataset_type": "synthetic",
444+
"reuse_example_batch": 1,
445+
"profiler": "xplane",
446+
"sa_block_q": 1024,
447+
"sa_block_q_dkv": 2048,
448+
"sa_block_q_dq": 2048,
449+
"steps": 1000000,
450+
451+
# Additional tuning params for pathways long running test.
452+
"enable_checkpointing": True,
453+
"async_checkpointing": True,
454+
"checkpoint_period": 100,
455+
"checkpoint_storage_use_ocdbt": False,
456+
"checkpoint_storage_use_zarr3": False,
457+
"metrics_file": "metrics.txt",
458+
"goodput_upload_interval_seconds": 30,
459+
"enable_pathways_goodput": True,
460+
"enable_checkpoint_cloud_logger": True,
461+
"enable_single_controller": True,
462+
},
463+
xla_flags=(
464+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
465+
+ xla_flags_library.CF_FOR_ALL_GATHER
466+
),
467+
)
468+
469+
llama2_70b_4096_pw_rd_tfds = MaxTextModel(
470+
model_name="llama2_70b_4096_pw_rd_tfds",
471+
model_type="llama2-70b",
472+
tuning_params={
473+
"per_device_batch_size": 2,
474+
"ici_fsdp_parallelism": 1,
475+
"ici_fsdp_transpose_parallelism": -1,
476+
"ici_tensor_parallelism": 1,
477+
"remat_policy": "qkv_proj_offloaded",
478+
"max_target_length": 4096,
479+
"attention": "flash",
480+
"gcs_metrics": True,
481+
"use_iota_embed": True,
482+
"dataset_path": "gs://trillium-storage-datasets-sr",
483+
"enable_checkpointing": False,
484+
"profiler": "xplane",
485+
"sa_block_q": 1024,
486+
"sa_block_q_dkv": 2048,
487+
"sa_block_q_dq": 2048,
488+
489+
# Additional tuning params for pathways long running test.
490+
"enable_checkpointing": True,
491+
"async_checkpointing": True,
492+
"checkpoint_period": 100,
493+
"checkpoint_storage_use_ocdbt": False,
494+
"checkpoint_storage_use_zarr3": False,
495+
"metrics_file": "metrics.txt",
496+
"goodput_upload_interval_seconds": 30,
497+
"enable_pathways_goodput": True,
498+
"enable_checkpoint_cloud_logger": True,
499+
"enable_single_controller": True,
500+
},
501+
xla_flags=(
502+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
503+
+ xla_flags_library.CF_FOR_ALL_GATHER
504+
),
505+
)
506+
322507

323508
llama3_8b_8192 = MaxTextModel(
324509
model_name="llama3-8b-8192",
@@ -695,9 +880,14 @@ class MaxTextModel:
695880
gpt_3_175b,
696881
llama2_7b_4096,
697882
llama2_70b_4096,
883+
llama2_70b_4096_pw_long_run,
698884
llama2_70b_4096_real_data,
885+
llama2_70b_4096_real_data_pw_long_run,
886+
llama2_70b_4096_pw_rd_tfds,
699887
llama3_8b_8192, # Not Optimizied yet
700888
llama3_70b_8192, # Not Optimizied yet
889+
llama2_70b_4096_synthetic_pw_lr,
890+
llama2_70b_4096_synthetic,
701891
llama3_1_405b_8192_fsdp_dcn,
702892
llama3_1_8b_8192,
703893
llama3_1_70b_8192,

0 commit comments

Comments
 (0)