Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit c3aa72e

Browse files
joker-ephEdd Wilder-James
authored andcommitted
RFC: MLIR Dialects for TensorFlow (#115)
* MLIR Dialects for TensorFlow * Update 20190612-mlir-dialect.md Change status to Accepted
1 parent d249774 commit c3aa72e

File tree

1 file changed

+335
-0
lines changed

1 file changed

+335
-0
lines changed

rfcs/20190612-mlir-dialect.md

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
# TensorFlow MLIR Dialects
2+
3+
|Status | Accepted |
4+
|:------------ | :-----------------------------------------|
5+
|**Author(s)** | Mehdi Amini ([email protected]) |
6+
| | Tatiana Schpeisman ([email protected]) |
7+
| | Chris Lattner ([email protected]) |
8+
|**Sponsor** | Alexandre Passos ([email protected]) |
9+
| | Jacques Pienaar ([email protected]) |
10+
|**Updated** | 2019-06-10 |
11+
12+
## Objective
13+
14+
[MLIR](https://medium.com/tensorflow/mlir-a-new-intermediate-representation-and-compiler-framework-beba999ed18d)
15+
is the intermediate representation and compiler framework we are investing in to
16+
build the compiler infrastructure for TensorFlow. The representation for
17+
TensorFlow exposed in this document will be what future high-level
18+
transformations will operate on.
19+
20+
We make use of two different dialects to model TensorFlow graphs in MLIR: first
21+
the `tf_executor` dialect that represents the execution model of the TensorFlow
22+
executor (e.g. control dependencies, deadness propagation) and the `tf` dialect
23+
which represent the regular operations in a TensorFlow graph (the ones that
24+
don’t have special contract with the executor).
25+
26+
One intent of this design is that TensorFlow 2.x features can choose to target
27+
just the `tf` dialect, allowing us to phase out the `tf_executor` dialect in
28+
subsequent TensorFlow releases. The combination of the two dialects allows to
29+
represent arbitrary existing TensorFlow graphs.
30+
31+
The representation in this document does not address the specific needs of
32+
accelerators or "custom backends" for TensorFlow. We plan to provide a generic
33+
infrastructure for replacing the TF/XLA bridge with a more flexible and reusable
34+
system across targets. A later design proposal will address these aspects. Also
35+
this representation does not address shape inference, an independent design
36+
exploration is being conducted separately at the moment.
37+
38+
## TensorFlow Dialect
39+
40+
The TensorFlow dialect in MLIR is an open dialect (it allows operations that
41+
MLIR doesn't know about) that can contain any TensorFlow operation that does not
42+
have a specific handling by the executor. These operations don’t operate on dead
43+
values, don’t have control dependencies, and execute conceptually in program
44+
order. The form used in this dialect aligns with the direction taken by
45+
TensorFlow 2.0 with tf.function and autograph, as well as with the needs of
46+
other frontends. This should ease the development of analyses and
47+
transformations: optimizations operate on a simpler semantics and local graph
48+
transformations can be validated in a local scope. Simple patterns like folding
49+
`x-x` into a constant 0 do not need to update any control dependencies. It
50+
should also be easily lowerable towards multiple accelerators and heterogeneous
51+
systems in general.
52+
53+
Operations in this dialect usually operate on tensor and scalar types defined in
54+
the standard dialect. The extra defined types are specific to TensorFlow: `QINT`
55+
types like !tf.qint8 (etc), `QUINT` types like !tf.quint8, all of the `REF`
56+
types like !tf.uint8ref, as well as !tf.string, !tf.resource, and !tf.variant
57+
which correspond to the tensorflow types of the same name.
58+
59+
### Example:
60+
61+
Below is an example of a function operating on the TensorFlow dialect:
62+
63+
```mlir {.mlir}
64+
/// This is a regular function, taking inputs by value and returning a new value.
65+
/// The body is a regular CFG.
66+
func some_function(%input : tensor<*xf32>) -> tensor<*xf32> {
67+
// TensorFlow operations are not variadic: this `tf.add` operation always
68+
// takes two inputs and returns a single output. This simplifies
69+
// pattern-matching, verification and rewriting.
70+
%added = tf.Add %input, %input : tensor<*xf32>
71+
// Operations have sequential execution semantics in a basic block, there are
72+
// no control dependencies. The compiler can reorder operations according to
73+
// the as-if rule ( https://en.wikipedia.org/wiki/As-if_rule ).
74+
%three = constant splat<tensor<f32>, 3.0>
75+
%mul = tf.Mul %input, %three : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
76+
77+
// Only control flow v2 is supported in TF dialect.
78+
// The tf.If operation takes three functions that accept the same
79+
// arguments: the condition returns a bool and the two branches must return
80+
// the same type, which is also the return of the tf.If.
81+
%value = "tf.If”(%added, %mul)
82+
{cond: @cond_func, true_branch: @func_foo, false_branch: @func_bar}
83+
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
84+
85+
return %value : tensor<*xf32>
86+
}
87+
```
88+
89+
## TensorFlow Executor Dialect
90+
91+
The `tf_executor` dialect is intended to model the current TensorFlow executor
92+
semantics and (when combined with the `tf` dialect) can represent arbitrary
93+
TensorFlow 1.x and 2.x graphs. As such it follows the executor model, including
94+
deadness propagation, concurrent semantics, and control dependencies. The
95+
`tf_executor` dialect defines two dialect-specific types:
96+
97+
* `!tf_executor.control` to represent control dependencies.
98+
* `!tf_executor.token` to represent the pair of operations modeling
99+
NextIteration operation.
100+
101+
The `tf_executor` dialect is closed (operations are all known to MLIR) as there
102+
are only 8 TensorFlow ops with specific graph executor behavior and 4 additional
103+
operations to represent islands of predictability.
104+
105+
This dialect models the TensorFlow executor semantics; as such, a large part of
106+
the defined operations are mirroring the
107+
[TensorFlow Control Flow Ops](https://www.tensorflow.org/api_docs/cc/group/control-flow-ops)
108+
and
109+
[implement Control Flow In TensorFlow](http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf).
110+
Also, almost all the operations accept a variadic number of control tokens and
111+
return an extra control token as output. Except for `tf_executor.Merge` and
112+
`tf_executor.ControlTrigger`, operations are propagating deadness: if any of the
113+
input (control and non-control) is dead, all the outputs (control and
114+
non-control) are dead as well. For `tf_executor.Merge`, the output is dead only
115+
when either an input control token is dead or all of the regular inputs are
116+
dead. For `tf_executor.ControlTrigger`, a live control output is always produced
117+
even when some control inputs are dead.
118+
119+
### `tf_executor.graph` Operation
120+
121+
The `tf_executor.graph` operation contains a region with a single block that
122+
lists the operations in a TensorFlow graph. The operations are topologically
123+
sorted in-order (no cycles are allowed in the SSA values). The execution model
124+
for operations in this block follows the TensorFlow executor semantics:
125+
126+
1. Operations that don’t have any transitive dependencies through the SSA
127+
def/use chains may be executed in parallel
128+
(`tf_executor.NextIteration.Source` is the exception).
129+
2. SSA values in this block can be implicitly dead. This means that every SSA
130+
value defined in a `tf_executor.graph` can be considered implicitly wrapped
131+
in a conceptual `dead_or<T>` structure, and includes a runtime flag
132+
indicating if the value is dead or present. Operations may have special case
133+
handling of dead values.
134+
3. Operations in this dialect return a value of type `!tf_executor.control` as
135+
last returned value (exceptions are `tf_executor.NextIteration.sink` and
136+
`tf_executor.fetch` which don’t return any value).
137+
138+
The `tf_executor.graph` op only allows specific `tf_executor` dialect operations
139+
in its body: the `tf_executor.graph` verifier will reject any unknown operation.
140+
In order to execute standard `tf` dialect operations (like `tf.Add`) they must
141+
be wrapped in the `tf_executor.island` operation.
142+
143+
The `tf_executor.graph` operation does not accept any operands, inputs are
144+
implicitly captured by the region, representing the feeds to the graph.
145+
146+
The region attached to `tf_executor.graph` is terminated by a
147+
`tf_executor.fetch` operation. The non-control operands of the terminator
148+
correspond to the result values (or fetches) of the `tf_executor.graph`
149+
operation. The behavior is undefined if any of the operands of the
150+
`tf_executor.fetch` is dead.
151+
152+
```mlir {.mlir}
153+
%fetches = tf_executor.graph : tensor<*xf32> {
154+
// Operations in the current block execute when their inputs are ready,
155+
// possibly concurrently.
156+
// Only operations in the tf_executor dialect are expected here.
157+
// Ops can return multiple outputs and a control token for control
158+
// dependencies.
159+
// We don’t mention the control token in the return type here, it is implicit.
160+
%0, %ctl0 = tf_executor.opA %feed#0, %feed#1 : tensor<*xf32>
161+
%1, %ctl1 = tf_executor.opB : tensor<*xf32>
162+
%2, %ctl2 = tf_executor.opC %1, %ctl0 : tensor<*xf32>
163+
%3, %ctl3 = tf_executor.opD %2 : tensor<*xf32>
164+
tf_executor.fetch %3 : tensor<*xf32>
165+
} // end of the “tf_executor.graph" operation/region
166+
```
167+
168+
### ‘tf_executor.island’ Operation
169+
170+
The `tf_executor.graph` operation does not allow `tf` dialect operations to be
171+
immediately nested underneath it. The `tf_executor.island` is introduced as a
172+
wrapper for general computation (for example, all the `tf` dialect operations):
173+
this results in a more consistent representation which makes analysis and
174+
transformation simpler.
175+
176+
The `tf_executor.island` operation has a single region with a single block
177+
attached (only functional control flow is allowed). The block is terminated by a
178+
`tf_executor.yield` operation. The operands of the terminator correspond to the
179+
result values of the `tf_executor.graph` operation. An extra result of type
180+
`!_tf_executor.control` is always produced by every `tf_executor.island`.
181+
182+
Within an island, execution semantics follow standard sequential behavior
183+
consistent with the direction of TensorFlow 2.0 and autograph, and desirable for
184+
compiler analyses and transformations. Values in an island can’t be dead. Other
185+
nested `tf_executor.graph` operations can be present in the region (or called
186+
functions) to re-enable the TensorFlow executor behavior for a subsection of the
187+
code. This is important for the following reasons:
188+
189+
* Initially the functional control flow operations are calling functions
190+
involving nested graphs, if `tf_executor.graph` weren’t allowed in an
191+
island, these operations would need to have an equivalent in the
192+
`tf_executor` dialect.
193+
* Nesting also allows to form islands without involving inter-procedural
194+
analyzes: any function call may involve a callee with a graph.
195+
196+
The `tf_executor.island` region allows implicit capture. If any value captured
197+
by a `tf_executor.island` is dead, the whole region does not execute and every
198+
produced value is marked as dead as well.
199+
200+
An arbitrary number of `tf_executor.control` operands are accepted by a
201+
`tf_executor.island` operation. If any operand is dead, the region is not
202+
executed and dead values are immediately returned for every result.
203+
204+
```mlir {.mlir}
205+
// The island is capturing implicitly %0 and %1. It is also taking a control
206+
// dependency %ctl0 as input. It produces a tensor<*xf32> value matching the
207+
// argument of the yield terminator, as well as an extra control token.
208+
%2, %ctl2 = tf_executor.island (%ctl0)
209+
: (tensor<*xf32>, !tf_executor<"control">) -> tensor<*xf32> {
210+
%added = tf.Add %1, %0 : tensor<*xf32>
211+
%mul = tf.Mul %added, %1 :tensor<*xf32>
212+
213+
// The yield terminator operands are the result values of the island.
214+
tf_executor.yield %mul : tensor<*xf32>
215+
}
216+
```
217+
218+
The case where a single operation is wrapped inside an island can even be
219+
compressed by inferring the terminator to be the returned value of the
220+
operation. The example above if it only contained the addition with implicit
221+
capture would be displayed as:
222+
223+
```mlir {.mlir}
224+
%2, %ctl2 = tf_executor.island(%ctl0) wraps tf.Add %1, %0 : tensor<*xf32>
225+
```
226+
227+
### `tf_executor.Switch` Operation
228+
229+
[`tf_executor.Switch`](https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/switch):
230+
takes two inputs,`predicate`and`data`and returns two regular
231+
outputs,`true_output`,`false_output`. The`data`input is copied
232+
to`true_output`if`predicate`evaluates to true otherwise it is copied
233+
to`false_output`. The other output is marked as dead. If one of the inputs or a
234+
control token is dead, then all of the outputs are marked as dead as well.
235+
236+
### `tf_executor.SwitchN` Operation
237+
238+
[`tf_executor.SwitchN`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/control_flow_ops.cc#L49-L53):
239+
takes two inputs,`data`and`index`and an integer attribute`num_outs`indicating
240+
the number of outputs. The`data`input is copied to output indicated by
241+
the`index` input. The other outputs are marked as dead. If one of the inputs or
242+
a control token is dead, then all of the outputs are marked as dead as well.
243+
244+
### `tf_executor.Merge` Operation
245+
246+
[`tf_executor.Merge`](https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/merge):
247+
takes a variadic number of inputs, and returns a single output. The output is
248+
defined as a non-dead input (selected in a non-defined way if multiple inputs
249+
are non-dead). If all inputs are dead, the output is also dead.
250+
251+
### NextIteration: `tf_executor.NextIteration.Source` and `tf_executor.NextIteration.Sink` Operation
252+
253+
The TensorFlow
254+
[`NextIteration`](https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/next-iteration)
255+
op is modeled using these two paired operations. Since _NextIteration_ is
256+
intended for modeling the loop back-edges, breaking it in two different
257+
operations allows to keep a structural
258+
DAG.`tf_executor.NextIteration.Source`does not take any operand and produces two
259+
results: one regular value corresponding to the TensorFlow graph, and a second
260+
value of type`tf_executor.loop_token`. This token is consumed by the
261+
paired`tf_executor.NextIteration.Sink`Operation alongside the value that is
262+
passed through the back-edge. No value is returned
263+
by`tf_executor.NextIteration.Sink`. The type of the result of the source must
264+
match the type of the value operand of the sink.
265+
266+
`tf_executor.NextIteration.Source` is an exception in the executor model in the
267+
sense that it executes after the paired `tf_executor.NextIteration.Sink` even
268+
though there is no data dependency between them.
269+
270+
### `tf_executor.LoopCond` Operation
271+
272+
[`tf_executor.LoopCond`](https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/loop-cond):
273+
forwards its boolean input to its output,
274+
[it acts as`pivot` for marking the loop termination condition](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/control_flow_ops.h#L115-L118).
275+
276+
### `tf_executor.Enter` Operation
277+
278+
[`tf_executor.Enter`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/control_flow_ops.h##77-L79):
279+
takes a single input and a`name` string attribute that identifies the execution
280+
frame. It forwards its input to its output in the new execution frame.
281+
282+
### `tf_executor.Exit` Operation
283+
284+
[`tf_executor.Exit`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/control_flow_ops.h#L90-L92):
285+
forwards its single input to its output, exiting the current execution frame.
286+
287+
### `tf_executor.ControlTrigger` Operation
288+
289+
[`tf_executor.ControlTrigger`](https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/control-trigger):
290+
it is similar to
291+
[a no-op](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/control_flow_ops.h#L23-L26)
292+
that acts as a placeholder for control dependencies. It always produces a live
293+
control output even when some control inputs are dead.
294+
295+
### `tf_executor.Send` Operation
296+
297+
[`tf_executor.Send`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/sendrecv_ops.h#L24):
298+
matches TensorFlow semantics.
299+
300+
### `tf_executor.Recv` Operation
301+
302+
[`tf_executor.Recv`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/sendrecv_ops.h#L37):
303+
matches TensorFlow semantics.
304+
305+
## Example
306+
307+
Below is an example of a loop decrementing an initial `%_count.init` integer
308+
until it reaches 0 and returns the last value in the loop.
309+
310+
```mlir {.mlir}
311+
// Loop `%count.init` times and return the last counter (always zero)
312+
%fetches = tf_executor.graph {
313+
314+
%loop.init, %ctl0 = tf_executor.Enter %count.init : i32
315+
316+
%next_count, %tok = tf_executor.NextIteration.Source : i32
317+
318+
%loop.body.init, %ctlMerge = tf_executor.Merge %loop.init, %next_count : i32
319+
320+
%dec_count, %ctlAdd = tf_executor.island
321+
wraps tf.Add %loop.body.init, -1 : (i32, i32) -> i32
322+
323+
%loop_cond, %ctlNE = tf_executor.island
324+
wraps tf.NotEqual %dec_count, 0 : (i32, i32) -> i1
325+
326+
%true, %false, %ctlSwitch = tf_executor.Switch %loop_cond, %dec_count : i32
327+
328+
tf_executor.NextIteration.Sink[%tok] %false : i32
329+
330+
%exit_count, %ctlExit = tf_executor.Exit %true : i32
331+
332+
tf_executor.fetch %exit_count : i32
333+
} // end of the "tf_executor.graph" operation/region
334+
```
335+

0 commit comments

Comments
 (0)