from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport loggingimport torchimport torch.nn as nnimport numpy as nplogger = logging.getLogger(__name__)def make_input(t, requires_grad=False, need_cuda=True):    inp = torch.autograd.Variable(t, requires_grad=requires_grad)    inp = inp.sum()    if need_cuda:        inp = inp.cuda()    return inpclass HeatmapLoss(nn.Module):    def __init__(self):        super().__init__()    def forward(self, pred, gt, mask):        assert pred.size() == gt.size()        loss = ((pred - gt)**2) * mask[:, None, :, :].expand_as(pred)        loss = loss.mean(dim=3).mean(dim=2).mean(dim=1)        # loss = loss.mean(dim=3).mean(dim=2).sum(dim=1)        return lossclass AELoss(nn.Module):    def __init__(self, loss_type, max_num_people, output_resolution):        super().__init__()        self.loss_type = loss_type        self.max_num_people = max_num_people        self.output_resolution = output_resolution    def singleTagLoss(self, pred_tag, joints):        """        associative embedding loss for one image        """        tags = []        pull = 0        for joints_per_person in joints:            tmp = []            for joint in joints_per_person:                if joint[1] > 0:                    tmp.append(pred_tag[joint[0]])            if len(tmp) == 0:                continue            tmp = torch.stack(tmp)            tags.append(torch.mean(tmp, dim=0))            pull = pull + torch.mean((tmp - tags[-1].expand_as(tmp))**2)        num_tags = len(tags)        if num_tags == 0:            return make_input(torch.zeros(1).float()), \                make_input(torch.zeros(1).float())        elif num_tags == 1:            return make_input(torch.zeros(1).float()), \                pull/(num_tags)        tags = torch.stack(tags)        size = (num_tags, num_tags)        A = tags.expand(*size)        B = A.permute(1, 0)        diff = A - B        if self.loss_type == 'exp':            diff = torch.pow(diff, 2)            push = torch.exp(-diff)            push = torch.sum(push) - num_tags        elif self.loss_type == 'max':            diff = 1 - torch.abs(diff)            push = torch.clamp(diff, min=0).sum() - num_tags        else:            raise ValueError('Unkown ae loss type')        return push/((num_tags - 1) * num_tags) * 0.5, \            pull/(num_tags)    def batchTagLoss(self, pred_tag_map, joints):        '''        pred_tag_map: (B, K * H * W)        joints: (B, 30, K, 2)        '''        if pred_tag_map.dim() == 3:            pred_tag_map = pred_tag_map.squeeze(2)        batch_size, max_person, num_joints = int(joints.shape[0]), int(joints.shape[1]), int(joints.shape[2])        joints_vis = joints[:, :, :, 1].float()  # (B, 30, K)        person_joints_cnt = joints_vis.sum(2, keepdim=True)  # (B, 30, 1)        joints_loc = joints[:, :, :, 0].reshape(batch_size, -1).long()  # (B, 30 * K)        tags = torch.gather(pred_tag_map, index=joints_loc, dim=1)  # (B, 30 * K)        tags = tags.reshape(batch_size, max_person, num_joints) * joints_vis  # (B, 30, K)        person_cnt = (person_joints_cnt > 0).float().squeeze(2).sum(dim=1, keepdim=True)  # (B, 1)        person_cnt[person_cnt == 0] = 1        person_vis = (person_joints_cnt > 0).expand(batch_size, max_person, max_person).float()  # (B, 30, 30)        person_vis = person_vis * person_vis.permute(0, 2, 1)        person_joints_cnt[person_joints_cnt == 0] = 1        # PULL        tags_mean = tags.sum(2, keepdim=True) / person_joints_cnt  # (B, 30, 1)        # tags_mean[person_joints_cnt == 0] = 0  # clear nan elements from divided by zero        assert torch.isnan(tags_mean).sum() == 0        pull = torch.sum(joints_vis * (tags - tags_mean) ** 2, dim=2, keepdim=True) / person_joints_cnt  # (B, 30, 1)        pull[person_joints_cnt == 0] = 0        pull = pull.squeeze(2).sum(1, keepdim=True) / person_cnt        # pull[person_cnt == 0] = 0  # clear nan elements from divided by zero        pull = torch.mean(pull)        # PUSH        tags_mean = (tags_mean).expand(batch_size, max_person, max_person)  # (B, 30, 30)        diff = (tags_mean - tags_mean.permute(0, 2, 1)) * person_vis  # (B, 30, 30)        if self.loss_type == 'exp':            diff = torch.exp(- diff ** 2) * person_vis  # (B, 30, 30)            diff = 0.5 * (torch.sum(diff, dim=(1, 2)) - person_cnt.squeeze(1)) / torch.clamp((person_cnt - 1) * person_cnt, min=1).squeeze(1)  # (B)            diff[person_cnt.squeeze(1) < 2] = 0            push = torch.mean(diff)        elif self.loss_type == 'max':            diff = 1 - torch.abs(diff)  # (B, 30, 30)            diff = torch.clamp(diff, min=0) * person_vis            diff = 0.5 * (torch.sum(diff, dim=(1, 2)) - person_cnt.squeeze(1)) / torch.clamp((person_cnt - 1) * person_cnt, min=1).squeeze(1)  # (B)            diff[person_cnt.squeeze(1) < 2] = 0            push = torch.mean(diff)        return push, pull    def forward(self, tags, joints):        """        accumulate the tag loss for each image in the batch        tags: (B, K * H * W)        joints: (B, 30, K, 2)        """        batch_push, batch_pull = self.batchTagLoss(tags, joints)        return batch_push, batch_pull    # def forward(self, tags, joints):    #     """    #     accumulate the tag loss for each image in the batch    #     """    #     pushes, pulls = [], []    #     joints = joints.cpu().data.numpy()    #     batch_size = tags.size(0)    #     for i in range(batch_size):    #         push, pull = self.singleTagLoss(tags[i], joints[i])    #         pushes.append(push)    #         pulls.append(pull)    #     return torch.stack(pushes), torch.stack(pulls)class JointsMSELoss(nn.Module):    def __init__(self, use_target_weight):        super(JointsMSELoss, self).__init__()        self.criterion = nn.MSELoss(size_average=True)        self.use_target_weight = use_target_weight    def forward(self, output, target, target_weight):        batch_size = output.size(0)        num_joints = output.size(1)        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)        loss = 0        for idx in range(num_joints):            heatmap_pred = heatmaps_pred[idx].squeeze()            heatmap_gt = heatmaps_gt[idx].squeeze()            if self.use_target_weight:                loss += 0.5 * self.criterion(                    heatmap_pred.mul(target_weight[:, idx]),                    heatmap_gt.mul(target_weight[:, idx])                )            else:                loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)        return loss / num_jointsclass LossFactory(nn.Module):    def __init__(self, cfg):        super().__init__()        self.num_joints = cfg.DATASET.NUM_JOINTS        self.heatmaps_loss = None        self.ae_loss = None        self.heatmaps_loss_factor = 1.0        self.push_loss_factor = 1.0        self.pull_loss_factor = 1.0        if cfg.LOSS.WITH_HEATMAPS_LOSS:            self.heatmaps_loss = HeatmapLoss()            self.heatmaps_loss_factor = cfg.LOSS.HEATMAPS_LOSS_FACTOR        if cfg.LOSS.WITH_AE_LOSS:            self.ae_loss = AELoss(cfg.LOSS.AE_LOSS_TYPE, cfg.DATASET.MAX_NUM_PEOPLE, cfg.DATASET.OUTPUT_SIZE)            self.push_loss_factor = cfg.LOSS.PUSH_LOSS_FACTOR            self.pull_loss_factor = cfg.LOSS.PULL_LOSS_FACTOR        if not self.heatmaps_loss and not self.ae_loss:            logger.error('At least enable one loss!')    def forward(self, outputs, heatmaps, masks, joints):        # TODO(bowen): outputs and heatmaps can be lists of same length        heatmaps_pred = outputs[:, :self.num_joints]        tags_pred = outputs[:, self.num_joints:]        heatmaps_loss = None        push_loss = None        pull_loss = None        if self.heatmaps_loss is not None:            heatmaps_loss = self.heatmaps_loss(heatmaps_pred, heatmaps, masks)            heatmaps_loss = heatmaps_loss * self.heatmaps_loss_factor        if self.ae_loss is not None:            batch_size = tags_pred.size()[0]            tags_pred = tags_pred.contiguous().view(batch_size, -1, 1)            push_loss, pull_loss = self.ae_loss(tags_pred, joints)            push_loss = push_loss * self.push_loss_factor            pull_loss = pull_loss * self.pull_loss_factor        return [heatmaps_loss], [push_loss], [pull_loss]class MultiLossFactory(nn.Module):    def __init__(self, cfg):        super().__init__()        # init check        self._init_check(cfg)        self.num_joints = cfg.MODEL.NUM_JOINTS        self.num_stages = cfg.LOSS.NUM_STAGES        self.heatmaps_loss = \            nn.ModuleList(                [                    HeatmapLoss()                    if with_heatmaps_loss else None                    for with_heatmaps_loss in cfg.LOSS.WITH_HEATMAPS_LOSS                ]            )        self.heatmaps_loss_factor = cfg.LOSS.HEATMAPS_LOSS_FACTOR        self.ae_loss = \            nn.ModuleList(                [                    AELoss(cfg.LOSS.AE_LOSS_TYPE, cfg.DATASET.MAX_NUM_PEOPLE, cfg.DATASET.OUTPUT_SIZE[index]) if cfg.LOSS.WITH_AE_LOSS[index] else None                    for index in range(len(cfg.LOSS.WITH_AE_LOSS))                ]            )        self.push_loss_factor = cfg.LOSS.PUSH_LOSS_FACTOR        self.pull_loss_factor = cfg.LOSS.PULL_LOSS_FACTOR    def forward(self, outputs, heatmaps, masks, joints):        # forward check        self._forward_check(outputs, heatmaps, masks, joints)        heatmaps_losses = []        push_losses = []        pull_losses = []        for idx in range(len(outputs)):            offset_feat = 0            if self.heatmaps_loss[idx]:                heatmaps_pred = outputs[idx][:, :self.num_joints]                offset_feat = self.num_joints                heatmaps_loss = self.heatmaps_loss[idx](                    heatmaps_pred, heatmaps[idx], masks[idx]                )                heatmaps_loss = heatmaps_loss * self.heatmaps_loss_factor[idx]                heatmaps_losses.append(heatmaps_loss)            else:                heatmaps_losses.append(None)            if self.ae_loss[idx]:                tags_pred = outputs[idx][:, offset_feat:]                batch_size = tags_pred.size()[0]                tags_pred = tags_pred.contiguous().view(batch_size, -1, 1)                push_loss, pull_loss = self.ae_loss[idx](                    tags_pred, joints[idx]                )                push_loss = push_loss * self.push_loss_factor[idx]                pull_loss = pull_loss * self.pull_loss_factor[idx]                push_losses.append(push_loss)                pull_losses.append(pull_loss)            else:                push_losses.append(None)                pull_losses.append(None)        return heatmaps_losses, push_losses, pull_losses    def _init_check(self, cfg):        assert isinstance(cfg.LOSS.WITH_HEATMAPS_LOSS, (list, tuple)), \            'LOSS.WITH_HEATMAPS_LOSS should be a list or tuple'        assert isinstance(cfg.LOSS.HEATMAPS_LOSS_FACTOR, (list, tuple)), \            'LOSS.HEATMAPS_LOSS_FACTOR should be a list or tuple'        assert isinstance(cfg.LOSS.WITH_AE_LOSS, (list, tuple)), \            'LOSS.WITH_AE_LOSS should be a list or tuple'        assert isinstance(cfg.LOSS.PUSH_LOSS_FACTOR, (list, tuple)), \            'LOSS.PUSH_LOSS_FACTOR should be a list or tuple'        assert isinstance(cfg.LOSS.PUSH_LOSS_FACTOR, (list, tuple)), \            'LOSS.PUSH_LOSS_FACTOR should be a list or tuple'        assert len(cfg.LOSS.WITH_HEATMAPS_LOSS) == cfg.LOSS.NUM_STAGES, \            'LOSS.WITH_HEATMAPS_LOSS and LOSS.NUM_STAGE should have same length, got {} vs {}.'.\                format(len(cfg.LOSS.WITH_HEATMAPS_LOSS), cfg.LOSS.NUM_STAGES)        assert len(cfg.LOSS.WITH_HEATMAPS_LOSS) == len(cfg.LOSS.HEATMAPS_LOSS_FACTOR), \            'LOSS.WITH_HEATMAPS_LOSS and LOSS.HEATMAPS_LOSS_FACTOR should have same length, got {} vs {}.'.\                format(len(cfg.LOSS.WITH_HEATMAPS_LOSS), len(cfg.LOSS.HEATMAPS_LOSS_FACTOR))        assert len(cfg.LOSS.WITH_AE_LOSS) == cfg.LOSS.NUM_STAGES, \            'LOSS.WITH_AE_LOSS and LOSS.NUM_STAGE should have same length, got {} vs {}.'.\                format(len(cfg.LOSS.WITH_AE_LOSS), cfg.LOSS.NUM_STAGES)        assert len(cfg.LOSS.WITH_AE_LOSS) == len(cfg.LOSS.PUSH_LOSS_FACTOR), \            'LOSS.WITH_AE_LOSS and LOSS.PUSH_LOSS_FACTOR should have same length, got {} vs {}.'. \                format(len(cfg.LOSS.WITH_AE_LOSS), len(cfg.LOSS.PUSH_LOSS_FACTOR))        assert len(cfg.LOSS.WITH_AE_LOSS) == len(cfg.LOSS.PULL_LOSS_FACTOR), \            'LOSS.WITH_AE_LOSS and LOSS.PULL_LOSS_FACTOR should have same length, got {} vs {}.'. \                format(len(cfg.LOSS.WITH_AE_LOSS), len(cfg.LOSS.PULL_LOSS_FACTOR))    def _forward_check(self, outputs, heatmaps, masks, joints):        assert isinstance(outputs, list), \            'outputs should be a list, got {} instead.'.format(type(outputs))        assert isinstance(heatmaps, list), \            'heatmaps should be a list, got {} instead.'.format(type(heatmaps))        assert isinstance(masks, list), \            'masks should be a list, got {} instead.'.format(type(masks))        assert isinstance(joints, list), \            'joints should be a list, got {} instead.'.format(type(joints))        assert len(outputs) == self.num_stages, \            'len(outputs) and num_stages should been same, got {} vs {}.'.format(len(outputs), self.num_stages)        assert len(outputs) == len(heatmaps), \            'outputs and heatmaps should have same length, got {} vs {}.'.format(len(outputs), len(heatmaps))        assert len(outputs) == len(masks), \            'outputs and masks should have same length, got {} vs {}.'.format(len(outputs), len(masks))        assert len(outputs) == len(joints), \            'outputs and joints should have same length, got {} vs {}.'.format(len(outputs), len(joints))        assert len(outputs) == len(self.heatmaps_loss), \            'outputs and heatmaps_loss should have same length, got {} vs {}.'. \                format(len(outputs), len(self.heatmaps_loss))        assert len(outputs) == len(self.ae_loss), \            'outputs and ae_loss should have same length, got {} vs {}.'. \                format(len(outputs), len(self.ae_loss))def test_ae_loss():    import numpy as np    t = torch.tensor(        np.arange(0, 32).reshape(1, 2, 4, 4).astype(np.float)*0.1,        requires_grad=True    )    t.register_hook(lambda x: print('t', x))    ae_loss = AELoss('exp', 30, 512)    joints = np.zeros((2, 2, 2))    joints[0, 0] = (3, 1)    joints[1, 0] = (10, 1)    joints[0, 1] = (22, 1)    joints[1, 1] = (30, 1)    joints = torch.LongTensor(joints)    joints = joints.view(1, 2, 2, 2)    t = t.contiguous().view(1, -1, 1)    l = ae_loss(t, joints)    print(l)if __name__ == '__main__':    test_ae_loss()