@@ -131,6 +131,24 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131131 return {&getBodyGraph ()};
132132}
133133
134+ // ===----------------------------------------------------------------------===//
135+ // TOSA variable operator support.
136+ // ===----------------------------------------------------------------------===//
137+
138+ static SmallVector<int64_t > convertToMlirShape (ArrayRef<int64_t > shape) {
139+ return to_vector (llvm::map_range (shape, [](int64_t dim) {
140+ return dim == -1 ? ShapedType::kDynamic : dim;
141+ }));
142+ }
143+
144+ // returns type of variable op
145+ RankedTensorType mlir::tosa::getVariableType (tosa::VariableOp variableOp) {
146+ Type elementType = variableOp.getType ();
147+ DenseIntElementsAttr varShapeAttr = variableOp.getVarShape ();
148+ auto shape = convertToMlirShape (to_vector (varShapeAttr.getValues <int64_t >()));
149+ return RankedTensorType::get (shape, elementType);
150+ }
151+
134152// ===----------------------------------------------------------------------===//
135153// Tosa dialect initialization.
136154// ===----------------------------------------------------------------------===//
@@ -177,42 +195,80 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
177195// Parsers and printers
178196// ===----------------------------------------------------------------------===//
179197
180- ParseResult mlir::tosa::parseTypeOrAttr (OpAsmParser &parser, TypeAttr &typeAttr,
181- Attribute &attr) {
198+ namespace {
199+
200+ ParseResult getShapeAndElementType (OpAsmParser &parser, Type parsedType,
201+ DenseElementsAttr &varShapeAttr,
202+ TypeAttr &typeAttr) {
203+ if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204+ if (!shapedType.hasRank ())
205+ return parser.emitError (parser.getCurrentLocation ())
206+ << " expected ranked type" ;
207+
208+ auto elementType = shapedType.getElementType ();
209+ typeAttr = TypeAttr::get (elementType);
210+ ArrayRef<int64_t > shape = shapedType.getShape ();
211+ Builder builder (parser.getContext ());
212+ varShapeAttr = builder.getIndexTensorAttr (convertFromMlirShape (shape));
213+ return success ();
214+ }
215+ return parser.emitError (parser.getCurrentLocation ())
216+ << " expected shaped type" ;
217+ }
218+
219+ } // namespace
220+
221+ // parses the optional initial value or type for a tosa variable
222+ // with initial value:
223+ // tosa.variable @name = dense<0.0> : tensor<1x8xf32>
224+ //
225+ // without initial value:
226+ // tosa.variable @name : tensor<1x8xf32>
227+ ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue (
228+ OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
229+ Attribute &initialValueAttr) {
182230 if (succeeded (parser.parseOptionalEqual ())) {
183- if (failed (parser.parseAttribute (attr ))) {
231+ if (failed (parser.parseAttribute (initialValueAttr ))) {
184232 return parser.emitError (parser.getCurrentLocation ())
185233 << " expected attribute" ;
186234 }
187- if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
188- typeAttr = TypeAttr::get (typedAttr.getType ());
235+ if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236+ return getShapeAndElementType (parser, typedAttr.getType (), varShapeAttr,
237+ typeAttr);
189238 }
190- return success ();
239+ return parser.emitError (parser.getCurrentLocation ())
240+ << " expected Typed attr" ;
191241 }
192242
193- Type type;
194- if (failed (parser.parseColonType (type))) {
195- return parser.emitError (parser.getCurrentLocation ()) << " expected type" ;
243+ initialValueAttr = nullptr ;
244+ Type parsedType;
245+ if (failed (parser.parseColonType (parsedType))) {
246+ return parser.emitError (parser.getCurrentLocation ())
247+ << " expected type after colon" ;
196248 }
197- typeAttr = TypeAttr::get (type);
198-
199- return success ();
249+ return getShapeAndElementType (parser, parsedType, varShapeAttr, typeAttr);
200250}
201251
202- void mlir::tosa::printTypeOrAttr (OpAsmPrinter &p, Operation *op, TypeAttr type,
203- Attribute attr) {
252+ void mlir::tosa::printVariableOpTypeOrInitialValue (
253+ OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
254+ TypeAttr typeAttr, Attribute initialValueAttr) {
204255 bool needsSpace = false ;
205- auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
206- if (!typedAttr || typedAttr.getType () != type.getValue ()) {
256+ if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257+ auto shape =
258+ convertToMlirShape (to_vector (varShapeAttr.getValues <int64_t >()));
259+ Type elementType = typeAttr.getValue ();
260+ RankedTensorType tensorType =
261+ RankedTensorType::get (ArrayRef<int64_t >(shape), elementType);
262+ auto tensorTypeAttr = TypeAttr::get (tensorType);
207263 p << " : " ;
208- p.printAttribute (type );
264+ p.printAttribute (tensorTypeAttr );
209265 needsSpace = true ; // subsequent attr value needs a space separator
210266 }
211- if (attr ) {
267+ if (initialValueAttr ) {
212268 if (needsSpace)
213269 p << ' ' ;
214270 p << " = " ;
215- p.printAttribute (attr );
271+ p.printAttribute (initialValueAttr );
216272 }
217273}
218274
@@ -657,8 +713,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
657713 << symName << " ' has not been declared by 'tosa.variable'" ;
658714
659715 // Verify type and shape
660- Type varType = cast<tosa::VariableOp>(varOp.value ()).getType ();
661- if (errorIfTypeOrShapeMismatch (op, type, name, varType, " the input tensor" )
716+ auto variableType = getVariableType (varOp.value ());
717+ if (errorIfTypeOrShapeMismatch (op, type, name, variableType,
718+ " the input tensor" )
662719 .failed ())
663720 return failure ();
664721
@@ -1103,6 +1160,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
11031160 result.types .push_back (outputType);
11041161}
11051162
1163+ static void buildVariableOp (OpBuilder &builder, OperationState &result,
1164+ StringRef name, Type variableType,
1165+ Attribute initialValue) {
1166+ const Location loc{result.location };
1167+ auto nameAttr = builder.getStringAttr (name);
1168+
1169+ auto shapedType = dyn_cast<ShapedType>(variableType);
1170+ if (!shapedType) {
1171+ (void )emitError (loc, " variable type must be a shaped type" );
1172+ return ;
1173+ }
1174+ if (!shapedType.hasRank ()) {
1175+ (void )emitError (loc, " variable type must be a ranked type" );
1176+ return ;
1177+ }
1178+
1179+ auto elementType = shapedType.getElementType ();
1180+ auto elementTypeAttr = TypeAttr::get (elementType);
1181+ ArrayRef<int64_t > shape = shapedType.getShape ();
1182+ auto varShapeAttr = builder.getIndexTensorAttr (convertFromMlirShape (shape));
1183+
1184+ result.addAttribute (" name" , nameAttr);
1185+ result.addAttribute (" var_shape" , varShapeAttr);
1186+ result.addAttribute (" type" , elementTypeAttr);
1187+ result.addAttribute (" initial_value" , initialValue);
1188+ }
1189+
11061190// ===----------------------------------------------------------------------===//
11071191// TOSA Operator Return Type Inference.
11081192// ===----------------------------------------------------------------------===//
@@ -1676,12 +1760,6 @@ LogicalResult tosa::PadOp::verify() {
16761760 return success ();
16771761}
16781762
1679- static SmallVector<int64_t > convertToMlirShape (ArrayRef<int64_t > shape) {
1680- return to_vector (llvm::map_range (shape, [](int64_t dim) {
1681- return dim == -1 ? ShapedType::kDynamic : dim;
1682- }));
1683- }
1684-
16851763LogicalResult tosa::SliceOp::inferReturnTypeComponents (
16861764 MLIRContext *context, ::std::optional<Location> location,
16871765 SliceOp::Adaptor adaptor,
0 commit comments