92 lines
3.4 KiB
Python
92 lines
3.4 KiB
Python
#!/usr/bin/python
|
|
# -*- coding: UTF-8 -*-
|
|
"""
|
|
@author:andrew
|
|
@file:ApneaDetection.py
|
|
@email:admin@marques22.com
|
|
@email:2021022362@m.scnu.edu.cn
|
|
@time:2023/09/15
|
|
"""
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import pandas as pd
|
|
from utils.SignalPreprocess import XinXiaoPreprocess
|
|
from utils.ModelDetection import SA_Detect
|
|
from utils.ResultSummary import AnalyseSegment
|
|
import argparse
|
|
|
|
|
|
def is_valid_file(par, arg):
|
|
if not Path(arg).is_file():
|
|
par.error(f'只能输入模型文件而不是文件夹:{arg}')
|
|
else:
|
|
return arg
|
|
|
|
|
|
def read_auto(file_path):
|
|
if file_path.suffix == '.txt':
|
|
# 使用pandas read csv读取txt
|
|
return pd.read_csv(file_path, header=None).to_numpy().reshape(-1)
|
|
elif file_path.suffix == '.npy':
|
|
return np.load(file_path.__str__())
|
|
else:
|
|
raise ValueError('这个文件类型好像不支持需要自己写读取程序')
|
|
|
|
|
|
def main(opt):
|
|
# 读取数据集
|
|
data_path = Path(opt.data_path)
|
|
all_file_path = []
|
|
print(data_path)
|
|
if data_path.is_file():
|
|
all_file_path.append(data_path)
|
|
elif data_path.is_dir():
|
|
all_file_path = list(data_path.glob('*.txt')) + list(data_path.glob('*.npy'))
|
|
else:
|
|
raise FileNotFoundError('没有找到以txt或npy结尾的文件')
|
|
|
|
if not all_file_path:
|
|
raise FileNotFoundError('没有找到以txt或npy结尾的文件')
|
|
|
|
for one_file in all_file_path:
|
|
data = read_auto(one_file)
|
|
|
|
data = XinXiaoPreprocess(data, opt.hz)
|
|
model_path = f'SADetectModel/SAmodel{opt.model}.pt'
|
|
if not Path(model_path).exists():
|
|
raise FileNotFoundError('模型文件不存在')
|
|
else:
|
|
print(f'正在使用模型{model_path}进行预测')
|
|
result = SA_Detect(model_path, data, opt.batch_size)
|
|
# result.to_csv(Path(opt.output) / (one_file.stem + "segment.csv"), index=False)
|
|
if opt.output:
|
|
Path(opt.output).mkdir(parents=True, exist_ok=True)
|
|
save_path = Path(opt.output) / (one_file.stem + ".csv")
|
|
else:
|
|
save_path = None
|
|
|
|
AnalyseSegment(result, save_path=save_path, true_sleep_time=opt.TST)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='SA Detection Configuration ----- Marques')
|
|
parser.add_argument('-d', '--data_path', nargs='?', type=str, default='./Data', help='待测试数据文件路径或文件夹')
|
|
parser.add_argument('-n', '--hz', type=int, nargs='?', default=1000, help='信号采样率')
|
|
parser.add_argument('-m', '--model', nargs='?', default=0, choices=[0, 1, 3, 4], help='选择一个模型路径')
|
|
parser.add_argument('-b', "--batch_size", nargs="?", default=4096, type=int, help="模型每次预测片段数量")
|
|
parser.add_argument('-o', '--output', nargs='?', default='./Output', help='输出文件夹')
|
|
parser.add_argument('-t', "--TST", nargs="?", default=0, type=int, help="睡眠总时长(仅单文件测试是支持手动输入)")
|
|
option = parser.parse_args()
|
|
|
|
# 手动输入,优先级高于默认值和命令行输入
|
|
# option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy/3103.npy"
|
|
# option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy"
|
|
# option.hz = 1000
|
|
# option.model = 0
|
|
# option.batch_size = 4096
|
|
# option.output = "./Output"
|
|
# option.TST = 565
|
|
|
|
|
|
main(option)
|