@@ -244,25 +244,41 @@ template <class AttrElementT,
244244Attribute constFoldBinaryOp (ArrayRef<Attribute> operands,
245245 const CalculationT &calculate) {
246246 assert (operands.size () == 2 && " binary op takes two operands" );
247+ if (!operands[0 ] || !operands[1 ])
248+ return {};
249+ if (operands[0 ].getType () != operands[1 ].getType ())
250+ return {};
247251
248- if (auto lhs = operands[0 ].dyn_cast_or_null <AttrElementT>()) {
249- auto rhs = operands[1 ].dyn_cast_or_null <AttrElementT>();
250- if (!rhs || lhs.getType () != rhs.getType ())
251- return {};
252+ if (operands[0 ].isa <AttrElementT>() && operands[1 ].isa <AttrElementT>()) {
253+ auto lhs = operands[0 ].cast <AttrElementT>();
254+ auto rhs = operands[1 ].cast <AttrElementT>();
252255
253256 return AttrElementT::get (lhs.getType (),
254257 calculate (lhs.getValue (), rhs.getValue ()));
255- } else if (auto lhs = operands[0 ].dyn_cast_or_null <SplatElementsAttr>()) {
256- auto rhs = operands[1 ].dyn_cast_or_null <SplatElementsAttr>();
257- if (!rhs || lhs.getType () != rhs.getType ())
258- return {};
259-
260- auto elementResult = constFoldBinaryOp<AttrElementT>(
261- {lhs.getSplatValue (), rhs.getSplatValue ()}, calculate);
262- if (!elementResult)
263- return {};
264-
258+ } else if (operands[0 ].isa <SplatElementsAttr>() &&
259+ operands[1 ].isa <SplatElementsAttr>()) {
260+ // Both operands are splats so we can avoid expanding the values out and
261+ // just fold based on the splat value.
262+ auto lhs = operands[0 ].cast <SplatElementsAttr>();
263+ auto rhs = operands[1 ].cast <SplatElementsAttr>();
264+
265+ auto elementResult = calculate (lhs.getSplatValue <ElementValueT>(),
266+ rhs.getSplatValue <ElementValueT>());
265267 return DenseElementsAttr::get (lhs.getType (), elementResult);
268+ } else if (operands[0 ].isa <ElementsAttr>() &&
269+ operands[1 ].isa <ElementsAttr>()) {
270+ // Operands are ElementsAttr-derived; perform an element-wise fold by
271+ // expanding the values.
272+ auto lhs = operands[0 ].cast <ElementsAttr>();
273+ auto rhs = operands[1 ].cast <ElementsAttr>();
274+
275+ auto lhsIt = lhs.getValues <ElementValueT>().begin ();
276+ auto rhsIt = rhs.getValues <ElementValueT>().begin ();
277+ SmallVector<ElementValueT, 4 > elementResults;
278+ elementResults.reserve (lhs.getNumElements ());
279+ for (size_t i = 0 , e = lhs.getNumElements (); i < e; ++i, ++lhsIt, ++rhsIt)
280+ elementResults.push_back (calculate (*lhsIt, *rhsIt));
281+ return DenseElementsAttr::get (lhs.getType (), elementResults);
266282 }
267283 return {};
268284}
0 commit comments