Draft: DroQ and TD3+TQC jax implementation#272
Draft: DroQ and TD3+TQC jax implementation#272araffin wants to merge 31 commits intovwxyzjn:masterfrom
Conversation
|
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
This reverts commit d5704b3.
Results are very preliminary, ADAN performs on par or slightly better than ADAM, but nothing significant yet. |
|
FYI https://github.com/deepmind/distrax might be a better replacement for tensorflow probability |
|
fwiw you can also use tensorflow_probability with a jax backend and then you don't need to use tensorflow at all (in one of their tutorials they even explicitly unninstall tf) |
|
Fyi, I converted that single file to a proof of concept of SB3 + Jax (SBX): https://github.com/araffin/sbx |
Description
FYI: unpolished jax implementation of TD3+DroQ and TD3+TQC implementations.
Related to #262 #258
My plan is to try to have sac in jax, but currently jax rely on tensorflow for probability distributions :/
So I adapted TD3 instead.
I also want to make it even faster but would need to tweak a bit the way the replay buffer is used.
EDIT: apparently tfd doesn't depends on tf anymore for latest version: https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX
Reference:
EDIT: SBX = SB3 + JAX: https://github.com/araffin/sbx
Known difference with original implementation: qf are updated at the same time of the actor instead of after each gradient step.Types of changes
Checklist:
pre-commit run --all-filespasses (required).mkdocs serve.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-videoflag toggled on (required).mkdocs serve.width=500andheight=300).