#!/usr/bin/python # -*- coding: UTF-8 -*- """ @author:andrew @file:ApneaDetection.py @email:admin@marques22.com @email:2021022362@m.scnu.edu.cn @time:2023/09/15 """ from pathlib import Path import numpy as np import pandas as pd from utils.SignalPreprocess import XinXiaoPreprocess from utils.ModelDetection import SA_Detect from utils.ResultSummary import AnalyseSegment import argparse def is_valid_file(par, arg): if not Path(arg).is_file(): par.error(f'只能输入模型文件而不是文件夹:{arg}') else: return arg def read_auto(file_path): if file_path.suffix == '.txt': # 使用pandas read csv读取txt return pd.read_csv(file_path, header=None).to_numpy().reshape(-1) elif file_path.suffix == '.npy': return np.load(file_path.__str__()) else: raise ValueError('这个文件类型好像不支持需要自己写读取程序') def main(opt): # 读取数据集 data_path = Path(opt.data_path) all_file_path = [] print(data_path) if data_path.is_file(): all_file_path.append(data_path) elif data_path.is_dir(): all_file_path = list(data_path.glob('*.txt')) + list(data_path.glob('*.npy')) else: raise FileNotFoundError('没有找到以txt或npy结尾的文件') if not all_file_path: raise FileNotFoundError('没有找到以txt或npy结尾的文件') for one_file in all_file_path: data = read_auto(one_file) data = XinXiaoPreprocess(data, opt.hz) model_path = f'SADetectModel/SAmodel{opt.model}.pt' if not Path(model_path).exists(): raise FileNotFoundError('模型文件不存在') else: print(f'正在使用模型{model_path}进行预测') result = SA_Detect(model_path, data, opt.batch_size) # result.to_csv(Path(opt.output) / (one_file.stem + "segment.csv"), index=False) if opt.output: Path(opt.output).mkdir(parents=True, exist_ok=True) save_path = Path(opt.output) / (one_file.stem + ".csv") else: save_path = None AnalyseSegment(result, save_path=save_path, true_sleep_time=opt.TST) if __name__ == '__main__': parser = argparse.ArgumentParser(description='SA Detection Configuration ----- Marques') parser.add_argument('-d', '--data_path', nargs='?', type=str, default='./Data', help='待测试数据文件路径或文件夹') parser.add_argument('-n', '--hz', type=int, nargs='?', default=1000, help='信号采样率') parser.add_argument('-m', '--model', nargs='?', default=0, choices=[0, 1, 3, 4], help='选择一个模型路径') parser.add_argument('-b', "--batch_size", nargs="?", default=4096, type=int, help="模型每次预测片段数量") parser.add_argument('-o', '--output', nargs='?', default='./Output', help='输出文件夹') parser.add_argument('-t', "--TST", nargs="?", default=0, type=int, help="睡眠总时长(仅单文件测试是支持手动输入)") option = parser.parse_args() # 手动输入,优先级高于默认值和命令行输入 # option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy/3103.npy" # option.data_path = "/home/marques/code/marques/apnea/dataset/zhongda/zhongda_origin_npy" # option.hz = 1000 # option.model = 0 # option.batch_size = 4096 # option.output = "./Output" # option.TST = 565 main(option)