Commit ed3f017
Fix
The `dtype` property would return to true dtype of the variable, instead of the dtype of the value that you get explicitly via `.value()` or implicitly by doing any operation.
This would cause seemingly correct things like this to fail with a dtype mismatch:
```
y = variable * tf.cast(x, variable.dtype)
```
Forcing users to write workarounds like:
```
v = variable.value()
y = variable * tf.cast(x, v.dtype)
```
Additionally, `assign`, `assign_add`, `assign_sub` expected the value to be of the true dtype, not the cast dtype.
This would cause seemingly correct things like this to fail with a dtype mismatch:
```
variable.assign(variable * factor)
```
(This is a common use case for non-trainable variables.)
Forcing users to write workarounds like:
```
variable.assign(tf.cast(variable * factor, variable.dtype))
```
This changes fixes these issues to make autocasting fully transparent:
- `dtype` returns the cast dtype if applicable
- `assign*` accept the cast dtype for the value if applicable
Note that this is consistent with how autocasting works in Keras 3.
PiperOrigin-RevId: 650386711dtype and assign* in AutocastVariable.1 parent 909a2a4 commit ed3f017
File tree
1 file changed
+8
-2
lines changed- tensorflow_model_optimization/python/core/sparsity/keras
1 file changed
+8
-2
lines changedLines changed: 8 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
233 | 233 | | |
234 | 234 | | |
235 | 235 | | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
236 | 242 | | |
237 | 243 | | |
238 | 244 | | |
239 | 245 | | |
240 | | - | |
| 246 | + | |
241 | 247 | | |
242 | 248 | | |
243 | 249 | | |
244 | 250 | | |
245 | 251 | | |
246 | 252 | | |
247 | 253 | | |
248 | | - | |
| 254 | + | |
249 | 255 | | |
250 | 256 | | |
251 | 257 | | |
| |||
0 commit comments