Skip to content

Commit 924dc17

Browse files
authored
Remove llvm-muladd pass and move it's functionality to to llvm-simdloop (JuliaLang#55802)
Closes JuliaLang#55785 I'm not sure if we want to backport this like this. Because that removes some functionality (the pass itself). So LLVM.jl and friends might need annoying version code. We can maybe keep the code there and just not run the pass in a backport.
1 parent b02d671 commit 924dc17

File tree

12 files changed

+125
-213
lines changed

12 files changed

+125
-213
lines changed

doc/src/devdocs/llvm-passes.md

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,6 @@ This pass is used to verify Julia's invariants about LLVM IR. This includes thin
114114

115115
These passes are used to perform transformations on LLVM IR that LLVM will not perform itself, e.g. fast math flag propagation, escape analysis, and optimizations on Julia-specific internal functions. They use knowledge about Julia's semantics to perform these optimizations.
116116

117-
### CombineMulAdd
118-
119-
* Filename: `llvm-muladd.cpp`
120-
* Class Name: `CombineMulAddPass`
121-
* Opt Name: `function(CombineMulAdd)`
122-
123-
This pass serves to optimize the particular combination of a regular `fmul` with a fast `fadd` into a contract `fmul` with a fast `fadd`. This is later optimized by the backend to a [fused multiply-add](https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation#Fused_multiply%E2%80%93add) instruction, which can provide significantly faster operations at the cost of more [unpredictable semantics](https://simonbyrne.github.io/notes/fastmath/).
124-
125-
!!! note
126-
127-
This optimization only occurs when the `fmul` has a single use, which is the fast `fadd`.
128-
129117
### AllocOpt
130118

131119
* Filename: `llvm-alloc-opt.cpp`

doc/src/devdocs/llvm.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ The code for lowering Julia AST to LLVM IR or interpreting it directly is in dir
3030
| `llvm-julia-licm.cpp` | Custom LLVM pass to hoist/sink Julia-specific intrinsics |
3131
| `llvm-late-gc-lowering.cpp` | Custom LLVM pass to root GC-tracked values |
3232
| `llvm-lower-handlers.cpp` | Custom LLVM pass to lower try-catch blocks |
33-
| `llvm-muladd.cpp` | Custom LLVM pass for fast-match FMA |
3433
| `llvm-multiversioning.cpp` | Custom LLVM pass to generate sysimg code on multiple architectures |
3534
| `llvm-propagate-addrspaces.cpp` | Custom LLVM pass to canonicalize addrspaces |
3635
| `llvm-ptls.cpp` | Custom LLVM pass to lower TLS operations |

src/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ RT_LLVMLINK :=
5252
CG_LLVMLINK :=
5353

5454
ifeq ($(JULIACODEGEN),LLVM)
55-
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop llvm-muladd \
55+
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop \
5656
llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering llvm-ptls \
5757
llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \
5858
llvm-multiversioning llvm-alloc-opt llvm-alloc-helpers cgmemmgr llvm-remove-addrspaces \

src/jl_exported_funcs.inc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,6 @@
554554
YY(LLVMExtraMPMAddRemoveAddrspacesPass) \
555555
YY(LLVMExtraMPMAddLowerPTLSPass) \
556556
YY(LLVMExtraFPMAddDemoteFloat16Pass) \
557-
YY(LLVMExtraFPMAddCombineMulAddPass) \
558557
YY(LLVMExtraFPMAddLateLowerGCPass) \
559558
YY(LLVMExtraFPMAddAllocOptPass) \
560559
YY(LLVMExtraFPMAddPropagateJuliaAddrspacesPass) \

src/llvm-julia-passes.inc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ MODULE_PASS("LowerPTLSPass", LowerPTLSPass, LowerPTLSPass())
1111
//Function passes
1212
#ifdef FUNCTION_PASS
1313
FUNCTION_PASS("DemoteFloat16", DemoteFloat16Pass, DemoteFloat16Pass())
14-
FUNCTION_PASS("CombineMulAdd", CombineMulAddPass, CombineMulAddPass())
1514
FUNCTION_PASS("LateLowerGCFrame", LateLowerGCPass, LateLowerGCPass())
1615
FUNCTION_PASS("AllocOpt", AllocOptPass, AllocOptPass())
1716
FUNCTION_PASS("PropagateJuliaAddrspaces", PropagateJuliaAddrspacesPass, PropagateJuliaAddrspacesPass())

src/llvm-muladd.cpp

Lines changed: 0 additions & 117 deletions
This file was deleted.

src/llvm-simdloop.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ STATISTIC(ReductionChainLength, "Total sum of instructions folded from reduction
4141
STATISTIC(MaxChainLength, "Max length of reduction chain");
4242
STATISTIC(AddChains, "Addition reduction chains");
4343
STATISTIC(MulChains, "Multiply reduction chains");
44+
STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");
4445

4546
#ifndef __clang_gcanalyzer__
4647
#define REMARK(remark) ORE.emit(remark)
@@ -49,6 +50,49 @@ STATISTIC(MulChains, "Multiply reduction chains");
4950
#endif
5051
namespace {
5152

53+
/**
54+
* Combine
55+
* ```
56+
* %v0 = fmul ... %a, %b
57+
* %v = fadd contract ... %v0, %c
58+
* ```
59+
* to
60+
* %v0 = fmul contract ... %a, %b
61+
* %v = fadd contract ... %v0, %c
62+
* when `%v0` has no other use
63+
*/
64+
65+
static bool checkCombine(Value *maybeMul, Loop &L, OptimizationRemarkEmitter &ORE) JL_NOTSAFEPOINT
66+
{
67+
auto mulOp = dyn_cast<Instruction>(maybeMul);
68+
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
69+
return false;
70+
if (!L.contains(mulOp))
71+
return false;
72+
if (!mulOp->hasOneUse()) {
73+
LLVM_DEBUG(dbgs() << "mulOp has multiple uses: " << *maybeMul << "\n");
74+
REMARK([&](){
75+
return OptimizationRemarkMissed(DEBUG_TYPE, "Multiuse FMul", mulOp)
76+
<< "fmul had multiple uses " << ore::NV("fmul", mulOp);
77+
});
78+
return false;
79+
}
80+
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
81+
auto fmf = mulOp->getFastMathFlags();
82+
if (!fmf.allowContract()) {
83+
LLVM_DEBUG(dbgs() << "Marking mulOp for FMA: " << *maybeMul << "\n");
84+
REMARK([&](){
85+
return OptimizationRemark(DEBUG_TYPE, "Marked for FMA", mulOp)
86+
<< "marked for fma " << ore::NV("fmul", mulOp);
87+
});
88+
++TotalContracted;
89+
fmf.setAllowContract(true);
90+
mulOp->copyFastMathFlags(fmf);
91+
return true;
92+
}
93+
return false;
94+
}
95+
5296
static unsigned getReduceOpcode(Instruction *J, Instruction *operand) JL_NOTSAFEPOINT
5397
{
5498
switch (J->getOpcode()) {
@@ -150,6 +194,28 @@ static void enableUnsafeAlgebraIfReduction(PHINode *Phi, Loop &L, OptimizationRe
150194
});
151195
(*K)->setHasAllowReassoc(true);
152196
(*K)->setHasAllowContract(true);
197+
switch ((*K)->getOpcode()) {
198+
case Instruction::FAdd: {
199+
if (!(*K)->hasAllowContract())
200+
continue;
201+
// (*K)->getOperand(0)->print(dbgs());
202+
// (*K)->getOperand(1)->print(dbgs());
203+
checkCombine((*K)->getOperand(0), L, ORE);
204+
checkCombine((*K)->getOperand(1), L, ORE);
205+
break;
206+
}
207+
case Instruction::FSub: {
208+
if (!(*K)->hasAllowContract())
209+
continue;
210+
// (*K)->getOperand(0)->print(dbgs());
211+
// (*K)->getOperand(1)->print(dbgs());
212+
checkCombine((*K)->getOperand(0), L, ORE);
213+
checkCombine((*K)->getOperand(1), L, ORE);
214+
break;
215+
}
216+
default:
217+
break;
218+
}
153219
if (SE)
154220
SE->forgetValue(*K);
155221
++length;

src/passes.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ struct DemoteFloat16Pass : PassInfoMixin<DemoteFloat16Pass> {
1515
static bool isRequired() { return true; }
1616
};
1717

18-
struct CombineMulAddPass : PassInfoMixin<CombineMulAddPass> {
19-
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
20-
};
21-
2218
struct LateLowerGCPass : PassInfoMixin<LateLowerGCPass> {
2319
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
2420
static bool isRequired() { return true; }

src/pipeline.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,6 @@ static void buildCleanupPipeline(ModulePassManager &MPM, PassBuilder *PB, Optimi
568568
if (options.cleanup) {
569569
if (O.getSpeedupLevel() >= 2) {
570570
FunctionPassManager FPM;
571-
JULIA_PASS(FPM.addPass(CombineMulAddPass()));
572571
FPM.addPass(DivRemPairsPass());
573572
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
574573
}

test/llvmpasses/julia-simdloop.ll

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes='loop(LowerSIMDLoop)' -S %s | FileCheck %s
44

55
; CHECK-LABEL: @simd_test(
6-
define void @simd_test(double *%a, double *%b) {
6+
define void @simd_test(ptr %a, ptr %b) {
77
top:
88
br label %loop
99
loop:
1010
%i = phi i64 [0, %top], [%nexti, %loop]
11-
%aptr = getelementptr double, double *%a, i64 %i
12-
%bptr = getelementptr double, double *%b, i64 %i
11+
%aptr = getelementptr double, ptr %a, i64 %i
12+
%bptr = getelementptr double, ptr %b, i64 %i
1313
; CHECK: llvm.mem.parallel_loop_access
14-
%aval = load double, double *%aptr
15-
%bval = load double, double *%aptr
14+
%aval = load double, ptr %aptr
15+
%bval = load double, ptr %aptr
1616
%cval = fadd double %aval, %bval
17-
store double %cval, double *%bptr
17+
store double %cval, ptr %bptr
1818
%nexti = add i64 %i, 1
1919
%done = icmp sgt i64 %nexti, 500
2020
br i1 %done, label %loopdone, label %loop, !llvm.loop !1
@@ -23,15 +23,15 @@ loopdone:
2323
}
2424

2525
; CHECK-LABEL: @simd_test_sub(
26-
define double @simd_test_sub(double *%a) {
26+
define double @simd_test_sub(ptr %a) {
2727
top:
2828
br label %loop
2929
loop:
3030
%i = phi i64 [0, %top], [%nexti, %loop]
3131
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
32-
%aptr = getelementptr double, double *%a, i64 %i
32+
%aptr = getelementptr double, ptr %a, i64 %i
3333
; CHECK: llvm.mem.parallel_loop_access
34-
%aval = load double, double *%aptr
34+
%aval = load double, ptr %aptr
3535
%nextv = fsub double %v, %aval
3636
; CHECK: fsub reassoc contract double %v, %aval
3737
%nexti = add i64 %i, 1
@@ -42,14 +42,14 @@ loopdone:
4242
}
4343

4444
; CHECK-LABEL: @simd_test_sub2(
45-
define double @simd_test_sub2(double *%a) {
45+
define double @simd_test_sub2(ptr %a) {
4646
top:
4747
br label %loop
4848
loop:
4949
%i = phi i64 [0, %top], [%nexti, %loop]
5050
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
51-
%aptr = getelementptr double, double *%a, i64 %i
52-
%aval = load double, double *%aptr
51+
%aptr = getelementptr double, ptr %a, i64 %i
52+
%aval = load double, ptr %aptr
5353
%nextv = fsub double %v, %aval
5454
; CHECK: fsub reassoc contract double %v, %aval
5555
%nexti = add i64 %i, 1
@@ -59,6 +59,26 @@ loopdone:
5959
ret double %nextv
6060
}
6161

62+
; CHECK-LABEL: @simd_test_sub4(
63+
define double @simd_test_sub4(ptr %a) {
64+
top:
65+
br label %loop
66+
loop:
67+
%i = phi i64 [0, %top], [%nexti, %loop]
68+
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
69+
%aptr = getelementptr double, double *%a, i64 %i
70+
%aval = load double, double *%aptr
71+
%nextv2 = fmul double %aval, %aval
72+
; CHECK: fmul contract double %aval, %aval
73+
%nextv = fsub double %v, %nextv2
74+
; CHECK: fsub reassoc contract double %v, %nextv2
75+
%nexti = add i64 %i, 1
76+
%done = icmp sgt i64 %nexti, 500
77+
br i1 %done, label %loopdone, label %loop, !llvm.loop !0
78+
loopdone:
79+
ret double %nextv
80+
}
81+
6282
; Tests if we correctly pass through other metadata
6383
; CHECK-LABEL: @disabled(
6484
define i32 @disabled(i32* noalias nocapture %a, i32* noalias nocapture readonly %b, i32 %N) {
@@ -82,6 +102,31 @@ for.end: ; preds = %for.body
82102
ret i32 %1
83103
}
84104

105+
; Check that we don't add contract to non loop things
106+
; CHECK-LABEL: @dont_add_no_loop(
107+
define double @dont_add_no_loop(ptr nocapture noundef nonnull readonly align 8 dereferenceable(72) %"a::Tuple", ptr nocapture noundef nonnull readonly align 8 dereferenceable(24) %"b::Tuple") #0 {
108+
top:
109+
%"a::Tuple[9]_ptr" = getelementptr inbounds i8, ptr %"a::Tuple", i64 64
110+
%"b::Tuple[3]_ptr" = getelementptr inbounds i8, ptr %"b::Tuple", i64 16
111+
%"a::Tuple[6]_ptr" = getelementptr inbounds i8, ptr %"a::Tuple", i64 40
112+
%"b::Tuple[2]_ptr" = getelementptr inbounds i8, ptr %"b::Tuple", i64 8
113+
%"a::Tuple[3]_ptr" = getelementptr inbounds i8, ptr %"a::Tuple", i64 16
114+
%"a::Tuple[3]_ptr.unbox" = load double, ptr %"a::Tuple[3]_ptr", align 8
115+
%"b::Tuple.unbox" = load double, ptr %"b::Tuple", align 8
116+
%0 = fmul double %"a::Tuple[3]_ptr.unbox", %"b::Tuple.unbox"
117+
; CHECK: fmul double %
118+
%"a::Tuple[6]_ptr.unbox" = load double, ptr %"a::Tuple[6]_ptr", align 8
119+
%"b::Tuple[2]_ptr.unbox" = load double, ptr %"b::Tuple[2]_ptr", align 8
120+
%1 = fmul contract double %"a::Tuple[6]_ptr.unbox", %"b::Tuple[2]_ptr.unbox"
121+
%2 = fadd contract double %0, %1
122+
%"a::Tuple[9]_ptr.unbox" = load double, ptr %"a::Tuple[9]_ptr", align 8
123+
%"b::Tuple[3]_ptr.unbox" = load double, ptr %"b::Tuple[3]_ptr", align 8
124+
%3 = fmul contract double %"a::Tuple[9]_ptr.unbox", %"b::Tuple[3]_ptr.unbox"
125+
%4 = fadd contract double %2, %3
126+
ret double %4
127+
}
128+
129+
85130
!0 = distinct !{!0, !"julia.simdloop"}
86131
!1 = distinct !{!1, !"julia.simdloop", !"julia.ivdep"}
87132
!2 = distinct !{!2, !"julia.simdloop", !"julia.ivdep", !3}

0 commit comments

Comments
 (0)