@@ -3142,6 +3142,74 @@ static bool hasScalableVectorType(Type t) {
31423142 return false ;
31433143}
31443144
3145+ // / Verifies the constant array represented by `arrayAttr` matches the provided
3146+ // / `arrayType`.
3147+ static LogicalResult verifyStructArrayConstant (LLVM::ConstantOp op,
3148+ LLVM::LLVMArrayType arrayType,
3149+ ArrayAttr arrayAttr, int dim) {
3150+ if (arrayType.getNumElements () != arrayAttr.size ())
3151+ return op.emitOpError ()
3152+ << " array attribute size does not match array type size in "
3153+ " dimension "
3154+ << dim << " : " << arrayAttr.size () << " vs. "
3155+ << arrayType.getNumElements ();
3156+
3157+ llvm::DenseSet<Attribute> elementsVerified;
3158+
3159+ // Recursively verify sub-dimensions for multidimensional arrays.
3160+ if (auto subArrayType =
3161+ dyn_cast<LLVM::LLVMArrayType>(arrayType.getElementType ())) {
3162+ for (auto [idx, elementAttr] : llvm::enumerate (arrayAttr))
3163+ if (elementsVerified.insert (elementAttr).second ) {
3164+ if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3165+ continue ;
3166+ auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3167+ if (!subArrayAttr)
3168+ return op.emitOpError ()
3169+ << " nested attribute for sub-array in dimension " << dim
3170+ << " at index " << idx
3171+ << " must be a zero, or undef, or array attribute" ;
3172+ if (failed (verifyStructArrayConstant (op, subArrayType, subArrayAttr,
3173+ dim + 1 )))
3174+ return failure ();
3175+ }
3176+ return success ();
3177+ }
3178+
3179+ // Forbid usages of ArrayAttr for simple array types that should use
3180+ // DenseElementsAttr instead. Note that there would be a use case for such
3181+ // array types when one element value is obtained via a ptr-to-int conversion
3182+ // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR
3183+ // user needs this so far, and it seems better to avoid people misusing the
3184+ // ArrayAttr for simple types.
3185+ auto structType = dyn_cast<LLVM::LLVMStructType>(arrayType.getElementType ());
3186+ if (!structType)
3187+ return op.emitOpError () << " for array with an array attribute must have a "
3188+ " struct element type" ;
3189+
3190+ // Shallow verification that leaf attributes are appropriate as struct initial
3191+ // value.
3192+ size_t numStructElements = structType.getBody ().size ();
3193+ for (auto [idx, elementAttr] : llvm::enumerate (arrayAttr)) {
3194+ if (elementsVerified.insert (elementAttr).second ) {
3195+ if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3196+ continue ;
3197+ auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3198+ if (!subArrayAttr)
3199+ return op.emitOpError ()
3200+ << " nested attribute for struct element at index " << idx
3201+ << " must be a zero, or undef, or array attribute" ;
3202+ if (subArrayAttr.size () != numStructElements)
3203+ return op.emitOpError ()
3204+ << " nested array attribute size for struct element at index "
3205+ << idx << " must match struct size: " << subArrayAttr.size ()
3206+ << " vs. " << numStructElements;
3207+ }
3208+ }
3209+
3210+ return success ();
3211+ }
3212+
31453213LogicalResult LLVM::ConstantOp::verify () {
31463214 if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue ())) {
31473215 auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType ());
@@ -3208,7 +3276,7 @@ LogicalResult LLVM::ConstantOp::verify() {
32083276 if (isa<IntegerType>(getType ()) && !getType ().isInteger (floatWidth)) {
32093277 return emitOpError () << " expected integer type of width " << floatWidth;
32103278 }
3211- } else if (isa<ElementsAttr, ArrayAttr >(getValue ())) {
3279+ } else if (isa<ElementsAttr>(getValue ())) {
32123280 if (hasScalableVectorType (getType ())) {
32133281 // The exact number of elements of a scalable vector is unknown, so we
32143282 // allow only splat attributes.
@@ -3221,15 +3289,20 @@ LogicalResult LLVM::ConstantOp::verify() {
32213289 if (!isa<VectorType, LLVM::LLVMArrayType>(getType ()))
32223290 return emitOpError () << " expected vector or array type" ;
32233291 // The number of elements of the attribute and the type must match.
3224- int64_t attrNumElements;
3225- if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue ()))
3226- attrNumElements = elementsAttr.getNumElements ();
3227- else
3228- attrNumElements = cast<ArrayAttr>(getValue ()).size ();
3229- if (getNumElements (getType ()) != attrNumElements)
3230- return emitOpError ()
3231- << " type and attribute have a different number of elements: "
3232- << getNumElements (getType ()) << " vs. " << attrNumElements;
3292+ if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue ())) {
3293+ int64_t attrNumElements = elementsAttr.getNumElements ();
3294+ if (getNumElements (getType ()) != attrNumElements)
3295+ return emitOpError ()
3296+ << " type and attribute have a different number of elements: "
3297+ << getNumElements (getType ()) << " vs. " << attrNumElements;
3298+ }
3299+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue ())) {
3300+ auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType ());
3301+ if (!arrayType)
3302+ return emitOpError () << " expected array type" ;
3303+ // When the attribute is an ArrayAttr, check that its nesting matches the
3304+ // corresponding ArrayType or VectorType nesting.
3305+ return verifyStructArrayConstant (*this , arrayType, arrayAttr, /* dim=*/ 0 );
32333306 } else {
32343307 return emitOpError ()
32353308 << " only supports integer, float, string or elements attributes" ;
0 commit comments