From 070109276a5784445e24dd651a553b0a9cfb491f Mon Sep 17 00:00:00 2001 From: guba Date: Mon, 4 Mar 2024 22:57:35 +0800 Subject: [PATCH 01/42] ADD: add initial files for modification. --- config.yml | 3 +- isegm/model/is_plainvit_model.py | 39 ++++++ ...i_out_vit_large448_pascal_new_scheduler.py | 126 ++++++++++++++++++ 3 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py diff --git a/config.yml b/config.yml index 95457a8..faca3c5 100755 --- a/config.yml +++ b/config.yml @@ -1,5 +1,5 @@ INTERACTIVE_MODELS_PATH: "./weights" -EXPS_PATH: "/playpen-raid2/qinliu/models/model_0125_2023" +EXPS_PATH: "./exps" # Evaluation datasets GRABCUT_PATH: "/playpen-raid2/qinliu/data/GrabCut" @@ -19,6 +19,7 @@ LVIS_v1_PATH: "/playpen-raid2/qinliu/data/COCO_2017" OPENIMAGES_PATH: "./datasets/OpenImages" PASCALVOC_PATH: "/playpen-raid2/qinliu/data/PascalVOC" ADE20K_PATH: "./datasets/ADE20K" +PASCAL_PATH: "/home/gyt/gyt/dataset/data/pascal_person_part" # You can download the weights for HRNet from the repository: # https://github.com/HRNet/HRNet-Image-Classification diff --git a/isegm/model/is_plainvit_model.py b/isegm/model/is_plainvit_model.py index cd0147b..8422fd8 100644 --- a/isegm/model/is_plainvit_model.py +++ b/isegm/model/is_plainvit_model.py @@ -93,3 +93,42 @@ def backbone_forward(self, image, coord_features=None): multi_scale_features = self.neck(backbone_features) return {'instances': self.head(multi_scale_features), 'instances_aux': None} + +class MultiOutVitModel(ISModel): + @serialize + def __init__( + self, + backbone_params={}, + neck_params={}, + head_params={}, + random_split=False, + **kwargs + ): + + super().__init__(**kwargs) + self.random_split = random_split + + self.patch_embed_coords = PatchEmbed( + img_size= backbone_params['img_size'], + patch_size=backbone_params['patch_size'], + in_chans=3 if self.with_prev_mask else 2, + embed_dim=backbone_params['embed_dim'], + ) + + self.backbone = VisionTransformer(**backbone_params) + self.neck = SimpleFPN(**neck_params) + self.head = SwinTransfomerSegHead(**head_params) + self.testMsg = "Test Message" + + def backbone_forward(self, image, coord_features=None): + coord_features = self.patch_embed_coords(coord_features) + backbone_features = self.backbone.forward_backbone(image, coord_features, self.random_split) + + # Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + B, N, C = backbone_features.shape + grid_size = self.backbone.patch_embed.grid_size + + backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) + multi_scale_features = self.neck(backbone_features) + + return {'instances': self.head(multi_scale_features), 'instances_aux': None} diff --git a/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py b/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py new file mode 100644 index 0000000..ed2a397 --- /dev/null +++ b/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py @@ -0,0 +1,126 @@ +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss + + +MODEL_NAME = 'cocolvis_vit_large448' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(16,16), + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1024, + out_dims = [192, 384, 768, 1536], + ) + + head_params = dict( + in_channels=[192, 384, 768, 1536], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=1, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample=cfg.upsample, + channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample], + ) + + model = MultiOutVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=cfg.random_split, + ) + + model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_LARGE) + model.to(cfg.device) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CocoLvisDataset( + cfg.LVIS_v1_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = CocoLvisDataset( + cfg.LVIS_v1_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[50, 55], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + layerwise_decay=cfg.layerwise_decay, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 20), (50, 1)], + image_dump_interval=300, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=3) + trainer.run(num_epochs=55, validation=False) From 18a300afc6a4199ba403ba0971926b393f8c9779 Mon Sep 17 00:00:00 2001 From: guba Date: Mon, 4 Mar 2024 22:58:14 +0800 Subject: [PATCH 02/42] TST: Add some auxiliary code for test issue --- TEST_simple_model.py | 0 testcnnect.py | 8 ++++++++ 2 files changed, 8 insertions(+) create mode 100644 TEST_simple_model.py create mode 100644 testcnnect.py diff --git a/TEST_simple_model.py b/TEST_simple_model.py new file mode 100644 index 0000000..e69de29 diff --git a/testcnnect.py b/testcnnect.py new file mode 100644 index 0000000..5f2adf3 --- /dev/null +++ b/testcnnect.py @@ -0,0 +1,8 @@ +import os +print("Import cv2") +import matplotlib +import cv2 +print(os.getcwd()) +# print file path +print(os.path.abspath(__file__)) +print("Done!") \ No newline at end of file From 55179568ca9a1621b1077bf3356adf3caf1b98bf Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 7 Mar 2024 15:37:34 +0800 Subject: [PATCH 03/42] ADD: add dataset files --- isegm/data/datasets/LIP.py | 241 ++++++++++++++++++++++++++++++ isegm/data/datasets/PASCAL.py | 219 +++++++++++++++++++++++++++ isegm/data/datasets/cityscapes.py | 195 ++++++++++++++++++++++++ 3 files changed, 655 insertions(+) create mode 100644 isegm/data/datasets/LIP.py create mode 100644 isegm/data/datasets/PASCAL.py create mode 100644 isegm/data/datasets/cityscapes.py diff --git a/isegm/data/datasets/LIP.py b/isegm/data/datasets/LIP.py new file mode 100644 index 0000000..ac299f4 --- /dev/null +++ b/isegm/data/datasets/LIP.py @@ -0,0 +1,241 @@ +import os +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample +from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes + +from tqdm import tqdm +import pickle + + +class LIP(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'trainval', 'test'} + self.name = 'LIP' + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "train_images" + self._insts_path = self.dataset_path / "train_segmentations" + self.init_path = self.dataset_path / "20_cls_interactive_point" + self.dataset_split = split + self.class_num = 20 # 这个class_num 指所有在miou中可以被计算的类,包含背景类但不包含忽略区域 + self.ignore_id = 255 + + self.loadfile = self.dataset_split+".pkl" + if os.path.exists(str(self.dataset_path/self.loadfile)): + with open(str(self.dataset_path/self.loadfile), 'rb') as file: + self.dataset_samples = pickle.load(file) + else: + dataset_samples = [] + idsfile = self.dataset_split+"_id.txt" + with open(str(self.dataset_path/idsfile), "r") as f: + id_list = [line.strip() for line in f.readlines()] + for id in id_list: + img_path = self._images_path/(id+".jpg") + gt_path = self._insts_path/(id+".png") + init_path = self.init_path/(id+".png") + dataset_samples.append((img_path, gt_path, init_path)) + image_id_lst = self.get_images_and_ids_list(dataset_samples) + self.dataset_samples = image_id_lst + # print(image_id_lst[:5]) + + ''' + def get_sample(self, index) -> DSample: + sample_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{sample_id}.jpg') + mask_path = str(self._insts_path / f'{sample_id}.png') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + if self.dataset_split == 'test': + instance_id = self.instance_ids[index] + mask = np.zeros_like(instances_mask) + mask[instances_mask == 220] = 220 # ignored area + mask[instances_mask == instance_id] = 1 + objects_ids = [1] + instances_mask = mask + else: + objects_ids = np.unique(instances_mask) + objects_ids = [x for x in objects_ids if x != 0 and x != 220] + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) + ''' + + def get_sample(self, index) -> DSample: + sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] + # sample_id = str(sample_id) + # print(sample_id) + # num_zero = 6 - len(sample_id) + # sample_id = '2007_'+'0'*num_zero + sample_id + + image_path = str(sample_path) + mask_path = str(target_path) + init_path = str(init_path) + + # print(image_path) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + + + mask = instances_mask + # mask[instances_mask == 255] = 220 # ignored area + # mask[instances_mask == instance_id] = 1 + objects_ids = instance_ids # 现在instance_ids 是一个列表 + instances_mask = mask + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[self.ignore_id], sample_id=index, init_clicks=init_path) + + def get_images_and_ids_list(self, dataset_samples, ignore_id = 255): + images_and_ids_list = [] + object_count = 0 + # for i in tqdm(range(len(dataset_samples))): + for i in range(len(dataset_samples)): + image_path, mask_path, init_path = dataset_samples[i] + instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != ignore_id] + object_count+=len(objects_ids) + # for j in objects_ids: + images_and_ids_list.append([image_path, mask_path ,objects_ids, init_path]) + # print(i,j,objects_ids) + with open(str(self.dataset_path/self.loadfile), "wb") as file: + pickle.dump(images_and_ids_list, file) + print("file count: "+str(len(dataset_samples))) + print("object count: "+str(object_count)) + return images_and_ids_list + +class LIP_train(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'trainval', 'test'} + + self._buggy_mask_thresh = 0.08 + self._buggy_objects = dict() + + self.name = 'LIP' + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "train_images" + self._insts_path = self.dataset_path / "train_segmentations" + self.init_path = self.dataset_path / "20_cls_interactive_point" + self.dataset_split = split + self.class_num = 19 # 这个class_num 指所有在miou中可以被计算的类,包含背景类但不包含忽略区域 + self.ignore_id = 0 + + self.loadfile = self.dataset_split+".pkl" + if os.path.exists(str(self.dataset_path/self.loadfile)): + with open(str(self.dataset_path/self.loadfile), 'rb') as file: + self.dataset_samples = pickle.load(file) + else: + dataset_samples = [] + idsfile = self.dataset_split+"_id.txt" + with open(str(self.dataset_path/idsfile), "r") as f: + id_list = [line.strip() for line in f.readlines()] + for id in id_list: + img_path = self._images_path/(id+".jpg") + gt_path = self._insts_path/(id+".png") + init_path = self.init_path/(id+".png") + dataset_samples.append((img_path, gt_path, init_path)) + image_id_lst = self.get_images_and_ids_list(dataset_samples) + self.dataset_samples = image_id_lst + # print(image_id_lst[:5]) + + ''' + def get_sample(self, index) -> DSample: + sample_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{sample_id}.jpg') + mask_path = str(self._insts_path / f'{sample_id}.png') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + if self.dataset_split == 'test': + instance_id = self.instance_ids[index] + mask = np.zeros_like(instances_mask) + mask[instances_mask == 220] = 220 # ignored area + mask[instances_mask == instance_id] = 1 + objects_ids = [1] + instances_mask = mask + else: + objects_ids = np.unique(instances_mask) + objects_ids = [x for x in objects_ids if x != 0 and x != 220] + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) + ''' + + def get_sample(self, index) -> DSample: + sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] + # sample_id = str(sample_id) + # print(sample_id) + # num_zero = 6 - len(sample_id) + # sample_id = '2007_'+'0'*num_zero + sample_id + + image_path = str(sample_path) + mask_path = str(target_path) + init_path = str(init_path) + + # print(image_path) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + + instances_mask = self.remove_buggy_masks(index, instances_mask) + instances_ids, _ = get_labels_with_sizes(instances_mask, ignoreid=0) + + objects_ids = instances_ids # 现在instance_ids 是一个列表 + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[0], sample_id=index) + + def get_images_and_ids_list(self, dataset_samples, ignore_id = 0): + images_and_ids_list = [] + object_count = 0 + # for i in tqdm(range(len(dataset_samples))): + for i in range(len(dataset_samples)): + image_path, mask_path, init_path = dataset_samples[i] + instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != ignore_id] + object_count+=len(objects_ids) + # for j in objects_ids: + images_and_ids_list.append([image_path, mask_path ,objects_ids, init_path]) + # print(i,j,objects_ids) + with open(str(self.dataset_path/self.loadfile), "wb") as file: + pickle.dump(images_and_ids_list, file) + print("file count: "+str(len(dataset_samples))) + print("object count: "+str(object_count)) + return images_and_ids_list + def remove_buggy_masks(self, index, instances_mask): + if self._buggy_mask_thresh > 0.0: + buggy_image_objects = self._buggy_objects.get(index, None) + if buggy_image_objects is None: + buggy_image_objects = [] + instances_ids, _ = get_labels_with_sizes(instances_mask) + for obj_id in instances_ids: + obj_mask = instances_mask == obj_id + mask_area = obj_mask.sum() + bbox = get_bbox_from_mask(obj_mask) + bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) + obj_area_ratio = mask_area / bbox_area + if obj_area_ratio < self._buggy_mask_thresh: + buggy_image_objects.append(obj_id) + + self._buggy_objects[index] = buggy_image_objects + for obj_id in buggy_image_objects: + instances_mask[instances_mask == obj_id] = 0 + + return instances_mask \ No newline at end of file diff --git a/isegm/data/datasets/PASCAL.py b/isegm/data/datasets/PASCAL.py new file mode 100644 index 0000000..8a4ba28 --- /dev/null +++ b/isegm/data/datasets/PASCAL.py @@ -0,0 +1,219 @@ +import os +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample +from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes +from tqdm import tqdm +import pickle + + +class PASCAL(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'trainval', 'test'} + + self._buggy_mask_thresh = 0.08 + self._buggy_objects = dict() + + self.name = 'PASCAL' + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "JPEGImages" + self._insts_path = self.dataset_path / "SegmentationPart_label" + self.init_path = self.dataset_path / "voc_person_interactive_center_point" + self.dataset_split = split + self.class_num = 7 # 这个class_num 指所有在miou中可以被计算的类,包含背景类但不包含忽略区域 + self.ignore_id = 255 + + self.loadfile = self.dataset_split+".pkl" + if os.path.exists(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile)): + with open(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile), 'rb') as file: + self.dataset_samples = pickle.load(file) + else: + dataset_samples = [] + idsfile = self.dataset_split+"_id.txt" + with open(str(self.dataset_path/"pascal_person_part_trainval_list"/idsfile), "r") as f: + id_list = [line.strip() for line in f.readlines()] + for id in id_list: + img_path = self._images_path/(id+".jpg") + gt_path = self._insts_path/(id+".png") + init_path = self.init_path/(id+".png") + dataset_samples.append((img_path, gt_path, init_path)) + image_id_lst = self.get_images_and_ids_list(dataset_samples) + self.dataset_samples = image_id_lst + # print(image_id_lst[:5]) + + ''' + def get_sample(self, index) -> DSample: + sample_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{sample_id}.jpg') + mask_path = str(self._insts_path / f'{sample_id}.png') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + if self.dataset_split == 'test': + instance_id = self.instance_ids[index] + mask = np.zeros_like(instances_mask) + mask[instances_mask == 220] = 220 # ignored area + mask[instances_mask == instance_id] = 1 + objects_ids = [1] + instances_mask = mask + else: + objects_ids = np.unique(instances_mask) + objects_ids = [x for x in objects_ids if x != 0 and x != 220] + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) + ''' + + def get_sample(self, index) -> DSample: + sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] + # sample_id = str(sample_id) + # print(sample_id) + # num_zero = 6 - len(sample_id) + # sample_id = '2007_'+'0'*num_zero + sample_id + + image_path = str(sample_path) + mask_path = str(target_path) + init_path = str(init_path) + + # print(image_path) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + + + mask = instances_mask + # mask[instances_mask == 255] = 220 # ignored area + # mask[instances_mask == instance_id] = 1 + objects_ids = instance_ids # 现在instance_ids 是一个列表 + instances_mask = mask + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[self.ignore_id], sample_id=index, init_clicks=init_path) + + def get_images_and_ids_list(self, dataset_samples, ignore_id = 255): + images_and_ids_list = [] + object_count = 0 + # for i in tqdm(range(len(dataset_samples))): + for i in range(len(dataset_samples)): + image_path, mask_path, init_path = dataset_samples[i] + instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != ignore_id] + object_count+=len(objects_ids) + # for j in objects_ids: + images_and_ids_list.append([image_path, mask_path ,objects_ids, init_path]) + # print(i,j,objects_ids) + with open(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile), "wb") as file: + pickle.dump(images_and_ids_list, file) + print("file count: "+str(len(dataset_samples))) + print("object count: "+str(object_count)) + return images_and_ids_list + +class PASCAL_train(ISDataset): #TODO train requires same ignore_id + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'trainval', 'test'} + self._buggy_mask_thresh = 0.05 + self._buggy_objects = dict() + + self.name = 'PASCAL' + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "JPEGImages" + self._insts_path = self.dataset_path / "SegmentationPart_label" + self.init_path = self.dataset_path / "voc_person_interactive_center_point" + self.dataset_split = split + self.class_num = 7 # 这个class_num 指所有在miou中可以被计算的类,包含背景类但不包含忽略区域 + self.ignore_id = 255 + + self.loadfile = self.dataset_split+".pkl" + if os.path.exists(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile)): + with open(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile), 'rb') as file: + self.dataset_samples = pickle.load(file) + else: + dataset_samples = [] + idsfile = self.dataset_split+"_id.txt" + with open(str(self.dataset_path/"pascal_person_part_trainval_list"/idsfile), "r") as f: + id_list = [line.strip() for line in f.readlines()] + for id in id_list: + img_path = self._images_path/(id+".jpg") + gt_path = self._insts_path/(id+".png") + init_path = self.init_path/(id+".png") + dataset_samples.append((img_path, gt_path, init_path)) + image_id_lst = self.get_images_and_ids_list(dataset_samples) + self.dataset_samples = image_id_lst + # print(image_id_lst[:5]) + + def get_sample(self, index) -> DSample: + sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] + # sample_id = str(sample_id) + # print(sample_id) + # num_zero = 6 - len(sample_id) + # sample_id = '2007_'+'0'*num_zero + sample_id + + image_path = str(sample_path) + mask_path = str(target_path) + init_path = str(init_path) + + # print(image_path) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + + instances_mask = self.remove_buggy_masks(index, instances_mask) + instances_ids, _ = get_labels_with_sizes(instances_mask, ignoreid=0) + + objects_ids = instances_ids # 现在instance_ids 是一个列表 + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[0], sample_id=index) + + def get_images_and_ids_list(self, dataset_samples, ignore_id = 0): + images_and_ids_list = [] + object_count = 0 + # for i in tqdm(range(len(dataset_samples))): + for i in range(len(dataset_samples)): + image_path, mask_path, init_path = dataset_samples[i] + instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != ignore_id] + object_count+=len(objects_ids) + # for j in objects_ids: + images_and_ids_list.append([image_path, mask_path ,objects_ids, init_path]) + # print(i,j,objects_ids) + with open(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile), "wb") as file: + pickle.dump(images_and_ids_list, file) + print("file count: "+str(len(dataset_samples))) + print("object count: "+str(object_count)) + return images_and_ids_list + def remove_buggy_masks(self, index, instances_mask): + if self._buggy_mask_thresh > 0.0: + buggy_image_objects = self._buggy_objects.get(index, None) + if buggy_image_objects is None: + buggy_image_objects = [] + instances_ids, _ = get_labels_with_sizes(instances_mask) + for obj_id in instances_ids: + obj_mask = instances_mask == obj_id + mask_area = obj_mask.sum() + bbox = get_bbox_from_mask(obj_mask) + bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) + obj_area_ratio = mask_area / bbox_area + if obj_area_ratio < self._buggy_mask_thresh: + buggy_image_objects.append(obj_id) + + self._buggy_objects[index] = buggy_image_objects + for obj_id in buggy_image_objects: + instances_mask[instances_mask == obj_id] = 0 + + return instances_mask \ No newline at end of file diff --git a/isegm/data/datasets/cityscapes.py b/isegm/data/datasets/cityscapes.py new file mode 100644 index 0000000..29032cd --- /dev/null +++ b/isegm/data/datasets/cityscapes.py @@ -0,0 +1,195 @@ +import os +import pickle as pkl +from pathlib import Path +import random +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample +from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes +from tqdm import tqdm +import pickle + + +class CityScapes(ISDataset): + def __init__(self, dataset_path, split="train", **kwargs): + super(CityScapes, self).__init__(**kwargs) + assert split in {"train", "val", "trainval", "test"} + self.name = "Cityscapes" + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "leftImg8bit" / split + self._insts_path = self.dataset_path / "gtFine" / split + self.init_path = self.dataset_path / "init_interactive_point" + self.dataset_split = split + self.class_num = 19 + self.ignore_id = 255 + + self.loadfile = self.dataset_split+".pkl" + if os.path.exists(str(self.dataset_path/self.loadfile)): + with open(str(self.dataset_path/self.loadfile), 'rb') as file: + self.dataset_samples = pickle.load(file) + else: + dataset_samples = [] + for city in os.listdir(self._images_path): + img_dir = self._images_path / city + target_dir = self._insts_path / city + init_dir = self.init_path / city + for file_name in os.listdir(img_dir): + toAddPath = img_dir / file_name + initName = file_name.replace("_leftImg8bit", "") + initPath = init_dir / initName + labelName = file_name.replace("leftImg8bit", "gtFine_labelTrainIds") + labelPath = target_dir / labelName + dataset_samples.append((toAddPath, labelPath, initPath)) + + image_id_lst = self.get_images_and_ids_list(dataset_samples) + self.dataset_samples = image_id_lst + # print(image_id_lst[:5]) + + def get_sample(self, index) -> DSample: + sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] + + image_path = str(sample_path) + mask_path = str(target_path) + init_path = str(init_path) + + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( + np.int32 + ) + + ids = [x for x in np.unique(instances_mask) if x != self.ignore_id] + + objects_ids = ids # 现在instance_ids 是一个列表 + + return DSample( + image, + instances_mask, + objects_ids=objects_ids, + ignore_ids=[self.ignore_id], + sample_id=index, + init_clicks=init_path, + ) + + def get_images_and_ids_list(self, dataset_samples): + images_and_ids_list = [] + object_count = 0 + # for i in tqdm(range(len(dataset_samples))): + for i in range(len(dataset_samples)): + image_path, mask_path, init_path = dataset_samples[i] + instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( + np.int32 + ) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != self.ignore_id] + object_count += len(objects_ids) + + images_and_ids_list.append([image_path, mask_path, objects_ids, init_path]) + + with open(str(self.dataset_path/self.loadfile), "wb") as file: + pickle.dump(images_and_ids_list, file) + return images_and_ids_list + + +class CityScapes_train(ISDataset): + def __init__(self, dataset_path, split="train", **kwargs): + super(CityScapes_train, self).__init__(**kwargs) + assert split in {"train", "val", "trainval", "test"} + + self._buggy_mask_thresh = 0.08 + self._buggy_objects = dict() + + self.name = "Cityscapes" + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "leftImg8bit" / split + self._insts_path = self.dataset_path / "gtFine" / split + self.dataset_split = split + + dataset_samples = [] + for city in os.listdir(self._images_path): + img_dir = self._images_path / city + target_dir = self._insts_path / city + + for file_name in os.listdir(img_dir): + toAddPath = img_dir / file_name + labelName = file_name.replace("leftImg8bit", "gtFine_labelTrainIds") + labelPath = target_dir / labelName + dataset_samples.append((toAddPath, labelPath)) + self.dataset_samples = dataset_samples + # print(image_id_lst[:5]) + + def get_sample(self, index) -> DSample: + sample_path, target_path = self.dataset_samples[index] + # sample_id = str(sample_id) + # print(sample_id) + # num_zero = 6 - len(sample_id) + # sample_id = '2007_'+'0'*num_zero + sample_id + + image_path = str(sample_path) + mask_path = str(target_path) + + # print(image_path) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( + np.int32 + ) + + instances_mask = self.remove_buggy_masks(index, instances_mask) + instances_ids, _ = get_labels_with_sizes(instances_mask, ignoreid=255) + + objects_ids = instances_ids + + return DSample( + image, + instances_mask, + objects_ids=objects_ids, + ignore_ids=[255], + sample_id=index, + ) + + def get_images_and_ids_list(self, dataset_samples): + images_and_ids_list = [] + # for i in tqdm(range(len(dataset_samples))): + for i in range(len(dataset_samples)): + image_path, mask_path = dataset_samples[i] + instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( + np.int32 + ) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != 255] + for j in objects_ids: + images_and_ids_list.append([image_path, mask_path, j]) + # print(i,j,objects_ids) + return images_and_ids_list + + def remove_buggy_masks(self, index, instances_mask): + if self._buggy_mask_thresh > 0.0: + buggy_image_objects = self._buggy_objects.get(index, None) + if buggy_image_objects is None: + buggy_image_objects = [] + instances_ids, _ = get_labels_with_sizes(instances_mask) + for obj_id in instances_ids: + obj_mask = instances_mask == obj_id + mask_area = obj_mask.sum() + bbox = get_bbox_from_mask(obj_mask) + bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) + obj_area_ratio = mask_area / bbox_area + if obj_area_ratio < self._buggy_mask_thresh: + buggy_image_objects.append(obj_id) + + self._buggy_objects[index] = buggy_image_objects + for obj_id in buggy_image_objects: + instances_mask[instances_mask == obj_id] = 0 + + return instances_mask From 9f657bfcdec3f687cae0df9c20aa28974bec82ef Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 7 Mar 2024 15:37:52 +0800 Subject: [PATCH 04/42] ADD: Changed test datset path --- config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.yml b/config.yml index faca3c5..87fc098 100755 --- a/config.yml +++ b/config.yml @@ -8,7 +8,7 @@ DAVIS_PATH: "/playpen-raid2/qinliu/data/DAVIS345" COCO_MVAL_PATH: "/playpen-raid2/qinliu/data/COCO_MVal" BraTS_PATH: "/playpen-raid2/qinliu/data/BraTS20" ssTEM_PATH: "/playpen-raid2/qinliu/data/ssTEM" -OAIZIB_PATH: "/playpen-raid2/qinliu/data/OAI-ZIB/iseg_slices" +OAIZIB_PATH: "./dataset/OAI-ZIB" OAI_PATH: "/playpen-raid2/qinliu/data/OAI" HARD_PATH: "/playpen-raid2/qinliu/data/HARD" @@ -40,4 +40,4 @@ IMAGENET_PRETRAINED_MODELS: SWIN_LARGE: "./weights/pretrained/swin_large_patch4_window12_384_22k.pth" MAE_BASE: "./weights/pretrained/mae_pretrain_vit_base.pth" MAE_LARGE: "./weights/pretrained/mae_pretrain_vit_large.pth" - MAE_HUGE: "./weights/pretrained/mae_pretrain_vit_huge.pth" + MAE_HUGE: "./weights/pretrained/cocolvis_vit_huge.pth" From 662e78ba5d1ba6d05578f412d73df5341b917d8d Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 7 Mar 2024 15:38:03 +0800 Subject: [PATCH 05/42] ADD: modded test code --- TEST_dataset_read.py | 87 ++++++++++++++++++++++++++++++++++++++++++++ TEST_simple_model.py | 87 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 TEST_dataset_read.py diff --git a/TEST_dataset_read.py b/TEST_dataset_read.py new file mode 100644 index 0000000..afc5cf7 --- /dev/null +++ b/TEST_dataset_read.py @@ -0,0 +1,87 @@ +from isegm.inference.clicker import Clicker, Click +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss + +def init_model(): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(16,16), + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1024, + out_dims = [192, 384, 768, 1536], + ) + + head_params = dict( + in_channels=[192, 384, 768, 1536], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=7, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample='x1', + channels={'x1': 256, 'x2': 128, 'x4': 64}['x1'], + ) + + model = MultiOutVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=False, + ) + + # model.backbone.init_weights_from_pretrained("./weights/pretrained/cocolvis_vit_huge.pth") + model.to("cuda") + + return model, model_cfg + +def get_points_nd(clicks_lists): + total_clicks = [] + num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] + num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_max_points = max(num_pos_clicks + num_neg_clicks) + num_max_points = max(1, num_max_points) + + for clicks_list in clicks_lists: + clicks_list = clicks_list[:5] + pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] + + neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + total_clicks.append(pos_clicks + neg_clicks) + + return torch.tensor(total_clicks, device="cuda") + +def add_mask(img): + input_image = torch.cat((img, torch.zeros(1,1,448,448).cuda()), dim=1) + return input_image + + +model, model_cfg = init_model() + +import torch +import cv2 +import numpy as np + +img = torch.rand(1, 3, 448, 448).cuda() +click = Click(is_positive=True, coords=(1, 1), indx=0) +click_list = [[click]] +out = model(add_mask(img), get_points_nd(click_list)) + +print("done!") \ No newline at end of file diff --git a/TEST_simple_model.py b/TEST_simple_model.py index e69de29..afc5cf7 100644 --- a/TEST_simple_model.py +++ b/TEST_simple_model.py @@ -0,0 +1,87 @@ +from isegm.inference.clicker import Clicker, Click +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss + +def init_model(): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(16,16), + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1024, + out_dims = [192, 384, 768, 1536], + ) + + head_params = dict( + in_channels=[192, 384, 768, 1536], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=7, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample='x1', + channels={'x1': 256, 'x2': 128, 'x4': 64}['x1'], + ) + + model = MultiOutVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=False, + ) + + # model.backbone.init_weights_from_pretrained("./weights/pretrained/cocolvis_vit_huge.pth") + model.to("cuda") + + return model, model_cfg + +def get_points_nd(clicks_lists): + total_clicks = [] + num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] + num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_max_points = max(num_pos_clicks + num_neg_clicks) + num_max_points = max(1, num_max_points) + + for clicks_list in clicks_lists: + clicks_list = clicks_list[:5] + pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] + + neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + total_clicks.append(pos_clicks + neg_clicks) + + return torch.tensor(total_clicks, device="cuda") + +def add_mask(img): + input_image = torch.cat((img, torch.zeros(1,1,448,448).cuda()), dim=1) + return input_image + + +model, model_cfg = init_model() + +import torch +import cv2 +import numpy as np + +img = torch.rand(1, 3, 448, 448).cuda() +click = Click(is_positive=True, coords=(1, 1), indx=0) +click_list = [[click]] +out = model(add_mask(img), get_points_nd(click_list)) + +print("done!") \ No newline at end of file From 063b8f4353dd3c007ad9a7aeabf503f6362c9d66 Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 7 Mar 2024 15:38:15 +0800 Subject: [PATCH 06/42] MOD: Change the model in train code --- models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py b/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py index ed2a397..1549913 100644 --- a/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py +++ b/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py @@ -36,7 +36,7 @@ def init_model(cfg): in_channels=[192, 384, 768, 1536], in_index=[0, 1, 2, 3], dropout_ratio=0.1, - num_classes=1, + num_classes=7, loss_decode=CrossEntropyLoss(), align_corners=False, upsample=cfg.upsample, From 980492ec0fe62e66327972376be0788f468d0cbc Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 7 Mar 2024 20:44:46 +0800 Subject: [PATCH 07/42] MOD: Dataset mod --- TEST_dataset_read.py | 95 ++------------ config.yml | 2 +- isegm/data/datasets/PASCAL.py | 90 +------------ isegm/data/datasets/__init__.py | 3 +- isegm/data/sample.py | 3 +- .../multi_out_huge448_pascal_itermask.py | 124 ++++++++++++++++++ ...i_out_vit_large448_pascal_new_scheduler.py | 12 +- 7 files changed, 151 insertions(+), 178 deletions(-) create mode 100644 models/iter_mask/multi_out_huge448_pascal_itermask.py diff --git a/TEST_dataset_read.py b/TEST_dataset_read.py index afc5cf7..1a1d8e9 100644 --- a/TEST_dataset_read.py +++ b/TEST_dataset_read.py @@ -1,87 +1,16 @@ -from isegm.inference.clicker import Clicker, Click -from isegm.model.is_plainvit_model import MultiOutVitModel -from isegm.utils.exp_imports.default import * -from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss - -def init_model(): - model_cfg = edict() - model_cfg.crop_size = (448, 448) - model_cfg.num_max_points = 24 - - backbone_params = dict( - img_size=model_cfg.crop_size, - patch_size=(16,16), - in_chans=3, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4, - qkv_bias=True, - ) - - neck_params = dict( - in_dim = 1024, - out_dims = [192, 384, 768, 1536], - ) - - head_params = dict( - in_channels=[192, 384, 768, 1536], - in_index=[0, 1, 2, 3], - dropout_ratio=0.1, - num_classes=7, - loss_decode=CrossEntropyLoss(), - align_corners=False, - upsample='x1', - channels={'x1': 256, 'x2': 128, 'x4': 64}['x1'], - ) - - model = MultiOutVitModel( - use_disks=True, - norm_radius=5, - with_prev_mask=True, - backbone_params=backbone_params, - neck_params=neck_params, - head_params=head_params, - random_split=False, - ) - - # model.backbone.init_weights_from_pretrained("./weights/pretrained/cocolvis_vit_huge.pth") - model.to("cuda") - - return model, model_cfg - -def get_points_nd(clicks_lists): - total_clicks = [] - num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] - num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] - num_max_points = max(num_pos_clicks + num_neg_clicks) - num_max_points = max(1, num_max_points) - - for clicks_list in clicks_lists: - clicks_list = clicks_list[:5] - pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] - pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] - - neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] - neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] - total_clicks.append(pos_clicks + neg_clicks) - - return torch.tensor(total_clicks, device="cuda") - -def add_mask(img): - input_image = torch.cat((img, torch.zeros(1,1,448,448).cuda()), dim=1) - return input_image - +from isegm.data.datasets.PASCAL import PASCAL +import cv2 +import matplotlib.pyplot as plt -model, model_cfg = init_model() +def show_sample(sample): + plt.imshow(sample.image) + plt.show() + plt.imshow(sample._encoded_masks) + plt.show() + print("done") -import torch -import cv2 -import numpy as np +dataset = PASCAL(dataset_path="/home/gyt/gyt/dataset/data/pascal_person_part", split='train') -img = torch.rand(1, 3, 448, 448).cuda() -click = Click(is_positive=True, coords=(1, 1), indx=0) -click_list = [[click]] -out = model(add_mask(img), get_points_nd(click_list)) +a_sample = dataset.get_sample(0) +show_sample(a_sample) -print("done!") \ No newline at end of file diff --git a/config.yml b/config.yml index 87fc098..0d426e1 100755 --- a/config.yml +++ b/config.yml @@ -40,4 +40,4 @@ IMAGENET_PRETRAINED_MODELS: SWIN_LARGE: "./weights/pretrained/swin_large_patch4_window12_384_22k.pth" MAE_BASE: "./weights/pretrained/mae_pretrain_vit_base.pth" MAE_LARGE: "./weights/pretrained/mae_pretrain_vit_large.pth" - MAE_HUGE: "./weights/pretrained/cocolvis_vit_huge.pth" + MAE_HUGE: "./weights/pretrained/mae_pretrain_vit_huge.pth" diff --git a/isegm/data/datasets/PASCAL.py b/isegm/data/datasets/PASCAL.py index 8a4ba28..3af1e04 100644 --- a/isegm/data/datasets/PASCAL.py +++ b/isegm/data/datasets/PASCAL.py @@ -45,7 +45,7 @@ def __init__(self, dataset_path, split='train', **kwargs): dataset_samples.append((img_path, gt_path, init_path)) image_id_lst = self.get_images_and_ids_list(dataset_samples) self.dataset_samples = image_id_lst - # print(image_id_lst[:5]) + # print(image_id_lst[:5]) ''' def get_sample(self, index) -> DSample: @@ -88,13 +88,9 @@ def get_sample(self, index) -> DSample: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) instances_mask = cv2.imread(mask_path) instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) - - - mask = instances_mask # mask[instances_mask == 255] = 220 # ignored area # mask[instances_mask == instance_id] = 1 objects_ids = instance_ids # 现在instance_ids 是一个列表 - instances_mask = mask return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[self.ignore_id], sample_id=index, init_clicks=init_path) def get_images_and_ids_list(self, dataset_samples, ignore_id = 255): @@ -117,86 +113,6 @@ def get_images_and_ids_list(self, dataset_samples, ignore_id = 255): print("file count: "+str(len(dataset_samples))) print("object count: "+str(object_count)) return images_and_ids_list - -class PASCAL_train(ISDataset): #TODO train requires same ignore_id - def __init__(self, dataset_path, split='train', **kwargs): - super().__init__(**kwargs) - assert split in {'train', 'val', 'trainval', 'test'} - self._buggy_mask_thresh = 0.05 - self._buggy_objects = dict() - - self.name = 'PASCAL' - self.dataset_path = Path(dataset_path) - self._images_path = self.dataset_path / "JPEGImages" - self._insts_path = self.dataset_path / "SegmentationPart_label" - self.init_path = self.dataset_path / "voc_person_interactive_center_point" - self.dataset_split = split - self.class_num = 7 # 这个class_num 指所有在miou中可以被计算的类,包含背景类但不包含忽略区域 - self.ignore_id = 255 - - self.loadfile = self.dataset_split+".pkl" - if os.path.exists(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile)): - with open(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile), 'rb') as file: - self.dataset_samples = pickle.load(file) - else: - dataset_samples = [] - idsfile = self.dataset_split+"_id.txt" - with open(str(self.dataset_path/"pascal_person_part_trainval_list"/idsfile), "r") as f: - id_list = [line.strip() for line in f.readlines()] - for id in id_list: - img_path = self._images_path/(id+".jpg") - gt_path = self._insts_path/(id+".png") - init_path = self.init_path/(id+".png") - dataset_samples.append((img_path, gt_path, init_path)) - image_id_lst = self.get_images_and_ids_list(dataset_samples) - self.dataset_samples = image_id_lst - # print(image_id_lst[:5]) - - def get_sample(self, index) -> DSample: - sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] - # sample_id = str(sample_id) - # print(sample_id) - # num_zero = 6 - len(sample_id) - # sample_id = '2007_'+'0'*num_zero + sample_id - - image_path = str(sample_path) - mask_path = str(target_path) - init_path = str(init_path) - - # print(image_path) - - image = cv2.imread(image_path) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - instances_mask = cv2.imread(mask_path) - instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) - - instances_mask = self.remove_buggy_masks(index, instances_mask) - instances_ids, _ = get_labels_with_sizes(instances_mask, ignoreid=0) - - objects_ids = instances_ids # 现在instance_ids 是一个列表 - - return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[0], sample_id=index) - - def get_images_and_ids_list(self, dataset_samples, ignore_id = 0): - images_and_ids_list = [] - object_count = 0 - # for i in tqdm(range(len(dataset_samples))): - for i in range(len(dataset_samples)): - image_path, mask_path, init_path = dataset_samples[i] - instances_mask = cv2.imread(str(mask_path)) - instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) - objects_ids = np.unique(instances_mask) - - objects_ids = [x for x in objects_ids if x != ignore_id] - object_count+=len(objects_ids) - # for j in objects_ids: - images_and_ids_list.append([image_path, mask_path ,objects_ids, init_path]) - # print(i,j,objects_ids) - with open(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile), "wb") as file: - pickle.dump(images_and_ids_list, file) - print("file count: "+str(len(dataset_samples))) - print("object count: "+str(object_count)) - return images_and_ids_list def remove_buggy_masks(self, index, instances_mask): if self._buggy_mask_thresh > 0.0: buggy_image_objects = self._buggy_objects.get(index, None) @@ -216,4 +132,6 @@ def remove_buggy_masks(self, index, instances_mask): for obj_id in buggy_image_objects: instances_mask[instances_mask == obj_id] = 0 - return instances_mask \ No newline at end of file + return instances_mask + + diff --git a/isegm/data/datasets/__init__.py b/isegm/data/datasets/__init__.py index b8110cb..c4f56dc 100755 --- a/isegm/data/datasets/__init__.py +++ b/isegm/data/datasets/__init__.py @@ -15,4 +15,5 @@ from .ssTEM import ssTEMDataset from .oai_zib import OAIZIBDataset from .oai import OAIDataset -from .hard import HARDDataset \ No newline at end of file +from .hard import HARDDataset +from .PASCAL import PASCAL \ No newline at end of file diff --git a/isegm/data/sample.py b/isegm/data/sample.py index a08c409..472caaa 100755 --- a/isegm/data/sample.py +++ b/isegm/data/sample.py @@ -7,9 +7,10 @@ class DSample: def __init__(self, image, encoded_masks, objects=None, - objects_ids=None, ignore_ids=None, sample_id=None): + objects_ids=None, ignore_ids=None, sample_id=None, init_clicks=None): self.image = image self.sample_id = sample_id + self.init_clicks = init_clicks if len(encoded_masks.shape) == 2: encoded_masks = encoded_masks[:, :, np.newaxis] diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py new file mode 100644 index 0000000..4bc3ec0 --- /dev/null +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -0,0 +1,124 @@ +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss + +MODEL_NAME = 'cocolvis_vit_huge448' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(14,14), + in_chans=3, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1280, + out_dims = [240, 480, 960, 1920], + ) + + head_params = dict( + in_channels=[240, 480, 960, 1920], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=7, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample=cfg.upsample, + channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample], + ) + + model = PlainVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=cfg.random_split, + ) + + model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_HUGE) + model.to(cfg.device) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = PASCAL( + cfg.PASCAL_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + stuff_prob=0.30 + ) + + valset = PASCAL( + cfg.PASCAL_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[50, 55], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + layerwise_decay=cfg.layerwise_decay, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 20), (50, 1)], + image_dump_interval=300, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=3) + trainer.run(num_epochs=55, validation=False) \ No newline at end of file diff --git a/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py b/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py index 1549913..a88cd14 100644 --- a/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py +++ b/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py @@ -3,7 +3,7 @@ from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss -MODEL_NAME = 'cocolvis_vit_large448' +MODEL_NAME = 'cocolvis_vit_huge448' def main(cfg): @@ -53,7 +53,7 @@ def init_model(cfg): random_split=cfg.random_split, ) - model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_LARGE) + model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_HUGE) model.to(cfg.device) return model, model_cfg @@ -86,8 +86,8 @@ def train(model, cfg, model_cfg): merge_objects_prob=0.15, max_num_merged_objects=2) - trainset = CocoLvisDataset( - cfg.LVIS_v1_PATH, + trainset = PASCAL( + cfg.PASCAL_PATH, split='train', augmentator=train_augmentator, min_object_area=1000, @@ -97,8 +97,8 @@ def train(model, cfg, model_cfg): stuff_prob=0.30 ) - valset = CocoLvisDataset( - cfg.LVIS_v1_PATH, + valset = PASCAL( + cfg.PASCAL_PATH, split='val', augmentator=val_augmentator, min_object_area=1000, From be9264040c89b9d0fe92c2dc7f3d8bf49321ab53 Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 7 Mar 2024 20:46:01 +0800 Subject: [PATCH 08/42] TAG: Ready to change trainer --- models/iter_mask/multi_out_huge448_pascal_itermask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index 4bc3ec0..eb7df98 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -92,7 +92,7 @@ def train(model, cfg, model_cfg): keep_background_prob=0.05, points_sampler=points_sampler, epoch_len=30000, - stuff_prob=0.30 + # stuff_prob=0.30 ) valset = PASCAL( From d6418570b6c4f026b287b844c3f3fcb3c577d84e Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 8 Mar 2024 00:11:10 +0800 Subject: [PATCH 09/42] ADD: add mutli trainer file --- isegm/engine/Multi_trainer.py | 418 ++++++++++++++++++++++++++++++++++ 1 file changed, 418 insertions(+) create mode 100755 isegm/engine/Multi_trainer.py diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py new file mode 100755 index 0000000..9b0488d --- /dev/null +++ b/isegm/engine/Multi_trainer.py @@ -0,0 +1,418 @@ +import os +import random +import logging +from copy import deepcopy +from collections import defaultdict + +import cv2 +import torch +import numpy as np +from tqdm import tqdm +from torch.utils.data import DataLoader + +from isegm.utils.log import logger, TqdmToLogger, SummaryWriterAvg +from isegm.utils.vis import draw_probmap, draw_points +from isegm.utils.misc import save_checkpoint +from isegm.utils.serialization import get_config_repr +from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict +from .optimizer import get_optimizer, get_optimizer_with_layerwise_decay +from .trainer import ISTrainer + + +class Multi_trainer(ISTrainer): + def __init__(self, model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=None, + layerwise_decay=False, + image_dump_interval=200, + checkpoint_interval=10, + tb_dump_period=25, + max_interactive_points=0, + lr_scheduler=None, + metrics=None, + additional_val_metrics=None, + net_inputs=('images', 'points'), + max_num_next_clicks=0, + click_models=None, + prev_mask_drop_prob=0.0, + ): + self.cfg = cfg + self.model_cfg = model_cfg + self.max_interactive_points = max_interactive_points + self.loss_cfg = loss_cfg + self.val_loss_cfg = deepcopy(loss_cfg) + self.tb_dump_period = tb_dump_period + self.net_inputs = net_inputs + self.max_num_next_clicks = max_num_next_clicks + + self.click_models = click_models + self.prev_mask_drop_prob = prev_mask_drop_prob + + if cfg.distributed: + cfg.batch_size //= cfg.ngpus + cfg.val_batch_size //= cfg.ngpus + + if metrics is None: + metrics = [] + self.train_metrics = metrics + self.val_metrics = deepcopy(metrics) + if additional_val_metrics is not None: + self.val_metrics.extend(additional_val_metrics) + + self.checkpoint_interval = checkpoint_interval + self.image_dump_interval = image_dump_interval + self.task_prefix = '' + self.sw = None + + self.trainset = trainset + self.valset = valset + + logger.info(f'Dataset of {trainset.get_samples_number()} samples was loaded for training.') + logger.info(f'Dataset of {valset.get_samples_number()} samples was loaded for validation.') + + self.train_data = DataLoader( + trainset, cfg.batch_size, + sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed), + drop_last=True, pin_memory=True, + num_workers=cfg.workers + ) + + self.val_data = DataLoader( + valset, cfg.val_batch_size, + sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed), + drop_last=True, pin_memory=True, + num_workers=cfg.workers + ) + + if layerwise_decay: + self.optim = get_optimizer_with_layerwise_decay(model, optimizer, optimizer_params) + else: + self.optim = get_optimizer(model, optimizer, optimizer_params) + model = self._load_weights(model) + + if cfg.multi_gpu: + model = get_dp_wrapper(cfg.distributed)(model, device_ids=cfg.gpu_ids, + output_device=cfg.gpu_ids[0]) + + if self.is_master: + logger.info(model) + logger.info(get_config_repr(model._config)) + + self.device = cfg.device + self.net = model.to(self.device) + self.lr = optimizer_params['lr'] + + if lr_scheduler is not None: + self.lr_scheduler = lr_scheduler(optimizer=self.optim) + if cfg.start_epoch > 0: + for _ in range(cfg.start_epoch): + self.lr_scheduler.step() + + self.tqdm_out = TqdmToLogger(logger, level=logging.INFO) + + if self.click_models is not None: + for click_model in self.click_models: + for param in click_model.parameters(): + param.requires_grad = False + click_model.to(self.device) + click_model.eval() + + def run(self, num_epochs, start_epoch=None, validation=True): + if start_epoch is None: + start_epoch = self.cfg.start_epoch + + logger.info(f'Starting Epoch: {start_epoch}') + logger.info(f'Total Epochs: {num_epochs}') + for epoch in range(start_epoch, num_epochs): + self.training(epoch) + if validation: + self.validation(epoch) + + def training(self, epoch): + if self.sw is None and self.is_master: + self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, dump_period=self.tb_dump_period) + + if self.cfg.distributed: + self.train_data.sampler.set_epoch(epoch) + + log_prefix = 'Train' + self.task_prefix.capitalize() + tbar = tqdm(self.train_data, file=self.tqdm_out, ncols=100)\ + if self.is_master else self.train_data + + for metric in self.train_metrics: + metric.reset_epoch_stats() + + self.net.train() + train_loss = 0.0 + for i, batch_data in enumerate(tbar): + global_step = epoch * len(self.train_data) + i + + loss, losses_logging, splitted_batch_data, outputs = \ + self.batch_forward(batch_data) + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + losses_logging['overall'] = loss + reduce_loss_dict(losses_logging) + + train_loss += losses_logging['overall'].item() + + if self.is_master: + for loss_name, loss_value in losses_logging.items(): + self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', + value=loss_value.item(), + global_step=global_step) + + for k, v in self.loss_cfg.items(): + if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0: + v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step) + + if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0: + self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train') + + self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate', + value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1], + global_step=global_step) + + tbar.set_description(f'Epoch {epoch}, training loss {train_loss/(i+1):.4f}') + for metric in self.train_metrics: + metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + + if self.is_master: + for metric in self.train_metrics: + self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', + value=metric.get_epoch_value(), + global_step=epoch, disable_avg=True) + + save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, + epoch=None, multi_gpu=self.cfg.multi_gpu) + + if isinstance(self.checkpoint_interval, (list, tuple)): + checkpoint_interval = [x for x in self.checkpoint_interval if x[0] <= epoch][-1][1] + else: + checkpoint_interval = self.checkpoint_interval + + if epoch % checkpoint_interval == 0: + save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, + epoch=epoch, multi_gpu=self.cfg.multi_gpu) + + if hasattr(self, 'lr_scheduler'): + self.lr_scheduler.step() + + def validation(self, epoch): + if self.sw is None and self.is_master: + self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, dump_period=self.tb_dump_period) + + log_prefix = 'Val' + self.task_prefix.capitalize() + tbar = tqdm(self.val_data, file=self.tqdm_out, ncols=100) if self.is_master else self.val_data + + for metric in self.val_metrics: + metric.reset_epoch_stats() + + val_loss = 0 + losses_logging = defaultdict(list) + + self.net.eval() + for i, batch_data in enumerate(tbar): + global_step = epoch * len(self.val_data) + i + loss, batch_losses_logging, splitted_batch_data, outputs = \ + self.batch_forward(batch_data, validation=True) + + batch_losses_logging['overall'] = loss + reduce_loss_dict(batch_losses_logging) + for loss_name, loss_value in batch_losses_logging.items(): + losses_logging[loss_name].append(loss_value.item()) + + val_loss += batch_losses_logging['overall'].item() + + if self.is_master: + tbar.set_description(f'Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}') + for metric in self.val_metrics: + metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + + if self.is_master: + for loss_name, loss_values in losses_logging.items(): + self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=np.array(loss_values).mean(), + global_step=epoch, disable_avg=True) + + for metric in self.val_metrics: + self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(), + global_step=epoch, disable_avg=True) + + def batch_forward(self, batch_data, validation=False): + metrics = self.val_metrics if validation else self.train_metrics + losses_logging = dict() + + with torch.set_grad_enabled(not validation): + batch_data = {k: v.to(self.device) for k, v in batch_data.items()} + image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points'] + orig_image, orig_gt_mask, orig_points = image.clone(), gt_mask.clone(), points.clone() + + prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] + + last_click_indx = None + + with torch.no_grad(): + num_iters = random.randint(0, self.max_num_next_clicks) + + for click_indx in range(num_iters): + last_click_indx = click_indx + + if not validation: + self.net.eval() + + if self.click_models is None or click_indx >= len(self.click_models): + eval_model = self.net + else: + eval_model = self.click_models[click_indx] + + net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image + prev_output = torch.sigmoid(eval_model(net_input, points)['instances']) + + points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1) + + if not validation: + self.net.train() + + if self.net.with_prev_mask and self.prev_mask_drop_prob > 0 and last_click_indx is not None: + zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob + prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask]) + + batch_data['points'] = points + + net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image + output = self.net(net_input, points) + + loss = 0.0 + loss = self.add_loss('instance_loss', loss, losses_logging, validation, + lambda: (output['instances'], batch_data['instances'])) + loss = self.add_loss('instance_aux_loss', loss, losses_logging, validation, + lambda: (output['instances_aux'], batch_data['instances'])) + + if self.is_master: + with torch.no_grad(): + for m in metrics: + m.update(*(output.get(x) for x in m.pred_outputs), + *(batch_data[x] for x in m.gt_outputs)) + return loss, losses_logging, batch_data, output + + def add_loss(self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs): + loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg + loss_weight = loss_cfg.get(loss_name + '_weight', 0.0) + if loss_weight > 0.0: + loss_criterion = loss_cfg.get(loss_name) + loss = loss_criterion(*lambda_loss_inputs()) + loss = torch.mean(loss) + losses_logging[loss_name] = loss + loss = loss_weight * loss + total_loss = total_loss + loss + + return total_loss + + def save_visualization(self, splitted_batch_data, outputs, global_step, prefix): + output_images_path = self.cfg.VIS_PATH / prefix + if self.task_prefix: + output_images_path /= self.task_prefix + + if not output_images_path.exists(): + output_images_path.mkdir(parents=True) + image_name_prefix = f'{global_step:06d}' + + def _save_image(suffix, image): + cv2.imwrite(str(output_images_path / f'{image_name_prefix}_{suffix}.jpg'), + image, [cv2.IMWRITE_JPEG_QUALITY, 85]) + + images = splitted_batch_data['images'] + points = splitted_batch_data['points'] + instance_masks = splitted_batch_data['instances'] + + gt_instance_masks = instance_masks.cpu().numpy() + predicted_instance_masks = torch.sigmoid(outputs['instances']).detach().cpu().numpy() + points = points.detach().cpu().numpy() + + image_blob, points = images[0], points[0] + gt_mask = np.squeeze(gt_instance_masks[0], axis=0) + predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0) + + image = image_blob.cpu().numpy() * 255 + image = image.transpose((1, 2, 0)) + + image_with_points = draw_points(image, points[:self.max_interactive_points], (0, 255, 0)) + image_with_points = draw_points(image_with_points, points[self.max_interactive_points:], (0, 0, 255)) + + gt_mask[gt_mask < 0] = 0.25 + gt_mask = draw_probmap(gt_mask) + predicted_mask = draw_probmap(predicted_mask) + viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(np.uint8) + + _save_image('instance_segmentation', viz_image[:, :, ::-1]) + + def _load_weights(self, net): + if self.cfg.weights is not None: + if os.path.isfile(self.cfg.weights): + load_weights(net, self.cfg.weights) + self.cfg.weights = None + else: + raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'") + elif self.cfg.resume_exp is not None: + checkpoints = list(self.cfg.CHECKPOINTS_PATH.glob(f'{self.cfg.resume_prefix}*.pth')) + assert len(checkpoints) == 1 + + checkpoint_path = checkpoints[0] + logger.info(f'Load checkpoint from path: {checkpoint_path}') + load_weights(net, str(checkpoint_path)) + return net + + @property + def is_master(self): + return self.cfg.local_rank == 0 + + +def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): + assert click_indx > 0 + pred = pred.cpu().numpy()[:, 0, :, :] + gt = gt.cpu().numpy()[:, 0, :, :] > 0.5 + + fn_mask = np.logical_and(gt, pred < pred_thresh) + fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) + + fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + num_points = points.size(1) // 2 + points = points.clone() + + for bindx in range(fn_mask.shape[0]): + fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] + fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] + + fn_max_dist = np.max(fn_mask_dt) + fp_max_dist = np.max(fp_mask_dt) + + is_positive = fn_max_dist > fp_max_dist + dt = fn_mask_dt if is_positive else fp_mask_dt + inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 + indices = np.argwhere(inner_mask) + if len(indices) > 0: + coords = indices[np.random.randint(0, len(indices))] + if is_positive: + points[bindx, num_points - click_indx, 0] = float(coords[0]) + points[bindx, num_points - click_indx, 1] = float(coords[1]) + points[bindx, num_points - click_indx, 2] = float(click_indx) + else: + points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) + points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) + points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) + + return points + + +def load_weights(model, path_to_weights): + current_state_dict = model.state_dict() + new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict'] + current_state_dict.update(new_state_dict) + model.load_state_dict(current_state_dict) From f6433b40dbdd7838ca2e2ac6f25b020153547764 Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 8 Mar 2024 00:11:32 +0800 Subject: [PATCH 10/42] DEL: delete old _large file which is a mistake --- ...i_out_vit_large448_pascal_new_scheduler.py | 126 ------------------ 1 file changed, 126 deletions(-) delete mode 100644 models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py diff --git a/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py b/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py deleted file mode 100644 index a88cd14..0000000 --- a/models/iter_mask/multi_out_vit_large448_pascal_new_scheduler.py +++ /dev/null @@ -1,126 +0,0 @@ -from isegm.model.is_plainvit_model import MultiOutVitModel -from isegm.utils.exp_imports.default import * -from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss - - -MODEL_NAME = 'cocolvis_vit_huge448' - - -def main(cfg): - model, model_cfg = init_model(cfg) - train(model, cfg, model_cfg) - - -def init_model(cfg): - model_cfg = edict() - model_cfg.crop_size = (448, 448) - model_cfg.num_max_points = 24 - - backbone_params = dict( - img_size=model_cfg.crop_size, - patch_size=(16,16), - in_chans=3, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4, - qkv_bias=True, - ) - - neck_params = dict( - in_dim = 1024, - out_dims = [192, 384, 768, 1536], - ) - - head_params = dict( - in_channels=[192, 384, 768, 1536], - in_index=[0, 1, 2, 3], - dropout_ratio=0.1, - num_classes=7, - loss_decode=CrossEntropyLoss(), - align_corners=False, - upsample=cfg.upsample, - channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample], - ) - - model = MultiOutVitModel( - use_disks=True, - norm_radius=5, - with_prev_mask=True, - backbone_params=backbone_params, - neck_params=neck_params, - head_params=head_params, - random_split=cfg.random_split, - ) - - model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_HUGE) - model.to(cfg.device) - - return model, model_cfg - - -def train(model, cfg, model_cfg): - cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size - cfg.val_batch_size = cfg.batch_size - crop_size = model_cfg.crop_size - - loss_cfg = edict() - loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) - loss_cfg.instance_loss_weight = 1.0 - - train_augmentator = Compose([ - UniformRandomResize(scale_range=(0.75, 1.40)), - HorizontalFlip(), - PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), - RandomCrop(*crop_size), - RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), - RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) - ], p=1.0) - - val_augmentator = Compose([ - PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), - RandomCrop(*crop_size) - ], p=1.0) - - points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, - merge_objects_prob=0.15, - max_num_merged_objects=2) - - trainset = PASCAL( - cfg.PASCAL_PATH, - split='train', - augmentator=train_augmentator, - min_object_area=1000, - keep_background_prob=0.05, - points_sampler=points_sampler, - epoch_len=30000, - stuff_prob=0.30 - ) - - valset = PASCAL( - cfg.PASCAL_PATH, - split='val', - augmentator=val_augmentator, - min_object_area=1000, - points_sampler=points_sampler, - epoch_len=2000 - ) - - optimizer_params = { - 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 - } - - lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, - milestones=[50, 55], gamma=0.1) - trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, - trainset, valset, - optimizer='adam', - optimizer_params=optimizer_params, - layerwise_decay=cfg.layerwise_decay, - lr_scheduler=lr_scheduler, - checkpoint_interval=[(0, 20), (50, 1)], - image_dump_interval=300, - metrics=[AdaptiveIoU()], - max_interactive_points=model_cfg.num_max_points, - max_num_next_clicks=3) - trainer.run(num_epochs=55, validation=False) From e4c7374ce087551781231a9b08c527e029a0c463 Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 8 Mar 2024 00:11:49 +0800 Subject: [PATCH 11/42] COMMENT: add line of comment --- isegm/model/is_plainvit_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/isegm/model/is_plainvit_model.py b/isegm/model/is_plainvit_model.py index 8422fd8..cc6ade3 100644 --- a/isegm/model/is_plainvit_model.py +++ b/isegm/model/is_plainvit_model.py @@ -94,7 +94,7 @@ def backbone_forward(self, image, coord_features=None): return {'instances': self.head(multi_scale_features), 'instances_aux': None} -class MultiOutVitModel(ISModel): +class MultiOutVitModel(ISModel): # unused @serialize def __init__( self, From d08d303c9391abd533ac47bf0705cace47a143a7 Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 8 Mar 2024 00:12:05 +0800 Subject: [PATCH 12/42] MOD: Change the trainer in new file to Multi_trainer --- models/iter_mask/multi_out_huge448_pascal_itermask.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index eb7df98..6a6212a 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -1,3 +1,5 @@ +from isegm.engine.Multi_trainer import Multi_trainer +from isegm.inference.clicker import Click from isegm.utils.exp_imports.default import * from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss @@ -110,7 +112,7 @@ def train(model, cfg, model_cfg): lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, milestones=[50, 55], gamma=0.1) - trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainer = Multi_trainer(model, cfg, model_cfg, loss_cfg, trainset, valset, optimizer='adam', optimizer_params=optimizer_params, From b81c4b147d3a746e22929cfb06e2e843298ff137 Mon Sep 17 00:00:00 2001 From: guba Date: Sun, 10 Mar 2024 20:35:41 +0800 Subject: [PATCH 13/42] MOD: Add some comments --- isegm/data/base.py | 13 +++++++++++++ isegm/engine/Multi_trainer.py | 4 ++-- isegm/model/is_model.py | 19 ++++++++++++++++++- isegm/model/ops.py | 12 ++++++------ 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/isegm/data/base.py b/isegm/data/base.py index ee2a532..04d3e6f 100755 --- a/isegm/data/base.py +++ b/isegm/data/base.py @@ -30,6 +30,19 @@ def __init__(self, self.dataset_samples = None def __getitem__(self, index): + ''' + + Args: + index: + + Returns: + { + 'images': torch.Tensor, # The image tensor, + 'points': np.ndarray, # Points, take max_num_points as 24, then shape is (48, 3). First 24 is pos, last 24 is neg. First few is [y, x, 100], then extended with (-1, -1, -1). + 'instances': np.ndarray # The mask + } + + ''' if self.samples_precomputed_scores is not None: index = np.random.choice(self.samples_precomputed_scores['indices'], p=self.samples_precomputed_scores['probs']) diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py index 9b0488d..f1e3a9e 100755 --- a/isegm/engine/Multi_trainer.py +++ b/isegm/engine/Multi_trainer.py @@ -258,7 +258,7 @@ def batch_forward(self, batch_data, validation=False): last_click_indx = None with torch.no_grad(): - num_iters = random.randint(0, self.max_num_next_clicks) + num_iters = random.randint(0, self.max_num_next_clicks) # Here max_num_next_clicks is 3 in default for click_indx in range(num_iters): last_click_indx = click_indx @@ -283,7 +283,7 @@ def batch_forward(self, batch_data, validation=False): zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask]) - batch_data['points'] = points + batch_data['points'] = points # write back if the `for click_indx` loop is executed, if it is not then is remains the same net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image output = self.net(net_input, points) diff --git a/isegm/model/is_model.py b/isegm/model/is_model.py index cf94aef..b003d39 100755 --- a/isegm/model/is_model.py +++ b/isegm/model/is_model.py @@ -36,7 +36,7 @@ def __init__(self, with_aux_output=False, norm_radius=5, use_disks=False, cpu_di self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps, use_disks=use_disks) - def forward(self, image, points): + def forward(self, image, points): # Here points is still in like (b, 48, 3) form image, prev_mask = self.prepare_input(image) coord_features = self.get_coord_features(image, prev_mask, points) coord_features = self.maps_transform(coord_features) @@ -51,6 +51,23 @@ def forward(self, image, points): return outputs def prepare_input(self, image): + ''' + + Args: + image: + + Returns: + image, prev_mask + + Notes: + Note that Image is in (b, 4, h, w) form via former transforms. + + Image is normalized. + + Prev_mask is sliced from the input image if with_prev_mask is True. Else it's None. + + + ''' prev_mask = None if self.with_prev_mask: prev_mask = image[:, 3:, :, :] diff --git a/isegm/model/ops.py b/isegm/model/ops.py index 9be9c73..abdf812 100755 --- a/isegm/model/ops.py +++ b/isegm/model/ops.py @@ -54,25 +54,25 @@ def get_coord_features(self, points, batchsize, rows, cols): norm_delimeter)) coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() else: - num_points = points.shape[1] // 2 - points = points.view(-1, points.size(2)) + num_points = points.shape[1] // 2 # 24 + points = points.view(-1, points.size(2)) # (b, 48, 3) -> (b*48, 3) points, points_order = torch.split(points, [2, 1], dim=1) - invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 + invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 # tensor of shape (96,) with True/False row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) coord_rows, coord_cols = torch.meshgrid(row_array, col_array) - coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) + coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) # 96, 2, 448, 448 add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) coords.add_(-add_xy) if not self.use_disks: coords.div_(self.norm_radius * self.spatial_scale) - coords.mul_(coords) + coords.mul_(coords) # 96, 2, h, w coords[:, 0] += coords[:, 1] - coords = coords[:, :1] + coords = coords[:, :1] # Till here, coords store the squared distance from the points to the chosen pixels coords[invalid_points, :, :, :] = 1e6 From 0f14c8cc884f5d1691214c43f68ec12db3bc3cef Mon Sep 17 00:00:00 2001 From: guba Date: Tue, 26 Mar 2024 22:15:12 +0800 Subject: [PATCH 14/42] I'm tired, every fucking thing together, don't care if it is a disaster --- TEST_Dataset.py | 39 ++++ TEST_read_trained_model.py | 221 ++++++++++++++++++ isegm/data/datasets/PASCAL.py | 48 ++++ isegm/data/points_sampler.py | 72 ++++++ isegm/engine/Multi_trainer.py | 82 ++++--- isegm/engine/trainer.py | 10 +- isegm/model/is_plainvit_model.py | 19 +- isegm/model/losses.py | 70 +++++- isegm/model/metrics.py | 83 +++++++ isegm/model/ops.py | 62 +++++ .../multi_out_huge448_pascal_itermask.py | 13 +- .../multi_out_huge448_pascal_itermask_tst.py | 126 ++++++++++ 12 files changed, 802 insertions(+), 43 deletions(-) create mode 100644 TEST_Dataset.py create mode 100644 TEST_read_trained_model.py create mode 100644 models/iter_mask/multi_out_huge448_pascal_itermask_tst.py diff --git a/TEST_Dataset.py b/TEST_Dataset.py new file mode 100644 index 0000000..95e0b11 --- /dev/null +++ b/TEST_Dataset.py @@ -0,0 +1,39 @@ +from isegm.data.datasets import PASCAL +from isegm.data.points_sampler import MultiClassSampler +import argparse +import os +from pathlib import Path + +from isegm.data.points_sampler import MultiClassSampler +from isegm.engine.Multi_trainer import Multi_trainer +from isegm.inference.clicker import Click +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.model.metrics import AdaptiveMIoU +from isegm.utils.exp import init_experiment +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss +from train import load_module + +points_sampler = MultiClassSampler(2, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) +trainset = PASCAL( + "/home/gyt/gyt/dataset/data/pascal_person_part", + split='train', + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + # stuff_prob=0.30 +) + +valset = PASCAL( + "/home/gyt/gyt/dataset/data/pascal_person_part", + split='val', + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 +) + +for batch_data in trainset: + print(batch_data["points"].shape) \ No newline at end of file diff --git a/TEST_read_trained_model.py b/TEST_read_trained_model.py new file mode 100644 index 0000000..c6b305d --- /dev/null +++ b/TEST_read_trained_model.py @@ -0,0 +1,221 @@ +import argparse +import os +from pathlib import Path + +from isegm.data.points_sampler import MultiClassSampler +from isegm.engine.Multi_trainer import Multi_trainer +from isegm.inference.clicker import Click +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.model.metrics import AdaptiveMIoU +from isegm.utils.exp import init_experiment +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss +from train import load_module + +MODEL_NAME = 'cocolvis_vit_huge448' + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('model_path', type=str, + help='Path to the model script.') + + parser.add_argument('--exp-name', type=str, default='', + help='Here you can specify the name of the experiment. ' + 'It will be added as a suffix to the experiment folder.') + + parser.add_argument('--workers', type=int, default=4, + metavar='N', help='Dataloader threads.') + + parser.add_argument('--batch-size', type=int, default=-1, + help='You can override model batch size by specify positive number.') + + parser.add_argument('--ngpus', type=int, default=1, + help='Number of GPUs. ' + 'If you only specify "--gpus" argument, the ngpus value will be calculated automatically. ' + 'You should use either this argument or "--gpus".') + + parser.add_argument('--gpus', type=str, default='', required=False, + help='Ids of used GPUs. You should use either this argument or "--ngpus".') + + parser.add_argument('--resume-exp', type=str, default=None, + help='The prefix of the name of the experiment to be continued. ' + 'If you use this field, you must specify the "--resume-prefix" argument.') + + parser.add_argument('--resume-prefix', type=str, default='latest', + help='The prefix of the name of the checkpoint to be loaded.') + + parser.add_argument('--start-epoch', type=int, default=0, + help='The number of the starting epoch from which training will continue. ' + '(it is important for correct logging and learning rate)') + + parser.add_argument('--weights', type=str, default=None, + help='Model weights will be loaded from the specified path if you use this argument.') + + parser.add_argument('--temp-model-path', type=str, default='', + help='Do not use this argument (for internal purposes).') + + parser.add_argument("--local_rank", type=int, default=0) + + # parameters for experimenting + parser.add_argument('--layerwise-decay', action='store_true', + help='layer wise decay for transformer blocks.') + + parser.add_argument('--upsample', type=str, default='x1', + help='upsample the output.') + + parser.add_argument('--random-split', action='store_true', + help='random split the patch instead of window split.') + + return parser.parse_args() +def main(): + model, model_cfg = init_model() + weight_path = "last_checkpoint.pth" + weights = torch.load(weight_path) + model.load_state_dict(weights['state_dict']) + model.eval() + cfg = edict() + + train(model, cfg, model_cfg) + + +def init_model(): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(14,14), + in_chans=3, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1280, + out_dims = [240, 480, 960, 1920], + ) + + head_params = dict( + in_channels=[240, 480, 960, 1920], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=7, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample='x1', + channels={'x1': 256, 'x2': 128, 'x4': 64}['x1'], + ) + + model = MultiOutVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=False, + ) + model.to('cuda') + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 1 + cfg.distributed = 'WORLD_SIZE' in os.environ + cfg.local_rank = 0 + cfg.workers = 4 + cfg.weights = "last_checkpoint.pth" + cfg.val_batch_size = cfg.batch_size + cfg.ngpus = 1 + cfg.device = torch.device('cuda') + cfg.start_epoch = 0 + cfg.multi_gpu = cfg.ngpus > 1 + crop_size = model_cfg.crop_size + + cfg.EXPS_PATH = 'TST_OUT' + experiments_path = Path(cfg.EXPS_PATH) + exp_parent_path = experiments_path / '/'.join("") + exp_parent_path.mkdir(parents=True, exist_ok=True) + + + last_exp_indx = 0 + exp_name = f'{last_exp_indx:03d}' + exp_path = exp_parent_path / exp_name + + if cfg.local_rank == 0: + exp_path.mkdir(parents=True, exist_ok=True) + + cfg.EXP_PATH = exp_path + cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' + cfg.VIS_PATH = exp_path / 'vis' + cfg.LOGS_PATH = exp_path / 'logs' + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedMultiFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiClassSampler(100, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = PASCAL( + "/home/gyt/gyt/dataset/data/pascal_person_part", + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + # stuff_prob=0.30 + ) + + valset = PASCAL( + "/home/gyt/gyt/dataset/data/pascal_person_part", + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[50, 55], gamma=0.1) + trainer = Multi_trainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 20), (50, 1)], + image_dump_interval=300, + metrics=[AdaptiveMIoU(num_classes=7)], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=15) + trainer.validation(epoch=0) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/isegm/data/datasets/PASCAL.py b/isegm/data/datasets/PASCAL.py index 3af1e04..ea73df1 100644 --- a/isegm/data/datasets/PASCAL.py +++ b/isegm/data/datasets/PASCAL.py @@ -1,5 +1,6 @@ import os import pickle as pkl +import random from pathlib import Path import cv2 @@ -133,5 +134,52 @@ def remove_buggy_masks(self, index, instances_mask): instances_mask[instances_mask == obj_id] = 0 return instances_mask + + def __getitem__(self, index): # points should be sampled from the whole mask + ''' + + Args: + index: + + Returns: + { + 'images': torch.Tensor, # The image tensor, + 'points': np.ndarray, # Points, take max_num_points as 24, then shape is (48, 3). First 24 is pos, last 24 is neg. First few is [y, x, 100], then extended with (-1, -1, -1). + 'instances': np.ndarray # The mask + } + + ''' + if self.samples_precomputed_scores is not None: + index = np.random.choice(self.samples_precomputed_scores['indices'], + p=self.samples_precomputed_scores['probs']) + else: + if self.epoch_len > 0: + index = random.randrange(0, len(self.dataset_samples)) + + sample = self.get_sample(index) + sample = self.augment_sample(sample) + sample.remove_small_objects(self.min_object_area) + init_points = cv2.imread(sample.init_clicks)[:,:,0] + rows, cols = np.where(init_points != 255) + non_255_values = init_points[rows, cols] + coords_and_values = list(zip(rows, cols, non_255_values)) + coords_and_values.extend([(-1, -1, -1)] * (self.points_sampler.max_num_points - len(coords_and_values))) + init_points = np.array(coords_and_values) + self.points_sampler.sample_object(sample) + points = np.array(self.points_sampler.sample_points()) + + mask = self.points_sampler.selected_mask + + output = { + 'images': self.to_tensor(sample.image), + 'points': init_points.astype(np.float32), + 'instances': mask, + # 'init_points': init_points.astype(np.float32) + } + + if self.with_image_info: + output['image_info'] = sample.sample_id + + return output diff --git a/isegm/data/points_sampler.py b/isegm/data/points_sampler.py index 79f0cd7..69b801b 100755 --- a/isegm/data/points_sampler.py +++ b/isegm/data/points_sampler.py @@ -303,3 +303,75 @@ def get_point_candidates(obj_mask, k=1.7, full_prob=0.0): click_indx = np.random.choice(len(prob_map), p=prob_map) click_coords = np.unravel_index(click_indx, dt.shape) return np.array([click_coords]) + +class MultiClassSampler(MultiPointSampler): + def sample_object(self, sample: DSample): + ''' + + Args: + sample: + + Returns: + + Notes: + Note that sample is a DSample instance. + ''' + self._selected_mask = sample._encoded_masks + + def sample_points(self): + ''' + Randomly sample points from gt_mask. \n + Returns: + + ''' + assert self._selected_mask is not None + num_points = self.max_num_points + # num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs) + h, w = self._selected_mask.shape[:2] + points = [] + y = np.random.randint(0, h, num_points) + x = np.random.randint(0, w, num_points) + for i in range(num_points): + cls = self._selected_mask[y[i], x[i]][0] + points.append([y[i], x[i], cls]) + + return points + + def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False): + selected_masks = selected_masks[:self.max_num_points] + + each_obj_points = [ + self._sample_points(mask, is_negative=is_negative[i], + with_first_click=with_first_click) + for i, mask in enumerate(selected_masks) + ] + each_obj_points = [x for x in each_obj_points if len(x) > 0] + + points = [] + if len(each_obj_points) == 1: + points = each_obj_points[0] + elif len(each_obj_points) > 1: + if self.only_one_first_click: + each_obj_points = each_obj_points[:1] + + points = [obj_points[0] for obj_points in each_obj_points] + + aggregated_masks_with_prob = [] + for indx, x in enumerate(selected_masks): + if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)): + for t, prob in x: + aggregated_masks_with_prob.append((t, prob / len(selected_masks))) + else: + aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks))) + + other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True) + if len(other_points_union) + len(points) <= self.max_num_points: + points.extend(other_points_union) + else: + points.extend(random.sample(other_points_union, self.max_num_points - len(points))) + + if len(points) < self.max_num_points: + points.extend([(-1, -1, -1)] * (self.max_num_points - len(points))) + + return points + diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py index f1e3a9e..9ee6d6e 100755 --- a/isegm/engine/Multi_trainer.py +++ b/isegm/engine/Multi_trainer.py @@ -171,8 +171,8 @@ def training(self, epoch): if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0: v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step) - if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0: - self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train') + # if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0: + # self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train') self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate', value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1], @@ -256,9 +256,9 @@ def batch_forward(self, batch_data, validation=False): prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] last_click_indx = None - + # First part with torch.no_grad(): - num_iters = random.randint(0, self.max_num_next_clicks) # Here max_num_next_clicks is 3 in default + num_iters = self.max_num_next_clicks # Here max_num_next_clicks is 3 in default for click_indx in range(num_iters): last_click_indx = click_indx @@ -273,7 +273,7 @@ def batch_forward(self, batch_data, validation=False): net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image prev_output = torch.sigmoid(eval_model(net_input, points)['instances']) - + prev_output = torch.max(prev_output, dim=1, keepdim=True)[1] points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1) if not validation: @@ -375,42 +375,62 @@ def is_master(self): def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): assert click_indx > 0 - pred = pred.cpu().numpy()[:, 0, :, :] - gt = gt.cpu().numpy()[:, 0, :, :] > 0.5 - - fn_mask = np.logical_and(gt, pred < pred_thresh) - fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) - - fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) - fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) - num_points = points.size(1) // 2 - points = points.clone() - - for bindx in range(fn_mask.shape[0]): - fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] - fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] + pred_o = pred.cpu().numpy()[:, 0, :, :] + gt_o = gt.cpu().numpy()[:, :, :, 0] + rows = [] + for bindx in range(pred.shape[0]): + gt = gt_o[bindx] + pred = pred_o[bindx] + areas = [] + for cls in np.unique(gt): + area_value, max_idx, _, _ = get_contours_and_maxidx(cls, gt, pred) + areas.append([cls, area_value]) + areas = sorted(areas, key=lambda x: x[1], reverse=True) + max_area_cls = areas[0][0] + area, max_idx, contours, fn_mask = get_contours_and_maxidx(max_area_cls, gt, pred) + for k in range(len(contours)): + if k != max_idx: + cv2.fillPoly(fn_mask, [contours[k]], 0) + + points = points.clone() + + fn_mask_dt = cv2.distanceTransform(fn_mask, cv2.DIST_L2, 5)[1:-1, 1:-1] fn_max_dist = np.max(fn_mask_dt) - fp_max_dist = np.max(fp_mask_dt) - is_positive = fn_max_dist > fp_max_dist - dt = fn_mask_dt if is_positive else fp_mask_dt - inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 + is_positive = True + dt = fn_mask_dt + inner_mask = dt > (fn_max_dist / 2.0) indices = np.argwhere(inner_mask) if len(indices) > 0: coords = indices[np.random.randint(0, len(indices))] - if is_positive: - points[bindx, num_points - click_indx, 0] = float(coords[0]) - points[bindx, num_points - click_indx, 1] = float(coords[1]) - points[bindx, num_points - click_indx, 2] = float(click_indx) - else: - points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) - points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) - points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) + new_row = np.array([float(coords[0]),float(coords[1]), float(max_area_cls)]) + rows.append(new_row) + else: + return points + rows = np.vstack(rows) + rows = rows[:,np.newaxis,:] + points = torch.cat((points, torch.from_numpy(rows).float().to(points.device)), dim=1) # TODO Point Number issue return points +def get_contours_and_maxidx(cls, gt, pred): + t_gt = gt == cls + t_pred = pred == cls + t_fn = np.logical_and(t_gt, np.logical_not(t_pred)).astype(np.uint8) + contours, hierarchy = cv2.findContours(t_fn, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + area = [] + for j in range(len(contours)): + area.append(cv2.contourArea(contours[j])) + if len(area) == 0: + return 0, 0, contours, t_fn + else: + max_idx = np.argmax(area) + area_value = area[max_idx] + return area_value, max_idx, contours, t_fn + + def load_weights(model, path_to_weights): current_state_dict = model.state_dict() new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict'] diff --git a/isegm/engine/trainer.py b/isegm/engine/trainer.py index 72a5a92..a6c4d97 100755 --- a/isegm/engine/trainer.py +++ b/isegm/engine/trainer.py @@ -381,21 +381,21 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) - fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) # A edge padding of 1 pixel of zero for cv2.distanceTransform num_points = points.size(1) // 2 points = points.clone() for bindx in range(fn_mask.shape[0]): fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] - fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] + fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] # Output a mask with each pixel with distance to nearest zero - fn_max_dist = np.max(fn_mask_dt) + fn_max_dist = np.max(fn_mask_dt) # max distance as value fp_max_dist = np.max(fp_mask_dt) is_positive = fn_max_dist > fp_max_dist dt = fn_mask_dt if is_positive else fp_mask_dt - inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 - indices = np.argwhere(inner_mask) + inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 # output a mask with each pixel with distance to nearest zero > (max distance / 2) + indices = np.argwhere(inner_mask) # get the nd list of the pixels that are True in the inner_mask if len(indices) > 0: coords = indices[np.random.randint(0, len(indices))] if is_positive: diff --git a/isegm/model/is_plainvit_model.py b/isegm/model/is_plainvit_model.py index cc6ade3..c9da0a4 100644 --- a/isegm/model/is_plainvit_model.py +++ b/isegm/model/is_plainvit_model.py @@ -4,6 +4,7 @@ from .is_model import ISModel from .modeling.models_vit import VisionTransformer, PatchEmbed from .modeling.swin_transformer import SwinTransfomerSegHead +from .ops import DistMaps, New_DistMaps class SimpleFPN(nn.Module): @@ -111,7 +112,7 @@ def __init__( self.patch_embed_coords = PatchEmbed( img_size= backbone_params['img_size'], patch_size=backbone_params['patch_size'], - in_chans=3 if self.with_prev_mask else 2, + in_chans=8 if self.with_prev_mask else 2, embed_dim=backbone_params['embed_dim'], ) @@ -119,6 +120,8 @@ def __init__( self.neck = SimpleFPN(**neck_params) self.head = SwinTransfomerSegHead(**head_params) self.testMsg = "Test Message" + self.dist_maps = New_DistMaps(norm_radius=5, spatial_scale=1.0, + cpu_mode=False, use_disks=False) def backbone_forward(self, image, coord_features=None): coord_features = self.patch_embed_coords(coord_features) @@ -132,3 +135,17 @@ def backbone_forward(self, image, coord_features=None): multi_scale_features = self.neck(backbone_features) return {'instances': self.head(multi_scale_features), 'instances_aux': None} + + def forward(self, image, points): # Here points is still in like (b, 48, 3) form + image, prev_mask = self.prepare_input(image) + coord_features = self.get_coord_features(image, prev_mask, points) + # coord_features = self.maps_transform(coord_features) + outputs = self.backbone_forward(image, coord_features) + + outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:], + mode='bilinear', align_corners=True) + if self.with_aux_output: + outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:], + mode='bilinear', align_corners=True) + + return outputs diff --git a/isegm/model/losses.py b/isegm/model/losses.py index 38c4ee3..aeda7d4 100755 --- a/isegm/model/losses.py +++ b/isegm/model/losses.py @@ -27,7 +27,7 @@ def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, self._k_sum = 0 self._m_max = 0 - def forward(self, pred, label): + def forward(self, pred, label): # TODO Error here one_hot = label > 0.5 sample_weight = label != self._ignore_label @@ -74,6 +74,74 @@ def log_states(self, sw, name, global_step): sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) +class NormalizedMultiFocalLossSigmoid(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, + from_sigmoid=False, detach_delimeter=True, + batch_axis=0, weight=None, size_average=True, + ignore_label=-1): + super(NormalizedMultiFocalLossSigmoid, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._from_logits = from_sigmoid + self._eps = eps + self._size_average = size_average + self._detach_delimeter = detach_delimeter + self._max_mult = max_mult + self._k_sum = 0 + self._m_max = 0 + + def forward(self, pred, label): # TODO Error here + label = torch.squeeze(label).to(torch.int64) + if len(label.shape) == 2: + label = label.unsqueeze(0) + N, H, W = label.shape + C = pred.shape[1] + label_one_hot = F.one_hot(label, num_classes=C).permute(0, 3, 1, 2).float() + sample_weight = (label != self._ignore_label).float() + sample_weight_expanded = sample_weight.unsqueeze(1) + sample_weight_expanded = sample_weight_expanded.expand(-1, C, -1, -1) + if not self._from_logits: + # 对于多分类问题,我们假设 pred 已经是 softmax 输出 + pred = F.softmax(pred, dim=1) + + alpha = torch.where(label_one_hot == 1, self._alpha * sample_weight_expanded, (1 - self._alpha) * sample_weight_expanded) + pt = torch.where(label_one_hot == 1, pred, 1 - pred) + + beta = (1 - pt) ** self._gamma + + sw_sum = torch.sum(sample_weight_expanded, dim=[0, 2, 3], keepdim=True) + beta_sum = torch.sum(beta, dim=[0, 2, 3], keepdim=True) + mult = sw_sum / (beta_sum + self._eps) + if self._detach_delimeter: + mult = mult.detach() + beta = beta * mult + + # 对 beta 进行裁剪 + if self._max_mult > 0: + beta = torch.clamp(beta, max=self._max_mult) + + # 计算损失 + loss = -alpha * beta * torch.log(torch.clamp(pt, min=self._eps)) + + # 应用样本权重 + loss = loss * sample_weight_expanded + + # 汇总损失 + if self._size_average: + loss = loss.sum(dim=[0, 2, 3]) / (sw_sum.sum(dim=[0, 2, 3]) + self._eps) + else: + loss = loss.sum(dim=[0, 2, 3]) + return loss.mean() + + def log_states(self, sw, name, global_step): + sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) + sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) + class FocalLoss(nn.Module): def __init__(self, axis=-1, alpha=0.25, gamma=2, from_logits=False, batch_axis=0, diff --git a/isegm/model/metrics.py b/isegm/model/metrics.py index a572dcd..0c48852 100755 --- a/isegm/model/metrics.py +++ b/isegm/model/metrics.py @@ -26,6 +26,89 @@ def name(self): return type(self).__name__ +class AdaptiveMIoU(TrainMetric): + def __init__(self, num_classes, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, + ignore_label=-1, from_logits=True, + pred_output='instances', gt_output='instances'): + super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) + self._ignore_label = ignore_label + self._from_logits = from_logits + self._iou_thresh = [init_thresh] * num_classes + self._thresh_step = thresh_step + self._thresh_beta = thresh_beta + self._iou_beta = iou_beta + self._ema_iou = np.zeros(num_classes) + self._epoch_iou_sum = np.zeros(num_classes) + self._epoch_batch_count = 0 + self.n_classes = num_classes + self.ignore_id = ignore_label + self.confusion_matrix = np.zeros((num_classes + 1, num_classes + 1), dtype=np.int64) + self.img_count = 0 + self.ob_count = 0 + self.click_count = 0 + + def update(self, label_preds, label_trues): + label_preds = torch.argmax(label_preds, dim=1) + label_trues = label_trues.squeeze() + if label_trues.dim() == 2: + label_trues = label_trues.unsqueeze(0) + for lt, lp in zip(label_trues, label_preds): + lt = lt.detach().cpu().numpy() + lp = lp.detach().cpu().numpy() + self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten()) + self._epoch_batch_count += 1 + + def _fast_hist(self,label_pred, label_true): + ignore_To = self.n_classes + label_pred = np.where(label_pred == self.ignore_id, ignore_To, label_pred) + mask = (label_true >= 0) & (label_true < self.n_classes) + hist = np.bincount( + (self.n_classes+1) * label_true[mask].astype(int) + label_pred[mask], + minlength=(self.n_classes+1) * (self.n_classes + 1), + ).reshape(self.n_classes+1, self.n_classes + 1) + return hist + + def get_epoch_value(self): + hist = self.confusion_matrix + acc = np.diag(hist).sum() / hist.sum() + acc_cls = np.diag(hist) / hist.sum(axis=1) + acc_cls = np.nanmean(acc_cls) + iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) + iu = iu[:-1] + mean_iu = np.nanmean(iu) + return mean_iu + + def reset_epoch_stats(self): + self.confusion_matrix = np.zeros((self.n_classes+1, self.n_classes+1),dtype=np.int64) + self._epoch_batch_count = 0 + + def log_states(self, sw, tag_prefix, global_step): + hist = self.confusion_matrix + acc = np.diag(hist).sum() / hist.sum() + acc_cls = np.diag(hist) / hist.sum(axis=1) + acc_cls = np.nanmean(acc_cls) + iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) + iu = iu[:-1] + mean_iu = np.nanmean(iu) + cls_iu = dict(zip(range(self.n_classes), iu)) + + for cls in range(self.n_classes): + sw.add_scalar(tag=f'{tag_prefix}_class_{cls}_ema_iou', value=cls_iu[cls], global_step=global_step) + + def get_miou(self, label_trues, label_preds): + hist = np.zeros(self.confusion_matrix.shape) + for lt, lp in zip(label_trues, label_preds): + hist += self._fast_hist(lt.flatten(), lp.flatten()) + acc_cls = np.diag(hist) / hist.sum(axis=1) + iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) + iu = iu[:-1] + mean_iu = np.nanmean(iu) + cls_iu = dict(zip(range(self.n_classes), iu)) + return { + "Mean IoU": mean_iu, + "Class IoU": cls_iu, + } + class AdaptiveIoU(TrainMetric): def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, ignore_label=-1, from_logits=True, diff --git a/isegm/model/ops.py b/isegm/model/ops.py index abdf812..b31cd22 100755 --- a/isegm/model/ops.py +++ b/isegm/model/ops.py @@ -90,6 +90,68 @@ def get_coord_features(self, points, batchsize, rows, cols): def forward(self, x, coords): return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) +class New_DistMaps(nn.Module): + def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False): + super(New_DistMaps, self).__init__() + self.spatial_scale = spatial_scale + self.norm_radius = norm_radius + self.cpu_mode = cpu_mode + self.use_disks = use_disks + if self.cpu_mode: + from isegm.utils.cython import get_dist_maps + self._get_dist_maps = get_dist_maps + + def get_coord_features(self, points, batchsize, rows, cols): + if self.cpu_mode: + coords = [] + for i in range(batchsize): + norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius + coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols, + norm_delimeter)) + coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() + else: + mask = ~(points == torch.tensor([-1, -1, -1],device=points.device)).all(dim=-1) + points = points[mask] + points = points.unsqueeze(0) + point_num = points.shape[1] + if point_num == 0: + print("No point") + exit(1) + else: + points = points.view(-1, points.size(2)) + points, points_cls = torch.split(points, [2, 1], dim=1) + invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 + row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) + col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) + coord_rows, coord_cols = torch.meshgrid(row_array, col_array) + + coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) + + add_xy = (points * 1).view(points.size(0), points.size(1), 1, 1) + coords.add_(-add_xy) + coords.mul_(coords) # 96, 2, h, w + coords[:, 0] += coords[:, 1] + coords = coords[:, :1] + coords[invalid_points, :, :, :] = 1e6 + coords = coords.view(-1, point_num, 1, rows, cols) + + coords = coords.view(-1, 1, rows, cols) + + coords = (coords <= (5) ** 2).float() + + zeros = torch.zeros_like(coords) + res = zeros.repeat(1, 7, 1, 1) # TODO 7 cls magic + + for i in range(len(points_cls)): + cls = int(points_cls[i]) + res[i, cls] = coords[i, 0] + res = res.view(-1,point_num, 7, rows, cols) # TODO 7 cls magic + res = res.max(dim=1)[0] + return res + + def forward(self, x, coords): + return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) + class ScaleLayer(nn.Module): def __init__(self, init_value=1.0, lr_mult=1): diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index 6a6212a..5f44a6e 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -1,5 +1,8 @@ +from isegm.data.points_sampler import MultiClassSampler from isegm.engine.Multi_trainer import Multi_trainer from isegm.inference.clicker import Click +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.model.metrics import AdaptiveMIoU from isegm.utils.exp_imports.default import * from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss @@ -43,7 +46,7 @@ def init_model(cfg): channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample], ) - model = PlainVitModel( + model = MultiOutVitModel( use_disks=True, norm_radius=5, with_prev_mask=True, @@ -65,7 +68,7 @@ def train(model, cfg, model_cfg): crop_size = model_cfg.crop_size loss_cfg = edict() - loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss = NormalizedMultiFocalLossSigmoid(alpha=0.5, gamma=2) loss_cfg.instance_loss_weight = 1.0 train_augmentator = Compose([ @@ -82,7 +85,7 @@ def train(model, cfg, model_cfg): RandomCrop(*crop_size) ], p=1.0) - points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.80, + points_sampler = MultiClassSampler(100, prob_gamma=0.80, merge_objects_prob=0.15, max_num_merged_objects=2) @@ -120,7 +123,7 @@ def train(model, cfg, model_cfg): lr_scheduler=lr_scheduler, checkpoint_interval=[(0, 20), (50, 1)], image_dump_interval=300, - metrics=[AdaptiveIoU()], + metrics=[AdaptiveMIoU(num_classes=7)], max_interactive_points=model_cfg.num_max_points, - max_num_next_clicks=3) + max_num_next_clicks=15) trainer.run(num_epochs=55, validation=False) \ No newline at end of file diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask_tst.py b/models/iter_mask/multi_out_huge448_pascal_itermask_tst.py new file mode 100644 index 0000000..3271c8c --- /dev/null +++ b/models/iter_mask/multi_out_huge448_pascal_itermask_tst.py @@ -0,0 +1,126 @@ +from isegm.engine.Multi_trainer import Multi_trainer +from isegm.inference.clicker import Click +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss + +MODEL_NAME = 'cocolvis_vit_huge448' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(14,14), + in_chans=3, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1280, + out_dims = [240, 480, 960, 1920], + ) + + head_params = dict( + in_channels=[240, 480, 960, 1920], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=1, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample=cfg.upsample, + channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample], + ) + + model = PlainVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=cfg.random_split, + ) + + model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_HUGE) + model.to(cfg.device) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(10, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = PASCAL( + cfg.PASCAL_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + # stuff_prob=0.30 + ) + + valset = PASCAL( + cfg.PASCAL_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[50, 55], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + layerwise_decay=cfg.layerwise_decay, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 20), (50, 1)], + image_dump_interval=300, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=3) + trainer.run(num_epochs=55, validation=False) \ No newline at end of file From 5f24ac1f8eb4a306632f5d9002ec22aa33fa151d Mon Sep 17 00:00:00 2001 From: guba Date: Tue, 26 Mar 2024 22:27:57 +0800 Subject: [PATCH 15/42] Add MIOU scalar to tensorboard 2 lines in metrics.py --- isegm/model/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/isegm/model/metrics.py b/isegm/model/metrics.py index 0c48852..7c06594 100755 --- a/isegm/model/metrics.py +++ b/isegm/model/metrics.py @@ -91,9 +91,9 @@ def log_states(self, sw, tag_prefix, global_step): iu = iu[:-1] mean_iu = np.nanmean(iu) cls_iu = dict(zip(range(self.n_classes), iu)) - + sw.add_scalar(tag=f'{tag_prefix}_miou', value=mean_iu, global_step=global_step) for cls in range(self.n_classes): - sw.add_scalar(tag=f'{tag_prefix}_class_{cls}_ema_iou', value=cls_iu[cls], global_step=global_step) + sw.add_scalar(tag=f'{tag_prefix}_class_{cls}_iou', value=cls_iu[cls], global_step=global_step) def get_miou(self, label_trues, label_preds): hist = np.zeros(self.confusion_matrix.shape) From b3fd11d7bc5af757953736d7cdd43a1aa3f6d76d Mon Sep 17 00:00:00 2001 From: guba Date: Tue, 26 Mar 2024 22:29:22 +0800 Subject: [PATCH 16/42] Better TST, now each tst file will appear in separate folder named by weight_path I know it implies that weights can only be put in root path but I DONT CARE --- TEST_read_trained_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/TEST_read_trained_model.py b/TEST_read_trained_model.py index c6b305d..793e353 100644 --- a/TEST_read_trained_model.py +++ b/TEST_read_trained_model.py @@ -76,7 +76,7 @@ def main(): model.load_state_dict(weights['state_dict']) model.eval() cfg = edict() - + cfg.weights = weight_path train(model, cfg, model_cfg) @@ -131,7 +131,6 @@ def train(model, cfg, model_cfg): cfg.distributed = 'WORLD_SIZE' in os.environ cfg.local_rank = 0 cfg.workers = 4 - cfg.weights = "last_checkpoint.pth" cfg.val_batch_size = cfg.batch_size cfg.ngpus = 1 cfg.device = torch.device('cuda') @@ -155,7 +154,7 @@ def train(model, cfg, model_cfg): cfg.EXP_PATH = exp_path cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' cfg.VIS_PATH = exp_path / 'vis' - cfg.LOGS_PATH = exp_path / 'logs' + cfg.LOGS_PATH = exp_path / 'logs' / cfg.weights loss_cfg = edict() loss_cfg.instance_loss = NormalizedMultiFocalLossSigmoid(alpha=0.5, gamma=2) From e2139f8aab80a2e2256ad4f7fbd61261a43bdbfd Mon Sep 17 00:00:00 2001 From: guba Date: Wed, 27 Mar 2024 17:21:46 +0800 Subject: [PATCH 17/42] Point sampler treat [-1,-1,-1] as no input (all zero mask) --- isegm/model/ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/isegm/model/ops.py b/isegm/model/ops.py index b31cd22..cb84f1a 100755 --- a/isegm/model/ops.py +++ b/isegm/model/ops.py @@ -115,8 +115,9 @@ def get_coord_features(self, points, batchsize, rows, cols): points = points.unsqueeze(0) point_num = points.shape[1] if point_num == 0: - print("No point") - exit(1) + zeros = torch.zeros(1, 1, rows, cols, device=points.device) + res = zeros.repeat(1, 7, 1, 1) # TODO 7 cls magic + return res else: points = points.view(-1, points.size(2)) points, points_cls = torch.split(points, [2, 1], dim=1) From 2b85d0bc980388b9d0fdd761a3ff25d85d9dc42e Mon Sep 17 00:00:00 2001 From: guba Date: Wed, 27 Mar 2024 17:22:00 +0800 Subject: [PATCH 18/42] Add a extra_name thing --- TEST_read_trained_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TEST_read_trained_model.py b/TEST_read_trained_model.py index 793e353..458e942 100644 --- a/TEST_read_trained_model.py +++ b/TEST_read_trained_model.py @@ -77,6 +77,7 @@ def main(): model.eval() cfg = edict() cfg.weights = weight_path + cfg.extra_name = "only_init" train(model, cfg, model_cfg) @@ -154,7 +155,7 @@ def train(model, cfg, model_cfg): cfg.EXP_PATH = exp_path cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' cfg.VIS_PATH = exp_path / 'vis' - cfg.LOGS_PATH = exp_path / 'logs' / cfg.weights + cfg.LOGS_PATH = exp_path / 'logs' / cfg.weights /cfg.extra_name loss_cfg = edict() loss_cfg.instance_loss = NormalizedMultiFocalLossSigmoid(alpha=0.5, gamma=2) From 5570b9dd066add67e788cb9d1f557b63f13ed024 Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 28 Mar 2024 16:27:33 +0800 Subject: [PATCH 19/42] Train: Comment code that mask (-1,-1,-1) out since they can't fulfill batch computation requirement --- isegm/model/ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/isegm/model/ops.py b/isegm/model/ops.py index cb84f1a..1085e8f 100755 --- a/isegm/model/ops.py +++ b/isegm/model/ops.py @@ -110,9 +110,9 @@ def get_coord_features(self, points, batchsize, rows, cols): norm_delimeter)) coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() else: - mask = ~(points == torch.tensor([-1, -1, -1],device=points.device)).all(dim=-1) - points = points[mask] - points = points.unsqueeze(0) + # mask = ~(points == torch.tensor([-1, -1, -1],device=points.device)).all(dim=-1) + # points = points[mask] + # points = points.unsqueeze(0) point_num = points.shape[1] if point_num == 0: zeros = torch.zeros(1, 1, rows, cols, device=points.device) From 3a36d6eac5d54a68ec51c720bec8ec13dc80a6e9 Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 28 Mar 2024 16:28:04 +0800 Subject: [PATCH 20/42] Train: Make the trainer inter for once in main file --- models/iter_mask/multi_out_huge448_pascal_itermask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index 5f44a6e..82c225b 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -125,5 +125,5 @@ def train(model, cfg, model_cfg): image_dump_interval=300, metrics=[AdaptiveMIoU(num_classes=7)], max_interactive_points=model_cfg.num_max_points, - max_num_next_clicks=15) + max_num_next_clicks=1) trainer.run(num_epochs=55, validation=False) \ No newline at end of file From 00c5a796431f970539d31b1575bf75532d15a297 Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 28 Mar 2024 16:34:24 +0800 Subject: [PATCH 21/42] Train: Get_next_point function modify --- isegm/engine/Multi_trainer.py | 50 ++++++++++------------------------- 1 file changed, 14 insertions(+), 36 deletions(-) diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py index 9ee6d6e..8fb9634 100755 --- a/isegm/engine/Multi_trainer.py +++ b/isegm/engine/Multi_trainer.py @@ -373,45 +373,23 @@ def is_master(self): return self.cfg.local_rank == 0 -def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): +def get_next_points(pred, gt, points, click_indx, points_num=15): assert click_indx > 0 pred_o = pred.cpu().numpy()[:, 0, :, :] gt_o = gt.cpu().numpy()[:, :, :, 0] - rows = [] - for bindx in range(pred.shape[0]): - gt = gt_o[bindx] - pred = pred_o[bindx] - areas = [] - for cls in np.unique(gt): - area_value, max_idx, _, _ = get_contours_and_maxidx(cls, gt, pred) - areas.append([cls, area_value]) - areas = sorted(areas, key=lambda x: x[1], reverse=True) - max_area_cls = areas[0][0] - area, max_idx, contours, fn_mask = get_contours_and_maxidx(max_area_cls, gt, pred) - for k in range(len(contours)): - if k != max_idx: - cv2.fillPoly(fn_mask, [contours[k]], 0) - - points = points.clone() - - fn_mask_dt = cv2.distanceTransform(fn_mask, cv2.DIST_L2, 5)[1:-1, 1:-1] - - fn_max_dist = np.max(fn_mask_dt) - - is_positive = True - dt = fn_mask_dt - inner_mask = dt > (fn_max_dist / 2.0) - indices = np.argwhere(inner_mask) - if len(indices) > 0: - coords = indices[np.random.randint(0, len(indices))] - new_row = np.array([float(coords[0]),float(coords[1]), float(max_area_cls)]) - rows.append(new_row) - else: - return points - - rows = np.vstack(rows) - rows = rows[:,np.newaxis,:] - points = torch.cat((points, torch.from_numpy(rows).float().to(points.device)), dim=1) # TODO Point Number issue + f_area = pred_o != gt_o + f_area_idx = np.where(f_area) + num_samples = min(points_num*gt_o.shape[0], len(f_area_idx[0])) + random_indices = np.random.choice(len(f_area_idx[0]), size=num_samples, replace=False) + _points = [[f_area_idx[0][x] + , f_area_idx[1][x] + , f_area_idx[2][x] + , gt_o[f_area_idx[0][x],f_area_idx[1][x], f_area_idx[2][x]] + ] for x in random_indices] + res_points = np.ones((gt_o.shape[0],num_samples,3))*(-1) + for i, point in enumerate(_points): + res_points[point[0], i] = [point[1], point[2], point[3]] + points = torch.cat((points, torch.from_numpy(res_points).float().to(points.device)), dim=1) # TODO Point Number issue return points From 34ca6ebd10e42503ae137a32b327f28c473cffc6 Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 28 Mar 2024 21:23:02 +0800 Subject: [PATCH 22/42] Train: point encoding to 255 --- isegm/model/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/isegm/model/ops.py b/isegm/model/ops.py index 1085e8f..27bfb94 100755 --- a/isegm/model/ops.py +++ b/isegm/model/ops.py @@ -148,7 +148,7 @@ def get_coord_features(self, points, batchsize, rows, cols): res[i, cls] = coords[i, 0] res = res.view(-1,point_num, 7, rows, cols) # TODO 7 cls magic res = res.max(dim=1)[0] - return res + return res*255 def forward(self, x, coords): return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) From 554893d8811ed8c878cf3e79052e276e544b91b0 Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 28 Mar 2024 21:52:08 +0800 Subject: [PATCH 23/42] Train: Complete the __len__ function in PASCAL --- isegm/data/datasets/PASCAL.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/isegm/data/datasets/PASCAL.py b/isegm/data/datasets/PASCAL.py index ea73df1..5da4b82 100644 --- a/isegm/data/datasets/PASCAL.py +++ b/isegm/data/datasets/PASCAL.py @@ -135,6 +135,9 @@ def remove_buggy_masks(self, index, instances_mask): return instances_mask + def __len__(self): + return len(self.dataset_samples) + def __getitem__(self, index): # points should be sampled from the whole mask ''' From aa9965d2bd515da1a0bdff78581d4993bf0432b4 Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 18 Apr 2024 15:02:36 +0800 Subject: [PATCH 24/42] DELETE : DELETE A TEST FILE --- TEST_dataset_read.py | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 TEST_dataset_read.py diff --git a/TEST_dataset_read.py b/TEST_dataset_read.py deleted file mode 100644 index 1a1d8e9..0000000 --- a/TEST_dataset_read.py +++ /dev/null @@ -1,16 +0,0 @@ -from isegm.data.datasets.PASCAL import PASCAL -import cv2 -import matplotlib.pyplot as plt - -def show_sample(sample): - plt.imshow(sample.image) - plt.show() - plt.imshow(sample._encoded_masks) - plt.show() - print("done") - -dataset = PASCAL(dataset_path="/home/gyt/gyt/dataset/data/pascal_person_part", split='train') - -a_sample = dataset.get_sample(0) -show_sample(a_sample) - From 0714d38278bbac720d2e6999db7ccab2239f628a Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 19 Apr 2024 14:48:27 +0800 Subject: [PATCH 25/42] DATASET: FIX CS DATASET --- config.yml | 1 + isegm/data/datasets/cityscapes.py | 143 +++++------------- .../multi_out_huge448_pascal_itermask.py | 9 +- 3 files changed, 40 insertions(+), 113 deletions(-) diff --git a/config.yml b/config.yml index 0d426e1..c372dbb 100755 --- a/config.yml +++ b/config.yml @@ -20,6 +20,7 @@ OPENIMAGES_PATH: "./datasets/OpenImages" PASCALVOC_PATH: "/playpen-raid2/qinliu/data/PascalVOC" ADE20K_PATH: "./datasets/ADE20K" PASCAL_PATH: "/home/gyt/gyt/dataset/data/pascal_person_part" +CITYSCAPES_PATH: "/home/gyt/gyt/dataset/data/cityscapes" # You can download the weights for HRNet from the repository: # https://github.com/HRNet/HRNet-Image-Classification diff --git a/isegm/data/datasets/cityscapes.py b/isegm/data/datasets/cityscapes.py index 29032cd..389b0d3 100644 --- a/isegm/data/datasets/cityscapes.py +++ b/isegm/data/datasets/cityscapes.py @@ -13,36 +13,35 @@ class CityScapes(ISDataset): - def __init__(self, dataset_path, split="train", **kwargs): + def __init__(self, dataset_path, split="train", use_cache=True, **kwargs): super(CityScapes, self).__init__(**kwargs) assert split in {"train", "val", "trainval", "test"} self.name = "Cityscapes" self.dataset_path = Path(dataset_path) - self._images_path = self.dataset_path / "leftImg8bit" / split - self._insts_path = self.dataset_path / "gtFine" / split - self.init_path = self.dataset_path / "init_interactive_point" + self._images_path = Path("leftImg8bit") / split + self._insts_path = Path("gtFine") / split + self.init_path = Path("init_interactive_point") self.dataset_split = split self.class_num = 19 self.ignore_id = 255 self.loadfile = self.dataset_split+".pkl" - if os.path.exists(str(self.dataset_path/self.loadfile)): + if os.path.exists(str(self.dataset_path/self.loadfile)) and use_cache: with open(str(self.dataset_path/self.loadfile), 'rb') as file: self.dataset_samples = pickle.load(file) else: dataset_samples = [] - for city in os.listdir(self._images_path): + for city in os.listdir(self.dataset_path/self._images_path): img_dir = self._images_path / city target_dir = self._insts_path / city init_dir = self.init_path / city - for file_name in os.listdir(img_dir): + for file_name in os.listdir(self.dataset_path/img_dir): toAddPath = img_dir / file_name initName = file_name.replace("_leftImg8bit", "") initPath = init_dir / initName labelName = file_name.replace("leftImg8bit", "gtFine_labelTrainIds") labelPath = target_dir / labelName dataset_samples.append((toAddPath, labelPath, initPath)) - image_id_lst = self.get_images_and_ids_list(dataset_samples) self.dataset_samples = image_id_lst # print(image_id_lst[:5]) @@ -50,9 +49,9 @@ def __init__(self, dataset_path, split="train", **kwargs): def get_sample(self, index) -> DSample: sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] - image_path = str(sample_path) - mask_path = str(target_path) - init_path = str(init_path) + image_path = str(self.dataset_path/sample_path) + mask_path = str(self.dataset_path/target_path) + init_path = str(self.dataset_path/init_path) image = cv2.imread(image_path) @@ -78,10 +77,10 @@ def get_sample(self, index) -> DSample: def get_images_and_ids_list(self, dataset_samples): images_and_ids_list = [] object_count = 0 - # for i in tqdm(range(len(dataset_samples))): - for i in range(len(dataset_samples)): + for i in tqdm(range(len(dataset_samples))): + # for i in range(len(dataset_samples)): image_path, mask_path, init_path = dataset_samples[i] - instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.imread(str(self.dataset_path/mask_path)) instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( np.int32 ) @@ -96,100 +95,26 @@ def get_images_and_ids_list(self, dataset_samples): pickle.dump(images_and_ids_list, file) return images_and_ids_list + def __len__(self): + return len(self.dataset_samples) + + def __getitem__(self, index): + sample = self.get_sample(index) + sample = self.augment_sample(sample) + init_points = cv2.imread(sample.init_clicks)[:, :, 0] + rows, cols = np.where(init_points != 255) + non_255_values = init_points[rows, cols] + coords_and_values = list(zip(rows, cols, non_255_values)) + coords_and_values.extend([(-1, -1, -1)] * (self.points_sampler.max_num_points - len(coords_and_values))) + init_points = np.array(coords_and_values) + self.points_sampler.sample_object(sample) + mask = self.points_sampler.selected_mask + output = { + 'images': self.to_tensor(sample.image), + 'points': init_points.astype(np.float32), + 'instances': mask, + # 'init_points': init_points.astype(np.float32) + } + return output -class CityScapes_train(ISDataset): - def __init__(self, dataset_path, split="train", **kwargs): - super(CityScapes_train, self).__init__(**kwargs) - assert split in {"train", "val", "trainval", "test"} - - self._buggy_mask_thresh = 0.08 - self._buggy_objects = dict() - - self.name = "Cityscapes" - self.dataset_path = Path(dataset_path) - self._images_path = self.dataset_path / "leftImg8bit" / split - self._insts_path = self.dataset_path / "gtFine" / split - self.dataset_split = split - - dataset_samples = [] - for city in os.listdir(self._images_path): - img_dir = self._images_path / city - target_dir = self._insts_path / city - - for file_name in os.listdir(img_dir): - toAddPath = img_dir / file_name - labelName = file_name.replace("leftImg8bit", "gtFine_labelTrainIds") - labelPath = target_dir / labelName - dataset_samples.append((toAddPath, labelPath)) - self.dataset_samples = dataset_samples - # print(image_id_lst[:5]) - - def get_sample(self, index) -> DSample: - sample_path, target_path = self.dataset_samples[index] - # sample_id = str(sample_id) - # print(sample_id) - # num_zero = 6 - len(sample_id) - # sample_id = '2007_'+'0'*num_zero + sample_id - - image_path = str(sample_path) - mask_path = str(target_path) - - # print(image_path) - - image = cv2.imread(image_path) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - instances_mask = cv2.imread(mask_path) - instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( - np.int32 - ) - - instances_mask = self.remove_buggy_masks(index, instances_mask) - instances_ids, _ = get_labels_with_sizes(instances_mask, ignoreid=255) - - objects_ids = instances_ids - - return DSample( - image, - instances_mask, - objects_ids=objects_ids, - ignore_ids=[255], - sample_id=index, - ) - - def get_images_and_ids_list(self, dataset_samples): - images_and_ids_list = [] - # for i in tqdm(range(len(dataset_samples))): - for i in range(len(dataset_samples)): - image_path, mask_path = dataset_samples[i] - instances_mask = cv2.imread(str(mask_path)) - instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( - np.int32 - ) - objects_ids = np.unique(instances_mask) - - objects_ids = [x for x in objects_ids if x != 255] - for j in objects_ids: - images_and_ids_list.append([image_path, mask_path, j]) - # print(i,j,objects_ids) - return images_and_ids_list - def remove_buggy_masks(self, index, instances_mask): - if self._buggy_mask_thresh > 0.0: - buggy_image_objects = self._buggy_objects.get(index, None) - if buggy_image_objects is None: - buggy_image_objects = [] - instances_ids, _ = get_labels_with_sizes(instances_mask) - for obj_id in instances_ids: - obj_mask = instances_mask == obj_id - mask_area = obj_mask.sum() - bbox = get_bbox_from_mask(obj_mask) - bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) - obj_area_ratio = mask_area / bbox_area - if obj_area_ratio < self._buggy_mask_thresh: - buggy_image_objects.append(obj_id) - - self._buggy_objects[index] = buggy_image_objects - for obj_id in buggy_image_objects: - instances_mask[instances_mask == obj_id] = 0 - - return instances_mask diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index 82c225b..cc80587 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -1,3 +1,4 @@ +from isegm.data.datasets.cityscapes import CityScapes from isegm.data.points_sampler import MultiClassSampler from isegm.engine.Multi_trainer import Multi_trainer from isegm.inference.clicker import Click @@ -89,8 +90,8 @@ def train(model, cfg, model_cfg): merge_objects_prob=0.15, max_num_merged_objects=2) - trainset = PASCAL( - cfg.PASCAL_PATH, + trainset = CityScapes( + cfg.CITYSCAPES_PATH, split='train', augmentator=train_augmentator, min_object_area=1000, @@ -100,8 +101,8 @@ def train(model, cfg, model_cfg): # stuff_prob=0.30 ) - valset = PASCAL( - cfg.PASCAL_PATH, + valset = CityScapes( + cfg.CITYSCAPES_PATH, split='val', augmentator=val_augmentator, min_object_area=1000, From cb07c1c3ad182bbb3a9c553c66f198d11d6ba6e8 Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 19 Apr 2024 16:23:01 +0800 Subject: [PATCH 26/42] MODEL: Structure changes --- isegm/model/is_plainvit_model.py | 2 +- models/iter_mask/multi_out_huge448_pascal_itermask.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/isegm/model/is_plainvit_model.py b/isegm/model/is_plainvit_model.py index c9da0a4..bc68247 100644 --- a/isegm/model/is_plainvit_model.py +++ b/isegm/model/is_plainvit_model.py @@ -112,7 +112,7 @@ def __init__( self.patch_embed_coords = PatchEmbed( img_size= backbone_params['img_size'], patch_size=backbone_params['patch_size'], - in_chans=8 if self.with_prev_mask else 2, + in_chans=20 if self.with_prev_mask else 2, embed_dim=backbone_params['embed_dim'], ) diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index cc80587..be00839 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -40,7 +40,7 @@ def init_model(cfg): in_channels=[240, 480, 960, 1920], in_index=[0, 1, 2, 3], dropout_ratio=0.1, - num_classes=7, + num_classes=19, loss_decode=CrossEntropyLoss(), align_corners=False, upsample=cfg.upsample, From 7e6e9cacede64e287ee7e8f11525813cb9988b09 Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 19 Apr 2024 16:23:17 +0800 Subject: [PATCH 27/42] Loss: change loss for 255 ignore --- isegm/model/losses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/isegm/model/losses.py b/isegm/model/losses.py index aeda7d4..f0e3b24 100755 --- a/isegm/model/losses.py +++ b/isegm/model/losses.py @@ -101,7 +101,8 @@ def forward(self, pred, label): # TODO Error here label = label.unsqueeze(0) N, H, W = label.shape C = pred.shape[1] - label_one_hot = F.one_hot(label, num_classes=C).permute(0, 3, 1, 2).float() + label[label==self._ignore_label] = C + label_one_hot = F.one_hot(label, num_classes=C+1)[:,:,:,:C].permute(0, 3, 1, 2).float() sample_weight = (label != self._ignore_label).float() sample_weight_expanded = sample_weight.unsqueeze(1) sample_weight_expanded = sample_weight_expanded.expand(-1, C, -1, -1) From 4f481f4ca345da71bf91d9857cebcc2db4a9be71 Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 19 Apr 2024 16:23:38 +0800 Subject: [PATCH 28/42] Metric: Change metric calc for 255 ignore --- isegm/model/metrics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/isegm/model/metrics.py b/isegm/model/metrics.py index 7c06594..d8c3531 100755 --- a/isegm/model/metrics.py +++ b/isegm/model/metrics.py @@ -53,9 +53,9 @@ def update(self, label_preds, label_trues): if label_trues.dim() == 2: label_trues = label_trues.unsqueeze(0) for lt, lp in zip(label_trues, label_preds): - lt = lt.detach().cpu().numpy() - lp = lp.detach().cpu().numpy() - self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten()) + lt_ = lt.detach().cpu().numpy() + lp_ = lp.detach().cpu().numpy() + self.confusion_matrix += self._fast_hist(lt_.flatten(), lp_.flatten()) self._epoch_batch_count += 1 def _fast_hist(self,label_pred, label_true): @@ -84,9 +84,9 @@ def reset_epoch_stats(self): def log_states(self, sw, tag_prefix, global_step): hist = self.confusion_matrix - acc = np.diag(hist).sum() / hist.sum() - acc_cls = np.diag(hist) / hist.sum(axis=1) - acc_cls = np.nanmean(acc_cls) + # acc = np.diag(hist).sum() / hist.sum() + # acc_cls = np.diag(hist) / hist.sum(axis=1) + # acc_cls = np.nanmean(acc_cls) iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) iu = iu[:-1] mean_iu = np.nanmean(iu) From 266d4a374dee55bd2a9534be7385767d0567408a Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 19 Apr 2024 16:24:25 +0800 Subject: [PATCH 29/42] Points: CHanges point calc for 255 ignore and 19 cls --- isegm/data/points_sampler.py | 6 ++++-- isegm/engine/Multi_trainer.py | 2 +- isegm/model/ops.py | 6 +++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/isegm/data/points_sampler.py b/isegm/data/points_sampler.py index 69b801b..249efd1 100755 --- a/isegm/data/points_sampler.py +++ b/isegm/data/points_sampler.py @@ -320,8 +320,10 @@ def sample_object(self, sample: DSample): def sample_points(self): ''' - Randomly sample points from gt_mask. \n - Returns: + Randomly sample points from gt_mask. The number of points is max_num_points. + Note that the cls cannot be 255. + + Returns: points ''' assert self._selected_mask is not None diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py index 8fb9634..9cafa0a 100755 --- a/isegm/engine/Multi_trainer.py +++ b/isegm/engine/Multi_trainer.py @@ -377,7 +377,7 @@ def get_next_points(pred, gt, points, click_indx, points_num=15): assert click_indx > 0 pred_o = pred.cpu().numpy()[:, 0, :, :] gt_o = gt.cpu().numpy()[:, :, :, 0] - f_area = pred_o != gt_o + f_area = np.logical_and(pred_o != gt_o, gt_o != 255) f_area_idx = np.where(f_area) num_samples = min(points_num*gt_o.shape[0], len(f_area_idx[0])) random_indices = np.random.choice(len(f_area_idx[0]), size=num_samples, replace=False) diff --git a/isegm/model/ops.py b/isegm/model/ops.py index 27bfb94..f501fd5 100755 --- a/isegm/model/ops.py +++ b/isegm/model/ops.py @@ -116,7 +116,7 @@ def get_coord_features(self, points, batchsize, rows, cols): point_num = points.shape[1] if point_num == 0: zeros = torch.zeros(1, 1, rows, cols, device=points.device) - res = zeros.repeat(1, 7, 1, 1) # TODO 7 cls magic + res = zeros.repeat(1, 19, 1, 1) # TODO 7 cls magic return res else: points = points.view(-1, points.size(2)) @@ -141,12 +141,12 @@ def get_coord_features(self, points, batchsize, rows, cols): coords = (coords <= (5) ** 2).float() zeros = torch.zeros_like(coords) - res = zeros.repeat(1, 7, 1, 1) # TODO 7 cls magic + res = zeros.repeat(1, 19, 1, 1) # TODO 19 cls magic for i in range(len(points_cls)): cls = int(points_cls[i]) res[i, cls] = coords[i, 0] - res = res.view(-1,point_num, 7, rows, cols) # TODO 7 cls magic + res = res.view(-1,point_num, 19, rows, cols) # TODO 19 cls magic res = res.max(dim=1)[0] return res*255 From fddb73ce86446f0d6f8e78293e612143116cedfe Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 19 Apr 2024 16:25:33 +0800 Subject: [PATCH 30/42] MOD: Add ignore arguments to train file --- models/iter_mask/multi_out_huge448_pascal_itermask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index be00839..ac7ff51 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -69,7 +69,7 @@ def train(model, cfg, model_cfg): crop_size = model_cfg.crop_size loss_cfg = edict() - loss_cfg.instance_loss = NormalizedMultiFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss = NormalizedMultiFocalLossSigmoid(alpha=0.5, gamma=2,ignore_label=255) loss_cfg.instance_loss_weight = 1.0 train_augmentator = Compose([ @@ -124,7 +124,7 @@ def train(model, cfg, model_cfg): lr_scheduler=lr_scheduler, checkpoint_interval=[(0, 20), (50, 1)], image_dump_interval=300, - metrics=[AdaptiveMIoU(num_classes=7)], + metrics=[AdaptiveMIoU(num_classes=19,ignore_label=255)], max_interactive_points=model_cfg.num_max_points, max_num_next_clicks=1) trainer.run(num_epochs=55, validation=False) \ No newline at end of file From 3bf2d624979ac1e680ce11d0ba6f0a9ebf5064c6 Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 19 Apr 2024 19:03:29 +0800 Subject: [PATCH 31/42] MOD: Point Sampler Add a `ignore_label` argument to MultiClassSampler. Now the sampler randomly sample from non-ignored area. Modify the CS dataset file, first points now use the random sampled ones. --- isegm/data/base.py | 2 +- isegm/data/datasets/cityscapes.py | 3 ++- isegm/data/points_sampler.py | 17 ++++++++++++----- .../multi_out_huge448_pascal_itermask.py | 2 +- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/isegm/data/base.py b/isegm/data/base.py index 04d3e6f..7f4be3a 100755 --- a/isegm/data/base.py +++ b/isegm/data/base.py @@ -3,7 +3,7 @@ import numpy as np import torch from torchvision import transforms -from .points_sampler import MultiPointSampler +from .points_sampler import MultiPointSampler, MultiClassSampler from .sample import DSample diff --git a/isegm/data/datasets/cityscapes.py b/isegm/data/datasets/cityscapes.py index 389b0d3..dc5c7b2 100644 --- a/isegm/data/datasets/cityscapes.py +++ b/isegm/data/datasets/cityscapes.py @@ -108,10 +108,11 @@ def __getitem__(self, index): coords_and_values.extend([(-1, -1, -1)] * (self.points_sampler.max_num_points - len(coords_and_values))) init_points = np.array(coords_and_values) self.points_sampler.sample_object(sample) + points = np.array(self.points_sampler.sample_points()) mask = self.points_sampler.selected_mask output = { 'images': self.to_tensor(sample.image), - 'points': init_points.astype(np.float32), + 'points': points.astype(np.float32), 'instances': mask, # 'init_points': init_points.astype(np.float32) } diff --git a/isegm/data/points_sampler.py b/isegm/data/points_sampler.py index 249efd1..d57b286 100755 --- a/isegm/data/points_sampler.py +++ b/isegm/data/points_sampler.py @@ -305,6 +305,9 @@ def get_point_candidates(obj_mask, k=1.7, full_prob=0.0): return np.array([click_coords]) class MultiClassSampler(MultiPointSampler): + def __init__(self, ignore_label=255, **kwargs): + super().__init__(**kwargs) + self.ignore_label = ignore_label def sample_object(self, sample: DSample): ''' @@ -330,12 +333,16 @@ def sample_points(self): num_points = self.max_num_points # num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs) h, w = self._selected_mask.shape[:2] + valid_mask = self._selected_mask[:, :, 0] != self.ignore_label + valid_indices = np.where(valid_mask) + selected_indices = np.random.choice(len(valid_indices[0]), self.max_num_points, replace=False) + points = [] - y = np.random.randint(0, h, num_points) - x = np.random.randint(0, w, num_points) - for i in range(num_points): - cls = self._selected_mask[y[i], x[i]][0] - points.append([y[i], x[i], cls]) + + for idx in selected_indices: + y, x = valid_indices[0][idx], valid_indices[1][idx] + cls = self._selected_mask[y, x, 0] + points.append([y, x, cls]) return points diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index ac7ff51..fdc83b5 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -86,7 +86,7 @@ def train(model, cfg, model_cfg): RandomCrop(*crop_size) ], p=1.0) - points_sampler = MultiClassSampler(100, prob_gamma=0.80, + points_sampler = MultiClassSampler(max_num_points=100, prob_gamma=0.80, merge_objects_prob=0.15, max_num_merged_objects=2) From 1138913754ec1d0cef3b89d8bb9efb27f5a1a60d Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 19 Apr 2024 19:27:53 +0800 Subject: [PATCH 32/42] Add run argument --first-return-points --- isegm/data/datasets/cityscapes.py | 11 ++++++++++- isegm/engine/Multi_trainer.py | 2 +- models/iter_mask/multi_out_huge448_pascal_itermask.py | 2 ++ train.py | 1 + 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/isegm/data/datasets/cityscapes.py b/isegm/data/datasets/cityscapes.py index dc5c7b2..42a3f90 100644 --- a/isegm/data/datasets/cityscapes.py +++ b/isegm/data/datasets/cityscapes.py @@ -13,9 +13,10 @@ class CityScapes(ISDataset): - def __init__(self, dataset_path, split="train", use_cache=True, **kwargs): + def __init__(self, dataset_path, split="train", use_cache=True, first_return_points=True, **kwargs): super(CityScapes, self).__init__(**kwargs) assert split in {"train", "val", "trainval", "test"} + assert first_return_points in {"init", "random", "blank"} self.name = "Cityscapes" self.dataset_path = Path(dataset_path) self._images_path = Path("leftImg8bit") / split @@ -24,6 +25,8 @@ def __init__(self, dataset_path, split="train", use_cache=True, **kwargs): self.dataset_split = split self.class_num = 19 self.ignore_id = 255 + self.first_return_points = first_return_points + self.loadfile = self.dataset_split+".pkl" if os.path.exists(str(self.dataset_path/self.loadfile)) and use_cache: @@ -110,6 +113,12 @@ def __getitem__(self, index): self.points_sampler.sample_object(sample) points = np.array(self.points_sampler.sample_points()) mask = self.points_sampler.selected_mask + if self.first_return_points=="init": + points = init_points + elif self.first_return_points=="random": + points = points + else: + points = np.ones((self.points_sampler.max_num_points, 3))*-1 output = { 'images': self.to_tensor(sample.image), 'points': points.astype(np.float32), diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py index 9cafa0a..eb97797 100755 --- a/isegm/engine/Multi_trainer.py +++ b/isegm/engine/Multi_trainer.py @@ -251,7 +251,7 @@ def batch_forward(self, batch_data, validation=False): with torch.set_grad_enabled(not validation): batch_data = {k: v.to(self.device) for k, v in batch_data.items()} image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points'] - orig_image, orig_gt_mask, orig_points = image.clone(), gt_mask.clone(), points.clone() + orig_gt_mask = gt_mask.clone() prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index fdc83b5..2045aa4 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -92,6 +92,7 @@ def train(model, cfg, model_cfg): trainset = CityScapes( cfg.CITYSCAPES_PATH, + first_return_points=cfg.first_return_points, split='train', augmentator=train_augmentator, min_object_area=1000, @@ -103,6 +104,7 @@ def train(model, cfg, model_cfg): valset = CityScapes( cfg.CITYSCAPES_PATH, + first_return_points=cfg.first_return_points, split='val', augmentator=val_augmentator, min_object_area=1000, diff --git a/train.py b/train.py index b2f56c8..49e5807 100755 --- a/train.py +++ b/train.py @@ -76,6 +76,7 @@ def parse_args(): parser.add_argument('--random-split', action='store_true', help='random split the patch instead of window split.') + parser.add_argument('--first-return-points',type=str, default='blank',choices=["init", "random", "blank"]) return parser.parse_args() From ff344b7ce87b48fc2abd1ac0b8c84f674e90ac03 Mon Sep 17 00:00:00 2001 From: guba Date: Fri, 19 Apr 2024 20:50:21 +0800 Subject: [PATCH 33/42] CityScapes won't run sample_points if not specified --- isegm/data/datasets/cityscapes.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/isegm/data/datasets/cityscapes.py b/isegm/data/datasets/cityscapes.py index 42a3f90..1ca734d 100644 --- a/isegm/data/datasets/cityscapes.py +++ b/isegm/data/datasets/cityscapes.py @@ -111,12 +111,11 @@ def __getitem__(self, index): coords_and_values.extend([(-1, -1, -1)] * (self.points_sampler.max_num_points - len(coords_and_values))) init_points = np.array(coords_and_values) self.points_sampler.sample_object(sample) - points = np.array(self.points_sampler.sample_points()) mask = self.points_sampler.selected_mask if self.first_return_points=="init": points = init_points elif self.first_return_points=="random": - points = points + points = np.array(self.points_sampler.sample_points()) else: points = np.ones((self.points_sampler.max_num_points, 3))*-1 output = { From 8d442f18c048f3324153730f21bf9cbd625c185c Mon Sep 17 00:00:00 2001 From: guba Date: Sat, 20 Apr 2024 22:38:40 +0800 Subject: [PATCH 34/42] Add argument Now the number of iter in training can be modified through running args. --- models/iter_mask/multi_out_huge448_pascal_itermask.py | 2 +- train.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py index 2045aa4..acfd493 100644 --- a/models/iter_mask/multi_out_huge448_pascal_itermask.py +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -128,5 +128,5 @@ def train(model, cfg, model_cfg): image_dump_interval=300, metrics=[AdaptiveMIoU(num_classes=19,ignore_label=255)], max_interactive_points=model_cfg.num_max_points, - max_num_next_clicks=1) + max_num_next_clicks=cfg.max_next_clicks) trainer.run(num_epochs=55, validation=False) \ No newline at end of file diff --git a/train.py b/train.py index 49e5807..efb255c 100755 --- a/train.py +++ b/train.py @@ -77,6 +77,7 @@ def parse_args(): parser.add_argument('--random-split', action='store_true', help='random split the patch instead of window split.') parser.add_argument('--first-return-points',type=str, default='blank',choices=["init", "random", "blank"]) + parser.add_argument('--max-next-clicks',type=int, default=1) return parser.parse_args() From 073f2dc285f72715028864630304d112c14b3f0e Mon Sep 17 00:00:00 2001 From: guba Date: Sat, 20 Apr 2024 23:53:54 +0800 Subject: [PATCH 35/42] Dataloader issue solved Add collate_fn, make a dict in trainer itself --- isegm/data/base.py | 29 +++++++++++++++++++++++++++++ isegm/data/datasets/cityscapes.py | 4 ++-- isegm/engine/Multi_trainer.py | 9 +++++++-- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/isegm/data/base.py b/isegm/data/base.py index 7f4be3a..5c1ae89 100755 --- a/isegm/data/base.py +++ b/isegm/data/base.py @@ -1,5 +1,7 @@ import random import pickle +from collections import namedtuple + import numpy as np import torch from torchvision import transforms @@ -110,3 +112,30 @@ def _load_samples_scores(samples_scores_path, samples_scores_gamma): } print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}') return samples_scores + +def is_dataset_collate_fn(data): + """ + data: is a list of tuples with (example, label, length) + where 'example' is a tensor of arbitrary shape + and label/length are scalars + """ + images = [d['images'] for d in data] + points = [d['points'] for d in data] + instances = [d['instances'] for d in data] + + all_points = points + max_len = max([len(x) for x in all_points]) + for i in range(len(data)): + padding_length = max_len - len(points[i]) + if padding_length > 0: + # Create padding of shape (padding_length, 3) and fill it with (-1, -1, -1) + padding = np.full((padding_length, 3), (-1, -1, -1)) + # Concatenate the original data with the padding + padded_point = np.concatenate((all_points[i], padding), axis=0) + all_points[i]=padded_point + images = torch.stack(images) + all_points = np.array(all_points) + instances = np.array(instances) + return images, torch.from_numpy(all_points), torch.from_numpy(instances) + + diff --git a/isegm/data/datasets/cityscapes.py b/isegm/data/datasets/cityscapes.py index 1ca734d..65688cb 100644 --- a/isegm/data/datasets/cityscapes.py +++ b/isegm/data/datasets/cityscapes.py @@ -108,7 +108,7 @@ def __getitem__(self, index): rows, cols = np.where(init_points != 255) non_255_values = init_points[rows, cols] coords_and_values = list(zip(rows, cols, non_255_values)) - coords_and_values.extend([(-1, -1, -1)] * (self.points_sampler.max_num_points - len(coords_and_values))) + # coords_and_values.extend([(-1, -1, -1)] * (self.points_sampler.max_num_points - len(coords_and_values))) init_points = np.array(coords_and_values) self.points_sampler.sample_object(sample) mask = self.points_sampler.selected_mask @@ -117,7 +117,7 @@ def __getitem__(self, index): elif self.first_return_points=="random": points = np.array(self.points_sampler.sample_points()) else: - points = np.ones((self.points_sampler.max_num_points, 3))*-1 + points = np.array([(-1, -1, -1)]) output = { 'images': self.to_tensor(sample.image), 'points': points.astype(np.float32), diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py index eb97797..a9f6834 100755 --- a/isegm/engine/Multi_trainer.py +++ b/isegm/engine/Multi_trainer.py @@ -17,6 +17,7 @@ from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict from .optimizer import get_optimizer, get_optimizer_with_layerwise_decay from .trainer import ISTrainer +from ..data.base import is_dataset_collate_fn class Multi_trainer(ISTrainer): @@ -75,14 +76,16 @@ def __init__(self, model, cfg, model_cfg, loss_cfg, trainset, cfg.batch_size, sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed), drop_last=True, pin_memory=True, - num_workers=cfg.workers + num_workers=cfg.workers, + collate_fn=is_dataset_collate_fn ) self.val_data = DataLoader( valset, cfg.val_batch_size, sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed), drop_last=True, pin_memory=True, - num_workers=cfg.workers + num_workers=cfg.workers, + collate_fn=is_dataset_collate_fn ) if layerwise_decay: @@ -245,6 +248,8 @@ def validation(self, epoch): global_step=epoch, disable_avg=True) def batch_forward(self, batch_data, validation=False): + batch_data = {'images': batch_data[0], 'points': batch_data[1], + 'instances': batch_data[2]} metrics = self.val_metrics if validation else self.train_metrics losses_logging = dict() From 79c24bbf4b2adf8cb797edefedf82ffffab3dba1 Mon Sep 17 00:00:00 2001 From: guba Date: Wed, 24 Apr 2024 08:29:18 +0800 Subject: [PATCH 36/42] MEM: FIX MEMORY ISSUE IN `get_coord_features()` --- isegm/engine/Multi_trainer.py | 2 +- isegm/model/ops.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py index a9f6834..7c68521 100755 --- a/isegm/engine/Multi_trainer.py +++ b/isegm/engine/Multi_trainer.py @@ -254,7 +254,7 @@ def batch_forward(self, batch_data, validation=False): losses_logging = dict() with torch.set_grad_enabled(not validation): - batch_data = {k: v.to(self.device) for k, v in batch_data.items()} + batch_data = {k: v.to(self.device) if k!='points' else v for k, v in batch_data.items()} image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points'] orig_gt_mask = gt_mask.clone() diff --git a/isegm/model/ops.py b/isegm/model/ops.py index f501fd5..dc3b870 100755 --- a/isegm/model/ops.py +++ b/isegm/model/ops.py @@ -117,18 +117,17 @@ def get_coord_features(self, points, batchsize, rows, cols): if point_num == 0: zeros = torch.zeros(1, 1, rows, cols, device=points.device) res = zeros.repeat(1, 19, 1, 1) # TODO 7 cls magic - return res else: points = points.view(-1, points.size(2)) points, points_cls = torch.split(points, [2, 1], dim=1) invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 - row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) - col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) + row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device='cuda') + col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device='cuda') coord_rows, coord_cols = torch.meshgrid(row_array, col_array) coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) - add_xy = (points * 1).view(points.size(0), points.size(1), 1, 1) + add_xy = ((points * 1).view(points.size(0), points.size(1), 1, 1)).to('cuda') coords.add_(-add_xy) coords.mul_(coords) # 96, 2, h, w coords[:, 0] += coords[:, 1] @@ -148,6 +147,7 @@ def get_coord_features(self, points, batchsize, rows, cols): res[i, cls] = coords[i, 0] res = res.view(-1,point_num, 19, rows, cols) # TODO 19 cls magic res = res.max(dim=1)[0] + res = res.to('cuda') return res*255 def forward(self, x, coords): From 48aca95d92a1c0cc77702cb9dd13c56dd0d40469 Mon Sep 17 00:00:00 2001 From: guba Date: Wed, 24 Apr 2024 08:30:04 +0800 Subject: [PATCH 37/42] BUGFIX: CITYSCAPE DEFAULT ARGS CHANGE first_return_points='init' instead of True --- isegm/data/datasets/cityscapes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/isegm/data/datasets/cityscapes.py b/isegm/data/datasets/cityscapes.py index 65688cb..dc94519 100644 --- a/isegm/data/datasets/cityscapes.py +++ b/isegm/data/datasets/cityscapes.py @@ -13,7 +13,7 @@ class CityScapes(ISDataset): - def __init__(self, dataset_path, split="train", use_cache=True, first_return_points=True, **kwargs): + def __init__(self, dataset_path, split="train", use_cache=True, first_return_points="init", **kwargs): super(CityScapes, self).__init__(**kwargs) assert split in {"train", "val", "trainval", "test"} assert first_return_points in {"init", "random", "blank"} From 2223682fa11259fa2ef4d6a54d838912ec695ad5 Mon Sep 17 00:00:00 2001 From: guba Date: Thu, 2 May 2024 21:48:28 +0800 Subject: [PATCH 38/42] Update: config.yml and .gitignore --- .gitignore | 4 ++++ config.yml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0f70c8e..ff1054f 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,7 @@ dmypy.json # Pyre type checker .pyre/ + +.idea/ + +isegm/.DS_Store diff --git a/config.yml b/config.yml index c372dbb..b97c012 100755 --- a/config.yml +++ b/config.yml @@ -20,7 +20,7 @@ OPENIMAGES_PATH: "./datasets/OpenImages" PASCALVOC_PATH: "/playpen-raid2/qinliu/data/PascalVOC" ADE20K_PATH: "./datasets/ADE20K" PASCAL_PATH: "/home/gyt/gyt/dataset/data/pascal_person_part" -CITYSCAPES_PATH: "/home/gyt/gyt/dataset/data/cityscapes" +CITYSCAPES_PATH: "./data/cityscapes" # You can download the weights for HRNet from the repository: # https://github.com/HRNet/HRNet-Image-Classification From e1bc2e892bc3fee749dc92fa17fc914bd80ee5ce Mon Sep 17 00:00:00 2001 From: cc Date: Thu, 2 May 2024 22:06:51 +0800 Subject: [PATCH 39/42] Update: gitignore ignore data folder, maybe should commit before I really mv the folder --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index ff1054f..5b43c1a 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,5 @@ dmypy.json .idea/ isegm/.DS_Store + +data/ \ No newline at end of file From 0bd0c78c8ab8518c4d147a958194856fa54182e0 Mon Sep 17 00:00:00 2001 From: cc Date: Thu, 2 May 2024 22:54:47 +0800 Subject: [PATCH 40/42] Readability: Dataset point return fix --- isegm/data/datasets/cityscapes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/isegm/data/datasets/cityscapes.py b/isegm/data/datasets/cityscapes.py index dc94519..bd5e0a2 100644 --- a/isegm/data/datasets/cityscapes.py +++ b/isegm/data/datasets/cityscapes.py @@ -116,7 +116,7 @@ def __getitem__(self, index): points = init_points elif self.first_return_points=="random": points = np.array(self.points_sampler.sample_points()) - else: + elif self.first_return_points=="blank": points = np.array([(-1, -1, -1)]) output = { 'images': self.to_tensor(sample.image), From b2678acfc365a4df73d1ddd18b95a5c76fd0c0f8 Mon Sep 17 00:00:00 2001 From: cc Date: Thu, 2 May 2024 22:55:14 +0800 Subject: [PATCH 41/42] update gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 5b43c1a..92e21a0 100644 --- a/.gitignore +++ b/.gitignore @@ -137,4 +137,5 @@ dmypy.json isegm/.DS_Store -data/ \ No newline at end of file +data/ +exps/ From 2a5aabaa82559cf6663749c9a118023579067f87 Mon Sep 17 00:00:00 2001 From: cc Date: Thu, 2 May 2024 22:56:06 +0800 Subject: [PATCH 42/42] Doc: Code comment --- isegm/engine/Multi_trainer.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py index 7c68521..c069cd2 100755 --- a/isegm/engine/Multi_trainer.py +++ b/isegm/engine/Multi_trainer.py @@ -264,7 +264,11 @@ def batch_forward(self, batch_data, validation=False): # First part with torch.no_grad(): num_iters = self.max_num_next_clicks # Here max_num_next_clicks is 3 in default - + ''' + In this block, the net will click for num_iters times(max-next-clicks) + prev_output is initial zeros. + The points are processed in dataset + ''' for click_indx in range(num_iters): last_click_indx = click_indx @@ -277,7 +281,7 @@ def batch_forward(self, batch_data, validation=False): eval_model = self.click_models[click_indx] net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image - prev_output = torch.sigmoid(eval_model(net_input, points)['instances']) + prev_output = torch.sigmoid(eval_model(net_input, points)['instances']) # <== This line run the model prev_output = torch.max(prev_output, dim=1, keepdim=True)[1] points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1) @@ -379,6 +383,20 @@ def is_master(self): def get_next_points(pred, gt, points, click_indx, points_num=15): + """ + This function get points feedback DURING training. + + Args: + pred (_type_): _description_ + gt (_type_): _description_ + points (_type_): _description_ + click_indx (_type_): _description_ + points_num (int, optional): _description_. Defaults to 15. + + Returns: + points: torch.Tensor: [batch_size, num_points, 3] + """ + assert click_indx > 0 pred_o = pred.cpu().numpy()[:, 0, :, :] gt_o = gt.cpu().numpy()[:, :, :, 0]