Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit b922289

Browse files
committed
addressing initial comments
1 parent f1f032a commit b922289

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

rfcs/20210731-tfjs-named-tensors.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,10 @@ const keyM: = ...
7373
const valueM: = ...
7474
const xs: = ...
7575

76-
const keys = tf.matMul(xs, keyM.read())
77-
const queries = tf.matMul(xs, queryM.read())
78-
const attention = tf.softmax(
79-
tf.div(tf.matMul(queries, keys, true, false), inputRepSizeSqrt)) as tf.Tensor2D;
80-
const values = tf.matMul(xs, valueM.read(), true, false)
76+
const inputKeys = tf.matMul(xs, keyM.read());
77+
const inputQueries = tf.matMul(xs, queryM.read());
78+
const attention = tf.matMul(inputQueries, inputKeys, true, false);
79+
const values = tf.matMul(xs, valueM.read(), true, false);
8180
const attendedValues = tf.matMul(attention, values, false, true);
8281
```
8382

@@ -147,26 +146,31 @@ dimension.
147146

148147
```ts
149148
const g1: GTensor<'inputTokens'|'tokenRep'> = ...;
150-
const g2: GTensor<'tokenRep'|'queryRep'> = ...;
149+
const g2: GTensor<'tokenRep'|'queryTokens'> = ...;
151150
const g3 = g1.dim.tokenRep.dot(g2.dim.tokenRep);
152151
// g3: GTensor<'inputTokens'|'queryRep'>
153152
```
154153

155154
Type-checking ensures that dimension names match. i.e.
156155

157156
```ts
158-
g1.dim.tokenRep.dot(g2.dim.foo)
159-
// Type Error: "tokenRep" is different from "foo".
157+
g1.dim.inputTokens.dot(g2.dim.queryTokens)
158+
// Type Error: "inputTokens" is different from "queryTokens".
160159
```
161160

162-
Dimensions can be renamed also to provide a new dimension object with the
163-
correct name, e.g. if `g2` didn't have an `tokenRep` dimension, but we wanted to
164-
dot product with the `foo` dimension, we could do:
161+
Sometimes one wants to multiply dimensions that have different names. In such
162+
cases, this is done by an explicit renaming e.g. for the above example, `g2`'s
163+
dimension `inputRep` can be renamed to `tokenRep` to allow the above
164+
multiplication:
165165

166166
```ts
167-
g1.dim.tokenRep.dot(g2.dim.foo.rename('tokenRep'));
167+
g1.dim.inputTokens.dot(g2.dim.queryTokens.rename('inputTokens'));
168168
```
169169

170+
This can be seen as an analag to explicit type-casting, and ensures that
171+
dimensions with different names are being multiplied together intentionally by
172+
the user.
173+
170174
By working at this more abstract level, you never need to worry about the axis position, you just reference it by name. Underneath this abstraction, we can now optimise the "layout" of the tensors (the order of the axis) and the various permutation operations.
171175

172176
The vision is that this also provides a higher level abstraction that can be used to efficiently compile to XLA, and thus provide a better high level abstraction for ML programming in TypeScript, with better tool support, making it easier for more people be able to explore and write ML algorithms, and remove a large part of the boring and frustrating challenges of making sure indexes align correctly. A side effect is that this also makes code much more readable (see the Attention Head implementation below).

0 commit comments

Comments
 (0)