import loggingimport mathimport numpy as npimport osfrom datetime import datetimeimport psutilimport torchfrom fvcore.nn.flop_count import flop_countfrom matplotlib import pyplot as pltfrom torch import nnimport slowfast.utils.logging as loggingfrom slowfast.datasets.utils import pack_pathway_outputfrom slowfast.models.batchnorm_helper import SubBatchNorm3dlogger = logging.get_logger(__name__)def check_nan_losses(loss):    """    Determine whether the loss is NaN (not a number).    Args:        loss (loss): loss to check whether is NaN.    """    if math.isnan(loss):        raise RuntimeError("ERROR: Got NaN losses {}".format(datetime.now()))def params_count(model):    """    Compute the number of parameters.    Args:        model (model): model to count the number of parameters.    """    return np.sum([p.numel() for p in model.parameters()]).item()def gpu_mem_usage():    """    Compute the GPU memory usage for the current device (GB).    """    mem_usage_bytes = torch.cuda.max_memory_allocated()    return mem_usage_bytes / 1024 ** 3def cpu_mem_usage():    """    Compute the system memory (RAM) usage for the current device (GB).    Returns:        usage (float): used memory (GB).        total (float): total memory (GB).    """    vram = psutil.virtual_memory()    usage = (vram.total - vram.available) / 1024 ** 3    total = vram.total / 1024 ** 3    return usage, totaldef get_flop_stats(model, cfg, is_train):    """    Compute the gflops for the current model given the config.    Args:        model (model): model to compute the flop counts.        cfg (CfgNode): configs. Details can be found in            slowfast/config/defaults.py        is_train (bool): if True, compute flops for training. Otherwise,            compute flops for testing.    Returns:        float: the total number of gflops of the given model.    """    rgb_dimension = 3    if is_train:        input_tensors = torch.rand(            rgb_dimension,            cfg.DATA.NUM_FRAMES,            cfg.DATA.TRAIN_CROP_SIZE,            cfg.DATA.TRAIN_CROP_SIZE,        )    else:        input_tensors = torch.rand(            rgb_dimension,            cfg.DATA.NUM_FRAMES,            cfg.DATA.TEST_CROP_SIZE,            cfg.DATA.TEST_CROP_SIZE,        )    flop_inputs = pack_pathway_output(cfg, input_tensors)    for i in range(len(flop_inputs)):        flop_inputs[i] = flop_inputs[i].unsqueeze(0).cuda(non_blocking=True)    # If detection is enabled, count flops for one proposal.    if cfg.DETECTION.ENABLE:        bbox = torch.tensor([[0, 0, 1.0, 0, 1.0]])        bbox = bbox.cuda()        inputs = (flop_inputs, bbox)    else:        inputs = (flop_inputs,)    gflop_dict, _ = flop_count(model, inputs)    gflops = sum(gflop_dict.values())    return gflopsdef log_model_info(model, cfg, is_train=True):    """    Log info, includes number of parameters, gpu usage and gflops.    Args:        model (model): model to log the info.        cfg (CfgNode): configs. Details can be found in            slowfast/config/defaults.py        is_train (bool): if True, log info for training. Otherwise,            log info for testing.    """    logger.info("Model:\n{}".format(model))    logger.info("Params: {:,}".format(params_count(model)))    logger.info("Mem: {:,} MB".format(gpu_mem_usage()))    logger.info(        "FLOPs: {:,} GFLOPs".format(get_flop_stats(model, cfg, is_train))    )    logger.info("nvidia-smi")    os.system("nvidia-smi")def is_eval_epoch(cfg, cur_epoch):    """    Determine if the model should be evaluated at the current epoch.    Args:        cfg (CfgNode): configs. Details can be found in            slowfast/config/defaults.py        cur_epoch (int): current epoch.    """    return (        cur_epoch + 1    ) % cfg.TRAIN.EVAL_PERIOD == 0 or cur_epoch + 1 == cfg.SOLVER.MAX_EPOCHdef plot_input(tensor, bboxes=(), texts=(), path="./tmp_vis.png"):    """    Plot the input tensor with the optional bounding box and save it to disk.    Args:        tensor (tensor): a tensor with shape of `NxCxHxW`.        bboxes (tuple): bounding boxes with format of [[x, y, h, w]].        texts (tuple): a tuple of string to plot.        path (str): path to the image to save to.    """    tensor = tensor - tensor.min()    tensor = tensor / tensor.max()    f, ax = plt.subplots(nrows=1, ncols=tensor.shape[0], figsize=(50, 20))    for i in range(tensor.shape[0]):        ax[i].axis("off")        ax[i].imshow(tensor[i].permute(1, 2, 0))        # ax[1][0].axis('off')        if bboxes is not None and len(bboxes) > i:            for box in bboxes[i]:                x1, y1, x2, y2 = box                ax[i].vlines(x1, y1, y2, colors="g", linestyles="solid")                ax[i].vlines(x2, y1, y2, colors="g", linestyles="solid")                ax[i].hlines(y1, x1, x2, colors="g", linestyles="solid")                ax[i].hlines(y2, x1, x2, colors="g", linestyles="solid")        if texts is not None and len(texts) > i:            ax[i].text(0, 0, texts[i])    f.savefig(path)def frozen_bn_stats(model):    """    Set all the bn layers to eval mode.    Args:        model (model): model to set bn layers to eval mode.    """    for m in model.modules():        if isinstance(m, nn.BatchNorm3d):            m.eval()def aggregate_split_bn_stats(module):    """    Recursively find all SubBN modules and aggregate sub-BN stats.    Args:        module (nn.Module)    Returns:        count (int): number of SubBN module found.    """    count = 0    for child in module.children():        if isinstance(child, SubBatchNorm3d):            child.aggregate_stats()            count += 1        else:            count += aggregate_split_bn_stats(child)    return count