#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:andrew
@file:ApneaDetection.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2023/09/15
"""
from pathlib import Path
import numpy as np
import pandas as pd
from utils.SignalPreprocess import XinXiaoPreprocess
from utils.ModelDetection import SA_Detect
from utils.ResultSummary import AnalyseSegment
import argparse


def is_valid_file(par, arg):
    if not Path(arg).is_file():
        par.error(f'只能输入模型文件而不是文件夹:{arg}')
    else:
        return arg


def read_auto(file_path):
    if file_path.suffix == '.txt':
        # 使用pandas read csv读取txt
        return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
    elif file_path.suffix == '.npy':
        return np.load(file_path.__str__())
    else:
        raise ValueError('这个文件类型好像不支持需要自己写读取程序')


def main(opt):
    # 读取数据集
    data_path = Path(opt.data_path)
    all_file_path = []
    print(data_path)
    if data_path.is_file():
        all_file_path.append(data_path)
    elif data_path.is_dir():
        all_file_path = list(data_path.glob('*.txt')) + list(data_path.glob('*.npy'))
    else:
        raise FileNotFoundError('没有找到以txt或npy结尾的文件')

    if not all_file_path:
        raise FileNotFoundError('没有找到以txt或npy结尾的文件')

    for one_file in all_file_path:
        data = read_auto(one_file)

        data = XinXiaoPreprocess(data, opt.hz)
        model_path = f'SADetectModel/SAmodel{opt.model}.pt'
        if not Path(model_path).exists():
            raise FileNotFoundError('模型文件不存在')
        else:
            print(f'正在使用模型{model_path}进行预测')
        result = SA_Detect(model_path, data, opt.batch_size)
        # result.to_csv(Path(opt.output) / (one_file.stem + "segment.csv"), index=False)
        if opt.output:
            Path(opt.output).mkdir(parents=True, exist_ok=True)
            save_path = Path(opt.output) / (one_file.stem + ".csv")
        else:
            save_path = None

        AnalyseSegment(result, save_path=save_path, true_sleep_time=opt.TST)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='SA Detection Configuration ----- Marques')
    parser.add_argument('-d', '--data_path', nargs='?', type=str, default='./Data', help='待测试数据文件路径或文件夹')
    parser.add_argument('-n', '--hz', type=int, nargs='?', default=1000, help='信号采样率')
    parser.add_argument('-m', '--model', nargs='?', default=0, choices=[0, 1, 3, 4], help='选择一个模型路径')
    parser.add_argument('-b', "--batch_size", nargs="?", default=4096, type=int, help="模型每次预测片段数量")
    parser.add_argument('-o', '--output', nargs='?', default='./Output', help='输出文件夹')
    parser.add_argument('-t', "--TST", nargs="?", default=0, type=int, help="睡眠总时长(仅单文件测试是支持手动输入)")
    option = parser.parse_args()

    # 手动输入,优先级高于默认值和命令行输入
    # option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy/3103.npy"
    # option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy"
    # option.hz = 1000
    # option.model = 0
    # option.batch_size = 4096
    # option.output = "./Output"
    # option.TST = 565


    main(option)