Skip to content

Commit 48bdc36

Browse files
amacatiratheron
andauthored
Add capsule draw function (#49)
--------- Co-authored-by: Rather1337 <marcel.rath@gmx.de>
1 parent 3b28b44 commit 48bdc36

File tree

4 files changed

+136
-8
lines changed

4 files changed

+136
-8
lines changed

crazyflow/sim/visualize.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,37 @@ def draw_points(sim: Sim, points: NDArray, rgba: NDArray | None = None, size: fl
6969
)
7070

7171

72+
def draw_capsule(
73+
sim: Sim,
74+
p1: NDArray,
75+
p2: NDArray,
76+
radius: float = 0.05,
77+
rgba: NDArray | None = None,
78+
cylinder: bool = False,
79+
):
80+
"""Draw a capsule (pill) or cylinder between two points.
81+
82+
Args:
83+
sim: The simulation.
84+
p1: Start point [3,]
85+
p2: End point [3,]
86+
radius: The thickness of the geom in [m].
87+
rgba: The color of the object.
88+
cylinder: If True, draws a flat-ended cylinder. If False, draws a pill-shaped capsule.
89+
"""
90+
if sim.viewer is None:
91+
return
92+
93+
pos = (p1 + p2) / 2.0 # Center of the geom
94+
half_length = np.linalg.norm(p2 - p1) / 2.0 # MuJoCo uses half-extents
95+
size = np.array([radius, half_length, 0])
96+
# Align the z-axis of the geom to the vector from p1 to p2
97+
mat = _rotation_matrix_from_points(p1[None, :], p2[None, :]).as_matrix().flatten()
98+
geom_type = mujoco.mjtGeom.mjGEOM_CYLINDER if cylinder else mujoco.mjtGeom.mjGEOM_CAPSULE
99+
rgba = rgba if rgba is not None else np.array([1, 0, 0, 1.0])
100+
sim.viewer.viewer.add_marker(type=geom_type, pos=pos, size=size, mat=mat, rgba=rgba)
101+
102+
72103
def change_material(
73104
sim: Sim,
74105
mat_name: str,

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,15 @@ def device() -> str:
3131
if "gpu" in available_backends():
3232
return "gpu"
3333
return "cpu"
34+
35+
36+
def skip_headless():
37+
if os.environ.get("DISPLAY") is None:
38+
pytest.skip("DISPLAY is not set, skipping test in headless environment")
39+
40+
41+
# Marker for conditional skip in headless environments
42+
skip_if_headless = pytest.mark.skipif(
43+
os.environ.get("DISPLAY") is None,
44+
reason="DISPLAY is not set, skipping test in headless environment",
45+
)

tests/unit/test_sim.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
from __future__ import annotations
44

5-
import os
65
from typing import TYPE_CHECKING
76

87
import jax
98
import jax.numpy as jnp
109
import mujoco
1110
import numpy as np
1211
import pytest
12+
from conftest import skip_if_headless
1313
from jax import Array
1414

1515
from crazyflow.control import Control
@@ -23,11 +23,6 @@
2323
from typing import Any
2424

2525

26-
def skip_headless():
27-
if os.environ.get("DISPLAY") is None:
28-
pytest.skip("DISPLAY is not set, skipping test in headless environment")
29-
30-
3126
def array_meta_assert(
3227
x: Array,
3328
shape: tuple[int, ...] | None = None,
@@ -272,15 +267,15 @@ def test_sim_state_control_device(device: str):
272267

273268

274269
@pytest.mark.render
270+
@skip_if_headless
275271
def test_render_human(device: str):
276272
sim = Sim(device=device)
277273
sim.render()
278274
sim.viewer.close()
279275

280276

281-
# Do not mark as render to ensure it runs by default. This function will not open a viewer.
277+
@skip_if_headless
282278
def test_render_rgb_array(device: str):
283-
skip_headless()
284279
sim = Sim(n_worlds=2, device=device)
285280
img = sim.render(mode="rgb_array", width=1024, height=1024)
286281
assert isinstance(img, np.ndarray), "Image must be a numpy array"

tests/unit/test_visualizations.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
import pytest
3+
from conftest import skip_if_headless
4+
5+
from crazyflow import Sim
6+
from crazyflow.sim.visualize import draw_capsule, draw_line, draw_points
7+
8+
9+
@pytest.mark.unit
10+
@skip_if_headless
11+
def test_draw_capsule(device: str):
12+
"""Test drawing a capsule and verify it changes the rendering."""
13+
sim = Sim(device=device)
14+
# Render without drawing
15+
sim.render(mode="rgb_array", width=120, height=90) # Warm up the renderer
16+
img_before = sim.render(mode="rgb_array", width=120, height=90)
17+
# Draw a capsule and render
18+
p1 = np.array([0.0, 0.0, 0.5])
19+
p2 = np.array([0.5, 0.5, 1.0])
20+
rgba = np.array([1.0, 0.0, 0.0, 1.0])
21+
draw_capsule(sim, p1, p2, radius=0.05, rgba=rgba)
22+
img_after = sim.render(mode="rgb_array", width=120, height=90)
23+
# Verify that the image changed
24+
assert not np.array_equal(img_before, img_after), "Drawing capsule should change the rendering"
25+
sim.close()
26+
27+
28+
@pytest.mark.unit
29+
@skip_if_headless
30+
def test_draw_capsule_cylinder(device: str):
31+
"""Test drawing a cylinder."""
32+
sim = Sim(device=device)
33+
sim.render(mode="rgb_array", width=120, height=90) # Warm up the renderer
34+
img_before = sim.render(mode="rgb_array", width=120, height=90)
35+
# Draw a cylinder instead of capsule
36+
p1 = np.array([0.0, 0.0, 0.5])
37+
p2 = np.array([0.0, 0.0, 1.0])
38+
rgba = np.array([0.0, 1.0, 0.0, 1.0])
39+
draw_capsule(sim, p1, p2, radius=0.05, rgba=rgba, cylinder=True)
40+
img_after = sim.render(mode="rgb_array", width=120, height=90)
41+
assert not np.array_equal(img_before, img_after), "Drawing cylinder should change the rendering"
42+
sim.close()
43+
44+
45+
@pytest.mark.unit
46+
@skip_if_headless
47+
def test_draw_line(device: str):
48+
"""Test drawing a line and verify it changes the rendering."""
49+
sim = Sim(device=device)
50+
sim.render(mode="rgb_array", width=120, height=90) # Warm up the renderer
51+
img_before = sim.render(mode="rgb_array", width=120, height=90)
52+
# Draw a line with multiple points
53+
points = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]])
54+
rgba = np.array([0.0, 0.0, 1.0, 1.0])
55+
draw_line(sim, points, rgba=rgba, start_size=5.0, end_size=2.0)
56+
img_after = sim.render(mode="rgb_array", width=120, height=90)
57+
assert not np.array_equal(img_before, img_after), "Drawing line should change the rendering"
58+
sim.close()
59+
60+
61+
@pytest.mark.unit
62+
@skip_if_headless
63+
def test_draw_points(device: str):
64+
"""Test drawing points and verify it changes the rendering."""
65+
sim = Sim(device=device)
66+
sim.render(mode="rgb_array", width=120, height=90) # Warm up the renderer
67+
img_before = sim.render(mode="rgb_array", width=120, height=90)
68+
# Draw multiple points
69+
points = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.5], [0.0, 0.0, 1.0]])
70+
rgba = np.array([1.0, 1.0, 0.0, 1.0])
71+
draw_points(sim, points, rgba=rgba, size=0.02)
72+
img_after = sim.render(mode="rgb_array", width=120, height=90)
73+
assert not np.array_equal(img_before, img_after), "Drawing points should change the rendering"
74+
sim.close()
75+
76+
77+
@pytest.mark.unit
78+
@skip_if_headless
79+
def test_draw_combined(device: str):
80+
"""Test drawing multiple visualization elements together."""
81+
sim = Sim(device=device)
82+
p1 = np.array([0.0, 0.0, 0.5])
83+
p2 = np.array([0.2, 0.0, 0.8])
84+
draw_capsule(sim, p1, p2, radius=0.03, rgba=np.array([1.0, 0.0, 0.0, 1.0]))
85+
line_points = np.array([[0.3, 0.0, 0.5], [0.3, 0.2, 0.7], [0.5, 0.2, 0.9]])
86+
draw_line(sim, line_points, rgba=np.array([0.0, 1.0, 0.0, 1.0]))
87+
points = np.array([[0.6, 0.0, 0.6], [0.7, 0.1, 0.7], [0.8, 0.0, 0.8]])
88+
draw_points(sim, points, rgba=np.array([0.0, 0.0, 1.0, 1.0]), size=0.025)
89+
sim.render(mode="rgb_array", width=120, height=90)
90+
sim.close()

0 commit comments

Comments
 (0)