from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport argparseimport osimport pprintimport torchimport torch.nn.parallelimport torch.backends.cudnn as cudnnimport torch.optimimport torch.utils.dataimport torch.utils.data.distributedimport torchvision.transforms as transformsimport _init_pathsfrom core.config import configfrom core.config import update_configfrom core.config import update_dirfrom core.loss import JointsMSELossfrom core.function import validatefrom utils.utils import create_loggerimport datasetimport modelsdef parse_args():    parser = argparse.ArgumentParser(description='Train keypoints network')    # general    parser.add_argument('--cfg',                        help='experiment configure file name',                        required=True,                        type=str)    args, rest = parser.parse_known_args()    # update config    update_config(args.cfg)    # training    parser.add_argument('--frequent',                        help='frequency of logging',                        default=config.PRINT_FREQ,                        type=int)    parser.add_argument('--gpus',                        help='gpus',                        type=str)    parser.add_argument('--workers',                        help='num of dataloader workers',                        type=int)    parser.add_argument('--model-file',                        help='model state file',                        type=str)    parser.add_argument('--use-detect-bbox',                        help='use detect bbox',                        action='store_true')    parser.add_argument('--flip-test',                        help='use flip test',                        action='store_true')    parser.add_argument('--post-process',                        help='use post process',                        action='store_true')    parser.add_argument('--shift-heatmap',                        help='shift heatmap',                        action='store_true')    parser.add_argument('--coco-bbox-file',                        help='coco detection bbox file',                        type=str)    args = parser.parse_args()    return argsdef reset_config(config, args):    if args.gpus:        config.GPUS = args.gpus    if args.workers:        config.WORKERS = args.workers    if args.use_detect_bbox:        config.TEST.USE_GT_BBOX = not args.use_detect_bbox    if args.flip_test:        config.TEST.FLIP_TEST = args.flip_test    if args.post_process:        config.TEST.POST_PROCESS = args.post_process    if args.shift_heatmap:        config.TEST.SHIFT_HEATMAP = args.shift_heatmap    if args.model_file:        config.TEST.MODEL_FILE = args.model_file    if args.coco_bbox_file:        config.TEST.COCO_BBOX_FILE = args.coco_bbox_filedef main():    args = parse_args()    reset_config(config, args)    logger, final_output_dir, tb_log_dir = create_logger(        config, args.cfg, 'valid')    logger.info(pprint.pformat(args))    logger.info(pprint.pformat(config))    # cudnn related setting    cudnn.benchmark = config.CUDNN.BENCHMARK    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC    torch.backends.cudnn.enabled = config.CUDNN.ENABLED    model = eval('models.'+config.MODEL.NAME+'.get_pose_net')(        config, is_train=False    )    if config.TEST.MODEL_FILE:        logger.info('=> loading model from {}'.format(config.TEST.MODEL_FILE))        model.load_state_dict(torch.load(config.TEST.MODEL_FILE))    else:        model_state_file = os.path.join(final_output_dir,                                        'final_state.pth.tar')        logger.info('=> loading model from {}'.format(model_state_file))        model.load_state_dict(torch.load(model_state_file))    gpus = [int(i) for i in config.GPUS.split(',')]    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()    # define loss function (criterion) and optimizer    criterion = JointsMSELoss(        use_target_weight=config.LOSS.USE_TARGET_WEIGHT    ).cuda()    # Data loading code    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],                                     std=[0.229, 0.224, 0.225])    valid_dataset = eval('dataset.'+config.DATASET.DATASET)(        config,        config.DATASET.ROOT,        config.DATASET.TEST_SET,        False,        transforms.Compose([            transforms.ToTensor(),            normalize,        ])    )    valid_loader = torch.utils.data.DataLoader(        valid_dataset,        batch_size=config.TEST.BATCH_SIZE*len(gpus),        shuffle=False,        num_workers=config.WORKERS,        pin_memory=True    )    # evaluate on validation set    validate(config, valid_loader, valid_dataset, model, criterion,             final_output_dir, tb_log_dir)if __name__ == '__main__':    main()