""" PyTorch IndicTrans model."""import mathfrom typing import List, Optional, Tuple, Unionimport torchimport torch.nn as nnfrom torch.nn import functional as Ffrom transformers.activations import ACT2FNfrom transformers.integrations.deepspeed import is_deepspeed_zero3_enabledfrom transformers.modeling_outputs import (    BaseModelOutput,    BaseModelOutputWithPastAndCrossAttentions,    Seq2SeqLMOutput,    Seq2SeqModelOutput,)from transformers.utils import loggingfrom transformers.modeling_utils import PreTrainedModelfrom configuration_indictrans import IndicTransConfiglogger = logging.get_logger(__name__)_CONFIG_FOR_DOC = "IndicTransConfig"INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]# Copied from transformers.models.bart.modeling_bart.shift_tokens_rightdef shift_tokens_right(    input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):    """    Shift input ids one token to the right.    """    shifted_input_ids = input_ids.new_zeros(input_ids.shape)    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()    shifted_input_ids[:, 0] = decoder_start_token_id    if pad_token_id is None:        raise ValueError("self.model.config.pad_token_id has to be defined.")    # replace possible -100 values in labels by `pad_token_id`    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)    return shifted_input_ids# Copied from transformers.models.bart.modeling_bart._make_causal_maskdef _make_causal_mask(    input_ids_shape: torch.Size,    dtype: torch.dtype,    device: torch.device,    past_key_values_length: int = 0,):    """    Make causal mask used for bi-directional self-attention.    """    bsz, tgt_len = input_ids_shape    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)    mask_cond = torch.arange(mask.size(-1), device=device)    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)    mask = mask.to(dtype)    if past_key_values_length > 0:        mask = torch.cat(            [                torch.zeros(                    tgt_len, past_key_values_length, dtype=dtype, device=device                ),                mask,            ],            dim=-1,        )    return mask[None, None, :, :].expand(        bsz, 1, tgt_len, tgt_len + past_key_values_length    )# Copied from transformers.models.bart.modeling_bart._expand_maskdef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):    """    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.    """    bsz, src_len = mask.size()    tgt_len = tgt_len if tgt_len is not None else src_len    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)    inverted_mask = 1.0 - expanded_mask    return inverted_mask.masked_fill(        inverted_mask.to(torch.bool), torch.finfo(dtype).min    )def create_position_ids_from_input_ids(    input_ids, padding_idx, past_key_values_length=0):    """    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols    are ignored. This is modified from fairseq's `utils.make_positions`.    """    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.    mask = input_ids.ne(padding_idx).int()    incremental_indices = (        torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length    ) * mask    return incremental_indices.long() + padding_idx# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTransclass IndicTransSinusoidalPositionalEmbedding(nn.Module):    """This module produces sinusoidal positional embeddings of any length."""    def __init__(        self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None    ):        super().__init__()        self.offset = 2        self.embedding_dim = embedding_dim        self.padding_idx = padding_idx        self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)    def make_weights(        self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None    ):        emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)        if hasattr(self, "weights"):            # in forward put the weights on the correct dtype and device of the param            emb_weights = emb_weights.to(                dtype=self.weights.dtype, device=self.weights.device            )        self.register_buffer("weights", emb_weights, persistent=False)    @staticmethod    def get_embedding(        num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None    ):        """        Build sinusoidal embeddings.        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of        "Attention Is All You Need".        """        half_dim = embedding_dim // 2        emb = math.log(10000) / (half_dim - 1)        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(            1        ) * emb.unsqueeze(0)        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(            num_embeddings, -1        )        if embedding_dim % 2 == 1:            # zero pad            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)        if padding_idx is not None:            emb[padding_idx, :] = 0        return emb.to(torch.get_default_dtype())    @torch.no_grad()    def forward(        self,        input_ids: torch.Tensor = None,        inputs_embeds: torch.Tensor = None,        past_key_values_length: int = 0,    ):        if input_ids is not None:            bsz, seq_len = input_ids.size()            # Create the position ids from the input token ids. Any padded tokens remain padded.            position_ids = create_position_ids_from_input_ids(                input_ids, self.padding_idx, past_key_values_length            ).to(input_ids.device)        else:            bsz, seq_len = inputs_embeds.size()[:-1]            position_ids = self.create_position_ids_from_inputs_embeds(                inputs_embeds, past_key_values_length            )        # expand embeddings if needed        max_pos = self.padding_idx + 1 + seq_len + past_key_values_length        if max_pos > self.weights.size(0):            self.make_weights(                max_pos + self.offset, self.embedding_dim, self.padding_idx            )        return (            self.weights.index_select(0, position_ids.view(-1))            .view(bsz, seq_len, self.weights.shape[-1])            .detach()        )    def create_position_ids_from_inputs_embeds(        self, inputs_embeds, past_key_values_length    ):        """        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.        Args:            inputs_embeds: torch.Tensor        Returns: torch.Tensor        """        input_shape = inputs_embeds.size()[:-1]        sequence_length = input_shape[1]        position_ids = torch.arange(            self.padding_idx + 1,            sequence_length + self.padding_idx + 1,            dtype=torch.long,            device=inputs_embeds.device,        )        return (            position_ids.unsqueeze(0).expand(input_shape).contiguous()            + past_key_values_length        )# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTransclass IndicTransAttention(nn.Module):    """Multi-headed attention from 'Attention Is All You Need' paper"""    def __init__(        self,        embed_dim: int,        num_heads: int,        dropout: float = 0.0,        is_decoder: bool = False,        bias: bool = True,    ):        super().__init__()        self.embed_dim = embed_dim        self.num_heads = num_heads        self.dropout = dropout        self.head_dim = embed_dim // num_heads        if (self.head_dim * num_heads) != self.embed_dim:            raise ValueError(                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"                f" and `num_heads`: {num_heads})."            )        self.scaling = self.head_dim**-0.5        self.is_decoder = is_decoder        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):        return (            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)            .transpose(1, 2)            .contiguous()        )    def forward(        self,        hidden_states: torch.Tensor,        key_value_states: Optional[torch.Tensor] = None,        past_key_value: Optional[Tuple[torch.Tensor]] = None,        attention_mask: Optional[torch.Tensor] = None,        layer_head_mask: Optional[torch.Tensor] = None,        output_attentions: bool = False,    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:        """Input shape: Batch x Time x Channel"""        # if key_value_states are provided this layer is used as a cross-attention layer        # for the decoder        is_cross_attention = key_value_states is not None        bsz, tgt_len, _ = hidden_states.size()        # get query proj        query_states = self.q_proj(hidden_states) * self.scaling        # get key, value proj        # `past_key_value[0].shape[2] == key_value_states.shape[1]`        # is checking that the `sequence_length` of the `past_key_value` is the same as        # the provided `key_value_states` to support prefix tuning        if (            is_cross_attention            and past_key_value is not None            and past_key_value[0].shape[2] == key_value_states.shape[1]        ):            # reuse k,v, cross_attentions            key_states = past_key_value[0]            value_states = past_key_value[1]        elif is_cross_attention:            # cross_attentions            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)        elif past_key_value is not None:            # reuse k, v, self_attention            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)            key_states = torch.cat([past_key_value[0], key_states], dim=2)            value_states = torch.cat([past_key_value[1], value_states], dim=2)        else:            # self_attention            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)        if self.is_decoder:            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.            # Further calls to cross_attention layer can then reuse all cross-attention            # key/value_states (first "if" case)            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of            # all previous decoder key/value_states. Further calls to uni-directional self-attention            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)            # if encoder bi-directional self-attention `past_key_value` is always `None`            past_key_value = (key_states, value_states)        proj_shape = (bsz * self.num_heads, -1, self.head_dim)        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)        key_states = key_states.reshape(*proj_shape)        value_states = value_states.reshape(*proj_shape)        src_len = key_states.size(1)        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):            raise ValueError(                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"                f" {attn_weights.size()}"            )        if attention_mask is not None:            if attention_mask.size() != (bsz, 1, tgt_len, src_len):                raise ValueError(                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"                )            attn_weights = (                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)                + attention_mask            )            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)        attn_weights = F.softmax(attn_weights, dim=-1)        if layer_head_mask is not None:            if layer_head_mask.size() != (self.num_heads,):                raise ValueError(                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"                    f" {layer_head_mask.size()}"                )            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(                bsz, self.num_heads, tgt_len, src_len            )            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)        if output_attentions:            # this operation is a bit awkward, but it's required to            # make sure that attn_weights keeps its gradient.            # In order to do so, attn_weights have to be reshaped            # twice and have to be reused in the following            attn_weights_reshaped = attn_weights.view(                bsz, self.num_heads, tgt_len, src_len            )            attn_weights = attn_weights_reshaped.view(                bsz * self.num_heads, tgt_len, src_len            )        else:            attn_weights_reshaped = None        attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)        attn_output = torch.bmm(attn_probs, value_states)        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):            raise ValueError(                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"                f" {attn_output.size()}"            )        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)        attn_output = attn_output.transpose(1, 2)        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be        # partitioned across GPUs when using tensor-parallelism.        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)        attn_output = self.out_proj(attn_output)        return attn_output, attn_weights_reshaped, past_key_value# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTransclass IndicTransEncoderLayer(nn.Module):    def __init__(self, config: IndicTransConfig):        super().__init__()        self.embed_dim = config.encoder_embed_dim        self.self_attn = IndicTransAttention(            embed_dim=self.embed_dim,            num_heads=config.encoder_attention_heads,            dropout=config.attention_dropout,        )        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)        self.dropout = config.dropout        self.activation_fn = ACT2FN[config.activation_function]        self.activation_dropout = config.activation_dropout        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)        self.final_layer_norm = nn.LayerNorm(self.embed_dim)        self.normalize_before = config.encoder_normalize_before    def forward(        self,        hidden_states: torch.Tensor,        attention_mask: torch.Tensor,        layer_head_mask: torch.Tensor,        output_attentions: bool = False,    ) -> torch.Tensor:        """        Args:            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`            attention_mask (`torch.FloatTensor`): attention mask of size                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size                `(encoder_attention_heads,)`.            output_attentions (`bool`, *optional*):                Whether or not to return the attentions tensors of all attention layers. See `attentions` under                returned tensors for more detail.        """        residual = hidden_states        if self.normalize_before:            hidden_states = self.self_attn_layer_norm(hidden_states)        hidden_states, attn_weights, _ = self.self_attn(            hidden_states=hidden_states,            attention_mask=attention_mask,            layer_head_mask=layer_head_mask,            output_attentions=output_attentions,        )        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)        hidden_states = residual + hidden_states        if not self.normalize_before:            hidden_states = self.self_attn_layer_norm(hidden_states)        residual = hidden_states        if self.normalize_before:            hidden_states = self.final_layer_norm(hidden_states)        hidden_states = self.activation_fn(self.fc1(hidden_states))        hidden_states = F.dropout(            hidden_states, p=self.activation_dropout, training=self.training        )        hidden_states = self.fc2(hidden_states)        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)        hidden_states = residual + hidden_states        if not self.normalize_before:            hidden_states = self.final_layer_norm(hidden_states)        if hidden_states.dtype == torch.float16 and (            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()        ):            clamp_value = torch.finfo(hidden_states.dtype).max - 1000            hidden_states = torch.clamp(                hidden_states, min=-clamp_value, max=clamp_value            )        outputs = (hidden_states,)        if output_attentions:            outputs += (attn_weights,)        return outputs# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTransclass IndicTransDecoderLayer(nn.Module):    def __init__(self, config: IndicTransConfig):        super().__init__()        self.embed_dim = config.decoder_embed_dim        self.self_attn = IndicTransAttention(            embed_dim=self.embed_dim,            num_heads=config.decoder_attention_heads,            dropout=config.attention_dropout,            is_decoder=True,        )        self.dropout = config.dropout        self.activation_fn = ACT2FN[config.activation_function]        self.activation_dropout = config.activation_dropout        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)        self.encoder_attn = IndicTransAttention(            self.embed_dim,            config.decoder_attention_heads,            dropout=config.attention_dropout,            is_decoder=True,        )        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)        self.final_layer_norm = nn.LayerNorm(self.embed_dim)        self.normalize_before = config.decoder_normalize_before    def forward(        self,        hidden_states: torch.Tensor,        attention_mask: Optional[torch.Tensor] = None,        encoder_hidden_states: Optional[torch.Tensor] = None,        encoder_attention_mask: Optional[torch.Tensor] = None,        layer_head_mask: Optional[torch.Tensor] = None,        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,        past_key_value: Optional[Tuple[torch.Tensor]] = None,        output_attentions: Optional[bool] = False,        use_cache: Optional[bool] = True,    ) -> torch.Tensor:        """        Args:            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`            attention_mask (`torch.FloatTensor`): attention mask of size                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.            encoder_hidden_states (`torch.FloatTensor`):                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size                `(encoder_attention_heads,)`.            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of                size `(decoder_attention_heads,)`.            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states            output_attentions (`bool`, *optional*):                Whether or not to return the attentions tensors of all attention layers. See `attentions` under                returned tensors for more detail.        """        residual = hidden_states        if self.normalize_before:            hidden_states = self.self_attn_layer_norm(hidden_states)        # Self Attention        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2        self_attn_past_key_value = (            past_key_value[:2] if past_key_value is not None else None        )        # add present self-attn cache to positions 1,2 of present_key_value tuple        hidden_states, self_attn_weights, present_key_value = self.self_attn(            hidden_states=hidden_states,            past_key_value=self_attn_past_key_value,            attention_mask=attention_mask,            layer_head_mask=layer_head_mask,            output_attentions=output_attentions,        )        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)        hidden_states = residual + hidden_states        if not self.normalize_before:            hidden_states = self.self_attn_layer_norm(hidden_states)        # Cross-Attention Block        cross_attn_present_key_value = None        cross_attn_weights = None        if encoder_hidden_states is not None:            residual = hidden_states            if self.normalize_before:                hidden_states = self.encoder_attn_layer_norm(hidden_states)            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple            cross_attn_past_key_value = (                past_key_value[-2:] if past_key_value is not None else None            )            (                hidden_states,                cross_attn_weights,                cross_attn_present_key_value,            ) = self.encoder_attn(                hidden_states=hidden_states,                key_value_states=encoder_hidden_states,                attention_mask=encoder_attention_mask,                layer_head_mask=cross_attn_layer_head_mask,                past_key_value=cross_attn_past_key_value,                output_attentions=output_attentions,            )            hidden_states = F.dropout(                hidden_states, p=self.dropout, training=self.training            )            hidden_states = residual + hidden_states            if not self.normalize_before:                hidden_states = self.encoder_attn_layer_norm(hidden_states)            # add cross-attn to positions 3,4 of present_key_value tuple            present_key_value = present_key_value + cross_attn_present_key_value        # Fully Connected        residual = hidden_states        if self.normalize_before:            hidden_states = self.final_layer_norm(hidden_states)        hidden_states = self.activation_fn(self.fc1(hidden_states))        hidden_states = F.dropout(            hidden_states, p=self.activation_dropout, training=self.training        )        hidden_states = self.fc2(hidden_states)        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)        hidden_states = residual + hidden_states        if not self.normalize_before:            hidden_states = self.final_layer_norm(hidden_states)        outputs = (hidden_states,)        if output_attentions:            outputs += (self_attn_weights, cross_attn_weights)        if use_cache:            outputs += (present_key_value,)        return outputs# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTransclass IndicTransPreTrainedModel(PreTrainedModel):    config_class = IndicTransConfig    base_model_prefix = "model"    supports_gradient_checkpointing = True    _no_split_modules = ["IndicTransAttention"]    def _init_weights(self, module):        std = self.config.init_std        if isinstance(module, nn.Linear):            module.weight.data.normal_(mean=0.0, std=std)            if module.bias is not None:                module.bias.data.zero_()        elif isinstance(module, nn.Embedding):            module.weight.data.normal_(mean=0.0, std=std)            if module.padding_idx is not None:                module.weight.data[module.padding_idx].zero_()    def _set_gradient_checkpointing(self, module, value=False):        if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):            module.gradient_checkpointing = value# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTransclass IndicTransEncoder(IndicTransPreTrainedModel):    """    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a    [`IndicTransEncoderLayer`].    Args:        config: IndicTransConfig        embed_tokens (nn.Embedding): output embedding    """    def __init__(        self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None    ):        super().__init__(config)        self.dropout = config.dropout        self.layerdrop = config.encoder_layerdrop        embed_dim = config.encoder_embed_dim        self.padding_idx = config.pad_token_id        self.max_source_positions = config.max_source_positions        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0        self.embed_tokens = nn.Embedding(            config.encoder_vocab_size, embed_dim, self.padding_idx        )        if embed_tokens is not None:            self.embed_tokens.weight = embed_tokens.weight        self.embed_positions = IndicTransSinusoidalPositionalEmbedding(            config.max_source_positions,            embed_dim,            self.padding_idx,        )        self.layers = nn.ModuleList(            [IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]        )        self.layer_norm = (            nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None        )        self.layernorm_embedding = (            nn.LayerNorm(embed_dim) if config.layernorm_embedding else None        )        self.gradient_checkpointing = False        # Initialize weights and apply final processing        self.post_init()    def forward(        self,        input_ids: Optional[torch.Tensor] = None,        attention_mask: Optional[torch.Tensor] = None,        head_mask: Optional[torch.Tensor] = None,        inputs_embeds: Optional[torch.Tensor] = None,        output_attentions: Optional[bool] = None,        output_hidden_states: Optional[bool] = None,        return_dict: Optional[bool] = None,    ):        r"""        Args:            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you                provide it.                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and                [`PreTrainedTokenizer.__call__`] for details.                [What are input IDs?](../glossary#input-ids)            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:                - 1 for tokens that are **not masked**,                - 0 for tokens that are **masked**.                [What are attention masks?](../glossary#attention-mask)            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:                - 1 indicates the head is **not masked**,                - 0 indicates the head is **masked**.            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.                This is useful if you want more control over how to convert `input_ids` indices into associated vectors                than the model's internal embedding lookup matrix.            output_attentions (`bool`, *optional*):                Whether or not to return the attentions tensors of all attention layers. See `attentions` under                returned tensors for more detail.            output_hidden_states (`bool`, *optional*):                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors                for more detail.            return_dict (`bool`, *optional*):                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.        """        output_attentions = (            output_attentions            if output_attentions is not None            else self.config.output_attentions        )        output_hidden_states = (            output_hidden_states            if output_hidden_states is not None            else self.config.output_hidden_states        )        return_dict = (            return_dict if return_dict is not None else self.config.use_return_dict        )        # retrieve input_ids and inputs_embeds        if input_ids is not None and inputs_embeds is not None:            raise ValueError(                "You cannot specify both input_ids and inputs_embeds at the same time"            )        elif input_ids is not None:            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)            input_shape = input_ids.size()            input_ids = input_ids.view(-1, input_shape[-1])        elif inputs_embeds is not None:            input_shape = inputs_embeds.size()[:-1]        else:            raise ValueError("You have to specify either input_ids or inputs_embeds")        if inputs_embeds is None:            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale        embed_pos = self.embed_positions(input_ids, inputs_embeds)        embed_pos = embed_pos.to(inputs_embeds.device)        hidden_states = inputs_embeds + embed_pos        if self.layernorm_embedding is not None:            x = self.layernorm_embedding(hidden_states)        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)        # expand attention_mask        if attention_mask is not None:            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]            attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)        encoder_states = () if output_hidden_states else None        all_attentions = () if output_attentions else None        # check if head_mask has a correct number of layers specified if desired        if head_mask is not None:            if head_mask.size()[0] != len(self.layers):                raise ValueError(                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"                    f" {head_mask.size()[0]}."                )        deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()        for idx, encoder_layer in enumerate(self.layers):            if output_hidden_states:                encoder_states = encoder_states + (hidden_states,)            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)            dropout_probability = torch.rand([])            skip_the_layer = (                True                if self.training and (dropout_probability < self.layerdrop)                else False            )            if not skip_the_layer or deepspeed_zero3_is_enabled:                # under deepspeed zero3 all gpus must run in sync                if self.gradient_checkpointing and self.training:                    # create gradient checkpointing function                    def create_custom_forward(module):                        def custom_forward(*inputs):                            return module(*inputs, output_attentions)                        return custom_forward                    layer_outputs = torch.utils.checkpoint.checkpoint(                        create_custom_forward(encoder_layer),                        hidden_states,                        attention_mask,                        (head_mask[idx] if head_mask is not None else None),                    )                else:                    layer_outputs = encoder_layer(                        hidden_states,                        attention_mask,                        layer_head_mask=(                            head_mask[idx] if head_mask is not None else None                        ),                        output_attentions=output_attentions,                    )                hidden_states = layer_outputs[0]            if skip_the_layer:                layer_outputs = (None, None)            if output_attentions:                all_attentions = all_attentions + (layer_outputs[1],)        if self.layer_norm is not None:            hidden_states = self.layer_norm(hidden_states)        if output_hidden_states:            encoder_states = encoder_states + (hidden_states,)        if not return_dict:            return tuple(                v                for v in [hidden_states, encoder_states, all_attentions]                if v is not None            )        return BaseModelOutput(            last_hidden_state=hidden_states,            hidden_states=encoder_states,            attentions=all_attentions,        )# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTransclass IndicTransDecoder(IndicTransPreTrainedModel):    """    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]    Args:        config: IndicTransConfig        embed_tokens (nn.Embedding): output embedding    """    def __init__(        self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None    ):        super().__init__(config)        self.dropout = config.dropout        self.layerdrop = config.decoder_layerdrop        embed_dim = config.encoder_embed_dim        self.padding_idx = config.pad_token_id        self.max_target_positions = config.max_target_positions        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0        self.embed_tokens = nn.Embedding(            config.decoder_vocab_size, embed_dim, self.padding_idx        )        if embed_tokens is not None:            self.embed_tokens.weight = embed_tokens.weight        self.embed_positions = IndicTransSinusoidalPositionalEmbedding(            config.max_target_positions,            embed_dim,            self.padding_idx,        )        self.layers = nn.ModuleList(            [IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]        )        self.layer_norm = (            nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None        )        self.layernorm_embedding = (            nn.LayerNorm(embed_dim) if config.layernorm_embedding else None        )        self.gradient_checkpointing = False        # Initialize weights and apply final processing        self.post_init()    def forward(        self,        input_ids: Optional[torch.Tensor] = None,        attention_mask: Optional[torch.Tensor] = None,        encoder_hidden_states: Optional[torch.Tensor] = None,        encoder_attention_mask: Optional[torch.Tensor] = None,        head_mask: Optional[torch.Tensor] = None,        cross_attn_head_mask: Optional[torch.Tensor] = None,        past_key_values: Optional[List[torch.FloatTensor]] = None,        inputs_embeds: Optional[torch.Tensor] = None,        use_cache: Optional[bool] = None,        output_attentions: Optional[bool] = None,        output_hidden_states: Optional[bool] = None,        return_dict: Optional[bool] = None,    ):        r"""        Args:            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you                provide it.                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and                [`PreTrainedTokenizer.__call__`] for details.                [What are input IDs?](../glossary#input-ids)            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:                - 1 for tokens that are **not masked**,                - 0 for tokens that are **masked**.                [What are attention masks?](../glossary#attention-mask)            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention                of the decoder.            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values                selected in `[0, 1]`:                - 1 for tokens that are **not masked**,                - 0 for tokens that are **masked**.                [What are attention masks?](../glossary#attention-mask)            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:                - 1 indicates the head is **not masked**,                - 0 indicates the head is **masked**.            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):                Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing                cross-attention on hidden heads. Mask values selected in `[0, 1]`:                - 1 indicates the head is **not masked**,                - 0 indicates the head is **masked**.            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of                all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of                shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing                `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more                control over how to convert `input_ids` indices into associated vectors than the model's internal                embedding lookup matrix.            output_attentions (`bool`, *optional*):                Whether or not to return the attentions tensors of all attention layers. See `attentions` under                returned tensors for more detail.            output_hidden_states (`bool`, *optional*):                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors                for more detail.            return_dict (`bool`, *optional*):                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.        """        output_attentions = (            output_attentions            if output_attentions is not None            else self.config.output_attentions        )        output_hidden_states = (            output_hidden_states            if output_hidden_states is not None            else self.config.output_hidden_states        )        use_cache = use_cache if use_cache is not None else self.config.use_cache        return_dict = (            return_dict if return_dict is not None else self.config.use_return_dict        )        # retrieve input_ids and inputs_embeds        if input_ids is not None and inputs_embeds is not None:            raise ValueError(                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"            )        elif input_ids is not None:            input_shape = input_ids.size()            input_ids = input_ids.view(-1, input_shape[-1])        elif inputs_embeds is not None:            input_shape = inputs_embeds.size()[:-1]        else:            raise ValueError(                "You have to specify either decoder_input_ids or decoder_inputs_embeds"            )        # past_key_values_length        past_key_values_length = (            past_key_values[0][0].shape[2] if past_key_values is not None else 0        )        if inputs_embeds is None:            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale        # create causal mask        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]        combined_attention_mask = None        if input_shape[-1] > 1:            combined_attention_mask = _make_causal_mask(                input_shape,                inputs_embeds.dtype,                device=inputs_embeds.device,                past_key_values_length=past_key_values_length,            )        if attention_mask is not None and combined_attention_mask is not None:            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]            combined_attention_mask = combined_attention_mask + _expand_mask(                attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]            )        # expand encoder attention mask        if encoder_hidden_states is not None and encoder_attention_mask is not None:            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]            encoder_attention_mask = _expand_mask(                encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]            )        # embed positions        positions = self.embed_positions(            input_ids, inputs_embeds, past_key_values_length        )        positions = positions.to(inputs_embeds.device)        hidden_states = inputs_embeds + positions        if self.layernorm_embedding is not None:            hidden_states = self.layernorm_embedding(hidden_states)        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)        if self.gradient_checkpointing and self.training:            if use_cache:                logger.warning_once(                    "`use_cache=True` is incompatible with gradient checkpointing. Setting"                    " `use_cache=False`..."                )                use_cache = False        # decoder layers        all_hidden_states = () if output_hidden_states else None        all_self_attns = () if output_attentions else None        all_cross_attentions = () if output_attentions else None        next_decoder_cache = () if use_cache else None        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired        for attn_mask, mask_name in zip(            [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]        ):            if attn_mask is not None:                if attn_mask.size()[0] != len(self.layers):                    raise ValueError(                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"                        f" {head_mask.size()[0]}."                    )        deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()        for idx, decoder_layer in enumerate(self.layers):            if output_hidden_states:                all_hidden_states += (hidden_states,)            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)            dropout_probability = torch.rand([])            skip_the_layer = (                True                if self.training and (dropout_probability < self.layerdrop)                else False            )            if not skip_the_layer or deepspeed_zero3_is_enabled:                # under deepspeed zero3 all gpus must run in sync                past_key_value = (                    past_key_values[idx] if past_key_values is not None else None                )                if self.gradient_checkpointing and self.training:                    def create_custom_forward(module):                        def custom_forward(*inputs):                            # None for past_key_value                            return module(*inputs, output_attentions, use_cache)                        return custom_forward                    layer_outputs = torch.utils.checkpoint.checkpoint(                        create_custom_forward(decoder_layer),                        hidden_states,                        combined_attention_mask,                        encoder_hidden_states,                        encoder_attention_mask,                        head_mask[idx] if head_mask is not None else None,                        cross_attn_head_mask[idx]                        if cross_attn_head_mask is not None                        else None,                        None,                    )                else:                    layer_outputs = decoder_layer(                        hidden_states,                        attention_mask=combined_attention_mask,                        encoder_hidden_states=encoder_hidden_states,                        encoder_attention_mask=encoder_attention_mask,                        layer_head_mask=(                            head_mask[idx] if head_mask is not None else None                        ),                        cross_attn_layer_head_mask=(                            cross_attn_head_mask[idx]                            if cross_attn_head_mask is not None                            else None                        ),                        past_key_value=past_key_value,                        output_attentions=output_attentions,                        use_cache=use_cache,                    )                hidden_states = layer_outputs[0]            if skip_the_layer:                continue            if use_cache:                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)            if output_attentions:                all_self_attns += (layer_outputs[1],)                all_cross_attentions += (layer_outputs[2],)        if self.layer_norm is not None:            hidden_states = self.layer_norm(hidden_states)        # add hidden states from the last decoder layer        if output_hidden_states:            all_hidden_states += (hidden_states,)        next_cache = next_decoder_cache if use_cache else None        if not return_dict:            return tuple(                v                for v in [                    hidden_states,                    next_cache,                    all_hidden_states,                    all_self_attns,                    all_cross_attentions,                ]                if v is not None            )        return BaseModelOutputWithPastAndCrossAttentions(            last_hidden_state=hidden_states,            past_key_values=next_cache,            hidden_states=all_hidden_states,            attentions=all_self_attns,            cross_attentions=all_cross_attentions,        )# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTransclass IndicTransModel(IndicTransPreTrainedModel):    _tied_weights_keys = None    def __init__(self, config: IndicTransConfig):        super().__init__(config)        self.encoder = IndicTransEncoder(config)        self.decoder = IndicTransDecoder(config)        # Initialize weights and apply final processing        self.post_init()    def get_encoder(self):        return self.encoder    def get_decoder(self):        return self.decoder    def forward(        self,        input_ids: Optional[torch.LongTensor] = None,        attention_mask: Optional[torch.Tensor] = None,        decoder_input_ids: Optional[torch.LongTensor] = None,        decoder_attention_mask: Optional[torch.LongTensor] = None,        head_mask: Optional[torch.Tensor] = None,        decoder_head_mask: Optional[torch.Tensor] = None,        cross_attn_head_mask: Optional[torch.Tensor] = None,        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,        inputs_embeds: Optional[torch.FloatTensor] = None,        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,        use_cache: Optional[bool] = None,        output_attentions: Optional[bool] = None,        output_hidden_states: Optional[bool] = None,        return_dict: Optional[bool] = None,    ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:        output_attentions = (            output_attentions            if output_attentions is not None            else self.config.output_attentions        )        output_hidden_states = (            output_hidden_states            if output_hidden_states is not None            else self.config.output_hidden_states        )        use_cache = use_cache if use_cache is not None else self.config.use_cache        return_dict = (            return_dict if return_dict is not None else self.config.use_return_dict        )        if encoder_outputs is None:            encoder_outputs = self.encoder(                input_ids=input_ids,                attention_mask=attention_mask,                head_mask=head_mask,                inputs_embeds=inputs_embeds,                output_attentions=output_attentions,                output_hidden_states=output_hidden_states,                return_dict=return_dict,            )        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):            encoder_outputs = BaseModelOutput(                last_hidden_state=encoder_outputs[0],                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,            )        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)        decoder_outputs = self.decoder(            input_ids=decoder_input_ids,            attention_mask=decoder_attention_mask,            encoder_hidden_states=encoder_outputs[0],            encoder_attention_mask=attention_mask,            head_mask=decoder_head_mask,            cross_attn_head_mask=cross_attn_head_mask,            past_key_values=past_key_values,            inputs_embeds=decoder_inputs_embeds,            use_cache=use_cache,            output_attentions=output_attentions,            output_hidden_states=output_hidden_states,            return_dict=return_dict,        )        if not return_dict:            return decoder_outputs + encoder_outputs        return Seq2SeqModelOutput(            last_hidden_state=decoder_outputs.last_hidden_state,            past_key_values=decoder_outputs.past_key_values,            decoder_hidden_states=decoder_outputs.hidden_states,            decoder_attentions=decoder_outputs.attentions,            cross_attentions=decoder_outputs.cross_attentions,            encoder_last_hidden_state=encoder_outputs.last_hidden_state,            encoder_hidden_states=encoder_outputs.hidden_states,            encoder_attentions=encoder_outputs.attentions,        )# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTransclass IndicTransForConditionalGeneration(IndicTransPreTrainedModel):    base_model_prefix = "model"    _tied_weights_keys = None    _label_smoothing = 0.0    def __init__(self, config: IndicTransConfig):        super().__init__(config)        self.model = IndicTransModel(config)        self.lm_head = nn.Linear(            config.decoder_embed_dim, config.decoder_vocab_size, bias=False        )        if config.share_decoder_input_output_embed:            self.lm_head.weight = self.model.decoder.embed_tokens.weight        self.post_init()    def tie_weights(self):        pass    def get_encoder(self):        return self.model.get_encoder()    def get_decoder(self):        return self.model.get_decoder()    def get_output_embeddings(self):        return self.lm_head    def set_output_embeddings(self, new_embeddings):        self.lm_head = new_embeddings        def set_label_smoothing(self, label_smoothing):        self._label_smoothing = label_smoothing    def forward(        self,        input_ids: Optional[torch.LongTensor] = None,        attention_mask: Optional[torch.Tensor] = None,        decoder_input_ids: Optional[torch.LongTensor] = None,        decoder_attention_mask: Optional[torch.LongTensor] = None,        head_mask: Optional[torch.Tensor] = None,        decoder_head_mask: Optional[torch.Tensor] = None,        cross_attn_head_mask: Optional[torch.Tensor] = None,        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,        inputs_embeds: Optional[torch.FloatTensor] = None,        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,        labels: Optional[torch.LongTensor] = None,        use_cache: Optional[bool] = None,        output_attentions: Optional[bool] = None,        output_hidden_states: Optional[bool] = None,        return_dict: Optional[bool] = None,    ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:        r"""        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.        Returns:        """        return_dict = (            return_dict if return_dict is not None else self.config.use_return_dict        )        if labels is not None:            if decoder_input_ids is None:                decoder_input_ids = shift_tokens_right(                    labels, self.config.pad_token_id, self.config.decoder_start_token_id                )        outputs = self.model(            input_ids,            attention_mask=attention_mask,            decoder_input_ids=decoder_input_ids,            encoder_outputs=encoder_outputs,            decoder_attention_mask=decoder_attention_mask,            head_mask=head_mask,            decoder_head_mask=decoder_head_mask,            cross_attn_head_mask=cross_attn_head_mask,            past_key_values=past_key_values,            inputs_embeds=inputs_embeds,            decoder_inputs_embeds=decoder_inputs_embeds,            use_cache=use_cache,            output_attentions=output_attentions,            output_hidden_states=output_hidden_states,            return_dict=return_dict,        )        lm_logits = self.lm_head(outputs[0])        masked_lm_loss = None        if labels is not None:            # move labels to the correct device to enable PP            labels = labels.to(lm_logits.device)            masked_lm_loss = F.cross_entropy(                input=lm_logits.view(-1, self.config.decoder_vocab_size),                target=labels.view(-1),                ignore_index=self.config.pad_token_id,                label_smoothing=self._label_smoothing,            )        if not return_dict:            output = (lm_logits,) + outputs[1:]            return (                ((masked_lm_loss,) + output) if masked_lm_loss is not None else output            )        return Seq2SeqLMOutput(            loss=masked_lm_loss,            logits=lm_logits,            past_key_values=outputs.past_key_values,            decoder_hidden_states=outputs.decoder_hidden_states,            decoder_attentions=outputs.decoder_attentions,            cross_attentions=outputs.cross_attentions,            encoder_last_hidden_state=outputs.encoder_last_hidden_state,            encoder_hidden_states=outputs.encoder_hidden_states,            encoder_attentions=outputs.encoder_attentions,        )    def prepare_inputs_for_generation(        self,        decoder_input_ids,        past_key_values=None,        attention_mask=None,        head_mask=None,        decoder_head_mask=None,        cross_attn_head_mask=None,        use_cache=None,        encoder_outputs=None,        **kwargs,    ):        # cut decoder_input_ids if past is used        if past_key_values is not None:            decoder_input_ids = decoder_input_ids[:, -1:]        return {            "input_ids": None,  # encoder_outputs is defined. input_ids not needed            "encoder_outputs": encoder_outputs,            "past_key_values": past_key_values,            "decoder_input_ids": decoder_input_ids,            "attention_mask": attention_mask,            "head_mask": head_mask,            "decoder_head_mask": decoder_head_mask,            "cross_attn_head_mask": cross_attn_head_mask,            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)        }    @staticmethod    def _reorder_cache(past_key_values, beam_idx):        reordered_past = ()        for layer_past in past_key_values:            reordered_past += (                tuple(                    past_state.index_select(0, beam_idx) for past_state in layer_past                ),            )        return reordered_past