Skip to content

Commit 445333d

Browse files
committed
Add basic Metal Shader Converter support
1 parent fdb2fad commit 445333d

9 files changed

Lines changed: 136 additions & 6 deletions

File tree

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ option(VULKAN_SUPPORT "Vulkan support" ON)
1111
cmake_dependent_option(DIRECTX_SUPPORT "DirectX 12 support" ON "WIN32" OFF)
1212
option(AGILITY_SDK_REQUIRED "Use Agility SDK" OFF)
1313
cmake_dependent_option(METAL_SUPPORT "Metal support" ON "APPLE" OFF)
14+
option(USE_METAL_SHADER_CONVERTER "Use Metal Shader Converter" OFF)
1415
option(ENABLE_VALIDATION "Enable backend graphics api validation layer" ON)
1516
option(BUILD_SAMPLES "Build samples" ON)
1617
cmake_dependent_option(BUILD_TESTING "Build unit tests" ON "NOT IOS_OR_TVOS AND NOT ANDROID" OFF)
@@ -54,6 +55,9 @@ endif()
5455
if (ENABLE_VALIDATION)
5556
add_compile_definitions(ENABLE_VALIDATION)
5657
endif()
58+
if (USE_METAL_SHADER_CONVERTER)
59+
add_compile_definitions(USE_METAL_SHADER_CONVERTER)
60+
endif()
5761

5862
if (BUILD_TESTING)
5963
enable_testing()

src/FlyCube/CMakeLists.txt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,11 @@ list(APPEND GPUDescriptorPool
114114
)
115115

116116
list(APPEND HLSLCompiler
117-
HLSLCompiler/Compiler.h
117+
$<$<BOOL:${USE_METAL_SHADER_CONVERTER}>:HLSLCompiler/MetalIrConverterImpl.mm>
118+
$<$<BOOL:${USE_METAL_SHADER_CONVERTER}>:HLSLCompiler/MetalShaderConverter.cpp>
119+
$<$<BOOL:${USE_METAL_SHADER_CONVERTER}>:HLSLCompiler/MetalShaderConverter.h>
118120
HLSLCompiler/Compiler.cpp
121+
HLSLCompiler/Compiler.h
119122
HLSLCompiler/DXCLoader.cpp
120123
HLSLCompiler/DXCLoader.h
121124
HLSLCompiler/MSLConverter.cpp
@@ -312,13 +315,14 @@ add_library(FlyCube
312315
)
313316

314317
target_link_libraries(FlyCube
315-
$<$<BOOL:${APPLE}>:-framework\ Foundation>
316-
$<$<BOOL:${APPLE}>:-framework\ QuartzCore>
317318
$<$<AND:$<BOOL:${DIRECTX_SUPPORT}>,$<PLATFORM_ID:Windows>>:d3d12>
318319
$<$<AND:$<BOOL:${DIRECTX_SUPPORT}>,$<PLATFORM_ID:Windows>>:dxgi>
319320
$<$<AND:$<BOOL:${DIRECTX_SUPPORT}>,$<PLATFORM_ID:Windows>>:dxguid>
321+
$<$<BOOL:${APPLE}>:-framework\ Foundation>
322+
$<$<BOOL:${APPLE}>:-framework\ QuartzCore>
320323
$<$<BOOL:${METAL_SUPPORT}>:-framework\ Metal>
321324
$<$<BOOL:${METAL_SUPPORT}>:MVKPixelFormats>
325+
$<$<BOOL:${USE_METAL_SHADER_CONVERTER}>:/usr/local/lib/libmetalirconverter.dylib>
322326
$<$<BOOL:${VULKAN_SUPPORT}>:vulkan>
323327
DirectX-Headers
324328
dxc
@@ -332,6 +336,8 @@ target_link_libraries(FlyCube
332336
target_include_directories(FlyCube
333337
PUBLIC
334338
"${CMAKE_CURRENT_SOURCE_DIR}"
339+
$<$<BOOL:${USE_METAL_SHADER_CONVERTER}>:/usr/local/include/metal_irconverter_runtime>
340+
$<$<BOOL:${USE_METAL_SHADER_CONVERTER}>:/usr/local/include/metal_irconverter>
335341
)
336342

337343
if (APPLE)

src/FlyCube/CommandList/MTCommandList.mm

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
#include "Utilities/NotReached.h"
1212
#include "View/MTView.h"
1313

14+
#if defined(USE_METAL_SHADER_CONVERTER)
15+
#include <metal_irconverter_runtime.h>
16+
#endif
17+
1418
namespace {
1519

1620
MTLIndexType ConvertIndexType(gli::format format)
@@ -472,7 +476,11 @@ MTLStages ResourceStateToMTLStages(ResourceState state)
472476
{
473477
id<MTLBuffer> vertex = CastToImpl<MTResource>(resource)->GetBuffer();
474478
AddAllocation(vertex);
479+
#if defined(USE_METAL_SHADER_CONVERTER)
480+
uint32_t index = kIRVertexBufferBindPoint + slot;
481+
#else
475482
uint32_t index = device_.GetMaxPerStageBufferCount() - slot - 1;
483+
#endif
476484
[state_->argument_tables.at(ShaderType::kVertex) setAddress:vertex.gpuAddress + offset atIndex:index];
477485
}
478486

src/FlyCube/Device/MTDevice.mm

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@
2525

2626
#include <type_traits>
2727

28+
namespace {
29+
30+
#if defined(USE_METAL_SHADER_CONVERTER)
31+
constexpr ShaderBlobType kShaderBlobType = ShaderBlobType::kDXIL;
32+
#else
33+
constexpr ShaderBlobType kShaderBlobType = ShaderBlobType::kSPIRV;
34+
#endif
35+
36+
} // namespace
37+
2838
MTLCompareFunction ConvertToCompareFunction(ComparisonFunc func)
2939
{
3040
switch (func) {
@@ -183,7 +193,7 @@ MTLCompareFunction ConvertToCompareFunction(ComparisonFunc func)
183193

184194
std::shared_ptr<Shader> MTDevice::CompileShader(const ShaderDesc& desc)
185195
{
186-
return std::make_shared<MTShader>(*this, Compile(desc, ShaderBlobType::kSPIRV), ShaderBlobType::kSPIRV, desc.type);
196+
return std::make_shared<MTShader>(*this, Compile(desc, kShaderBlobType), kShaderBlobType, desc.type);
187197
}
188198

189199
std::shared_ptr<Pipeline> MTDevice::CreateGraphicsPipeline(const GraphicsPipelineDesc& desc)
@@ -351,7 +361,7 @@ MTLCompareFunction ConvertToCompareFunction(ComparisonFunc func)
351361

352362
ShaderBlobType MTDevice::GetSupportedShaderBlobType() const
353363
{
354-
return ShaderBlobType::kSPIRV;
364+
return kShaderBlobType;
355365
}
356366

357367
uint64_t MTDevice::GetConstantBufferOffsetAlignment() const
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#import <Metal/Metal.h>
2+
3+
#define IR_PRIVATE_IMPLEMENTATION
4+
#include <metal_irconverter.h>
5+
#include <metal_irconverter_runtime.h>
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "HLSLCompiler/MetalShaderConverter.h"
2+
3+
#include "Utilities/Logging.h"
4+
#include "Utilities/NotReached.h"
5+
6+
#include <metal_irconverter.h>
7+
8+
namespace {
9+
10+
IRShaderStage GetShaderStage(ShaderType type)
11+
{
12+
switch (type) {
13+
case ShaderType::kVertex:
14+
return IRShaderStageVertex;
15+
case ShaderType::kPixel:
16+
return IRShaderStageFragment;
17+
case ShaderType::kGeometry:
18+
return IRShaderStageGeometry;
19+
case ShaderType::kCompute:
20+
return IRShaderStageCompute;
21+
case ShaderType::kAmplification:
22+
return IRShaderStageAmplification;
23+
case ShaderType::kMesh:
24+
return IRShaderStageMesh;
25+
default:
26+
NOTREACHED();
27+
}
28+
}
29+
30+
} // namespace
31+
32+
std::vector<uint8_t> ConvertToMetalLibBytecode(ShaderType shader_type, const std::vector<uint8_t>& blob)
33+
{
34+
IRCompiler* compiler = IRCompilerCreate();
35+
IRObject* dxil_obj = IRObjectCreateFromDXIL(blob.data(), blob.size(), IRBytecodeOwnershipNone);
36+
37+
if (shader_type == ShaderType::kVertex) {
38+
IRCompilerSetStageInGenerationMode(compiler, IRStageInCodeGenerationModeUseMetalVertexFetch);
39+
}
40+
41+
IRError* error = nullptr;
42+
IRObject* metal_ir = IRCompilerAllocCompileAndLink(compiler, nullptr, dxil_obj, &error);
43+
if (!metal_ir) {
44+
Logging::Println("IRCompilerAllocCompileAndLink failed: {}", IRErrorGetCode(error));
45+
IRErrorDestroy(error);
46+
IRObjectDestroy(dxil_obj);
47+
IRCompilerDestroy(compiler);
48+
return {};
49+
}
50+
51+
auto* metal_lib = IRMetalLibBinaryCreate();
52+
IRObjectGetMetalLibBinary(metal_ir, GetShaderStage(shader_type), metal_lib);
53+
54+
size_t metal_lib_size = IRMetalLibGetBytecodeSize(metal_lib);
55+
std::vector<uint8_t> metal_lib_bytecode(metal_lib_size);
56+
IRMetalLibGetBytecode(metal_lib, metal_lib_bytecode.data());
57+
58+
IRMetalLibBinaryDestroy(metal_lib);
59+
IRObjectDestroy(metal_ir);
60+
IRObjectDestroy(dxil_obj);
61+
IRCompilerDestroy(compiler);
62+
63+
return metal_lib_bytecode;
64+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#pragma once
2+
#include "Instance/BaseTypes.h"
3+
4+
std::vector<uint8_t> ConvertToMetalLibBytecode(ShaderType shader_type, const std::vector<uint8_t>& blob);

src/FlyCube/Pipeline/MTGraphicsPipeline.mm

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#include "Utilities/Logging.h"
88
#include "Utilities/NotReached.h"
99

10+
#if defined(USE_METAL_SHADER_CONVERTER)
11+
#include <metal_irconverter_runtime.h>
12+
#endif
13+
1014
namespace {
1115

1216
MTLStencilOperation ConvertStencilOperation(StencilOp op)
@@ -245,13 +249,21 @@ MTLBlendFactor ConvertBlendOp(BlendFactor factor)
245249
CHECK(input_layout_stride[vertex.slot] == vertex.stride);
246250
}
247251

252+
#if defined(USE_METAL_SHADER_CONVERTER)
253+
const uint32_t buffer_index = kIRVertexBufferBindPoint + vertex.slot;
254+
#else
248255
const uint32_t buffer_index = device_.GetMaxPerStageBufferCount() - vertex.slot - 1;
256+
#endif
249257
MTLVertexBufferLayoutDescriptor* layout = vertex_descriptor.layouts[buffer_index];
250258
layout.stride = vertex.stride;
251259
layout.stepFunction = MTLVertexStepFunctionPerVertex;
252260
layout.stepRate = 1;
253261

262+
#if defined(USE_METAL_SHADER_CONVERTER)
263+
const uint32_t location = kIRStageInAttributeStartIndex + shader->GetInputLayoutLocation(vertex.semantic_name);
264+
#else
254265
const uint32_t location = shader->GetInputLayoutLocation(vertex.semantic_name);
266+
#endif
255267
MTLVertexAttributeDescriptor* attribute = vertex_descriptor.attributes[location];
256268
attribute.offset = vertex.offset;
257269
attribute.bufferIndex = buffer_index;

src/FlyCube/Shader/MTShader.mm

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
#include "Shader/MTShader.h"
22

33
#include "Device/MTDevice.h"
4-
#include "HLSLCompiler/MSLConverter.h"
54
#include "Utilities/Logging.h"
65

6+
#if defined(USE_METAL_SHADER_CONVERTER)
7+
#include "HLSLCompiler/MetalShaderConverter.h"
8+
#else
9+
#include "HLSLCompiler/MSLConverter.h"
10+
#endif
11+
712
MTShader::MTShader(MTDevice& device, const std::vector<uint8_t>& blob, ShaderBlobType blob_type, ShaderType shader_type)
813
: ShaderBase(blob, blob_type, shader_type)
914
{
15+
#if defined(USE_METAL_SHADER_CONVERTER)
16+
std::string entry_point = "main";
17+
auto metal_lib_bytecode = ConvertToMetalLibBytecode(shader_type, blob);
18+
dispatch_data_t metal_lib_data = dispatch_data_create(metal_lib_bytecode.data(), metal_lib_bytecode.size(), nullptr,
19+
DISPATCH_DATA_DESTRUCTOR_DEFAULT);
20+
NSError* error = nullptr;
21+
id<MTLLibrary> library = [device.GetDevice() newLibraryWithData:metal_lib_data error:&error];
22+
if (library == nullptr) {
23+
Logging::Println("Failed to create MTLLibrary: {}", error);
24+
}
25+
#else
1026
std::string entry_point;
1127
std::string msl_source = GetMSLShader(shader_type, blob_, slot_remapping_, entry_point);
1228

@@ -17,6 +33,7 @@
1733
if (!library) {
1834
Logging::Println("Failed to create MTLLibrary: {}", error);
1935
}
36+
#endif
2037

2138
function_descriptor_ = [MTL4LibraryFunctionDescriptor new];
2239
function_descriptor_.library = library;

0 commit comments

Comments
 (0)