import cPickleimport mxnet as mximport mathfrom utils.symbol import Symbolfrom operator_py.proposal import *from operator_py.proposal_target import *from operator_py.box_annotator_ohem import *from resnet_v1_101_rcnn_base import resnet_v1_101_rcnn_baseclass resnet_v1_101_rcnn_attention_1024_pairwise_position_multi_head_16(resnet_v1_101_rcnn_base):    def __init__(self):        """        Use __init__ to define parameter network needs        """        self.eps = 1e-5        self.use_global_stats = True        self.workspace = 512        self.units = (3, 4, 23, 3) # use for 101        self.filter_list = [256, 512, 1024, 2048]    @staticmethod    def extract_position_embedding(position_mat, feat_dim, wave_length=1000):        # position_mat, [num_rois, nongt_dim, 4]        feat_range = mx.sym.arange(0, feat_dim / 8)        dim_mat = mx.sym.broadcast_power(lhs=mx.sym.full((1,), wave_length),                                         rhs=(8. / feat_dim) * feat_range)        dim_mat = mx.sym.Reshape(dim_mat, shape=(1, 1, 1, -1))        position_mat = mx.sym.expand_dims(100.0 * position_mat, axis=3)        div_mat = mx.sym.broadcast_div(lhs=position_mat, rhs=dim_mat)        sin_mat = mx.sym.sin(data=div_mat)        cos_mat = mx.sym.cos(data=div_mat)        # embedding, [num_rois, nongt_dim, 4, feat_dim/4]        embedding = mx.sym.concat(sin_mat, cos_mat, dim=3)        # embedding, [num_rois, nongt_dim, feat_dim]        embedding = mx.sym.Reshape(embedding, shape=(0, 0, feat_dim))        return embedding    @staticmethod    def extract_position_matrix(bbox, nongt_dim):        """ Extract position matrix        Args:            bbox: [num_boxes, 4]        Returns:            position_matrix: [num_boxes, nongt_dim, 4]        """        xmin, ymin, xmax, ymax = mx.sym.split(data=bbox,                                              num_outputs=4, axis=1)        # [num_fg_classes, num_boxes, 1]        bbox_width = xmax - xmin + 1.        bbox_height = ymax - ymin + 1.        center_x = 0.5 * (xmin + xmax)        center_y = 0.5 * (ymin + ymax)        # [num_fg_classes, num_boxes, num_boxes]        delta_x = mx.sym.broadcast_minus(lhs=center_x,                                         rhs=mx.sym.transpose(center_x))        delta_x = mx.sym.broadcast_div(delta_x, bbox_width)        delta_x = mx.sym.log(mx.sym.maximum(mx.sym.abs(delta_x), 1e-3))        delta_y = mx.sym.broadcast_minus(lhs=center_y,                                         rhs=mx.sym.transpose(center_y))        delta_y = mx.sym.broadcast_div(delta_y, bbox_height)        delta_y = mx.sym.log(mx.sym.maximum(mx.sym.abs(delta_y), 1e-3))        delta_width = mx.sym.broadcast_div(lhs=bbox_width,                                           rhs=mx.sym.transpose(bbox_width))        delta_width = mx.sym.log(delta_width)        delta_height = mx.sym.broadcast_div(lhs=bbox_height,                                            rhs=mx.sym.transpose(bbox_height))        delta_height = mx.sym.log(delta_height)        concat_list = [delta_x, delta_y, delta_width, delta_height]        for idx, sym in enumerate(concat_list):            sym = mx.sym.slice_axis(sym, axis=1, begin=0, end=nongt_dim)            concat_list[idx] = mx.sym.expand_dims(sym, axis=2)        position_matrix = mx.sym.concat(*concat_list, dim=2)        return position_matrix    def attention_module_multi_head(self, roi_feat, position_embedding,                                    nongt_dim, fc_dim, feat_dim,                                    dim=(1024, 1024, 1024),                                    group=16, index=1):        """ Attetion module with vectorized version        Args:            roi_feat: [num_rois, feat_dim]            position_embedding: [num_rois, nongt_dim, emb_dim]            nongt_dim:            fc_dim: should be same as group            feat_dim: dimension of roi_feat, should be same as dim[2]            dim: a 3-tuple of (query, key, output)            group:            index:        Returns:            output: [num_rois, ovr_feat_dim, output_dim]        """        dim_group = (dim[0] / group, dim[1] / group, dim[2] / group)        nongt_roi_feat = mx.symbol.slice_axis(data=roi_feat, axis=0, begin=0, end=nongt_dim)        # [num_rois * nongt_dim, emb_dim]        position_embedding_reshape = mx.sym.Reshape(position_embedding, shape=(-3, -2))        # position_feat_1, [num_rois * nongt_dim, fc_dim]        position_feat_1 = mx.sym.FullyConnected(name='pair_pos_fc1_' + str(index),                                                data=position_embedding_reshape,                                                num_hidden=fc_dim)        position_feat_1_relu = mx.sym.Activation(data=position_feat_1, act_type='relu')        # aff_weight, [num_rois, nongt_dim, fc_dim]        aff_weight = mx.sym.Reshape(position_feat_1_relu, shape=(-1, nongt_dim, fc_dim))        # aff_weight, [num_rois, fc_dim, nongt_dim]        aff_weight = mx.sym.transpose(aff_weight, axes=(0, 2, 1))        # multi head        assert dim[0] == dim[1], 'Matrix multiply requires same dimensions!'        q_data = mx.sym.FullyConnected(name='query_' + str(index),                                       data=roi_feat,                                       num_hidden=dim[0])        q_data_batch = mx.sym.Reshape(q_data, shape=(-1, group, dim_group[0]))        q_data_batch = mx.sym.transpose(q_data_batch, axes=(1, 0, 2))        k_data = mx.symbol.FullyConnected(name='key_' + str(index),                                          data=nongt_roi_feat,                                          num_hidden=dim[1])        k_data_batch = mx.sym.Reshape(k_data, shape=(-1, group, dim_group[1]))        k_data_batch = mx.sym.transpose(k_data_batch, axes=(1, 0, 2))        v_data = nongt_roi_feat        # v_data =  mx.symbol.FullyConnected(name='value_'+str(index)+'_'+str(gid), data=roi_feat, num_hidden=dim_group[2])        aff = mx.symbol.batch_dot(lhs=q_data_batch, rhs=k_data_batch, transpose_a=False, transpose_b=True)        # aff_scale, [group, num_rois, nongt_dim]        aff_scale = (1.0 / math.sqrt(float(dim_group[1]))) * aff        aff_scale = mx.sym.transpose(aff_scale, axes=(1, 0, 2))        assert fc_dim == group, 'fc_dim != group'        # weighted_aff, [num_rois, fc_dim, nongt_dim]        weighted_aff = mx.sym.log(mx.sym.maximum(left=aff_weight, right=1e-6)) + aff_scale        aff_softmax = mx.symbol.softmax(data=weighted_aff, axis=2, name='softmax_' + str(index))        # [num_rois * fc_dim, nongt_dim]        aff_softmax_reshape = mx.sym.Reshape(aff_softmax, shape=(-3, -2))        # output_t, [num_rois * fc_dim, feat_dim]        output_t = mx.symbol.dot(lhs=aff_softmax_reshape, rhs=v_data)        # output_t, [num_rois, fc_dim * feat_dim, 1, 1]        output_t = mx.sym.Reshape(output_t, shape=(-1, fc_dim * feat_dim, 1, 1))        # linear_out, [num_rois, dim[2], 1, 1]        linear_out = mx.symbol.Convolution(name='linear_out_' + str(index), data=output_t,                                           kernel=(1, 1), num_filter=dim[2], num_group=fc_dim)        output = mx.sym.Reshape(linear_out, shape=(0, 0))        return output    def get_symbol(self, cfg, is_train=True):        # config alias for convenient        num_classes = cfg.dataset.NUM_CLASSES        num_reg_classes = (2 if cfg.CLASS_AGNOSTIC else num_classes)        num_anchors = cfg.network.NUM_ANCHORS        # input init        if is_train:            data = mx.sym.Variable(name="data")            im_info = mx.sym.Variable(name="im_info")            gt_boxes = mx.sym.Variable(name="gt_boxes")            rpn_label = mx.sym.Variable(name='label')            rpn_bbox_target = mx.sym.Variable(name='bbox_target')            rpn_bbox_weight = mx.sym.Variable(name='bbox_weight')        else:            data = mx.sym.Variable(name="data")            im_info = mx.sym.Variable(name="im_info")            gt_boxes = mx.sym.Variable(name="gt_boxes")        # shared convolutional layers        conv_feat = self.get_resnet_v1_conv4(data)        # res5        relu1 = self.get_resnet_v1_conv5(conv_feat)        rpn_cls_score, rpn_bbox_pred = self.get_rpn(conv_feat, num_anchors)        if is_train:            # prepare rpn data            rpn_cls_score_reshape = mx.sym.Reshape(                data=rpn_cls_score, shape=(0, 2, -1, 0), name="rpn_cls_score_reshape")            # classification            rpn_cls_prob = mx.sym.SoftmaxOutput(data=rpn_cls_score_reshape, label=rpn_label, multi_output=True,                                                   normalization='valid', use_ignore=True, ignore_label=-1, name="rpn_cls_prob")            # bounding box regression            rpn_bbox_loss_ = rpn_bbox_weight * mx.sym.smooth_l1(name='rpn_bbox_loss_', scalar=3.0, data=(rpn_bbox_pred - rpn_bbox_target))            rpn_bbox_loss = mx.sym.MakeLoss(name='rpn_bbox_loss', data=rpn_bbox_loss_, grad_scale=1.0 / cfg.TRAIN.RPN_BATCH_SIZE)            # ROI proposal            rpn_cls_act = mx.sym.SoftmaxActivation(                data=rpn_cls_score_reshape, mode="channel", name="rpn_cls_act")            rpn_cls_act_reshape = mx.sym.Reshape(                data=rpn_cls_act, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_act_reshape')            if cfg.TRAIN.CXX_PROPOSAL:                rois = mx.contrib.sym.Proposal(                    cls_prob=rpn_cls_act_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois',                    feature_stride=cfg.network.RPN_FEAT_STRIDE, scales=tuple(cfg.network.ANCHOR_SCALES),                    ratios=tuple(cfg.network.ANCHOR_RATIOS),                    rpn_pre_nms_top_n=cfg.TRAIN.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=cfg.TRAIN.RPN_POST_NMS_TOP_N,                    threshold=cfg.TRAIN.RPN_NMS_THRESH, rpn_min_size=cfg.TRAIN.RPN_MIN_SIZE)            else:                rois = mx.sym.Custom(                    cls_prob=rpn_cls_act_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois',                    op_type='proposal', feat_stride=cfg.network.RPN_FEAT_STRIDE,                    scales=tuple(cfg.network.ANCHOR_SCALES), ratios=tuple(cfg.network.ANCHOR_RATIOS),                    rpn_pre_nms_top_n=cfg.TRAIN.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=cfg.TRAIN.RPN_POST_NMS_TOP_N,                    threshold=cfg.TRAIN.RPN_NMS_THRESH, rpn_min_size=cfg.TRAIN.RPN_MIN_SIZE)            # ROI proposal target            gt_boxes_reshape = mx.sym.Reshape(data=gt_boxes, shape=(-1, 5), name='gt_boxes_reshape')            rois, label, bbox_target, bbox_weight = mx.sym.Custom(rois=rois, gt_boxes=gt_boxes_reshape,                                                                  op_type='proposal_target',                                                                  num_classes=num_reg_classes,                                                                  batch_images=cfg.TRAIN.BATCH_IMAGES,                                                                  batch_rois=cfg.TRAIN.BATCH_ROIS,                                                                  cfg=cPickle.dumps(cfg),                                                                  fg_fraction=cfg.TRAIN.FG_FRACTION)        else:            # ROI Proposal            rpn_cls_score_reshape = mx.sym.Reshape(                data=rpn_cls_score, shape=(0, 2, -1, 0), name="rpn_cls_score_reshape")            rpn_cls_prob = mx.sym.SoftmaxActivation(                data=rpn_cls_score_reshape, mode="channel", name="rpn_cls_prob")            rpn_cls_prob_reshape = mx.sym.Reshape(                data=rpn_cls_prob, shape=(0, 2 * num_anchors, -1, 0), name='rpn_cls_prob_reshape')            if cfg.TEST.CXX_PROPOSAL:                rois = mx.contrib.sym.Proposal(                    cls_prob=rpn_cls_prob_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois',                    feature_stride=cfg.network.RPN_FEAT_STRIDE, scales=tuple(cfg.network.ANCHOR_SCALES),                    ratios=tuple(cfg.network.ANCHOR_RATIOS),                    rpn_pre_nms_top_n=cfg.TEST.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=cfg.TEST.RPN_POST_NMS_TOP_N,                    threshold=cfg.TEST.RPN_NMS_THRESH, rpn_min_size=cfg.TEST.RPN_MIN_SIZE)            else:                rois = mx.sym.Custom(                    cls_prob=rpn_cls_prob_reshape, bbox_pred=rpn_bbox_pred, im_info=im_info, name='rois',                    op_type='proposal', feat_stride=cfg.network.RPN_FEAT_STRIDE,                    scales=tuple(cfg.network.ANCHOR_SCALES), ratios=tuple(cfg.network.ANCHOR_RATIOS),                    rpn_pre_nms_top_n=cfg.TEST.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=cfg.TEST.RPN_POST_NMS_TOP_N,                    threshold=cfg.TEST.RPN_NMS_THRESH, rpn_min_size=cfg.TEST.RPN_MIN_SIZE)            gt_boxes_reshape = mx.sym.Reshape(data=gt_boxes, shape=(-1, 5), name='gt_boxes_reshape')        nongt_dim = cfg.TRAIN.RPN_POST_NMS_TOP_N if is_train else cfg.TEST.RPN_POST_NMS_TOP_N        # fc_position = mx.symbol.FullyConnected(name='fc_position', data=position_feat, num_hidden=1024)        #fc_position_relu =  mx.sym.Activation(data=fc_position, act_type='relu', name='fc_position_relu')        conv_new_1 = mx.sym.Convolution(data=relu1, kernel=(1, 1), num_filter=256, name="conv_new_1")        conv_new_1_relu = mx.sym.Activation(data=conv_new_1, act_type='relu', name='conv_new_1_relu')        roi_pool = mx.symbol.ROIPooling(            name='roi_pool', data=conv_new_1_relu, rois=rois, pooled_size=(7, 7), spatial_scale=0.0625)        sliced_rois = mx.sym.slice_axis(rois, axis=1, begin=1, end=None)        # [num_rois, nongt_dim, 4]        position_matrix = self.extract_position_matrix(sliced_rois, nongt_dim=nongt_dim)        # [num_rois, nongt_dim, 64]        position_embedding = self.extract_position_embedding(position_matrix, feat_dim=64)        # 2 fc        fc_new_1 = mx.symbol.FullyConnected(name='fc_new_1', data=roi_pool, num_hidden=1024)        # attention, [num_rois, feat_dim]        attention_1 = self.attention_module_multi_head(fc_new_1, position_embedding,                                                       nongt_dim=nongt_dim, fc_dim=16, feat_dim=1024,                                                       index=1, group=16,                                                       dim=(1024, 1024, 1024))        fc_all_1 = fc_new_1 + attention_1        fc_all_1_relu = mx.sym.Activation(data=fc_all_1, act_type='relu', name='fc_all_1_relu')        fc_new_2 = mx.symbol.FullyConnected(name='fc_new_2', data=fc_all_1_relu, num_hidden=1024)        attention_2 = self.attention_module_multi_head(fc_new_2, position_embedding,                                                       nongt_dim=nongt_dim, fc_dim=16, feat_dim=1024,                                                       index=2, group=16,                                                       dim=(1024, 1024, 1024))        fc_all_2 = fc_new_2 + attention_2        fc_all_2_relu = mx.sym.Activation(data=fc_all_2, act_type='relu', name='fc_all_2_relu')        # cls_score/bbox_pred        cls_score = mx.symbol.FullyConnected(name='cls_score', data=fc_all_2_relu, num_hidden=num_classes)        bbox_pred = mx.symbol.FullyConnected(name='bbox_pred', data=fc_all_2_relu, num_hidden=num_reg_classes * 4)        if is_train:            if cfg.TRAIN.ENABLE_OHEM:                labels_ohem, bbox_weights_ohem = mx.sym.Custom(op_type='BoxAnnotatorOHEM', num_classes=num_classes,                                                               num_reg_classes=num_reg_classes,                                                               roi_per_img=cfg.TRAIN.BATCH_ROIS_OHEM,                                                               cls_score=cls_score, bbox_pred=bbox_pred, labels=label,                                                               bbox_targets=bbox_target, bbox_weights=bbox_weight)                cls_prob = mx.sym.SoftmaxOutput(name='cls_prob', data=cls_score, label=labels_ohem,                                                normalization='valid', use_ignore=True, ignore_label=-1)                bbox_loss_ = bbox_weights_ohem * mx.sym.smooth_l1(name='bbox_loss_', scalar=1.0,                                                                  data=(bbox_pred - bbox_target))                bbox_loss = mx.sym.MakeLoss(name='bbox_loss', data=bbox_loss_,                                            grad_scale=1.0 / cfg.TRAIN.BATCH_ROIS_OHEM)                rcnn_label = labels_ohem            else:                cls_prob = mx.sym.SoftmaxOutput(name='cls_prob', data=cls_score, label=label, normalization='valid')                bbox_loss_ = bbox_weight * mx.sym.smooth_l1(name='bbox_loss_', scalar=1.0,                                                            data=(bbox_pred - bbox_target))                if cfg.TRAIN.BATCH_ROIS < 0:                    batch_rois_num = 300                else:                    batch_rois_num = cfg.TRAIN.BATCH_ROIS                bbox_loss = mx.sym.MakeLoss(name='bbox_loss', data=bbox_loss_, grad_scale=1.0 / batch_rois_num)                #bbox_loss = mx.sym.MakeLoss(name='bbox_loss', data=bbox_loss_, grad_scale=1.0 / cfg.TRAIN.BATCH_ROIS)                rcnn_label = label            # reshape output            rcnn_label = mx.sym.Reshape(data=rcnn_label, shape=(cfg.TRAIN.BATCH_IMAGES, -1), name='label_reshape')            cls_prob = mx.sym.Reshape(data=cls_prob, shape=(cfg.TRAIN.BATCH_IMAGES, -1, num_classes),                                      name='cls_prob_reshape')            bbox_loss = mx.sym.Reshape(data=bbox_loss, shape=(cfg.TRAIN.BATCH_IMAGES, -1, 4 * num_reg_classes),                                       name='bbox_loss_reshape')            group = mx.sym.Group([rpn_cls_prob, rpn_bbox_loss, cls_prob, bbox_loss, mx.sym.BlockGrad(rcnn_label)])        else:            cls_prob = mx.sym.SoftmaxActivation(name='cls_prob', data=cls_score)            cls_prob = mx.sym.Reshape(data=cls_prob, shape=(cfg.TEST.BATCH_IMAGES, -1, num_classes),                                      name='cls_prob_reshape')            bbox_pred = mx.sym.Reshape(data=bbox_pred, shape=(cfg.TEST.BATCH_IMAGES, -1, 4 * num_reg_classes),                                       name='bbox_pred_reshape')            #group = mx.sym.Group([rois, cls_prob, bbox_pred, fc_new_1, aff_softmax_0, aff_softmax_1, aff_softmax_2, aff_softmax_3])            group = mx.sym.Group([rois, cls_prob, bbox_pred, attention_1, attention_2])        self.sym = group        return group    def init_weight_attention_multi_head(self, cfg, arg_params, aux_params, index=1):        arg_params['pair_pos_fc1_' + str(index) + '_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict[            'pair_pos_fc1_' + str(index) + '_weight'])        arg_params['pair_pos_fc1_' + str(index) + '_bias'] = mx.nd.zeros(            shape=self.arg_shape_dict['pair_pos_fc1_' + str(index) + '_bias'])        # batch mode        arg_params['query_' + str(index) + '_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict[            'query_' + str(index) + '_weight'])        arg_params['query_' + str(index) + '_bias'] = mx.nd.zeros(            shape=self.arg_shape_dict['query_' + str(index) + '_bias'])        arg_params['key_' + str(index) + '_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict[            'key_' + str(index) + '_weight'])        arg_params['key_' + str(index) + '_bias'] = mx.nd.zeros(            shape=self.arg_shape_dict['key_' + str(index) + '_bias'])        arg_params['linear_out_' + str(index) + '_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict[            'linear_out_' + str(index) + '_weight'])        arg_params['linear_out_' + str(index) + '_bias'] = mx.nd.zeros(            shape=self.arg_shape_dict['linear_out_' + str(index) + '_bias'])    def init_weight_rcnn(self, cfg, arg_params, aux_params):        arg_params['conv_new_1_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['conv_new_1_weight'])        arg_params['conv_new_1_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['conv_new_1_bias'])        # arg_params['fc_position_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['fc_position_weight'])        # arg_params['fc_position_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['fc_position_bias'])        arg_params['fc_new_1_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['fc_new_1_weight'])        arg_params['fc_new_1_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['fc_new_1_bias'])        arg_params['fc_new_2_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['fc_new_2_weight'])        arg_params['fc_new_2_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['fc_new_2_bias'])        arg_params['cls_score_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['cls_score_weight'])        arg_params['cls_score_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['cls_score_bias'])        arg_params['bbox_pred_weight'] = mx.random.normal(0, 0.01, shape=self.arg_shape_dict['bbox_pred_weight'])        arg_params['bbox_pred_bias'] = mx.nd.zeros(shape=self.arg_shape_dict['bbox_pred_bias'])        for idx in range(2):            self.init_weight_attention_multi_head(cfg, arg_params, aux_params, index=idx+1)    def init_weight(self, cfg, arg_params, aux_params):        self.init_weight_rpn(cfg, arg_params, aux_params)        self.init_weight_rcnn(cfg, arg_params, aux_params)