-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
I trained a segmentation model and ran inference on the validation data, but the performance is extremely poor. Could anyone explain what might be going wrong?
gt:
import os
import argparse
import numpy as np
import torch
import open3d as o3d
from mmengine import Config
from mmengine.logging.logger import MMLogger
from utils.load_save_util import revise_ckpt
import numba as nb
import warnings
from dataloader.dataset import get_nuScenes_label_name
import yaml
warnings.filterwarnings("ignore")
# ID 0 (忽略/噪音) + ID 1-16 (您的类别)
COLOR_MAP = {
0: [0, 0, 0], # 忽略/噪音 (Black)
# 运动物体及交通设施 (1 - 10)
1: [0.6, 0.4, 0.2], # barrier (障碍物 - 棕色)
2: [1.0, 0.6, 0.0], # bicycle (自行车 - 橙色)
3: [0.6, 0.0, 1.0], # bus (公共汽车 - 紫色)
4: [1.0, 0.0, 0.0], # car (轿车 - 红色)
5: [0.4, 0.2, 0.6], # construction_vehicle (工程车辆 - 靛青色)
6: [1.0, 0.0, 1.0], # motorcycle (摩托车 - 洋红色)
7: [0.0, 1.0, 1.0], # pedestrian (行人 - 青色)
8: [1.0, 1.0, 0.0], # traffic_cone (交通锥 - 黄色)
9: [0.5, 0.5, 0.5], # trailer (拖车 - 灰色)
10: [0.2, 0.2, 0.8],# truck (卡车 - 蓝色)
# 静态结构及地面 (11 - 16)
11: [0.2, 0.2, 0.2],# driveable_surface (可行驶路面 - 深灰色)
12: [0.8, 0.8, 0.8],# other_flat (其他平面 - 亮灰色)
13: [0.0, 1.0, 0.0],# sidewalk (人行道 - 绿色)
14: [0.2, 0.6, 0.2],# terrain (地形 - 暗绿色)
15: [0.8, 0.6, 0.0],# manmade (人造物 - 褐黄色)
16: [0.0, 0.4, 0.0] # vegetation (植被 - 深绿色)
}
def get_colors_from_labels(labels):
"""
将标签数组转换为颜色数组
"""
n_points = labels.shape[0]
colors = np.zeros((n_points, 3))
# ... (逻辑不变) ...
if isinstance(labels, torch.Tensor):
labels = labels.cpu().numpy()
for label_id, color in COLOR_MAP.items():
idx = np.where(labels == label_id)[0]
colors[idx] = color
colors[np.where(colors.sum(axis=1) == 0)] = [1, 1, 1]
return colors
def visualize_results(points_xyz, pred_labels, title="Prediction Result"):
"""
points_xyz: (N, 3) numpy array
pred_labels: (N,) numpy array or tensor
"""
# 1. 确保数据是 Numpy 格式
if isinstance(points_xyz, torch.Tensor):
points_xyz = points_xyz.cpu().numpy()
if isinstance(pred_labels, torch.Tensor):
pred_labels = pred_labels.cpu().numpy()
vis_list = []
# --- 创建预测结果的点云 ---
pcd_pred = o3d.geometry.PointCloud()
pcd_pred.points = o3d.utility.Vector3dVector(points_xyz[:, :3])
# 获取颜色
colors_pred = get_colors_from_labels(pred_labels)
pcd_pred.colors = o3d.utility.Vector3dVector(colors_pred)
vis_list.append(pcd_pred)
# --- 启动可视化窗口 ---
print(f"Opening visualization: {title}")
o3d.visualization.draw_geometries(vis_list,
window_name=title,
width=1024, height=768,
left=50, top=50)
o3d.io.write_point_cloud("seg_pcd_pred.pcd", pcd_pred)
def polar2cat(input_xyz_polar):
x = input_xyz_polar[0] * np.cos(input_xyz_polar[1])
y = input_xyz_polar[0] * np.sin(input_xyz_polar[1])
return np.stack((x, y, input_xyz_polar[2]), axis=0)
def cart2polar(input_xyz):
rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2)
phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0])
return np.stack((rho, phi, input_xyz[:, 2]), axis=1)
def load_custom_data(dataset_config,
lidar_path="data/nuscenes/samples/LIDAR_TOP/n015-2018-08-02-17-16-37+0800__LIDAR_TOP__1533201470448696.pcd.bin"):
# 加载点云[x, y, z, intensity, ring]
points = np.fromfile(lidar_path, dtype=np.float32, count=-1).reshape([-1, 5])
return points
def custom_data_preprocess(points, cfg):
# 参数读取
dataset_config = cfg.dataset_params
grid_size = cfg.grid_size
grid_size_vox = dataset_config['grid_size_vox']
fixed_volume_space = dataset_config['fixed_volume_space']
max_volume_space = dataset_config['max_volume_space']
min_volume_space = dataset_config['min_volume_space']
grid_size = np.asarray(grid_size).astype(np.int32)
grid_size_vox = np.asarray(grid_size_vox).astype(np.int32)
xyz, feat = points[:, :3], points[:, 3:]
# convert coordinate into polar coordinates
xyz_pol = cart2polar(xyz)
assert fixed_volume_space
max_bound = np.asarray(max_volume_space)
min_bound = np.asarray(min_volume_space)
# get grid index
crop_range = max_bound - min_bound
intervals = crop_range / (grid_size)
intervals_vox = crop_range / (grid_size_vox)
if (intervals == 0).any():
print("Zero interval!")
xyz_pol_grid = np.clip(xyz_pol, min_bound, max_bound - 1e-3)
grid_ind = (np.floor((xyz_pol_grid - min_bound) / intervals)).astype(np.int32)
grid_ind_vox = (np.floor((xyz_pol_grid - min_bound) / intervals_vox)).astype(np.int32)
grid_ind_vox_float = ((xyz_pol_grid - min_bound) / intervals_vox).astype(np.float32)
# center data on each voxel for PTnet
voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound
return_xyz = xyz_pol - voxel_centers
return_feat = np.concatenate((return_xyz, xyz_pol, xyz[:, :2], feat), axis=1)
# 4. 构建数据列表
data_list = [(return_feat, grid_ind, grid_ind_vox_float)]
# 5. 堆叠并返回 3 个 Tensor
point_feat = np.stack([d[0] for d in data_list]).astype(np.float32)
grid_ind_stack = np.stack([d[1] for d in data_list]).astype(np.float32)
grid_ind_vox_stack = np.stack([d[2] for d in data_list]).astype(np.float32)
# 返回 点云特征, 点到粗体素的索引, 点到细体素的索引 (3个 Tensor)
return torch.from_numpy(point_feat), \
torch.from_numpy(grid_ind_stack), \
torch.from_numpy(grid_ind_vox_stack)
# --- [函数修改] main ---
def main(args):
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cpu":
raise RuntimeError("GPU is required for inference")
cfg = Config.fromfile(args.py_config)
dataset_config = cfg.dataset_params
logger = MMLogger(name='infer_log', log_file=args.log_file, log_level='INFO')
logger.info(f"Checkpoint Path: {args.ckpt_path}")
logger.info(f"Device: {device}")
# 获取 nuScenes 标签名称,用于语义分割评估
SemKITTI_label_name = get_nuScenes_label_name(dataset_config["label_mapping"])
unique_label = np.asarray(cfg.unique_label)
print(f"unique_label: {unique_label}")
unique_label_str = [SemKITTI_label_name[x] for x in unique_label]
print(f"unique_label_str: {unique_label_str}")
# 模型加载
from builder import model_builder
my_model = model_builder.build(cfg.model).to(device)
n_params = sum(p.numel() for p in my_model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params}")
assert os.path.isfile(args.ckpt_path), f"Checkpoint not found: {args.ckpt_path}"
ckpt = torch.load(args.ckpt_path, map_location='cpu')
if 'state_dict' in ckpt:
ckpt = ckpt['state_dict']
my_model.load_state_dict(revise_ckpt(ckpt), strict=False)
my_model.eval()
logger.info("Checkpoint loaded successfully.")
# 数据预处理
raw_points = load_custom_data(dataset_config) # 1. 只加载 raw_points
print(f"input raw_points shape {raw_points.shape}")
# 移除 raw_labels 打印
process_data = custom_data_preprocess(raw_points, cfg) # 2. 只传入 raw_points
# 推理
with torch.no_grad():
# 3. 接收 3 个返回值的 Tensor: 点云特征, 点到粗体素的索引, 点到细体素的索引
(points, val_grid_float, val_grid_vox) = process_data
val_grid_vox = val_grid_vox.to(torch.float32).cuda() # val_grid_vox 仍然是 float
points = points.cuda()
val_grid_float = val_grid_float.to(torch.float32).cuda()
print(f"points shape {points.shape}")
print(f"val_grid_float shape {val_grid_float.shape}")
print(f"val_grid_vox shape {val_grid_vox.shape}")
# 模型推理
predict_labels_vox, predict_labels_pts = my_model(points=points, grid_ind=val_grid_float,
grid_ind_vox=val_grid_vox)
print(f"predict_labels_vox shape {predict_labels_vox.shape}")
print(f"predict_labels_pts shape {predict_labels_pts.shape}")
print(f"开始进行后处理")
predict_labels_pts = predict_labels_pts.squeeze(-1).squeeze(-1)
predict_labels_pts = torch.argmax(predict_labels_pts, dim=1) # bs, n
predict_labels_pts = predict_labels_pts.detach().cpu()
predict_labels_vox = torch.argmax(predict_labels_vox, dim=1)
predict_labels_vox = predict_labels_vox.detach().cpu()
print(f"predict_labels_vox shape {predict_labels_vox.shape}")
print(f"predict_labels_pts shape {predict_labels_pts.shape}")
print(f"后处理完成")
for count in range(predict_labels_pts.shape[0]): # 循环 batch (通常为 1)
pred_cpu = predict_labels_pts.detach().cpu() # 预测标签
# 遍历 batch 中的每一个样本
cur_pred = pred_cpu[count] # (N,)
# 使用 PyTorch unique 统计每个标签 ID 的数量
unique_labels, counts = torch.unique(cur_pred, return_counts=True)
print(f"\n--- 预测结果类别数量统计 (Sample {count}) ---")
# 创建表格数据
stats_data = []
total_points = cur_pred.numel()
for label_id, count_val in zip(unique_labels.tolist(), counts.tolist()):
# 查找类别名称,如果找不到则显示 'Unknown'
class_name = SemKITTI_label_name.get(label_id, f"Unknown ID {label_id}")
percentage = (count_val / total_points) * 100
stats_data.append({
"ID": label_id,
"Class Name": class_name,
"Count": count_val,
"Percentage": f"{percentage:.2f}%"
})
# 格式化输出为 Markdown 表格
print("| ID | Class Name | Count | Percentage |")
print("|:--:|:-----------|:------|:-----------|")
for row in stats_data:
print(f"| {row['ID']} | {row['Class Name']} | {row['Count']} | {row['Percentage']} |")
print(f"Total points: {total_points}")
print("------------------------------------------\n")
vis_xyz = raw_points[:, :3] # 原始点云坐标 (N, 3)
# 4. 调用可视化函数:只传入预测结果和点云坐标
visualize_results(vis_xyz, cur_pred, title=f"Prediction Result Sample {count}")
logger.info("Inference completed.")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Custom PCD Inference for PointOcc (Pure Inference)")
parser.add_argument('--py-config', default='config/pointtpv_nusc_lidarseg.py', help='Model config file')
parser.add_argument('--ckpt-path', type=str, default="weights/seg/epoch_19.pth", help='Model checkpoint path')
parser.add_argument('--log-file', type=str, default='infer.log', help='Log file path')
args = parser.parse_args()
main(args)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels
