diff --git a/.gitignore b/.gitignore index 0f70c8e..92e21a0 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,10 @@ dmypy.json # Pyre type checker .pyre/ + +.idea/ + +isegm/.DS_Store + +data/ +exps/ diff --git a/TEST_Dataset.py b/TEST_Dataset.py new file mode 100644 index 0000000..95e0b11 --- /dev/null +++ b/TEST_Dataset.py @@ -0,0 +1,39 @@ +from isegm.data.datasets import PASCAL +from isegm.data.points_sampler import MultiClassSampler +import argparse +import os +from pathlib import Path + +from isegm.data.points_sampler import MultiClassSampler +from isegm.engine.Multi_trainer import Multi_trainer +from isegm.inference.clicker import Click +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.model.metrics import AdaptiveMIoU +from isegm.utils.exp import init_experiment +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss +from train import load_module + +points_sampler = MultiClassSampler(2, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) +trainset = PASCAL( + "/home/gyt/gyt/dataset/data/pascal_person_part", + split='train', + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + # stuff_prob=0.30 +) + +valset = PASCAL( + "/home/gyt/gyt/dataset/data/pascal_person_part", + split='val', + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 +) + +for batch_data in trainset: + print(batch_data["points"].shape) \ No newline at end of file diff --git a/TEST_read_trained_model.py b/TEST_read_trained_model.py new file mode 100644 index 0000000..458e942 --- /dev/null +++ b/TEST_read_trained_model.py @@ -0,0 +1,221 @@ +import argparse +import os +from pathlib import Path + +from isegm.data.points_sampler import MultiClassSampler +from isegm.engine.Multi_trainer import Multi_trainer +from isegm.inference.clicker import Click +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.model.metrics import AdaptiveMIoU +from isegm.utils.exp import init_experiment +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss +from train import load_module + +MODEL_NAME = 'cocolvis_vit_huge448' + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('model_path', type=str, + help='Path to the model script.') + + parser.add_argument('--exp-name', type=str, default='', + help='Here you can specify the name of the experiment. ' + 'It will be added as a suffix to the experiment folder.') + + parser.add_argument('--workers', type=int, default=4, + metavar='N', help='Dataloader threads.') + + parser.add_argument('--batch-size', type=int, default=-1, + help='You can override model batch size by specify positive number.') + + parser.add_argument('--ngpus', type=int, default=1, + help='Number of GPUs. ' + 'If you only specify "--gpus" argument, the ngpus value will be calculated automatically. ' + 'You should use either this argument or "--gpus".') + + parser.add_argument('--gpus', type=str, default='', required=False, + help='Ids of used GPUs. You should use either this argument or "--ngpus".') + + parser.add_argument('--resume-exp', type=str, default=None, + help='The prefix of the name of the experiment to be continued. ' + 'If you use this field, you must specify the "--resume-prefix" argument.') + + parser.add_argument('--resume-prefix', type=str, default='latest', + help='The prefix of the name of the checkpoint to be loaded.') + + parser.add_argument('--start-epoch', type=int, default=0, + help='The number of the starting epoch from which training will continue. ' + '(it is important for correct logging and learning rate)') + + parser.add_argument('--weights', type=str, default=None, + help='Model weights will be loaded from the specified path if you use this argument.') + + parser.add_argument('--temp-model-path', type=str, default='', + help='Do not use this argument (for internal purposes).') + + parser.add_argument("--local_rank", type=int, default=0) + + # parameters for experimenting + parser.add_argument('--layerwise-decay', action='store_true', + help='layer wise decay for transformer blocks.') + + parser.add_argument('--upsample', type=str, default='x1', + help='upsample the output.') + + parser.add_argument('--random-split', action='store_true', + help='random split the patch instead of window split.') + + return parser.parse_args() +def main(): + model, model_cfg = init_model() + weight_path = "last_checkpoint.pth" + weights = torch.load(weight_path) + model.load_state_dict(weights['state_dict']) + model.eval() + cfg = edict() + cfg.weights = weight_path + cfg.extra_name = "only_init" + train(model, cfg, model_cfg) + + +def init_model(): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(14,14), + in_chans=3, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1280, + out_dims = [240, 480, 960, 1920], + ) + + head_params = dict( + in_channels=[240, 480, 960, 1920], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=7, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample='x1', + channels={'x1': 256, 'x2': 128, 'x4': 64}['x1'], + ) + + model = MultiOutVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=False, + ) + model.to('cuda') + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 1 + cfg.distributed = 'WORLD_SIZE' in os.environ + cfg.local_rank = 0 + cfg.workers = 4 + cfg.val_batch_size = cfg.batch_size + cfg.ngpus = 1 + cfg.device = torch.device('cuda') + cfg.start_epoch = 0 + cfg.multi_gpu = cfg.ngpus > 1 + crop_size = model_cfg.crop_size + + cfg.EXPS_PATH = 'TST_OUT' + experiments_path = Path(cfg.EXPS_PATH) + exp_parent_path = experiments_path / '/'.join("") + exp_parent_path.mkdir(parents=True, exist_ok=True) + + + last_exp_indx = 0 + exp_name = f'{last_exp_indx:03d}' + exp_path = exp_parent_path / exp_name + + if cfg.local_rank == 0: + exp_path.mkdir(parents=True, exist_ok=True) + + cfg.EXP_PATH = exp_path + cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' + cfg.VIS_PATH = exp_path / 'vis' + cfg.LOGS_PATH = exp_path / 'logs' / cfg.weights /cfg.extra_name + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedMultiFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiClassSampler(100, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = PASCAL( + "/home/gyt/gyt/dataset/data/pascal_person_part", + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + # stuff_prob=0.30 + ) + + valset = PASCAL( + "/home/gyt/gyt/dataset/data/pascal_person_part", + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[50, 55], gamma=0.1) + trainer = Multi_trainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 20), (50, 1)], + image_dump_interval=300, + metrics=[AdaptiveMIoU(num_classes=7)], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=15) + trainer.validation(epoch=0) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/TEST_simple_model.py b/TEST_simple_model.py new file mode 100644 index 0000000..afc5cf7 --- /dev/null +++ b/TEST_simple_model.py @@ -0,0 +1,87 @@ +from isegm.inference.clicker import Clicker, Click +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss + +def init_model(): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(16,16), + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1024, + out_dims = [192, 384, 768, 1536], + ) + + head_params = dict( + in_channels=[192, 384, 768, 1536], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=7, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample='x1', + channels={'x1': 256, 'x2': 128, 'x4': 64}['x1'], + ) + + model = MultiOutVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=False, + ) + + # model.backbone.init_weights_from_pretrained("./weights/pretrained/cocolvis_vit_huge.pth") + model.to("cuda") + + return model, model_cfg + +def get_points_nd(clicks_lists): + total_clicks = [] + num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] + num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_max_points = max(num_pos_clicks + num_neg_clicks) + num_max_points = max(1, num_max_points) + + for clicks_list in clicks_lists: + clicks_list = clicks_list[:5] + pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] + + neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + total_clicks.append(pos_clicks + neg_clicks) + + return torch.tensor(total_clicks, device="cuda") + +def add_mask(img): + input_image = torch.cat((img, torch.zeros(1,1,448,448).cuda()), dim=1) + return input_image + + +model, model_cfg = init_model() + +import torch +import cv2 +import numpy as np + +img = torch.rand(1, 3, 448, 448).cuda() +click = Click(is_positive=True, coords=(1, 1), indx=0) +click_list = [[click]] +out = model(add_mask(img), get_points_nd(click_list)) + +print("done!") \ No newline at end of file diff --git a/config.yml b/config.yml index 95457a8..b97c012 100755 --- a/config.yml +++ b/config.yml @@ -1,5 +1,5 @@ INTERACTIVE_MODELS_PATH: "./weights" -EXPS_PATH: "/playpen-raid2/qinliu/models/model_0125_2023" +EXPS_PATH: "./exps" # Evaluation datasets GRABCUT_PATH: "/playpen-raid2/qinliu/data/GrabCut" @@ -8,7 +8,7 @@ DAVIS_PATH: "/playpen-raid2/qinliu/data/DAVIS345" COCO_MVAL_PATH: "/playpen-raid2/qinliu/data/COCO_MVal" BraTS_PATH: "/playpen-raid2/qinliu/data/BraTS20" ssTEM_PATH: "/playpen-raid2/qinliu/data/ssTEM" -OAIZIB_PATH: "/playpen-raid2/qinliu/data/OAI-ZIB/iseg_slices" +OAIZIB_PATH: "./dataset/OAI-ZIB" OAI_PATH: "/playpen-raid2/qinliu/data/OAI" HARD_PATH: "/playpen-raid2/qinliu/data/HARD" @@ -19,6 +19,8 @@ LVIS_v1_PATH: "/playpen-raid2/qinliu/data/COCO_2017" OPENIMAGES_PATH: "./datasets/OpenImages" PASCALVOC_PATH: "/playpen-raid2/qinliu/data/PascalVOC" ADE20K_PATH: "./datasets/ADE20K" +PASCAL_PATH: "/home/gyt/gyt/dataset/data/pascal_person_part" +CITYSCAPES_PATH: "./data/cityscapes" # You can download the weights for HRNet from the repository: # https://github.com/HRNet/HRNet-Image-Classification diff --git a/isegm/data/base.py b/isegm/data/base.py index ee2a532..5c1ae89 100755 --- a/isegm/data/base.py +++ b/isegm/data/base.py @@ -1,9 +1,11 @@ import random import pickle +from collections import namedtuple + import numpy as np import torch from torchvision import transforms -from .points_sampler import MultiPointSampler +from .points_sampler import MultiPointSampler, MultiClassSampler from .sample import DSample @@ -30,6 +32,19 @@ def __init__(self, self.dataset_samples = None def __getitem__(self, index): + ''' + + Args: + index: + + Returns: + { + 'images': torch.Tensor, # The image tensor, + 'points': np.ndarray, # Points, take max_num_points as 24, then shape is (48, 3). First 24 is pos, last 24 is neg. First few is [y, x, 100], then extended with (-1, -1, -1). + 'instances': np.ndarray # The mask + } + + ''' if self.samples_precomputed_scores is not None: index = np.random.choice(self.samples_precomputed_scores['indices'], p=self.samples_precomputed_scores['probs']) @@ -97,3 +112,30 @@ def _load_samples_scores(samples_scores_path, samples_scores_gamma): } print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}') return samples_scores + +def is_dataset_collate_fn(data): + """ + data: is a list of tuples with (example, label, length) + where 'example' is a tensor of arbitrary shape + and label/length are scalars + """ + images = [d['images'] for d in data] + points = [d['points'] for d in data] + instances = [d['instances'] for d in data] + + all_points = points + max_len = max([len(x) for x in all_points]) + for i in range(len(data)): + padding_length = max_len - len(points[i]) + if padding_length > 0: + # Create padding of shape (padding_length, 3) and fill it with (-1, -1, -1) + padding = np.full((padding_length, 3), (-1, -1, -1)) + # Concatenate the original data with the padding + padded_point = np.concatenate((all_points[i], padding), axis=0) + all_points[i]=padded_point + images = torch.stack(images) + all_points = np.array(all_points) + instances = np.array(instances) + return images, torch.from_numpy(all_points), torch.from_numpy(instances) + + diff --git a/isegm/data/datasets/LIP.py b/isegm/data/datasets/LIP.py new file mode 100644 index 0000000..ac299f4 --- /dev/null +++ b/isegm/data/datasets/LIP.py @@ -0,0 +1,241 @@ +import os +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample +from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes + +from tqdm import tqdm +import pickle + + +class LIP(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'trainval', 'test'} + self.name = 'LIP' + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "train_images" + self._insts_path = self.dataset_path / "train_segmentations" + self.init_path = self.dataset_path / "20_cls_interactive_point" + self.dataset_split = split + self.class_num = 20 # 这个class_num 指所有在miou中可以被计算的类,包含背景类但不包含忽略区域 + self.ignore_id = 255 + + self.loadfile = self.dataset_split+".pkl" + if os.path.exists(str(self.dataset_path/self.loadfile)): + with open(str(self.dataset_path/self.loadfile), 'rb') as file: + self.dataset_samples = pickle.load(file) + else: + dataset_samples = [] + idsfile = self.dataset_split+"_id.txt" + with open(str(self.dataset_path/idsfile), "r") as f: + id_list = [line.strip() for line in f.readlines()] + for id in id_list: + img_path = self._images_path/(id+".jpg") + gt_path = self._insts_path/(id+".png") + init_path = self.init_path/(id+".png") + dataset_samples.append((img_path, gt_path, init_path)) + image_id_lst = self.get_images_and_ids_list(dataset_samples) + self.dataset_samples = image_id_lst + # print(image_id_lst[:5]) + + ''' + def get_sample(self, index) -> DSample: + sample_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{sample_id}.jpg') + mask_path = str(self._insts_path / f'{sample_id}.png') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + if self.dataset_split == 'test': + instance_id = self.instance_ids[index] + mask = np.zeros_like(instances_mask) + mask[instances_mask == 220] = 220 # ignored area + mask[instances_mask == instance_id] = 1 + objects_ids = [1] + instances_mask = mask + else: + objects_ids = np.unique(instances_mask) + objects_ids = [x for x in objects_ids if x != 0 and x != 220] + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) + ''' + + def get_sample(self, index) -> DSample: + sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] + # sample_id = str(sample_id) + # print(sample_id) + # num_zero = 6 - len(sample_id) + # sample_id = '2007_'+'0'*num_zero + sample_id + + image_path = str(sample_path) + mask_path = str(target_path) + init_path = str(init_path) + + # print(image_path) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + + + mask = instances_mask + # mask[instances_mask == 255] = 220 # ignored area + # mask[instances_mask == instance_id] = 1 + objects_ids = instance_ids # 现在instance_ids 是一个列表 + instances_mask = mask + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[self.ignore_id], sample_id=index, init_clicks=init_path) + + def get_images_and_ids_list(self, dataset_samples, ignore_id = 255): + images_and_ids_list = [] + object_count = 0 + # for i in tqdm(range(len(dataset_samples))): + for i in range(len(dataset_samples)): + image_path, mask_path, init_path = dataset_samples[i] + instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != ignore_id] + object_count+=len(objects_ids) + # for j in objects_ids: + images_and_ids_list.append([image_path, mask_path ,objects_ids, init_path]) + # print(i,j,objects_ids) + with open(str(self.dataset_path/self.loadfile), "wb") as file: + pickle.dump(images_and_ids_list, file) + print("file count: "+str(len(dataset_samples))) + print("object count: "+str(object_count)) + return images_and_ids_list + +class LIP_train(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'trainval', 'test'} + + self._buggy_mask_thresh = 0.08 + self._buggy_objects = dict() + + self.name = 'LIP' + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "train_images" + self._insts_path = self.dataset_path / "train_segmentations" + self.init_path = self.dataset_path / "20_cls_interactive_point" + self.dataset_split = split + self.class_num = 19 # 这个class_num 指所有在miou中可以被计算的类,包含背景类但不包含忽略区域 + self.ignore_id = 0 + + self.loadfile = self.dataset_split+".pkl" + if os.path.exists(str(self.dataset_path/self.loadfile)): + with open(str(self.dataset_path/self.loadfile), 'rb') as file: + self.dataset_samples = pickle.load(file) + else: + dataset_samples = [] + idsfile = self.dataset_split+"_id.txt" + with open(str(self.dataset_path/idsfile), "r") as f: + id_list = [line.strip() for line in f.readlines()] + for id in id_list: + img_path = self._images_path/(id+".jpg") + gt_path = self._insts_path/(id+".png") + init_path = self.init_path/(id+".png") + dataset_samples.append((img_path, gt_path, init_path)) + image_id_lst = self.get_images_and_ids_list(dataset_samples) + self.dataset_samples = image_id_lst + # print(image_id_lst[:5]) + + ''' + def get_sample(self, index) -> DSample: + sample_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{sample_id}.jpg') + mask_path = str(self._insts_path / f'{sample_id}.png') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + if self.dataset_split == 'test': + instance_id = self.instance_ids[index] + mask = np.zeros_like(instances_mask) + mask[instances_mask == 220] = 220 # ignored area + mask[instances_mask == instance_id] = 1 + objects_ids = [1] + instances_mask = mask + else: + objects_ids = np.unique(instances_mask) + objects_ids = [x for x in objects_ids if x != 0 and x != 220] + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) + ''' + + def get_sample(self, index) -> DSample: + sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] + # sample_id = str(sample_id) + # print(sample_id) + # num_zero = 6 - len(sample_id) + # sample_id = '2007_'+'0'*num_zero + sample_id + + image_path = str(sample_path) + mask_path = str(target_path) + init_path = str(init_path) + + # print(image_path) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + + instances_mask = self.remove_buggy_masks(index, instances_mask) + instances_ids, _ = get_labels_with_sizes(instances_mask, ignoreid=0) + + objects_ids = instances_ids # 现在instance_ids 是一个列表 + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[0], sample_id=index) + + def get_images_and_ids_list(self, dataset_samples, ignore_id = 0): + images_and_ids_list = [] + object_count = 0 + # for i in tqdm(range(len(dataset_samples))): + for i in range(len(dataset_samples)): + image_path, mask_path, init_path = dataset_samples[i] + instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != ignore_id] + object_count+=len(objects_ids) + # for j in objects_ids: + images_and_ids_list.append([image_path, mask_path ,objects_ids, init_path]) + # print(i,j,objects_ids) + with open(str(self.dataset_path/self.loadfile), "wb") as file: + pickle.dump(images_and_ids_list, file) + print("file count: "+str(len(dataset_samples))) + print("object count: "+str(object_count)) + return images_and_ids_list + def remove_buggy_masks(self, index, instances_mask): + if self._buggy_mask_thresh > 0.0: + buggy_image_objects = self._buggy_objects.get(index, None) + if buggy_image_objects is None: + buggy_image_objects = [] + instances_ids, _ = get_labels_with_sizes(instances_mask) + for obj_id in instances_ids: + obj_mask = instances_mask == obj_id + mask_area = obj_mask.sum() + bbox = get_bbox_from_mask(obj_mask) + bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) + obj_area_ratio = mask_area / bbox_area + if obj_area_ratio < self._buggy_mask_thresh: + buggy_image_objects.append(obj_id) + + self._buggy_objects[index] = buggy_image_objects + for obj_id in buggy_image_objects: + instances_mask[instances_mask == obj_id] = 0 + + return instances_mask \ No newline at end of file diff --git a/isegm/data/datasets/PASCAL.py b/isegm/data/datasets/PASCAL.py new file mode 100644 index 0000000..5da4b82 --- /dev/null +++ b/isegm/data/datasets/PASCAL.py @@ -0,0 +1,188 @@ +import os +import pickle as pkl +import random +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample +from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes +from tqdm import tqdm +import pickle + + +class PASCAL(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'trainval', 'test'} + + self._buggy_mask_thresh = 0.08 + self._buggy_objects = dict() + + self.name = 'PASCAL' + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "JPEGImages" + self._insts_path = self.dataset_path / "SegmentationPart_label" + self.init_path = self.dataset_path / "voc_person_interactive_center_point" + self.dataset_split = split + self.class_num = 7 # 这个class_num 指所有在miou中可以被计算的类,包含背景类但不包含忽略区域 + self.ignore_id = 255 + + self.loadfile = self.dataset_split+".pkl" + if os.path.exists(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile)): + with open(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile), 'rb') as file: + self.dataset_samples = pickle.load(file) + else: + dataset_samples = [] + idsfile = self.dataset_split+"_id.txt" + with open(str(self.dataset_path/"pascal_person_part_trainval_list"/idsfile), "r") as f: + id_list = [line.strip() for line in f.readlines()] + for id in id_list: + img_path = self._images_path/(id+".jpg") + gt_path = self._insts_path/(id+".png") + init_path = self.init_path/(id+".png") + dataset_samples.append((img_path, gt_path, init_path)) + image_id_lst = self.get_images_and_ids_list(dataset_samples) + self.dataset_samples = image_id_lst + # print(image_id_lst[:5]) + + ''' + def get_sample(self, index) -> DSample: + sample_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{sample_id}.jpg') + mask_path = str(self._insts_path / f'{sample_id}.png') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + if self.dataset_split == 'test': + instance_id = self.instance_ids[index] + mask = np.zeros_like(instances_mask) + mask[instances_mask == 220] = 220 # ignored area + mask[instances_mask == instance_id] = 1 + objects_ids = [1] + instances_mask = mask + else: + objects_ids = np.unique(instances_mask) + objects_ids = [x for x in objects_ids if x != 0 and x != 220] + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) + ''' + + def get_sample(self, index) -> DSample: + sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] + # sample_id = str(sample_id) + # print(sample_id) + # num_zero = 6 - len(sample_id) + # sample_id = '2007_'+'0'*num_zero + sample_id + + image_path = str(sample_path) + mask_path = str(target_path) + init_path = str(init_path) + + # print(image_path) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + # mask[instances_mask == 255] = 220 # ignored area + # mask[instances_mask == instance_id] = 1 + objects_ids = instance_ids # 现在instance_ids 是一个列表 + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[self.ignore_id], sample_id=index, init_clicks=init_path) + + def get_images_and_ids_list(self, dataset_samples, ignore_id = 255): + images_and_ids_list = [] + object_count = 0 + # for i in tqdm(range(len(dataset_samples))): + for i in range(len(dataset_samples)): + image_path, mask_path, init_path = dataset_samples[i] + instances_mask = cv2.imread(str(mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != ignore_id] + object_count+=len(objects_ids) + # for j in objects_ids: + images_and_ids_list.append([image_path, mask_path ,objects_ids, init_path]) + # print(i,j,objects_ids) + with open(str(self.dataset_path/"pascal_person_part_trainval_list"/self.loadfile), "wb") as file: + pickle.dump(images_and_ids_list, file) + print("file count: "+str(len(dataset_samples))) + print("object count: "+str(object_count)) + return images_and_ids_list + def remove_buggy_masks(self, index, instances_mask): + if self._buggy_mask_thresh > 0.0: + buggy_image_objects = self._buggy_objects.get(index, None) + if buggy_image_objects is None: + buggy_image_objects = [] + instances_ids, _ = get_labels_with_sizes(instances_mask) + for obj_id in instances_ids: + obj_mask = instances_mask == obj_id + mask_area = obj_mask.sum() + bbox = get_bbox_from_mask(obj_mask) + bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) + obj_area_ratio = mask_area / bbox_area + if obj_area_ratio < self._buggy_mask_thresh: + buggy_image_objects.append(obj_id) + + self._buggy_objects[index] = buggy_image_objects + for obj_id in buggy_image_objects: + instances_mask[instances_mask == obj_id] = 0 + + return instances_mask + + def __len__(self): + return len(self.dataset_samples) + + def __getitem__(self, index): # points should be sampled from the whole mask + ''' + + Args: + index: + + Returns: + { + 'images': torch.Tensor, # The image tensor, + 'points': np.ndarray, # Points, take max_num_points as 24, then shape is (48, 3). First 24 is pos, last 24 is neg. First few is [y, x, 100], then extended with (-1, -1, -1). + 'instances': np.ndarray # The mask + } + + ''' + if self.samples_precomputed_scores is not None: + index = np.random.choice(self.samples_precomputed_scores['indices'], + p=self.samples_precomputed_scores['probs']) + else: + if self.epoch_len > 0: + index = random.randrange(0, len(self.dataset_samples)) + + sample = self.get_sample(index) + sample = self.augment_sample(sample) + sample.remove_small_objects(self.min_object_area) + init_points = cv2.imread(sample.init_clicks)[:,:,0] + rows, cols = np.where(init_points != 255) + non_255_values = init_points[rows, cols] + coords_and_values = list(zip(rows, cols, non_255_values)) + coords_and_values.extend([(-1, -1, -1)] * (self.points_sampler.max_num_points - len(coords_and_values))) + init_points = np.array(coords_and_values) + self.points_sampler.sample_object(sample) + points = np.array(self.points_sampler.sample_points()) + + mask = self.points_sampler.selected_mask + + output = { + 'images': self.to_tensor(sample.image), + 'points': init_points.astype(np.float32), + 'instances': mask, + # 'init_points': init_points.astype(np.float32) + } + + if self.with_image_info: + output['image_info'] = sample.sample_id + + return output + + diff --git a/isegm/data/datasets/__init__.py b/isegm/data/datasets/__init__.py index b8110cb..c4f56dc 100755 --- a/isegm/data/datasets/__init__.py +++ b/isegm/data/datasets/__init__.py @@ -15,4 +15,5 @@ from .ssTEM import ssTEMDataset from .oai_zib import OAIZIBDataset from .oai import OAIDataset -from .hard import HARDDataset \ No newline at end of file +from .hard import HARDDataset +from .PASCAL import PASCAL \ No newline at end of file diff --git a/isegm/data/datasets/cityscapes.py b/isegm/data/datasets/cityscapes.py new file mode 100644 index 0000000..bd5e0a2 --- /dev/null +++ b/isegm/data/datasets/cityscapes.py @@ -0,0 +1,129 @@ +import os +import pickle as pkl +from pathlib import Path +import random +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample +from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes +from tqdm import tqdm +import pickle + + +class CityScapes(ISDataset): + def __init__(self, dataset_path, split="train", use_cache=True, first_return_points="init", **kwargs): + super(CityScapes, self).__init__(**kwargs) + assert split in {"train", "val", "trainval", "test"} + assert first_return_points in {"init", "random", "blank"} + self.name = "Cityscapes" + self.dataset_path = Path(dataset_path) + self._images_path = Path("leftImg8bit") / split + self._insts_path = Path("gtFine") / split + self.init_path = Path("init_interactive_point") + self.dataset_split = split + self.class_num = 19 + self.ignore_id = 255 + self.first_return_points = first_return_points + + + self.loadfile = self.dataset_split+".pkl" + if os.path.exists(str(self.dataset_path/self.loadfile)) and use_cache: + with open(str(self.dataset_path/self.loadfile), 'rb') as file: + self.dataset_samples = pickle.load(file) + else: + dataset_samples = [] + for city in os.listdir(self.dataset_path/self._images_path): + img_dir = self._images_path / city + target_dir = self._insts_path / city + init_dir = self.init_path / city + for file_name in os.listdir(self.dataset_path/img_dir): + toAddPath = img_dir / file_name + initName = file_name.replace("_leftImg8bit", "") + initPath = init_dir / initName + labelName = file_name.replace("leftImg8bit", "gtFine_labelTrainIds") + labelPath = target_dir / labelName + dataset_samples.append((toAddPath, labelPath, initPath)) + image_id_lst = self.get_images_and_ids_list(dataset_samples) + self.dataset_samples = image_id_lst + # print(image_id_lst[:5]) + + def get_sample(self, index) -> DSample: + sample_path, target_path, instance_ids, init_path = self.dataset_samples[index] + + image_path = str(self.dataset_path/sample_path) + mask_path = str(self.dataset_path/target_path) + init_path = str(self.dataset_path/init_path) + + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( + np.int32 + ) + + ids = [x for x in np.unique(instances_mask) if x != self.ignore_id] + + objects_ids = ids # 现在instance_ids 是一个列表 + + return DSample( + image, + instances_mask, + objects_ids=objects_ids, + ignore_ids=[self.ignore_id], + sample_id=index, + init_clicks=init_path, + ) + + def get_images_and_ids_list(self, dataset_samples): + images_and_ids_list = [] + object_count = 0 + for i in tqdm(range(len(dataset_samples))): + # for i in range(len(dataset_samples)): + image_path, mask_path, init_path = dataset_samples[i] + instances_mask = cv2.imread(str(self.dataset_path/mask_path)) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype( + np.int32 + ) + objects_ids = np.unique(instances_mask) + + objects_ids = [x for x in objects_ids if x != self.ignore_id] + object_count += len(objects_ids) + + images_and_ids_list.append([image_path, mask_path, objects_ids, init_path]) + + with open(str(self.dataset_path/self.loadfile), "wb") as file: + pickle.dump(images_and_ids_list, file) + return images_and_ids_list + + def __len__(self): + return len(self.dataset_samples) + + def __getitem__(self, index): + sample = self.get_sample(index) + sample = self.augment_sample(sample) + init_points = cv2.imread(sample.init_clicks)[:, :, 0] + rows, cols = np.where(init_points != 255) + non_255_values = init_points[rows, cols] + coords_and_values = list(zip(rows, cols, non_255_values)) + # coords_and_values.extend([(-1, -1, -1)] * (self.points_sampler.max_num_points - len(coords_and_values))) + init_points = np.array(coords_and_values) + self.points_sampler.sample_object(sample) + mask = self.points_sampler.selected_mask + if self.first_return_points=="init": + points = init_points + elif self.first_return_points=="random": + points = np.array(self.points_sampler.sample_points()) + elif self.first_return_points=="blank": + points = np.array([(-1, -1, -1)]) + output = { + 'images': self.to_tensor(sample.image), + 'points': points.astype(np.float32), + 'instances': mask, + # 'init_points': init_points.astype(np.float32) + } + return output + + diff --git a/isegm/data/points_sampler.py b/isegm/data/points_sampler.py index 79f0cd7..d57b286 100755 --- a/isegm/data/points_sampler.py +++ b/isegm/data/points_sampler.py @@ -303,3 +303,84 @@ def get_point_candidates(obj_mask, k=1.7, full_prob=0.0): click_indx = np.random.choice(len(prob_map), p=prob_map) click_coords = np.unravel_index(click_indx, dt.shape) return np.array([click_coords]) + +class MultiClassSampler(MultiPointSampler): + def __init__(self, ignore_label=255, **kwargs): + super().__init__(**kwargs) + self.ignore_label = ignore_label + def sample_object(self, sample: DSample): + ''' + + Args: + sample: + + Returns: + + Notes: + Note that sample is a DSample instance. + ''' + self._selected_mask = sample._encoded_masks + + def sample_points(self): + ''' + Randomly sample points from gt_mask. The number of points is max_num_points. + Note that the cls cannot be 255. + + Returns: points + + ''' + assert self._selected_mask is not None + num_points = self.max_num_points + # num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs) + h, w = self._selected_mask.shape[:2] + valid_mask = self._selected_mask[:, :, 0] != self.ignore_label + valid_indices = np.where(valid_mask) + selected_indices = np.random.choice(len(valid_indices[0]), self.max_num_points, replace=False) + + points = [] + + for idx in selected_indices: + y, x = valid_indices[0][idx], valid_indices[1][idx] + cls = self._selected_mask[y, x, 0] + points.append([y, x, cls]) + + return points + + def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False): + selected_masks = selected_masks[:self.max_num_points] + + each_obj_points = [ + self._sample_points(mask, is_negative=is_negative[i], + with_first_click=with_first_click) + for i, mask in enumerate(selected_masks) + ] + each_obj_points = [x for x in each_obj_points if len(x) > 0] + + points = [] + if len(each_obj_points) == 1: + points = each_obj_points[0] + elif len(each_obj_points) > 1: + if self.only_one_first_click: + each_obj_points = each_obj_points[:1] + + points = [obj_points[0] for obj_points in each_obj_points] + + aggregated_masks_with_prob = [] + for indx, x in enumerate(selected_masks): + if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)): + for t, prob in x: + aggregated_masks_with_prob.append((t, prob / len(selected_masks))) + else: + aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks))) + + other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True) + if len(other_points_union) + len(points) <= self.max_num_points: + points.extend(other_points_union) + else: + points.extend(random.sample(other_points_union, self.max_num_points - len(points))) + + if len(points) < self.max_num_points: + points.extend([(-1, -1, -1)] * (self.max_num_points - len(points))) + + return points + diff --git a/isegm/data/sample.py b/isegm/data/sample.py index a08c409..472caaa 100755 --- a/isegm/data/sample.py +++ b/isegm/data/sample.py @@ -7,9 +7,10 @@ class DSample: def __init__(self, image, encoded_masks, objects=None, - objects_ids=None, ignore_ids=None, sample_id=None): + objects_ids=None, ignore_ids=None, sample_id=None, init_clicks=None): self.image = image self.sample_id = sample_id + self.init_clicks = init_clicks if len(encoded_masks.shape) == 2: encoded_masks = encoded_masks[:, :, np.newaxis] diff --git a/isegm/engine/Multi_trainer.py b/isegm/engine/Multi_trainer.py new file mode 100755 index 0000000..c069cd2 --- /dev/null +++ b/isegm/engine/Multi_trainer.py @@ -0,0 +1,439 @@ +import os +import random +import logging +from copy import deepcopy +from collections import defaultdict + +import cv2 +import torch +import numpy as np +from tqdm import tqdm +from torch.utils.data import DataLoader + +from isegm.utils.log import logger, TqdmToLogger, SummaryWriterAvg +from isegm.utils.vis import draw_probmap, draw_points +from isegm.utils.misc import save_checkpoint +from isegm.utils.serialization import get_config_repr +from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict +from .optimizer import get_optimizer, get_optimizer_with_layerwise_decay +from .trainer import ISTrainer +from ..data.base import is_dataset_collate_fn + + +class Multi_trainer(ISTrainer): + def __init__(self, model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=None, + layerwise_decay=False, + image_dump_interval=200, + checkpoint_interval=10, + tb_dump_period=25, + max_interactive_points=0, + lr_scheduler=None, + metrics=None, + additional_val_metrics=None, + net_inputs=('images', 'points'), + max_num_next_clicks=0, + click_models=None, + prev_mask_drop_prob=0.0, + ): + self.cfg = cfg + self.model_cfg = model_cfg + self.max_interactive_points = max_interactive_points + self.loss_cfg = loss_cfg + self.val_loss_cfg = deepcopy(loss_cfg) + self.tb_dump_period = tb_dump_period + self.net_inputs = net_inputs + self.max_num_next_clicks = max_num_next_clicks + + self.click_models = click_models + self.prev_mask_drop_prob = prev_mask_drop_prob + + if cfg.distributed: + cfg.batch_size //= cfg.ngpus + cfg.val_batch_size //= cfg.ngpus + + if metrics is None: + metrics = [] + self.train_metrics = metrics + self.val_metrics = deepcopy(metrics) + if additional_val_metrics is not None: + self.val_metrics.extend(additional_val_metrics) + + self.checkpoint_interval = checkpoint_interval + self.image_dump_interval = image_dump_interval + self.task_prefix = '' + self.sw = None + + self.trainset = trainset + self.valset = valset + + logger.info(f'Dataset of {trainset.get_samples_number()} samples was loaded for training.') + logger.info(f'Dataset of {valset.get_samples_number()} samples was loaded for validation.') + + self.train_data = DataLoader( + trainset, cfg.batch_size, + sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed), + drop_last=True, pin_memory=True, + num_workers=cfg.workers, + collate_fn=is_dataset_collate_fn + ) + + self.val_data = DataLoader( + valset, cfg.val_batch_size, + sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed), + drop_last=True, pin_memory=True, + num_workers=cfg.workers, + collate_fn=is_dataset_collate_fn + ) + + if layerwise_decay: + self.optim = get_optimizer_with_layerwise_decay(model, optimizer, optimizer_params) + else: + self.optim = get_optimizer(model, optimizer, optimizer_params) + model = self._load_weights(model) + + if cfg.multi_gpu: + model = get_dp_wrapper(cfg.distributed)(model, device_ids=cfg.gpu_ids, + output_device=cfg.gpu_ids[0]) + + if self.is_master: + logger.info(model) + logger.info(get_config_repr(model._config)) + + self.device = cfg.device + self.net = model.to(self.device) + self.lr = optimizer_params['lr'] + + if lr_scheduler is not None: + self.lr_scheduler = lr_scheduler(optimizer=self.optim) + if cfg.start_epoch > 0: + for _ in range(cfg.start_epoch): + self.lr_scheduler.step() + + self.tqdm_out = TqdmToLogger(logger, level=logging.INFO) + + if self.click_models is not None: + for click_model in self.click_models: + for param in click_model.parameters(): + param.requires_grad = False + click_model.to(self.device) + click_model.eval() + + def run(self, num_epochs, start_epoch=None, validation=True): + if start_epoch is None: + start_epoch = self.cfg.start_epoch + + logger.info(f'Starting Epoch: {start_epoch}') + logger.info(f'Total Epochs: {num_epochs}') + for epoch in range(start_epoch, num_epochs): + self.training(epoch) + if validation: + self.validation(epoch) + + def training(self, epoch): + if self.sw is None and self.is_master: + self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, dump_period=self.tb_dump_period) + + if self.cfg.distributed: + self.train_data.sampler.set_epoch(epoch) + + log_prefix = 'Train' + self.task_prefix.capitalize() + tbar = tqdm(self.train_data, file=self.tqdm_out, ncols=100)\ + if self.is_master else self.train_data + + for metric in self.train_metrics: + metric.reset_epoch_stats() + + self.net.train() + train_loss = 0.0 + for i, batch_data in enumerate(tbar): + global_step = epoch * len(self.train_data) + i + + loss, losses_logging, splitted_batch_data, outputs = \ + self.batch_forward(batch_data) + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + losses_logging['overall'] = loss + reduce_loss_dict(losses_logging) + + train_loss += losses_logging['overall'].item() + + if self.is_master: + for loss_name, loss_value in losses_logging.items(): + self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', + value=loss_value.item(), + global_step=global_step) + + for k, v in self.loss_cfg.items(): + if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0: + v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step) + + # if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0: + # self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train') + + self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate', + value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1], + global_step=global_step) + + tbar.set_description(f'Epoch {epoch}, training loss {train_loss/(i+1):.4f}') + for metric in self.train_metrics: + metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + + if self.is_master: + for metric in self.train_metrics: + self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', + value=metric.get_epoch_value(), + global_step=epoch, disable_avg=True) + + save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, + epoch=None, multi_gpu=self.cfg.multi_gpu) + + if isinstance(self.checkpoint_interval, (list, tuple)): + checkpoint_interval = [x for x in self.checkpoint_interval if x[0] <= epoch][-1][1] + else: + checkpoint_interval = self.checkpoint_interval + + if epoch % checkpoint_interval == 0: + save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, + epoch=epoch, multi_gpu=self.cfg.multi_gpu) + + if hasattr(self, 'lr_scheduler'): + self.lr_scheduler.step() + + def validation(self, epoch): + if self.sw is None and self.is_master: + self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, dump_period=self.tb_dump_period) + + log_prefix = 'Val' + self.task_prefix.capitalize() + tbar = tqdm(self.val_data, file=self.tqdm_out, ncols=100) if self.is_master else self.val_data + + for metric in self.val_metrics: + metric.reset_epoch_stats() + + val_loss = 0 + losses_logging = defaultdict(list) + + self.net.eval() + for i, batch_data in enumerate(tbar): + global_step = epoch * len(self.val_data) + i + loss, batch_losses_logging, splitted_batch_data, outputs = \ + self.batch_forward(batch_data, validation=True) + + batch_losses_logging['overall'] = loss + reduce_loss_dict(batch_losses_logging) + for loss_name, loss_value in batch_losses_logging.items(): + losses_logging[loss_name].append(loss_value.item()) + + val_loss += batch_losses_logging['overall'].item() + + if self.is_master: + tbar.set_description(f'Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}') + for metric in self.val_metrics: + metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + + if self.is_master: + for loss_name, loss_values in losses_logging.items(): + self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=np.array(loss_values).mean(), + global_step=epoch, disable_avg=True) + + for metric in self.val_metrics: + self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(), + global_step=epoch, disable_avg=True) + + def batch_forward(self, batch_data, validation=False): + batch_data = {'images': batch_data[0], 'points': batch_data[1], + 'instances': batch_data[2]} + metrics = self.val_metrics if validation else self.train_metrics + losses_logging = dict() + + with torch.set_grad_enabled(not validation): + batch_data = {k: v.to(self.device) if k!='points' else v for k, v in batch_data.items()} + image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points'] + orig_gt_mask = gt_mask.clone() + + prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] + + last_click_indx = None + # First part + with torch.no_grad(): + num_iters = self.max_num_next_clicks # Here max_num_next_clicks is 3 in default + ''' + In this block, the net will click for num_iters times(max-next-clicks) + prev_output is initial zeros. + The points are processed in dataset + ''' + for click_indx in range(num_iters): + last_click_indx = click_indx + + if not validation: + self.net.eval() + + if self.click_models is None or click_indx >= len(self.click_models): + eval_model = self.net + else: + eval_model = self.click_models[click_indx] + + net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image + prev_output = torch.sigmoid(eval_model(net_input, points)['instances']) # <== This line run the model + prev_output = torch.max(prev_output, dim=1, keepdim=True)[1] + points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1) + + if not validation: + self.net.train() + + if self.net.with_prev_mask and self.prev_mask_drop_prob > 0 and last_click_indx is not None: + zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob + prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask]) + + batch_data['points'] = points # write back if the `for click_indx` loop is executed, if it is not then is remains the same + + net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image + output = self.net(net_input, points) + + loss = 0.0 + loss = self.add_loss('instance_loss', loss, losses_logging, validation, + lambda: (output['instances'], batch_data['instances'])) + loss = self.add_loss('instance_aux_loss', loss, losses_logging, validation, + lambda: (output['instances_aux'], batch_data['instances'])) + + if self.is_master: + with torch.no_grad(): + for m in metrics: + m.update(*(output.get(x) for x in m.pred_outputs), + *(batch_data[x] for x in m.gt_outputs)) + return loss, losses_logging, batch_data, output + + def add_loss(self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs): + loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg + loss_weight = loss_cfg.get(loss_name + '_weight', 0.0) + if loss_weight > 0.0: + loss_criterion = loss_cfg.get(loss_name) + loss = loss_criterion(*lambda_loss_inputs()) + loss = torch.mean(loss) + losses_logging[loss_name] = loss + loss = loss_weight * loss + total_loss = total_loss + loss + + return total_loss + + def save_visualization(self, splitted_batch_data, outputs, global_step, prefix): + output_images_path = self.cfg.VIS_PATH / prefix + if self.task_prefix: + output_images_path /= self.task_prefix + + if not output_images_path.exists(): + output_images_path.mkdir(parents=True) + image_name_prefix = f'{global_step:06d}' + + def _save_image(suffix, image): + cv2.imwrite(str(output_images_path / f'{image_name_prefix}_{suffix}.jpg'), + image, [cv2.IMWRITE_JPEG_QUALITY, 85]) + + images = splitted_batch_data['images'] + points = splitted_batch_data['points'] + instance_masks = splitted_batch_data['instances'] + + gt_instance_masks = instance_masks.cpu().numpy() + predicted_instance_masks = torch.sigmoid(outputs['instances']).detach().cpu().numpy() + points = points.detach().cpu().numpy() + + image_blob, points = images[0], points[0] + gt_mask = np.squeeze(gt_instance_masks[0], axis=0) + predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0) + + image = image_blob.cpu().numpy() * 255 + image = image.transpose((1, 2, 0)) + + image_with_points = draw_points(image, points[:self.max_interactive_points], (0, 255, 0)) + image_with_points = draw_points(image_with_points, points[self.max_interactive_points:], (0, 0, 255)) + + gt_mask[gt_mask < 0] = 0.25 + gt_mask = draw_probmap(gt_mask) + predicted_mask = draw_probmap(predicted_mask) + viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(np.uint8) + + _save_image('instance_segmentation', viz_image[:, :, ::-1]) + + def _load_weights(self, net): + if self.cfg.weights is not None: + if os.path.isfile(self.cfg.weights): + load_weights(net, self.cfg.weights) + self.cfg.weights = None + else: + raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'") + elif self.cfg.resume_exp is not None: + checkpoints = list(self.cfg.CHECKPOINTS_PATH.glob(f'{self.cfg.resume_prefix}*.pth')) + assert len(checkpoints) == 1 + + checkpoint_path = checkpoints[0] + logger.info(f'Load checkpoint from path: {checkpoint_path}') + load_weights(net, str(checkpoint_path)) + return net + + @property + def is_master(self): + return self.cfg.local_rank == 0 + + +def get_next_points(pred, gt, points, click_indx, points_num=15): + """ + This function get points feedback DURING training. + + Args: + pred (_type_): _description_ + gt (_type_): _description_ + points (_type_): _description_ + click_indx (_type_): _description_ + points_num (int, optional): _description_. Defaults to 15. + + Returns: + points: torch.Tensor: [batch_size, num_points, 3] + """ + + assert click_indx > 0 + pred_o = pred.cpu().numpy()[:, 0, :, :] + gt_o = gt.cpu().numpy()[:, :, :, 0] + f_area = np.logical_and(pred_o != gt_o, gt_o != 255) + f_area_idx = np.where(f_area) + num_samples = min(points_num*gt_o.shape[0], len(f_area_idx[0])) + random_indices = np.random.choice(len(f_area_idx[0]), size=num_samples, replace=False) + _points = [[f_area_idx[0][x] + , f_area_idx[1][x] + , f_area_idx[2][x] + , gt_o[f_area_idx[0][x],f_area_idx[1][x], f_area_idx[2][x]] + ] for x in random_indices] + res_points = np.ones((gt_o.shape[0],num_samples,3))*(-1) + for i, point in enumerate(_points): + res_points[point[0], i] = [point[1], point[2], point[3]] + points = torch.cat((points, torch.from_numpy(res_points).float().to(points.device)), dim=1) # TODO Point Number issue + return points + + +def get_contours_and_maxidx(cls, gt, pred): + t_gt = gt == cls + t_pred = pred == cls + t_fn = np.logical_and(t_gt, np.logical_not(t_pred)).astype(np.uint8) + contours, hierarchy = cv2.findContours(t_fn, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + area = [] + for j in range(len(contours)): + area.append(cv2.contourArea(contours[j])) + if len(area) == 0: + return 0, 0, contours, t_fn + else: + max_idx = np.argmax(area) + area_value = area[max_idx] + return area_value, max_idx, contours, t_fn + + +def load_weights(model, path_to_weights): + current_state_dict = model.state_dict() + new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict'] + current_state_dict.update(new_state_dict) + model.load_state_dict(current_state_dict) diff --git a/isegm/engine/trainer.py b/isegm/engine/trainer.py index 72a5a92..a6c4d97 100755 --- a/isegm/engine/trainer.py +++ b/isegm/engine/trainer.py @@ -381,21 +381,21 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) - fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) # A edge padding of 1 pixel of zero for cv2.distanceTransform num_points = points.size(1) // 2 points = points.clone() for bindx in range(fn_mask.shape[0]): fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] - fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] + fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] # Output a mask with each pixel with distance to nearest zero - fn_max_dist = np.max(fn_mask_dt) + fn_max_dist = np.max(fn_mask_dt) # max distance as value fp_max_dist = np.max(fp_mask_dt) is_positive = fn_max_dist > fp_max_dist dt = fn_mask_dt if is_positive else fp_mask_dt - inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 - indices = np.argwhere(inner_mask) + inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 # output a mask with each pixel with distance to nearest zero > (max distance / 2) + indices = np.argwhere(inner_mask) # get the nd list of the pixels that are True in the inner_mask if len(indices) > 0: coords = indices[np.random.randint(0, len(indices))] if is_positive: diff --git a/isegm/model/is_model.py b/isegm/model/is_model.py index cf94aef..b003d39 100755 --- a/isegm/model/is_model.py +++ b/isegm/model/is_model.py @@ -36,7 +36,7 @@ def __init__(self, with_aux_output=False, norm_radius=5, use_disks=False, cpu_di self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps, use_disks=use_disks) - def forward(self, image, points): + def forward(self, image, points): # Here points is still in like (b, 48, 3) form image, prev_mask = self.prepare_input(image) coord_features = self.get_coord_features(image, prev_mask, points) coord_features = self.maps_transform(coord_features) @@ -51,6 +51,23 @@ def forward(self, image, points): return outputs def prepare_input(self, image): + ''' + + Args: + image: + + Returns: + image, prev_mask + + Notes: + Note that Image is in (b, 4, h, w) form via former transforms. + + Image is normalized. + + Prev_mask is sliced from the input image if with_prev_mask is True. Else it's None. + + + ''' prev_mask = None if self.with_prev_mask: prev_mask = image[:, 3:, :, :] diff --git a/isegm/model/is_plainvit_model.py b/isegm/model/is_plainvit_model.py index cd0147b..bc68247 100644 --- a/isegm/model/is_plainvit_model.py +++ b/isegm/model/is_plainvit_model.py @@ -4,6 +4,7 @@ from .is_model import ISModel from .modeling.models_vit import VisionTransformer, PatchEmbed from .modeling.swin_transformer import SwinTransfomerSegHead +from .ops import DistMaps, New_DistMaps class SimpleFPN(nn.Module): @@ -93,3 +94,58 @@ def backbone_forward(self, image, coord_features=None): multi_scale_features = self.neck(backbone_features) return {'instances': self.head(multi_scale_features), 'instances_aux': None} + +class MultiOutVitModel(ISModel): # unused + @serialize + def __init__( + self, + backbone_params={}, + neck_params={}, + head_params={}, + random_split=False, + **kwargs + ): + + super().__init__(**kwargs) + self.random_split = random_split + + self.patch_embed_coords = PatchEmbed( + img_size= backbone_params['img_size'], + patch_size=backbone_params['patch_size'], + in_chans=20 if self.with_prev_mask else 2, + embed_dim=backbone_params['embed_dim'], + ) + + self.backbone = VisionTransformer(**backbone_params) + self.neck = SimpleFPN(**neck_params) + self.head = SwinTransfomerSegHead(**head_params) + self.testMsg = "Test Message" + self.dist_maps = New_DistMaps(norm_radius=5, spatial_scale=1.0, + cpu_mode=False, use_disks=False) + + def backbone_forward(self, image, coord_features=None): + coord_features = self.patch_embed_coords(coord_features) + backbone_features = self.backbone.forward_backbone(image, coord_features, self.random_split) + + # Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + B, N, C = backbone_features.shape + grid_size = self.backbone.patch_embed.grid_size + + backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) + multi_scale_features = self.neck(backbone_features) + + return {'instances': self.head(multi_scale_features), 'instances_aux': None} + + def forward(self, image, points): # Here points is still in like (b, 48, 3) form + image, prev_mask = self.prepare_input(image) + coord_features = self.get_coord_features(image, prev_mask, points) + # coord_features = self.maps_transform(coord_features) + outputs = self.backbone_forward(image, coord_features) + + outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:], + mode='bilinear', align_corners=True) + if self.with_aux_output: + outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:], + mode='bilinear', align_corners=True) + + return outputs diff --git a/isegm/model/losses.py b/isegm/model/losses.py index 38c4ee3..f0e3b24 100755 --- a/isegm/model/losses.py +++ b/isegm/model/losses.py @@ -27,7 +27,7 @@ def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, self._k_sum = 0 self._m_max = 0 - def forward(self, pred, label): + def forward(self, pred, label): # TODO Error here one_hot = label > 0.5 sample_weight = label != self._ignore_label @@ -74,6 +74,75 @@ def log_states(self, sw, name, global_step): sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) +class NormalizedMultiFocalLossSigmoid(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, + from_sigmoid=False, detach_delimeter=True, + batch_axis=0, weight=None, size_average=True, + ignore_label=-1): + super(NormalizedMultiFocalLossSigmoid, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._from_logits = from_sigmoid + self._eps = eps + self._size_average = size_average + self._detach_delimeter = detach_delimeter + self._max_mult = max_mult + self._k_sum = 0 + self._m_max = 0 + + def forward(self, pred, label): # TODO Error here + label = torch.squeeze(label).to(torch.int64) + if len(label.shape) == 2: + label = label.unsqueeze(0) + N, H, W = label.shape + C = pred.shape[1] + label[label==self._ignore_label] = C + label_one_hot = F.one_hot(label, num_classes=C+1)[:,:,:,:C].permute(0, 3, 1, 2).float() + sample_weight = (label != self._ignore_label).float() + sample_weight_expanded = sample_weight.unsqueeze(1) + sample_weight_expanded = sample_weight_expanded.expand(-1, C, -1, -1) + if not self._from_logits: + # 对于多分类问题,我们假设 pred 已经是 softmax 输出 + pred = F.softmax(pred, dim=1) + + alpha = torch.where(label_one_hot == 1, self._alpha * sample_weight_expanded, (1 - self._alpha) * sample_weight_expanded) + pt = torch.where(label_one_hot == 1, pred, 1 - pred) + + beta = (1 - pt) ** self._gamma + + sw_sum = torch.sum(sample_weight_expanded, dim=[0, 2, 3], keepdim=True) + beta_sum = torch.sum(beta, dim=[0, 2, 3], keepdim=True) + mult = sw_sum / (beta_sum + self._eps) + if self._detach_delimeter: + mult = mult.detach() + beta = beta * mult + + # 对 beta 进行裁剪 + if self._max_mult > 0: + beta = torch.clamp(beta, max=self._max_mult) + + # 计算损失 + loss = -alpha * beta * torch.log(torch.clamp(pt, min=self._eps)) + + # 应用样本权重 + loss = loss * sample_weight_expanded + + # 汇总损失 + if self._size_average: + loss = loss.sum(dim=[0, 2, 3]) / (sw_sum.sum(dim=[0, 2, 3]) + self._eps) + else: + loss = loss.sum(dim=[0, 2, 3]) + return loss.mean() + + def log_states(self, sw, name, global_step): + sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) + sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) + class FocalLoss(nn.Module): def __init__(self, axis=-1, alpha=0.25, gamma=2, from_logits=False, batch_axis=0, diff --git a/isegm/model/metrics.py b/isegm/model/metrics.py index a572dcd..d8c3531 100755 --- a/isegm/model/metrics.py +++ b/isegm/model/metrics.py @@ -26,6 +26,89 @@ def name(self): return type(self).__name__ +class AdaptiveMIoU(TrainMetric): + def __init__(self, num_classes, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, + ignore_label=-1, from_logits=True, + pred_output='instances', gt_output='instances'): + super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) + self._ignore_label = ignore_label + self._from_logits = from_logits + self._iou_thresh = [init_thresh] * num_classes + self._thresh_step = thresh_step + self._thresh_beta = thresh_beta + self._iou_beta = iou_beta + self._ema_iou = np.zeros(num_classes) + self._epoch_iou_sum = np.zeros(num_classes) + self._epoch_batch_count = 0 + self.n_classes = num_classes + self.ignore_id = ignore_label + self.confusion_matrix = np.zeros((num_classes + 1, num_classes + 1), dtype=np.int64) + self.img_count = 0 + self.ob_count = 0 + self.click_count = 0 + + def update(self, label_preds, label_trues): + label_preds = torch.argmax(label_preds, dim=1) + label_trues = label_trues.squeeze() + if label_trues.dim() == 2: + label_trues = label_trues.unsqueeze(0) + for lt, lp in zip(label_trues, label_preds): + lt_ = lt.detach().cpu().numpy() + lp_ = lp.detach().cpu().numpy() + self.confusion_matrix += self._fast_hist(lt_.flatten(), lp_.flatten()) + self._epoch_batch_count += 1 + + def _fast_hist(self,label_pred, label_true): + ignore_To = self.n_classes + label_pred = np.where(label_pred == self.ignore_id, ignore_To, label_pred) + mask = (label_true >= 0) & (label_true < self.n_classes) + hist = np.bincount( + (self.n_classes+1) * label_true[mask].astype(int) + label_pred[mask], + minlength=(self.n_classes+1) * (self.n_classes + 1), + ).reshape(self.n_classes+1, self.n_classes + 1) + return hist + + def get_epoch_value(self): + hist = self.confusion_matrix + acc = np.diag(hist).sum() / hist.sum() + acc_cls = np.diag(hist) / hist.sum(axis=1) + acc_cls = np.nanmean(acc_cls) + iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) + iu = iu[:-1] + mean_iu = np.nanmean(iu) + return mean_iu + + def reset_epoch_stats(self): + self.confusion_matrix = np.zeros((self.n_classes+1, self.n_classes+1),dtype=np.int64) + self._epoch_batch_count = 0 + + def log_states(self, sw, tag_prefix, global_step): + hist = self.confusion_matrix + # acc = np.diag(hist).sum() / hist.sum() + # acc_cls = np.diag(hist) / hist.sum(axis=1) + # acc_cls = np.nanmean(acc_cls) + iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) + iu = iu[:-1] + mean_iu = np.nanmean(iu) + cls_iu = dict(zip(range(self.n_classes), iu)) + sw.add_scalar(tag=f'{tag_prefix}_miou', value=mean_iu, global_step=global_step) + for cls in range(self.n_classes): + sw.add_scalar(tag=f'{tag_prefix}_class_{cls}_iou', value=cls_iu[cls], global_step=global_step) + + def get_miou(self, label_trues, label_preds): + hist = np.zeros(self.confusion_matrix.shape) + for lt, lp in zip(label_trues, label_preds): + hist += self._fast_hist(lt.flatten(), lp.flatten()) + acc_cls = np.diag(hist) / hist.sum(axis=1) + iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) + iu = iu[:-1] + mean_iu = np.nanmean(iu) + cls_iu = dict(zip(range(self.n_classes), iu)) + return { + "Mean IoU": mean_iu, + "Class IoU": cls_iu, + } + class AdaptiveIoU(TrainMetric): def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, ignore_label=-1, from_logits=True, diff --git a/isegm/model/ops.py b/isegm/model/ops.py index 9be9c73..dc3b870 100755 --- a/isegm/model/ops.py +++ b/isegm/model/ops.py @@ -54,25 +54,25 @@ def get_coord_features(self, points, batchsize, rows, cols): norm_delimeter)) coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() else: - num_points = points.shape[1] // 2 - points = points.view(-1, points.size(2)) + num_points = points.shape[1] // 2 # 24 + points = points.view(-1, points.size(2)) # (b, 48, 3) -> (b*48, 3) points, points_order = torch.split(points, [2, 1], dim=1) - invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 + invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 # tensor of shape (96,) with True/False row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) coord_rows, coord_cols = torch.meshgrid(row_array, col_array) - coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) + coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) # 96, 2, 448, 448 add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) coords.add_(-add_xy) if not self.use_disks: coords.div_(self.norm_radius * self.spatial_scale) - coords.mul_(coords) + coords.mul_(coords) # 96, 2, h, w coords[:, 0] += coords[:, 1] - coords = coords[:, :1] + coords = coords[:, :1] # Till here, coords store the squared distance from the points to the chosen pixels coords[invalid_points, :, :, :] = 1e6 @@ -90,6 +90,69 @@ def get_coord_features(self, points, batchsize, rows, cols): def forward(self, x, coords): return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) +class New_DistMaps(nn.Module): + def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False): + super(New_DistMaps, self).__init__() + self.spatial_scale = spatial_scale + self.norm_radius = norm_radius + self.cpu_mode = cpu_mode + self.use_disks = use_disks + if self.cpu_mode: + from isegm.utils.cython import get_dist_maps + self._get_dist_maps = get_dist_maps + + def get_coord_features(self, points, batchsize, rows, cols): + if self.cpu_mode: + coords = [] + for i in range(batchsize): + norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius + coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols, + norm_delimeter)) + coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() + else: + # mask = ~(points == torch.tensor([-1, -1, -1],device=points.device)).all(dim=-1) + # points = points[mask] + # points = points.unsqueeze(0) + point_num = points.shape[1] + if point_num == 0: + zeros = torch.zeros(1, 1, rows, cols, device=points.device) + res = zeros.repeat(1, 19, 1, 1) # TODO 7 cls magic + else: + points = points.view(-1, points.size(2)) + points, points_cls = torch.split(points, [2, 1], dim=1) + invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 + row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device='cuda') + col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device='cuda') + coord_rows, coord_cols = torch.meshgrid(row_array, col_array) + + coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) + + add_xy = ((points * 1).view(points.size(0), points.size(1), 1, 1)).to('cuda') + coords.add_(-add_xy) + coords.mul_(coords) # 96, 2, h, w + coords[:, 0] += coords[:, 1] + coords = coords[:, :1] + coords[invalid_points, :, :, :] = 1e6 + coords = coords.view(-1, point_num, 1, rows, cols) + + coords = coords.view(-1, 1, rows, cols) + + coords = (coords <= (5) ** 2).float() + + zeros = torch.zeros_like(coords) + res = zeros.repeat(1, 19, 1, 1) # TODO 19 cls magic + + for i in range(len(points_cls)): + cls = int(points_cls[i]) + res[i, cls] = coords[i, 0] + res = res.view(-1,point_num, 19, rows, cols) # TODO 19 cls magic + res = res.max(dim=1)[0] + res = res.to('cuda') + return res*255 + + def forward(self, x, coords): + return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) + class ScaleLayer(nn.Module): def __init__(self, init_value=1.0, lr_mult=1): diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask.py b/models/iter_mask/multi_out_huge448_pascal_itermask.py new file mode 100644 index 0000000..acfd493 --- /dev/null +++ b/models/iter_mask/multi_out_huge448_pascal_itermask.py @@ -0,0 +1,132 @@ +from isegm.data.datasets.cityscapes import CityScapes +from isegm.data.points_sampler import MultiClassSampler +from isegm.engine.Multi_trainer import Multi_trainer +from isegm.inference.clicker import Click +from isegm.model.is_plainvit_model import MultiOutVitModel +from isegm.model.metrics import AdaptiveMIoU +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss + +MODEL_NAME = 'cocolvis_vit_huge448' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(14,14), + in_chans=3, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1280, + out_dims = [240, 480, 960, 1920], + ) + + head_params = dict( + in_channels=[240, 480, 960, 1920], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=19, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample=cfg.upsample, + channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample], + ) + + model = MultiOutVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=cfg.random_split, + ) + + model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_HUGE) + model.to(cfg.device) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedMultiFocalLossSigmoid(alpha=0.5, gamma=2,ignore_label=255) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiClassSampler(max_num_points=100, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = CityScapes( + cfg.CITYSCAPES_PATH, + first_return_points=cfg.first_return_points, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + # stuff_prob=0.30 + ) + + valset = CityScapes( + cfg.CITYSCAPES_PATH, + first_return_points=cfg.first_return_points, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[50, 55], gamma=0.1) + trainer = Multi_trainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + layerwise_decay=cfg.layerwise_decay, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 20), (50, 1)], + image_dump_interval=300, + metrics=[AdaptiveMIoU(num_classes=19,ignore_label=255)], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=cfg.max_next_clicks) + trainer.run(num_epochs=55, validation=False) \ No newline at end of file diff --git a/models/iter_mask/multi_out_huge448_pascal_itermask_tst.py b/models/iter_mask/multi_out_huge448_pascal_itermask_tst.py new file mode 100644 index 0000000..3271c8c --- /dev/null +++ b/models/iter_mask/multi_out_huge448_pascal_itermask_tst.py @@ -0,0 +1,126 @@ +from isegm.engine.Multi_trainer import Multi_trainer +from isegm.inference.clicker import Click +from isegm.utils.exp_imports.default import * +from isegm.model.modeling.transformer_helper.cross_entropy_loss import CrossEntropyLoss + +MODEL_NAME = 'cocolvis_vit_huge448' + + +def main(cfg): + model, model_cfg = init_model(cfg) + train(model, cfg, model_cfg) + + +def init_model(cfg): + model_cfg = edict() + model_cfg.crop_size = (448, 448) + model_cfg.num_max_points = 24 + + backbone_params = dict( + img_size=model_cfg.crop_size, + patch_size=(14,14), + in_chans=3, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + ) + + neck_params = dict( + in_dim = 1280, + out_dims = [240, 480, 960, 1920], + ) + + head_params = dict( + in_channels=[240, 480, 960, 1920], + in_index=[0, 1, 2, 3], + dropout_ratio=0.1, + num_classes=1, + loss_decode=CrossEntropyLoss(), + align_corners=False, + upsample=cfg.upsample, + channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample], + ) + + model = PlainVitModel( + use_disks=True, + norm_radius=5, + with_prev_mask=True, + backbone_params=backbone_params, + neck_params=neck_params, + head_params=head_params, + random_split=cfg.random_split, + ) + + model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_HUGE) + model.to(cfg.device) + + return model, model_cfg + + +def train(model, cfg, model_cfg): + cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size + cfg.val_batch_size = cfg.batch_size + crop_size = model_cfg.crop_size + + loss_cfg = edict() + loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) + loss_cfg.instance_loss_weight = 1.0 + + train_augmentator = Compose([ + UniformRandomResize(scale_range=(0.75, 1.40)), + HorizontalFlip(), + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size), + RandomBrightnessContrast(brightness_limit=(-0.25, 0.25), contrast_limit=(-0.15, 0.4), p=0.75), + RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.75) + ], p=1.0) + + val_augmentator = Compose([ + PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), + RandomCrop(*crop_size) + ], p=1.0) + + points_sampler = MultiPointSampler(10, prob_gamma=0.80, + merge_objects_prob=0.15, + max_num_merged_objects=2) + + trainset = PASCAL( + cfg.PASCAL_PATH, + split='train', + augmentator=train_augmentator, + min_object_area=1000, + keep_background_prob=0.05, + points_sampler=points_sampler, + epoch_len=30000, + # stuff_prob=0.30 + ) + + valset = PASCAL( + cfg.PASCAL_PATH, + split='val', + augmentator=val_augmentator, + min_object_area=1000, + points_sampler=points_sampler, + epoch_len=2000 + ) + + optimizer_params = { + 'lr': 5e-5, 'betas': (0.9, 0.999), 'eps': 1e-8 + } + + lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, + milestones=[50, 55], gamma=0.1) + trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=optimizer_params, + layerwise_decay=cfg.layerwise_decay, + lr_scheduler=lr_scheduler, + checkpoint_interval=[(0, 20), (50, 1)], + image_dump_interval=300, + metrics=[AdaptiveIoU()], + max_interactive_points=model_cfg.num_max_points, + max_num_next_clicks=3) + trainer.run(num_epochs=55, validation=False) \ No newline at end of file diff --git a/testcnnect.py b/testcnnect.py new file mode 100644 index 0000000..5f2adf3 --- /dev/null +++ b/testcnnect.py @@ -0,0 +1,8 @@ +import os +print("Import cv2") +import matplotlib +import cv2 +print(os.getcwd()) +# print file path +print(os.path.abspath(__file__)) +print("Done!") \ No newline at end of file diff --git a/train.py b/train.py index b2f56c8..efb255c 100755 --- a/train.py +++ b/train.py @@ -76,6 +76,8 @@ def parse_args(): parser.add_argument('--random-split', action='store_true', help='random split the patch instead of window split.') + parser.add_argument('--first-return-points',type=str, default='blank',choices=["init", "random", "blank"]) + parser.add_argument('--max-next-clicks',type=int, default=1) return parser.parse_args()