from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport argparseimport osimport pprintimport jsonimport timeimport torchimport torch.backends.cudnn as cudnnimport torch.nn.parallelimport torch.optimimport torch.utils.dataimport timeimport torch.utils.data.distributedimport torchvision.transformsimport torch.multiprocessingfrom tqdm import tqdmimport _init_pathsimport modelsfrom config import cfgfrom config import check_configfrom config import update_configfrom core.inference import get_multi_stage_outputsfrom core.inference import aggregate_resultsfrom core.group import HeatmapParserfrom dataset import make_test_dataloader, make_train_dataloaderfrom fp16_utils.fp16util import network_to_halffrom utils.utils import create_loggerfrom utils.utils import get_model_summaryfrom utils.vis import save_debug_imagesfrom utils.vis import save_valid_imagefrom utils.transforms import resize_align_multi_scalefrom utils.transforms import get_final_predsfrom utils.transforms import get_multi_scale_sizefrom arch_manager import ArchManagertorch.multiprocessing.set_sharing_strategy('file_system')def parse_args():    parser = argparse.ArgumentParser(description='Test keypoints network')    # general    parser.add_argument('--cfg',                        help='experiment configure file name',                        required=True,                        type=str)    parser.add_argument('opts',                        help="Modify config options using the command-line",                        default=None,                        nargs=argparse.REMAINDER)                            #fixed config for supernet    parser.add_argument('--superconfig',                        default=None,                        type=str,                        help='fixed arch for supernet training')    args = parser.parse_args()    return args# markdown format outputdef _print_name_value(logger, name_value, full_arch_name):    names = name_value.keys()    values = name_value.values()    num_values = len(name_value)    logger.info(        '| Arch ' +        ' '.join(['| {}'.format(name) for name in names]) +        ' |'    )    logger.info('|---' * (num_values+1) + '|')    if len(full_arch_name) > 15:        full_arch_name = full_arch_name[:8] + '...'    logger.info(        '| ' + full_arch_name + ' ' +        ' '.join(['| {:.3f}'.format(value) for value in values]) +         ' |'    )def main():    args = parse_args()    update_config(cfg, args)    check_config(cfg)    # change the resolution according to config    fixed_arch = None    if args.superconfig is not None:        with open(args.superconfig, 'r') as f:           fixed_arch = json.load(f)        cfg.defrost()        reso = fixed_arch['img_size']        cfg.DATASET.INPUT_SIZE = reso        cfg.DATASET.OUTPUT_SIZE = [reso // 4, reso // 2]        cfg.freeze()    logger, final_output_dir, tb_log_dir = create_logger(        cfg, args.cfg, 'valid'    )    logger.info(pprint.pformat(args))    logger.info(cfg)    # cudnn related setting    cudnn.benchmark = cfg.CUDNN.BENCHMARK    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED    if cfg.MODEL.NAME == 'pose_mobilenet' or cfg.MODEL.NAME == 'pose_simplenet':        arch_manager = ArchManager(cfg)        cfg_arch = arch_manager.fixed_sample()        if fixed_arch is not None:            cfg_arch = fixed_arch        model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(            cfg, is_train=True, cfg_arch = cfg_arch        )    else:        model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(            cfg, is_train=True        )    #set super config    if cfg.MODEL.NAME == 'pose_supermobilenet':        model.arch_manager.is_search = True        if args.superconfig is not None:            with open(args.superconfig, 'r') as f:                model.arch_manager.search_arch = json.load(f)        else:            model.arch_manager.search_arch = model.arch_manager.fixed_sample()    dump_input = torch.rand(        (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE)    )    logger.info(get_model_summary(cfg.DATASET.INPUT_SIZE, model, dump_input))    if cfg.FP16.ENABLED:        model = network_to_half(model)    if cfg.TEST.MODEL_FILE:        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)    else:        model_state_file = os.path.join(            final_output_dir, 'model_best.pth.tar'        )        logger.info('=> loading model from {}'.format(model_state_file))        model.load_state_dict(torch.load(model_state_file))    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()    data_loader, test_dataset = make_test_dataloader(cfg)    train_data_loader, train_dataset = make_train_dataloader(cfg)    if cfg.MODEL.NAME == 'pose_hourglass':        transforms = torchvision.transforms.Compose(            [                torchvision.transforms.ToTensor(),            ]        )    else:        transforms = torchvision.transforms.Compose(            [                torchvision.transforms.ToTensor(),                torchvision.transforms.Normalize(                    mean=[0.485, 0.456, 0.406],                    std=[0.229, 0.224, 0.225]                )            ]        )    parser = HeatmapParser(cfg)    all_preds = []    all_scores = []    pbar = tqdm(total=len(test_dataset)) if cfg.TEST.LOG_PROGRESS else None    #eval mode    model.eval()    for i, (images, annos) in enumerate(data_loader):        assert 1 == images.size(0), 'Test batch size should be 1'        image = images[0].cpu().numpy()        # size at scale 1.0        base_size, center, scale = get_multi_scale_size(            image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR)        )        with torch.no_grad():            infer_begin = time.time()            final_heatmaps = None            tags_list = []            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):                input_size = cfg.DATASET.INPUT_SIZE                image_resized, center, scale = resize_align_multi_scale(                    image, input_size, s, min(cfg.TEST.SCALE_FACTOR)                )                image_resized = transforms(image_resized)                image_resized = image_resized.unsqueeze(0).cuda()                outputs, heatmaps, tags = get_multi_stage_outputs(                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,                    cfg.TEST.PROJECT2IMAGE,base_size                )                final_heatmaps, tags_list = aggregate_results(                    cfg, s, final_heatmaps, tags_list, heatmaps, tags                )            final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))            tags = torch.cat(tags_list, dim=4)            group_begin = time.time()            grouped, scores = parser.parse(                final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE            )            final_results = get_final_preds(                grouped, center, scale,                [final_heatmaps.size(3), final_heatmaps.size(2)]            )        if cfg.TEST.LOG_PROGRESS:            pbar.update()        if i % cfg.PRINT_FREQ == 0:            print("finish images: {}".format(i))            # prefix = '{}_{}'.format(os.path.join(final_output_dir, 'result_valid'), i)            # save_valid_image(image, final_results, '{}.jpg'.format(prefix), dataset=test_dataset.name)        all_preds.append(final_results)        all_scores.append(scores)    if cfg.TEST.LOG_PROGRESS:        pbar.close()    name_values, _ = test_dataset.evaluate(        cfg, all_preds, all_scores, final_output_dir    )    if isinstance(name_values, list):        for name_value in name_values:            _print_name_value(logger, name_value, cfg.MODEL.NAME)    else:        _print_name_value(logger, name_values, cfg.MODEL.NAME)if __name__ == '__main__':    main()