"""
StarGAN v2
Copyright (c) 2020-present NAVER Corp.

This work is licensed under the Creative Commons Attribution-NonCommercial
4.0 International License. To view a copy of this license, visit
http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""

import os
import torch


class CheckpointIO(object):
    def __init__(self, fname_template, data_parallel=False, **kwargs):
        os.makedirs(os.path.dirname(fname_template), exist_ok=True)
        self.fname_template = fname_template
        self.module_dict = kwargs
        self.data_parallel = data_parallel

    def register(self, **kwargs):
        self.module_dict.update(kwargs)

    def save(self, step):
        fname = self.fname_template.format(step)
        print('Saving checkpoint into %s...' % fname)

        outdict = {}

        for name, module in self.module_dict.items():
            if self.data_parallel:
                outdict[name] = module.module.state_dict()
            else:
                outdict[name] = module.state_dict()
                        
        torch.save(outdict, fname)

    def load(self, args, step):
        fname = self.fname_template.format(step)
        #import pdb; pdb.set_trace()
        #assert os.path.exists(fname), fname + ' does not exist!'
        print('Loading checkpoint from %s...' % fname)
        # if torch.cuda.is_available():
        #     module_dict = torch.load(fname)
        # else:
        try:
            if args.mode == "eval":
                module_dict = torch.load(fname, map_location=torch.device('cuda'))
            else:
                module_dict = torch.load(fname, map_location=torch.device('cpu'))
                
            for name, module in self.module_dict.items():
                if not args.step2:
                    if self.data_parallel:
                        module.module.load_state_dict(module_dict[name])
                    else:
                        module.load_state_dict(module_dict[name])
                else:
                    #import pdb; pdb.set_trace()
                    name_load = module_dict.keys()
                    if  "prompt" not in name:
                        if self.data_parallel:
                            module.module.load_state_dict(module_dict[name])
                        else:
                            module.load_state_dict(module_dict[name])
                    else:
                        if "prompt" in name_load:
                            if self.data_parallel:
                                module.module.load_state_dict(module_dict[name])
                            else:
                                module.load_state_dict(module_dict[name])
                        else:
                            continue
        except:
            print("@@@@@@@@@@@@@@@@@" + fname + ' does not exist!')

        

