Skip to content

Commit 295b672

Browse files
[Backport to llvm_release_180] initial support for SPV_INTEL_device_barrier (KhronosGroup#3536)
Backport of PR KhronosGroup#3461 into `llvm_release_180`. All commits applied cleanly. --------- Co-authored-by: Ben Ashbaugh <ben.ashbaugh@intel.com>
1 parent e2b703f commit 295b672

File tree

13 files changed

+337
-7
lines changed

13 files changed

+337
-7
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ EXT(SPV_INTEL_runtime_aligned)
5757
EXT(SPV_EXT_arithmetic_fence)
5858
EXT(SPV_INTEL_arithmetic_fence)
5959
EXT(SPV_INTEL_bfloat16_conversion)
60+
EXT(SPV_INTEL_device_barrier)
6061
EXT(SPV_INTEL_joint_matrix)
6162
EXT(SPV_INTEL_hw_thread_queries)
6263
EXT(SPV_INTEL_global_variable_decorations)

lib/SPIRV/SPIRVInternal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ CallInst *setAttrByCalledFunc(CallInst *Call);
985985
bool isSPIRVBuiltinVariable(GlobalVariable *GV, SPIRVBuiltinVariableKind *Kind);
986986
// Transform builtin variable from GlobalVariable to builtin call.
987987
// e.g.
988-
// - GlobalInvolcationId[x] -> _Z33__spirv_BuiltInGlobalInvocationIdi(x)
988+
// - GlobalInvocationId[x] -> _Z33__spirv_BuiltInGlobalInvocationIdi(x)
989989
// - WorkDim -> _Z22__spirv_BuiltInWorkDimv()
990990
bool lowerBuiltinVariableToCall(GlobalVariable *GV,
991991
SPIRVBuiltinVariableKind Kind);

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2034,13 +2034,15 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
20342034
} else if (auto *Load = dyn_cast<LoadInst>(U)) {
20352035
// Figure out which index the accumulated offset corresponds to. If we
20362036
// have a weird offset (e.g., trying to load byte 7), bail out.
2037-
Type *ScalarTy = ReplacementFunc->getReturnType();
20382037
APInt Index;
2039-
uint64_t Remainder;
2040-
APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
2041-
Index, Remainder);
2042-
if (Remainder != 0)
2043-
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
2038+
Type *ScalarTy = ReplacementFunc->getReturnType();
2039+
if (!ScalarTy->isIntegerTy(1)) {
2040+
uint64_t Remainder;
2041+
APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
2042+
Index, Remainder);
2043+
if (Remainder != 0)
2044+
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
2045+
}
20442046

20452047
IRBuilder<> Builder(Load);
20462048
Value *Replacement;

lib/SPIRV/libSPIRV/SPIRVEntry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,8 @@ class SPIRVCapability : public SPIRVEntryNoId<OpCapability> {
912912
return ExtensionID::SPV_INTEL_function_variants;
913913
case internal::CapabilityBFloat16ArithmeticINTEL:
914914
return ExtensionID::SPV_INTEL_bfloat16_arithmetic;
915+
case internal::CapabilityDeviceBarrierINTEL:
916+
return ExtensionID::SPV_INTEL_device_barrier;
915917
default:
916918
return {};
917919
}

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,8 @@ template <> inline void SPIRVMap<BuiltIn, SPIRVCapVec>::init() {
584584
{internal::CapabilityHWThreadQueryINTEL});
585585
ADD_VEC_INIT(internal::BuiltInGlobalHWThreadIDINTEL,
586586
{internal::CapabilityHWThreadQueryINTEL});
587+
ADD_VEC_INIT(internal::BuiltInDeviceBarrierValidINTEL,
588+
{internal::CapabilityDeviceBarrierINTEL});
587589
}
588590

589591
template <> inline void SPIRVMap<MemorySemanticsMask, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2359,13 +2359,40 @@ class SPIRVControlBarrier : public SPIRVInstruction {
23592359
return getValues(Operands);
23602360
}
23612361

2362+
SPIRVCapVec getRequiredCapability() const override {
2363+
if (isDeviceBarrier()) {
2364+
return getVec(internal::CapabilityDeviceBarrierINTEL);
2365+
}
2366+
return SPIRVInstruction::getRequiredCapability();
2367+
}
2368+
std::optional<ExtensionID> getRequiredExtension() const override {
2369+
if (isDeviceBarrier()) {
2370+
return ExtensionID::SPV_INTEL_device_barrier;
2371+
}
2372+
return std::nullopt;
2373+
}
2374+
23622375
protected:
23632376
_SPIRV_DEF_ENCDEC3(ExecScope, MemScope, MemSema)
23642377
void validate() const override {
23652378
assert(OpCode == OC);
23662379
assert(WordCount == 4);
23672380
SPIRVInstruction::validate();
23682381
}
2382+
2383+
bool isDeviceBarrier() const {
2384+
if (!getModule()->isAllowedToUseExtension(
2385+
ExtensionID::SPV_INTEL_device_barrier))
2386+
return false;
2387+
SPIRVValue *ESV = getValue(ExecScope);
2388+
if (ESV && ESV->getOpCode() == OpConstant) {
2389+
if (static_cast<SPIRVConstant *>(ESV)->getZExtIntValue() != ScopeDevice) {
2390+
return false;
2391+
}
2392+
}
2393+
return true;
2394+
}
2395+
23692396
SPIRVId ExecScope;
23702397
SPIRVId MemScope = SPIRVID_INVALID;
23712398
SPIRVId MemSema = SPIRVID_INVALID;

lib/SPIRV/libSPIRV/SPIRVIsValidEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ inline bool isValid(spv::BuiltIn V) {
282282
case BuiltInCullMaskKHR:
283283
case internal::BuiltInSubDeviceIDINTEL:
284284
case internal::BuiltInGlobalHWThreadIDINTEL:
285+
case internal::BuiltInDeviceBarrierValidINTEL:
285286
return true;
286287
default:
287288
return false;

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ template <> inline void SPIRVMap<BuiltIn, std::string>::init() {
349349
add(BuiltInMax, "BuiltInMax");
350350
add(internal::BuiltInSubDeviceIDINTEL, "BuiltInSubDeviceIDINTEL");
351351
add(internal::BuiltInGlobalHWThreadIDINTEL, "BuiltInGlobalHWThreadIDINTEL");
352+
add(internal::BuiltInDeviceBarrierValidINTEL,
353+
"BuiltInDeviceBarrierValidINTEL");
352354
}
353355
SPIRV_DEF_NAMEMAP(BuiltIn, SPIRVBuiltInNameMap)
354356

@@ -691,6 +693,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
691693
add(CapabilityFloat8CooperativeMatrixEXT, "Float8CooperativeMatrixEXT");
692694
add(internal::CapabilityPredicatedIOINTEL, "PredicatedIOINTEL");
693695
add(internal::CapabilitySigmoidINTEL, "SigmoidINTEL");
696+
add(internal::CapabilityDeviceBarrierINTEL, "DeviceBarrierINTEL");
694697
add(internal::CapabilityFloat4E2M1INTEL, "Float4E2M1INTEL");
695698
add(internal::CapabilityFloat4E2M1CooperativeMatrixINTEL,
696699
"Float4E2M1CooperativeMatrixINTEL");

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ enum InternalCapability {
115115
ICapabilityHWThreadQueryINTEL = 6134,
116116
ICapGlobalVariableDecorationsINTEL = 6146,
117117
ICapabilitySigmoidINTEL = 6167,
118+
ICapabilityDeviceBarrierINTEL = 6185,
118119
ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192,
119120
ICapabilityFloat4E2M1INTEL = 6212,
120121
ICapabilityFloat4E2M1CooperativeMatrixINTEL = 6213,
@@ -167,6 +168,7 @@ enum InternalFPEncoding {
167168
enum InternalBuiltIn {
168169
IBuiltInSubDeviceIDINTEL = 6135,
169170
IBuiltInGlobalHWThreadIDINTEL = 6136,
171+
IBuiltInDeviceBarrierValidINTEL = 6186,
170172
};
171173

172174
#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
@@ -198,6 +200,9 @@ _SPIRV_OP(Op, CooperativeMatrixConstructCheckedINTEL)
198200
_SPIRV_OP(Capability, CooperativeMatrixInvocationInstructionsINTEL)
199201
_SPIRV_OP(Op, CooperativeMatrixApplyFunctionINTEL)
200202

203+
_SPIRV_OP(Capability, DeviceBarrierINTEL)
204+
_SPIRV_OP(BuiltIn, DeviceBarrierValidINTEL)
205+
201206
_SPIRV_OP(Capability, HWThreadQueryINTEL)
202207
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
203208
_SPIRV_OP(BuiltIn, GlobalHWThreadIDINTEL)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
;; kernel void test(global uint* dst)
2+
;; {
3+
;; int scope = magic_get_scope();
4+
;; __spirv_ControlBarrier(scope, 1, 264); // local
5+
;; __spirv_ControlBarrier(scope, 1, 520); // global
6+
;; __spirv_ControlBarrier(scope, 1, 2056); // image
7+
;;
8+
;; __spirv_ControlBarrier(scope, 0, 520); // global, all_svm_devices
9+
;; __spirv_ControlBarrier(scope, 1, 520); // global, device
10+
;; __spirv_ControlBarrier(scope, 2, 520); // global, work_group
11+
;; __spirv_ControlBarrier(scope, 3, 520); // global, subgroup
12+
;; __spirv_ControlBarrier(scope, 4, 520); // global, work_item
13+
;;}
14+
15+
; Test for SPV_INTEL_device_barrier (SPIR-V friendly LLVM IR)
16+
; RUN: llvm-as %s -o %t.bc
17+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_device_barrier
18+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
19+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
20+
21+
; RUN: llvm-spirv %t.spv -o %t.rev.bc -r --spirv-target-env=SPV-IR
22+
; RUN: llvm-dis %t.rev.bc -o %t.rev.ll
23+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
24+
25+
; RUN: llvm-spirv %t.bc -o %t.disabled.spv
26+
; RUN: llvm-spirv %t.disabled.spv -o %t.disabled.spt --to-text
27+
; RUN: FileCheck < %t.disabled.spt %s --check-prefix=CHECK-SPIRV-EXTENSION-DISABLED
28+
29+
; ModuleID = 'device_barrier_spirv.cl'
30+
source_filename = "device_barrier_spirv.cl"
31+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
32+
target triple = "spir64"
33+
34+
; CHECK-SPIRV: Capability DeviceBarrierINTEL
35+
; CHECK-SPIRV: Extension "SPV_INTEL_device_barrier"
36+
; CHECK-SPIRV: TypeInt [[UINT:[0-9]+]] 32 0
37+
;
38+
;; When the SPV_INTEL_device_barrier extension is not enabled, a runtime variable
39+
;; should not cause the device barrier extension or capability to be declared.
40+
; CHECK-SPIRV-EXTENSION-DISABLED-NOT: Capability DeviceBarrierINTEL
41+
; CHECK-SPIRV-EXTENSION-DISABLED-NOT: Extension "SPV_INTEL_device_barrier"
42+
;
43+
; Scopes:
44+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_CROSS_DEVICE:[0-9]+]] 0 {{$}}
45+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_DEVICE:[0-9]+]] 1 {{$}}
46+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_WORK_GROUP:[0-9]+]] 2 {{$}}
47+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_SUBGROUP:[0-9]+]] 3 {{$}}
48+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_INVOCATION:[0-9]+]] 4 {{$}}
49+
;
50+
; Memory Semantics:
51+
; 0x8 AcquireRelease + 0x100 WorkgroupMemory
52+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[ACQREL_LOCAL:[0-9]+]] 264
53+
; 0x8 AcquireRelease + 0x200 CrossWorkgroupMemory
54+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[ACQREL_GLOBAL:[0-9]+]] 520
55+
; 0x8 AcquireRelease + 0x800 ImageMemory
56+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[ACQREL_IMAGE:[0-9]+]] 2056
57+
;
58+
; Runtime execution scope:
59+
; CHECK-SPIRV: FunctionCall [[#]] [[EXEC_SCOPE:[0-9]+]] [[#]]
60+
;
61+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_LOCAL]]
62+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_GLOBAL]]
63+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_IMAGE]]
64+
;
65+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_CROSS_DEVICE]] [[ACQREL_GLOBAL]]
66+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_GLOBAL]]
67+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_WORK_GROUP]] [[ACQREL_GLOBAL]]
68+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_SUBGROUP]] [[ACQREL_GLOBAL]]
69+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_INVOCATION]] [[ACQREL_GLOBAL]]
70+
71+
; CHECK-LLVM-LABEL: define spir_kernel void @test
72+
; Function Attrs: convergent norecurse nounwind
73+
define dso_local spir_kernel void @test(ptr addrspace(1) nocapture noundef readnone align 4 %0) local_unnamed_addr #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_base_type !6 !kernel_arg_type_qual !7 {
74+
%2 = call noundef i32 @magic_get_scope()
75+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 264) #2
76+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 264) #1
77+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 520) #2
78+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 520) #1
79+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 2056) #2
80+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 2056) #1
81+
82+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 0, i32 noundef 520) #2
83+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 0, i32 520) #1
84+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 520) #2
85+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 520) #1
86+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 2, i32 noundef 520) #2
87+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 2, i32 520) #1
88+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 3, i32 noundef 520) #2
89+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 3, i32 520) #1
90+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 4, i32 noundef 520) #2
91+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 4, i32 520) #1
92+
ret void
93+
}
94+
95+
; Function Attrs: convergent
96+
declare dso_local spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #1
97+
98+
declare spir_func i32 @magic_get_scope()
99+
100+
attributes #0 = { convergent norecurse nounwind "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" }
101+
attributes #1 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
102+
attributes #2 = { convergent nounwind }
103+
104+
!llvm.module.flags = !{!0, !1}
105+
!opencl.ocl.version = !{!2}
106+
!opencl.spir.version = !{!2}
107+
!llvm.ident = !{!3}
108+
109+
!0 = !{i32 1, !"wchar_size", i32 4}
110+
!1 = !{i32 7, !"frame-pointer", i32 2}
111+
!2 = !{i32 2, i32 0}
112+
!3 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project 861386dbd6ff0d91636b7c674c2abb2eccd9d3f2)"}
113+
!4 = !{i32 1}
114+
!5 = !{!"none"}
115+
!6 = !{!"uint*"}
116+
!7 = !{!""}

0 commit comments

Comments
 (0)