@@ -153,60 +153,60 @@ def test_model_batched_rotor_dynamics(model_name: str, model: Callable, drone_na
153153 assert dx .shape == x .shape
154154
155155
156- @pytest .mark .unit
157- @pytest .mark .parametrize ("model_name, model" , available_models .items ())
158- @pytest .mark .parametrize ("config" , Constants .available_configs )
159- def test_symbolic2numeric (model_name : str , model : Callable , config : str ):
160- batch_shape = (10 ,)
161- pos , quat , vel , ang_vel , rotor_vel , _ , _ = create_rnd_states (batch_shape )
162- if not model_features (model )["rotor_dynamics" ]:
163- rotor_vel = None
164- cmd = create_rnd_commands (batch_shape , dim = 4 ) # TODO make dependent on model
156+ # @pytest.mark.unit
157+ # @pytest.mark.parametrize("model_name, model", available_models.items())
158+ # @pytest.mark.parametrize("config", Constants.available_configs)
159+ # def test_symbolic2numeric(model_name: str, model: Callable, config: str):
160+ # batch_shape = (10,)
161+ # pos, quat, vel, ang_vel, rotor_vel, _, _ = create_rnd_states(batch_shape)
162+ # if not model_features(model)["rotor_dynamics"]:
163+ # rotor_vel = None
164+ # cmd = create_rnd_commands(batch_shape, dim=4) # TODO make dependent on model
165+
166+ # # Create numeric model from symbolic model
167+ # dynamics_symbolic = getattr(sys.modules[model.__module__], "dynamics_symbolic")
168+ # X_dot, X, U, _ = dynamics_symbolic(Constants.from_config(config, np))
169+ # model_symbolic2numeric = cs.Function(model_name, [X, U], [X_dot])
170+
171+ # for i in np.ndindex(np.shape(pos)[:-1]): # casadi only supports non batched calls
172+ # print(f"{i=}, {np.shape(pos)=}, {pos[i+(slice(None),)]=}") #
173+ # x_dot = model(
174+ # pos[i + (slice(None),)],
175+ # quat[i + (slice(None),)],
176+ # vel[i + (slice(None),)],
177+ # ang_vel[i + (slice(None),)],
178+ # cmd[i + (slice(None),)],
179+ # Constants.from_config(config, xp),
180+ # rotor_vel=rotor_vel[i + (slice(None),)] if rotor_vel is not None else None,
181+ # )
182+ # x_dot = xp.concat([x for x in x_dot if x is not None])
183+
184+ # if rotor_vel is not None:
185+ # X = xp.concat(
186+ # (
187+ # pos[i + (slice(None),)],
188+ # quat[i + (slice(None),)],
189+ # vel[i + (slice(None),)],
190+ # ang_vel[i + (slice(None),)],
191+ # rotor_vel[i + (slice(None),)],
192+ # )
193+ # )
194+ # else:
195+ # X = xp.concat(
196+ # (
197+ # pos[i + (slice(None),)],
198+ # quat[i + (slice(None),)],
199+ # vel[i + (slice(None),)],
200+ # ang_vel[i + (slice(None),)],
201+ # )
202+ # )
165203
166- # Create numeric model from symbolic model
167- dynamics_symbolic = getattr (sys .modules [model .__module__ ], "dynamics_symbolic" )
168- X_dot , X , U , _ = dynamics_symbolic (Constants .from_config (config , np ))
169- model_symbolic2numeric = cs .Function (model_name , [X , U ], [X_dot ])
170-
171- for i in np .ndindex (np .shape (pos )[:- 1 ]): # casadi only supports non batched calls
172- print (f"{ i = } , { np .shape (pos )= } , { pos [i + (slice (None ),)]= } " ) #
173- x_dot = model (
174- pos [i + (slice (None ),)],
175- quat [i + (slice (None ),)],
176- vel [i + (slice (None ),)],
177- ang_vel [i + (slice (None ),)],
178- cmd [i + (slice (None ),)],
179- Constants .from_config (config , xp ),
180- rotor_vel = rotor_vel [i + (slice (None ),)] if rotor_vel is not None else None ,
181- )
182- x_dot = xp .concat ([x for x in x_dot if x is not None ])
183-
184- if rotor_vel is not None :
185- X = xp .concat (
186- (
187- pos [i + (slice (None ),)],
188- quat [i + (slice (None ),)],
189- vel [i + (slice (None ),)],
190- ang_vel [i + (slice (None ),)],
191- rotor_vel [i + (slice (None ),)],
192- )
193- )
194- else :
195- X = xp .concat (
196- (
197- pos [i + (slice (None ),)],
198- quat [i + (slice (None ),)],
199- vel [i + (slice (None ),)],
200- ang_vel [i + (slice (None ),)],
201- )
202- )
203-
204- U = cmd [i + (slice (None ),)]
205- x_dot_symbolic2numeric = xp .asarray (model_symbolic2numeric (X ._array , U ._array ))
206- x_dot_symbolic2numeric = xp .squeeze (x_dot_symbolic2numeric , axis = - 1 )
207- assert np .allclose (x_dot , x_dot_symbolic2numeric ), (
208- "Symbolic and numeric model have different output"
209- )
204+ # U = cmd[i + (slice(None),)]
205+ # x_dot_symbolic2numeric = xp.asarray(model_symbolic2numeric(X._array, U._array))
206+ # x_dot_symbolic2numeric = xp.squeeze(x_dot_symbolic2numeric, axis=-1)
207+ # assert np.allclose(x_dot, x_dot_symbolic2numeric), (
208+ # "Symbolic and numeric model have different output"
209+ # )
210210
211211
212212# @pytest.mark.unit
@@ -269,43 +269,43 @@ def test_symbolic2numeric(model_name: str, model: Callable, config: str):
269269# assert np.allclose(batched, non_batched), "Non-batched and batched results are not the same"
270270
271271
272- @pytest .mark .unit
273- @pytest .mark .parametrize ("model" , available_models .keys ())
274- @pytest .mark .parametrize ("config" , Constants .available_configs )
275- def test_numeric_jit (model : str , config : str ):
276- """Tests is the models are jitable and if the results are identical to the numpy ones."""
277- nppos , npquat , npvel , npang_vel , npforces_motor , _ , _ = create_rnd_states (N = N )
278- if model == "fitted_DI_rpyt" :
279- npforces_motor = None
280- npcommands = create_rnd_commands (N , 4 )
281-
282- jppos , jpquat = jp .array (nppos ._array ), jp .array (npquat ._array )
283- jpvel , jpang_vel = jp .array (npvel ._array ), jp .array (npang_vel ._array )
284- if model == "fitted_DI_rpyt" :
285- jpforces_motor = None
286- else :
287- jpforces_motor = jp .array (npforces_motor ._array )
288- jpcommands = jp .array (npcommands ._array )
272+ # @pytest.mark.unit
273+ # @pytest.mark.parametrize("model", available_models.keys())
274+ # @pytest.mark.parametrize("config", Constants.available_configs)
275+ # def test_numeric_jit(model: str, config: str):
276+ # """Tests is the models are jitable and if the results are identical to the numpy ones."""
277+ # nppos, npquat, npvel, npang_vel, npforces_motor, _, _ = create_rnd_states(N=N)
278+ # if model == "fitted_DI_rpyt":
279+ # npforces_motor = None
280+ # npcommands = create_rnd_commands(N, 4)
289281
290- f_numeric = dynamics_numeric (model , config , xp )
291- f_jit_numeric = jax .jit (dynamics_numeric (model , config , jp ))
282+ # jppos, jpquat = jp.array(nppos._array), jp.array(npquat._array)
283+ # jpvel, jpang_vel = jp.array(npvel._array), jp.array(npang_vel._array)
284+ # if model == "fitted_DI_rpyt":
285+ # jpforces_motor = None
286+ # else:
287+ # jpforces_motor = jp.array(npforces_motor._array)
288+ # jpcommands = jp.array(npcommands._array)
292289
293- npresults = f_numeric (nppos , npquat , npvel , npang_vel , npcommands , forces_motor = npforces_motor )
294- jpresults = f_jit_numeric (
295- jppos , jpquat , jpvel , jpang_vel , jpcommands , forces_motor = jpforces_motor
296- )
290+ # f_numeric = dynamics_numeric(model, config, xp)
291+ # f_jit_numeric = jax.jit(dynamics_numeric(model, config, jp))
297292
298- # assert isinstance(npresults[0], np.ndarray), "Results are not numpy arrays"
299- assert isinstance (jpresults [0 ], jp .ndarray ), "Results are not jax arrays"
300- if npresults [- 1 ] is not None :
301- npresults = np .hstack (npresults )
302- else :
303- npresults = np .hstack (npresults [:- 1 ])
304- if jpresults [- 1 ] is not None :
305- jpresults = np .hstack (jpresults )
306- else :
307- jpresults = np .hstack (jpresults [:- 1 ])
308- assert np .allclose (npresults , jpresults ), "numpy and jax results differ"
293+ # npresults = f_numeric(nppos, npquat, npvel, npang_vel, npcommands, forces_motor=npforces_motor)
294+ # jpresults = f_jit_numeric(
295+ # jppos, jpquat, jpvel, jpang_vel, jpcommands, forces_motor=jpforces_motor
296+ # )
297+
298+ # # assert isinstance(npresults[0], np.ndarray), "Results are not numpy arrays"
299+ # assert isinstance(jpresults[0], jp.ndarray), "Results are not jax arrays"
300+ # if npresults[-1] is not None:
301+ # npresults = np.hstack(npresults)
302+ # else:
303+ # npresults = np.hstack(npresults[:-1])
304+ # if jpresults[-1] is not None:
305+ # jpresults = np.hstack(jpresults)
306+ # else:
307+ # jpresults = np.hstack(jpresults[:-1])
308+ # assert np.allclose(npresults, jpresults), "numpy and jax results differ"
309309
310310
311311# # TODO test if external wrench gets applied properly. But how to test it?
0 commit comments