@@ -150,8 +150,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
150150
151151def doMulWide : Predicate<"doMulWide">;
152152
153- def allowFMA : Predicate<"allowFMA()">;
154- def noFMA : Predicate<"!allowFMA()">;
155153def allowUnsafeFPMath : Predicate<"allowUnsafeFPMath()">;
156154def noUnsafeFPMath : Predicate<"!allowUnsafeFPMath()">;
157155
@@ -367,167 +365,89 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
367365// This multiclass should be used for nodes that can be folded to make fma ops.
368366// In this case, we use the ".rn" variant when FMA is disabled, as this behaves
369367// just like the non ".rn" op, but prevents ptxas from creating FMAs.
370- multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
371- def f64rr :
372- NVPTXInst<(outs Float64Regs:$dst),
373- (ins Float64Regs:$a, Float64Regs:$b),
374- !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
375- [(set f64:$dst, (OpNode f64:$a, f64:$b))]>,
376- Requires<[allowFMA]>;
377- def f64ri :
378- NVPTXInst<(outs Float64Regs:$dst),
379- (ins Float64Regs:$a, f64imm:$b),
380- !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
381- [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>,
382- Requires<[allowFMA]>;
383- def f32rr_ftz :
384- NVPTXInst<(outs Float32Regs:$dst),
385- (ins Float32Regs:$a, Float32Regs:$b),
386- !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
387- [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
388- Requires<[allowFMA, doF32FTZ]>;
389- def f32ri_ftz :
390- NVPTXInst<(outs Float32Regs:$dst),
391- (ins Float32Regs:$a, f32imm:$b),
392- !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
393- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
394- Requires<[allowFMA, doF32FTZ]>;
395- def f32rr :
396- NVPTXInst<(outs Float32Regs:$dst),
397- (ins Float32Regs:$a, Float32Regs:$b),
398- !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
399- [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
400- Requires<[allowFMA]>;
401- def f32ri :
402- NVPTXInst<(outs Float32Regs:$dst),
403- (ins Float32Regs:$a, f32imm:$b),
404- !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
405- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
406- Requires<[allowFMA]>;
407-
408- def f16rr_ftz :
409- NVPTXInst<(outs Int16Regs:$dst),
410- (ins Int16Regs:$a, Int16Regs:$b),
411- !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
412- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
413- Requires<[useFP16Math, allowFMA, doF32FTZ]>;
414- def f16rr :
415- NVPTXInst<(outs Int16Regs:$dst),
416- (ins Int16Regs:$a, Int16Regs:$b),
417- !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
418- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
419- Requires<[useFP16Math, allowFMA]>;
368+ multiclass F3<string op_str, SDPatternOperator op_pat> {
369+ def f64rr :
370+ NVPTXInst<(outs Float64Regs:$dst),
371+ (ins Float64Regs:$a, Float64Regs:$b),
372+ op_str # ".f64 \t$dst, $a, $b;",
373+ [(set f64:$dst, (op_pat f64:$a, f64:$b))]>;
374+ def f64ri :
375+ NVPTXInst<(outs Float64Regs:$dst),
376+ (ins Float64Regs:$a, f64imm:$b),
377+ op_str # ".f64 \t$dst, $a, $b;",
378+ [(set f64:$dst, (op_pat f64:$a, fpimm:$b))]>;
379+ def f32rr_ftz :
380+ NVPTXInst<(outs Float32Regs:$dst),
381+ (ins Float32Regs:$a, Float32Regs:$b),
382+ op_str # ".ftz.f32 \t$dst, $a, $b;",
383+ [(set f32:$dst, (op_pat f32:$a, f32:$b))]>,
384+ Requires<[doF32FTZ]>;
385+ def f32ri_ftz :
386+ NVPTXInst<(outs Float32Regs:$dst),
387+ (ins Float32Regs:$a, f32imm:$b),
388+ op_str # ".ftz.f32 \t$dst, $a, $b;",
389+ [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>,
390+ Requires<[doF32FTZ]>;
391+ def f32rr :
392+ NVPTXInst<(outs Float32Regs:$dst),
393+ (ins Float32Regs:$a, Float32Regs:$b),
394+ op_str # ".f32 \t$dst, $a, $b;",
395+ [(set f32:$dst, (op_pat f32:$a, f32:$b))]>;
396+ def f32ri :
397+ NVPTXInst<(outs Float32Regs:$dst),
398+ (ins Float32Regs:$a, f32imm:$b),
399+ op_str # ".f32 \t$dst, $a, $b;",
400+ [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
401+
402+ def f16rr_ftz :
403+ NVPTXInst<(outs Int16Regs:$dst),
404+ (ins Int16Regs:$a, Int16Regs:$b),
405+ op_str # ".ftz.f16 \t$dst, $a, $b;",
406+ [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
407+ Requires<[useFP16Math, doF32FTZ]>;
408+ def f16rr :
409+ NVPTXInst<(outs Int16Regs:$dst),
410+ (ins Int16Regs:$a, Int16Regs:$b),
411+ op_str # ".f16 \t$dst, $a, $b;",
412+ [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
413+ Requires<[useFP16Math]>;
414+
415+ def f16x2rr_ftz :
416+ NVPTXInst<(outs Int32Regs:$dst),
417+ (ins Int32Regs:$a, Int32Regs:$b),
418+ op_str # ".ftz.f16x2 \t$dst, $a, $b;",
419+ [(set v2f16:$dst, (op_pat v2f16:$a, v2f16:$b))]>,
420+ Requires<[useFP16Math, doF32FTZ]>;
421+ def f16x2rr :
422+ NVPTXInst<(outs Int32Regs:$dst),
423+ (ins Int32Regs:$a, Int32Regs:$b),
424+ op_str # ".f16x2 \t$dst, $a, $b;",
425+ [(set v2f16:$dst, (op_pat v2f16:$a, v2f16:$b))]>,
426+ Requires<[useFP16Math]>;
427+ def bf16rr :
428+ NVPTXInst<(outs Int16Regs:$dst),
429+ (ins Int16Regs:$a, Int16Regs:$b),
430+ op_str # ".bf16 \t$dst, $a, $b;",
431+ [(set bf16:$dst, (op_pat bf16:$a, bf16:$b))]>,
432+ Requires<[hasBF16Math]>;
433+
434+ def bf16x2rr :
435+ NVPTXInst<(outs Int32Regs:$dst),
436+ (ins Int32Regs:$a, Int32Regs:$b),
437+ op_str # ".bf16x2 \t$dst, $a, $b;",
438+ [(set v2bf16:$dst, (op_pat v2bf16:$a, v2bf16:$b))]>,
439+ Requires<[hasBF16Math]>;
440+ }
420441
421- def f16x2rr_ftz :
422- NVPTXInst<(outs Int32Regs:$dst),
423- (ins Int32Regs:$a, Int32Regs:$b),
424- !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
425- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
426- Requires<[useFP16Math, allowFMA, doF32FTZ]>;
427- def f16x2rr :
428- NVPTXInst<(outs Int32Regs:$dst),
429- (ins Int32Regs:$a, Int32Regs:$b),
430- !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
431- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
432- Requires<[useFP16Math, allowFMA]>;
433- def bf16rr :
434- NVPTXInst<(outs Int16Regs:$dst),
435- (ins Int16Regs:$a, Int16Regs:$b),
436- !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
437- [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
438- Requires<[hasBF16Math, allowFMA]>;
442+ class BinOpAllowsFMA<SDPatternOperator operator>
443+ : PatFrag<(ops node:$A, node:$B),
444+ (operator node:$A, node:$B), [{
445+ return allowFMA() || N->getFlags().hasAllowContract();;
446+ }]>;
439447
440- def bf16x2rr :
441- NVPTXInst<(outs Int32Regs:$dst),
442- (ins Int32Regs:$a, Int32Regs:$b),
443- !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
444- [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
445- Requires<[hasBF16Math, allowFMA]>;
446- // These have strange names so we don't perturb existing mir tests.
447- def _rnf64rr :
448- NVPTXInst<(outs Float64Regs:$dst),
449- (ins Float64Regs:$a, Float64Regs:$b),
450- !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"),
451- [(set f64:$dst, (OpNode f64:$a, f64:$b))]>,
452- Requires<[noFMA]>;
453- def _rnf64ri :
454- NVPTXInst<(outs Float64Regs:$dst),
455- (ins Float64Regs:$a, f64imm:$b),
456- !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"),
457- [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>,
458- Requires<[noFMA]>;
459- def _rnf32rr_ftz :
460- NVPTXInst<(outs Float32Regs:$dst),
461- (ins Float32Regs:$a, Float32Regs:$b),
462- !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"),
463- [(set f32:$dst, (OpNode f32:$a, Float32Regs:$b))]>,
464- Requires<[noFMA, doF32FTZ]>;
465- def _rnf32ri_ftz :
466- NVPTXInst<(outs Float32Regs:$dst),
467- (ins Float32Regs:$a, f32imm:$b),
468- !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"),
469- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
470- Requires<[noFMA, doF32FTZ]>;
471- def _rnf32rr :
472- NVPTXInst<(outs Float32Regs:$dst),
473- (ins Float32Regs:$a, Float32Regs:$b),
474- !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"),
475- [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
476- Requires<[noFMA]>;
477- def _rnf32ri :
478- NVPTXInst<(outs Float32Regs:$dst),
479- (ins Float32Regs:$a, f32imm:$b),
480- !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"),
481- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
482- Requires<[noFMA]>;
483- def _rnf16rr_ftz :
484- NVPTXInst<(outs Int16Regs:$dst),
485- (ins Int16Regs:$a, Int16Regs:$b),
486- !strconcat(OpcStr, ".rn.ftz.f16 \t$dst, $a, $b;"),
487- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
488- Requires<[useFP16Math, noFMA, doF32FTZ]>;
489- def _rnf16rr :
490- NVPTXInst<(outs Int16Regs:$dst),
491- (ins Int16Regs:$a, Int16Regs:$b),
492- !strconcat(OpcStr, ".rn.f16 \t$dst, $a, $b;"),
493- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
494- Requires<[useFP16Math, noFMA]>;
495- def _rnf16x2rr_ftz :
496- NVPTXInst<(outs Int32Regs:$dst),
497- (ins Int32Regs:$a, Int32Regs:$b),
498- !strconcat(OpcStr, ".rn.ftz.f16x2 \t$dst, $a, $b;"),
499- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
500- Requires<[useFP16Math, noFMA, doF32FTZ]>;
501- def _rnf16x2rr :
502- NVPTXInst<(outs Int32Regs:$dst),
503- (ins Int32Regs:$a, Int32Regs:$b),
504- !strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"),
505- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
506- Requires<[useFP16Math, noFMA]>;
507- def _rnbf16rr_ftz :
508- NVPTXInst<(outs Int16Regs:$dst),
509- (ins Int16Regs:$a, Int16Regs:$b),
510- !strconcat(OpcStr, ".rn.ftz.bf16 \t$dst, $a, $b;"),
511- [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
512- Requires<[hasBF16Math, noFMA, doF32FTZ]>;
513- def _rnbf16rr :
514- NVPTXInst<(outs Int16Regs:$dst),
515- (ins Int16Regs:$a, Int16Regs:$b),
516- !strconcat(OpcStr, ".rn.bf16 \t$dst, $a, $b;"),
517- [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
518- Requires<[hasBF16Math, noFMA]>;
519- def _rnbf16x2rr_ftz :
520- NVPTXInst<(outs Int32Regs:$dst),
521- (ins Int32Regs:$a, Int32Regs:$b),
522- !strconcat(OpcStr, ".rn.ftz.bf16x2 \t$dst, $a, $b;"),
523- [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
524- Requires<[hasBF16Math, noFMA, doF32FTZ]>;
525- def _rnbf16x2rr :
526- NVPTXInst<(outs Int32Regs:$dst),
527- (ins Int32Regs:$a, Int32Regs:$b),
528- !strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"),
529- [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
530- Requires<[hasBF16Math, noFMA]>;
448+ multiclass F3_fma_component<string op_str, SDNode op_node> {
449+ defm "" : F3<op_str, BinOpAllowsFMA<op_node>>;
450+ defm _rn : F3<op_str # ".rn", op_node>;
531451}
532452
533453// Template for operations which take two f32 or f64 operands. Provides three
0 commit comments