Skip to content

Optimize loss calculation with in-place gradients calculation ~40% memory save#185

Merged
sleepcoo merged 3 commits intosgl-project:mainfrom
yubofredwang:optimize-loss-calc
Aug 27, 2025
Merged

Optimize loss calculation with in-place gradients calculation ~40% memory save#185
sleepcoo merged 3 commits intosgl-project:mainfrom
yubofredwang:optimize-loss-calc

Conversation

@yubofredwang
Copy link
Collaborator

@yubofredwang yubofredwang commented Aug 27, 2025

Motivation

The loss calculation on TTT steps of logits are taking up huge chunk of memory. According to the memory profiling, it is due to the intermediate tensors and gradients created during backward torch autograd.

Screenshot 2025-08-26 at 11 05 00 PM

??:0:torch::autograd::generated::EmbeddingBackward0::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&)

We would like to save the gradients into the logits instead of creating a separate tensor.

Modifications

Added a triton implementation of log softmax calculation with in-place mod of input logits.

Related Issues

#112

Accuracy Test

Unit test added

Benchmark & Profiling

Config (B,T,V)  PyTorch (ms)    Triton (ms)     Speedup    PyTorch Mem (GB)   Triton Mem (GB) Memory Save 
-------------------------------------------------------------------------------------------------------------------
(1,1024,32000)  449.08          435.22          1.03x      1.85               0.98            46.7%       
(1,1024,64000)  167.10          467.80          0.36x      3.68               2.81            23.4%       
(1,4096,32000)  127.67          7.03            18.15x     7.32               5.62            23.3%       
(1,4096,64000)  20.78           24.35           0.85x      14.65              11.23           23.3%       
(1,8192,32000)  20.48           13.56           1.51x      21.48              14.65           31.8%       
(1,8192,64000)  41.14           48.11           0.86x      29.30              22.46           23.3%       
(1,16384,32000) 41.11           26.95           1.53x      42.97              29.30           31.8%   

Also 50% faster

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@yubofredwang yubofredwang requested a review from zyksir August 27, 2025 06:01
@sleepcoo sleepcoo merged commit d852345 into sgl-project:main Aug 27, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants