#!/usr/bin/python # -*- coding: UTF-8 -*- """ @author:Marques @file:load_dataset.py @email:admin@marques22.com @email:2021022362@m.scnu.edu.cn @time:2021/12/03 """ import sys from pathlib import Path import pandas as pd import numpy as np import torch import yaml from torch.utils.data import Dataset from tqdm import tqdm """ 1. 读取方法 # 无论是否提前切分,均提前转成npy格式 # 1.1 提前预处理,切分好后生成npy,直接载入切分好的片段 内存占用多 读取简单 使用此方法: 1.2 提前预处理,载入整夜数据,切分好后生成csv或xls,根据片段读取 内存占用少 读取较为复杂 """ datasets = {} # 减少重复读取 def read_dataset(config, augment=None): data_path = Path(config["Path"]["dataset"]) try: file_list = [] if data_path.is_dir(): dataset_list = list(data_path.rglob("*.npy_low_zscore.npy")) dataset_list.sort() file_list += dataset_list elif data_path.is_file(): raise Exception(f'dataset path should be a dir') else: raise Exception(f'{data_path} does not exist') except Exception as e: raise Exception(f'Error loading data from {data_path}: {e} \n') print("loading dataset") for i in tqdm(file_list): select_dataset = np.load(i, allow_pickle=True)[0] # select_dataset = preprocessing.Butterworth(select_dataset, "lowpass", low_cut=20, order=3) if augment is not None: select_dataset = augment(select_dataset, config) datasets[i.name.split("samp")[0]] = select_dataset # 用第二种方法读取 class ApneaDataset(Dataset): def __init__(self, config, dataset_type, select_sampno, segment_augment=None): self.data_path = Path(config["Path"]["dataset"]) self.label_path = Path(config["Path"]["label"]) self.segment_augment = segment_augment self.labels_info = None self.labels = None self.dataset_type = dataset_type self.select_sampNo = select_sampno self.disable_hpy = config["disable_hpy"] self.apply_samplerate = config["apply_samplerate"] # self._getAllData() self._getAllLabels() def __getitem__(self, index): # PN patience number # SP/EP start point, end point # temp_label.append([sampNo, label[-1], i, hpy_num, csa_num, osa_num, mean_low, flow_low]) PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index] # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index] if isinstance(datasets, dict): segment = self.segment_augment(datasets[str(PN)], SP * self.apply_samplerate, EP * self.apply_samplerate) return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP) else: raise Exception(f'dataset read failure!') def count_SA(self): # assert isinstance(self.disable_hpy, int) return sum(self.labels) def __len__(self): return len(self.labels_info) def _getAllLabels(self): label_path = Path(self.label_path) if not label_path.exists(): raise Exception(f'{self.label_path} does not exist') try: file_list = [] if label_path.is_dir(): if self.dataset_type == "train": label_list = list(label_path.rglob("*_train_label.csv")) elif self.dataset_type == "valid": label_list = list(label_path.rglob("*_valid_label.csv")) elif self.dataset_type == "test": label_list = list(label_path.glob("*_sa_test_label.csv")) # label_list = list(label_path.rglob("*_test_label.npy")) elif self.dataset_type == "all_test": label_list = list(label_path.rglob("*_sa_all_label.csv")) else: raise ValueError("self.dataset type error") # label_list = list(label_path.rglob("*_label.npy")) label_list.sort() file_list += label_list elif label_path.is_file(): raise Exception(f'dataset path should be a dir') else: raise Exception(f'{self.label_path} does not exist') except Exception as e: raise Exception(f'Error loading data from {self.label_path}: {e} \n') print("loading labels") for i in tqdm(file_list): if int(i.name.split("_")[0]) not in self.select_sampNo: continue if self.labels_info is None: self.labels_info = pd.read_csv(i).to_numpy(dtype=int) else: labels = pd.read_csv(i).to_numpy(dtype=int) if len(labels) > 0: self.labels_info = np.concatenate((self.labels_info, labels)) self.labels = (self.labels_info[:, 3] > self.disable_hpy) * 1 self.labels = torch.from_numpy(self.labels) gpu = torch.cuda.is_available() self.labels = self.labels.cuda() if gpu else self.labels # self.labels_info = self.labels_info[:10000] print(f"{self.dataset_type} length is {len(self.labels_info)}") class TestApneaDataset2(ApneaDataset): def __init__(self, config, dataset_type, select_sampno, segment_augment=None): super(TestApneaDataset2, self).__init__( config, dataset_type=dataset_type, select_sampno=select_sampno, segment_augment=segment_augment, ) def __getitem__(self, index): PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index] # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index] if isinstance(datasets, dict): dataset = datasets[str(PN)] segment = self.segment_augment(dataset, SP * self.apply_samplerate, EP * self.apply_samplerate) return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP) else: raise Exception(f'dataset read failure!') if __name__ == '__main__': pass