Skip to content

Commit 132829f

Browse files
authored
Merge pull request #1 from tysam-code/cleaner-logging
Updated the logging to better align with repo goals, README.md
2 parents d683ee9 + bd341d7 commit 132829f

File tree

2 files changed

+52
-37
lines changed

2 files changed

+52
-37
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Welcome to the hyperlightspeedbench CIFAR-10 (HLB-CIFAR10) repo.
88
`git clone https://github.com/tysam-code/hlb-CIFAR10 && cd hlb-CIFAR10 && python -m pip install -r requirements.txt && python main.py`
99

1010

11-
If you're curious, this code is generally Colab friendly and is built to appropriately reset state without having to reload the instance (in fact -- most of this was developed in Colab!)
11+
If you're curious, this code is generally Colab friendly (in fact -- most of this was developed in Colab!). Just be sure to uncomment the reset block at the top of the code.
1212

1313

1414
### Main
@@ -22,10 +22,10 @@ Goals:
2222
* near world-record single-GPU training time (~<18.1 seconds on an A100) .
2323
* <2 seconds training time in <2 years
2424

25-
This is a neural network implementation that recreates and reproduces from nearly the ground-up in a painstakingly accurate manner a hacking-friendly version of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/) -- 94% accuracy in ~<18.1 seconds on an A100 GPU. There is only one primary functional difference that I am aware of. The code has been rewritten practically from scratch in an annotated, hackable flat structure that for me has been extremely fast to prototype ideas in. This code took about 120-130 hours of work from start to finish, and about about 80-90+ of those hours were mind-numbingly tedious debugging of the minutia between my implementation and David's implementation. It turns out that there are so many little things to consider to actually achieve and hold the accuracy David achieved, I find it an interesting balance of tons of wiggle room in places and none at all in others.
25+
This is a neural network implementation that painstakingly reproduces from nearly the ground-up a hacking-friendly version of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/) -- 94% accuracy in ~<18.1 seconds on an A100 GPU. There is only one primary functional difference that I am aware of. The intended structure of the code is a flat structure intended for quick hacking in practically _any_ (!!!) stage of the training pipeline. This code took about 120-130 hours of work from start to finish, and about about 80-90+ of those hours were mind-numbingly tedious debugging of performance differences between my work David's original work. It was somewhat surprising in places which things really mattered, and which did not. To that end, I found it very educational to write (and may do a writeup someday if enough people and I have enough interest in it).
2626

2727

28-
I built this because I loved David's work but for for my personal experimentation, his nearly-purely-functional style made implementing radical idea sketches nearly impossible. As a complement to his work, this code is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. The upside is that since making this repository, I've already gone from idea-to-new-single-GPU-world-record in under 10 minutes for one idea, and maybe under an hourish for doing the same thing a second, different idea as well. I personally find this code a delight to use, and hope you do too! :D Please let me know, whichever way it ends up going for you. I hope to publish those updates in the future, but for now, this is a (relatively) accurate baseline.
28+
I built this because I loved David's work but found it difficult for my quick-experiment-and-hacking usecases. As a complement to his work, this code is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. The upside for me in this repository is that I've already been able to explore a wide variety of ideas rapidly, some of which already improve over the baseline (hopefully more of that in future releases). I truly enjoy personally using this code, and hope you do as well! :D Please let me know if you have any feedback. I hope to continue publishing updates to this in the future, but for now, this is a (relatively) accurate baseline.
2929

3030

3131
Your support helps a lot -- even if it's a dollar as month. I have several more projects I'm in various stages on, and you can help me have the money and time to get them to the finish line! If you like what I'm doing, or this project has brought you some value, please consider subscribing on my [Patreon](https://www.patreon.com/user/posts?u=83632131). There's not too many extra rewards besides better software more frequently. Alternatively, if you want me to work up to a part-time amount of hours with you, feel free to reach out to me at [email protected]. I'd love to hear from you.
@@ -49,4 +49,4 @@ Currently, submissions to this codebase as a benchmark are closed as we figure o
4949

5050
#### Bugs & Etc.
5151

52-
If you find a bug, open an issue! L:D If you have a success story, let me know! It helps me understand what works and doesn't more than you might expect -- if I know how this is specifically helping people, that can help me further improve as a developer, as I can keep that in mind when developing other software for people in the future. :D :)
52+
If you find a bug, open an issue! L:D If you have a success story, let me know! It helps me understand what works and doesn't more than you might expect -- if I know how this is specifically helping people, that can help me further improve as a developer, as I can keep that in mind when developing other software for people in the future. :D :)

main.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# Bug: Currently we're having some trouble with differentiating between a colab notebook and a .py file, so right now you'll just need to uncomment the below if you're wanting to properly clear your variable state in your colab notebook each time.
1+
# BUG: Currently, I haven't found a good way to distinguish between python and ipython, so we're leaving the 'colab mode' to require a manual
2+
# uncomment of this code to fully guard against state errors. This lets people run the one-line download-and-train command appropriately.
23
"""
34
# If we are in an ipython session or a notebook, clear the state to avoid bugs
45
try:
@@ -17,11 +18,6 @@
1718
import torch.nn.functional as F
1819
from torch import nn
1920

20-
import rich
21-
from rich.live import Live
22-
from rich.progress import Progress, Console, track
23-
from rich.table import Table
24-
2521
import torchvision
2622
from torchvision import transforms
2723

@@ -457,11 +453,26 @@ def init_split_parameter_dictionaries(network):
457453
## Hey look, it's the soft-targets/label-smoothed loss! Native to PyTorch. Now, _that_ is pretty cool, and simplifies things a lot, to boot! :D :)
458454
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2, reduction='none')
459455

460-
console = Console()
461-
progress_bar = Progress()
462-
progress_table = Table()
463-
for column_name in ['epoch', 'train_loss', 'val_loss', 'train_acc', 'val_acc', 'val_ema_acc', 'timing (s)']:
464-
progress_table.add_column(column_name)
456+
logging_columns_list = ['epoch', 'train_loss', 'val_loss', 'train_acc', 'val_acc', 'ema_val_acc', 'total_time_seconds']
457+
# define the printing function and print the column heads
458+
def print_training_details(columns_list, separator_left='| ', separator_right=' ', final="|", column_heads_only=False, is_final_entry=False):
459+
print_string = ""
460+
if column_heads_only:
461+
for column_head_name in columns_list:
462+
print_string += separator_left + column_head_name + separator_right
463+
print_string += final
464+
print('-'*(len(print_string))) # print the top bar
465+
print(print_string)
466+
print('-'*(len(print_string))) # print the bottom bar
467+
else:
468+
for column_value in columns_list:
469+
print_string += separator_left + column_value + separator_right
470+
print_string += final
471+
print(print_string)
472+
if is_final_entry:
473+
print('-'*(len(print_string))) # print the final output bar
474+
475+
print_training_details(logging_columns_list, column_heads_only=True) # print out the training column heads.
465476

466477
########################################
467478
# Train and Eval #
@@ -471,13 +482,13 @@ def init_split_parameter_dictionaries(network):
471482
def main():
472483
# Initializing constants for the whole run.
473484
net_ema = None ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training
474-
## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs)
485+
## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs)
475486

476-
total_time = 0.
487+
total_time_seconds = 0.
477488
current_steps = 0.
478489

479490
# TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole batchsize)....
480-
num_steps_per_epoch = len(data['train']['images']) // batchsize # todo: a tad bit of cleanup here. :) :>
491+
num_steps_per_epoch = len(data['train']['images']) // batchsize # todo: a bit of a tad of cleanup here. ::::))) :>>>
481492
total_train_steps = num_steps_per_epoch * hyp['misc']['train_epochs']
482493
ema_epoch_start = hyp['misc']['train_epochs'] - hyp['misc']['ema']['epochs']
483494
num_low_lr_steps_for_ema = hyp['misc']['ema']['epochs'] * num_steps_per_epoch
@@ -506,18 +517,16 @@ def main():
506517
lr_sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps-num_low_lr_steps_for_ema, anneal_strategy='linear', cycle_momentum=False)
507518
lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps-num_low_lr_steps_for_ema, anneal_strategy='linear', cycle_momentum=False)
508519

509-
510520
## For accurately timing GPU code
511521
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
512522
## There's another repository that's mainly reorganized David's code while still maintaining some of the functional structure, and it
513523
## has a timing feature too, but there's no synchronizes so I suspect the times reported are much faster than they may be in actuality
514524
## due to some of the quirks of timing GPU operations.
515525
torch.cuda.synchronize() ## clean up any pre-net setup operations
516526

517-
# Hack -- disabling the below until we figure out how to properly nest these tables during training.
518-
#with Live(progress_table, refresh_per_second=1):
519-
if True:
520-
for epoch in track(range(hyp['misc']['train_epochs'])):
527+
528+
if True: ## Sometimes we need a conditional/for loop here, this is placed to save the trouble of needing to indent
529+
for epoch in range(hyp['misc']['train_epochs']):
521530
#################
522531
# Training Mode #
523532
#################
@@ -539,8 +548,8 @@ def main():
539548

540549
# we only take the last-saved accs and losses from train
541550
if epoch_step % 50 == 0:
542-
accuracy_train = (outputs.detach().argmax(-1) == targets).float().mean().item()
543-
loss_train = loss.cpu().item()/batchsize
551+
train_acc = (outputs.detach().argmax(-1) == targets).float().mean().item()
552+
train_loss = loss.detach().cpu().item()/batchsize
544553

545554
loss.backward()
546555

@@ -567,7 +576,7 @@ def main():
567576

568577
ender.record()
569578
torch.cuda.synchronize()
570-
total_time += 1e-3 * starter.elapsed_time(ender)
579+
total_time_seconds += 1e-3 * starter.elapsed_time(ender)
571580

572581
####################
573582
# Evaluation Mode #
@@ -579,7 +588,7 @@ def main():
579588
loss_list_val, acc_list, acc_list_ema = [], [], []
580589

581590
with torch.no_grad():
582-
# TODO: Copy is probably slow, we can def avoid this I think....
591+
# TODO: Copy is probably slow, we can def avoid this somehow, I think....
583592
for inputs, targets in get_batches(data, key='eval', batchsize=eval_batchsize):
584593
if epoch >= ema_epoch_start:
585594
outputs = net_ema(inputs)
@@ -588,18 +597,24 @@ def main():
588597
loss_list_val.append(loss_fn(outputs, targets).float().mean())
589598
acc_list.append((outputs.argmax(-1) == targets).float().mean())
590599

591-
accuracy_non_ema = torch.stack(acc_list).mean().item()
592-
accuracy_ema = None
600+
val_acc = torch.stack(acc_list).mean().item()
601+
ema_val_acc = None
593602
# TODO: We can fuse these two operations (just above and below) all-together like :D :))))
594603
if epoch >= ema_epoch_start:
595-
accuracy_ema = torch.stack(acc_list_ema).mean().item()
596-
597-
loss_val = torch.stack(loss_list_val).mean().item()
598-
599-
format_for_table = lambda x: "{0:.4f}".format(x) if x is not None else None
600-
progress_table.add_row(*list(map(format_for_table, [epoch, loss_train, loss_val, accuracy_train, accuracy_non_ema, accuracy_ema, total_time])))
601-
#progress_bar.advance() ## Now that we've finish everything, we update our eta (and we do it last so the ETA is more accurate.)
602-
console.print(progress_table)
604+
ema_val_acc = torch.stack(acc_list_ema).mean().item()
605+
606+
val_loss = torch.stack(loss_list_val).mean().item()
607+
# We basically need to look up local variables by name so we can have the names, so we can pad to the proper column width.
608+
## Printing stuff in the terminal can get tricky and this used to use an outside library, but some of the required stuff seemed even
609+
## more heinous than this, unfortunately. So we switched to the "more simple" version of this!
610+
format_for_table = lambda x, locals: (f"{locals[x]}".rjust(len(x))) \
611+
if type(locals[x]) == int else "{:0.4f}".format(locals[x]).rjust(len(x)) \
612+
if locals[x] is not None \
613+
else " "*len(x)
614+
615+
# Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....)
616+
## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round.
617+
print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch == hyp['misc']['train_epochs'] - 1))
603618

604619
if __name__ == "__main__":
605620
for run_num in range(5):

0 commit comments

Comments
 (0)