import unittest
from collections import OrderedDict
import holmes_extractor as holmes
import os

script_directory = os.path.dirname(os.path.realpath(__file__))
ontology = holmes.Ontology(os.sep.join((script_directory, "test_ontology.owl")))
holmes_manager = holmes.Manager(
    "en_core_web_trf",
    perform_coreference_resolution=True,
    ontology=ontology,
    number_of_workers=1,
)
no_ontology_holmes_manager = holmes.Manager(
    "en_core_web_trf", perform_coreference_resolution=True, number_of_workers=1
)
ontology2 = holmes.Ontology(os.sep.join((script_directory, "test_ontology.owl")))
no_coref_holmes_manager = holmes.Manager(
    "en_core_web_trf",
    perform_coreference_resolution=False,
    ontology=ontology2,
    number_of_workers=1,
)


def get_first_key_in_dict(dictionary: OrderedDict) -> str:
    return list(dictionary.keys())[0]


class EnglishSupervisedTopicClassificationTest(unittest.TestCase):
    def test_get_labels_to_classification_frequencies_direct_matching(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(one_hot=False)
        sttb.parse_and_register_training_document("A lion chases a tiger", "animals")
        sttb.parse_and_register_training_document("fast", "dummy")
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertEqual(freq["predicate-actor: chasing-lion"], {"animals": 1})
        self.assertEqual(freq["predicate-patient: chasing-tiger"], {"animals": 1})
        self.assertEqual(
            freq["predicate-actor: chasing-lion/predicate-patient: chasing-tiger"],
            {"animals": 1},
        )
        self.assertEqual(freq["word: lion"], {"animals": 1})
        self.assertEqual(freq["word: tiger"], {"animals": 1})

    def test_get_labels_to_classification_frequencies_ontology_matching(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(one_hot=False)
        sttb.parse_and_register_training_document("A dog chases a cat", "animals")
        sttb.parse_and_register_training_document("fast", "dummy")
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertEqual(freq["predicate-actor: chasing-animal"], {"animals": 1})
        self.assertEqual(freq["predicate-patient: chasing-animal"], {"animals": 1})
        self.assertEqual(
            freq["predicate-actor: chasing-animal/predicate-patient: chasing-animal"],
            {"animals": 1},
        )
        self.assertEqual(freq["word: animal"], {"animals": 2})

    def test_get_labels_to_classification_frequencies_ontology_multiword_matching(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(one_hot=False)
        sttb.parse_and_register_training_document(
            "A gymnast jumps over a wastage horse", "gym"
        )
        sttb.parse_and_register_training_document("fast", "dummy")
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertEqual(freq["predicate-actor: jump-gymnast"], {"gym": 1})
        self.assertEqual(freq["word: gymnast"], {"gym": 1})
        self.assertEqual(freq["word: gymnastics equipment"], {"gym": 1})

    def test_linked_matching_common_dependent(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(one_hot=False)
        sttb.parse_and_register_training_document(
            "A lion eats and consumes a tiger", "animals"
        )
        sttb.parse_and_register_training_document("fast", "dummy")
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertEqual(freq["predicate-actor: consume-lion"], {"animals": 1})
        self.assertEqual(freq["predicate-actor: eat-lion"], {"animals": 1})
        self.assertEqual(freq["predicate-patient: consume-tiger"], {"animals": 1})
        self.assertEqual(
            freq["predicate-actor: consume-lion/predicate-patient: consume-tiger"],
            {"animals": 1},
        )
        self.assertEqual(
            freq["predicate-actor: consume-lion/predicate-actor: eat-lion"],
            {"animals": 1},
        )
        self.assertEqual(freq["word: lion"], {"animals": 1})
        self.assertEqual(freq["word: tiger"], {"animals": 1})

    def test_linked_matching_common_dependent_control(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(one_hot=False)
        sttb.parse_and_register_training_document(
            "A lion eats and a lion consumes", "animals"
        )
        sttb.parse_and_register_training_document("fast", "dummy")
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertEqual(freq["predicate-actor: consume-lion"], {"animals": 1})
        self.assertEqual(freq["predicate-actor: eat-lion"], {"animals": 1})
        self.assertTrue(
            "predicate-actor: consume-lion/predicate-actor: eat-lion" not in freq.keys()
        )
        self.assertEqual(freq["word: lion"], {"animals": 2})

    def test_linked_matching_stepped_lower_first(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(one_hot=False)
        sttb.parse_and_register_training_document("A big lion eats", "animals")
        sttb.parse_and_register_training_document("fast", "dummy")
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertEqual(freq["governor-adjective: lion-big"], {"animals": 1})
        self.assertEqual(freq["predicate-actor: eat-lion"], {"animals": 1})
        self.assertEqual(
            freq["governor-adjective: lion-big/predicate-actor: eat-lion"],
            {"animals": 1},
        )
        self.assertEqual(freq["word: lion"], {"animals": 1})

    def test_linked_matching_stepped_lower_second(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(one_hot=False)
        sttb.parse_and_register_training_document(
            "Something eats a big lion", "animals"
        )
        sttb.parse_and_register_training_document("fast", "dummy")
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertEqual(freq["governor-adjective: lion-big"], {"animals": 1})
        self.assertEqual(freq["predicate-patient: eat-lion"], {"animals": 1})
        self.assertEqual(
            freq["governor-adjective: lion-big/predicate-patient: eat-lion"],
            {"animals": 1},
        )
        self.assertEqual(freq["word: lion"], {"animals": 1})

    def test_linked_matching_stepped_control(self):
        sttb = no_coref_holmes_manager.get_supervised_topic_training_basis(
            one_hot=False
        )
        sttb.parse_and_register_training_document(
            "There is a big lion and the lion eats", "animals"
        )
        sttb.parse_and_register_training_document("fast", "dummy")
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertEqual(freq["governor-adjective: lion-big"], {"animals": 1})
        self.assertEqual(freq["predicate-actor: eat-lion"], {"animals": 1})
        self.assertTrue(
            "governor-adjective: lion-big/predicate-actor: eat-lion" not in freq.keys()
        )
        self.assertEqual(freq["word: lion"], {"animals": 2})

    def test_repeating_relation_through_coreference(self):
        sttb = no_ontology_holmes_manager.get_supervised_topic_training_basis()
        sttb.parse_and_register_training_document(
            "The building was used last year. It is used this year", "test"
        )
        sttb.parse_and_register_training_document("fast", "dummy")
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertFalse(
            "predicate-patient: use-building/predicate-patient: use-building" in freq
        )

    def test_one_hot(self):
        sttb1 = no_coref_holmes_manager.get_supervised_topic_training_basis(
            one_hot=False
        )
        sttb1.parse_and_register_training_document(
            "A dog chases a cat. A dog chases a cat", "animals"
        )
        sttb1.parse_and_register_training_document(
            "A dog chases a cat. A dog chases a cat", "animals"
        )
        sttb1.parse_and_register_training_document(
            "A dog chases a cat. A dog chases a cat", "animals2"
        )
        sttb1.parse_and_register_training_document(
            "A dog chases a cat. A dog chases a cat", "animals2"
        )
        sttb1.prepare()
        freq1 = sttb1.labels_to_classification_frequencies
        sttb2 = no_coref_holmes_manager.get_supervised_topic_training_basis(
            one_hot=True
        )
        sttb2.parse_and_register_training_document(
            "A dog chases a cat. A dog chases a cat", "animals"
        )
        sttb2.parse_and_register_training_document(
            "A dog chases a cat. A dog chases a cat", "animals"
        )
        sttb2.parse_and_register_training_document(
            "A dog chases a cat. A dog chases a cat", "animals2"
        )
        sttb2.parse_and_register_training_document(
            "A dog chases a cat. A dog chases a cat", "animals2"
        )
        sttb2.prepare()
        freq2 = sttb2.labels_to_classification_frequencies
        self.assertEqual(
            freq1["predicate-actor: chasing-animal/predicate-patient: chasing-animal"],
            {"animals": 4, "animals2": 4},
        )
        self.assertEqual(
            freq1["predicate-actor: chasing-animal"], {"animals": 4, "animals2": 4}
        )
        self.assertEqual(
            freq1["predicate-patient: chasing-animal"], {"animals": 4, "animals2": 4}
        )
        self.assertEqual(freq1["word: animal"], {"animals": 8, "animals2": 8})
        self.assertEqual(
            freq2["predicate-actor: chasing-animal/predicate-patient: chasing-animal"],
            {"animals": 2, "animals2": 2},
        )
        self.assertEqual(
            freq2["predicate-actor: chasing-animal"], {"animals": 2, "animals2": 2}
        )
        self.assertEqual(
            freq2["predicate-patient: chasing-animal"], {"animals": 2, "animals2": 2}
        )
        self.assertEqual(freq2["word: animal"], {"animals": 2, "animals2": 2})

    def test_multiple_document_classes(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(one_hot=False)
        sttb.parse_and_register_training_document("A dog chases a cat", "animals")
        sttb.parse_and_register_training_document("A cat chases a dog", "animals")
        sttb.parse_and_register_training_document("A cat chases a horse", "animals")
        sttb.parse_and_register_training_document("A cat chases a horse", "animals")
        sttb.parse_and_register_training_document("A gymnast jumps over a horse", "gym")
        sttb.parse_and_register_training_document(
            "A gymnast jumps over a wastage horse", "gym"
        )
        sttb.prepare()
        freq = sttb.labels_to_classification_frequencies
        self.assertEqual(freq["predicate-actor: chasing-animal"], {"animals": 4})
        self.assertEqual(freq["predicate-actor: jump-gymnast"], {"gym": 2})
        self.assertEqual(freq["predicate-patient: chasing-animal"], {"animals": 4})
        self.assertEqual(
            freq["predicate-actor: chasing-animal/predicate-patient: chasing-animal"],
            {"animals": 4},
        )
        self.assertEqual(freq["word: animal"], {"animals": 8, "gym": 2})
        self.assertEqual(freq["word: gymnast"], {"gym": 2})
        self.assertEqual(freq["word: gymnastics equipment"], {"animals": 2, "gym": 2})

    def test_whole_scenario_with_classification_ontology(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(
            classification_ontology=ontology, one_hot=False
        )
        sttb.parse_and_register_training_document("A puppy", "puppy", "d0")
        sttb.parse_and_register_training_document("A pussy", "cat", "d1")
        sttb.parse_and_register_training_document("A dog on a lead", "dog", "d2")
        sttb.parse_and_register_training_document("Mimi Momo", "Mimi Momo", "d3")
        sttb.parse_and_register_training_document("An animal", "animal", "d4")
        sttb.parse_and_register_training_document("A computer", "computers", "d5")
        sttb.parse_and_register_training_document("A robot", "computers", "d6")
        sttb.register_additional_classification_label("parrot")
        sttb.register_additional_classification_label("hound")
        sttb.prepare()
        self.assertEqual(
            {
                "Mimi Momo": ["animal", "cat"],
                "dog": ["animal", "hound"],
                "puppy": ["animal", "dog", "hound"],
                "cat": ["animal"],
                "hound": ["animal", "dog"],
            },
            sttb.classification_implication_dict,
        )
        self.assertEqual(
            ["Mimi Momo", "animal", "cat", "computers", "dog", "hound", "puppy"],
            sttb.classifications,
        )
        # With so little training data, the NN does not consistently learn correctly
        for i in range(20):
            trainer = sttb.train(
                minimum_occurrences=0,
                cv_threshold=0,
                max_epochs=1000,
                learning_rate=0.0001,
                convergence_threshold=0,
            )
            stc = trainer.classifier()
            if (
                get_first_key_in_dict(stc.parse_and_classify("You are a robot."))
                == "computers"
                and get_first_key_in_dict(stc.parse_and_classify("You are a cat."))
                == "animal"
            ):
                break
            if i == 20:
                self.assertTrue(
                    get_first_key_in_dict(stc.parse_and_classify("You are a robot."))
                    == "computers"
                    and get_first_key_in_dict(
                        stc.parse_and_classify("You are a cat.")
                    )
                    == "animal"
                )

        self.assertEqual(
            [
                "prepgovernor-noun: animal-lead",
                "word: animal",
                "word: computer",
                "word: lead",
                "word: robot",
            ],
            list(trainer.sorted_label_dict.keys()),
        )
        self.assertEqual(
            [{1: 1}, {1: 1}, {1: 1, 0: 1, 3: 1}, {1: 1}, {1: 1}, {2: 1}, {4: 1}],
            trainer.occurrence_dicts,
        )
        self.assertEqual(
            [
                [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
                [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
                [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
            ],
            trainer.output_matrix.tolist(),
        )
        self.assertEqual([5, 5, 6], trainer._hidden_layer_sizes)
        self.assertIsNone(
            stc.parse_and_classify("My name is Charles and I like sewing.")
        )
        serialized_supervised_topic_classifier_model = stc.serialize_model()
        stc2 = no_ontology_holmes_manager.deserialize_supervised_topic_classifier(
            serialized_supervised_topic_classifier_model, verbose=True
        )
        self.assertEqual(
            [
                "prepgovernor-noun: animal-lead",
                "word: animal",
                "word: computer",
                "word: lead",
                "word: robot",
            ],
            list(stc2.model.sorted_label_dict.keys()),
        )
        self.assertEqual(
            get_first_key_in_dict(stc2.parse_and_classify("You are a robot.")),
            "computers",
        )
        self.assertEqual(
            get_first_key_in_dict(stc2.parse_and_classify("You are a cat.")), "animal"
        )
        self.assertIsNone(
            stc2.parse_and_classify("My name is Charles and I like sewing.")
        )

    def test_whole_scenario_with_classification_ontology_and_match_all_words(self):
        sttb = holmes_manager.get_supervised_topic_training_basis(
            classification_ontology=ontology, match_all_words=True, one_hot=False
        )
        sttb.parse_and_register_training_document("A puppy", "puppy", "d0")
        sttb.parse_and_register_training_document("A pussy", "cat", "d1")
        sttb.parse_and_register_training_document("A dog on a lead", "dog", "d2")
        sttb.parse_and_register_training_document("Mimi Momo", "Mimi Momo", "d3")
        sttb.parse_and_register_training_document("An animal", "animal", "d4")
        sttb.parse_and_register_training_document("A computer", "computers", "d5")
        sttb.parse_and_register_training_document("A robot", "computers", "d6")
        sttb.register_additional_classification_label("parrot")
        sttb.register_additional_classification_label("hound")
        sttb.prepare()
        self.assertEqual(
            {
                "Mimi Momo": ["animal", "cat"],
                "dog": ["animal", "hound"],
                "puppy": ["animal", "dog", "hound"],
                "cat": ["animal"],
                "hound": ["animal", "dog"],
            },
            sttb.classification_implication_dict,
        )
        self.assertEqual(
            ["Mimi Momo", "animal", "cat", "computers", "dog", "hound", "puppy"],
            sttb.classifications,
        )
        # With so little training data, the NN does not consistently learn correctly
        for i in range(20):
            trainer = sttb.train(
                minimum_occurrences=0,
                cv_threshold=0,
                max_epochs=1000,
                learning_rate=0.0001,
                convergence_threshold=0,
            )
            stc = trainer.classifier()
            if (
                get_first_key_in_dict(stc.parse_and_classify("You are a robot."))
                == "computers"
                and get_first_key_in_dict(stc.parse_and_classify("You are a cat."))
                == "animal"
            ):
                break
            if i == 20:
                self.assertTrue(
                    get_first_key_in_dict(stc.parse_and_classify("You are a robot."))
                    == "computers"
                    and get_first_key_in_dict(
                        stc.parse_and_classify("You are a cat.")
                    )
                    == "animal"
                )

        
        self.assertEqual(
            [
                "prepgovernor-noun: animal-lead",
                "word: animal",
                "word: computer",
                "word: lead",
                "word: mimi",
                "word: momo",
                "word: on",
                "word: robot",
            ],
            list(trainer.sorted_label_dict.keys()),
        )
        self.assertEqual(
            [{1: 1}, {1: 1}, {0: 1, 1: 1, 3: 1, 6: 1}, {1: 1, 4: 1, 5: 1}, {1: 1}, {2: 1}, {7: 1}],
            trainer.occurrence_dicts,
        )
        self.assertEqual(
            [
                [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
                [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
                [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
            ],
            trainer.output_matrix.tolist(),
        )
        self.assertEqual([8, 7, 7], trainer._hidden_layer_sizes)
        self.assertIsNone(
            stc.parse_and_classify("My name is Charles and I like sewing.")
        )
        serialized_supervised_topic_classifier_model = stc.serialize_model()
        stc2 = no_ontology_holmes_manager.deserialize_supervised_topic_classifier(
            serialized_supervised_topic_classifier_model
        )
        self.assertEqual(
            [
                "prepgovernor-noun: animal-lead",
                "word: animal",
                "word: computer",
                "word: lead",
                "word: mimi",
                "word: momo",
                "word: on",
                "word: robot",
            ],
            list(stc2.model.sorted_label_dict.keys()),
        )
        self.assertEqual(
            get_first_key_in_dict(stc2.parse_and_classify("You are a robot.")),
            "computers",
        )
        self.assertEqual(
            get_first_key_in_dict(stc2.parse_and_classify("You are a cat.")), "animal"
        )
        self.assertIsNone(
            stc2.parse_and_classify("My name is Charles and I like sewing.")
        )

    def test_filtering(self):
        sttb = holmes_manager.get_supervised_topic_training_basis()
        sttb.parse_and_register_training_document("A dog chases a cat", "animals")
        sttb.parse_and_register_training_document("A cat chases a dog", "animals")
        sttb.parse_and_register_training_document("A cat chases a horse", "animals")
        sttb.parse_and_register_training_document("A cat chases a horse", "animals")
        sttb.parse_and_register_training_document("A gymnast jumps over a horse", "gym")
        sttb.parse_and_register_training_document(
            "A gymnast jumps over a vaulting horse", "gym"
        )
        sttb.prepare()
        trainer = sttb.train(minimum_occurrences=4, cv_threshold=0.0)
        self.assertEqual(
            list(trainer.sorted_label_dict.keys()),
            [
                "predicate-actor: chasing-animal",
                "predicate-actor: chasing-animal/predicate-patient: chasing-animal",
                "predicate-patient: chasing-animal",
                "word: animal",
            ],
        )
        self.assertEqual(
            set(map(lambda phr: phr.label, trainer.phraselet_infos)),
            {
                "predicate-actor: chasing-animal",
                "predicate-patient: chasing-animal",
                "word: animal",
            },
        )
        trainer2 = sttb.train(minimum_occurrences=4, cv_threshold=1)
        self.assertEqual(
            list(trainer2.sorted_label_dict.keys()),
            [
                "predicate-actor: chasing-animal",
                "predicate-actor: chasing-animal/predicate-patient: chasing-animal",
                "predicate-patient: chasing-animal",
            ],
        )
        self.assertEqual(
            set(map(lambda phr: phr.label, trainer2.phraselet_infos)),
            {"predicate-actor: chasing-animal", "predicate-patient: chasing-animal"},
        )
