Skip to content

Commit 75bf7a3

Browse files
authored
Fix camera settings in render() (#51)
1 parent d74a404 commit 75bf7a3

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

crazyflow/sim/sim.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ def render(
157157
) -> NDArray | None:
158158
if self.viewer is None:
159159
if isinstance(camera, str):
160-
cam_id, cam_name = None, camera
160+
cam_id = mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_CAMERA, camera)
161+
assert cam_id > -1, f"Camera name '{camera}' not found in the model."
161162
elif isinstance(camera, int):
162-
cam_id, cam_name = camera, None
163-
if cam_id < -1:
164-
raise ValueError(f"camera id must be >=-1, was {cam_id}")
163+
cam_id = camera
164+
assert cam_id >= -1, f"camera id must be >=-1, was {cam_id}"
165165
else:
166166
raise TypeError("camera argument must be integer or string")
167167
self.mj_model.vis.global_.offwidth = width
@@ -174,8 +174,14 @@ def render(
174174
height=height,
175175
width=width,
176176
camera_id=cam_id,
177-
camera_name=cam_name,
178177
)
178+
# In human mode, cam_id is set to -1, so we force it to the desired value
179+
if mode == "human" and cam_id > -1:
180+
# Render one frame to force mj to create the viewer
181+
self.viewer.render(mode)
182+
self.viewer.viewer.cam.fixedcamid = cam_id
183+
self.viewer.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
184+
179185
self.mj_data.qpos[:] = self.mjx_data.qpos[world, :]
180186
self.mj_data.mocap_pos[:] = self.mjx_data.mocap_pos[world, :]
181187
self.mj_data.mocap_quat[:] = self.mjx_data.mocap_quat[world, :]

tests/unit/test_render.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import mujoco
2+
import pytest
3+
from conftest import skip_if_headless
4+
5+
from crazyflow import Sim
6+
7+
8+
@pytest.mark.unit
9+
@pytest.mark.parametrize("cam_name", ["fpv_cam:0", "track_cam:0", "fpv_cam:1", "track_cam:1"])
10+
@pytest.mark.render
11+
@skip_if_headless
12+
def test_render_camera_selection_from_name(cam_name: str):
13+
sim = Sim(drone_model="cf21B_500", n_drones=2)
14+
cam_id = mujoco.mj_name2id(sim.mj_model, mujoco.mjtObj.mjOBJ_CAMERA, cam_name)
15+
sim.render(mode="human", camera=cam_name)
16+
viewer_cam = sim.viewer.viewer.cam
17+
assert viewer_cam.type == mujoco.mjtCamera.mjCAMERA_FIXED, "Camera type was not set to FIXED"
18+
assert viewer_cam.fixedcamid == cam_id, f"Expected cam ID {cam_id}, got {viewer_cam.fixedcamid}"
19+
sim.close()
20+
21+
22+
@pytest.mark.unit
23+
@pytest.mark.parametrize("cam_id", [0, 1, 2, 3])
24+
@pytest.mark.render
25+
@skip_if_headless
26+
def test_render_camera_selection_from_id(cam_id: int):
27+
sim = Sim(drone_model="cf21B_500", n_drones=2)
28+
sim.render(mode="human", camera=cam_id)
29+
viewer_cam = sim.viewer.viewer.cam
30+
assert viewer_cam.type == mujoco.mjtCamera.mjCAMERA_FIXED, "Camera type was not set to FIXED"
31+
assert viewer_cam.fixedcamid == cam_id, f"Expected cam ID {cam_id}, got {viewer_cam.fixedcamid}"
32+
sim.close()
33+
34+
35+
@pytest.mark.unit
36+
@pytest.mark.render
37+
@skip_if_headless
38+
def test_render_free_camera():
39+
sim = Sim(drone_model="cf21B_500", n_drones=2)
40+
sim.render(mode="human")
41+
viewer_cam = sim.viewer.viewer.cam
42+
assert viewer_cam.type == mujoco.mjtCamera.mjCAMERA_FREE, "Camera type was not set to FREE"
43+
sim.close()

0 commit comments

Comments
 (0)