Skip to content

Commit 39cb557

Browse files
GleasonKTensorFlow MLIR Team
authored andcommitted
[StableHLO Builder] Add API to set frontend attributes
PiperOrigin-RevId: 820455957
1 parent 4d03be9 commit 39cb557

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

stablehlo/stablehlo/integrations/cpp/builder/StablehloBuilder.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,36 @@ limitations under the License.
2929
#include "stablehlo/dialect/TypeInference.h"
3030
#include "stablehlo/integrations/cpp/builder/AttrTypeBuilderUtil.h"
3131
#include "stablehlo/integrations/cpp/builder/MlirBuilder.h"
32+
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
3233

3334
namespace mlir {
3435
namespace stablehlo {
3536

37+
///////////////
38+
// Dialect Helpers
39+
///////////////
40+
41+
MlirOp AttachFrontendAttribute(MlirBuilder& builder, MlirOp op, StringRef name,
42+
Attribute value) {
43+
constexpr char kFrontendAttrName[] = "mhlo.frontend_attributes";
44+
Operation* mlirOp = unwrap(op).getDefiningOp();
45+
SmallVector<NamedAttribute> attrs;
46+
DictionaryAttr frontendAttr =
47+
mlirOp->getAttrOfType<DictionaryAttr>(kFrontendAttrName);
48+
if (frontendAttr) {
49+
for (NamedAttribute attr : frontendAttr.getValue()) {
50+
// Populate all non-conflicting names.
51+
if (attr.getName() != name) {
52+
attrs.push_back(attr);
53+
}
54+
}
55+
}
56+
attrs.emplace_back(name, value);
57+
mlirOp->setAttr(kFrontendAttrName,
58+
DictionaryAttr::get(&builder.getContext(), attrs));
59+
return op;
60+
}
61+
3662
/////////////////
3763
// MANUAL APIs
3864
/////////////////

stablehlo/stablehlo/integrations/cpp/builder/StablehloBuilder.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,23 @@ limitations under the License.
2727
#include "stablehlo/dialect/StablehloOps.h"
2828
#include "stablehlo/integrations/cpp/builder/AttrTypeBuilderUtil.h"
2929
#include "stablehlo/integrations/cpp/builder/MlirBuilder.h"
30+
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
3031

3132
namespace mlir {
3233
namespace stablehlo {
3334

35+
///////////////
36+
// Dialect Helpers
37+
///////////////
38+
39+
// Appends or overwrites an entry in the `mhlo.frontend_attributes` attribute
40+
//
41+
// of the given op.
42+
// Ex:
43+
// stablehlo.abs %0 { mhlo.frontend_attributes = { "foo" = 123 } }
44+
MlirOp AttachFrontendAttribute(MlirBuilder& builder, MlirOp op, StringRef name,
45+
Attribute value);
46+
3447
/////////////////
3548
// MANUAL APIs
3649
/////////////////

stablehlo/stablehlo/integrations/cpp/builder/StablehloBuilderTest.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,5 +1592,57 @@ TEST(MlirBuilderTest, ResultAccuracyAttrTolerance) {
15921592
EXPECT_EQ(expected, debugString(*module));
15931593
}
15941594

1595+
TEST(MlirBuilderTest, FrontendAttributesAppend) {
1596+
std::string expected = R"mlir(module {
1597+
func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
1598+
%0 = stablehlo.exponential %arg0 {mhlo.frontend_attributes = {bar = "hello", foo = 123 : i32}} : tensor<2xf32>
1599+
return %0 : tensor<2xf32>
1600+
}
1601+
})mlir";
1602+
1603+
StablehloModuleBuilder mb;
1604+
{
1605+
Location funcLoc = fileLineColLoc(mb->getContext(), "main.mlir", 1, 1);
1606+
func::FunctionBuilder fb(mb.get(), "main", funcLoc);
1607+
auto type = makeTensorType(fb.getContext(), {2}, ElementType::F32);
1608+
auto arg0 = func::Argument(fb, type);
1609+
auto exp = Exp(arg0);
1610+
stablehlo::AttachFrontendAttribute(
1611+
fb, exp, "foo", fb.getOpBuilder().getI32IntegerAttr(123));
1612+
stablehlo::AttachFrontendAttribute(
1613+
fb, exp, "bar", fb.getOpBuilder().getStringAttr("hello"));
1614+
func::Return(fb, {exp});
1615+
}
1616+
1617+
OwningOpRef<ModuleOp> module = mb->build();
1618+
EXPECT_EQ(expected, debugString(*module));
1619+
}
1620+
1621+
TEST(MlirBuilderTest, FrontendAttributesOverwrite) {
1622+
std::string expected = R"mlir(module {
1623+
func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> {
1624+
%0 = stablehlo.exponential %arg0 {mhlo.frontend_attributes = {foo = 456 : i32}} : tensor<2xf32>
1625+
return %0 : tensor<2xf32>
1626+
}
1627+
})mlir";
1628+
1629+
StablehloModuleBuilder mb;
1630+
{
1631+
Location funcLoc = fileLineColLoc(mb->getContext(), "main.mlir", 1, 1);
1632+
func::FunctionBuilder fb(mb.get(), "main", funcLoc);
1633+
auto type = makeTensorType(fb.getContext(), {2}, ElementType::F32);
1634+
auto arg0 = func::Argument(fb, type);
1635+
auto exp = Exp(arg0);
1636+
stablehlo::AttachFrontendAttribute(
1637+
fb, exp, "foo", fb.getOpBuilder().getI32IntegerAttr(123));
1638+
stablehlo::AttachFrontendAttribute(
1639+
fb, exp, "foo", fb.getOpBuilder().getI32IntegerAttr(456));
1640+
func::Return(fb, {exp});
1641+
}
1642+
1643+
OwningOpRef<ModuleOp> module = mb->build();
1644+
EXPECT_EQ(expected, debugString(*module));
1645+
}
1646+
15951647
} // namespace stablehlo
15961648
} // namespace mlir

0 commit comments

Comments
 (0)