-
Notifications
You must be signed in to change notification settings - Fork 72
Open
Description
Hi, Thanks for open-sourcing these batch-invariant implementations and I really appreciate this work!
However, i find that pytorch native log_softmax is deterministic. below is my code:
import torch
from batch_invariant_ops import set_batch_invariant_mode
device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu")
torch.set_default_device(device_type)
count_batch_num = 10
# Just to get the logging out of the way haha
with set_batch_invariant_mode(True):
pass
def test_batch_invariance(dtype=torch.float32):
B, D = 2048, 4096 * 10
a = torch.linspace(-100, 100, B*D, dtype=dtype).reshape(B, D)
# Method 1:
out1 = torch._log_softmax(a[:count_batch_num], dim=-1, half_to_float=False)
# Method 2:
out2 = torch._log_softmax(a, dim=-1, half_to_float=False)[:count_batch_num]
# Check if results are identical
diff = (out1 - out2).abs().max()
return diff.item() == 0, diff
def run_iters(iters=10):
for dtype in [ torch.float32 , torch.bfloat16 ]:
is_deterministic = True
difflist = []
for i in range (iters):
isd, df = test_batch_invariance(dtype)
is_deterministic = is_deterministic and isd
difflist.append(df)
print( f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations")
# Test with standard PyTorch (likely to show differences)
print("Standard PyTorch:")
with set_batch_invariant_mode(False):
run_iters()
# Test with batch-invariant operations
print("\nBatch-Invariant Mode:")
with set_batch_invariant_mode(True):
run_iters()
is somthing wrong for my trial on H20 96G?
Metadata
Metadata
Assignees
Labels
No labels
