#!/usr/bin/python # -*- coding: UTF-8 -*- """ @author:Marques @file:load_dataset.py @email:admin@marques22.com @email:2021022362@m.scnu.edu.cn @time:2021/12/03 """ import sys from pathlib import Path import pandas as pd import numpy as np import torch.utils.data from torch.utils.data import Dataset from tqdm import tqdm from utils.Preprocessing import BCG_Operation preprocessing = BCG_Operation() preprocessing.sample_rate = 100 """ 1. 读取方法 # 无论是否提前切分,均提前转成npy格式 # 1.1 提前预处理,切分好后生成npy,直接载入切分好的片段 内存占用多 读取简单 使用此方法: 1.2 提前预处理,载入整夜数据,切分好后生成csv或xls,根据片段读取 内存占用少 读取较为复杂 """ datasets = {} # 减少重复读取 def read_dataset(data_path, augment=None): data_path = Path(data_path) try: f = [] if data_path.is_dir(): dataset_list = list(data_path.rglob("*.npy")) dataset_list.sort() f += dataset_list elif data_path.is_file(): raise Exception(f'dataset path should be a dir') else: raise Exception(f'{data_path} does not exist') except Exception as e: raise Exception(f'Error loading data from {data_path}: {e} \n') print("loading dataset") for i in tqdm(f): select_dataset = np.load(i) select_dataset = preprocessing.Butterworth(select_dataset, "lowpass", low_cut=20, order=3) if augment is not None: select_dataset = augment(select_dataset) datasets[i.name.split("samp")[0]] = select_dataset # 用第二种方法读取 class ApneaDataset(Dataset): def __init__(self, data_path, label_path, select_sampno, dataset_type, segment_augment=None): self.data_path = data_path self.label_path = label_path self.segment_augment = segment_augment self.labels = None self.dataset_type = dataset_type self.select_sampNo = select_sampno # self._getAllData() self._getAllLabels() def __getitem__(self, index): # PN patience number # SP/EP start point, end point # temp_label.append([sampNo, label[-1], i, hpy_num, csa_num, osa_num, mean_low, flow_low]) PN, segmentNo, label_type, new_label, SP, EP = self.labels[index] # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels[index] if isinstance(datasets, dict): segment = datasets[str(PN)][SP*10:EP*10].copy() segment = segment.reshape(-1, 5) return segment, int(float(label_type) > 1) else: raise Exception(f'dataset read failure!') def count_SA(self): return sum(self.labels[:, 3] > 1) def __len__(self): return len(self.labels) def _getAllLabels(self): label_path = Path(self.label_path) if not label_path.exists(): raise Exception(f'{self.label_path} does not exist') try: f = [] if label_path.is_dir(): if self.dataset_type == "train": label_list = list(label_path.rglob("*_train_label.csv")) elif self.dataset_type == "valid": label_list = list(label_path.rglob("*_valid_label.csv")) elif self.dataset_type == "test": label_list = list(label_path.glob("*_sa_test_label.csv")) # label_list = list(label_path.rglob("*_test_label.npy")) elif self.dataset_type == "all_test": label_list = list(label_path.rglob("*_sa_all_label.csv")) else: raise ValueError("self.dataset type error") # label_list = list(label_path.rglob("*_label.npy")) label_list.sort() f += label_list elif label_path.is_file(): raise Exception(f'dataset path should be a dir') else: raise Exception(f'{self.label_path} does not exist') except Exception as e: raise Exception(f'Error loading data from {self.label_path}: {e} \n') print("loading labels") for i in tqdm(f): if int(i.name.split("_")[0]) not in self.select_sampNo: continue if self.labels is None: self.labels = pd.read_csv(i).to_numpy(dtype=int) else: labels = pd.read_csv(i).to_numpy(dtype=int) if len(labels) > 0: self.labels = np.concatenate((self.labels, labels)) # self.labels = self.labels[:10000] print(f"{self.dataset_type} length is {len(self.labels)}") class TestApneaDataset2(ApneaDataset): def __init__(self, data_path, label_path, select_sampno, dataset_type, segment_augment=None): super(TestApneaDataset2, self).__init__( data_path=data_path, label_path=label_path, dataset_type=dataset_type, segment_augment=segment_augment, select_sampno=select_sampno ) def __getitem__(self, index): PN, segmentNo, label_type, new_label, SP, EP = self.labels[index] # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels[index] if isinstance(datasets, dict): segment = datasets[str(PN)][SP*10:EP*10].copy() segment = segment.reshape(-1, 5) return segment, int(float(label_type) > 1), PN, segmentNo, label_type, new_label, SP, EP else: raise Exception(f'dataset read failure!') if __name__ == '__main__': pass