Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions unsloth_zoo/fused_losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ def forward(
device = lm_head_weight.device
if extra_kwargs is None: extra_kwargs = {}

# Fix for multi-GPU: ensure all tensors are on the same device for computation
# torch.func.grad_and_value fails when tensors are on different devices
# BUT we must return gradients on the ORIGINAL device of hidden_states
original_hidden_states_device = hidden_states.device
if hidden_states.device != device:
hidden_states = hidden_states.to(device)
if labels.device != device:
labels = labels.to(device)

# Get shifted labels first
if shift_labels:
_labels = torch.empty_like(labels, device = device)
Expand Down Expand Up @@ -328,6 +337,7 @@ def accumulate_chunk(
pass
ctx.save_for_backward(grad_inputs, grad_lm_head, grad_lm_head_bias)
ctx.scaling = scaling
ctx.original_hidden_states_device = original_hidden_states_device
return accumulated_loss
pass

Expand All @@ -338,6 +348,10 @@ def backward(ctx, grad_output,):
scaling = ctx.scaling if ctx.scaling is not None else 1.0
torch._assert(torch.all(grad_output == scaling), f"Fused losses expect grad_output to be all {scaling}, but got {grad_output.ravel()[:10]}")
(grad_inputs, grad_lm_head, grad_lm_head_bias, ) = ctx.saved_tensors
# Fix for multi-GPU: return gradients on the ORIGINAL device of hidden_states
original_device = ctx.original_hidden_states_device
if grad_inputs.device != original_device:
grad_inputs = grad_inputs.to(original_device)
return (None, grad_inputs, grad_lm_head, grad_lm_head_bias, None, None, None, None, None, None, None, None, None,)
pass
pass
Expand Down
3 changes: 2 additions & 1 deletion unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def chunked_hidden_states_selective_log_softmax(
all_per_token_logps = []

for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
# Fix for multi-GPU: ensure all tensors are on the same device
chunk_logits = chunk_hidden_states.to(lm_head.device).to(lm_head.dtype) @ lm_head.t()

if logit_scale_multiply != 0.0:
chunk_logits = chunk_logits * logit_scale_multiply
Expand Down