Skip to content

Commit 8819d65

Browse files
committed
feat: Added compute shader support for vulkan
Vulkan graphics backend has been modified to support compute shaders, additional modifications were also made to the shader compiler so that correct glsl compute shaders can be generated.
1 parent d64fedd commit 8819d65

File tree

12 files changed

+458
-225
lines changed

12 files changed

+458
-225
lines changed

sources/engine/Stride.Graphics/Vulkan/CommandList.Vulkan.cs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ private unsafe void FlushInternal(bool wait)
130130

131131
if (activePipeline != null)
132132
{
133-
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, activePipeline.NativePipeline);
133+
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, activePipeline.NativePipeline);
134134
var descriptorSetCopy = descriptorSet;
135-
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &descriptorSetCopy, dynamicOffsetCount: 0, dynamicOffsets: null);
135+
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &descriptorSetCopy, dynamicOffsetCount: 0, dynamicOffsets: null);
136136
}
137137
SetRenderTargetsImpl(depthStencilBuffer, renderTargetCount, renderTargets);
138138
}
@@ -249,7 +249,11 @@ private unsafe void PrepareDraw()
249249

250250
// Lazily set the render pass and frame buffer
251251
EnsureRenderPass();
252+
BindDescriptorSets();
253+
}
252254

255+
private unsafe void BindDescriptorSets()
256+
{
253257
// Keep track of descriptor pool usage
254258
bool isPoolExhausted = ++allocatedSetCount > GraphicsDevice.MaxDescriptorSetCount;
255259
for (int i = 0; i < DescriptorSetLayout.DescriptorTypeCount; i++)
@@ -328,18 +332,28 @@ private unsafe void PrepareDraw()
328332
sType = VkStructureType.WriteDescriptorSet,
329333
descriptorType = mapping.DescriptorType,
330334
dstSet = localDescriptorSet,
331-
dstBinding = (uint) mapping.DestinationBinding,
335+
dstBinding = (uint)mapping.DestinationBinding,
332336
dstArrayElement = 0,
333337
descriptorCount = 1
334338
};
335339

336340
switch (mapping.DescriptorType)
337341
{
338342
case VkDescriptorType.SampledImage:
339-
var texture = heapObject.Value as Texture;
340-
descriptorData->ImageInfo = new VkDescriptorImageInfo { imageView = texture?.NativeImageView ?? GraphicsDevice.EmptyTexture.NativeImageView, imageLayout = VkImageLayout.ShaderReadOnlyOptimal };
341-
write->pImageInfo = &descriptorData->ImageInfo;
342-
break;
343+
{
344+
var texture = heapObject.Value as Texture;
345+
descriptorData->ImageInfo = new VkDescriptorImageInfo { imageView = texture?.NativeImageView ?? GraphicsDevice.EmptyTexture.NativeImageView, imageLayout = VkImageLayout.ShaderReadOnlyOptimal };
346+
write->pImageInfo = &descriptorData->ImageInfo;
347+
break;
348+
}
349+
350+
case VkDescriptorType.StorageImage:
351+
{
352+
var texture = heapObject.Value as Texture;
353+
descriptorData->ImageInfo = new VkDescriptorImageInfo { imageView = texture?.NativeImageView ?? GraphicsDevice.EmptyTexture.NativeImageView, imageLayout = VkImageLayout.General };
354+
write->pImageInfo = &descriptorData->ImageInfo;
355+
break;
356+
}
343357

344358
case VkDescriptorType.Sampler:
345359
var samplerState = heapObject.Value as SamplerState;
@@ -349,7 +363,7 @@ private unsafe void PrepareDraw()
349363

350364
case VkDescriptorType.UniformBuffer:
351365
var buffer = heapObject.Value as Buffer;
352-
descriptorData->BufferInfo = new VkDescriptorBufferInfo { buffer = buffer?.NativeBuffer ?? VkBuffer.Null, offset = (ulong) heapObject.Offset, range = (ulong) heapObject.Size };
366+
descriptorData->BufferInfo = new VkDescriptorBufferInfo { buffer = buffer?.NativeBuffer ?? VkBuffer.Null, offset = (ulong)heapObject.Offset, range = (ulong)heapObject.Size };
353367
write->pBufferInfo = &descriptorData->BufferInfo;
354368
break;
355369

@@ -364,9 +378,9 @@ private unsafe void PrepareDraw()
364378
}
365379
}
366380

367-
vkUpdateDescriptorSets(GraphicsDevice.NativeDevice, (uint) bindingCount, writes, descriptorCopyCount: 0, descriptorCopies: null);
381+
vkUpdateDescriptorSets(GraphicsDevice.NativeDevice, (uint)bindingCount, writes, descriptorCopyCount: 0, descriptorCopies: null);
368382
#endif
369-
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &localDescriptorSet, dynamicOffsetCount: 0, dynamicOffsets: null);
383+
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &localDescriptorSet, dynamicOffsetCount: 0, dynamicOffsets: null);
370384
}
371385

372386
private readonly FastList<VkCopyDescriptorSet> copies = new();
@@ -390,7 +404,7 @@ public void SetPipelineState(PipelineState pipelineState)
390404

391405
activePipeline = pipelineState;
392406

393-
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, pipelineState.NativePipeline);
407+
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, pipelineState.NativePipeline);
394408
}
395409

396410
public unsafe void SetVertexBuffer(int index, Buffer buffer, int offset, int stride)
@@ -446,7 +460,7 @@ public unsafe void ResourceBarrierTransition(GraphicsResource resource, Graphics
446460
case GraphicsResourceState.PixelShaderResource:
447461
texture.NativeLayout = VkImageLayout.ShaderReadOnlyOptimal;
448462
texture.NativeAccessMask = VkAccessFlags.ShaderRead;
449-
texture.NativePipelineStageMask = VkPipelineStageFlags.FragmentShader;
463+
texture.NativePipelineStageMask = VkPipelineStageFlags.FragmentShader | VkPipelineStageFlags.ComputeShader; // TODO: Not sure why I did this can probably double check ...
450464
break;
451465
case GraphicsResourceState.GenericRead:
452466
texture.NativeLayout = VkImageLayout.General;
@@ -503,6 +517,9 @@ public void SetDescriptorSets(int index, DescriptorSet[] descriptorSets)
503517
/// <inheritdoc />
504518
public void Dispatch(int threadCountX, int threadCountY, int threadCountZ)
505519
{
520+
CleanupRenderPass();
521+
BindDescriptorSets();
522+
vkCmdDispatch(currentCommandList.NativeCommandBuffer, (uint)threadCountX, (uint)threadCountY, (uint)threadCountZ);
506523
}
507524

508525
/// <summary>
@@ -512,6 +529,9 @@ public void Dispatch(int threadCountX, int threadCountY, int threadCountZ)
512529
/// <param name="offsetInBytes">The offset information bytes.</param>
513530
public void Dispatch(Buffer indirectBuffer, int offsetInBytes)
514531
{
532+
CleanupRenderPass();
533+
BindDescriptorSets();
534+
vkCmdDispatchIndirect(currentCommandList.NativeCommandBuffer, indirectBuffer.NativeBuffer, (ulong)offsetInBytes);
515535
}
516536

517537
/// <summary>

sources/engine/Stride.Graphics/Vulkan/GraphicsDevice.Vulkan.cs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ public partial class GraphicsDevice
5353
256, // Sampler
5454
0, // CombinedImageSampler
5555
512, // SampledImage
56-
0, // StorageImage
56+
32, // StorageImage
5757
64, // UniformTexelBuffer
58-
0, // StorageTexelBuffer
58+
32, // StorageTexelBuffer
5959
512, // UniformBuffer
6060
0, // StorageBuffer
6161
0, // UniformBufferDynamic
@@ -292,11 +292,22 @@ private unsafe void InitializePlatformDevice(GraphicsProfile[] graphicsProfiles,
292292
depthClamp = true,
293293
};
294294

295-
Span<VkUtf8String> supportedExtensionProperties = stackalloc VkUtf8String[]
295+
if (graphicsProfiles.Any(x => x >= GraphicsProfile.Level_11_0))
296+
{
297+
enabledFeature.shaderStorageImageReadWithoutFormat = true;
298+
enabledFeature.shaderStorageImageWriteWithoutFormat = true;
299+
}
300+
301+
var extensionProperties = vkEnumerateDeviceExtensionProperties(NativePhysicalDevice);
302+
var availableExtensionNames = new List<string>();
303+
var desiredExtensionNames = new List<string>();
304+
305+
fixed (VkExtensionProperties* extensionPropertiesPtr = extensionProperties)
296306
{
297307
VK_KHR_SWAPCHAIN_EXTENSION_NAME,
298308
VK_EXT_DEBUG_MARKER_EXTENSION_NAME,
299309
};
310+
Span<VkUtf8String> supportedExtensionProperties = stackalloc VkUtf8String[]
300311
var availableExtensionProperties = GetAvailableExtensionProperties(supportedExtensionProperties);
301312
ValidateExtensionPropertiesAvailability(availableExtensionProperties);
302313
var desiredExtensionProperties = new HashSet<VkUtf8String>

0 commit comments

Comments
 (0)