import tensorflow as tfimport sixdef assert_rank(tensor, expected_rank, name=None):    """Raises an exception if the tensor rank is not of the expected rank.    Args:      tensor: A tf.Tensor to check the rank of.      expected_rank: Python integer or list of integers, expected rank.      name: Optional name of the tensor for the error message.    Raises:      ValueError: If the expected shape doesn't match the actual shape.    """    if name is None:        name = tensor.name    expected_rank_dict = {}    if isinstance(expected_rank, six.integer_types):        expected_rank_dict[expected_rank] = True    else:        for x in expected_rank:            expected_rank_dict[x] = True    actual_rank = tensor.shape.ndims    if actual_rank not in expected_rank_dict:        scope_name = tf.get_variable_scope().name        raise ValueError(            "For the tensor `%s` in scope `%s`, the actual rank "            "`%d` (shape = %s) is not equal to the expected rank `%s`" %            (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))def get_shape_list_imagebert(tensor, expected_rank=None, name=None):    """Returns a list of the shape of tensor, preferring static dimensions.    Args:      tensor: A tf.Tensor object to find the shape of.      expected_rank: (optional) int. The expected rank of `tensor`. If this is        specified and the `tensor` has a different rank, and exception will be        thrown.      name: Optional name of the tensor for the error message.    Returns:      A list of dimensions of the shape of tensor. All static dimensions will      be returned as python integers, and dynamic dimensions will be returned      as tf.Tensor scalars.    """    if name is None:        name = tensor.name    if expected_rank is not None:        assert_rank(tensor, expected_rank, name)    shape = tensor.shape.as_list()    non_static_indexes = []    for (index, dim) in enumerate(shape):        if dim is None:            non_static_indexes.append(index)    if not non_static_indexes:        return shape    dyn_shape = tf.shape(tensor)    for index in non_static_indexes:        shape[index] = dyn_shape[index]    return shapedef get_shape_list(x):    """Deal with dynamic shape in tensorflow cleanly."""    static = x.shape.as_list()    dynamic = tf.shape(x)    return [dynamic[i] if s is None else s for i, s in enumerate(static)]def get_initializer(initializer_range=0.02):    """Creates a `tf.initializers.truncated_normal` with the given range.    Args:        initializer_range: float, initializer range for stddev.    Returns:        TruncatedNormal initializer with stddev = `initializer_range`.    """    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)def gather_indexes(sequence_tensor, positions):    """Gathers the vectors at the specific positions over a minibatch."""    sequence_shape = get_shape_list(sequence_tensor)    batch_size = sequence_shape[0]    seq_length = sequence_shape[1]    width = sequence_shape[2]    batch_size = tf.cast(batch_size, tf.int64)    seq_length = tf.cast(seq_length, tf.int64)    flat_offsets = tf.reshape(        tf.range(0, batch_size, dtype=tf.int64) * seq_length, [-1, 1])    flat_positions = tf.reshape(positions + flat_offsets, [-1])    flat_sequence_tensor = tf.reshape(sequence_tensor,                                      [batch_size * seq_length, width])    output_tensor = tf.gather(flat_sequence_tensor, flat_positions)    return output_tensor