179 lines
6.5 KiB
Python
179 lines
6.5 KiB
Python
|
#!/usr/bin/python
|
|||
|
# -*- coding: UTF-8 -*-
|
|||
|
"""
|
|||
|
@author:Marques
|
|||
|
@file:load_dataset.py
|
|||
|
@email:admin@marques22.com
|
|||
|
@email:2021022362@m.scnu.edu.cn
|
|||
|
@time:2021/12/03
|
|||
|
"""
|
|||
|
import sys
|
|||
|
from pathlib import Path
|
|||
|
import pandas as pd
|
|||
|
import numpy as np
|
|||
|
import torch.utils.data
|
|||
|
from torch.utils.data import Dataset
|
|||
|
from tqdm import tqdm
|
|||
|
from utils.Preprocessing import BCG_Operation
|
|||
|
|
|||
|
preprocessing = BCG_Operation()
|
|||
|
preprocessing.sample_rate = 100
|
|||
|
"""
|
|||
|
1. 读取方法
|
|||
|
# 无论是否提前切分,均提前转成npy格式
|
|||
|
# 1.1 提前预处理,切分好后生成npy,直接载入切分好的片段 内存占用多 读取简单
|
|||
|
使用此方法: 1.2 提前预处理,载入整夜数据,切分好后生成csv或xls,根据片段读取 内存占用少 读取较为复杂
|
|||
|
"""
|
|||
|
|
|||
|
datasets = {}
|
|||
|
|
|||
|
|
|||
|
# 减少重复读取
|
|||
|
def read_dataset(data_path, augment=None):
|
|||
|
data_path = Path(data_path)
|
|||
|
try:
|
|||
|
f = []
|
|||
|
if data_path.is_dir():
|
|||
|
dataset_list = list(data_path.rglob("*.npy"))
|
|||
|
dataset_list.sort()
|
|||
|
f += dataset_list
|
|||
|
elif data_path.is_file():
|
|||
|
raise Exception(f'dataset path should be a dir')
|
|||
|
else:
|
|||
|
raise Exception(f'{data_path} does not exist')
|
|||
|
except Exception as e:
|
|||
|
raise Exception(f'Error loading data from {data_path}: {e} \n')
|
|||
|
|
|||
|
print("loading dataset")
|
|||
|
for i in tqdm(f):
|
|||
|
select_dataset = np.load(i)
|
|||
|
select_dataset = preprocessing.Butterworth(select_dataset, "lowpass", low_cut=20, order=3)
|
|||
|
if augment is not None:
|
|||
|
select_dataset = augment(select_dataset)
|
|||
|
datasets[i.name.split("samp")[0]] = select_dataset
|
|||
|
|
|||
|
|
|||
|
# 用第二种方法读取
|
|||
|
class ApneaDataset(Dataset):
|
|||
|
def __init__(self, data_path, label_path, select_sampno, dataset_type, segment_augment=None):
|
|||
|
self.data_path = data_path
|
|||
|
self.label_path = label_path
|
|||
|
self.segment_augment = segment_augment
|
|||
|
self.labels = None
|
|||
|
self.dataset_type = dataset_type
|
|||
|
self.select_sampNo = select_sampno
|
|||
|
|
|||
|
# self._getAllData()
|
|||
|
self._getAllLabels()
|
|||
|
|
|||
|
def __getitem__(self, index):
|
|||
|
# PN patience number
|
|||
|
# SP/EP start point, end point
|
|||
|
# temp_label.append([sampNo, label[-1], i, hpy_num, csa_num, osa_num, mean_low, flow_low])
|
|||
|
PN, segmentNo, label_type, new_label, SP, EP = self.labels[index]
|
|||
|
# PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels[index]
|
|||
|
|
|||
|
if isinstance(datasets, dict):
|
|||
|
segment = datasets[str(PN)][SP*10:EP*10].copy()
|
|||
|
segment = segment.reshape(-1, 1)
|
|||
|
return segment, int(float(label_type) > 1)
|
|||
|
else:
|
|||
|
raise Exception(f'dataset read failure!')
|
|||
|
|
|||
|
def count_SA(self):
|
|||
|
return sum(self.labels[:, 3] > 1)
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.labels)
|
|||
|
|
|||
|
def _getAllLabels(self):
|
|||
|
label_path = Path(self.label_path)
|
|||
|
if not label_path.exists():
|
|||
|
raise Exception(f'{self.label_path} does not exist')
|
|||
|
|
|||
|
try:
|
|||
|
f = []
|
|||
|
if label_path.is_dir():
|
|||
|
if self.dataset_type == "train":
|
|||
|
label_list = list(label_path.rglob("*_train_label.csv"))
|
|||
|
elif self.dataset_type == "valid":
|
|||
|
label_list = list(label_path.rglob("*_valid_label.csv"))
|
|||
|
elif self.dataset_type == "test":
|
|||
|
label_list = list(label_path.glob("*_sa_test_label.csv"))
|
|||
|
# label_list = list(label_path.rglob("*_test_label.npy"))
|
|||
|
elif self.dataset_type == "all_test":
|
|||
|
label_list = list(label_path.rglob("*_sa_all_label.csv"))
|
|||
|
else:
|
|||
|
raise ValueError("self.dataset type error")
|
|||
|
# label_list = list(label_path.rglob("*_label.npy"))
|
|||
|
label_list.sort()
|
|||
|
f += label_list
|
|||
|
elif label_path.is_file():
|
|||
|
raise Exception(f'dataset path should be a dir')
|
|||
|
else:
|
|||
|
raise Exception(f'{self.label_path} does not exist')
|
|||
|
except Exception as e:
|
|||
|
raise Exception(f'Error loading data from {self.label_path}: {e} \n')
|
|||
|
print("loading labels")
|
|||
|
for i in tqdm(f):
|
|||
|
if int(i.name.split("_")[0]) not in self.select_sampNo:
|
|||
|
continue
|
|||
|
|
|||
|
if self.labels is None:
|
|||
|
self.labels = pd.read_csv(i).to_numpy(dtype=int)
|
|||
|
else:
|
|||
|
labels = pd.read_csv(i).to_numpy(dtype=int)
|
|||
|
if len(labels) > 0:
|
|||
|
self.labels = np.concatenate((self.labels, labels))
|
|||
|
# self.labels = self.labels[:10000]
|
|||
|
print(f"{self.dataset_type} length is {len(self.labels)}")
|
|||
|
|
|||
|
|
|||
|
class TestApneaDataset(ApneaDataset):
|
|||
|
def __init__(self, data_path, label_path, dataset_type, select_sampno, segment_augment=None):
|
|||
|
super(TestApneaDataset, self).__init__(
|
|||
|
data_path=data_path,
|
|||
|
label_path=label_path,
|
|||
|
dataset_type=dataset_type,
|
|||
|
select_sampno=select_sampno,
|
|||
|
segment_augment=segment_augment
|
|||
|
)
|
|||
|
|
|||
|
def __getitem__(self, index):
|
|||
|
PN, segmentNo, label_type, SP, EP = self.labels[index]
|
|||
|
# PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels[index]
|
|||
|
|
|||
|
if isinstance(datasets, dict):
|
|||
|
segment = datasets[str(PN)][int(SP) * 100:int(EP) * 100].copy()
|
|||
|
if self.segment_augment is not None:
|
|||
|
segment = self.segment_augment(segment)
|
|||
|
return segment, int(float(label_type) > 1), PN, segmentNo, SP, EP
|
|||
|
else:
|
|||
|
raise Exception(f'dataset read failure!')
|
|||
|
|
|||
|
|
|||
|
class TestApneaDataset2(ApneaDataset):
|
|||
|
def __init__(self, data_path, label_path, select_sampno, dataset_type, segment_augment=None):
|
|||
|
super(TestApneaDataset2, self).__init__(
|
|||
|
data_path=data_path,
|
|||
|
label_path=label_path,
|
|||
|
dataset_type=dataset_type,
|
|||
|
segment_augment=segment_augment,
|
|||
|
select_sampno=select_sampno
|
|||
|
)
|
|||
|
|
|||
|
def __getitem__(self, index):
|
|||
|
PN, segmentNo, label_type, new_label, SP, EP = self.labels[index]
|
|||
|
# PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels[index]
|
|||
|
|
|||
|
if isinstance(datasets, dict):
|
|||
|
segment = datasets[str(PN)][SP*10:EP*10].copy()
|
|||
|
segment = segment.reshape(-1, 1)
|
|||
|
return segment, int(float(label_type) > 1), PN, segmentNo, label_type, new_label, SP, EP
|
|||
|
else:
|
|||
|
raise Exception(f'dataset read failure!')
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
pass
|