import os
import re
import sys
import tempfile

import torchvision
from hydra.experimental import compose, initialize

from lightly import cli
from tests.api_workflow.mocked_api_workflow_client import (
    MockedApiWorkflowClient,
    MockedApiWorkflowSetup,
)


class TestCLITrain(MockedApiWorkflowSetup):
    def setUp(self):
        MockedApiWorkflowSetup.setUp(self)
        self.create_fake_dataset()
        with initialize(config_path="../../lightly/cli/config", job_name="test_app"):
            self.cfg = compose(
                config_name="config",
                overrides=[
                    f"input_dir={self.folder_path}",
                    "trainer.max_epochs=1",
                ],
            )

    def create_fake_dataset(self):
        n_data = 5
        self.dataset = torchvision.datasets.FakeData(
            size=n_data, image_size=(3, 32, 32)
        )

        self.folder_path = tempfile.mkdtemp()
        sample_names = [f"img_{i}.jpg" for i in range(n_data)]
        self.sample_names = sample_names
        for sample_idx in range(n_data):
            data = self.dataset[sample_idx]
            path = os.path.join(self.folder_path, sample_names[sample_idx])
            data[0].save(path)

    def test_checkpoint_created(self):
        cli.train_cli(self.cfg)
        checkpoint_path = os.getenv(
            self.cfg["environment_variable_names"]["lightly_last_checkpoint_path"]
        )
        assert checkpoint_path.endswith(".ckpt")
        assert os.path.isfile(checkpoint_path)

    def tearDown(self) -> None:
        for filename in ["embeddings.csv", "embeddings_sorted.csv"]:
            try:
                os.remove(filename)
            except FileNotFoundError:
                pass
