High-performance JAX implementation of PointMaze environments with MuJoCo-inspired physics.
Pointax provides a complete JAX implementation of the PointMaze environment from Gymnasium Robotics, featuring full JIT compilation, vectorization support, and a simplified but accurate 2D physics engine inspired by MuJoCo.
Pointax implements a simplified but accurate 2D physics engine inspired by MuJoCo while being fully differentiable and JIT-compilable:
- Collision Detection: Sphere-AABB intersection with anti-sticking mechanisms
- Force Integration: Proper velocity and position updates with motor scaling
- Boundary Handling: Smooth wall interactions with friction coefficients
- Parameter Matching: Robot radius (0.1m), motor gear (100x), and dynamics match MuJoCo reference
import jax
import pointax
# Create environment
env = pointax.make_umaze()
params = env.default_params
# Reset and step
key = jax.random.PRNGKey(42)
obs, state = env.reset_env(key, params)
action = jax.numpy.array([0.5, 0.0]) # Move right
obs, state, reward, done, info = env.step_env(key, state, action, params)
print(f"Reward: {reward}, Success: {info['is_success']}")From GitHub (not yet on PyPi):
pip install git+https://github.com/riiswa/pointax.gitFrom source:
git clone https://github.com/riiswa/pointax.git
cd pointax
pip install -e .| Environment | Size | Description |
|---|---|---|
UMaze |
5×5 | U-shaped maze with single path |
Open |
7×5 | Open rectangular arena |
Medium |
8×8 | Moderately complex maze |
Large |
12×9 | Large complex maze |
Giant |
16×12 | Very large maze |
| Environment | Description |
|---|---|
Open_Diverse_G |
Open maze with multiple goal locations |
Medium_Diverse_G |
Medium maze with diverse goal placement |
Large_Diverse_G |
Large maze with many goal options |
| Environment | Description |
|---|---|
Open_Diverse_GR |
Open maze with flexible goal/reset locations |
Medium_Diverse_GR |
Medium maze with combined goal/reset spots |
Large_Diverse_GR |
Large maze with maximum location diversity |
import pointax
# Standard environments
env = pointax.make_umaze(reward_type="sparse")
env = pointax.make_large(reward_type="dense")
# Diverse environments
env = pointax.make_open_diverse_g()
env = pointax.make_large_diverse_gr(reward_type="dense")Create environments from simple 2D layouts:
# Simple custom maze (1=wall, 0=empty)
custom_maze = [
[1, 1, 1, 1, 1],
[1, 0, 0, 0, 1],
[1, 0, 1, 0, 1],
[1, 0, 0, 0, 1],
[1, 1, 1, 1, 1]
]
env = pointax.make_custom(custom_maze)
# Maze with specific goal/reset locations
maze_with_goals = [
[1, 1, 1, 1, 1, 1, 1],
[1, 'R', 0, 0, 0, 'G', 1], # 'R'=reset, 'G'=goal
[1, 0, 1, 1, 1, 0, 1],
[1, 0, 0, 'G', 0, 0, 1], # Multiple goals
[1, 'C', 0, 0, 0, 'C', 1], # 'C'=combined goal/reset
[1, 1, 1, 1, 1, 1, 1]
]
env = pointax.make_custom(maze_with_goals, reward_type="dense")
# Boolean maze (intuitive for many users)
bool_maze = [
[True, True, True ], # True = wall
[True, False, True ], # False = empty
[True, True, True ]
]
env = pointax.make_custom(bool_maze, maze_id="MyMaze")Pointax is designed for JAX workflows:
import jax
import jax.numpy as jnp
env = pointax.make_medium()
params = env.default_params
# JIT compilation
@jax.jit
def fast_step(key, state, action):
return env.step_env(key, state, action, params)
# Vectorization
@jax.vmap
def batch_step(keys, states, actions):
return env.step_env(keys, states, actions, params)
# Use with batches
batch_size = 64
keys = jax.random.split(jax.random.PRNGKey(42), batch_size)
actions = jnp.zeros((batch_size, 2))
# ... batch operations- Type:
Box(-inf, inf, (6,), float32) - Contents:
[pos_x, pos_y, vel_x, vel_y, goal_x, goal_y]
- Type:
Box(-1.0, 1.0, (2,), float32) - Contents: Continuous forces in x and y directions
- Sparse: 1.0 when goal reached (distance ≤ 0.45m), 0.0 otherwise
- Dense:
-distance_to_goal(negative distance for gradient-based learning)
| Symbol | Value | Description |
|---|---|---|
1 or True |
1 | Wall (impassable) |
0 or False |
0 | Empty space |
'G' |
2 | Goal location |
'R' |
3 | Reset location |
'C' |
4 | Combined goal/reset location |
We welcome contributions! Please see CONTRIBUTING.md for guidelines.
If you use Pointax in your research, please cite:
@misc{pointax2025,
title={Pointax: JAX-Native PointMaze Environment},
author={Waris Radji},
year={2025},
url={https://github.com/riiswa/pointax}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Original PointMaze environment from Gymnasium Robotics
- Gymnax for their API
- MuJoCo physics engine for inspiration
- JAX team for the excellent framework
Pointax builds upon and complements several excellent JAX-based RL environments:
-
JaxGCRL: Brax-accelerated goal-conditioned RL environments including maze implementations.
-
QDax: Quality-Diversity optimization library containing an alternative PointMaze environment.
Pointax differentiates itself by providing, a pure JAX implementation without Brax overhead.
