#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@file:load_dataset.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2021/12/03
"""
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import torch.utils.data
from torch.utils.data import Dataset
from tqdm import tqdm
from utils.Preprocessing import BCG_Operation

preprocessing = BCG_Operation()
preprocessing.sample_rate = 100
"""
1. 读取方法
# 无论是否提前切分,均提前转成npy格式
# 1.1 提前预处理,切分好后生成npy,直接载入切分好的片段   内存占用多  读取简单
使用此方法: 1.2 提前预处理,载入整夜数据,切分好后生成csv或xls,根据片段读取  内存占用少    读取较为复杂
"""

datasets = {}


# 减少重复读取
def read_dataset(data_path, augment=None):
    data_path = Path(data_path)
    try:
        f = []
        if data_path.is_dir():
            dataset_list = list(data_path.rglob("*.npy"))
            dataset_list.sort()
            f += dataset_list
        elif data_path.is_file():
            raise Exception(f'dataset path should be a dir')
        else:
            raise Exception(f'{data_path} does not exist')
    except Exception as e:
        raise Exception(f'Error loading data from {data_path}: {e} \n')

    print("loading dataset")
    for i in tqdm(f):
        select_dataset = np.load(i)
        select_dataset = preprocessing.Butterworth(select_dataset, "lowpass", low_cut=20, order=3)
        if augment is not None:
            select_dataset = augment(select_dataset)
        datasets[i.name.split("samp")[0]] = select_dataset


# 用第二种方法读取
class ApneaDataset(Dataset):
    def __init__(self, data_path, label_path, select_sampno, dataset_type, segment_augment=None):
        self.data_path = data_path
        self.label_path = label_path
        self.segment_augment = segment_augment
        self.labels = None
        self.dataset_type = dataset_type
        self.select_sampNo = select_sampno

        # self._getAllData()
        self._getAllLabels()

    def __getitem__(self, index):
        # PN patience number
        # SP/EP start point, end point
        # temp_label.append([sampNo, label[-1], i, hpy_num, csa_num, osa_num, mean_low, flow_low])
        PN, segmentNo, label_type, new_label, SP, EP = self.labels[index]
        # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels[index]

        if isinstance(datasets, dict):
            dataset = datasets[str(PN)]
            segment = self.segment_augment(dataset, SP, EP)
            return (*segment, int(float(label_type) > 1), PN, segmentNo, label_type, new_label, SP, EP)
        else:
            raise Exception(f'dataset read failure!')

    def count_SA(self):
        return sum(self.labels[:, 3] > 1)

    def __len__(self):
        return len(self.labels)

    def _getAllLabels(self):
        label_path = Path(self.label_path)
        if not label_path.exists():
            raise Exception(f'{self.label_path} does not exist')

        try:
            f = []
            if label_path.is_dir():
                if self.dataset_type == "train":
                    label_list = list(label_path.rglob("*_train_label.csv"))
                elif self.dataset_type == "valid":
                    label_list = list(label_path.rglob("*_valid_label.csv"))
                elif self.dataset_type == "test":
                    label_list = list(label_path.glob("*_sa_test_label.csv"))
                    # label_list = list(label_path.rglob("*_test_label.npy"))
                elif self.dataset_type == "all_test":
                    label_list = list(label_path.rglob("*_sa_all_label.csv"))
                else:
                    raise ValueError("self.dataset type error")
                # label_list = list(label_path.rglob("*_label.npy"))
                label_list.sort()
                f += label_list
            elif label_path.is_file():
                raise Exception(f'dataset path should be a dir')
            else:
                raise Exception(f'{self.label_path} does not exist')
        except Exception as e:
            raise Exception(f'Error loading data from {self.label_path}: {e} \n')
        print("loading labels")
        for i in tqdm(f):
            if int(i.name.split("_")[0]) not in self.select_sampNo:
                continue

            if self.labels is None:
                self.labels = pd.read_csv(i).to_numpy(dtype=int)
            else:
                labels = pd.read_csv(i).to_numpy(dtype=int)
                if len(labels) > 0:
                    self.labels = np.concatenate((self.labels, labels))
        # self.labels = self.labels[:10000]
        print(f"{self.dataset_type} length is {len(self.labels)}")


class TestApneaDataset2(ApneaDataset):
    def __init__(self, data_path, label_path, select_sampno, dataset_type, segment_augment=None):
        super(TestApneaDataset2, self).__init__(
            data_path=data_path,
            label_path=label_path,
            dataset_type=dataset_type,
            segment_augment=segment_augment,
            select_sampno=select_sampno
        )

    def __getitem__(self, index):
        PN, segmentNo, label_type, new_label, SP, EP = self.labels[index]
        # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels[index]

        if isinstance(datasets, dict):
            dataset = datasets[str(PN)]
            segment = self.segment_augment(dataset, SP, EP)
            return (*segment, int(float(label_type) > 1), PN, segmentNo, label_type, new_label, SP, EP)
        else:
            raise Exception(f'dataset read failure!')


if __name__ == '__main__':
    pass