Skip to content

Commit 3bf45d7

Browse files
dimitar-asenovTensorFlow MLIR Team
authored andcommitted
[XLA:GPU] Generalize the Reduce-Precision in MHLO to also work on Tensors and use it in a Triton emitter.
PiperOrigin-RevId: 681364188
1 parent ad1f329 commit 3bf45d7

File tree

3 files changed

+196
-110
lines changed

3 files changed

+196
-110
lines changed

BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,23 @@ cc_library(
482482
name = "map_mhlo_to_scalar_op",
483483
hdrs = ["mhlo/transforms/map_mhlo_to_scalar_op.h"],
484484
strip_include_prefix = ".",
485+
deps = [
486+
":mlir_hlo",
487+
":transformation_helpers",
488+
"@llvm-project//llvm:Support",
489+
"@llvm-project//mlir:ArithDialect",
490+
"@llvm-project//mlir:ComplexDialect",
491+
"@llvm-project//mlir:IR",
492+
"@llvm-project//mlir:MathDialect",
493+
"@llvm-project//mlir:SCFDialect",
494+
"@llvm-project//mlir:Support",
495+
],
496+
)
497+
498+
cc_library(
499+
name = "transformation_helpers",
500+
hdrs = ["mhlo/transforms/transformation_helpers.h"],
501+
strip_include_prefix = ".",
485502
deps = [
486503
":mlir_hlo",
487504
"@llvm-project//llvm:Support",

mhlo/transforms/map_mhlo_to_scalar_op.h

Lines changed: 5 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include "llvm/ADT/StringRef.h"
2323
#include "llvm/ADT/StringSwitch.h"
2424
#include "mhlo/IR/hlo_ops.h"
25+
#include "mhlo/transforms/transformation_helpers.h"
2526
#include "mlir/Dialect/Arith/IR/Arith.h"
2627
#include "mlir/Dialect/Complex/IR/Complex.h"
2728
#include "mlir/Dialect/Math/IR/Math.h"
@@ -469,117 +470,11 @@ inline Value mapMhloOpToStdScalarOp<mhlo::CompareOp>(
469470

470471
template <>
471472
inline Value mapMhloOpToStdScalarOp<mhlo::ReducePrecisionOp>(
472-
Location loc, ArrayRef<Type> /*resultTypes*/, ArrayRef<Type> argTypes,
473+
Location loc, ArrayRef<Type> /*resultTypes*/, ArrayRef<Type> /*argTypes*/,
473474
mhlo::ReducePrecisionOp::Adaptor adaptor, OpBuilder* builder) {
474-
using llvm::APInt;
475-
mlir::ImplicitLocOpBuilder b(loc, *builder);
476-
477-
// Integer and float types for casting and constant generation.
478-
auto floatType =
479-
mlir::cast<FloatType>(getElementTypeOrSelf(argTypes.front()));
480-
int64_t nbits = floatType.getWidth();
481-
auto intType = mlir::IntegerType::get(loc.getContext(), nbits);
482-
483-
Value xAsInt = b.create<arith::BitcastOp>(intType, adaptor.getOperand());
484-
485-
// SignificandWidth includes the implicit extra bit.
486-
auto srcMantissaBits = floatType.getFPMantissaWidth() - 1;
487-
int srcExponentBits = nbits - 1 - srcMantissaBits;
488-
489-
// Clear the sign bit, it does not participate in rounding and we will restore
490-
// it later.
491-
APInt signBitMask(nbits, 1);
492-
signBitMask <<= nbits - 1;
493-
494-
APInt expBitsMask(nbits, 1);
495-
expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits;
496-
497-
auto createConstant = [&](const APInt& v) {
498-
return b.create<arith::ConstantIntOp>(v.getZExtValue(), intType)
499-
.getResult();
500-
};
501-
502-
Value xAbsBits =
503-
b.create<arith::AndIOp>(xAsInt, createConstant(~signBitMask));
504-
Value xIsNan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, xAbsBits,
505-
createConstant(expBitsMask));
506-
507-
int destMantissaBits = adaptor.getMantissaBits();
508-
if (destMantissaBits < static_cast<int>(srcMantissaBits)) {
509-
// Last remaining mantissa bit.
510-
APInt lastMantissaBitMask(nbits, 1);
511-
lastMantissaBitMask <<= srcMantissaBits - destMantissaBits;
512-
513-
// Compute rounding bias for round-to-nearest with ties to even. This is
514-
// equal to a base value of 0111... plus one bit if the last remaining
515-
// mantissa bit is 1.
516-
APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1;
517-
518-
Value mantissaDiff = b.create<arith::ConstantIntOp>(
519-
srcMantissaBits - destMantissaBits, intType);
520-
Value highestMantissaMaskVal = createConstant(lastMantissaBitMask);
521-
Value baseRoundingBiasVal = createConstant(baseRoundingBias);
522-
Value xLastMantissaBit = b.create<arith::ShRUIOp>(
523-
b.create<arith::AndIOp>(xAsInt, highestMantissaMaskVal), mantissaDiff);
524-
Value xRoundingBias =
525-
b.create<arith::AddIOp>(xLastMantissaBit, baseRoundingBiasVal);
526-
527-
// Add rounding bias, and mask out truncated bits. Note that the case
528-
// where adding the rounding bias overflows into the exponent bits is
529-
// correct; the non-masked mantissa bits will all be zero, and the
530-
// exponent will be incremented by one.
531-
APInt truncationMask = ~(lastMantissaBitMask - 1);
532-
Value xRounded = b.create<arith::AddIOp>(xAsInt, xRoundingBias);
533-
xAsInt = b.create<arith::AndIOp>(xRounded, createConstant(truncationMask));
534-
}
535-
536-
int destExponentBits = adaptor.getExponentBits();
537-
if (destExponentBits < srcExponentBits) {
538-
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
539-
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
540-
// 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
541-
// size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
542-
// exponent (corresponding to 0.0f).
543-
//
544-
// Thus, the f32 exponent corresponding to the highest non-infinite
545-
// exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
546-
// exponent corresponding to the lowest exponent for a bit size of n is
547-
// (2^7-1) - 2^(n-1)-1.
548-
//
549-
// Note that we have already checked that exponents_bits >= 1.
550-
APInt exponentBias(nbits, 1);
551-
exponentBias = (exponentBias << (srcExponentBits - 1)) - 1;
552-
553-
APInt reducedExponentBias(nbits, 1);
554-
reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1;
555-
556-
APInt reducedMaxExponent = exponentBias + reducedExponentBias;
557-
APInt reducedMinExponent = exponentBias - reducedExponentBias;
558-
559-
// Do we overflow or underflow?
560-
Value xExponent =
561-
b.create<arith::AndIOp>(xAsInt, createConstant(expBitsMask));
562-
Value xOverflows = b.create<arith::CmpIOp>(
563-
arith::CmpIPredicate::ugt, xExponent,
564-
createConstant(reducedMaxExponent << srcMantissaBits));
565-
Value xUnderflows = b.create<arith::CmpIOp>(
566-
arith::CmpIPredicate::ule, xExponent,
567-
createConstant(reducedMinExponent << srcMantissaBits));
568-
569-
// Compute appropriately-signed values of zero and infinity.
570-
Value xSignedZero =
571-
b.create<arith::AndIOp>(xAsInt, createConstant(signBitMask));
572-
Value xSignedInf =
573-
b.create<arith::OrIOp>(xSignedZero, createConstant(expBitsMask));
574-
575-
// Force to zero or infinity if overflow or underflow. (Note that this
576-
// truncates all denormal values to zero, rather than rounding them.)
577-
xAsInt = b.create<arith::SelectOp>(xOverflows, xSignedInf, xAsInt);
578-
xAsInt = b.create<arith::SelectOp>(xUnderflows, xSignedZero, xAsInt);
579-
}
580-
581-
Value result = b.create<arith::BitcastOp>(floatType, xAsInt);
582-
return b.create<arith::SelectOp>(xIsNan, adaptor.getOperand(), result);
475+
return reducePrecision<arith::BitcastOp>(loc, adaptor.getOperand(),
476+
adaptor.getExponentBits(),
477+
adaptor.getMantissaBits(), builder);
583478
}
584479

585480
template <>
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_COMPILER_XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_
17+
#define TENSORFLOW_COMPILER_XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_
18+
19+
#include <cstdint>
20+
#include <optional>
21+
#include <vector>
22+
23+
#include "llvm/Support/Casting.h"
24+
#include "mlir/Dialect/Arith/IR/Arith.h"
25+
#include "mlir/IR/Builders.h"
26+
#include "mlir/IR/BuiltinTypes.h"
27+
#include "mlir/IR/ImplicitLocOpBuilder.h"
28+
#include "mlir/IR/Location.h"
29+
#include "mlir/IR/TypeUtilities.h"
30+
#include "mlir/IR/Value.h"
31+
#include "mlir/Support/LLVM.h"
32+
33+
namespace mlir::mhlo {
34+
35+
// Creates an integer constant that is either a tensor (if shape is provided) or
36+
// a scalar.
37+
template <typename T>
38+
arith::ConstantOp createConst(ImplicitLocOpBuilder& b,
39+
mlir::IntegerType intType, T value,
40+
std::optional<ArrayRef<int64_t>> shape) {
41+
if (shape.has_value()) {
42+
auto tensorType = mlir::RankedTensorType::get(shape.value(), intType);
43+
return b.create<arith::ConstantOp>(mlir::DenseElementsAttr::get(
44+
tensorType, mlir::APInt(intType.getIntOrFloatBitWidth(), value)));
45+
}
46+
return b.create<arith::ConstantOp>(b.getIntegerAttr(intType, value));
47+
}
48+
49+
// Returns the input value with a reduced precision as specified by the target
50+
// exponent and mantissa bits. This function will preserve the input shape on
51+
// the output - i.e. it works with both scalars and tensors.
52+
//
53+
// The templated bitcast type allows this function to work with different kinds
54+
// of bitcats, e.g. `arith.bitcast` or `triton.bitcast`.
55+
template <typename BitCastOp>
56+
Value reducePrecision(Location loc, Value input, int destExponentBits,
57+
int destMantissaBits, OpBuilder* builder) {
58+
using llvm::APInt;
59+
mlir::ImplicitLocOpBuilder b(loc, *builder);
60+
61+
// Integer and float types for casting and constant generation.
62+
auto floatType = mlir::cast<FloatType>(getElementTypeOrSelf(input.getType()));
63+
int64_t nbits = floatType.getWidth();
64+
auto intScalarType = mlir::IntegerType::get(loc.getContext(), nbits);
65+
66+
Type intType = intScalarType;
67+
std::optional<std::vector<int64_t>> shape;
68+
if (auto tensorType = llvm::dyn_cast<TensorType>(input.getType())) {
69+
shape = tensorType.getShape().vec();
70+
intType = tensorType.clone(intScalarType);
71+
}
72+
73+
Value xAsInt = b.create<BitCastOp>(intType, input);
74+
75+
// SignificandWidth includes the implicit extra bit.
76+
auto srcMantissaBits = floatType.getFPMantissaWidth() - 1;
77+
int srcExponentBits = nbits - 1 - srcMantissaBits;
78+
79+
// Clear the sign bit, it does not participate in rounding and we will restore
80+
// it later.
81+
APInt signBitMask(nbits, 1);
82+
signBitMask <<= nbits - 1;
83+
84+
APInt expBitsMask(nbits, 1);
85+
expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits;
86+
87+
auto createConstant = [&](const APInt& v) {
88+
return createConst(b, intScalarType, v.getZExtValue(), shape);
89+
};
90+
91+
Value xAbsBits =
92+
b.create<arith::AndIOp>(xAsInt, createConstant(~signBitMask));
93+
Value xIsNan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, xAbsBits,
94+
createConstant(expBitsMask));
95+
96+
if (destMantissaBits < static_cast<int>(srcMantissaBits)) {
97+
// Last remaining mantissa bit.
98+
APInt lastMantissaBitMask(nbits, 1);
99+
lastMantissaBitMask <<= srcMantissaBits - destMantissaBits;
100+
101+
// Compute rounding bias for round-to-nearest with ties to even. This is
102+
// equal to a base value of 0111... plus one bit if the last remaining
103+
// mantissa bit is 1.
104+
APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1;
105+
106+
Value mantissaDiff = createConst(b, intScalarType,
107+
srcMantissaBits - destMantissaBits, shape);
108+
109+
Value highestMantissaMaskVal = createConstant(lastMantissaBitMask);
110+
Value baseRoundingBiasVal = createConstant(baseRoundingBias);
111+
Value xLastMantissaBit = b.create<arith::ShRUIOp>(
112+
b.create<arith::AndIOp>(xAsInt, highestMantissaMaskVal), mantissaDiff);
113+
Value xRoundingBias =
114+
b.create<arith::AddIOp>(xLastMantissaBit, baseRoundingBiasVal);
115+
116+
// Add rounding bias, and mask out truncated bits. Note that the case
117+
// where adding the rounding bias overflows into the exponent bits is
118+
// correct; the non-masked mantissa bits will all be zero, and the
119+
// exponent will be incremented by one.
120+
APInt truncationMask = ~(lastMantissaBitMask - 1);
121+
Value xRounded = b.create<arith::AddIOp>(xAsInt, xRoundingBias);
122+
xAsInt = b.create<arith::AndIOp>(xRounded, createConstant(truncationMask));
123+
}
124+
125+
if (destExponentBits < srcExponentBits) {
126+
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
127+
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
128+
// 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
129+
// size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
130+
// exponent (corresponding to 0.0f).
131+
//
132+
// Thus, the f32 exponent corresponding to the highest non-infinite
133+
// exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
134+
// exponent corresponding to the lowest exponent for a bit size of n is
135+
// (2^7-1) - 2^(n-1)-1.
136+
//
137+
// Note that we have already checked that exponents_bits >= 1.
138+
APInt exponentBias(nbits, 1);
139+
exponentBias = (exponentBias << (srcExponentBits - 1)) - 1;
140+
141+
APInt reducedExponentBias(nbits, 1);
142+
reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1;
143+
144+
APInt reducedMaxExponent = exponentBias + reducedExponentBias;
145+
APInt reducedMinExponent = exponentBias - reducedExponentBias;
146+
147+
// Do we overflow or underflow?
148+
Value xExponent =
149+
b.create<arith::AndIOp>(xAsInt, createConstant(expBitsMask));
150+
Value xOverflows = b.create<arith::CmpIOp>(
151+
arith::CmpIPredicate::ugt, xExponent,
152+
createConstant(reducedMaxExponent << srcMantissaBits));
153+
Value xUnderflows = b.create<arith::CmpIOp>(
154+
arith::CmpIPredicate::ule, xExponent,
155+
createConstant(reducedMinExponent << srcMantissaBits));
156+
157+
// Compute appropriately-signed values of zero and infinity.
158+
Value xSignedZero =
159+
b.create<arith::AndIOp>(xAsInt, createConstant(signBitMask));
160+
Value xSignedInf =
161+
b.create<arith::OrIOp>(xSignedZero, createConstant(expBitsMask));
162+
163+
// Force to zero or infinity if overflow or underflow. (Note that this
164+
// truncates all denormal values to zero, rather than rounding them.)
165+
xAsInt = b.create<arith::SelectOp>(xOverflows, xSignedInf, xAsInt);
166+
xAsInt = b.create<arith::SelectOp>(xUnderflows, xSignedZero, xAsInt);
167+
}
168+
169+
Value result = b.create<BitCastOp>(input.getType(), xAsInt);
170+
return b.create<arith::SelectOp>(xIsNan, input, result);
171+
}
172+
} // namespace mlir::mhlo
173+
174+
#endif // TENSORFLOW_COMPILER_XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_

0 commit comments

Comments
 (0)