import os
import torch
import numpy as np
import json
import re
import jsonlines
from torch.autograd import Variable


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

def tok2int_sent(sentence, tokenizer, max_seq_length):
    """Loads a data file into a list of `InputBatch`s."""
    sent_a, sent_b = sentence
    tokens_a = tokenizer.tokenize(sent_a)
    tokens_b = tokenizer.tokenize(sent_b)
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
    tokens =  ["[CLS]"] + tokens_a + ["[SEP]"]
    segment_ids = [0] * len(tokens)
    tokens = tokens + tokens_b + ["[SEP]"]
    segment_ids += [1] * (len(tokens_b) + 1)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)
    padding = [0] * (max_seq_length - len(input_ids))

    input_ids += padding
    input_mask += padding
    segment_ids += padding

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    return input_ids, input_mask, segment_ids





def tok2int_list(src_list, tokenizer, max_seq_length, max_seq_size=-1):
    inp_padding = list()
    msk_padding = list()
    seg_padding = list()
    for step, sent in enumerate(src_list):
        input_ids, input_mask, input_seg = tok2int_sent(sent, tokenizer, max_seq_length)
        inp_padding.append(input_ids)
        msk_padding.append(input_mask)
        seg_padding.append(input_seg)
    if max_seq_size != -1:
        inp_padding = inp_padding[:max_seq_size]
        msk_padding = msk_padding[:max_seq_size]
        seg_padding = seg_padding[:max_seq_size]
        inp_padding += ([[0] * max_seq_length] * (max_seq_size - len(inp_padding)))
        msk_padding += ([[0] * max_seq_length] * (max_seq_size - len(msk_padding)))
        seg_padding += ([[0] * max_seq_length] * (max_seq_size - len(seg_padding)))
    return inp_padding, msk_padding, seg_padding


class DataLoader(object):
    ''' For data iteration '''

    def __init__(self, data_path, label_map, tokenizer, args, test=False, cuda=True, batch_size=64):
        self.cuda = cuda

        self.batch_size = batch_size
        self.tokenizer = tokenizer
        self.max_len = args.max_len
        self.evi_num = args.evi_num
        self.label_map = label_map
        self.threshold = args.threshold
        self.data_path = data_path
        examples = self.read_file(data_path)
        self.examples = examples
        inputs, labels = list(zip(* examples))
        self.inputs = inputs
        self.labels = labels
        self.test = test

        self.total_num = len(examples)
        if self.test:
            self.total_step = self.total_num // batch_size #np.ceil(self.total_num * 1.0 / batch_size)
        else:
            self.total_step = self.total_num // batch_size
            self.shuffle()
        self.step = 0



    def read_file(self, data_path):
        examples = list()
        with open(data_path) as fin:
            for step, line in enumerate(fin):
                instance = json.loads(line.strip())
                claim = instance['claim']
                evi_list = list()
                for evidence in instance['evidence']:
                    evi_list.append([claim, evidence[3]])
                label = self.label_map[instance['label']]
                evi_list = evi_list[:self.evi_num]
                examples.append([evi_list, label])
        return examples


    def shuffle(self):
        np.random.shuffle(self.examples)


    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def __len__(self):
        return self._n_batch

    def next(self):
        ''' Get the next batch '''

        if self.step < self.total_step:
            inputs = self.inputs[self.step * self.batch_size : (self.step+1)*self.batch_size]
            labels = self.labels[self.step * self.batch_size : (self.step+1)*self.batch_size]
            inp_padding_inputs, msk_padding_inputs, seg_padding_inputs = [], [], []
            for step in range(len(inputs)):
                inp, msk, seg = tok2int_list(inputs[step], self.tokenizer, self.max_len, self.evi_num)
                inp_padding_inputs += inp
                msk_padding_inputs += msk
                seg_padding_inputs += seg

            inp_tensor_input = Variable(
                torch.LongTensor(inp_padding_inputs)).view(-1, self.evi_num, self.max_len)
            msk_tensor_input = Variable(
                torch.LongTensor(msk_padding_inputs)).view(-1, self.evi_num, self.max_len)
            seg_tensor_input = Variable(
                torch.LongTensor(seg_padding_inputs)).view(-1, self.evi_num, self.max_len)
            lab_tensor = Variable(
                torch.LongTensor(labels))
            if self.cuda:
                inp_tensor_input = inp_tensor_input.cuda()
                msk_tensor_input = msk_tensor_input.cuda()
                seg_tensor_input = seg_tensor_input.cuda()
                lab_tensor = lab_tensor.cuda()
            self.step += 1
            return (inp_tensor_input, msk_tensor_input, seg_tensor_input), lab_tensor
        else:
            self.step = 0
            if not self.test:
                self.shuffle()
                inputs, labels = list(zip(*self.examples))
                self.inputs = inputs
                self.labels = labels
            raise StopIteration()

class DataLoaderTest(object):
    ''' For data iteration '''

    def __init__(self, args, label_map, tokenizer, cuda=True, batch_size=64):
        self.cuda = cuda

        self.batch_size = batch_size
        self.tokenizer = tokenizer
        self.max_len = args.max_len
        self.evi_num = args.evi_num
        self.label_map = label_map
        examples = self.read_file(args)
        self.examples = examples
        ids, dids, sids, inputs = list(zip(* examples))
        self.inputs = inputs
        self.ids = ids
        self.dids = dids
        self.sids = sids
        self.total_num = len(examples)
        self.total_step = np.ceil(self.total_num * 1.0 / batch_size)
        self.step = 0


    def read_file(self, args):
        corpus = {doc['doc_id']: doc for doc in jsonlines.open(args.corpus)}
        evidence_retrieval = jsonlines.open(args.evidence_retrieval)
        dataset = jsonlines.open(args.dataset)
        examples = list()
        for data, retrieval in list(zip(dataset, evidence_retrieval)):
            assert data['id'] == retrieval['claim_id']
            for did, sids in retrieval["evidence"].items():
                evidence_list = list()
                content = corpus[int(did)]
                for sid in sids:
                    evidence = content["abstract"][sid]
                    evidence_list.append([data["claim"], evidence])
                examples.append([int(data['id']), did, sids, evidence_list])
        return examples


    def shuffle(self):
        np.random.shuffle(self.examples)

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def __len__(self):
        return self._n_batch

    def next(self):
        ''' Get the next batch '''

        if self.step < self.total_step:
            inputs = self.inputs[self.step * self.batch_size : (self.step+1)*self.batch_size]
            ids = self.ids[self.step * self.batch_size : (self.step+1)*self.batch_size]
            dids = self.dids[self.step * self.batch_size: (self.step + 1) * self.batch_size]
            sids = self.sids[self.step * self.batch_size: (self.step + 1) * self.batch_size]
            inp_padding_inputs, msk_padding_inputs, seg_padding_inputs = [], [], []
            for step in range(len(inputs)):
                inp, msk, seg = tok2int_list(inputs[step], self.tokenizer, self.max_len, self.evi_num)
                inp_padding_inputs += inp
                msk_padding_inputs += msk
                seg_padding_inputs += seg

            inp_tensor_input = Variable(
                torch.LongTensor(inp_padding_inputs))
            msk_tensor_input = Variable(
                torch.LongTensor(msk_padding_inputs))
            seg_tensor_input = Variable(
                torch.LongTensor(seg_padding_inputs))

            if self.cuda:
                inp_tensor_input = inp_tensor_input.cuda()
                msk_tensor_input = msk_tensor_input.cuda()
                seg_tensor_input = seg_tensor_input.cuda()

            self.step += 1
            return  ids, dids, sids, (inp_tensor_input, msk_tensor_input, seg_tensor_input)
        else:
            self.step = 0
            raise StopIteration()
