0915CXH_DL_SA/ApneaDetection.py
2023-09-17 14:49:23 +08:00

92 lines
3.4 KiB
Python

#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:andrew
@file:ApneaDetection.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2023/09/15
"""
from pathlib import Path
import numpy as np
import pandas as pd
from utils.SignalPreprocess import XinXiaoPreprocess
from utils.ModelDetection import SA_Detect
from utils.ResultSummary import AnalyseSegment
import argparse
def is_valid_file(par, arg):
if not Path(arg).is_file():
par.error(f'只能输入模型文件而不是文件夹:{arg}')
else:
return arg
def read_auto(file_path):
if file_path.suffix == '.txt':
# 使用pandas read csv读取txt
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
elif file_path.suffix == '.npy':
return np.load(file_path.__str__())
else:
raise ValueError('这个文件类型好像不支持需要自己写读取程序')
def main(opt):
# 读取数据集
data_path = Path(opt.data_path)
all_file_path = []
print(data_path)
if data_path.is_file():
all_file_path.append(data_path)
elif data_path.is_dir():
all_file_path = list(data_path.glob('*.txt')) + list(data_path.glob('*.npy'))
else:
raise FileNotFoundError('没有找到以txt或npy结尾的文件')
if not all_file_path:
raise FileNotFoundError('没有找到以txt或npy结尾的文件')
for one_file in all_file_path:
data = read_auto(one_file)
data = XinXiaoPreprocess(data, opt.hz)
model_path = f'SADetectModel/SAmodel{opt.model}.pt'
if not Path(model_path).exists():
raise FileNotFoundError('模型文件不存在')
else:
print(f'正在使用模型{model_path}进行预测')
result = SA_Detect(model_path, data, opt.batch_size)
# result.to_csv(Path(opt.output) / (one_file.stem + "segment.csv"), index=False)
if opt.output:
Path(opt.output).mkdir(parents=True, exist_ok=True)
save_path = Path(opt.output) / (one_file.stem + ".csv")
else:
save_path = None
AnalyseSegment(result, save_path=save_path, true_sleep_time=opt.TST)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SA Detection Configuration ----- Marques')
parser.add_argument('-d', '--data_path', nargs='?', type=str, default='./Data', help='待测试数据文件路径或文件夹')
parser.add_argument('-n', '--hz', type=int, nargs='?', default=1000, help='信号采样率')
parser.add_argument('-m', '--model', nargs='?', default=0, choices=[0, 1, 3, 4], help='选择一个模型路径')
parser.add_argument('-b', "--batch_size", nargs="?", default=4096, type=int, help="模型每次预测片段数量")
parser.add_argument('-o', '--output', nargs='?', default='./Output', help='输出文件夹')
parser.add_argument('-t', "--TST", nargs="?", default=0, type=int, help="睡眠总时长(仅单文件测试是支持手动输入)")
option = parser.parse_args()
# 手动输入,优先级高于默认值和命令行输入
# option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy/3103.npy"
# option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy"
# option.hz = 1000
# option.model = 0
# option.batch_size = 4096
# option.output = "./Output"
# option.TST = 565
main(option)