Skip to content

Commit 4612d20

Browse files
committed
User argpars utils to show default args on command line
1 parent c079904 commit 4612d20

27 files changed

+30
-30
lines changed

ch04/04_gqa/gpt_with_kv_gqa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
290290

291291

292292
def main():
293-
parser = argparse.ArgumentParser(description="Run GPT with grouped-query attention.")
293+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Run GPT with grouped-query attention.")
294294
parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.")
295295
parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
296296
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")

ch04/04_gqa/gpt_with_kv_mha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
278278

279279

280280
def main():
281-
parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.")
281+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Run GPT with standard multi-head attention.")
282282
parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.")
283283
parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
284284
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")

ch04/04_gqa/memory_estimator_gqa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def calc_kv_bytes_total(batch, context_length, emb_dim, n_heads,
3131

3232

3333
def main():
34-
p = argparse.ArgumentParser(description="Estimate KV-cache memory for MHA vs GQA")
34+
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Estimate KV-cache memory for MHA vs GQA")
3535
p.add_argument("--context_length", default=1024, type=int)
3636
p.add_argument("--emb_dim", required=True, type=int)
3737
p.add_argument("--n_heads", required=True, type=int)

ch04/05_mla/gpt_with_kv_mha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
278278

279279

280280
def main():
281-
parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.")
281+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Run GPT with standard multi-head attention.")
282282
parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.")
283283
parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
284284
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")

ch04/05_mla/gpt_with_kv_mla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,13 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
286286

287287

288288
def main():
289-
parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.")
289+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Run GPT with standard multi-head attention.")
290290
parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.")
291291
parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
292292
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")
293293
parser.add_argument("--max_new_tokens", type=int, default=200, help="Number of tokens to generate.")
294294
parser.add_argument("--latent_dim", type=int, default=None,
295-
help="Latent dim for MLA (default: d_out//8)")
295+
help="Latent dim for MLA")
296296

297297
args = parser.parse_args()
298298

ch04/05_mla/memory_estimator_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def calc_mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_
3737

3838

3939
def main():
40-
p = argparse.ArgumentParser(description="Estimate KV-cache memory for MHA vs GQA vs MLA")
40+
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Estimate KV-cache memory for MHA vs GQA vs MLA")
4141
p.add_argument("--context_length", default=1024, type=int)
4242
p.add_argument("--emb_dim", required=True, type=int)
4343
p.add_argument("--n_heads", required=True, type=int)

ch04/06_swa/gpt_with_kv_mha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
278278

279279

280280
def main():
281-
parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.")
281+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Run GPT with standard multi-head attention.")
282282
parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.")
283283
parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
284284
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")

ch04/06_swa/gpt_with_kv_swa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def generate_text_simple_cached(model, idx, max_new_tokens,
311311

312312

313313
def main():
314-
parser = argparse.ArgumentParser(description="Run GPT with standard multi-head attention.")
314+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Run GPT with standard multi-head attention.")
315315
parser.add_argument("--emb_dim", type=int, default=768, help="Model embedding dimension.")
316316
parser.add_argument("--n_heads", type=int, default=12, help="Number of attention heads.")
317317
parser.add_argument("--n_layers", type=int, default=12, help="Number of transformer blocks.")

ch04/06_swa/memory_estimator_swa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def estimate_totals(context_length, sliding_window_size, emb_dim, n_heads, n_lay
9090

9191

9292
def main():
93-
p = argparse.ArgumentParser(description="Estimate KV-cache memory for MHA/GQA with SWA layer ratio")
93+
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Estimate KV-cache memory for MHA/GQA with SWA layer ratio")
9494
p.add_argument("--context_length", default=1024, type=int)
9595
p.add_argument("--sliding_window_size", required=True, type=int,
9696
help="SWA window size W per SWA layer.")

ch04/06_swa/plot_memory_estimates_swa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def calc_kv_bytes_total_gqa_swa(
102102

103103

104104
def main():
105-
p = argparse.ArgumentParser(
105+
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
106106
description="KV-cache vs Context Length — MHA vs GQA with SWA overlays"
107107
)
108108
p.add_argument("--emb_dim", type=int, required=True)

0 commit comments

Comments
 (0)