"""Model for InstructNeRF2NeRF"""from __future__ import annotationsfrom dataclasses import dataclass, fieldfrom typing import Typeimport torchfrom torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarityfrom nerfstudio.model_components.losses import (    L1Loss,    MSELoss,    interlevel_loss,)from nerfstudio.models.nerfacto import NerfactoModel, NerfactoModelConfig@dataclassclass InstructNeRF2NeRFModelConfig(NerfactoModelConfig):    """Configuration for the InstructNeRF2NeRFModel."""    _target: Type = field(default_factory=lambda: InstructNeRF2NeRFModel)    use_lpips: bool = True    """Whether to use LPIPS loss"""    use_l1: bool = True    """Whether to use L1 loss"""    patch_size: int = 32    """Patch size to use for LPIPS loss."""    lpips_loss_mult: float = 1.0    """Multiplier for LPIPS loss."""class InstructNeRF2NeRFModel(NerfactoModel):    """Model for InstructNeRF2NeRF."""    config: InstructNeRF2NeRFModelConfig    def populate_modules(self):        """Required to use L1 Loss."""        super().populate_modules()        if self.config.use_l1:            self.rgb_loss = L1Loss()        else:            self.rgb_loss = MSELoss()        self.lpips = LearnedPerceptualImagePatchSimilarity()    def get_loss_dict(self, outputs, batch, metrics_dict=None):        loss_dict = {}        image = batch["image"].to(self.device)        loss_dict["rgb_loss"] = self.rgb_loss(image, outputs["rgb"])        if self.config.use_lpips:            out_patches = (outputs["rgb"].view(-1, self.config.patch_size,self.config.patch_size, 3).permute(0, 3, 1, 2) * 2 - 1).clamp(-1, 1)            gt_patches = (image.view(-1, self.config.patch_size,self.config.patch_size, 3).permute(0, 3, 1, 2) * 2 - 1).clamp(-1, 1)            loss_dict["lpips_loss"] = self.config.lpips_loss_mult * self.lpips(out_patches, gt_patches)        if self.training:            loss_dict["interlevel_loss"] = self.config.interlevel_loss_mult * interlevel_loss(                outputs["weights_list"], outputs["ray_samples_list"]            )            assert metrics_dict is not None and "distortion" in metrics_dict            loss_dict["distortion_loss"] = self.config.distortion_loss_mult * metrics_dict["distortion"]            if self.config.predict_normals:                # orientation loss for computed normals                loss_dict["orientation_loss"] = self.config.orientation_loss_mult * torch.mean(                    outputs["rendered_orientation_loss"]                )                # ground truth supervision for normals                loss_dict["pred_normal_loss"] = self.config.pred_normal_loss_mult * torch.mean(                    outputs["rendered_pred_normal_loss"]                )        return loss_dict