Skip to content

Commit 575a6f8

Browse files
committed
[flang] add ExtendedValue type helpers and factory::genZeroValue
Add some helpers to get the base type and element type of fir::ExtendedValue and to test if a fir::ExtendedValue is a derived type with length parameters. Add a new helper factory::genZeroValue to generate zero scalar value for all the numerical types and false for logicals. These helpers are used only in lowering for now, so add unit tests. Differential Revision: https://reviews.llvm.org/D118795
1 parent 7cc3e02 commit 575a6f8

File tree

4 files changed

+159
-0
lines changed

4 files changed

+159
-0
lines changed

flang/include/flang/Optimizer/Builder/BoxValue.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,30 @@ inline bool isUnboxedValue(const ExtendedValue &exv) {
467467
[](const fir::UnboxedValue &box) { return box ? true : false; },
468468
[](const auto &) { return false; });
469469
}
470+
471+
/// Returns the base type of \p exv. This is the type of \p exv
472+
/// without any memory or box type. The sequence type, if any, is kept.
473+
inline mlir::Type getBaseTypeOf(const ExtendedValue &exv) {
474+
return exv.match(
475+
[](const fir::MutableBoxValue &box) { return box.getBaseTy(); },
476+
[](const fir::BoxValue &box) { return box.getBaseTy(); },
477+
[&](const auto &) {
478+
return fir::unwrapRefType(fir::getBase(exv).getType());
479+
});
480+
}
481+
482+
/// Return the scalar type of \p exv type. This removes all
483+
/// reference, box, or sequence type from \p exv base.
484+
inline mlir::Type getElementTypeOf(const ExtendedValue &exv) {
485+
return fir::unwrapSequenceType(getBaseTypeOf(exv));
486+
}
487+
488+
/// Is the extended value `exv` a derived type with length parameters ?
489+
inline bool isDerivedWithLengthParameters(const ExtendedValue &exv) {
490+
auto record = getElementTypeOf(exv).dyn_cast<fir::RecordType>();
491+
return record && record.getNumLenParams() != 0;
492+
}
493+
470494
} // namespace fir
471495

472496
#endif // FORTRAN_OPTIMIZER_BUILDER_BOXVALUE_H

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,11 @@ mlir::Value locationToLineNo(fir::FirOpBuilder &, mlir::Location, mlir::Type);
416416
/// flang/include/flang/Runtime/ragged.h.
417417
mlir::TupleType getRaggedArrayHeaderType(fir::FirOpBuilder &builder);
418418

419+
/// Create the zero value of a given the numerical or logical \p type (`false`
420+
/// for logical types).
421+
mlir::Value createZeroValue(fir::FirOpBuilder &builder, mlir::Location loc,
422+
mlir::Type type);
423+
419424
} // namespace fir::factory
420425

421426
#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,22 @@ fir::factory::getRaggedArrayHeaderType(fir::FirOpBuilder &builder) {
604604
auto shTy = fir::HeapType::get(extTy);
605605
return mlir::TupleType::get(builder.getContext(), {i64Ty, buffTy, shTy});
606606
}
607+
608+
mlir::Value fir::factory::createZeroValue(fir::FirOpBuilder &builder,
609+
mlir::Location loc, mlir::Type type) {
610+
mlir::Type i1 = builder.getIntegerType(1);
611+
if (type.isa<fir::LogicalType>() || type == i1)
612+
return builder.createConvert(loc, type, builder.createBool(loc, false));
613+
if (fir::isa_integer(type))
614+
return builder.createIntegerConstant(loc, type, 0);
615+
if (fir::isa_real(type))
616+
return builder.createRealZeroConstant(loc, type);
617+
if (fir::isa_complex(type)) {
618+
fir::factory::Complex complexHelper(builder, loc);
619+
mlir::Type partType = complexHelper.getComplexPartType(type);
620+
mlir::Value zeroPart = builder.createRealZeroConstant(loc, partType);
621+
return complexHelper.createComplex(type, zeroPart, zeroPart);
622+
}
623+
fir::emitFatalError(loc, "internal: trying to generate zero value of non "
624+
"numeric or logical type");
625+
}

flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,3 +414,114 @@ TEST_F(FIRBuilderTest, getExtents) {
414414
auto readExtents = fir::factory::getExtents(builder, loc, ex);
415415
EXPECT_EQ(2u, readExtents.size());
416416
}
417+
418+
TEST_F(FIRBuilderTest, createZeroValue) {
419+
auto builder = getBuilder();
420+
auto loc = builder.getUnknownLoc();
421+
422+
mlir::Type i64Ty = mlir::IntegerType::get(builder.getContext(), 64);
423+
mlir::Value zeroInt = fir::factory::createZeroValue(builder, loc, i64Ty);
424+
EXPECT_TRUE(zeroInt.getType() == i64Ty);
425+
auto cst =
426+
mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(zeroInt.getDefiningOp());
427+
EXPECT_TRUE(cst);
428+
auto intAttr = cst.getValue().dyn_cast<mlir::IntegerAttr>();
429+
EXPECT_TRUE(intAttr && intAttr.getInt() == 0);
430+
431+
mlir::Type f32Ty = mlir::FloatType::getF32(builder.getContext());
432+
mlir::Value zeroFloat = fir::factory::createZeroValue(builder, loc, f32Ty);
433+
EXPECT_TRUE(zeroFloat.getType() == f32Ty);
434+
auto cst2 = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
435+
zeroFloat.getDefiningOp());
436+
EXPECT_TRUE(cst2);
437+
auto floatAttr = cst2.getValue().dyn_cast<mlir::FloatAttr>();
438+
EXPECT_TRUE(floatAttr && floatAttr.getValueAsDouble() == 0.);
439+
440+
mlir::Type boolTy = mlir::IntegerType::get(builder.getContext(), 1);
441+
mlir::Value flaseBool = fir::factory::createZeroValue(builder, loc, boolTy);
442+
EXPECT_TRUE(flaseBool.getType() == boolTy);
443+
auto cst3 = mlir::dyn_cast_or_null<mlir::arith::ConstantOp>(
444+
flaseBool.getDefiningOp());
445+
EXPECT_TRUE(cst3);
446+
auto intAttr2 = cst.getValue().dyn_cast<mlir::IntegerAttr>();
447+
EXPECT_TRUE(intAttr2 && intAttr2.getInt() == 0);
448+
}
449+
450+
TEST_F(FIRBuilderTest, getBaseTypeOf) {
451+
auto builder = getBuilder();
452+
auto loc = builder.getUnknownLoc();
453+
454+
auto makeExv = [&](mlir::Type elementType, mlir::Type arrayType)
455+
-> std::tuple<llvm::SmallVector<fir::ExtendedValue, 4>,
456+
llvm::SmallVector<fir::ExtendedValue, 4>> {
457+
auto ptrTyArray = fir::PointerType::get(arrayType);
458+
auto ptrTyScalar = fir::PointerType::get(elementType);
459+
auto ptrBoxTyArray = fir::BoxType::get(ptrTyArray);
460+
auto ptrBoxTyScalar = fir::BoxType::get(ptrTyScalar);
461+
auto boxRefTyArray = fir::ReferenceType::get(ptrBoxTyArray);
462+
auto boxRefTyScalar = fir::ReferenceType::get(ptrBoxTyScalar);
463+
auto boxTyArray = fir::BoxType::get(arrayType);
464+
auto boxTyScalar = fir::BoxType::get(elementType);
465+
466+
auto ptrValArray = builder.create<fir::UndefOp>(loc, ptrTyArray);
467+
auto ptrValScalar = builder.create<fir::UndefOp>(loc, ptrTyScalar);
468+
auto boxRefValArray = builder.create<fir::UndefOp>(loc, boxRefTyArray);
469+
auto boxRefValScalar = builder.create<fir::UndefOp>(loc, boxRefTyScalar);
470+
auto boxValArray = builder.create<fir::UndefOp>(loc, boxTyArray);
471+
auto boxValScalar = builder.create<fir::UndefOp>(loc, boxTyScalar);
472+
473+
llvm::SmallVector<fir::ExtendedValue, 4> scalars;
474+
scalars.emplace_back(fir::UnboxedValue(ptrValScalar));
475+
scalars.emplace_back(fir::BoxValue(boxValScalar));
476+
scalars.emplace_back(
477+
fir::MutableBoxValue(boxRefValScalar, mlir::ValueRange(), {}));
478+
479+
llvm::SmallVector<fir::ExtendedValue, 4> arrays;
480+
auto extent = builder.create<fir::UndefOp>(loc, builder.getIndexType());
481+
llvm::SmallVector<mlir::Value> extents(
482+
arrayType.dyn_cast<fir::SequenceType>().getDimension(),
483+
extent.getResult());
484+
arrays.emplace_back(fir::ArrayBoxValue(ptrValArray, extents));
485+
arrays.emplace_back(fir::BoxValue(boxValArray));
486+
arrays.emplace_back(
487+
fir::MutableBoxValue(boxRefValArray, mlir::ValueRange(), {}));
488+
return {scalars, arrays};
489+
};
490+
491+
auto f32Ty = mlir::FloatType::getF32(builder.getContext());
492+
mlir::Type f32SeqTy = builder.getVarLenSeqTy(f32Ty);
493+
auto [f32Scalars, f32Arrays] = makeExv(f32Ty, f32SeqTy);
494+
for (const auto &scalar : f32Scalars) {
495+
EXPECT_EQ(fir::getBaseTypeOf(scalar), f32Ty);
496+
EXPECT_EQ(fir::getElementTypeOf(scalar), f32Ty);
497+
EXPECT_FALSE(fir::isDerivedWithLengthParameters(scalar));
498+
}
499+
for (const auto &array : f32Arrays) {
500+
EXPECT_EQ(fir::getBaseTypeOf(array), f32SeqTy);
501+
EXPECT_EQ(fir::getElementTypeOf(array), f32Ty);
502+
EXPECT_FALSE(fir::isDerivedWithLengthParameters(array));
503+
}
504+
505+
auto derivedWithLengthTy =
506+
fir::RecordType::get(builder.getContext(), "derived_test");
507+
508+
llvm::SmallVector<std::pair<std::string, mlir::Type>> parameters;
509+
llvm::SmallVector<std::pair<std::string, mlir::Type>> components;
510+
parameters.emplace_back("p1", builder.getI64Type());
511+
components.emplace_back("c1", f32Ty);
512+
derivedWithLengthTy.finalize(parameters, components);
513+
mlir::Type derivedWithLengthSeqTy =
514+
builder.getVarLenSeqTy(derivedWithLengthTy);
515+
auto [derivedWithLengthScalars, derivedWithLengthArrays] =
516+
makeExv(derivedWithLengthTy, derivedWithLengthSeqTy);
517+
for (const auto &scalar : derivedWithLengthScalars) {
518+
EXPECT_EQ(fir::getBaseTypeOf(scalar), derivedWithLengthTy);
519+
EXPECT_EQ(fir::getElementTypeOf(scalar), derivedWithLengthTy);
520+
EXPECT_TRUE(fir::isDerivedWithLengthParameters(scalar));
521+
}
522+
for (const auto &array : derivedWithLengthArrays) {
523+
EXPECT_EQ(fir::getBaseTypeOf(array), derivedWithLengthSeqTy);
524+
EXPECT_EQ(fir::getElementTypeOf(array), derivedWithLengthTy);
525+
EXPECT_TRUE(fir::isDerivedWithLengthParameters(array));
526+
}
527+
}

0 commit comments

Comments
 (0)