sleep_apnea_hybrid/exam/042/load_dataset.py

156 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@file:load_dataset.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2021/12/03
"""
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import torch.utils.data
from torch.utils.data import Dataset
from tqdm import tqdm
from utils.Preprocessing import BCG_Operation
preprocessing = BCG_Operation()
preprocessing.sample_rate = 100
"""
1. 读取方法
# 无论是否提前切分均提前转成npy格式
# 1.1 提前预处理切分好后生成npy直接载入切分好的片段 内存占用多 读取简单
使用此方法: 1.2 提前预处理载入整夜数据切分好后生成csv或xls根据片段读取 内存占用少 读取较为复杂
"""
datasets = {}
# 减少重复读取
def read_dataset(data_path, augment=None):
data_path = Path(data_path)
try:
f = []
if data_path.is_dir():
dataset_list = list(data_path.rglob("*.npy"))
dataset_list.sort()
f += dataset_list
elif data_path.is_file():
raise Exception(f'dataset path should be a dir')
else:
raise Exception(f'{data_path} does not exist')
except Exception as e:
raise Exception(f'Error loading data from {data_path}: {e} \n')
print("loading dataset")
for i in tqdm(f):
select_dataset = np.load(i)
select_dataset = preprocessing.Butterworth(select_dataset, "lowpass", low_cut=20, order=3)
if augment is not None:
select_dataset = augment(select_dataset)
datasets[i.name.split("samp")[0]] = select_dataset
# 用第二种方法读取
class ApneaDataset(Dataset):
def __init__(self, data_path, label_path, select_sampno, dataset_type, segment_augment=None):
self.data_path = data_path
self.label_path = label_path
self.segment_augment = segment_augment
self.labels = None
self.dataset_type = dataset_type
self.select_sampNo = select_sampno
# self._getAllData()
self._getAllLabels()
def __getitem__(self, index):
# PN patience number
# SP/EP start point, end point
# temp_label.append([sampNo, label[-1], i, hpy_num, csa_num, osa_num, mean_low, flow_low])
PN, segmentNo, label_type, new_label, SP, EP = self.labels[index]
# PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels[index]
if isinstance(datasets, dict):
dataset = datasets[str(PN)]
segment = self.segment_augment(dataset, SP, EP)
return (*segment, int(float(label_type) > 1), PN, segmentNo, label_type, new_label, SP, EP)
else:
raise Exception(f'dataset read failure!')
def count_SA(self):
return sum(self.labels[:, 3] > 1)
def __len__(self):
return len(self.labels)
def _getAllLabels(self):
label_path = Path(self.label_path)
if not label_path.exists():
raise Exception(f'{self.label_path} does not exist')
try:
f = []
if label_path.is_dir():
if self.dataset_type == "train":
label_list = list(label_path.rglob("*_train_label.csv"))
elif self.dataset_type == "valid":
label_list = list(label_path.rglob("*_valid_label.csv"))
elif self.dataset_type == "test":
label_list = list(label_path.glob("*_sa_test_label.csv"))
# label_list = list(label_path.rglob("*_test_label.npy"))
elif self.dataset_type == "all_test":
label_list = list(label_path.rglob("*_sa_all_label.csv"))
else:
raise ValueError("self.dataset type error")
# label_list = list(label_path.rglob("*_label.npy"))
label_list.sort()
f += label_list
elif label_path.is_file():
raise Exception(f'dataset path should be a dir')
else:
raise Exception(f'{self.label_path} does not exist')
except Exception as e:
raise Exception(f'Error loading data from {self.label_path}: {e} \n')
print("loading labels")
for i in tqdm(f):
if int(i.name.split("_")[0]) not in self.select_sampNo:
continue
if self.labels is None:
self.labels = pd.read_csv(i).to_numpy(dtype=int)
else:
labels = pd.read_csv(i).to_numpy(dtype=int)
if len(labels) > 0:
self.labels = np.concatenate((self.labels, labels))
# self.labels = self.labels[:10000]
print(f"{self.dataset_type} length is {len(self.labels)}")
class TestApneaDataset2(ApneaDataset):
def __init__(self, data_path, label_path, select_sampno, dataset_type, segment_augment=None):
super(TestApneaDataset2, self).__init__(
data_path=data_path,
label_path=label_path,
dataset_type=dataset_type,
segment_augment=segment_augment,
select_sampno=select_sampno
)
def __getitem__(self, index):
PN, segmentNo, label_type, new_label, SP, EP = self.labels[index]
# PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels[index]
if isinstance(datasets, dict):
dataset = datasets[str(PN)]
segment = self.segment_augment(dataset, SP, EP)
return (*segment, int(float(label_type) > 1), PN, segmentNo, label_type, new_label, SP, EP)
else:
raise Exception(f'dataset read failure!')
if __name__ == '__main__':
pass