55
66import numpy as np
77
8+ from mm_utils import parsing
89from mm_utils .enums import RefType
910from mm_utils .math import interpolate , wrap_pi_array , wrap_pi_scalar
1011
@@ -142,7 +143,7 @@ def __init__(self, config):
142143
143144 # Parse base pose (optional)
144145 if "base_pose" in config :
145- self .base_target = np . array (config ["base_pose" ])
146+ self .base_target = parsing . parse_array (config ["base_pose" ])
146147 if len (self .base_target ) != 3 :
147148 raise ValueError (
148149 f"base_pose must be SE2 [x, y, yaw], got { len (self .base_target )} dimensions"
@@ -154,7 +155,7 @@ def __init__(self, config):
154155
155156 # Parse EE pose (optional)
156157 if "ee_pose" in config :
157- self .ee_target = np . array (config ["ee_pose" ])
158+ self .ee_target = parsing . parse_array (config ["ee_pose" ])
158159 if len (self .ee_target ) != 6 :
159160 raise ValueError (
160161 f"ee_pose must be SE3 [x, y, z, roll, pitch, yaw], got { len (self .ee_target )} dimensions"
@@ -170,10 +171,23 @@ def __init__(self, config):
170171 )
171172
172173 # Common parameters
173- self .tracking_err_tol = config .get ("tracking_err_tol " , 0.02 )
174+ self .tracking_pos_err_tol = config .get ("tracking_pos_err_tol " , 0.02 )
174175 self .tracking_ori_err_tol = config .get ("tracking_ori_err_tol" , 0.1 )
175176 self .hold_period = config .get ("hold_period" , 0.0 )
176177
178+ # Parse masks to specify which dimensions matter for completion
179+ # base_mask: [x, y, yaw] - True means that dimension is checked
180+ if "base_mask" in config :
181+ self .base_mask = np .array (config ["base_mask" ], dtype = bool )
182+ else :
183+ self .base_mask = np .ones (3 , dtype = bool )
184+
185+ # ee_mask: [x, y, z, roll, pitch, yaw] - True means that dimension is checked
186+ if "ee_mask" in config :
187+ self .ee_mask = np .array (config ["ee_mask" ], dtype = bool )
188+ else :
189+ self .ee_mask = np .ones (6 , dtype = bool )
190+
177191 # State tracking
178192 self .finished = False
179193 self .base_reached = False
@@ -224,10 +238,19 @@ def checkFinished(self, t, states):
224238 # Check base if applicable
225239 if self .has_base_ref :
226240 base_pose = states ["base" ]["pose" ]
227- pos_err = np .linalg .norm (base_pose [:2 ] - self .base_target [:2 ])
228- yaw_err = abs (wrap_pi_scalar (base_pose [2 ] - self .base_target [2 ]))
229- pos_within_tol = pos_err < self .tracking_err_tol
230- ori_within_tol = yaw_err < self .tracking_ori_err_tol
241+
242+ # Check position (x, y) only if mask indicates it matters
243+ pos_mask = self .base_mask [:2 ]
244+ pos_err = np .linalg .norm ((base_pose [:2 ] - self .base_target [:2 ])[pos_mask ])
245+ pos_within_tol = pos_err < self .tracking_pos_err_tol
246+
247+ # Check orientation (yaw) only if mask indicates it matters
248+ if self .base_mask [2 ]:
249+ yaw_err = abs (wrap_pi_scalar (base_pose [2 ] - self .base_target [2 ]))
250+ ori_within_tol = yaw_err < self .tracking_ori_err_tol
251+ else :
252+ ori_within_tol = True
253+ yaw_err = 0.0
231254
232255 if pos_within_tol and ori_within_tol :
233256 if not self .base_reached :
@@ -246,10 +269,16 @@ def checkFinished(self, t, states):
246269 # Check EE if applicable
247270 if self .has_ee_ref :
248271 ee_pose = states ["EE" ]["pose" ]
249- pos_err = np .linalg .norm (ee_pose [:3 ] - self .ee_target [:3 ])
272+
273+ # Check position (x, y, z) only if mask indicates it matters
274+ pos_mask = self .ee_mask [:3 ]
275+ pos_err = np .linalg .norm ((ee_pose [:3 ] - self .ee_target [:3 ])[pos_mask ])
276+ pos_within_tol = pos_err < self .tracking_pos_err_tol
277+
278+ # Check orientation (roll, pitch, yaw) only if mask indicates it matters
279+ ori_mask = self .ee_mask [3 :]
250280 ori_diff = wrap_pi_array (ee_pose [3 :] - self .ee_target [3 :])
251- ori_err = np .linalg .norm (ori_diff )
252- pos_within_tol = pos_err < self .tracking_err_tol
281+ ori_err = np .linalg .norm (ori_diff [ori_mask ])
253282 ori_within_tol = ori_err < self .tracking_ori_err_tol
254283
255284 if pos_within_tol and ori_within_tol :
@@ -305,7 +334,10 @@ def __init__(self, config):
305334
306335 # Parse base path (optional)
307336 if "base_path" in config :
308- base_path = np .array (config ["base_path" ])
337+ # Parse each row to handle pi notation
338+ base_path = np .array (
339+ [parsing .parse_array (row ) for row in config ["base_path" ]]
340+ )
309341 if base_path .shape [1 ] != 3 :
310342 raise ValueError (
311343 f"base_path must be SE2 [x, y, yaw], got shape { base_path .shape } "
@@ -318,7 +350,8 @@ def __init__(self, config):
318350
319351 # Parse EE path (optional)
320352 if "ee_path" in config :
321- ee_path = np .array (config ["ee_path" ])
353+ # Parse each row to handle pi notation
354+ ee_path = np .array ([parsing .parse_array (row ) for row in config ["ee_path" ]])
322355 if ee_path .shape [1 ] != 6 :
323356 raise ValueError (
324357 f"ee_path must be SE3 [x, y, z, roll, pitch, yaw], got shape { ee_path .shape } "
@@ -335,7 +368,7 @@ def __init__(self, config):
335368 )
336369
337370 # Common parameters
338- self .tracking_err_tol = config .get ("tracking_err_tol " , 0.02 )
371+ self .tracking_pos_err_tol = config .get ("tracking_pos_err_tol " , 0.02 )
339372 self .tracking_ori_err_tol = config .get ("tracking_ori_err_tol" , 0.1 )
340373 self .end_stop = config .get ("end_stop" , False )
341374
@@ -472,7 +505,7 @@ def checkFinished(self, t, states):
472505 base_vel = states ["base" ].get ("velocity" )
473506 end_pose = self .base_plan ["p" ][- 1 ]
474507 pos_err , ori_err = self ._compute_error (base_pose , end_pose , is_base = True )
475- pos_cond = pos_err < self .tracking_err_tol
508+ pos_cond = pos_err < self .tracking_pos_err_tol
476509 ori_cond = ori_err < self .tracking_ori_err_tol
477510 pos_ori_cond = pos_cond and ori_cond
478511 vel_cond = base_vel is not None and np .linalg .norm (base_vel ) < 1e-2
@@ -490,7 +523,7 @@ def checkFinished(self, t, states):
490523 ee_vel = states ["EE" ].get ("velocity" )
491524 end_pose = self .ee_plan ["p" ][- 1 ]
492525 pos_err , ori_err = self ._compute_error (ee_pose , end_pose , is_base = False )
493- pos_cond = pos_err < self .tracking_err_tol
526+ pos_cond = pos_err < self .tracking_pos_err_tol
494527 ori_cond = ori_err < self .tracking_ori_err_tol
495528 pos_ori_cond = pos_cond and ori_cond
496529 vel_cond = ee_vel is not None and np .linalg .norm (ee_vel ) < 1e-2
0 commit comments