import os

import numpy as np

from utils import ModuleFindTool
from utils.IID import generate_iid_data, generate_non_iid_data


class BaseDataset:
    def __init__(self, iid_config):
        self.index_list = None
        self.label_min = None
        self.label_max = None
        self.datasets = []
        self.iid_config = iid_config
        self.train_data_size = None
        self.test_data = None
        self.train_data = None
        self.test_dataset = None
        self.train_labels = None
        self.raw_data = None
        self.train_dataset = None
        self.path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../data/')

    def init(self, clients, train_dataset, test_dataset):
        self.raw_data = train_dataset.data
        self.train_labels = np.array(train_dataset.targets)
        self.train_data = train_dataset.data
        self.test_data = test_dataset.data
        self.label_max = self.train_labels.max()
        self.label_min = self.train_labels.min()

        self.train_data_size = self.train_data.shape[0]
        self.generate_data(clients, train_dataset, test_dataset)

    def get_test_dataset(self):
        return self.test_dataset

    def get_index_list(self):
        return self.index_list

    def get_train_dataset(self):
        return self.train_dataset

    def get_config(self):
        return self.iid_config

    def generate_data(self, clients, train_dataset, test_dataset):
        if isinstance(self.iid_config, bool):
            print("generating iid data...")
            self.index_list = generate_iid_data(self, clients)
        elif isinstance(self.iid_config, dict) and "path" in self.iid_config:
            print("generate customize data distribution")
            data_distribution_generator = ModuleFindTool.find_class_by_path(self.iid_config["path"])()(self.iid_config["params"])
            self.index_list = data_distribution_generator.generate_data(self.iid_config, self, train_dataset)
        else:
            print("generating non_iid data...")
            self.index_list = generate_non_iid_data(self.iid_config, self, clients, self.label_min, self.label_max + 1,
                                                    train_dataset)
        print("data generation process completed")
