import argparseimport osimport pprintimport shutilimport sysimport loggingimport timeimport timeitfrom pathlib import Pathimport numpy as npimport torchimport torch.nn as nnimport torch.backends.cudnn as cudnnimport torch.optimfrom torch.utils.data.distributed import DistributedSamplerfrom tensorboardX import SummaryWriterimport _init_pathsimport modelsimport datasetsfrom config import configfrom config import update_configfrom core.criterion import CrossEntropy, OhemCrossEntropyfrom core.function import train, validatefrom utils.modelsummary import get_model_summaryfrom utils.utils import create_logger, FullModel, get_rankdef parse_args():    parser = argparse.ArgumentParser(description='Train segmentation network')        parser.add_argument('--cfg',                        help='experiment configure file name',                        required=True,                        type=str)    parser.add_argument("--local_rank", type=int, default=0)    parser.add_argument('opts',                        help="Modify config options using the command-line",                        default=None,                        nargs=argparse.REMAINDER)    parser.add_argument('--freq', type=int, default=1)    args = parser.parse_args()    update_config(config, args)    return argsdef main():    args = parse_args()    logger, final_output_dir, tb_log_dir = create_logger(        config, args.cfg, 'train')    logger.info(pprint.pformat(args))    logger.info(config)    writer_dict = {        'writer': SummaryWriter(tb_log_dir),        'train_global_steps': 0,        'valid_global_steps': 0,    }    # cudnn related setting    cudnn.benchmark = config.CUDNN.BENCHMARK    cudnn.deterministic = config.CUDNN.DETERMINISTIC    cudnn.enabled = config.CUDNN.ENABLED    gpus = list(config.GPUS)    distributed = len(gpus) >= 1    device = torch.device('cuda:{}'.format(args.local_rank))    # build model    model = eval('models.'+config.MODEL.NAME +                 '.get_seg_model')(config)    if args.local_rank == 0:        # provide the summary of model        dump_input = torch.rand(            (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])            )        logger.info(get_model_summary(model.to(device), dump_input.to(device)))        # copy model file        this_dir = os.path.dirname(__file__)        models_dst_dir = os.path.join(final_output_dir, 'models')        if os.path.exists(models_dst_dir):            shutil.rmtree(models_dst_dir)        shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)    if distributed:        torch.cuda.set_device(args.local_rank)        torch.distributed.init_process_group(            backend="nccl", init_method="env://",        )    # prepare data    crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])    train_dataset = eval('datasets.'+config.DATASET.DATASET)(                        root=config.DATASET.ROOT,                        list_path=config.DATASET.TRAIN_SET,                        num_samples=None,                        num_classes=config.DATASET.NUM_CLASSES,                        multi_scale=config.TRAIN.MULTI_SCALE,                        flip=config.TRAIN.FLIP,                        ignore_label=config.TRAIN.IGNORE_LABEL,                        base_size=config.TRAIN.BASE_SIZE,                        crop_size=crop_size,                        downsample_rate=config.TRAIN.DOWNSAMPLERATE,                        scale_factor=config.TRAIN.SCALE_FACTOR)    if distributed:        train_sampler = DistributedSampler(train_dataset)    else:        train_sampler = None    trainloader = torch.utils.data.DataLoader(        train_dataset,        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,        shuffle=config.TRAIN.SHUFFLE and train_sampler is None,        num_workers=config.WORKERS,        pin_memory=True,        drop_last=True,        sampler=train_sampler)    if config.DATASET.EXTRA_TRAIN_SET:        extra_train_dataset = eval('datasets.'+config.DATASET.DATASET)(                    root=config.DATASET.ROOT,                    list_path=config.DATASET.EXTRA_TRAIN_SET,                    num_samples=None,                    num_classes=config.DATASET.NUM_CLASSES,                    multi_scale=config.TRAIN.MULTI_SCALE,                    flip=config.TRAIN.FLIP,                    ignore_label=config.TRAIN.IGNORE_LABEL,                    base_size=config.TRAIN.BASE_SIZE,                    crop_size=crop_size,                    downsample_rate=config.TRAIN.DOWNSAMPLERATE,                    scale_factor=config.TRAIN.SCALE_FACTOR)        if distributed:            extra_train_sampler = DistributedSampler(extra_train_dataset)        else:            extra_train_sampler = None        extra_trainloader = torch.utils.data.DataLoader(            extra_train_dataset,            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,            shuffle=config.TRAIN.SHUFFLE and extra_train_sampler is None,            num_workers=config.WORKERS,            pin_memory=True,            drop_last=True,            sampler=extra_train_sampler)    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])    test_dataset = eval('datasets.'+config.DATASET.DATASET)(                        root=config.DATASET.ROOT,                        list_path=config.DATASET.TEST_SET,                        num_samples=config.TEST.NUM_SAMPLES,                        num_classes=config.DATASET.NUM_CLASSES,                        multi_scale=False,                        flip=False,                        ignore_label=config.TRAIN.IGNORE_LABEL,                        base_size=config.TEST.BASE_SIZE,                        crop_size=test_size,                        center_crop_test=config.TEST.CENTER_CROP_TEST,                        downsample_rate=1)    if distributed:        test_sampler = DistributedSampler(test_dataset)    else:        test_sampler = None    testloader = torch.utils.data.DataLoader(        test_dataset,        batch_size=config.TEST.BATCH_SIZE_PER_GPU,        shuffle=False,        num_workers=config.WORKERS,        pin_memory=False,        sampler=test_sampler)    # criterion    if config.LOSS.USE_OHEM:        criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,                                     thres=config.LOSS.OHEMTHRES,                                     min_kept=config.LOSS.OHEMKEEP,                                     weight=train_dataset.class_weights)    else:        criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,                                 weight=train_dataset.class_weights)    model = FullModel(model, criterion)    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)    model = model.to(device)    model = nn.parallel.DistributedDataParallel(        model, device_ids=[args.local_rank], output_device=args.local_rank)    # optimizer    if config.TRAIN.OPTIMIZER == 'sgd':        optimizer = torch.optim.SGD([{'params':                                  filter(lambda p: p.requires_grad,                                         model.parameters()),                                  'lr': config.TRAIN.LR}],                                lr=config.TRAIN.LR,                                momentum=config.TRAIN.MOMENTUM,                                weight_decay=config.TRAIN.WD,                                nesterov=config.TRAIN.NESTEROV,                                )    else:        raise ValueError('Only Support SGD optimizer')    epoch_iters = np.int(train_dataset.__len__() /                         config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))    best_mIoU = 0    last_epoch = 0    if config.TRAIN.RESUME:        model_state_file = os.path.join(final_output_dir,                                        'checkpoint.pth.tar')        if os.path.isfile(model_state_file):            checkpoint = torch.load(model_state_file,                         map_location=lambda storage, loc: storage)            best_mIoU = checkpoint['best_mIoU']            last_epoch = checkpoint['epoch']            model.module.load_state_dict(checkpoint['state_dict'])            optimizer.load_state_dict(checkpoint['optimizer'])            logger.info("=> loaded checkpoint (epoch {})"                        .format(checkpoint['epoch']))    start = timeit.default_timer()    end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH    num_iters = config.TRAIN.END_EPOCH * epoch_iters    extra_iters = config.TRAIN.EXTRA_EPOCH * epoch_iters        for epoch in range(last_epoch, end_epoch):        if distributed:            train_sampler.set_epoch(epoch)        if epoch >= config.TRAIN.END_EPOCH:            train(config, epoch-config.TRAIN.END_EPOCH,                   config.TRAIN.EXTRA_EPOCH, epoch_iters,                   config.TRAIN.EXTRA_LR, extra_iters,                   extra_trainloader, optimizer, model,                   writer_dict, device)        else:            train(config, epoch, config.TRAIN.END_EPOCH,                   epoch_iters, config.TRAIN.LR, num_iters,                  trainloader, optimizer, model, writer_dict,                  device)        if epoch % args.freq == 0 or epoch == end_epoch-1:            valid_loss, mean_IoU, IoU_array = validate(config,                         testloader, model, writer_dict, device)        if args.local_rank == 0:            logger.info('=> saving checkpoint to {}'.format(                final_output_dir + 'checkpoint.pth.tar'))            torch.save({                'epoch': epoch+1,                'best_mIoU': best_mIoU,                'state_dict': model.module.state_dict(),                'optimizer': optimizer.state_dict(),            }, os.path.join(final_output_dir,'checkpoint.pth.tar'))            if epoch % args.freq == 0 or epoch == config.TRAIN.END_EPOCH:                if mean_IoU > best_mIoU:                    best_mIoU = mean_IoU                    torch.save(model.module.state_dict(),                            os.path.join(final_output_dir, 'best.pth'))                msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(                        valid_loss, mean_IoU, best_mIoU)                logging.info(msg)                logging.info(IoU_array)            if epoch == end_epoch - 1:                torch.save(model.module.state_dict(),                       os.path.join(final_output_dir, 'final_state.pth'))                writer_dict['writer'].close()                end = timeit.default_timer()                logger.info('Hours: %d' % np.int((end-start)/3600))                logger.info('Done')if __name__ == '__main__':    main()