Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions slangpy/slang/atomics.slang
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,22 @@ public extension<T: IAtomicAddable> StorageTraits<T>
{
#ifdef __TARGET_CUDA__
public typealias AtomicBufferType = T*;
public static void atomicAdd(AtomicBufferType buffer, uint idx, T value) { T::atomicAdd(buffer, idx, value); }
public static T atomicLoad(AtomicBufferType buffer, uint idx) { return buffer[idx]; }
public static void atomicStore(AtomicBufferType buffer, uint idx, T value) { buffer[idx] = value; }
// Stride-aware versions (stride parameter ignored on CUDA since it uses element indexing)
public static void atomicAddWithStride(AtomicBufferType buffer, uint idx, uint byte_stride, T value) { T::atomicAdd(buffer, idx, value); }
public static T atomicLoadWithStride(AtomicBufferType buffer, uint idx, uint byte_stride) { return buffer[idx]; }
public static void atomicStoreWithStride(AtomicBufferType buffer, uint idx, uint byte_stride, T value) { buffer[idx] = value; }
#else
public typealias AtomicBufferType = RWByteAddressBuffer;
public static void atomicAdd(AtomicBufferType buffer, uint idx, T value) { T::atomicAdd(buffer, idx*sizeof(T), value); }
public static T atomicLoad(AtomicBufferType buffer, uint idx) { return buffer.Load<T>(idx*sizeof(T));}
public static void atomicStore(AtomicBufferType buffer, uint idx, T value) { buffer.Store<T>(idx*sizeof(T), value); }
// Stride-aware versions that use explicit byte stride instead of sizeof(T)
// This fixes alignment issues on Metal where sizeof(float3)=12 but buffer stride=16
public static void atomicAddWithStride(AtomicBufferType buffer, uint idx, uint byte_stride, T value) { T::atomicAdd(buffer, idx*byte_stride, value); }
public static T atomicLoadWithStride(AtomicBufferType buffer, uint idx, uint byte_stride) { return buffer.Load<T>(idx*byte_stride);}
public static void atomicStoreWithStride(AtomicBufferType buffer, uint idx, uint byte_stride, T value) { buffer.Store<T>(idx*byte_stride, value); }
#endif
// Non-stride versions call strided versions with sizeof(T) as default stride
public static void atomicAdd(AtomicBufferType buffer, uint idx, T value) { atomicAddWithStride(buffer, idx, sizeof(T), value); }
public static T atomicLoad(AtomicBufferType buffer, uint idx) { return atomicLoadWithStride(buffer, idx, sizeof(T)); }
public static void atomicStore(AtomicBufferType buffer, uint idx, T value) { atomicStoreWithStride(buffer, idx, sizeof(T), value); }
}

// 2xfloat16 -> uint tricks
Expand Down
22 changes: 17 additions & 5 deletions slangpy/slang/tensor.slang
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,26 @@ public struct AtomicTensor<T: IAtomicAddable, let D : int> : IRWTensor<T, D>
// Underlying data storage
public StorageTraits<T>::AtomicBufferType _data;

#ifdef __TARGET_METAL__
// Element byte stride - passed from C++ on Metal where sizeof(T) differs from buffer stride.
// For example, sizeof(float3)=12 but Metal buffer stride=16.
public uint _element_byte_stride;
#else
// On non-Metal platforms, sizeof(T) matches buffer stride, so use static const.
static const uint _element_byte_stride = sizeof(T);
#endif

public T read_buffer(int idx)
{
return StorageTraits<T>::atomicLoad(_data, idx);
return StorageTraits<T>::atomicLoadWithStride(_data, idx, _element_byte_stride);
}
public void write_buffer(int idx, T value)
{
StorageTraits<T>::atomicStore(_data, idx, value);
StorageTraits<T>::atomicStoreWithStride(_data, idx, _element_byte_stride, value);
}
public void accumulate_buffer(int idx, T value)
{
StorageTraits<T>::atomicAdd(_data, idx, value);
StorageTraits<T>::atomicAddWithStride(_data, idx, _element_byte_stride, value);
}
[ForceInline] void _accumulate_each<each I : __BuiltinIntegerType>(T value, expand each I idx)
{
Expand All @@ -392,11 +401,11 @@ public struct AtomicTensor<T: IAtomicAddable, let D : int> : IRWTensor<T, D>

[ForceInline] public void add<I : __BuiltinIntegerType>(I idx[D], T value)
{
StorageTraits<T>::atomicAdd(_data, _idx(idx, _strides, _offset), value);
StorageTraits<T>::atomicAddWithStride(_data, _idx(idx, _strides, _offset), _element_byte_stride, value);
}
[ForceInline] public void add<I : __BuiltinIntegerType>(vector<I, D> idx, T value)
{
StorageTraits<T>::atomicAdd(_data, _idx(idx, _strides, _offset), value);
StorageTraits<T>::atomicAddWithStride(_data, _idx(idx, _strides, _offset), _element_byte_stride, value);
}

public __subscript<I : __BuiltinIntegerType>(I indices[D])->T
Expand All @@ -417,6 +426,9 @@ public struct AtomicTensor<T: IAtomicAddable, let D : int> : IRWTensor<T, D>
public void __slangpy_load<let SliceD : int>(ContextND<D - SliceD> ctx, out AtomicTensor<T, SliceD> value)
{
value._data = _data;
#ifdef __TARGET_METAL__
value._element_byte_stride = _element_byte_stride;
#endif
_slice(ctx.call_id, _strides, _offset, value._strides, value._offset);
[ForceUnroll]
for (int i = 0; i < SliceD; ++i)
Expand Down
19 changes: 3 additions & 16 deletions slangpy/tests/slangpy_tests/test_differential_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,8 @@ def test_vec3_call_with_buffers(device_type: DeviceType):

kernel_eval_polynomial.bwds(a, b, res)

# TODO: https://github.com/shader-slang/slangpy/issues/118
# We use ByteAddressBuffer to store the out grads, however, in the shader code, we use
# `sizeof(T)` to calculate the offset of each element, which is wrong because sizeof(T)
# is not guaranteed to be aligned on metal target. So we will just read the raw data back.
# The WAR solution is to provide a element_stride to shader. Slang will add intrinsic to
# calculate the aligned stride in shader code.
a_grad_data = a.grad.storage.to_numpy().view(np.float32)[0 : 32 * 3].reshape(-1, 3)
b_grad_data = b.grad.storage.to_numpy().view(np.float32)[0 : 32 * 3].reshape(-1, 3)
a_grad_data = helpers.read_tensor_from_numpy(a.grad).reshape(-1, 3)
b_grad_data = helpers.read_tensor_from_numpy(b.grad).reshape(-1, 3)

exprected_grad = python_eval_polynomial_a_deriv(a_data, b_data)
assert np.allclose(a_grad_data, exprected_grad)
Expand Down Expand Up @@ -289,14 +283,7 @@ def test_vec3_call_with_buffers_soa(device_type: DeviceType):
a_z_grad_data = a_z.grad.storage.to_numpy().view(np.float32).reshape(-1, 1)

a_grad_data = np.column_stack((a_x_grad_data, a_y_grad_data, a_z_grad_data))

# TODO: https://github.com/shader-slang/slangpy/issues/118
# We use ByteAddressBuffer to store the out grads, however, in the shader code, we use
# `sizeof(T)` to calculate the offset of each element, which is wrong because sizeof(T)
# is not guaranteed to be aligned on metal target. So we will just read the raw data back.
# The WAR solution is to provide a element_stride to shader. Slang will add intrinsic to
# calculate the aligned stride in shader code.
b_grad_data = b.grad.storage.to_numpy().view(np.float32)[0 : 32 * 3]
b_grad_data = helpers.read_tensor_from_numpy(b.grad).reshape(-1, 3)

exprected_grad = python_eval_polynomial_a_deriv(a_data, b_data)
assert np.allclose(a_grad_data, exprected_grad)
Expand Down
25 changes: 25 additions & 0 deletions src/slangpy_ext/utils/slangpytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ NativeTensorMarshall::TensorFieldOffsets NativeTensorMarshall::extract_tensor_fi
offsets.shape = tensor_cursor["_shape"].offset();
offsets.strides = tensor_cursor["_strides"].offset();
offsets.offset = tensor_cursor["_offset"].offset();

// Extract element_byte_stride offset if present (for AtomicTensor on Metal)
ShaderCursor ebs_field = tensor_cursor.find_field("_element_byte_stride");
if (ebs_field.is_valid())
offsets.element_byte_stride = ebs_field.offset();

offsets.is_valid = true;
offsets.array_stride
= (int)tensor_cursor["_shape"].slang_type_layout()->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM);
Expand Down Expand Up @@ -488,6 +494,17 @@ void NativeTensorMarshall::write_tensor_fields_from_buffer(
offsets.offset.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offset
);

// Write element byte stride if field exists (for AtomicTensor on Metal)
// This is needed because sizeof(T) in shader may differ from buffer stride
// due to alignment requirements (e.g., sizeof(float3)=12 but Metal buffer stride=16)
if (offsets.element_byte_stride.is_valid()) {
write_value_helper(
base_address,
offsets.element_byte_stride.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
static_cast<uint32_t>(buffer->desc().struct_size)
);
}
}

void NativeTensorMarshall::write_tensor_fields_from_pointer(
Expand Down Expand Up @@ -530,6 +547,14 @@ void NativeTensorMarshall::write_tensor_fields_from_pointer(
offsets.offset.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offset
);

// Note: element_byte_stride is not written here for PyTorch tensors.
// On CUDA (the only backend PyTorch supports), _element_byte_stride is static const in shader.
// This field only exists as a runtime field on Metal (for AtomicTensor), and PyTorch doesn't support Metal.
SGL_CHECK(
!offsets.element_byte_stride.is_valid(),
"Unexpected element_byte_stride field for PyTorch tensor - this path should only be used on CUDA"
);
}

void NativeTensorMarshall::write_native_tensor_fields(
Expand Down
11 changes: 6 additions & 5 deletions src/slangpy_ext/utils/slangpytensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ class NativeTensorMarshall : public NativeMarshall {
/// Public so NativeTorchTensorMarshall can reuse them
struct TensorFieldOffsets {
int array_stride;
ShaderOffset data; // Offset for _data field
ShaderOffset shape; // Offset for _shape field
ShaderOffset strides; // Offset for _strides field
ShaderOffset offset; // Offset for _offset field
bool is_valid = false; // Whether offsets have been initialized
ShaderOffset data; // Offset for _data field
ShaderOffset shape; // Offset for _shape field
ShaderOffset strides; // Offset for _strides field
ShaderOffset offset; // Offset for _offset field
ShaderOffset element_byte_stride; // Offset for _element_byte_stride field (if present)
bool is_valid = false; // Whether offsets have been initialized
};

/// Cached offsets for all tensor variants (primal, grad_in, grad_out)
Expand Down