|
| 1 | +/* Copyright 2024 The OpenXLA Authors. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_ |
| 17 | +#define TENSORFLOW_COMPILER_XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_ |
| 18 | + |
| 19 | +#include <cstdint> |
| 20 | +#include <optional> |
| 21 | +#include <vector> |
| 22 | + |
| 23 | +#include "llvm/Support/Casting.h" |
| 24 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 25 | +#include "mlir/IR/Builders.h" |
| 26 | +#include "mlir/IR/BuiltinTypes.h" |
| 27 | +#include "mlir/IR/ImplicitLocOpBuilder.h" |
| 28 | +#include "mlir/IR/Location.h" |
| 29 | +#include "mlir/IR/TypeUtilities.h" |
| 30 | +#include "mlir/IR/Value.h" |
| 31 | +#include "mlir/Support/LLVM.h" |
| 32 | + |
| 33 | +namespace mlir::mhlo { |
| 34 | + |
| 35 | +// Creates an integer constant that is either a tensor (if shape is provided) or |
| 36 | +// a scalar. |
| 37 | +template <typename T> |
| 38 | +arith::ConstantOp createConst(ImplicitLocOpBuilder& b, |
| 39 | + mlir::IntegerType intType, T value, |
| 40 | + std::optional<ArrayRef<int64_t>> shape) { |
| 41 | + if (shape.has_value()) { |
| 42 | + auto tensorType = mlir::RankedTensorType::get(shape.value(), intType); |
| 43 | + return b.create<arith::ConstantOp>(mlir::DenseElementsAttr::get( |
| 44 | + tensorType, mlir::APInt(intType.getIntOrFloatBitWidth(), value))); |
| 45 | + } |
| 46 | + return b.create<arith::ConstantOp>(b.getIntegerAttr(intType, value)); |
| 47 | +} |
| 48 | + |
| 49 | +// Returns the input value with a reduced precision as specified by the target |
| 50 | +// exponent and mantissa bits. This function will preserve the input shape on |
| 51 | +// the output - i.e. it works with both scalars and tensors. |
| 52 | +// |
| 53 | +// The templated bitcast type allows this function to work with different kinds |
| 54 | +// of bitcats, e.g. `arith.bitcast` or `triton.bitcast`. |
| 55 | +template <typename BitCastOp> |
| 56 | +Value reducePrecision(Location loc, Value input, int destExponentBits, |
| 57 | + int destMantissaBits, OpBuilder* builder) { |
| 58 | + using llvm::APInt; |
| 59 | + mlir::ImplicitLocOpBuilder b(loc, *builder); |
| 60 | + |
| 61 | + // Integer and float types for casting and constant generation. |
| 62 | + auto floatType = mlir::cast<FloatType>(getElementTypeOrSelf(input.getType())); |
| 63 | + int64_t nbits = floatType.getWidth(); |
| 64 | + auto intScalarType = mlir::IntegerType::get(loc.getContext(), nbits); |
| 65 | + |
| 66 | + Type intType = intScalarType; |
| 67 | + std::optional<std::vector<int64_t>> shape; |
| 68 | + if (auto tensorType = llvm::dyn_cast<TensorType>(input.getType())) { |
| 69 | + shape = tensorType.getShape().vec(); |
| 70 | + intType = tensorType.clone(intScalarType); |
| 71 | + } |
| 72 | + |
| 73 | + Value xAsInt = b.create<BitCastOp>(intType, input); |
| 74 | + |
| 75 | + // SignificandWidth includes the implicit extra bit. |
| 76 | + auto srcMantissaBits = floatType.getFPMantissaWidth() - 1; |
| 77 | + int srcExponentBits = nbits - 1 - srcMantissaBits; |
| 78 | + |
| 79 | + // Clear the sign bit, it does not participate in rounding and we will restore |
| 80 | + // it later. |
| 81 | + APInt signBitMask(nbits, 1); |
| 82 | + signBitMask <<= nbits - 1; |
| 83 | + |
| 84 | + APInt expBitsMask(nbits, 1); |
| 85 | + expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits; |
| 86 | + |
| 87 | + auto createConstant = [&](const APInt& v) { |
| 88 | + return createConst(b, intScalarType, v.getZExtValue(), shape); |
| 89 | + }; |
| 90 | + |
| 91 | + Value xAbsBits = |
| 92 | + b.create<arith::AndIOp>(xAsInt, createConstant(~signBitMask)); |
| 93 | + Value xIsNan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, xAbsBits, |
| 94 | + createConstant(expBitsMask)); |
| 95 | + |
| 96 | + if (destMantissaBits < static_cast<int>(srcMantissaBits)) { |
| 97 | + // Last remaining mantissa bit. |
| 98 | + APInt lastMantissaBitMask(nbits, 1); |
| 99 | + lastMantissaBitMask <<= srcMantissaBits - destMantissaBits; |
| 100 | + |
| 101 | + // Compute rounding bias for round-to-nearest with ties to even. This is |
| 102 | + // equal to a base value of 0111... plus one bit if the last remaining |
| 103 | + // mantissa bit is 1. |
| 104 | + APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1; |
| 105 | + |
| 106 | + Value mantissaDiff = createConst(b, intScalarType, |
| 107 | + srcMantissaBits - destMantissaBits, shape); |
| 108 | + |
| 109 | + Value highestMantissaMaskVal = createConstant(lastMantissaBitMask); |
| 110 | + Value baseRoundingBiasVal = createConstant(baseRoundingBias); |
| 111 | + Value xLastMantissaBit = b.create<arith::ShRUIOp>( |
| 112 | + b.create<arith::AndIOp>(xAsInt, highestMantissaMaskVal), mantissaDiff); |
| 113 | + Value xRoundingBias = |
| 114 | + b.create<arith::AddIOp>(xLastMantissaBit, baseRoundingBiasVal); |
| 115 | + |
| 116 | + // Add rounding bias, and mask out truncated bits. Note that the case |
| 117 | + // where adding the rounding bias overflows into the exponent bits is |
| 118 | + // correct; the non-masked mantissa bits will all be zero, and the |
| 119 | + // exponent will be incremented by one. |
| 120 | + APInt truncationMask = ~(lastMantissaBitMask - 1); |
| 121 | + Value xRounded = b.create<arith::AddIOp>(xAsInt, xRoundingBias); |
| 122 | + xAsInt = b.create<arith::AndIOp>(xRounded, createConstant(truncationMask)); |
| 123 | + } |
| 124 | + |
| 125 | + if (destExponentBits < srcExponentBits) { |
| 126 | + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- |
| 127 | + // significant bit -- is equal to 1.0f for all exponent sizes. Adding |
| 128 | + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- |
| 129 | + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' |
| 130 | + // exponent (corresponding to 0.0f). |
| 131 | + // |
| 132 | + // Thus, the f32 exponent corresponding to the highest non-infinite |
| 133 | + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 |
| 134 | + // exponent corresponding to the lowest exponent for a bit size of n is |
| 135 | + // (2^7-1) - 2^(n-1)-1. |
| 136 | + // |
| 137 | + // Note that we have already checked that exponents_bits >= 1. |
| 138 | + APInt exponentBias(nbits, 1); |
| 139 | + exponentBias = (exponentBias << (srcExponentBits - 1)) - 1; |
| 140 | + |
| 141 | + APInt reducedExponentBias(nbits, 1); |
| 142 | + reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1; |
| 143 | + |
| 144 | + APInt reducedMaxExponent = exponentBias + reducedExponentBias; |
| 145 | + APInt reducedMinExponent = exponentBias - reducedExponentBias; |
| 146 | + |
| 147 | + // Do we overflow or underflow? |
| 148 | + Value xExponent = |
| 149 | + b.create<arith::AndIOp>(xAsInt, createConstant(expBitsMask)); |
| 150 | + Value xOverflows = b.create<arith::CmpIOp>( |
| 151 | + arith::CmpIPredicate::ugt, xExponent, |
| 152 | + createConstant(reducedMaxExponent << srcMantissaBits)); |
| 153 | + Value xUnderflows = b.create<arith::CmpIOp>( |
| 154 | + arith::CmpIPredicate::ule, xExponent, |
| 155 | + createConstant(reducedMinExponent << srcMantissaBits)); |
| 156 | + |
| 157 | + // Compute appropriately-signed values of zero and infinity. |
| 158 | + Value xSignedZero = |
| 159 | + b.create<arith::AndIOp>(xAsInt, createConstant(signBitMask)); |
| 160 | + Value xSignedInf = |
| 161 | + b.create<arith::OrIOp>(xSignedZero, createConstant(expBitsMask)); |
| 162 | + |
| 163 | + // Force to zero or infinity if overflow or underflow. (Note that this |
| 164 | + // truncates all denormal values to zero, rather than rounding them.) |
| 165 | + xAsInt = b.create<arith::SelectOp>(xOverflows, xSignedInf, xAsInt); |
| 166 | + xAsInt = b.create<arith::SelectOp>(xUnderflows, xSignedZero, xAsInt); |
| 167 | + } |
| 168 | + |
| 169 | + Value result = b.create<BitCastOp>(input.getType(), xAsInt); |
| 170 | + return b.create<arith::SelectOp>(xIsNan, input, result); |
| 171 | +} |
| 172 | +} // namespace mlir::mhlo |
| 173 | + |
| 174 | +#endif // TENSORFLOW_COMPILER_XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_ |
0 commit comments