import osimport yamlfrom yacs.config import CfgNode as CN_C = CN()# Base config files_C.BASE = ['']# -----------------------------------------------------------------------------# Data settings# -----------------------------------------------------------------------------_C.DATA = CN()# Batch size for a single GPU, could be overwritten by command line argument_C.DATA.BATCH_SIZE = 128# Path to dataset, could be overwritten by command line argument_C.DATA.DATA_PATH = ''# Dataset name_C.DATA.DATASET = 'imagenet'# Input image size_C.DATA.IMG_SIZE = 224# Interpolation to resize image (random, bilinear, bicubic)_C.DATA.INTERPOLATION = 'bicubic'# Use zipped dataset instead of folder dataset# could be overwritten by command line argument_C.DATA.ZIP_MODE = False# Cache Data in Memory, could be overwritten by command line argument_C.DATA.CACHE_MODE = 'part'# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU._C.DATA.PIN_MEMORY = True# Number of data loading threads_C.DATA.NUM_WORKERS = 8# [SimMIM] Mask patch size for MaskGenerator_C.DATA.MASK_PATCH_SIZE = 32# [SimMIM] Mask ratio for MaskGenerator_C.DATA.MASK_RATIO = 0.6# -----------------------------------------------------------------------------# Model settings# -----------------------------------------------------------------------------_C.MODEL = CN()# Model type_C.MODEL.TYPE = 'swin'# Model name_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'# Pretrained weight from checkpoint, could be imagenet22k pretrained weight# could be overwritten by command line argument_C.MODEL.PRETRAINED = ''# Checkpoint to resume, could be overwritten by command line argument_C.MODEL.RESUME = ''# Number of classes, overwritten in data preparation_C.MODEL.NUM_CLASSES = 1000# Dropout rate_C.MODEL.DROP_RATE = 0.0# Drop path rate_C.MODEL.DROP_PATH_RATE = 0.1# Label Smoothing_C.MODEL.LABEL_SMOOTHING = 0.1# Swin Transformer parameters_C.MODEL.SWIN = CN()_C.MODEL.SWIN.PATCH_SIZE = 4_C.MODEL.SWIN.IN_CHANS = 3_C.MODEL.SWIN.EMBED_DIM = 96_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]_C.MODEL.SWIN.WINDOW_SIZE = 7_C.MODEL.SWIN.MLP_RATIO = 4._C.MODEL.SWIN.QKV_BIAS = True_C.MODEL.SWIN.QK_SCALE = None_C.MODEL.SWIN.APE = False_C.MODEL.SWIN.PATCH_NORM = True# Swin Transformer V2 parameters_C.MODEL.SWINV2 = CN()_C.MODEL.SWINV2.PATCH_SIZE = 4_C.MODEL.SWINV2.IN_CHANS = 3_C.MODEL.SWINV2.EMBED_DIM = 96_C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2]_C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24]_C.MODEL.SWINV2.WINDOW_SIZE = 7_C.MODEL.SWINV2.MLP_RATIO = 4._C.MODEL.SWINV2.QKV_BIAS = True_C.MODEL.SWINV2.APE = False_C.MODEL.SWINV2.PATCH_NORM = True_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]# Swin Transformer MoE parameters_C.MODEL.SWIN_MOE = CN()_C.MODEL.SWIN_MOE.PATCH_SIZE = 4_C.MODEL.SWIN_MOE.IN_CHANS = 3_C.MODEL.SWIN_MOE.EMBED_DIM = 96_C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2]_C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24]_C.MODEL.SWIN_MOE.WINDOW_SIZE = 7_C.MODEL.SWIN_MOE.MLP_RATIO = 4._C.MODEL.SWIN_MOE.QKV_BIAS = True_C.MODEL.SWIN_MOE.QK_SCALE = None_C.MODEL.SWIN_MOE.APE = False_C.MODEL.SWIN_MOE.PATCH_NORM = True_C.MODEL.SWIN_MOE.MLP_FC2_BIAS = True_C.MODEL.SWIN_MOE.INIT_STD = 0.02_C.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]_C.MODEL.SWIN_MOE.MOE_BLOCKS = [[-1], [-1], [-1], [-1]]_C.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS = 1_C.MODEL.SWIN_MOE.TOP_VALUE = 1_C.MODEL.SWIN_MOE.CAPACITY_FACTOR = 1.25_C.MODEL.SWIN_MOE.COSINE_ROUTER = False_C.MODEL.SWIN_MOE.NORMALIZE_GATE = False_C.MODEL.SWIN_MOE.USE_BPR = True_C.MODEL.SWIN_MOE.IS_GSHARD_LOSS = False_C.MODEL.SWIN_MOE.GATE_NOISE = 1.0_C.MODEL.SWIN_MOE.COSINE_ROUTER_DIM = 256_C.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T = 0.5_C.MODEL.SWIN_MOE.MOE_DROP = 0.0_C.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT = 0.01# Swin MLP parameters_C.MODEL.SWIN_MLP = CN()_C.MODEL.SWIN_MLP.PATCH_SIZE = 4_C.MODEL.SWIN_MLP.IN_CHANS = 3_C.MODEL.SWIN_MLP.EMBED_DIM = 96_C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]_C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]_C.MODEL.SWIN_MLP.WINDOW_SIZE = 7_C.MODEL.SWIN_MLP.MLP_RATIO = 4._C.MODEL.SWIN_MLP.APE = False_C.MODEL.SWIN_MLP.PATCH_NORM = True# [SimMIM] Norm target during training_C.MODEL.SIMMIM = CN()_C.MODEL.SIMMIM.NORM_TARGET = CN()_C.MODEL.SIMMIM.NORM_TARGET.ENABLE = False_C.MODEL.SIMMIM.NORM_TARGET.PATCH_SIZE = 47# -----------------------------------------------------------------------------# Training settings# -----------------------------------------------------------------------------_C.TRAIN = CN()_C.TRAIN.START_EPOCH = 0_C.TRAIN.EPOCHS = 300_C.TRAIN.WARMUP_EPOCHS = 20_C.TRAIN.WEIGHT_DECAY = 0.05_C.TRAIN.BASE_LR = 5e-4_C.TRAIN.WARMUP_LR = 5e-7_C.TRAIN.MIN_LR = 5e-6# Clip gradient norm_C.TRAIN.CLIP_GRAD = 5.0# Auto resume from latest checkpoint_C.TRAIN.AUTO_RESUME = True# Gradient accumulation steps# could be overwritten by command line argument_C.TRAIN.ACCUMULATION_STEPS = 1# Whether to use gradient checkpointing to save memory# could be overwritten by command line argument_C.TRAIN.USE_CHECKPOINT = False# LR scheduler_C.TRAIN.LR_SCHEDULER = CN()_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'# Epoch interval to decay LR, used in StepLRScheduler_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30# LR decay rate, used in StepLRScheduler_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1# warmup_prefix used in CosineLRScheduler_C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = True# [SimMIM] Gamma / Multi steps value, used in MultiStepLRScheduler_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1_C.TRAIN.LR_SCHEDULER.MULTISTEPS = []# Optimizer_C.TRAIN.OPTIMIZER = CN()_C.TRAIN.OPTIMIZER.NAME = 'adamw'# Optimizer Epsilon_C.TRAIN.OPTIMIZER.EPS = 1e-8# Optimizer Betas_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)# SGD momentum_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9# [SimMIM] Layer decay for fine-tuning_C.TRAIN.LAYER_DECAY = 1.0# MoE_C.TRAIN.MOE = CN()# Only save model on master device_C.TRAIN.MOE.SAVE_MASTER = False# -----------------------------------------------------------------------------# Augmentation settings# -----------------------------------------------------------------------------_C.AUG = CN()# Color jitter factor_C.AUG.COLOR_JITTER = 0.4# Use AutoAugment policy. "v0" or "original"_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'# Random erase prob_C.AUG.REPROB = 0.25# Random erase mode_C.AUG.REMODE = 'pixel'# Random erase count_C.AUG.RECOUNT = 1# Mixup alpha, mixup enabled if > 0_C.AUG.MIXUP = 0.8# Cutmix alpha, cutmix enabled if > 0_C.AUG.CUTMIX = 1.0# Cutmix min/max ratio, overrides alpha and enables cutmix if set_C.AUG.CUTMIX_MINMAX = None# Probability of performing mixup or cutmix when either/both is enabled_C.AUG.MIXUP_PROB = 1.0# Probability of switching to cutmix when both mixup and cutmix enabled_C.AUG.MIXUP_SWITCH_PROB = 0.5# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"_C.AUG.MIXUP_MODE = 'batch'# -----------------------------------------------------------------------------# Testing settings# -----------------------------------------------------------------------------_C.TEST = CN()# Whether to use center crop when testing_C.TEST.CROP = True# Whether to use SequentialSampler as validation sampler_C.TEST.SEQUENTIAL = False_C.TEST.SHUFFLE = False# -----------------------------------------------------------------------------# Misc# -----------------------------------------------------------------------------# [SimMIM] Whether to enable pytorch amp, overwritten by command line argument_C.ENABLE_AMP = False# Enable Pytorch automatic mixed precision (amp)._C.AMP_ENABLE = True# [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2')_C.AMP_OPT_LEVEL = ''# Path to output folder, overwritten by command line argument_C.OUTPUT = ''# Tag of experiment, overwritten by command line argument_C.TAG = 'default'# Frequency to save checkpoint_C.SAVE_FREQ = 1# Frequency to logging info_C.PRINT_FREQ = 10# Fixed random seed_C.SEED = 0# Perform evaluation only, overwritten by command line argument_C.EVAL_MODE = False# Test throughput only, overwritten by command line argument_C.THROUGHPUT_MODE = False# local rank for DistributedDataParallel, given by command line argument_C.LOCAL_RANK = 0# for acceleration_C.FUSED_WINDOW_PROCESS = False_C.FUSED_LAYERNORM = Falsedef _update_config_from_file(config, cfg_file):    config.defrost()    with open(cfg_file, 'r') as f:        yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)    for cfg in yaml_cfg.setdefault('BASE', ['']):        if cfg:            _update_config_from_file(                config, os.path.join(os.path.dirname(cfg_file), cfg)            )    print('=> merge config from {}'.format(cfg_file))    config.merge_from_file(cfg_file)    config.freeze()def update_config(config, args):    _update_config_from_file(config, args.cfg)    config.defrost()    if args.opts:        config.merge_from_list(args.opts)    def _check_args(name):        if hasattr(args, name) and eval(f'args.{name}'):            return True        return False    # merge from specific arguments    if _check_args('batch_size'):        config.DATA.BATCH_SIZE = args.batch_size    if _check_args('data_path'):        config.DATA.DATA_PATH = args.data_path    if _check_args('zip'):        config.DATA.ZIP_MODE = True    if _check_args('cache_mode'):        config.DATA.CACHE_MODE = args.cache_mode    if _check_args('pretrained'):        config.MODEL.PRETRAINED = args.pretrained    if _check_args('resume'):        config.MODEL.RESUME = args.resume    if _check_args('accumulation_steps'):        config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps    if _check_args('use_checkpoint'):        config.TRAIN.USE_CHECKPOINT = True    if _check_args('amp_opt_level'):        print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")        if args.amp_opt_level == 'O0':            config.AMP_ENABLE = False    if _check_args('disable_amp'):        config.AMP_ENABLE = False    if _check_args('output'):        config.OUTPUT = args.output    if _check_args('tag'):        config.TAG = args.tag    if _check_args('eval'):        config.EVAL_MODE = True    if _check_args('throughput'):        config.THROUGHPUT_MODE = True    # [SimMIM]    if _check_args('enable_amp'):        config.ENABLE_AMP = args.enable_amp    # for acceleration    if _check_args('fused_window_process'):        config.FUSED_WINDOW_PROCESS = True    if _check_args('fused_layernorm'):        config.FUSED_LAYERNORM = True    ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb]    if _check_args('optim'):        config.TRAIN.OPTIMIZER.NAME = args.optim    # set local rank for distributed training    config.LOCAL_RANK = args.local_rank    # output folder    config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)    config.freeze()def get_config(args):    """Get a yacs CfgNode object with default values."""    # Return a clone so that the defaults will not be altered    # This is for the "local variable" use pattern    config = _C.clone()    update_config(config, args)    return config