diff --git a/1_vlm_demo.py b/1_vlm_demo.py index 3c79396..342747b 100644 --- a/1_vlm_demo.py +++ b/1_vlm_demo.py @@ -1,109 +1,300 @@ +""" +=============================================================================== +1_vlm_demo.py - Vision Language Model Demo for 3D Voxel Generation +=============================================================================== + +This script uses a fine-tuned Qwen2.5-VL model to analyze images and generate +3D voxel representations of objects and their parts. + +Pipeline Overview: + 1. Load an image from the demo folder + 2. Send image + prompt to the VLM to get basic object info (parts, materials) + 3. For each detected part, generate voxel coordinates in a 32x32x32 grid + 4. Save voxel data as numpy arrays and optionally as PLY point clouds + +Key Concepts: + - Voxel Grid: A 32x32x32 3D grid where each cell can be occupied or empty + - Voxel Encoding: 3D coordinates (x,y,z) are encoded into a single integer + using bit shifting: index = (x << 10) | (y << 5) | z + - Run-Length Encoding: Consecutive voxel indices are merged (e.g., "199-216") + +Author: PhysX-Anything Team +=============================================================================== +""" + +# ============================================================================= +# IMPORTS +# ============================================================================= + +# Machine Learning & Vision from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import torch + +# Standard Library import base64 import os -import ipdb +import argparse + +# Data Processing import numpy as np from PIL import Image import trimesh + +# Image Processing from rembg import remove -import argparse -def voxel_encode(voxels: np.ndarray, size: int = 32) -> np.ndarray: +# Debugging (can be removed in production) +import ipdb + + +# ============================================================================= +# VOXEL ENCODING/DECODING UTILITIES +# ============================================================================= + +def voxel_encode(voxels: np.ndarray, size: int = 32) -> np.ndarray: + """ + Encode 3D voxel coordinates into single integers using bit-packing. + + This function converts (x, y, z) coordinates into a single integer index + using bit shifting. Each coordinate uses 5 bits (supports values 0-31). + + Formula: index = (x << 10) | (y << 5) | z + + Args: + voxels (np.ndarray): Array of shape (N, 3) containing voxel coordinates + size (int): Grid size (must be 32 for 5-bit encoding) + + Returns: + np.ndarray: Array of encoded integer indices + + Example: + >>> coords = np.array([[0, 0, 0], [1, 2, 3]]) + >>> voxel_encode(coords) + array([0, 1091]) + """ voxels = np.asarray(voxels, dtype=np.int64) - assert voxels.ndim == 2 and voxels.shape[1] == 3, "voxels shape should be (N,3)" - assert size == 32, "size=32(2^5)。" + + # Validate input shape + assert voxels.ndim == 2 and voxels.shape[1] == 3, "voxels shape should be (N, 3)" + assert size == 32, "Grid size must be 32 (2^5) for 5-bit encoding" + + # Validate coordinate ranges if (voxels < 0).any() or (voxels >= size).any(): - raise ValueError("xyz should be within [0, 32).") - + raise ValueError("All coordinates must be in range [0, 32)") + + # Extract individual coordinates x, y, z = voxels[:, 0], voxels[:, 1], voxels[:, 2] + + # Bit-pack into single integer: x uses bits 10-14, y uses bits 5-9, z uses bits 0-4 return (x << 10) | (y << 5) | z def voxel_decode(indices: np.ndarray, size: int = 32) -> np.ndarray: - + """ + Decode integer indices back into 3D voxel coordinates. + + This reverses the voxel_encode operation, extracting (x, y, z) coordinates + from bit-packed integers using bit masking and shifting. + + Args: + indices (np.ndarray): Array of encoded voxel indices + size (int): Grid size (must be 32) + + Returns: + np.ndarray: Array of shape (N, 3) containing decoded coordinates + + Example: + >>> indices = np.array([0, 1091]) + >>> voxel_decode(indices) + array([[0, 0, 0], [1, 2, 3]]) + """ indices = np.asarray(indices, dtype=np.int64).ravel() - assert size == 32, "size=32(2^5)。" + + assert size == 32, "Grid size must be 32 (2^5) for 5-bit decoding" + + # Clamp out-of-range indices and warn if (indices < 0).any() or (indices >= size**3).any(): - - indices=indices.clip(0,size**3-1) - print("index should be within [0, 32768).") - - - x = (indices >> 10) & 31 - y = (indices >> 5) & 31 - z = indices & 31 + indices = indices.clip(0, size**3 - 1) + print("Warning: Some indices were out of range [0, 32768) and have been clamped.") + + # Extract coordinates using bit masking (31 = 0b11111, a 5-bit mask) + x = (indices >> 10) & 31 # Extract bits 10-14 + y = (indices >> 5) & 31 # Extract bits 5-9 + z = indices & 31 # Extract bits 0-4 + return np.stack([x, y, z], axis=1) +# ============================================================================= +# STRING CONVERSION UTILITIES (for VLM communication) +# ============================================================================= def ints_to_space_separated_str(arr: np.ndarray) -> str: + """ + Convert an array of integers to a space-separated string. + + Args: + arr (np.ndarray): Array of integers + + Returns: + str: Space-separated string representation + + Example: + >>> ints_to_space_separated_str(np.array([1, 2, 3])) + '1 2 3' + """ arr = np.asarray(arr, dtype=np.int64).ravel() return " ".join(map(str, arr)) - def merge_adjacent_to_dash(s: str) -> str: - + """ + Compress a sequence of numbers by merging consecutive runs into ranges. + + This function takes a space-separated string of numbers and converts + consecutive sequences into dash-separated ranges for more compact output. + + Args: + s (str): Space-separated string of integers (e.g., "1 2 3 5 6 7 10") + + Returns: + str: Compressed string with ranges (e.g., "1-3 5-7 10") + + Example: + >>> merge_adjacent_to_dash("199 200 201 202 230 231") + '199-202 230-231' + """ if not s.strip(): return "" - + + # Parse, sort, and deduplicate numbers nums = list(map(int, s.split())) - nums = sorted(set(nums)) - + + # Build ranges from consecutive sequences result = [] start = prev = nums[0] + for n in nums[1:]: if n == prev + 1: + # Continue current range prev = n else: + # End current range and start new one result.append(f"{start}-{prev}" if start != prev else f"{start}") start = prev = n + + # Don't forget the last range result.append(f"{start}-{prev}" if start != prev else f"{start}") + return " ".join(result) - def dash_str_to_ints(s: str) -> np.ndarray: - + """ + Expand a compressed dash-notation string back into individual integers. + + This reverses the merge_adjacent_to_dash operation, expanding ranges + like "199-202" back into individual numbers [199, 200, 201, 202]. + + Args: + s (str): Compressed string with ranges (e.g., "199-202 230-231") + + Returns: + np.ndarray: Sorted, deduplicated array of integers + + Example: + >>> dash_str_to_ints("1-3 5-7 10") + array([1, 2, 3, 5, 6, 7, 10]) + """ if not s.strip(): return np.array([], dtype=np.int64) - + out = [] for token in s.split(): if "-" in token: + # Expand range notation a, b = map(int, token.split("-")) if a > b: - a, b = b, a + a, b = b, a # Handle reversed ranges out.extend(range(a, b + 1)) else: + # Single number out.append(int(token)) + return np.array(sorted(set(out)), dtype=np.int64) -def addmessage(message,before,after): - answer={} - answer['role']='assistant' - answer['content']=[{"type": "text", "text": before}] - question={} - question['role']='user' - question['content']=[{"type": "text", "text": after}] - newmessage=message.copy() +# ============================================================================= +# CONVERSATION MANAGEMENT FOR VLM +# ============================================================================= + +def addmessage(message, before, after): + """ + Append a Q&A pair to the conversation history. + + This function simulates a conversation turn by adding the model's response + (before) and the user's follow-up question (after) to the message history. + + Args: + message (list): Current conversation history + before (str): Assistant's response text + after (str): User's next question text + + Returns: + list: Updated conversation history with new Q&A pair + """ + # Create assistant response + answer = { + 'role': 'assistant', + 'content': [{"type": "text", "text": before}] + } + + # Create user follow-up question + question = { + 'role': 'user', + 'content': [{"type": "text", "text": after}] + } + + # Append to conversation (copy to avoid modifying original) + newmessage = message.copy() newmessage.append(answer) newmessage.append(question) + return newmessage +# ============================================================================= +# MODEL INFERENCE +# ============================================================================= -def generate_save(model,messages,save_dir,save_name='test',save=True): - - +def generate_save(model, messages, save_dir, save_name='test', save=True): + """ + Generate a response from the VLM and optionally save it to a file. + + This function processes the conversation, runs inference through the model, + and extracts the generated text response. + + Args: + model: The loaded VLM model + messages (list): Conversation history in OpenAI-style format + save_dir (str): Directory to save the output + save_name (str): Base filename for the output (without extension) + save (bool): Whether to save the output to a file + + Returns: + str: The model's generated text response + """ + # Prepare input using the chat template text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) + + # Process any images/videos in the messages image_inputs, video_inputs = process_vision_info(messages) + + # Tokenize and prepare inputs inputs = processor( text=[text], images=image_inputs, @@ -112,69 +303,128 @@ def generate_save(model,messages,save_dir,save_name='test',save=True): return_tensors="pt", ) inputs = inputs.to(model.device) - - - generated_ids = model.generate(**inputs, do_sample=False,temperature=0,max_length=32768) + + # Generate response (deterministic with temperature=0) + generated_ids = model.generate( + **inputs, + do_sample=False, + temperature=0, + max_length=32768 # Allow long outputs for voxel coordinates + ) + + # Extract only the newly generated tokens (exclude input prompt) generated_ids_trimmed = [ - out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + out_ids[len(in_ids):] + for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] + + # Decode tokens to text output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False ) + + # Save output if requested if save: - with open(os.path.join(save_dir,save_name+'.txt'),'w') as file: - file.write( output_text[0]) + output_path = os.path.join(save_dir, save_name + '.txt') + with open(output_path, 'w') as file: + file.write(output_text[0]) + return output_text[0] +# ============================================================================= +# MAIN EXECUTION +# ============================================================================= + if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--demo_path", type=str, default='./demo') - parser.add_argument("--save_part_ply", type=bool, default=True) - parser.add_argument("--remove_bg", type=bool, default=False) - parser.add_argument("--ckpt", type=str, default='./pretrain/vlm') + # ------------------------------------------------------------------------- + # Parse Command Line Arguments + # ------------------------------------------------------------------------- + parser = argparse.ArgumentParser( + description="VLM Demo: Generate 3D voxel representations from images" + ) + parser.add_argument( + "--demo_path", type=str, default='./demo', + help="Path to input images directory" + ) + parser.add_argument( + "--save_part_ply", type=bool, default=True, + help="Whether to save individual part point clouds as PLY files" + ) + parser.add_argument( + "--remove_bg", type=bool, default=False, + help="Whether to remove background from input images" + ) + parser.add_argument( + "--ckpt", type=str, default='./pretrain/vlm', + help="Path to the fine-tuned VLM checkpoint" + ) args = parser.parse_args() - basepath=args.demo_path - namelist=os.listdir(basepath) + # ------------------------------------------------------------------------- + # Setup: Load Input Files and Model + # ------------------------------------------------------------------------- + basepath = args.demo_path + namelist = os.listdir(basepath) - - + print(f"Found {len(namelist)} images in {basepath}") + + # Load the Vision-Language Model with optimizations + print("Loading VLM model...") model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - args.ckpt, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - device_map="auto", - ) - min_pixels = 65536 - max_pixels = 262144 - - processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels) - processor.image_processor.min_pixels=min_pixels - processor.image_processor.max_pixels=max_pixels - processor.image_processor.size["shortest_edge"]=min_pixels - processor.image_processor.size["longest_edge"]=max_pixels - + args.ckpt, + torch_dtype=torch.bfloat16, # Use bfloat16 for memory efficiency + attn_implementation="flash_attention_2", # Use Flash Attention 2 for speed + device_map="auto", # Automatically distribute across GPUs + ) + + # Configure image processor with resolution limits + # These settings balance detail vs. memory usage + min_pixels = 65536 # 256x256 minimum + max_pixels = 262144 # 512x512 maximum + + processor = AutoProcessor.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", + min_pixels=min_pixels, + max_pixels=max_pixels + ) + processor.image_processor.min_pixels = min_pixels + processor.image_processor.max_pixels = max_pixels + processor.image_processor.size["shortest_edge"] = min_pixels + processor.image_processor.size["longest_edge"] = max_pixels + + # ------------------------------------------------------------------------- + # Process Each Image + # ------------------------------------------------------------------------- for name in namelist: - - - - save_dir=os.path.join('test_demo',name[:-4]) - os.makedirs(os.path.join(save_dir), exist_ok=True) - - image_path = os.path.join(basepath,name) - - - - with open(os.path.join('./dataset/overall_prompt.txt'), "r", encoding="utf-8") as f: + print(f"\n{'='*60}") + print(f"Processing: {name}") + print('='*60) + + # Create output directory for this image + save_dir = os.path.join('test_demo', name[:-4]) # Remove file extension + os.makedirs(save_dir, exist_ok=True) + + image_path = os.path.join(basepath, name) + + # Load the prompt template for object analysis + with open('./dataset/overall_prompt.txt', "r", encoding="utf-8") as f: basicqu = f.read() - + + # Load and resize input image input_image = Image.open(image_path) im_resized = input_image.resize((512, 512), Image.LANCZOS) - + + # Optionally remove background for cleaner analysis if args.remove_bg: im_resized = remove(im_resized) - + + # --------------------------------------------------------------------- + # Step 1: Get Basic Object Information + # --------------------------------------------------------------------- + # Initial message with image and analysis prompt messages = [ { "role": "user", @@ -188,27 +438,64 @@ def generate_save(model,messages,save_dir,save_name='test',save=True): } ] - - - basicoutput=generate_save(model,messages,save_dir,'basic_info') - index=0 - while 'l_'+str(index) in basicoutput: - index+=1 - - allcoord=[] - for part in range(index): - - question="Based on the structured description of l_"+str(part)+", generate its 3D voxel grid in the following format (voxel grid=32, use numbers from 0 to 32767, merge maximal consecutive runs: 199...216 -> 199-216): 184 198 199-216 230-237..." - messages1=addmessage(messages,basicoutput,question) - output1=generate_save(model,messages1,save_dir,'coord_'+str(part),save=True) - print(len(messages1)) + # Get basic object info (name, category, parts, materials, etc.) + print("Step 1: Analyzing object structure...") + basicoutput = generate_save(model, messages, save_dir, 'basic_info') + + # Count how many parts were detected (look for l_0, l_1, l_2, etc.) + num_parts = 0 + while f'l_{num_parts}' in basicoutput: + num_parts += 1 + print(f"Detected {num_parts} parts") + + # --------------------------------------------------------------------- + # Step 2: Generate Voxel Coordinates for Each Part + # --------------------------------------------------------------------- + allcoord = [] # Collect all part coordinates + + for part in range(num_parts): + print(f"Step 2.{part}: Generating voxels for part l_{part}...") + + # Construct prompt for voxel generation + # The model outputs voxel indices in compressed dash notation + question = ( + f"Based on the structured description of l_{part}, " + f"generate its 3D voxel grid in the following format " + f"(voxel grid=32, use numbers from 0 to 32767, " + f"merge maximal consecutive runs: 199...216 -> 199-216): " + f"184 198 199-216 230-237..." + ) + + # Add previous response and new question to conversation + messages1 = addmessage(messages, basicoutput, question) + + # Generate voxel coordinates + output1 = generate_save(model, messages1, save_dir, f'coord_{part}', save=True) + print(f" Conversation length: {len(messages1)} messages") + + # Decode the compressed voxel indices back to 3D coordinates idx_back = dash_str_to_ints(output1) voxels_back = voxel_decode(idx_back) + + print(f" Generated {len(voxels_back)} voxels for part {part}") + + # Save individual part data allcoord.append(voxels_back) - np.save(os.path.join(save_dir,'ind_'+str(part)+'.npy'),voxels_back) + np.save(os.path.join(save_dir, f'ind_{part}.npy'), voxels_back) + + # Optionally save as PLY point cloud for visualization if args.save_part_ply: - partply=trimesh.points.PointCloud(voxels_back) - partply.export(os.path.join(save_dir,'ind_'+str(part)+'.ply')) + partply = trimesh.points.PointCloud(voxels_back) + partply.export(os.path.join(save_dir, f'ind_{part}.ply')) + + # --------------------------------------------------------------------- + # Step 3: Save Combined Voxel Data + # --------------------------------------------------------------------- + if allcoord: + combined_voxels = np.concatenate(allcoord) + np.save(os.path.join(save_dir, 'allind.npy'), combined_voxels) + print(f"Saved combined voxels: {len(combined_voxels)} total voxels") + + print(f"Completed processing: {name}") - np.save(os.path.join(save_dir,'allind.npy'),np.concatenate(allcoord)) diff --git a/2_decoder.py b/2_decoder.py index d8abf1e..30bc497 100644 --- a/2_decoder.py +++ b/2_decoder.py @@ -1,58 +1,180 @@ +""" +=============================================================================== +2_decoder.py - 3D Model Generation from Voxel Data +=============================================================================== + +This script uses the Trellis pipeline to generate high-quality 3D models (GLB) +from voxel data produced by the VLM demo. + +Pipeline Overview: + 1. Load the pre-trained Trellis Image-to-3D decoder pipeline + 2. For each processed demo image: + - Load the original image and voxel coordinates + - Create a sparse structure tensor from voxel data + - Run the pipeline with the structure as a control signal + - Export the result as a GLB file with textures + +Key Concepts: + - Sparse Structure (ss): A 64x64x64 binary tensor indicating occupied voxels + - Voxel Coordinates: 32x32x32 grid coordinates that are upscaled to 64x64x64 + - Control Signal: The sparse structure guides the 3D generation process + +Dependencies: + - Trellis pipeline for 3D generation + - Pre-trained decoder model in ./pretrain/decoder + +Author: PhysX-Anything Team +=============================================================================== +""" + +# ============================================================================= +# IMPORTS +# ============================================================================= + import os -# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' -os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. - # 'auto' is faster but will do benchmarking at the beginning. - # Recommended to set to 'native' if run only once. -import imageio -from PIL import Image +# Environment configuration (must be set before importing trellis) +# os.environ['ATTN_BACKEND'] = 'xformers' # Alternative: 'flash-attn' (default) +os.environ['SPCONV_ALGO'] = 'native' # 'native' is slower but avoids benchmarking overhead + +# Trellis 3D generation pipeline from trellis.pipelines import TrellisImageTo3DPipeline from trellis.utils import render_utils, postprocessing_utils -import ipdb + +# Image processing +import imageio +from PIL import Image + +# Data processing import numpy as np import torch import trimesh -# Load a pipeline from a model folder or a Hugging Face model hub. + +# Debugging (can be removed in production) +import ipdb + + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +# Voxel grid settings +VOXEL_GRID_SIZE = 32 # Size of VLM output voxel grid +SPARSE_RESOLUTION = 64 # Size of decoder's sparse structure tensor + + +# ============================================================================= +# MAIN EXECUTION +# ============================================================================= + +# Load the Trellis Image-to-3D pipeline from pre-trained weights +print("Loading Trellis decoder pipeline...") pipeline = TrellisImageTo3DPipeline.from_pretrained("./pretrain/decoder") pipeline.cuda() +print("Pipeline loaded successfully!") -basepath='./demo' -filepath='./test_demo' -namelist=os.listdir(basepath) +# Setup paths +basepath = './demo' # Original demo images +filepath = './test_demo' # VLM output directory (contains voxel data) +namelist = os.listdir(basepath) +print(f"Found {len(namelist)} images to process") -for name in namelist: - image = Image.open(os.path.join(basepath,name)) - qwenpath=os.path.join(filepath,name[:-4]) +# ============================================================================= +# Process Each Image +# ============================================================================= - if os.path.exists(os.path.join(qwenpath,'allind.npy')): - newcoords=np.load(os.path.join(qwenpath,'allind.npy')) +for name in namelist: + print(f"\n{'='*60}") + print(f"Processing: {name}") + print('='*60) + + # Load the original image + image = Image.open(os.path.join(basepath, name)) + + # Path to VLM output for this image + qwenpath = os.path.join(filepath, name[:-4]) # Remove file extension + + # Check if voxel data exists (generated by 1_vlm_demo.py) + voxel_file = os.path.join(qwenpath, 'allind.npy') + + if os.path.exists(voxel_file): + print("Loading voxel coordinates...") + newcoords = np.load(voxel_file) + print(f" Loaded {len(newcoords)} voxel coordinates") - size=32 - resolution=64 - - newcoords=newcoords+32-(size)//2 + # ===================================================================== + # Create Sparse Structure Tensor + # ===================================================================== + # The voxel coordinates are in a 32x32x32 grid + # We need to place them in a 64x64x64 tensor, centered + size = VOXEL_GRID_SIZE # Original voxel grid size (32) + resolution = SPARSE_RESOLUTION # Target resolution (64) + + # Center the 32x32x32 voxels within the 64x64x64 grid + # Offset = 32 - (32/2) = 32 - 16 = 16 (but current formula gives: 32 - 16 = 16) + # This places voxels in the center region [16, 48) of each axis + offset = resolution // 2 - size // 2 # = 32 - 16 = 16 + newcoords = newcoords + offset + + # Create empty sparse structure tensor + # Shape: (1, batch) x (64, 64, 64) spatial dimensions ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long) + + # Mark occupied voxels ss[:, newcoords[:, 0], newcoords[:, 1], newcoords[:, 2]] = 1 - ss=ss.cuda().float().unsqueeze(0) - - - - outputs = pipeline.run_control(ss,image,seed=1,) - + + # Convert to float tensor on GPU with proper shape for pipeline + # Final shape: (1, 1, 64, 64, 64) - batch, channel, depth, height, width + ss = ss.cuda().float().unsqueeze(0) + + print(f" Created sparse structure tensor: {ss.shape}") + print(f" Occupied voxels: {ss.sum().item():.0f}") + + # ===================================================================== + # Run 3D Generation Pipeline + # ===================================================================== + print("Running 3D generation pipeline...") + + # run_control uses the sparse structure as a conditioning signal + # This guides the generation to match the predicted voxel structure + outputs = pipeline.run_control( + ss, # Sparse structure control signal + image, # Original image for appearance/texture + seed=1, # Fixed seed for reproducibility + ) + + print(" Generation complete!") + + # ===================================================================== + # Export as GLB File + # ===================================================================== + print("Exporting GLB file...") + + # Convert outputs to GLB format + # - gaussian: 3D Gaussian splatting representation + # - mesh: Triangle mesh representation glb = postprocessing_utils.to_glb( outputs['gaussian'][0], outputs['mesh'][0], - simplify=0.5, # Ratio of triangles to remove in the simplification process - texture_size=1024, # Size of the texture used for the GLB + simplify=0.5, # Remove 50% of triangles for smaller file size + texture_size=1024, # Texture resolution ) - - - + # Save the GLB file + output_path = os.path.join(qwenpath, 'sample.glb') + glb.export(output_path) - glb.export(os.path.join(qwenpath,'sample.glb')) + print(f" Saved: {output_path}") + else: + print(f" Skipping: No voxel data found at {voxel_file}") + print(" (Run 1_vlm_demo.py first to generate voxel data)") + +print("\n" + "="*60) +print("Decoder processing complete!") +print("="*60) + diff --git a/3_split.py b/3_split.py index 1c78bbb..1513c56 100644 --- a/3_split.py +++ b/3_split.py @@ -1,43 +1,133 @@ +""" +=============================================================================== +3_split.py - Mesh Segmentation Using Geodesic Distance Propagation +=============================================================================== + +This script segments a 3D mesh into multiple parts based on voxel labels from +the VLM output. Each part is exported as a separate OBJ file. + +Pipeline Overview: + 1. Load the generated GLB mesh and voxel part labels + 2. Find nearest labeled voxel for each mesh vertex + 3. Propagate labels using geodesic (surface) distance + 4. Assign face labels based on vertex majority voting + 5. Export each labeled region as a separate mesh + +Algorithm Details: + - Geodesic Propagation: Uses Dijkstra's algorithm on the mesh edge graph + - Label Assignment: Vertices close to labeled voxels become "seed" vertices + - Face Labeling: Each face takes the majority label of its vertices + - Tie Breaking: Uses geodesic distance sum when votes are tied + +Dependencies: + - trimesh: Mesh loading and manipulation + - scipy: KD-tree for nearest neighbor queries + +Author: PhysX-Anything Team +=============================================================================== +""" + +# ============================================================================= +# IMPORTS +# ============================================================================= + import os import heapq +import logging +import argparse + import numpy as np import trimesh from scipy.spatial import cKDTree -import argparse -def build_edge_graph(mesh: trimesh.Trimesh): +# ============================================================================= +# GRAPH CONSTRUCTION +# ============================================================================= +def build_edge_graph(mesh: trimesh.Trimesh): + """ + Build an edge graph from mesh connectivity for geodesic distance computation. + + Creates adjacency lists where each vertex stores its neighbors and edge weights + (Euclidean edge lengths). This graph is used for Dijkstra's algorithm. + + Args: + mesh (trimesh.Trimesh): Input mesh + + Returns: + tuple: (neighbors, weights) where: + - neighbors: List of arrays, neighbors[v] = adjacent vertex indices + - weights: List of arrays, weights[v] = edge lengths to neighbors[v] + """ edges = mesh.edges_unique V = mesh.vertices + + # Initialize adjacency lists for each vertex neighbors = [[] for _ in range(len(V))] - weights = [[] for _ in range(len(V))] + weights = [[] for _ in range(len(V))] + + # Calculate edge lengths e_len = np.linalg.norm(V[edges[:, 0]] - V[edges[:, 1]], axis=1) + + # Build undirected graph (add both directions) for (u, v), w in zip(edges, e_len): - neighbors[u].append(v); weights[u].append(w) - neighbors[v].append(u); weights[v].append(w) + neighbors[u].append(v) + weights[u].append(w) + neighbors[v].append(u) + weights[v].append(w) + + # Convert to numpy arrays for faster access neighbors = [np.asarray(n, dtype=np.int64) for n in neighbors] - weights = [np.asarray(w, dtype=np.float64) for w in weights] + weights = [np.asarray(w, dtype=np.float64) for w in weights] + return neighbors, weights -def nearest_label_all_vertices(vertices, label_to_points): +# ============================================================================= +# NEAREST LABEL COMPUTATION +# ============================================================================= +def nearest_label_all_vertices(vertices, label_to_points): + """ + Find the nearest labeled point for each mesh vertex using KD-trees. + + For each vertex, finds which label's point cloud contains the closest point. + This is used to initialize seed labels for geodesic propagation. + + Args: + vertices (np.ndarray): Mesh vertices, shape (N, 3) + label_to_points (dict): Maps label (str) -> point coordinates (N, 3) + + Returns: + tuple: (nearest_label, dmin_per_v, trees) where: + - nearest_label: Label of nearest point for each vertex + - dmin_per_v: Distance to nearest labeled point + - trees: Dict of KD-trees for each label + """ + # Build KD-tree for each label's point cloud trees = {} labels_sorted = sorted(label_to_points.keys(), key=lambda x: int(x)) + for lab in labels_sorted: P = np.asarray(label_to_points[lab], dtype=np.float64) trees[lab] = cKDTree(P) if len(P) > 0 else None + # Initialize arrays V = vertices.shape[0] nearest_label = np.zeros(V, dtype=np.int64) dmin_per_v = np.full(V, np.inf, dtype=np.float64) + # Find nearest label for each vertex for lab in labels_sorted: tree = trees[lab] if tree is None: continue + + # Query distance to nearest point in this label's cloud d, _ = tree.query(vertices, k=1, workers=-1) + + # Update if this label is closer mask = d < dmin_per_v dmin_per_v[mask] = d[mask] nearest_label[mask] = int(lab) @@ -45,31 +135,63 @@ def nearest_label_all_vertices(vertices, label_to_points): return nearest_label, dmin_per_v, trees +# ============================================================================= +# GEODESIC LABEL PROPAGATION +# ============================================================================= + def multisource_geodesic_propagation_with_fallback( neighbors, weights, seed_mask, seed_labels, fallback_labels ): - + """ + Propagate labels from seed vertices along mesh surface using Dijkstra's algorithm. + + This is a multi-source shortest path algorithm where each seed vertex starts + with distance 0. Labels are propagated to all vertices based on geodesic distance. + + Algorithm: + 1. Initialize seeds with distance 0 and their assigned labels + 2. Run Dijkstra's algorithm, propagating labels along shortest paths + 3. Use fallback labels for any unreachable vertices + + Args: + neighbors: Adjacency list from build_edge_graph() + weights: Edge weights from build_edge_graph() + seed_mask (np.ndarray): Boolean mask indicating seed vertices + seed_labels (np.ndarray): Labels for seed vertices + fallback_labels (np.ndarray): Backup labels for unreachable vertices + + Returns: + tuple: (labels, dist) where: + - labels: Final label assignment for each vertex + - dist: Geodesic distance to nearest seed of same label + """ V = len(neighbors) labels = np.full(V, -1, dtype=np.int64) - dist = np.full(V, np.inf, dtype=np.float64) - pq = [] - + dist = np.full(V, np.inf, dtype=np.float64) + pq = [] # Priority queue: (distance, vertex) + # Initialize seed vertices for v in range(V): if seed_mask[v]: labels[v] = seed_labels[v] dist[v] = 0.0 heapq.heappush(pq, (0.0, v)) + # Handle case with no seeds if len(pq) == 0: return fallback_labels.copy(), np.zeros(V, dtype=np.float64) - # Dijkstra + # Dijkstra's algorithm - propagate labels along shortest paths while pq: d_u, u = heapq.heappop(pq) + + # Skip if we've already found a shorter path to this vertex if d_u != dist[u]: continue + lab_u = labels[u] + + # Relax edges to neighbors for nv, w in zip(neighbors[u], weights[u]): nd = d_u + w if nd < dist[nv]: @@ -77,44 +199,94 @@ def multisource_geodesic_propagation_with_fallback( labels[nv] = lab_u heapq.heappush(pq, (nd, nv)) - + # Apply fallback labels to any unlabeled vertices miss = (labels == -1) if np.any(miss): labels[miss] = fallback_labels[miss] - dist[miss] = 0.0 + dist[miss] = 0.0 return labels, dist -def face_majority_label(mesh: trimesh.Trimesh, vlabels, vdist): +# ============================================================================= +# FACE LABELING +# ============================================================================= +def face_majority_label(mesh: trimesh.Trimesh, vlabels, vdist): + """ + Assign labels to faces based on majority voting of vertex labels. + + Each face's label is determined by: + 1. If all 3 vertices have the same label -> use that label + 2. Otherwise, use the most common label among vertices + 3. If tied, use the label with minimum total geodesic distance + + Args: + mesh (trimesh.Trimesh): Input mesh + vlabels (np.ndarray): Label for each vertex + vdist (np.ndarray): Geodesic distance for each vertex + + Returns: + np.ndarray: Label for each face + """ F = mesh.faces.shape[0] flabels = np.zeros(F, dtype=np.int64) + for i in range(F): + # Get the 3 vertices of this face vs = mesh.faces[i] labs = vlabels[vs] + + # Count votes for each label vals, counts = np.unique(labs, return_counts=True) + if len(vals) == 1: + # All vertices have same label - easy case flabels[i] = vals[0] else: + # Multiple labels - use majority voting idx = np.argmax(counts) + + # Check if there's a clear winner if np.sum(counts == counts[idx]) == 1: flabels[i] = vals[idx] else: + # Tie-breaker: use label with minimum total geodesic distance best_lab, best_sum = None, np.inf for lab in vals: s = vdist[vs][labs == lab].sum() if s < best_sum: best_sum, best_lab = s, lab flabels[i] = best_lab + return flabels -def ensure_nonempty_per_label(mesh, flabels, label_to_points, min_faces=10): +# ============================================================================= +# ENSURING ALL LABELS HAVE FACES +# ============================================================================= +def ensure_nonempty_per_label(mesh, flabels, label_to_points, min_faces=10): + """ + Ensure every label has at least some faces assigned. + + If a label has no faces (due to poor voxel predictions), force-assign + faces near the label's point cloud center. + + Args: + mesh (trimesh.Trimesh): Input mesh + flabels (np.ndarray): Current face labels (modified in-place) + label_to_points (dict): Maps label -> point coordinates + min_faces (int): Minimum faces to assign to missing labels + + Returns: + np.ndarray: Updated face labels + """ labels_sorted = sorted(label_to_points.keys(), key=lambda x: int(x)) F = mesh.faces.shape[0] - adj = mesh.face_adjacency # (M,2) + + # Build face adjacency for region growing + adj = mesh.face_adjacency # (M, 2) pairs of adjacent faces face_nbrs = [[] for _ in range(F)] for a, b in adj: face_nbrs[a].append(b) @@ -122,8 +294,11 @@ def ensure_nonempty_per_label(mesh, flabels, label_to_points, min_faces=10): tri_centers = mesh.triangles_center + # Check each label for lab in labels_sorted: lab_i = int(lab) + + # Skip if this label already has faces if np.any(flabels == lab_i): continue @@ -131,11 +306,14 @@ def ensure_nonempty_per_label(mesh, flabels, label_to_points, min_faces=10): if len(P) == 0: continue + # Find face closest to the point cloud center c = P.mean(axis=0) idx0 = np.argmin(np.linalg.norm(tri_centers - c[None, :], axis=1)) + # Grow region from seed face using BFS picked = set([idx0]) frontier = [idx0] + while len(picked) < min_faces and frontier: new_frontier = [] for f in frontier: @@ -144,29 +322,51 @@ def ensure_nonempty_per_label(mesh, flabels, label_to_points, min_faces=10): picked.add(g) new_frontier.append(g) frontier = new_frontier + + # Assign label to picked faces flabels[list(picked)] = lab_i return flabels -def export_label_submeshes(mesh: trimesh.Trimesh, flabels, out_dir): +# ============================================================================= +# MESH EXPORT +# ============================================================================= +def export_label_submeshes(mesh: trimesh.Trimesh, flabels, out_dir): + """ + Export each labeled region as a separate OBJ file. + + Args: + mesh (trimesh.Trimesh): Input mesh + flabels (np.ndarray): Label for each face + out_dir (str): Output directory path + """ os.makedirs(out_dir, exist_ok=True) unique_labs = np.unique(flabels) + for lab in unique_labs: mask = (flabels == lab) if not np.any(mask): continue + + # Extract submesh for this label sub = mesh.submesh([np.nonzero(mask)[0]], append=True, repair=True) + if sub.vertices.shape[0] == 0 or sub.faces.shape[0] == 0: continue + # Create subdirectory and export os.makedirs(os.path.join(out_dir, f"{lab}"), exist_ok=True) - export_path = os.path.join(out_dir,f"{lab}", f"{lab}.obj") + export_path = os.path.join(out_dir, f"{lab}", f"{lab}.obj") sub.export(export_path) + print(f"[+] Saved: {export_path} (V={len(sub.vertices)}, F={len(sub.faces)})") +# ============================================================================= +# MAIN SEGMENTATION FUNCTION +# ============================================================================= def segment_mesh_by_wrapped_pcd_no_minus1( mesh, @@ -175,47 +375,94 @@ def segment_mesh_by_wrapped_pcd_no_minus1( seed_tau_ratio: float = 0.02, min_seed_faces: int = 20 ): - + """ + Segment a mesh into parts using voxel-based labels and geodesic propagation. + + This is the main entry point for mesh segmentation. It: + 1. Computes nearest voxel labels for each vertex + 2. Uses geodesic propagation to spread labels across the surface + 3. Assigns face labels via majority voting + 4. Exports each segment as a separate mesh + Args: + mesh: Input mesh (Trimesh or Scene) + label_to_points (dict): Maps label (str) -> voxel coordinates (N, 3) + out_dir (str): Output directory for segmented meshes + seed_tau_ratio (float): Ratio of bbox diagonal for seed threshold + min_seed_faces (int): Minimum faces per label (for empty label handling) + """ + # Handle scene objects (merge all geometries) if not isinstance(mesh, trimesh.Trimesh): mesh = trimesh.util.concatenate([g for g in mesh.geometry.values()]) + # Compute seed threshold based on mesh size V = mesh.vertices bbox_diag = np.linalg.norm(mesh.bounds[1] - mesh.bounds[0]) tau_seed = bbox_diag * seed_tau_ratio - nearest_lab, dmin, _ = nearest_label_all_vertices(mesh.vertices, label_to_points) + print(f"Mesh bounding box diagonal: {bbox_diag:.4f}") + print(f"Seed distance threshold: {tau_seed:.4f}") + # Step 1: Find nearest label for each vertex + print("Computing nearest labels for vertices...") + nearest_lab, dmin, _ = nearest_label_all_vertices(mesh.vertices, label_to_points) + # Step 2: Build edge graph and propagate labels geodesically + print("Building edge graph...") neighbors, weights = build_edge_graph(mesh) + + print("Propagating labels via geodesic distance...") seed_mask = (dmin <= tau_seed) vlabels, vdist = multisource_geodesic_propagation_with_fallback( neighbors, weights, seed_mask=seed_mask, - seed_labels=nearest_lab, - fallback_labels=nearest_lab + seed_labels=nearest_lab, + fallback_labels=nearest_lab ) + # Step 3: Assign face labels by majority voting + print("Assigning face labels...") flabels = face_majority_label(mesh, vlabels, vdist) + # Step 4: Ensure all labels have faces flabels = ensure_nonempty_per_label(mesh, flabels, label_to_points, min_faces=min_seed_faces) + # Step 5: Export submeshes + print("Exporting segmented meshes...") export_label_submeshes(mesh, flabels, out_dir) -import logging +# ============================================================================= +# LOGGING SETUP +# ============================================================================= + def get_logger(filename, verbosity=1, name=None): + """ + Create a logger that writes to both file and console. + + Args: + filename (str): Log file path + verbosity (int): 0=DEBUG, 1=INFO, 2=WARNING + name (str): Logger name (optional) + + Returns: + logging.Logger: Configured logger instance + """ level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} + formatter = logging.Formatter( "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" ) + logger = logging.getLogger(name) logger.setLevel(level_dict[verbosity]) + # File handler fh = logging.FileHandler(filename, "w") fh.setFormatter(formatter) logger.addHandler(fh) + # Console handler sh = logging.StreamHandler() sh.setFormatter(formatter) logger.addHandler(sh) @@ -223,43 +470,88 @@ def get_logger(filename, verbosity=1, name=None): return logger -parser = argparse.ArgumentParser() -parser.add_argument("--index", type=int, default=0) -parser.add_argument("--range", type=int, default=2000) -args = parser.parse_args() -basepath='./test_demo' -namelist=os.listdir(basepath) -logger = get_logger(os.path.join('exp_split'+str(args.index)+'.log'),verbosity=1) -logger.info('start') - -for name in namelist: - tmpdir=os.path.join(basepath,name) - if os.path.exists(os.path.join(tmpdir,'sample.glb')): - os.makedirs(os.path.join(tmpdir), exist_ok=True) - os.makedirs(os.path.join(tmpdir,'objs'), exist_ok=True) - mesh = trimesh.load(os.path.join(tmpdir,'sample.glb'), force='mesh') - R = trimesh.transformations.rotation_matrix(np.deg2rad(90), [1, 0, 0]) - mesh.apply_transform(R) - - voxel_define=32 - loaded={} - - index=0 - while os.path.exists(os.path.join(tmpdir,'ind_'+str(index)+'.npy')): +# ============================================================================= +# MAIN EXECUTION +# ============================================================================= + +if __name__ == "__main__": + # ------------------------------------------------------------------------- + # Parse Arguments + # ------------------------------------------------------------------------- + parser = argparse.ArgumentParser( + description="Segment 3D meshes into parts based on VLM-generated voxel labels" + ) + parser.add_argument( + "--index", type=int, default=0, + help="Process index (for parallel processing)" + ) + parser.add_argument( + "--range", type=int, default=2000, + help="Processing range (unused, kept for compatibility)" + ) + args = parser.parse_args() + + # ------------------------------------------------------------------------- + # Setup + # ------------------------------------------------------------------------- + basepath = './test_demo' + namelist = os.listdir(basepath) + + # Setup logging + logger = get_logger( + os.path.join(f'exp_split{args.index}.log'), + verbosity=1 + ) + logger.info('Starting mesh segmentation...') + logger.info(f'Found {len(namelist)} items to process') + + # ------------------------------------------------------------------------- + # Process Each Item + # ------------------------------------------------------------------------- + VOXEL_GRID_SIZE = 32 # Size of the voxel grid from VLM + + for name in namelist: + tmpdir = os.path.join(basepath, name) + glb_path = os.path.join(tmpdir, 'sample.glb') + + if os.path.exists(glb_path): + logger.info(f'Processing: {name}') - vertices=np.load(os.path.join(tmpdir,'ind_'+str(index)+'.npy'))/voxel_define-0.5 - loaded[str(index)]=vertices - - index+=1 - - segment_mesh_by_wrapped_pcd_no_minus1( - mesh=mesh, - label_to_points=loaded, - out_dir=os.path.join(tmpdir,'objs'), - seed_tau_ratio=0.02, - min_seed_faces=20 - ) - logger.info('complete: '+name) - else: - logger.info('skip: '+name) + # Create output directories + os.makedirs(tmpdir, exist_ok=True) + os.makedirs(os.path.join(tmpdir, 'objs'), exist_ok=True) + + # Load mesh and apply rotation (GLB uses different coordinate system) + mesh = trimesh.load(glb_path, force='mesh') + R = trimesh.transformations.rotation_matrix(np.deg2rad(90), [1, 0, 0]) + mesh.apply_transform(R) + + # Load voxel labels for each part + loaded = {} + index = 0 + + while os.path.exists(os.path.join(tmpdir, f'ind_{index}.npy')): + # Load voxel coordinates and normalize to [-0.5, 0.5] range + vertices = np.load(os.path.join(tmpdir, f'ind_{index}.npy')) + vertices = vertices / VOXEL_GRID_SIZE - 0.5 + loaded[str(index)] = vertices + index += 1 + + logger.info(f' Loaded {len(loaded)} part labels') + + # Run segmentation + segment_mesh_by_wrapped_pcd_no_minus1( + mesh=mesh, + label_to_points=loaded, + out_dir=os.path.join(tmpdir, 'objs'), + seed_tau_ratio=0.02, # 2% of bbox diagonal + min_seed_faces=20 # Minimum faces per segment + ) + + logger.info(f'Complete: {name}') + else: + logger.info(f'Skip (no GLB): {name}') + + logger.info('All processing complete!') + diff --git a/4_simready_gen.py b/4_simready_gen.py index 1745399..3a62a57 100644 --- a/4_simready_gen.py +++ b/4_simready_gen.py @@ -1,18 +1,78 @@ +""" +=============================================================================== +4_simready_gen.py - Simulation-Ready Asset Generation (URDF & MJCF) +=============================================================================== + +This script generates physics simulation-ready assets from VLM outputs: + - URDF (Unified Robot Description Format) for ROS/PyBullet + - MJCF (MuJoCo-format XML) for MuJoCo physics simulation + +Pipeline Overview: + 1. Parse VLM output to extract: + - Object name, category, dimensions + - Part information (materials, physical properties) + - Kinematic groups (fixed, sliding, revolute joints) + 2. Generate URDF with proper joint hierarchy + 3. Generate MJCF with physics parameters, textures, and materials + +Joint Types Supported: + - A (Free): Floating joint (6-DOF) + - B (Slide): Prismatic joint (1-DOF translation) + - C (Revolute): Hinge joint (1-DOF rotation) + - D (Ball): Ball-and-socket joint (3-DOF rotation) + - CB (Combined): Revolute + Slide joint + +Key Concepts: + - Voxel Grid: 32x32x32 grid for position calculation + - Group Info: Defines kinematic relationships between parts + - Joint Parameters: Direction vectors, positions, and motion ranges + +Dependencies: + - trimesh: Mesh loading for volume calculation + - xml.etree: XML generation for URDF/MJCF + +Author: PhysX-Anything Team +=============================================================================== +""" + +# ============================================================================= +# IMPORTS +# ============================================================================= + import os -import numpy as np -import ipdb import re import json -import xml.etree.ElementTree as ET import shutil +import logging import argparse -from scipy.spatial import cKDTree as KDTree -import trimesh +import xml.etree.ElementTree as ET from typing import List, Dict, Optional from collections import defaultdict, deque -import logging +import numpy as np +import trimesh +from scipy.spatial import cKDTree as KDTree + +# Debugging (can be removed in production) +import ipdb + + +# ============================================================================= +# LOGGING SETUP +# ============================================================================= + def get_logger(filename, verbosity=1, name=None): + """ + Create a logger that writes to both file and console. + + Args: + filename (str): Log file path + verbosity (int): 0=DEBUG, 1=INFO, 2=WARNING + name (str): Logger name (optional) + + Returns: + logging.Logger: Configured logger instance + """ level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} formatter = logging.Formatter( "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" @@ -29,9 +89,30 @@ def get_logger(filename, verbosity=1, name=None): logger.addHandler(sh) return logger -def _pairwise_nn(a: np.ndarray, b: np.ndarray): +# ============================================================================= +# ADJACENT REGION DETECTION +# ============================================================================= + +def _pairwise_nn(a: np.ndarray, b: np.ndarray): + """ + Compute nearest neighbor correspondences between two point clouds. + + Uses KD-trees for efficient nearest neighbor queries in both directions + (A->B and B->A) to find mutual nearest neighbors. + + Args: + a (np.ndarray): First point cloud, shape (N, 3) + b (np.ndarray): Second point cloud, shape (M, 3) + + Returns: + tuple: (idx_ab, dist_ab, idx_ba, dist_ba) where: + - idx_ab: For each point in A, index of nearest point in B + - dist_ab: Distance to nearest point in B + - idx_ba: For each point in B, index of nearest point in A + - dist_ba: Distance to nearest point in A + """ ta = KDTree(a) tb = KDTree(b) dist_ab, idx_ab = tb.query(a, k=1, workers=-1) @@ -39,60 +120,114 @@ def _pairwise_nn(a: np.ndarray, b: np.ndarray): return idx_ab, dist_ab, idx_ba, dist_ba -def _robust_threshold(d, method="mad", q=0.2, k=2.5): +def _robust_threshold(d, method="mad", q=0.2, k=2.5): + """ + Compute a robust threshold for outlier detection. + + Supports two methods: + - "quantile": Simple quantile threshold + - "mad": Median Absolute Deviation (more robust to outliers) + + Args: + d (np.ndarray): Distance values + method (str): "quantile" or "mad" + q (float): Quantile for quantile method (default 0.2) + k (float): MAD multiplier (default 2.5) + + Returns: + float: Computed threshold value + """ d = np.asarray(d) if method == "quantile": return np.quantile(d, q) - # MAD + + # MAD (Median Absolute Deviation) method med = np.median(d) mad = np.median(np.abs(d - med)) + 1e-12 - return med + k * 1.4826 * mad + return med + k * 1.4826 * mad # 1.4826 is scale factor for normal distribution + def find_adjacent_region( a: np.ndarray, b: np.ndarray, thr: float | None = None, - thr_mode: str = "mad", + thr_mode: str = "mad", q: float = 0.2, expand_radius: float | None = None, ): - + """ + Find the adjacent/contact region between two point clouds. + + This function identifies points that are close to each other in two point clouds, + useful for determining where two parts of an object meet (e.g., joint locations). + + Algorithm: + 1. Find mutual nearest neighbors between A and B + 2. Filter by distance threshold (auto-computed or provided) + 3. Optionally expand the region by radius + 4. Fit a plane to the midpoints (for joint axis estimation) + + Args: + a (np.ndarray): First point cloud, shape (N, 3) + b (np.ndarray): Second point cloud, shape (M, 3) + thr (float): Distance threshold (None = auto-compute) + thr_mode (str): Threshold computation method ("mad" or "quantile") + q (float): Quantile for threshold computation + expand_radius (float): Expand region by this radius + + Returns: + dict: Contains: + - a_idx: Indices of adjacent points in A + - b_idx: Indices of adjacent points in B + - pairs: (N, 2) array of corresponding point pairs + - midpoints: (N, 3) array of midpoints between pairs + - plane: (center, normal) tuple or None + - thr: Threshold used + """ assert a.ndim == 2 and a.shape[1] == 3 assert b.ndim == 2 and b.shape[1] == 3 idx_ab, dist_ab, idx_ba, dist_ba = _pairwise_nn(a, b) - + # Find mutual nearest neighbors (A->B->A should return to same point) mutual = np.arange(len(a)) == idx_ba[idx_ab] d_mutual = dist_ab[mutual] i_a = np.nonzero(mutual)[0] j_b = idx_ab[mutual] + # Handle empty case if len(i_a) == 0: - return dict(a_idx=np.array([], dtype=int), - b_idx=np.array([], dtype=int), - pairs=np.empty((0,2), dtype=int), - midpoints=np.empty((0,3), dtype=a.dtype), - plane=None, - thr=0.0) + return dict( + a_idx=np.array([], dtype=int), + b_idx=np.array([], dtype=int), + pairs=np.empty((0, 2), dtype=int), + midpoints=np.empty((0, 3), dtype=a.dtype), + plane=None, + thr=0.0 + ) + # Compute threshold used_thr = _robust_threshold(d_mutual, thr_mode, q=q) if thr is None else thr + # Filter by threshold keep = d_mutual <= used_thr i_a = i_a[keep] j_b = j_b[keep] d_kept = d_mutual[keep] + # Fallback to looser threshold if nothing passes if len(i_a) == 0 and len(d_mutual) > 0 and thr is None: used_thr = _robust_threshold(d_mutual, "quantile", q=max(0.4, q)) keep = d_mutual <= used_thr i_a = np.nonzero(mutual)[0][keep] j_b = idx_ab[mutual][keep] - pairs = np.stack([i_a, j_b], axis=1) if len(i_a) else np.empty((0,2), dtype=int) - midpoints = (a[i_a] + b[j_b]) * 0.5 if len(i_a) else np.empty((0,3), dtype=a.dtype) + # Build output arrays + pairs = np.stack([i_a, j_b], axis=1) if len(i_a) else np.empty((0, 2), dtype=int) + midpoints = (a[i_a] + b[j_b]) * 0.5 if len(i_a) else np.empty((0, 3), dtype=a.dtype) + # Helper to expand region within a point cloud def _expand_within_cloud(points, seeds, radius): if radius is None or len(seeds) == 0: return np.unique(seeds) @@ -104,20 +239,20 @@ def _expand_within_cloud(points, seeds, radius): idxs.update(hits) return np.fromiter(idxs, dtype=int) - + # Expand and deduplicate indices a_idx = np.unique(i_a) b_idx = np.unique(j_b) a_idx = _expand_within_cloud(a, a_idx, expand_radius) b_idx = _expand_within_cloud(b, b_idx, expand_radius) + # Fit plane to midpoints (for joint axis estimation) plane = None if len(midpoints) >= 3: c = midpoints.mean(axis=0) X = midpoints - c - # SVD + # SVD to find normal (smallest singular value direction) _, _, vh = np.linalg.svd(X, full_matrices=False) - n = vh[-1] - + n = vh[-1] # Normal is last row of Vh n = n / (np.linalg.norm(n) + 1e-12) plane = (c, n) @@ -131,223 +266,426 @@ def _expand_within_cloud(points, seeds, radius): ) +# ============================================================================= +# VOXEL GRID OPERATIONS +# ============================================================================= +# Grid configuration +GRID = 32 # Voxel grid resolution +NEI6 = np.array([ + [1, 0, 0], [-1, 0, 0], + [0, 1, 0], [0, -1, 0], + [0, 0, 1], [0, 0, -1] +], dtype=np.int8) # 6-connectivity neighborhood offsets -###################################### - - -#refine_voxel -GRID = 32 -NEI6 = np.array([[1,0,0],[-1,0,0],[0,1,0],[0,-1,0],[0,0,1],[0,0,-1]], dtype=np.int8) def rasterize(points, grid=GRID): - occ = np.zeros((grid,grid,grid), dtype=bool) + """ + Convert point coordinates to a binary occupancy grid. + + Args: + points (np.ndarray): Point coordinates, shape (N, 3) + grid (int): Grid resolution + + Returns: + np.ndarray: Binary occupancy grid, shape (grid, grid, grid) + """ + occ = np.zeros((grid, grid, grid), dtype=bool) pts = np.asarray(points, dtype=np.int16) - mask = ((pts>=0)&(pts= 0) & (pts < grid)).all(1) + x, y, z = pts[mask].T + occ[x, y, z] = True + return occ + def boundary_mask(occ): + """ + Find boundary voxels (occupied voxels adjacent to empty space). + + Args: + occ (np.ndarray): Binary occupancy grid + + Returns: + np.ndarray: Binary mask of boundary voxels + """ bnd = np.zeros_like(occ) - xs,ys,zs = np.where(occ) - for dx,dy,dz in NEI6: - x2 = np.clip(xs+dx, 0, occ.shape[0]-1) - y2 = np.clip(ys+dy, 0, occ.shape[1]-1) - z2 = np.clip(zs+dz, 0, occ.shape[2]-1) - bnd[xs,ys,zs] |= ~occ[x2,y2,z2] + xs, ys, zs = np.where(occ) + + # Check 6-connectivity neighbors + for dx, dy, dz in NEI6: + x2 = np.clip(xs + dx, 0, occ.shape[0] - 1) + y2 = np.clip(ys + dy, 0, occ.shape[1] - 1) + z2 = np.clip(zs + dz, 0, occ.shape[2] - 1) + # Mark as boundary if neighbor is empty + bnd[xs, ys, zs] |= ~occ[x2, y2, z2] + return bnd + def idx_to_xyz(idx): + """ + Convert a boolean mask to XYZ coordinates. + + Args: + idx (np.ndarray): Boolean mask + + Returns: + np.ndarray: Coordinates of True values, shape (N, 3) + """ return np.stack(np.where(idx), axis=1).astype(np.int16) + def most_adjacent_shell_6n(A_xyz, B_xyz, grid=GRID): + """ + Find the contact region between two voxel regions using wave propagation. + + This function determines where two parts of an object meet by: + 1. Finding boundary voxels of each region + 2. Checking for direct 6-connectivity contact + 3. If no contact, propagating waves until they meet + + This is useful for determining joint positions between parts. + + Args: + A_xyz (np.ndarray): Voxel coordinates of region A, shape (N, 3) + B_xyz (np.ndarray): Voxel coordinates of region B, shape (M, 3) + grid (int): Grid resolution + + Returns: + dict: Contains: + - metric: "6-neighbor steps" + - min_grid_distance: Number of grid steps between regions + - pairs: Array of touching voxel pairs + - midpoints: Midpoints between touching pairs + """ + # Rasterize point clouds to occupancy grids A = rasterize(A_xyz, grid) B = rasterize(B_xyz, grid) + + # Find boundary voxels (surface of each region) A_front = boundary_mask(A) & A B_front = boundary_mask(B) & B + # Check for direct contact (6-connectivity) touch_pairs = [] if A_front.any() and B_front.any(): - Ax,Ay,Az = np.where(A_front) - Aset = set(zip(Ax,Ay,Az)) + Ax, Ay, Az = np.where(A_front) + Aset = set(zip(Ax, Ay, Az)) B_occ = B - for dx,dy,dz in NEI6: - nb = (np.clip(Ax+dx,0,grid-1), - np.clip(Ay+dy,0,grid-1), - np.clip(Az+dz,0,grid-1)) + + # Check each neighbor direction + for dx, dy, dz in NEI6: + nb = ( + np.clip(Ax + dx, 0, grid - 1), + np.clip(Ay + dy, 0, grid - 1), + np.clip(Az + dz, 0, grid - 1) + ) hit = B_occ[nb] if hit.any(): - for (x,y,z),(x2,y2,z2),h in zip(zip(Ax,Ay,Az), - zip(*nb), hit): + for (x, y, z), (x2, y2, z2), h in zip( + zip(Ax, Ay, Az), zip(*nb), hit + ): if h: - touch_pairs.append(((x,y,z),(x2,y2,z2))) + touch_pairs.append(((x, y, z), (x2, y2, z2))) + + # If direct contact found, return immediately if touch_pairs: - mid = np.array([(np.array(a)+np.array(b))/2.0 for a,b in touch_pairs], dtype=np.float32) + mid = np.array([ + (np.array(a) + np.array(b)) / 2.0 + for a, b in touch_pairs + ], dtype=np.float32) return { "metric": "6-neighbor steps", "min_grid_distance": 1, "pairs": np.array(touch_pairs, dtype=np.int16), - "midpoints": mid, # (M,3), + "midpoints": mid, } + # No direct contact - use wave propagation to find meeting point A_wave = A_front.copy() B_wave = B_front.copy() visitedA = A_front.copy() visitedB = B_front.copy() - dist = 1 + dist = 1 # Distance counter while True: def dilate_once(wave, solid): - xs,ys,zs = np.where(wave) + """Expand wave by one voxel in all 6 directions.""" + xs, ys, zs = np.where(wave) nxt = np.zeros_like(wave) - for dx,dy,dz in NEI6: - x2 = np.clip(xs+dx, 0, grid-1) - y2 = np.clip(ys+dy, 0, grid-1) - z2 = np.clip(zs+dz, 0, grid-1) - nxt[x2,y2,z2] = True - nxt &= ~solid + for dx, dy, dz in NEI6: + x2 = np.clip(xs + dx, 0, grid - 1) + y2 = np.clip(ys + dy, 0, grid - 1) + z2 = np.clip(zs + dz, 0, grid - 1) + nxt[x2, y2, z2] = True + nxt &= ~solid # Don't expand into solid region return nxt + # Expand both waves A_next = dilate_once(A_wave, A) B_next = dilate_once(B_wave, B) + # Remove already-visited voxels A_next &= ~visitedA B_next &= ~visitedB visitedA |= A_next visitedB |= B_next - + # Check if waves meet meet = A_next & visitedB if meet.any(): meet_xyz = idx_to_xyz(meet) - - B_prev = B_wave + + # Find contact pairs + B_prev = B_wave pairs = [] - for x,y,z in meet_xyz: - for dx,dy,dz in NEI6: - x2 = np.clip(x+dx,0,grid-1); y2 = np.clip(y+dy,0,grid-1); z2 = np.clip(z+dz,0,grid-1) - if B_prev[x2,y2,z2]: - pairs.append(((x,y,z),(x2,y2,z2))) + for x, y, z in meet_xyz: + for dx, dy, dz in NEI6: + x2 = np.clip(x + dx, 0, grid - 1) + y2 = np.clip(y + dy, 0, grid - 1) + z2 = np.clip(z + dz, 0, grid - 1) + if B_prev[x2, y2, z2]: + pairs.append(((x, y, z), (x2, y2, z2))) + pairs = np.unique(np.array(pairs, dtype=np.int16), axis=0) - mid = (pairs[:,0,:].astype(np.float32)+pairs[:,1,:].astype(np.float32))/2.0 + mid = (pairs[:, 0, :].astype(np.float32) + pairs[:, 1, :].astype(np.float32)) / 2.0 + return { "metric": "6-neighbor steps", - "min_grid_distance": dist+1, - "pairs": pairs, # (M,2,3) - "midpoints": mid, # (M,3) + "min_grid_distance": dist + 1, + "pairs": pairs, + "midpoints": mid, } + # Check if no more expansion possible (disconnected regions) if not (A_next.any() or B_next.any()): - return {"metric":"6-neighbor steps","min_grid_distance":None,"pairs":np.zeros((0,2,3),np.int16),"midpoints":np.zeros((0,3),np.float32)} + return { + "metric": "6-neighbor steps", + "min_grid_distance": None, + "pairs": np.zeros((0, 2, 3), np.int16), + "midpoints": np.zeros((0, 3), np.float32) + } + A_wave, B_wave = A_next, B_next dist += 1 -def bbox_corners_and_edge_midpoints(pts: np.ndarray): +# ============================================================================= +# BOUNDING BOX UTILITIES +# ============================================================================= +def bbox_corners_and_edge_midpoints(pts: np.ndarray): + """ + Compute bounding box corners, edge midpoints, and center for a point cloud. + + Useful for determining potential joint attachment points. + + Args: + pts (np.ndarray): Point cloud, shape (N, 3) + + Returns: + tuple: (corners, edge_mids, center) where: + - corners: 8 corner points of bounding box + - edge_mids: 12 edge midpoints + - center: Center point of bounding box + """ mins = pts.min(axis=0) maxs = pts.max(axis=0) x0, y0, z0 = mins x1, y1, z1 = maxs - corners = np.array([[x, y, z] - for x in [x0, x1] - for y in [y0, y1] - for z in [z0, z1]], dtype=float) + # Generate 8 corners + corners = np.array([ + [x, y, z] + for x in [x0, x1] + for y in [y0, y1] + for z in [z0, z1] + ], dtype=float) corners = np.unique(corners, axis=0) - + # Generate 12 edge midpoints xm, ym, zm = (x0 + x1) / 2, (y0 + y1) / 2, (z0 + z1) / 2 edge_mids = [] + + # X-parallel edges (4 edges) for y in [y0, y1]: for z in [z0, z1]: edge_mids.append([xm, y, z]) + + # Y-parallel edges (4 edges) for x in [x0, x1]: for z in [z0, z1]: edge_mids.append([x, ym, z]) + + # Z-parallel edges (4 edges) for x in [x0, x1]: for y in [y0, y1]: edge_mids.append([x, y, zm]) + edge_mids = np.unique(np.array(edge_mids, dtype=float), axis=0) center = np.array([xm, ym, zm], dtype=float) + return corners, edge_mids, center -def generate_allcandidate(ind_a_index,ind_b_index,datapath): - ind_a=[] - ind_b=[] - for ind in ind_a_index: - ind_a.append(np.load(os.path.join(datapath,'ind_'+str(ind)+'.npy'))) +# ============================================================================= +# JOINT POSITION CANDIDATE GENERATION +# ============================================================================= +def generate_allcandidate(ind_a_index, ind_b_index, datapath): + """ + Generate candidate joint positions between two groups of parts. + + Finds the contact region between two groups and returns the centers + of each group's contact area as potential joint positions. + + Args: + ind_a_index (list): Part indices for group A + ind_b_index (list): Part indices for group B + datapath (str): Path containing ind_{i}.npy files + + Returns: + np.ndarray: Candidate positions, shape (2, 3), normalized to [-0.5, 0.5] + """ + # Load voxel coordinates for both groups + ind_a = [] + ind_b = [] + + for ind in ind_a_index: + ind_a.append(np.load(os.path.join(datapath, f'ind_{ind}.npy'))) + for ind in ind_b_index: - ind_b.append(np.load(os.path.join(datapath,'ind_'+str(ind)+'.npy'))) + ind_b.append(np.load(os.path.join(datapath, f'ind_{ind}.npy'))) - ind_a=np.concatenate(ind_a) - ind_b=np.concatenate(ind_b) + ind_a = np.concatenate(ind_a) + ind_b = np.concatenate(ind_b) - results=most_adjacent_shell_6n(ind_a,ind_b) - ind_a_nei=results['pairs'][:,0] - ind_b_nei=results['pairs'][:,1] + # Find adjacent region using wave propagation + results = most_adjacent_shell_6n(ind_a, ind_b) + ind_a_nei = results['pairs'][:, 0] + ind_b_nei = results['pairs'][:, 1] - - - - corners, edge_mids, center=bbox_corners_and_edge_midpoints(ind_a_nei) - bbox_corners_a=np.concatenate([center[None]]) + # Get center of each contact region + corners, edge_mids, center = bbox_corners_and_edge_midpoints(ind_a_nei) + bbox_corners_a = np.concatenate([center[None]]) # Just use center + corners, edge_mids, center = bbox_corners_and_edge_midpoints(ind_b_nei) + bbox_corners_b = np.concatenate([center[None]]) # Just use center - corners, edge_mids, center=bbox_corners_and_edge_midpoints(ind_b_nei) - bbox_corners_b=np.concatenate([center[None]]) - - allcandidate=np.concatenate([bbox_corners_a,bbox_corners_b]) - allcandidate=allcandidate/32-0.5 + # Combine and normalize to [-0.5, 0.5] range + allcandidate = np.concatenate([bbox_corners_a, bbox_corners_b]) + allcandidate = allcandidate / 32 - 0.5 + return allcandidate -def generate_allcandidate_center(ind_a_index,ind_b_index,datapath): - ind_a=[] - ind_b=[] +def generate_allcandidate_center(ind_a_index, ind_b_index, datapath): + """ + Generate a single joint position candidate at the midpoint of contact regions. + + Similar to generate_allcandidate but returns the average of all candidate points. + + Args: + ind_a_index (list): Part indices for group A + ind_b_index (list): Part indices for group B + datapath (str): Path containing ind_{i}.npy files + + Returns: + np.ndarray: Single candidate position, shape (3,), normalized to [-0.5, 0.5] + """ + # Load voxel coordinates for both groups + ind_a = [] + ind_b = [] + for ind in ind_a_index: - ind_a.append(np.load(os.path.join(datapath,'ind_'+str(ind)+'.npy'))) - + ind_a.append(np.load(os.path.join(datapath, f'ind_{ind}.npy'))) + for ind in ind_b_index: - ind_b.append(np.load(os.path.join(datapath,'ind_'+str(ind)+'.npy'))) + ind_b.append(np.load(os.path.join(datapath, f'ind_{ind}.npy'))) - ind_a=np.concatenate(ind_a) - ind_b=np.concatenate(ind_b) + ind_a = np.concatenate(ind_a) + ind_b = np.concatenate(ind_b) - results=most_adjacent_shell_6n(ind_a,ind_b) - ind_a_nei=results['pairs'][:,0] - ind_b_nei=results['pairs'][:,1] + # Find adjacent region + results = most_adjacent_shell_6n(ind_a, ind_b) + ind_a_nei = results['pairs'][:, 0] + ind_b_nei = results['pairs'][:, 1] - corners, edge_mids, center=bbox_corners_and_edge_midpoints(ind_a_nei) - bbox_corners_a=np.concatenate([corners, edge_mids, center[None]]).mean(0) + # Get mean of all candidate points from each side + corners, edge_mids, center = bbox_corners_and_edge_midpoints(ind_a_nei) + bbox_corners_a = np.concatenate([corners, edge_mids, center[None]]).mean(0) + corners, edge_mids, center = bbox_corners_and_edge_midpoints(ind_b_nei) + bbox_corners_b = np.concatenate([corners, edge_mids, center[None]]).mean(0) - corners, edge_mids, center=bbox_corners_and_edge_midpoints(ind_b_nei) - bbox_corners_b=np.concatenate([corners, edge_mids, center[None]]).mean(0) - - allcandidate=(bbox_corners_a+bbox_corners_b)/2 - allcandidate=allcandidate/32-0.5 + # Return midpoint between the two centers + allcandidate = (bbox_corners_a + bbox_corners_b) / 2 + allcandidate = allcandidate / 32 - 0.5 + return allcandidate -############################################# + + +# ============================================================================= +# URDF XML GENERATION UTILITIES +# ============================================================================= + def make_origin_element(xyz, rpy): + """ + Create an XML element with position and rotation. + + Args: + xyz (list): Position as list of strings ["x", "y", "z"] + rpy (list): Rotation (roll, pitch, yaw) as list of strings + + Returns: + ET.Element: XML origin element + """ origin = ET.Element('origin') origin.set('xyz', ' '.join(xyz)) origin.set('rpy', ' '.join(rpy)) return origin -def add_inertial(link_element,xyz="0 0 0"): + +def add_inertial(link_element, xyz="0 0 0"): + """ + Add default inertial properties to a URDF link. + + Creates element with default mass and inertia values. + + Args: + link_element (ET.Element): Parent link element + xyz (str): Center of mass position + """ inertial = ET.SubElement(link_element, 'inertial') ET.SubElement(inertial, 'origin', xyz=xyz, rpy="0 0 0") ET.SubElement(inertial, 'mass', value="1.0") - ET.SubElement(inertial, 'inertia', ixx="1.0", ixy="0.0", ixz="0.0", - iyy="1.0", iyz="0.0", izz="1.0") + ET.SubElement( + inertial, 'inertia', + ixx="1.0", ixy="0.0", ixz="0.0", + iyy="1.0", iyz="0.0", izz="1.0" + ) + def add_fixed_joint(robot, name, parent, child, xyz="0 0 0", rpy="0 0 0"): + """ + Add a fixed joint between two links in a URDF. + + Args: + robot (ET.Element): Root robot element + name (str): Joint name + parent (str): Parent link name + child (str): Child link name + xyz (str): Joint position offset + rpy (str): Joint rotation offset + + Returns: + ET.Element: The created joint element + """ joint = ET.SubElement(robot, "joint", name=name, type="fixed") ET.SubElement(joint, "parent", link=parent) ET.SubElement(joint, "child", link=child) @@ -355,7 +693,21 @@ def add_fixed_joint(robot, name, parent, child, xyz="0 0 0", rpy="0 0 0"): return joint +# ============================================================================= +# TEXT PARSING UTILITIES +# ============================================================================= + def _to_nums(lst, expect_len): + """ + Convert a list of strings to floats with padding/truncation. + + Args: + lst (list): List of string values + expect_len (int): Expected output length + + Returns: + list: List of floats with exactly expect_len elements + """ out = [] for s in lst: s = s.strip() @@ -370,48 +722,105 @@ def _to_nums(lst, expect_len): else: v = 0.0 out.append(v) + + # Pad or truncate to expected length if len(out) < expect_len: out += [0.0] * (expect_len - len(out)) elif len(out) > expect_len: out = out[:expect_len] + return out + def clean_npfloat64(values): + """ + Clean numpy float64 string representations from VLM output. + + The VLM sometimes outputs values like "np.float64(0.5)" which need + to be extracted as plain numbers. + + Args: + values (list): List of strings possibly containing np.float64() + + Returns: + list: Cleaned string values + """ cleaned = [] for s in values: s = s.strip() - if s.startswith('np.float64('): + if s.startswith('np.float64('): + # Extract number from np.float64(X) num_str = re.sub(r'.*?\((.*?)\)', r'\1', s) - cleaned.append((num_str)) + cleaned.append(num_str) else: - cleaned.append((s)) + cleaned.append(s) return cleaned -def _extract_bracket_list(block, key, expect_len): +def _extract_bracket_list(block, key, expect_len): + """ + Extract a bracketed list of numbers from text. + + Searches for patterns like "key: [1, 2, 3]" and extracts the values. + + Args: + block (str): Text block to search + key (str): Key name to look for + expect_len (int): Expected number of values + + Returns: + list: Extracted float values, padded to expect_len + """ pattern = rf'{re.escape(key)}[^:\[]*:\s*\[([^\]]*)\]' m = re.search(pattern, block, flags=re.IGNORECASE) + if not m: return [0.0] * expect_len + raw = m.group(1) items = [x for x in raw.split(',')] - items=clean_npfloat64(items) + items = clean_npfloat64(items) + return _to_nums(items, expect_len) -#mujuco +# ============================================================================= +# MJCF XML TREE MANIPULATION +# ============================================================================= def find_body_by_name(root: ET.Element, name: str) -> ET.Element: + """ + Find a body element by name in an MJCF tree. + + Args: + root (ET.Element): Root element to search + name (str): Body name to find + + Returns: + ET.Element: Found body element or None + """ for elem in root.iter("body"): if elem.get("name") == name: return elem return None + def move_element(child: ET.Element, new_parent: ET.Element): + """ + Move an XML element from its current parent to a new parent. + + Note: This is a helper function for XML tree manipulation. + + Args: + child (ET.Element): Element to move + new_parent (ET.Element): New parent element + """ old_parent = child.getparent() if hasattr(child, "getparent") else None + if old_parent is None: for elem in new_parent.iter(): pass + def _find_parent(root, node): for e in root.iter(): for c in list(e): @@ -428,42 +837,65 @@ def _find_parent(root, node): parent.remove(child) new_parent.append(child) -def reparent_by_group_info(mjcf_root: ET.Element, group_info: dict, - base_body_name: str = "base", - group_body_prefix: str = "grouppart_"): +def reparent_by_group_info( + mjcf_root: ET.Element, + group_info: dict, + base_body_name: str = "base", + group_body_prefix: str = "grouppart_" +): + """ + Reparent body elements in MJCF tree based on kinematic group hierarchy. + + This function restructures the MJCF XML tree so that child groups are + properly nested under their parent groups, creating the correct kinematic chain. + + Args: + mjcf_root (ET.Element): Root of MJCF document + group_info (dict): Group hierarchy information from VLM output + Format: {group_id: [members, parent_group, params, type]} + base_body_name (str): Name of the base body + group_body_prefix (str): Prefix for group body names + """ + # Build parent-child relationships parent_of = {} for gkey, gval in group_info.items(): if str(gkey) == "0": - continue - + continue # Skip base group + try: parent_str = str(gval[1]) except Exception as e: - raise ValueError(f"group_info['{gkey}'] lack parent group: {gval}") from e + raise ValueError(f"group_info['{gkey}'] lacks parent group: {gval}") from e parent_of[str(gkey)] = parent_str - + # Find base body base_body = find_body_by_name(mjcf_root, base_body_name) if base_body is None: - raise ValueError(f"cannot find base body: name='{base_body_name}'") + raise ValueError(f"Cannot find base body: name='{base_body_name}'") def body_name_for_group(gid: str) -> str: + """Get body name for a group ID.""" if gid == "0": return base_body_name return f"{group_body_prefix}{gid}" + # Build dependency graph for topological sort children_of = defaultdict(list) indeg = defaultdict(int) - nodes = set(["0"]) + nodes = set(["0"]) # Include base group + for c, p in parent_of.items(): - nodes.add(c); nodes.add(p) + nodes.add(c) + nodes.add(p) children_of[p].append(c) indeg[c] += 1 indeg.setdefault(p, 0) + # Topological sort to process parents before children q = deque([n for n in nodes if indeg[n] == 0]) topo = [] + while q: u = q.popleft() topo.append(u) @@ -473,12 +905,13 @@ def body_name_for_group(gid: str) -> str: q.append(v) if len(topo) != len(nodes): - raise ValueError("Detect loop in group_info") + raise ValueError("Detected cycle in group_info hierarchy") - + # Reparent bodies in topological order for gid in topo: if gid == "0": continue + child_name = body_name_for_group(gid) parent_name = body_name_for_group(parent_of[gid]) @@ -486,20 +919,26 @@ def body_name_for_group(gid: str) -> str: parent_body = find_body_by_name(mjcf_root, parent_name) if child_body is None: - print(f"skip: {child_name}") + print(f"Skipping: {child_name} (not found)") continue + if parent_body is None: - raise ValueError(f"cannot find parent body: {parent_name} (child group: {gid} parent group: {parent_of[gid]})") + raise ValueError( + f"Cannot find parent body: {parent_name} " + f"(child group: {gid}, parent group: {parent_of[gid]})" + ) + # Check if already a child already_child = False for c in list(parent_body): if c is child_body: already_child = True break + if already_child: continue - + # Find current parent and reparent def find_parent(root, node): for e in mjcf_root.iter(): for c in list(e): @@ -514,23 +953,35 @@ def find_parent(root, node): def _indent(elem, level=0): + """ + Add indentation to XML element for pretty printing. + + Args: + elem (ET.Element): Element to indent + level (int): Current indentation level + """ i = "\n" + level * " " if len(elem): if not elem.text or not elem.text.strip(): elem.text = i + " " for e in elem: - _indent(e, level+1) + _indent(e, level + 1) if not e.tail or not e.tail.strip(): e.tail = i if level and (not elem.tail or not elem.tail.strip()): elem.tail = i + +# ============================================================================= +# MJCF GENERATION +# ============================================================================= + def generate_mjcf( - jsondata: dict={}, + jsondata: dict = {}, fixed_base: int = 0, out_path: str = "test.xml", model_name: str = "test", - # physics / options + # Physics / simulation options angle_unit: str = "radian", timestep: float = 0.002, gravity: str = "0 0 -9.81", @@ -538,7 +989,7 @@ def generate_mjcf( integrator: str = "implicitfast", density: float = 1.225, viscosity: float = 1.8e-5, - # visual + # Visual settings realtime: int = 1, shadowsize: int = 16384, numslices: int = 28, @@ -548,11 +999,11 @@ def generate_mjcf( headlight_active: int = 1, rgba_fog: str = "0 1 0 1", rgba_haze: str = "1 0 0 1", - # skybox + # Skybox skybox_file: Optional[str] = "./desert.png", skybox_gridsize: str = "3 4", skybox_gridlayout: str = ".U..LFRB.D..", - # plane checker + # Ground plane texture plane_texture_name: str = "plane", plane_material_name: str = "plane", plane_rgb1: str = ".1 .1 .1", @@ -564,13 +1015,13 @@ def generate_mjcf( plane_reflectance: float = 0.3, plane_texrepeat: str = "1 1", plane_texuniform: str = "true", - # contact / fluid defaults + # Contact / physics defaults geom_solref: str = ".5e-4", geom_solimp: str = "0.9 0.99 1e-4", geom_fluidcoef: str = "0.5 0.25 0.5 2.0 1.0", - # parts (each part creates: mesh, texture, material, default class, and sample body usage) + # Part definitions parts: List[Dict] = None, - # world items + # World layout floor_condim: int = 6, floor_size: str = "0 0 .25", light_pos: str = "30 30 30", @@ -578,304 +1029,403 @@ def generate_mjcf( light_ambient: str = ".3 .3 .3", light_diffuse: str = ".5 .5 .5", light_specular: str = ".5 .5 .5", - # demo body placements + # Object placement base_pos: str = "0 0 1", base_euler: str = "0 0 0", part_pos: str = "0 0 1.2", part_euler: str = "1.5 0 0", deformable: int = 0, ): + """ + Generate a complete MJCF (MuJoCo XML) file for physics simulation. + + This function creates a fully configured MJCF file including: + - Physics simulation parameters + - Visual rendering settings + - Asset definitions (meshes, textures, materials) + - World layout (ground, lights) + - Object bodies with joints based on group_info + + Joint Types (from group_info[-1]): + - 'A': Free joint (6-DOF floating) + - 'B': Slide/Prismatic joint (1-DOF translation) + - 'C': Revolute/Hinge joint (1-DOF rotation) + - 'D': Ball joint (3-DOF rotation) + - 'CB': Combined revolute + slide joint + Args: + jsondata (dict): Parsed VLM output containing group_info, parts, etc. + fixed_base (int): Whether to fix the base (0=free, 1=fixed) + out_path (str): Output file path + model_name (str): Model name in MJCF + ... (many physics and visual parameters) + parts (List[Dict]): Part configurations with mesh/texture paths + deformable (int): Whether to use deformable objects (0=rigid, 1=flex) + + Returns: + str: Output file path + """ if parts is None or len(parts) == 0: - raise ValueError("at least one part") + raise ValueError("At least one part must be provided") - # ---- root + # ========================================================================= + # Create root MJCF structure + # ========================================================================= mujoco = ET.Element("mujoco", attrib={"model": model_name}) - ET.SubElement(mujoco, "compiler", attrib={"angle": angle_unit, "autolimits": "true"}) - ET.SubElement( - mujoco, - "option", - attrib={ - "timestep": f"{timestep}", - "gravity": gravity, - "wind": wind, - "integrator": integrator, - "density": f"{density}", - "viscosity": f"{viscosity}", - }, - ) - - # ---- visual + + # Compiler settings + ET.SubElement(mujoco, "compiler", attrib={ + "angle": angle_unit, + "autolimits": "true" + }) + + # Simulation options + ET.SubElement(mujoco, "option", attrib={ + "timestep": f"{timestep}", + "gravity": gravity, + "wind": wind, + "integrator": integrator, + "density": f"{density}", + "viscosity": f"{viscosity}", + }) + + # ========================================================================= + # Visual settings + # ========================================================================= visual = ET.SubElement(mujoco, "visual") ET.SubElement(visual, "global", attrib={"realtime": str(realtime)}) - ET.SubElement( - visual, - "quality", - attrib={"shadowsize": str(shadowsize), "numslices": str(numslices), "offsamples": str(offsamples)}, - ) - ET.SubElement( - visual, - "headlight", - attrib={"diffuse": headlight_diffuse, "specular": headlight_specular, "active": str(headlight_active)}, - ) - ET.SubElement(visual, "rgba", attrib={"fog": rgba_fog, "haze": rgba_haze}) - - # ---- asset + ET.SubElement(visual, "quality", attrib={ + "shadowsize": str(shadowsize), + "numslices": str(numslices), + "offsamples": str(offsamples) + }) + ET.SubElement(visual, "headlight", attrib={ + "diffuse": headlight_diffuse, + "specular": headlight_specular, + "active": str(headlight_active) + }) + ET.SubElement(visual, "rgba", attrib={ + "fog": rgba_fog, + "haze": rgba_haze + }) + + # ========================================================================= + # Assets (meshes, textures, materials) + # ========================================================================= asset = ET.SubElement(mujoco, "asset") - # parts assets + # Add mesh, texture, and material for each part for p in parts: pname = p["name"] - # mesh - ET.SubElement( - asset, - "mesh", - attrib={"name": pname, "file": p["mesh_file"], "scale": p.get("scale", "1 1 1")}, - ) - # texture + + # Mesh asset + ET.SubElement(asset, "mesh", attrib={ + "name": pname, + "file": p["mesh_file"], + "scale": p.get("scale", "1 1 1") + }) + + # Texture asset tex_name = f"{pname}_tex" - ET.SubElement(asset, "texture", attrib={"type": "2d", "name": tex_name, "file": p["tex_file"]}) - # material + ET.SubElement(asset, "texture", attrib={ + "type": "2d", + "name": tex_name, + "file": p["tex_file"] + }) + + # Material asset mat_name = f"{pname}_img" - ET.SubElement(asset, "material", attrib={"name": mat_name, "texture": tex_name}) + ET.SubElement(asset, "material", attrib={ + "name": mat_name, + "texture": tex_name + }) - # skybox + # Skybox texture if skybox_file: - ET.SubElement( - asset, - "texture", - attrib={"type": "skybox", "file": skybox_file, "gridsize": skybox_gridsize, "gridlayout": skybox_gridlayout}, - ) - - # plane checker texture + material - ET.SubElement( - asset, - "texture", - attrib={ - "name": plane_texture_name, - "type": "2d", - "builtin": "checker", - "rgb1": plane_rgb1, - "rgb2": plane_rgb2, - "width": str(plane_width), - "height": str(plane_height), - "mark": plane_mark, - "markrgb": plane_markrgb, - }, - ) - ET.SubElement( - asset, - "material", - attrib={ - "name": plane_material_name, - "reflectance": str(plane_reflectance), - "texture": plane_texture_name, - "texrepeat": plane_texrepeat, - "texuniform": plane_texuniform, - }, - ) + ET.SubElement(asset, "texture", attrib={ + "type": "skybox", + "file": skybox_file, + "gridsize": skybox_gridsize, + "gridlayout": skybox_gridlayout + }) - # ---- default + # Ground plane texture and material + ET.SubElement(asset, "texture", attrib={ + "name": plane_texture_name, + "type": "2d", + "builtin": "checker", + "rgb1": plane_rgb1, + "rgb2": plane_rgb2, + "width": str(plane_width), + "height": str(plane_height), + "mark": plane_mark, + "markrgb": plane_markrgb, + }) + ET.SubElement(asset, "material", attrib={ + "name": plane_material_name, + "reflectance": str(plane_reflectance), + "texture": plane_texture_name, + "texrepeat": plane_texrepeat, + "texuniform": plane_texuniform, + }) + + # ========================================================================= + # Default settings + # ========================================================================= default = ET.SubElement(mujoco, "default") - ET.SubElement( - default, - "geom", - attrib={"solref": geom_solref, "solimp": geom_solimp, "fluidcoef": geom_fluidcoef}, - ) + ET.SubElement(default, "geom", attrib={ + "solref": geom_solref, + "solimp": geom_solimp, + "fluidcoef": geom_fluidcoef + }) - # per-part default class + # Per-part default classes for p in parts: pname = p["name"] dclass = ET.SubElement(default, "default", attrib={"class": pname}) + attrib = { "type": "mesh", "mesh": pname, "contype": p.get("contype", "1"), "conaffinity": p.get("conaffinity", "1"), } + if "density" in p: attrib["density"] = str(p["density"]) if "fluidshape" in p: attrib["fluidshape"] = p["fluidshape"] + ET.SubElement(dclass, "geom", attrib=attrib) - # ---- worldbody + # ========================================================================= + # World body (floor, lights, objects) + # ========================================================================= world = ET.SubElement(mujoco, "worldbody") - ET.SubElement( - world, - "geom", - attrib={ - "name": "floor", - "pos": "0 0 0", - "size": floor_size, - "type": "plane", - "material": plane_material_name, - "condim": str(floor_condim), - }, - ) - ET.SubElement( - world, - "light", - attrib={ - "directional": "true", - "ambient": light_ambient, - "pos": light_pos, - "dir": light_dir, - "diffuse": light_diffuse, - "specular": light_specular, - }, - ) - #ipdb.set_trace() - - base_body = ET.SubElement(world, "body", attrib={"name": "base", "pos": base_pos, "euler": base_euler}) - if fixed_base==0: + + # Ground plane + ET.SubElement(world, "geom", attrib={ + "name": "floor", + "pos": "0 0 0", + "size": floor_size, + "type": "plane", + "material": plane_material_name, + "condim": str(floor_condim), + }) + + # Directional light + ET.SubElement(world, "light", attrib={ + "directional": "true", + "ambient": light_ambient, + "pos": light_pos, + "dir": light_dir, + "diffuse": light_diffuse, + "specular": light_specular, + }) + + # ========================================================================= + # Base body (root of the object) + # ========================================================================= + base_body = ET.SubElement(world, "body", attrib={ + "name": "base", + "pos": base_pos, + "euler": base_euler + }) + + # Add free joint if base is not fixed + if fixed_base == 0: ET.SubElement(base_body, "freejoint") + + # Add geometries for base group (group 0) for idx in jsondata['group_info']['0']: part = parts_cfg[idx] ET.SubElement(base_body, "geom", attrib={ "class": part["name"], "material": f'{part["name"]}_img' }) - have_free=0 - dimscale=float(p.get("scale", "1 1 1").split(' ')[0]) - for group_idx in range(1,len(jsondata['group_info'])): - if jsondata['group_info'][str(group_idx)][-1]=='A': - have_free+=1 - elif jsondata['group_info'][str(group_idx)][-1]=='B': - movable_body = ET.SubElement(world, "body", attrib={"name": "grouppart_"+str(group_idx), "pos": "0 0 0"}) - - ET.SubElement( - movable_body, "joint", - attrib={ - "type": "slide", - "name": "slide_"+str(group_idx), - "axis": " ".join(map(str, jsondata['group_info'][str(group_idx)][2][:3])), - "range": " ".join(map(str, jsondata['group_info'][str(group_idx)][2][6:8])), - "damping": "0.001", - "frictionloss": "0.0", - "stiffness": "0" - } - ) + + # ========================================================================= + # Process kinematic groups (groups 1, 2, 3, ...) + # ========================================================================= + have_free = 0 # Count of free joints + dimscale = float(p.get("scale", "1 1 1").split(' ')[0]) + + for group_idx in range(1, len(jsondata['group_info'])): + joint_type = jsondata['group_info'][str(group_idx)][-1] # Last element is joint type + group_params = jsondata['group_info'][str(group_idx)][2] # Joint parameters + group_members = jsondata['group_info'][str(group_idx)][0] # Part indices in this group + + # --------------------------------------------------------------------- + # Type A: Free joint (floating, 6-DOF) + # --------------------------------------------------------------------- + if joint_type == 'A': + have_free += 1 + # Free joints are handled separately after all other joints + + # --------------------------------------------------------------------- + # Type B: Slide/Prismatic joint (1-DOF translation) + # --------------------------------------------------------------------- + elif joint_type == 'B': + movable_body = ET.SubElement(world, "body", attrib={ + "name": f"grouppart_{group_idx}", + "pos": "0 0 0" + }) + + # Add prismatic joint + ET.SubElement(movable_body, "joint", attrib={ + "type": "slide", + "name": f"slide_{group_idx}", + "axis": " ".join(map(str, group_params[:3])), # Direction vector + "range": " ".join(map(str, group_params[6:8])), # Motion limits + "damping": "0.001", + "frictionloss": "0.0", + "stiffness": "0" + }) - for idx in jsondata['group_info'][str(group_idx)][0]: + # Add part geometries + for idx in group_members: part = parts_cfg[idx] ET.SubElement(movable_body, "geom", attrib={ "class": part["name"], "material": f'{part["name"]}_img' }) - elif jsondata['group_info'][str(group_idx)][-1]=='C': - movable_body = ET.SubElement(world, "body", attrib={"name": "grouppart_"+str(group_idx), "pos": "0 0 0"}) - if jsondata['group_info'][str(group_idx)][2][6]==-1 and jsondata['group_info'][str(group_idx)][2][7]==1: - ET.SubElement( - movable_body, "joint", - attrib={ - "type": "hinge", - "name": "pivot_"+str(group_idx), - "axis": " ".join(map(str, jsondata['group_info'][str(group_idx)][2][:3])), - "pos": " ".join(map(str, (np.array(jsondata['group_info'][str(group_idx)][2][3:6])*dimscale).tolist())), - "range": " ".join(map(str, (np.array([-3000,3000])*np.pi).tolist())), - "damping": "0.001", - "frictionloss": "0.0", - "stiffness": "0" - } - ) + + # --------------------------------------------------------------------- + # Type C: Revolute/Hinge joint (1-DOF rotation) + # --------------------------------------------------------------------- + elif joint_type == 'C': + movable_body = ET.SubElement(world, "body", attrib={ + "name": f"grouppart_{group_idx}", + "pos": "0 0 0" + }) + + # Check if continuous (unlimited rotation) vs limited + is_continuous = (group_params[6] == -1 and group_params[7] == 1) + if is_continuous: + # Continuous rotation (no limits) + ET.SubElement(movable_body, "joint", attrib={ + "type": "hinge", + "name": f"pivot_{group_idx}", + "axis": " ".join(map(str, group_params[:3])), + "pos": " ".join(map(str, (np.array(group_params[3:6]) * dimscale).tolist())), + "range": " ".join(map(str, (np.array([-3000, 3000]) * np.pi).tolist())), + "damping": "0.001", + "frictionloss": "0.0", + "stiffness": "0" + }) else: - ET.SubElement( - movable_body, "joint", - attrib={ - "type": "hinge", - "name": "pivot_"+str(group_idx), - "axis": " ".join(map(str, jsondata['group_info'][str(group_idx)][2][:3])), - "pos": " ".join(map(str, (np.array(jsondata['group_info'][str(group_idx)][2][3:6])*dimscale).tolist())), - "range": " ".join(map(str, (np.array(jsondata['group_info'][str(group_idx)][2][6:8])*np.pi).tolist())), - "damping": "0.001", - "frictionloss": "0.0", - "stiffness": "0" - } - ) + # Limited rotation + ET.SubElement(movable_body, "joint", attrib={ + "type": "hinge", + "name": f"pivot_{group_idx}", + "axis": " ".join(map(str, group_params[:3])), + "pos": " ".join(map(str, (np.array(group_params[3:6]) * dimscale).tolist())), + "range": " ".join(map(str, (np.array(group_params[6:8]) * np.pi).tolist())), + "damping": "0.001", + "frictionloss": "0.0", + "stiffness": "0" + }) - for idx in jsondata['group_info'][str(group_idx)][0]: + # Add part geometries + for idx in group_members: part = parts_cfg[idx] ET.SubElement(movable_body, "geom", attrib={ "class": part["name"], "material": f'{part["name"]}_img' }) - elif jsondata['group_info'][str(group_idx)][-1]=='D': - movable_body = ET.SubElement(world, "body", attrib={"name": "grouppart_"+str(group_idx), "pos": "0 0 0"}) - ET.SubElement( - movable_body, "joint", - attrib={ - "type": "ball", - "name": "ball_"+str(group_idx), - "pos": " ".join(map(str, (np.array(jsondata['group_info'][str(group_idx)][2][3:6])*dimscale).tolist())), - "damping": "0.001", - "frictionloss": "0.0", - "stiffness": "0" - } - ) + + # --------------------------------------------------------------------- + # Type D: Ball joint (3-DOF rotation, spherical) + # --------------------------------------------------------------------- + elif joint_type == 'D': + movable_body = ET.SubElement(world, "body", attrib={ + "name": f"grouppart_{group_idx}", + "pos": "0 0 0" + }) + + ET.SubElement(movable_body, "joint", attrib={ + "type": "ball", + "name": f"ball_{group_idx}", + "pos": " ".join(map(str, (np.array(group_params[3:6]) * dimscale).tolist())), + "damping": "0.001", + "frictionloss": "0.0", + "stiffness": "0" + }) - for idx in jsondata['group_info'][str(group_idx)][0]: + # Add part geometries + for idx in group_members: part = parts_cfg[idx] ET.SubElement(movable_body, "geom", attrib={ "class": part["name"], "material": f'{part["name"]}_img' }) - elif jsondata['group_info'][str(group_idx)][-1]=='CB': - movable_body = ET.SubElement(world, "body", attrib={"name": "grouppart_"+str(group_idx), "pos": "0 0 0"}) - - if jsondata['group_info'][str(group_idx)][2][6]==-1 and jsondata['group_info'][str(group_idx)][2][7]==1: - ET.SubElement( - movable_body, "joint", - attrib={ - "type": "hinge", - "name": "pivot_"+str(group_idx), - "axis": " ".join(map(str, jsondata['group_info'][str(group_idx)][2][:3])), - "pos": " ".join(map(str, (np.array(jsondata['group_info'][str(group_idx)][2][3:6])*dimscale).tolist())), - "range": " ".join(map(str, (np.array([-3000,3000])*np.pi).tolist())), - "damping": "0.001", - "frictionloss": "0.0", - "stiffness": "0" - } - ) - + + # --------------------------------------------------------------------- + # Type CB: Combined revolute + slide joint + # --------------------------------------------------------------------- + elif joint_type == 'CB': + movable_body = ET.SubElement(world, "body", attrib={ + "name": f"grouppart_{group_idx}", + "pos": "0 0 0" + }) + + # Add revolute joint first + is_continuous = (group_params[6] == -1 and group_params[7] == 1) + + if is_continuous: + ET.SubElement(movable_body, "joint", attrib={ + "type": "hinge", + "name": f"pivot_{group_idx}", + "axis": " ".join(map(str, group_params[:3])), + "pos": " ".join(map(str, (np.array(group_params[3:6]) * dimscale).tolist())), + "range": " ".join(map(str, (np.array([-3000, 3000]) * np.pi).tolist())), + "damping": "0.001", + "frictionloss": "0.0", + "stiffness": "0" + }) else: - ET.SubElement( - movable_body, "joint", - attrib={ - "type": "hinge", - "name": "pivot_"+str(group_idx), - "axis": " ".join(map(str, jsondata['group_info'][str(group_idx)][2][:3])), - "pos": " ".join(map(str, (np.array(jsondata['group_info'][str(group_idx)][2][3:6])*dimscale).tolist())), - "range": " ".join(map(str, (np.array(jsondata['group_info'][str(group_idx)][2][6:8])*np.pi).tolist())), - "damping": "0.001", - "frictionloss": "0.0", - "stiffness": "0" - } - ) - ET.SubElement( - movable_body, "joint", - attrib={ - "type": "slide", - "name": "slide_"+str(group_idx), - "axis": " ".join(map(str, jsondata['group_info'][str(group_idx)][2][8:11])), - "range": " ".join(map(str, jsondata['group_info'][str(group_idx)][2][14:])), + ET.SubElement(movable_body, "joint", attrib={ + "type": "hinge", + "name": f"pivot_{group_idx}", + "axis": " ".join(map(str, group_params[:3])), + "pos": " ".join(map(str, (np.array(group_params[3:6]) * dimscale).tolist())), + "range": " ".join(map(str, (np.array(group_params[6:8]) * np.pi).tolist())), "damping": "0.001", "frictionloss": "0.0", "stiffness": "0" - } - ) + }) + + # Add slide joint + ET.SubElement(movable_body, "joint", attrib={ + "type": "slide", + "name": f"slide_{group_idx}", + "axis": " ".join(map(str, group_params[8:11])), # Slide direction + "range": " ".join(map(str, group_params[14:])), # Slide limits + "damping": "0.001", + "frictionloss": "0.0", + "stiffness": "0" + }) - for idx in jsondata['group_info'][str(group_idx)][0]: + # Add part geometries + for idx in group_members: part = parts_cfg[idx] ET.SubElement(movable_body, "geom", attrib={ "class": part["name"], "material": f'{part["name"]}_img' }) - if have_free>0: - for group_idx in range(1,len(jsondata['group_info'])): - if jsondata['group_info'][str(group_idx)][-1]=='A': - movable_body = ET.SubElement(world, "body", attrib={"name": "grouppart_"+str(group_idx), "pos": "0 0 1", "euler": base_euler}) + + # ========================================================================= + # Handle free joints (Type A) separately - they need special placement + # ========================================================================= + if have_free > 0: + for group_idx in range(1, len(jsondata['group_info'])): + if jsondata['group_info'][str(group_idx)][-1] == 'A': + movable_body = ET.SubElement(world, "body", attrib={ + "name": f"grouppart_{group_idx}", + "pos": "0 0 1", + "euler": base_euler + }) ET.SubElement(movable_body, "freejoint") for idx in jsondata['group_info'][str(group_idx)][0]: @@ -885,570 +1435,890 @@ def generate_mjcf( "material": f'{part["name"]}_img' }) - reparent_by_group_info(mujoco, jsondata['group_info'], base_body_name="base", group_body_prefix="grouppart_") + # ========================================================================= + # Reparent bodies according to kinematic hierarchy + # ========================================================================= + reparent_by_group_info( + mujoco, + jsondata['group_info'], + base_body_name="base", + group_body_prefix="grouppart_" + ) - if have_free>0: - for group_idx in range(1,len(jsondata['group_info'])): - if jsondata['group_info'][str(group_idx)][-1]=='A' and deformable==0: - extract_body_to_world(mujoco, "grouppart_"+str(group_idx)) - elif jsondata['group_info'][str(group_idx)][-1]=='A' and deformable==1: - #ipdb.set_trace() - + # ========================================================================= + # Handle deformable objects (if enabled) + # ========================================================================= + if have_free > 0: + for group_idx in range(1, len(jsondata['group_info'])): + joint_type = jsondata['group_info'][str(group_idx)][-1] + + if joint_type == 'A' and deformable == 0: + # Extract free body to world level (rigid) + extract_body_to_world(mujoco, f"grouppart_{group_idx}") + + elif joint_type == 'A' and deformable == 1: + # Create deformable (flex) object instead of rigid body world = find_worldbody(mujoco) - target = find_body(mujoco, "grouppart_"+str(group_idx)) + target = find_body(mujoco, f"grouppart_{group_idx}") parent = find_parent(mujoco, target) parent.remove(target) - filename=target.findall('geom')[0].get('class') - meshid=filename.split('l_')[1].split('_')[0] - + # Get mesh info for flex computation + filename = target.findall('geom')[0].get('class') + meshid = filename.split('l_')[1].split('_')[0] - - str_list=jsondata['dimension'].split(' ')[0].split('*') + # Calculate scaling from object dimensions + str_list = jsondata['dimension'].split(' ')[0].split('*') sorted_list = sorted(str_list, key=float, reverse=True) - scaling=float(sorted_list[0])/100 - - - - mesh=trimesh.load(os.path.join(out_path.split('basic.xml')[0],"./objs",str(meshid),str(meshid)+'.obj')) - voxel_size=0.01 - voxel_grid=mesh.voxelized(pitch=voxel_size) + scaling = float(sorted_list[0]) / 100 + + # Compute mass from volume and density + mesh = trimesh.load(os.path.join( + out_path.split('basic.xml')[0], + "./objs", str(meshid), f"{meshid}.obj" + )) + voxel_size = 0.01 + voxel_grid = mesh.voxelized(pitch=voxel_size) occupied = voxel_grid.matrix volume = np.sum(occupied) * (voxel_size ** 3) - mass=volume*(scaling**3)*jsondata['parts'][int(meshid)]['density'] - - - - flex=ET.SubElement( - world, "flexcomp", - attrib={ - "type": "mesh", - "file": os.path.join("./objs",str(meshid),str(meshid)+'.obj'), - "pos": "0 0 1", - "scale": p.get("scale", "1 1 1"), - "dim": "2", - "euler": "0 0 0", - "radius": "0.001", - "name": filename, - "dof": "trilinear", - "mass": str(mass), - } - ) - ET.SubElement( - flex, "elasticity", - attrib={ - "young": str(float(jsondata['parts'][int(meshid)]["Young's Modulus (GPa)"])*1e9), - "poisson": str(jsondata['parts'][int(meshid)]["Poisson's Ratio"]), - "damping": "0.001" - } - ) - ET.SubElement( - flex, "contact", - attrib={ - "selfcollide": "none", - "internal": "false", - } - ) - + mass = volume * (scaling ** 3) * jsondata['parts'][int(meshid)]['density'] + + # Create flexcomp element for deformable simulation + flex = ET.SubElement(world, "flexcomp", attrib={ + "type": "mesh", + "file": os.path.join("./objs", str(meshid), f"{meshid}.obj"), + "pos": "0 0 1", + "scale": p.get("scale", "1 1 1"), + "dim": "2", + "euler": "0 0 0", + "radius": "0.001", + "name": filename, + "dof": "trilinear", + "mass": str(mass), + }) + + # Add elasticity properties + ET.SubElement(flex, "elasticity", attrib={ + "young": str(float(jsondata['parts'][int(meshid)]["Young's Modulus (GPa)"]) * 1e9), + "poisson": str(jsondata['parts'][int(meshid)]["Poisson's Ratio"]), + "damping": "0.001" + }) + + # Add contact properties + ET.SubElement(flex, "contact", attrib={ + "selfcollide": "none", + "internal": "false", + }) + # ========================================================================= + # Write MJCF file + # ========================================================================= _indent(mujoco) tree = ET.ElementTree(mujoco) tree.write(out_path, encoding="utf-8", xml_declaration=True) + return out_path -#################### +# ============================================================================= +# MJCF BODY MANIPULATION UTILITIES +# ============================================================================= + def find_worldbody(root: ET.Element) -> ET.Element: + """Find the element in an MJCF tree.""" for e in root.iter("worldbody"): return e - raise ValueError("cannot find ") + raise ValueError("Cannot find ") + def find_body(root: ET.Element, name: str) -> ET.Element | None: + """Find a body by name in an MJCF tree.""" for b in root.iter("body"): if b.get("name") == name: return b return None + def find_parent(root: ET.Element, node: ET.Element) -> ET.Element | None: + """Find the parent element of a node in an XML tree.""" for e in root.iter(): for c in list(e): if c is node: return e return None + def is_direct_child_of_world(root: ET.Element, node: ET.Element) -> bool: + """Check if a node is a direct child of worldbody.""" parent = find_parent(root, node) return parent is not None and parent.tag == "worldbody" -def extract_body_to_world(root: ET.Element, body_name: str) -> bool: +def extract_body_to_world(root: ET.Element, body_name: str) -> bool: + """ + Move a body from its current parent to directly under worldbody. + + Used for free-floating bodies that shouldn't be nested. + + Args: + root (ET.Element): Root MJCF element + body_name (str): Name of body to extract + + Returns: + bool: True if moved, False if already at world level + """ world = find_worldbody(root) target = find_body(root, body_name) if target is None: - raise ValueError(f"❌ cannot find body: {body_name}") + raise ValueError(f"Cannot find body: {body_name}") if is_direct_child_of_world(root, target): - print(f" '{body_name}' skip") + print(f"'{body_name}' already at world level, skipping") return False parent = find_parent(root, target) if parent is None: - raise RuntimeError("❌ cannot find the parent node") + raise RuntimeError("Cannot find the parent node") parent.remove(target) world.append(target) - print(f"✅ Move '{body_name}' from parent body '{parent.get('name')}' to ") + print(f"Moved '{body_name}' from '{parent.get('name')}' to ") return True -#################### +# ============================================================================= +# MAIN EXECUTION +# ============================================================================= if __name__ == '__main__': - - parser = argparse.ArgumentParser(description="Convert urdf format to simplified format") - parser.add_argument('--voxel_define', type=int, default=32, help='Resolution of the voxel.') - parser.add_argument('--basepath', type=str, default='./test_demo', help='Path of the voxel.') - parser.add_argument('--process', type=int, default=0, help='whether use postprocess.') - parser.add_argument('--fixed_base', type=int, default=0, help='whether fix the basement of object in mjcf.') - parser.add_argument('--deformable', type=int, default=0, help='whether introduce deformable objects in mjcf.') + """ + Main entry point for URDF/MJCF generation pipeline. + + This script processes 3D segmented meshes and VLM-annotated basic_info.txt + to generate physics simulation-ready files: + 1. URDF (Universal Robot Description Format) for ROS/PyBullet + 2. MJCF (MuJoCo XML Format) for MuJoCo physics simulation + + The pipeline: + 1. Parse basic_info.txt to extract object metadata and group_info + 2. Optionally post-process joint positions using candidate generation + 3. Generate URDF with articulated joints + 4. Generate MJCF with physics properties + + Usage: + python 4_simready_gen.py --basepath ./test_demo --fixed_base 1 + """ + + # ========================================================================= + # Parse command line arguments + # ========================================================================= + parser = argparse.ArgumentParser( + description="Generate URDF and MJCF from segmented 3D models" + ) + parser.add_argument( + '--voxel_define', type=int, default=32, + help='Voxel grid resolution (default: 32)' + ) + parser.add_argument( + '--basepath', type=str, default='./test_demo', + help='Base path containing object folders' + ) + parser.add_argument( + '--process', type=int, default=0, + help='Enable joint position post-processing (0=off, 1=on)' + ) + parser.add_argument( + '--fixed_base', type=int, default=0, + help='Fix base in MJCF (0=floating, 1=fixed)' + ) + parser.add_argument( + '--deformable', type=int, default=0, + help='Enable deformable objects in MJCF (0=rigid, 1=flex)' + ) args = parser.parse_args() - logger = get_logger(os.path.join('exp_urdf.log'),verbosity=1) - logger.info('start') - - voxel_define=args.voxel_define - basepath=args.basepath - namelist=os.listdir(basepath) - - for filename in namelist: - logger.info('begin: '+filename) - if os.path.exists(os.path.join(basepath,filename,'objs')): - - with open(os.path.join(basepath,filename,'basic_info.txt'), "r", encoding="utf-8") as f: - basicqu = f.read() + # Initialize logging + logger = get_logger(os.path.join('exp_urdf.log'), verbosity=1) + logger.info('Starting URDF/MJCF generation pipeline') - lines = [line.strip() for line in basicqu.strip().split('\n') if line.strip()] - - data = {} - - - data['object_name'] = re.search(r'Name:\s*(.*)', lines[0]).group(1) - data['category'] = re.search(r'Category:\s*(.*)', lines[1]).group(1) - data['dimension'] = re.search(r'Dimension:\s*(.*)', lines[2]).group(1) - - parts = [] - for line in lines: - if line.startswith("l_"): - match = re.match( - r'l_(\d+):\s*([^,]+),\s*([^,]+),\s*([^,]+),\s*([^,]+),\s*([^,]+),\s*([^,]+),\s*(.*)', - line - ) - if match: - label = int(match.group(1)) - name = match.group(2).strip() - priority_rank = int(match.group(3)) - material = match.group(4).strip() - density = match.group(5).strip() - - young = (match.group(6).strip()) - poisson = (match.group(7).strip()) - basic_desc = match.group(8).strip() - - - - parts.append({ - "label": label, - "name": name, - "material": material, - "density": density, - "priority_rank": priority_rank, - "Basic_description": basic_desc, - "Young's Modulus (GPa)": young, - "Poisson's Ratio": poisson - }) - - data['parts'] = parts - - - group_info = {} - for i, line in enumerate(lines): - if re.match(r'^group_\d+\s*:', line.strip(), flags=re.IGNORECASE): - - gm = re.search(r'group_(\d+):\s*\[(.*?)\]', line, flags=re.IGNORECASE) - if not gm: - continue - gid = gm.group(1) - members_raw = gm.group(2) - - members = [] - for tok in members_raw.split(','): - tok = tok.strip().strip("'").strip('"') - nm = re.search(r'l_(\d+)', tok, flags=re.IGNORECASE) - if nm: - members.append(int(nm.group(1))) - - - tm = re.search(r'Type:\s*([A-Za-z])', line, flags=re.IGNORECASE) - gtype = tm.group(1).upper() if tm else "E" - if ': CB' in line: - gtype='CB' - - - rel_idx = None - + voxel_define = args.voxel_define + basepath = args.basepath + namelist = os.listdir(basepath) - rel_matches = re.findall(r'(?:relative\s*to\s*)+group_(\d+)', line, flags=re.IGNORECASE) - if rel_matches: - rel_idx = int(rel_matches[-1]) - - param_vec = [0.0] * 8 - - if gtype not in ("E", "A", "CB"): - - scan_block = line - - - dir_v = _extract_bracket_list(scan_block, 'direction', 3) - pos_v = _extract_bracket_list(scan_block, 'position', 3) + # ========================================================================= + # Process each object folder + # ========================================================================= + for filename in namelist: + logger.info(f'Processing: {filename}') + + # Skip folders without generated meshes + if not os.path.exists(os.path.join(basepath, filename, 'objs')): + logger.info(f'Skipping (no objs folder): {filename}') + continue + # ===================================================================== + # STEP 1: Parse basic_info.txt (VLM output) + # ===================================================================== + with open(os.path.join(basepath, filename, 'basic_info.txt'), "r", encoding="utf-8") as f: + basicqu = f.read() + + lines = [line.strip() for line in basicqu.strip().split('\n') if line.strip()] + + # Initialize data dictionary + data = {} + + # Extract object metadata + data['object_name'] = re.search(r'Name:\s*(.*)', lines[0]).group(1) + data['category'] = re.search(r'Category:\s*(.*)', lines[1]).group(1) + data['dimension'] = re.search(r'Dimension:\s*(.*)', lines[2]).group(1) + + # ----------------------------------------------------------------- + # Parse part definitions (l_0, l_1, l_2, ...) + # Format: l_N: name, priority, material, density, young, poisson, desc + # ----------------------------------------------------------------- + parts = [] + for line in lines: + if line.startswith("l_"): + match = re.match( + r'l_(\d+):\s*([^,]+),\s*([^,]+),\s*([^,]+),\s*([^,]+),\s*([^,]+),\s*([^,]+),\s*(.*)', + line + ) + if match: + parts.append({ + "label": int(match.group(1)), + "name": match.group(2).strip(), + "priority_rank": int(match.group(3)), + "material": match.group(4).strip(), + "density": match.group(5).strip(), + "Young's Modulus (GPa)": match.group(6).strip(), + "Poisson's Ratio": match.group(7).strip(), + "Basic_description": match.group(8).strip() + }) - pos_v=((np.array(pos_v)) / voxel_define - 0.5).tolist() - - if gtype in ("C"): - rng_v = _extract_bracket_list(scan_block, 'range', 2) - rng_v=(np.array(rng_v)/180).tolist() + data['parts'] = parts - if gtype in ("B"): - rng_v = _extract_bracket_list(scan_block, 'range', 2) - rng_v=(np.array(rng_v)/voxel_define).tolist() + # ----------------------------------------------------------------- + # Parse group definitions (kinematic groups with joint info) + # Format: group_N: [members], relative to group_M, Type: X, params... + # ----------------------------------------------------------------- + group_info = {} + + for i, line in enumerate(lines): + if not re.match(r'^group_\d+\s*:', line.strip(), flags=re.IGNORECASE): + continue + + # Extract group ID and members + gm = re.search(r'group_(\d+):\s*\[(.*?)\]', line, flags=re.IGNORECASE) + if not gm: + continue + + gid = gm.group(1) + members_raw = gm.group(2) + + # Parse member part IDs (e.g., "l_0, l_1" -> [0, 1]) + members = [] + for tok in members_raw.split(','): + tok = tok.strip().strip("'").strip('"') + nm = re.search(r'l_(\d+)', tok, flags=re.IGNORECASE) + if nm: + members.append(int(nm.group(1))) + + # Extract joint type (A, B, C, D, CB, or E=fixed) + tm = re.search(r'Type:\s*([A-Za-z])', line, flags=re.IGNORECASE) + gtype = tm.group(1).upper() if tm else "E" + if ': CB' in line: + gtype = 'CB' + + # Extract relative-to parent group + rel_idx = None + rel_matches = re.findall( + r'(?:relative\s*to\s*)+group_(\d+)', line, flags=re.IGNORECASE + ) + if rel_matches: + rel_idx = int(rel_matches[-1]) - param_vec = dir_v + pos_v + rng_v + # Initialize parameter vector (direction, position, range) + param_vec = [0.0] * 8 - elif gtype in ("CB"): - scan_block = line + # --------------------------------------------------------- + # Parse joint parameters based on type + # --------------------------------------------------------- + if gtype not in ("E", "A", "CB"): + # Types B, C, D: direction[3], position[3], range[2] + scan_block = line + dir_v = _extract_bracket_list(scan_block, 'direction', 3) + pos_v = _extract_bracket_list(scan_block, 'position', 3) - dir_v = _extract_bracket_list(scan_block, 'axis direction', 3) - pos_v = _extract_bracket_list(scan_block, 'axis position', 3) + # Normalize position to [-0.5, 0.5] range + pos_v = ((np.array(pos_v)) / voxel_define - 0.5).tolist() + + # Parse range based on joint type + if gtype == "C": + # Revolute: convert degrees to pi-normalized + rng_v = _extract_bracket_list(scan_block, 'range', 2) + rng_v = (np.array(rng_v) / 180).tolist() + + if gtype == "B": + # Prismatic: normalize to voxel scale + rng_v = _extract_bracket_list(scan_block, 'range', 2) + rng_v = (np.array(rng_v) / voxel_define).tolist() + + param_vec = dir_v + pos_v + rng_v + + elif gtype == "CB": + # Combined type: revolute + slide parameters + scan_block = line + + # Revolute axis parameters + dir_v = _extract_bracket_list(scan_block, 'axis direction', 3) + pos_v = _extract_bracket_list(scan_block, 'axis position', 3) + pos_v = ((np.array(pos_v)) / voxel_define - 0.5).tolist() + + rng_v = _extract_bracket_list(scan_block, 'revolute range', 2) + rng_v = (np.array(rng_v) / 180).tolist() - pos_v=((np.array(pos_v)) / voxel_define - 0.5).tolist() - rng_v = _extract_bracket_list(scan_block, 'revolute range', 2) - rng_v=(np.array(rng_v)/180).tolist() + # Slide parameters + dir_v1 = _extract_bracket_list(scan_block, 'slide direction', 3) + rng_v1 = _extract_bracket_list(scan_block, 'slide range', 2) + rng_v1 = (np.array(rng_v1) / voxel_define).tolist() - dir_v1 = _extract_bracket_list(scan_block, 'slide direction', 3) - - rng_v1 = _extract_bracket_list(scan_block, 'slide range', 2) - rng_v1=(np.array(rng_v1)/voxel_define).tolist() + # Combined: [rev_dir, rev_pos, rev_range, slide_dir, pad, slide_range] + param_vec = dir_v + pos_v + rng_v + dir_v1 + [0, 0, 0] + rng_v1 - param_vec = dir_v + pos_v + rng_v+dir_v1+[0,0,0]+rng_v1 + # Store group info + if gid == str(0): + # Group 0 is base (fixed parts), just store members + group_info[gid] = members + else: + # Movable groups: [members, parent_idx, params, type] + group_info[gid] = [members, str(rel_idx), param_vec, gtype] + data['group_info'] = group_info - - if gid==str(0): - group_info[gid] = members + # ===================================================================== + # STEP 2: Post-process joint positions (optional) + # ===================================================================== + # When --process=1, refine joint positions using geometry-based + # candidate detection to snap to actual mesh boundaries + + if args.process: + for group_id in range(1, len(group_info)): + joint_type = group_info[str(group_id)][-1] + + # --------------------------------------------------------- + # Refine revolute (C) and combined (CB) joint positions + # --------------------------------------------------------- + if joint_type in ('C', 'CB'): + # Get parent group's members + parent_id = group_info[str(group_id)][1] + if parent_id == '0': + group_b = group_info['0'] else: - group_info[gid] = [members,str(rel_idx),param_vec,gtype] - - data['group_info'] = group_info - - - - if args.process: - for group_id in range(1,len(group_info)): - if group_info[str(group_id)][-1]=='C' or group_info[str(group_id)][-1]=='CB': - if group_info[str(group_id)][1]=='0': - group_b=group_info['0'] - else: - if group_info[str(group_id)][1]=='0': - group_b=group_info[group_info[str(group_id)][1]] - else: - group_b=group_info[group_info[str(group_id)][1]][0] - - allcandidate=generate_allcandidate(group_info[str(group_id)][0],group_b,os.path.join(basepath,filename)) - - axisdir=np.array(group_info[str(group_id)][2][:3]) - axisdir = np.int32(axisdir / np.linalg.norm(axisdir)) - weights=np.array([1,1,1]) - weights[np.where(axisdir==1)]=0 - error=(allcandidate - np.array(group_info[str(group_id)][2][3:6]))*weights - - dist = np.linalg.norm(error, axis=1) - idx = np.argmin(dist) - nearest_point = allcandidate[idx] - - if np.linalg.norm((nearest_point-np.array(group_info[str(group_id)][2][3:6]))*weights)<0.03: - group_info[str(group_id)][2][3:6]=nearest_point.tolist() - - if group_info[str(group_id)][-1]=='D': - if group_info[str(group_id)][1]=='0': - group_b=group_info['0'] - else: - if group_info[str(group_id)][1]=='0': - group_b=group_info[group_info[str(group_id)][1]] - else: - group_b=group_info[group_info[str(group_id)][1]][0] - - - - allcandidate=generate_allcandidate_center(group_info[str(group_id)][0],group_b,os.path.join(basepath,filename)) - - weights=np.array([1,1,1]) - error=(allcandidate - np.array(group_info[str(group_id)][2][3:6]))*weights - dist = np.linalg.norm(error) - idx = np.argmin(dist) - nearest_point = allcandidate[idx] + parent_data = group_info[parent_id] + group_b = parent_data if isinstance(parent_data, list) and isinstance(parent_data[0], int) else parent_data[0] + + # Generate candidate axis positions from mesh intersection + allcandidate = generate_allcandidate( + group_info[str(group_id)][0], # Current group members + group_b, # Parent group members + os.path.join(basepath, filename) + ) - if np.linalg.norm((nearest_point-np.array(group_info[str(group_id)][2][3:6]))*weights)<0.03: + # Compute weights: ignore axis direction dimension + axisdir = np.array(group_info[str(group_id)][2][:3]) + axisdir = np.int32(axisdir / np.linalg.norm(axisdir)) + weights = np.array([1, 1, 1]) + weights[np.where(axisdir == 1)] = 0 + + # Find nearest candidate point + current_pos = np.array(group_info[str(group_id)][2][3:6]) + error = (allcandidate - current_pos) * weights + dist = np.linalg.norm(error, axis=1) + idx = np.argmin(dist) + nearest_point = allcandidate[idx] + + # Snap if close enough (threshold: 0.03) + if np.linalg.norm((nearest_point - current_pos) * weights) < 0.03: + group_info[str(group_id)][2][3:6] = nearest_point.tolist() + + # --------------------------------------------------------- + # Refine ball joint (D) center positions + # --------------------------------------------------------- + if joint_type == 'D': + # Get parent group's members + parent_id = group_info[str(group_id)][1] + if parent_id == '0': + group_b = group_info['0'] + else: + parent_data = group_info[parent_id] + group_b = parent_data if isinstance(parent_data, list) and isinstance(parent_data[0], int) else parent_data[0] + + # Generate candidate center positions + allcandidate = generate_allcandidate_center( + group_info[str(group_id)][0], + group_b, + os.path.join(basepath, filename) + ) - group_info[str(group_id)][2][3:6]=nearest_point.tolist() + # Find nearest candidate + weights = np.array([1, 1, 1]) + current_pos = np.array(group_info[str(group_id)][2][3:6]) + error = (allcandidate - current_pos) * weights + dist = np.linalg.norm(error) + idx = np.argmin(dist) + nearest_point = allcandidate[idx] + + # Snap if close enough + if np.linalg.norm((nearest_point - current_pos) * weights) < 0.03: + group_info[str(group_id)][2][3:6] = nearest_point.tolist() + + # ===================================================================== + # STEP 3: Save processed data as JSON + # ===================================================================== + with open(os.path.join(basepath, filename, 'basic_info.json'), "w", encoding="utf-8") as f: + json.dump(data, f, indent=4) + + # ===================================================================== + # STEP 4: Generate URDF file + # ===================================================================== + jsonfile = os.path.join(basepath, filename, 'basic_info.json') + geofile = os.path.join(basepath, filename, 'objs') + + with open(jsonfile, 'r') as fp: + jsondata = json.load(fp) + + mov = jsondata['group_info'] + + # Create root URDF element + robot = ET.Element('robot', name='scene') + + # World link (fixed reference frame) + link = ET.SubElement(robot, 'link', name='l_world') + add_inertial(link) + save = 1 # Track number of movable joints + # ----------------------------------------------------------------- + # Case 1: Static object (only base group, no movable parts) + # ----------------------------------------------------------------- + if len(mov) == 1: + fixlist = mov['0'] + + # Create links for each part in base group + for fixindex in fixlist: + link = ET.SubElement(robot, 'link', name=f'l_{fixindex}') + add_inertial(link) - - - - - with open(os.path.join(basepath,filename,'basic_info.json'), "w", encoding="utf-8") as f: - json.dump(data, f, indent=4) - - + mesh_path = os.path.join(geofile, str(fixindex), f'{fixindex}.obj') + if os.path.exists(mesh_path): + visual = ET.SubElement(link, 'visual') + geometry = ET.SubElement(visual, "geometry") + ET.SubElement( + geometry, "mesh", + filename=os.path.join('./objs', str(fixindex), f'{fixindex}.obj'), + scale="1 1 1" + ) + ET.SubElement(visual, "origin", xyz="0 0 0", rpy="0 0 0") + + # Chain parts with fixed joints + for i in range(len(fixlist) - 1): + parentname = f'l_{fixlist[i]}' + childname = f'l_{fixlist[i+1]}' + add_fixed_joint( + robot, f'joint_fixed_{fixlist[i]}_{fixlist[i+1]}', + parentname, childname, xyz="0 0 0", rpy="0 0 0" + ) - jsonfile=os.path.join(basepath,filename,'basic_info.json') - geofile=os.path.join(basepath,filename,'objs') - - with open(jsonfile,'r') as fp: - jsondata=json.load(fp) - - mov=jsondata['group_info'] - - robot = ET.Element('robot', name='scene') - link = ET.SubElement(robot, 'link', name='l_world') - add_inertial(link) + # Connect first part to world + add_fixed_joint( + robot, f'joint_fixed_world{fixlist[0]}', + 'l_world', f'l_{fixlist[0]}', xyz="0 0 0", rpy="0 0 0" + ) - save=1 + # ----------------------------------------------------------------- + # Case 2: Articulated object (multiple kinematic groups) + # ----------------------------------------------------------------- + else: + offset = False + # Create base group links + fixlist = mov['0'] + for fixindex in fixlist: + link = ET.SubElement(robot, 'link', name=f'l_{fixindex}') + add_inertial(link) + + mesh_path = os.path.join(geofile, str(fixindex), f'{fixindex}.obj') + if os.path.exists(mesh_path): + visual = ET.SubElement(link, 'visual') + geometry = ET.SubElement(visual, "geometry") + ET.SubElement( + geometry, "mesh", + filename=os.path.join('./objs', str(fixindex), f'{fixindex}.obj'), + scale="1 1 1" + ) + ET.SubElement(visual, "origin", xyz="0 0 0", rpy="0 0 0") + + # Chain base parts + for i in range(len(fixlist) - 1): + parentname = f'l_{fixlist[i]}' + childname = f'l_{fixlist[i+1]}' + add_fixed_joint( + robot, f'joint_fixed_{fixlist[i]}_{fixlist[i+1]}', + parentname, childname, xyz="0 0 0", rpy="0 0 0" + ) + + add_fixed_joint( + robot, f'joint_fixed_world{fixlist[0]}', + 'l_world', f'l_{fixlist[0]}', xyz="0 0 0", rpy="0 0 0" + ) - if len(mov)==1: - fixlist=mov['0'] + # ------------------------------------------------------------- + # Process movable groups (groups 1, 2, 3, ...) + # ------------------------------------------------------------- + groupnum = len(mov) + for groupindex in range(1, groupnum): + fixlist = mov[str(groupindex)][0] + + # Create links for parts in this group for fixindex in fixlist: - link = ET.SubElement(robot, 'link', name='l_'+str(fixindex)) + link = ET.SubElement(robot, 'link', name=f'l_{fixindex}') add_inertial(link) - if os.path.exists(os.path.join(geofile,str(fixindex),str(fixindex)+'.obj')): + + mesh_path = os.path.join(geofile, str(fixindex), f'{fixindex}.obj') + if os.path.exists(mesh_path): visual = ET.SubElement(link, 'visual') geometry = ET.SubElement(visual, "geometry") - ET.SubElement(geometry, "mesh", filename=os.path.join('./objs',str(fixindex),str(fixindex)+'.obj'), scale="1 1 1") + ET.SubElement( + geometry, "mesh", + filename=os.path.join('./objs', str(fixindex), f'{fixindex}.obj'), + scale="1 1 1" + ) ET.SubElement(visual, "origin", xyz="0 0 0", rpy="0 0 0") - for i in range(len(fixlist)-1): - parentname='l_'+str(fixlist[i]) - childname='l_'+str(fixlist[i+1]) - add_fixed_joint(robot, 'joint_fixed_'+str(fixlist[i])+'_'+str(fixlist[i+1]), parentname, childname, xyz="0 0 0", rpy="0 0 0") + # Chain parts within this group with fixed joints + for i in range(len(fixlist) - 1): + parentname = f'l_{fixlist[i]}' + childname = f'l_{fixlist[i+1]}' + add_fixed_joint( + robot, f'joint_fixed_{fixlist[i]}_{fixlist[i+1]}', + parentname, childname, xyz="0 0 0", rpy="0 0 0" + ) - add_fixed_joint(robot, 'joint_fixed_world'+str(fixlist[0]), 'l_world', 'l_'+str(fixlist[0]), xyz="0 0 0", rpy="0 0 0") - - else: - - offset=False + # Determine parent and child for this kinematic group + parent_group_data = mov[mov[str(groupindex)][1]] + if isinstance(parent_group_data[0], int): + parentgroupindex = str(parent_group_data[0]) + else: + parentgroupindex = str(parent_group_data[0][0]) + + childgroupindex = fixlist[0] + parentgroupname = f'l_{parentgroupindex}' + childgroupname = f'l_{childgroupindex}' + + # Create abstract link for joint connection + abs_link = ET.SubElement( + robot, 'link', + name=f'abstract_{parentgroupindex}_{childgroupindex}' + ) + add_inertial(abs_link) + + joint_type = mov[str(groupindex)][-1] + params = mov[str(groupindex)][-2] + + # --------------------------------------------------------- + # JOINT TYPE A: Free/Floating (6-DOF) + # --------------------------------------------------------- + if joint_type == 'A': + add_fixed_joint( + robot, + f'joint_fixed_abstract_{parentgroupindex}_{childgroupindex}', + f'abstract_{parentgroupindex}_{childgroupindex}', + childgroupname, xyz="0 0 0", rpy="0 0 0" + ) - fixlist=mov['0'] - for fixindex in fixlist: - link = ET.SubElement(robot, 'link', name='l_'+str(fixindex)) - add_inertial(link) - if os.path.exists(os.path.join(geofile,str(fixindex),str(fixindex)+'.obj')): - visual = ET.SubElement(link, 'visual') - geometry = ET.SubElement(visual, "geometry") - ET.SubElement(geometry, "mesh", filename=os.path.join('./objs',str(fixindex),str(fixindex)+'.obj'), scale="1 1 1") - ET.SubElement(visual, "origin", xyz="0 0 0", rpy="0 0 0") + joint = ET.SubElement( + robot, "joint", + name=f'joint_free_{parentgroupname}_abstract_{parentgroupindex}_{childgroupindex}', + type="floating" + ) + ET.SubElement(joint, "parent", link=parentgroupname) + ET.SubElement(joint, "child", link=f'abstract_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") + + # --------------------------------------------------------- + # JOINT TYPE B: Prismatic/Slide (1-DOF translation) + # --------------------------------------------------------- + elif joint_type == 'B': + save += 1 + add_fixed_joint( + robot, + f'joint_fixed_abstract_{parentgroupindex}_{childgroupindex}', + f'abstract_{parentgroupindex}_{childgroupindex}', + childgroupname, xyz="0 0 0", rpy="0 0 0" + ) - for i in range(len(fixlist)-1): - parentname='l_'+str(fixlist[i]) - childname='l_'+str(fixlist[i+1]) - add_fixed_joint(robot, 'joint_fixed_'+str(fixlist[i])+'_'+str(fixlist[i+1]), parentname, childname, xyz="0 0 0", rpy="0 0 0") - add_fixed_joint(robot, 'joint_fixed_world'+str(fixlist[0]), 'l_world', 'l_'+str(fixlist[0]), xyz="0 0 0", rpy="0 0 0") - - - groupnum=len(mov) - for groupindex in range(1,groupnum): - fixlist=mov[str(groupindex)][0] - for fixindex in fixlist: - link = ET.SubElement(robot, 'link', name='l_'+str(fixindex)) - add_inertial(link) - if os.path.exists(os.path.join(geofile,str(fixindex),str(fixindex)+'.obj')): - visual = ET.SubElement(link, 'visual') - geometry = ET.SubElement(visual, "geometry") - ET.SubElement(geometry, "mesh", filename=os.path.join('./objs',str(fixindex),str(fixindex)+'.obj'), scale="1 1 1") - ET.SubElement(visual, "origin", xyz="0 0 0", rpy="0 0 0") - - for i in range(len(fixlist)-1): - parentname='l_'+str(fixlist[i]) - childname='l_'+str(fixlist[i+1]) - add_fixed_joint(robot, 'joint_fixed_'+str(fixlist[i])+'_'+str(fixlist[i+1]), parentname, childname, xyz="0 0 0", rpy="0 0 0") - if isinstance(mov[mov[str(groupindex)][1]][0], int): - parentgroupindex=str(mov[mov[str(groupindex)][1]][0]) - else: - parentgroupindex=str(mov[mov[str(groupindex)][1]][0][0]) + # Slide axis direction + xyz = f'{params[0]} {params[1]} {params[2]}' - childgroupindex=fixlist[0] - parentgroupname='l_'+str(parentgroupindex) - childgroupname='l_'+str(childgroupindex) + joint = ET.SubElement( + robot, "joint", + name=f'joint_prismatic_{parentgroupname}_abstract_{parentgroupindex}_{childgroupindex}', + type="prismatic" + ) + ET.SubElement(joint, "parent", link=parentgroupname) + ET.SubElement(joint, "child", link=f'abstract_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") + ET.SubElement(joint, "axis", xyz=xyz) + ET.SubElement( + joint, "limit", + lower=str(params[-2]), upper=str(params[-1]), + effort="2000.0", velocity="2.0" + ) - abs_link = ET.SubElement(robot, 'link', name='abstract_'+str(parentgroupindex)+'_'+str(childgroupindex)) - add_inertial(abs_link) + # --------------------------------------------------------- + # JOINT TYPE C: Revolute/Hinge (1-DOF rotation) + # --------------------------------------------------------- + elif joint_type == 'C': + save += 1 + # Axis position and negative for child offset + point = f'{params[3]} {params[4]} {params[5]}' + pointrev = f'{-params[3]} {-params[4]} {-params[5]}' + xyz = f'{params[0]} {params[1]} {params[2]}' + + add_fixed_joint( + robot, + f'joint_fixed_abstract_{parentgroupindex}_{childgroupindex}', + f'abstract_{parentgroupindex}_{childgroupindex}', + childgroupname, xyz=pointrev, rpy="0 0 0" + ) - if mov[str(groupindex)][-1]=='A': - add_fixed_joint(robot, 'joint_fixed_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), 'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), childgroupname, xyz="0 0 0", rpy="0 0 0") - - joint = ET.SubElement(robot, "joint", name='joint_free_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="floating") - ET.SubElement(joint, "parent", link=parentgroupname) - ET.SubElement(joint, "child", link='abstract_'+str(parentgroupindex)+'_'+str(childgroupindex)) - ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") - - elif mov[str(groupindex)][-1]=='B': - save+=1 - add_fixed_joint(robot, 'joint_fixed_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), 'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), childgroupname, xyz="0 0 0", rpy="0 0 0") - - xyz=str(mov[str(groupindex)][-2][0])+' '+str(mov[str(groupindex)][-2][1])+' '+str(mov[str(groupindex)][-2][2]) - - joint = ET.SubElement(robot, "joint", name='joint_prismatic_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="prismatic") - ET.SubElement(joint, "parent", link=parentgroupname) - ET.SubElement(joint, "child", link='abstract_'+str(parentgroupindex)+'_'+str(childgroupindex)) - ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") - ET.SubElement(joint, "axis", xyz=xyz) - ET.SubElement(joint, "limit", lower=str(mov[str(groupindex)][-2][-2]), upper=str(mov[str(groupindex)][-2][-1]), effort="2000.0", velocity="2.0") - - elif mov[str(groupindex)][-1]=='C': - save+=1 - point=str(mov[str(groupindex)][-2][3])+' '+str(mov[str(groupindex)][-2][4])+' '+str(mov[str(groupindex)][-2][5]) - pointrev=str(-mov[str(groupindex)][-2][3])+' '+str(-mov[str(groupindex)][-2][4])+' '+str(-mov[str(groupindex)][-2][5]) - xyz=str(mov[str(groupindex)][-2][0])+' '+str(mov[str(groupindex)][-2][1])+' '+str(mov[str(groupindex)][-2][2]) - - add_fixed_joint(robot, 'joint_fixed_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), 'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), childgroupname, xyz=pointrev, rpy="0 0 0") - - - if mov[str(groupindex)][-2][-2]==-1 and mov[str(groupindex)][-2][-1]==1: - joint = ET.SubElement(robot, "joint", name='joint_revolute_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="continuous") - else: - joint = ET.SubElement(robot, "joint", name='joint_revolute_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="revolute") - ET.SubElement(joint, "parent", link=parentgroupname) - ET.SubElement(joint, "child", link='abstract_'+str(parentgroupindex)+'_'+str(childgroupindex)) - - ET.SubElement(joint, "origin", xyz=point, rpy="0 0 0") - ET.SubElement(joint, "axis", xyz=xyz) - - if mov[str(groupindex)][-2][-2]==-1 and mov[str(groupindex)][-2][-1]==1: - ET.SubElement(joint, "limit", effort="2000.0", velocity="2.0") - else: - - ET.SubElement(joint, "limit", lower=str(mov[str(groupindex)][-2][-2]*np.pi), upper=str(mov[str(groupindex)][-2][-1]*np.pi), effort="2000.0", velocity="2.0") - - elif mov[str(groupindex)][-1]=='D': - save+=1 - - point=str(mov[str(groupindex)][-2][3])+' '+str(mov[str(groupindex)][-2][4])+' '+str(mov[str(groupindex)][-2][5]) - pointrev=str(-mov[str(groupindex)][-2][3])+' '+str(-mov[str(groupindex)][-2][4])+' '+str(-mov[str(groupindex)][-2][5]) - xyz=str(mov[str(groupindex)][-2][0])+' '+str(mov[str(groupindex)][-2][1])+' '+str(mov[str(groupindex)][-2][2]) - - add_fixed_joint(robot, 'joint_fixed_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), 'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), childgroupname, xyz=pointrev, rpy="0 0 0") - - abs_linkx = ET.SubElement(robot, 'link', name='abstract_x_'+str(parentgroupindex)+'_'+str(childgroupindex)) - add_inertial(abs_linkx,pointrev) - abs_linkz = ET.SubElement(robot, 'link', name='abstract_z_'+str(parentgroupindex)+'_'+str(childgroupindex)) - add_inertial(abs_linkz,pointrev) - - joint = ET.SubElement(robot, "joint", name='joint_hinge_y_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="revolute") - ET.SubElement(joint, "parent", link=parentgroupname) - ET.SubElement(joint, "child", link='abstract_z_'+str(parentgroupindex)+'_'+str(childgroupindex)) + # Check if continuous (unlimited range) + is_continuous = (params[-2] == -1 and params[-1] == 1) + + joint = ET.SubElement( + robot, "joint", + name=f'joint_revolute_{parentgroupname}_abstract_{parentgroupindex}_{childgroupindex}', + type="continuous" if is_continuous else "revolute" + ) + ET.SubElement(joint, "parent", link=parentgroupname) + ET.SubElement(joint, "child", link=f'abstract_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "origin", xyz=point, rpy="0 0 0") + ET.SubElement(joint, "axis", xyz=xyz) - ET.SubElement(joint, "origin", xyz=point, rpy="0 0 0") - ET.SubElement(joint, "axis", xyz="0 0 1") - ET.SubElement(joint, "limit", lower=str(-np.pi), upper=str(np.pi), effort="2000.0", velocity="2.0") + if is_continuous: + ET.SubElement(joint, "limit", effort="2000.0", velocity="2.0") + else: + ET.SubElement( + joint, "limit", + lower=str(params[-2] * np.pi), + upper=str(params[-1] * np.pi), + effort="2000.0", velocity="2.0" + ) + + # --------------------------------------------------------- + # JOINT TYPE D: Ball/Spherical (3-DOF rotation) + # Implemented as 3 chained revolute joints (ZXY Euler) + # --------------------------------------------------------- + elif joint_type == 'D': + save += 1 + + point = f'{params[3]} {params[4]} {params[5]}' + pointrev = f'{-params[3]} {-params[4]} {-params[5]}' + xyz = f'{params[0]} {params[1]} {params[2]}' + + add_fixed_joint( + robot, + f'joint_fixed_abstract_{parentgroupindex}_{childgroupindex}', + f'abstract_{parentgroupindex}_{childgroupindex}', + childgroupname, xyz=pointrev, rpy="0 0 0" + ) - joint = ET.SubElement(robot, "joint", name='joint_hinge_z_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="revolute") - ET.SubElement(joint, "parent", link='abstract_z_'+str(parentgroupindex)+'_'+str(childgroupindex)) - ET.SubElement(joint, "child", link='abstract_x_'+str(parentgroupindex)+'_'+str(childgroupindex)) + # Create intermediate links for ball joint decomposition + abs_linkx = ET.SubElement( + robot, 'link', + name=f'abstract_x_{parentgroupindex}_{childgroupindex}' + ) + add_inertial(abs_linkx, pointrev) + + abs_linkz = ET.SubElement( + robot, 'link', + name=f'abstract_z_{parentgroupindex}_{childgroupindex}' + ) + add_inertial(abs_linkz, pointrev) - ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") - ET.SubElement(joint, "axis", xyz="1 0 0") - ET.SubElement(joint, "limit", lower=str(-np.pi), upper=str(np.pi), effort="2000.0", velocity="2.0") + # First rotation joint (Z-axis) + joint = ET.SubElement( + robot, "joint", + name=f'joint_hinge_y_{parentgroupname}_abstract_{parentgroupindex}_{childgroupindex}', + type="revolute" + ) + ET.SubElement(joint, "parent", link=parentgroupname) + ET.SubElement(joint, "child", link=f'abstract_z_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "origin", xyz=point, rpy="0 0 0") + ET.SubElement(joint, "axis", xyz="0 0 1") + ET.SubElement( + joint, "limit", + lower=str(-np.pi), upper=str(np.pi), + effort="2000.0", velocity="2.0" + ) - joint = ET.SubElement(robot, "joint", name='joint_hinge_x_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="revolute") - ET.SubElement(joint, "parent", link='abstract_x_'+str(parentgroupindex)+'_'+str(childgroupindex)) - ET.SubElement(joint, "child", link='abstract_'+str(parentgroupindex)+'_'+str(childgroupindex)) + # Second rotation joint (X-axis) + joint = ET.SubElement( + robot, "joint", + name=f'joint_hinge_z_{parentgroupname}_abstract_{parentgroupindex}_{childgroupindex}', + type="revolute" + ) + ET.SubElement(joint, "parent", link=f'abstract_z_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "child", link=f'abstract_x_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") + ET.SubElement(joint, "axis", xyz="1 0 0") + ET.SubElement( + joint, "limit", + lower=str(-np.pi), upper=str(np.pi), + effort="2000.0", velocity="2.0" + ) - ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") - ET.SubElement(joint, "axis", xyz="0 1 0") - ET.SubElement(joint, "limit", lower=str(-np.pi), upper=str(np.pi), effort="2000.0", velocity="2.0") + # Third rotation joint (Y-axis) + joint = ET.SubElement( + robot, "joint", + name=f'joint_hinge_x_{parentgroupname}_abstract_{parentgroupindex}_{childgroupindex}', + type="revolute" + ) + ET.SubElement(joint, "parent", link=f'abstract_x_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "child", link=f'abstract_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") + ET.SubElement(joint, "axis", xyz="0 1 0") + ET.SubElement( + joint, "limit", + lower=str(-np.pi), upper=str(np.pi), + effort="2000.0", velocity="2.0" + ) - elif mov[str(groupindex)][-1]=='CB': - save+=1 + # --------------------------------------------------------- + # JOINT TYPE CB: Combined Revolute + Prismatic + # --------------------------------------------------------- + elif joint_type == 'CB': + save += 1 - point=str(mov[str(groupindex)][-2][3])+' '+str(mov[str(groupindex)][-2][4])+' '+str(mov[str(groupindex)][-2][5]) - pointrev=str(-mov[str(groupindex)][-2][3])+' '+str(-mov[str(groupindex)][-2][4])+' '+str(-mov[str(groupindex)][-2][5]) - xyz=str(mov[str(groupindex)][-2][0])+' '+str(mov[str(groupindex)][-2][1])+' '+str(mov[str(groupindex)][-2][2]) - xyz1=str(mov[str(groupindex)][-2][8])+' '+str(mov[str(groupindex)][-2][9])+' '+str(mov[str(groupindex)][-2][10]) + # Axis position + point = f'{params[3]} {params[4]} {params[5]}' + pointrev = f'{-params[3]} {-params[4]} {-params[5]}' + + # Revolute axis direction + xyz = f'{params[0]} {params[1]} {params[2]}' + + # Slide axis direction (params[8:11]) + xyz1 = f'{params[8]} {params[9]} {params[10]}' + + add_fixed_joint( + robot, + f'joint_fixed_abstract_{parentgroupindex}_{childgroupindex}', + f'abstract_{parentgroupindex}_{childgroupindex}', + childgroupname, xyz=pointrev, rpy="0 0 0" + ) - add_fixed_joint(robot, 'joint_fixed_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), 'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), childgroupname, xyz=pointrev, rpy="0 0 0") - - - abs_linkx = ET.SubElement(robot, 'link', name='abstract_x_'+str(parentgroupindex)+'_'+str(childgroupindex)) - add_inertial(abs_linkx) + # Intermediate link for combined joint + abs_linkx = ET.SubElement( + robot, 'link', + name=f'abstract_x_{parentgroupindex}_{childgroupindex}' + ) + add_inertial(abs_linkx) - joint = ET.SubElement(robot, "joint", name='joint_prim_y_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="prismatic") - ET.SubElement(joint, "parent", link=parentgroupname) - ET.SubElement(joint, "child", link='abstract_x_'+str(parentgroupindex)+'_'+str(childgroupindex)) - - ET.SubElement(joint, "origin", xyz=point, rpy="0 0 0") - ET.SubElement(joint, "axis", xyz=xyz1) - ET.SubElement(joint, "limit", lower=str(mov[str(groupindex)][-2][-2]), upper=str(mov[str(groupindex)][-2][-1]), effort="2000.0", velocity="2.0") - - if mov[str(groupindex)][-2][6]==-1 and mov[str(groupindex)][-2][7]==1: - joint = ET.SubElement(robot, "joint", name='joint_revo_x_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="continuous") - else: - joint = ET.SubElement(robot, "joint", name='joint_revo_x_'+parentgroupname+'_'+'abstract_'+str(parentgroupindex)+'_'+str(childgroupindex), type="revolute") - ET.SubElement(joint, "parent", link='abstract_x_'+str(parentgroupindex)+'_'+str(childgroupindex)) - ET.SubElement(joint, "child", link='abstract_'+str(parentgroupindex)+'_'+str(childgroupindex)) - - ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") - ET.SubElement(joint, "axis", xyz=xyz) - - if mov[str(groupindex)][-2][6]==-1 and mov[str(groupindex)][-2][7]==1: - ET.SubElement(joint, "limit", effort="2000.0", velocity="2.0") - else: - ET.SubElement(joint, "limit", lower=str(mov[str(groupindex)][-2][6]*np.pi), upper=str(mov[str(groupindex)][-2][7]*np.pi), effort="2000.0", velocity="2.0") + # Prismatic joint first + joint = ET.SubElement( + robot, "joint", + name=f'joint_prim_y_{parentgroupname}_abstract_{parentgroupindex}_{childgroupindex}', + type="prismatic" + ) + ET.SubElement(joint, "parent", link=parentgroupname) + ET.SubElement(joint, "child", link=f'abstract_x_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "origin", xyz=point, rpy="0 0 0") + ET.SubElement(joint, "axis", xyz=xyz1) + ET.SubElement( + joint, "limit", + lower=str(params[-2]), upper=str(params[-1]), + effort="2000.0", velocity="2.0" + ) + # Revolute joint second + is_continuous = (params[6] == -1 and params[7] == 1) + + joint = ET.SubElement( + robot, "joint", + name=f'joint_revo_x_{parentgroupname}_abstract_{parentgroupindex}_{childgroupindex}', + type="continuous" if is_continuous else "revolute" + ) + ET.SubElement(joint, "parent", link=f'abstract_x_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "child", link=f'abstract_{parentgroupindex}_{childgroupindex}') + ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") + ET.SubElement(joint, "axis", xyz=xyz) + if is_continuous: + ET.SubElement(joint, "limit", effort="2000.0", velocity="2.0") else: - print('error type') - - - - tree = ET.ElementTree(robot) - ET.indent(tree, space=" ", level=0) - tree.write(os.path.join(basepath,filename,'basic.urdf'), encoding="utf-8", xml_declaration=True) + ET.SubElement( + joint, "limit", + lower=str(params[6] * np.pi), + upper=str(params[7] * np.pi), + effort="2000.0", velocity="2.0" + ) + + # Unknown joint type + else: + print(f'Error: Unknown joint type: {joint_type}') + + # ===================================================================== + # STEP 5: Write URDF file + # ===================================================================== + tree = ET.ElementTree(robot) + ET.indent(tree, space=" ", level=0) + tree.write( + os.path.join(basepath, filename, 'basic.urdf'), + encoding="utf-8", xml_declaration=True + ) + # ===================================================================== + # STEP 6: Generate MJCF file + # ===================================================================== + parts_cfg = jsondata['parts'] + + # Calculate model scale from dimensions + nums = [int(x) for x in re.findall(r'\d+', jsondata['dimension'])] + max_num = max(nums) / 100 # Convert to meters + + # Prepare part configurations for MJCF + for partind in range(len(parts_cfg)): + part = parts_cfg[partind] + part['name'] = f'l_{part["label"]}_{part["name"]}' + part['mesh_file'] = os.path.join('./objs', str(partind), f'{partind}.obj') + part['scale'] = f'{max_num} {max_num} {max_num}' + part['tex_file'] = os.path.join('./objs', str(partind), 'material_0.png') + + # Convert density from g/cm³ to kg/m³ + part['density'] = float(part['density'].split('g/cm')[0]) * 1000 + part['fluidshape'] = 'ellipsoid' + part['contype'] = '1' + part['conaffinity'] = '1' + + # Copy skybox texture + shutil.copy( + 'mjcf_source/desert.png', + os.path.join(basepath, filename, 'desert.png') + ) - parts_cfg = jsondata['parts'] + # Generate MJCF + out = generate_mjcf( + jsondata=jsondata, + out_path=os.path.join(basepath, filename, 'basic.xml'), + parts=parts_cfg, + fixed_base=args.fixed_base, + deformable=args.deformable + ) - - nums = [int(x) for x in re.findall(r'\d+', jsondata['dimension'])] - max_num = max(nums)/100 - for partind in range(len(parts_cfg)): - parts_cfg[partind]['name']='l_'+str(parts_cfg[partind]['label'])+'_'+parts_cfg[partind]['name'] - parts_cfg[partind]['mesh_file']=os.path.join('./objs',str(partind),str(partind)+'.obj') - parts_cfg[partind]['scale']=str(max_num)+' '+str(max_num)+' '+str(max_num) - parts_cfg[partind]['tex_file']=os.path.join('./objs',str(partind),'material_0.png') - parts_cfg[partind]['density']=float(parts_cfg[partind]['density'].split('g/cm')[0])*1000 - parts_cfg[partind]['fluidshape']='ellipsoid' - parts_cfg[partind]['contype']='1' - parts_cfg[partind]['conaffinity']='1' - shutil.copy('mjcf_source/desert.png',os.path.join(basepath,filename,'desert.png')) - out = generate_mjcf(jsondata=jsondata,out_path=os.path.join(basepath,filename,'basic.xml'), parts=parts_cfg,fixed_base=args.fixed_base,deformable=args.deformable) - - logger.info('complete: '+filename) - else: - logger.info('skip: '+filename) + logger.info(f'Completed: {filename}')