read_dataset/load_dataset.py
2023-03-06 15:57:09 +08:00

162 lines
6.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@file:load_dataset.py
@email:admin@marques22.com
@email:2021022362@m.scnu.edu.cn
@time:2021/12/03
"""
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import yaml
from torch.utils.data import Dataset
from tqdm import tqdm
"""
1. 读取方法
# 无论是否提前切分均提前转成npy格式
# 1.1 提前预处理切分好后生成npy直接载入切分好的片段 内存占用多 读取简单
使用此方法: 1.2 提前预处理载入整夜数据切分好后生成csv或xls根据片段读取 内存占用少 读取较为复杂
"""
datasets = {}
# 减少重复读取
def read_dataset(config, augment=None):
data_path = Path(config["Path"]["dataset"])
try:
file_list = []
if data_path.is_dir():
dataset_list = list(data_path.rglob("*.npy_low_zscore.npy"))
dataset_list.sort()
file_list += dataset_list
elif data_path.is_file():
raise Exception(f'dataset path should be a dir')
else:
raise Exception(f'{data_path} does not exist')
except Exception as e:
raise Exception(f'Error loading data from {data_path}: {e} \n')
print("loading dataset")
for i in tqdm(file_list):
select_dataset = np.load(i, allow_pickle=True)[0]
# select_dataset = preprocessing.Butterworth(select_dataset, "lowpass", low_cut=20, order=3)
if augment is not None:
select_dataset = augment(select_dataset, config)
datasets[i.name.split("samp")[0]] = select_dataset
# 用第二种方法读取
class ApneaDataset(Dataset):
def __init__(self, config, dataset_type, select_sampno, segment_augment=None):
self.data_path = Path(config["Path"]["dataset"])
self.label_path = Path(config["Path"]["label"])
self.segment_augment = segment_augment
self.labels_info = None
self.labels = None
self.dataset_type = dataset_type
self.select_sampNo = select_sampno
self.disable_hpy = config["disable_hpy"]
self.apply_samplerate = config["apply_samplerate"]
# self._getAllData()
self._getAllLabels()
def __getitem__(self, index):
# PN patience number
# SP/EP start point, end point
# temp_label.append([sampNo, label[-1], i, hpy_num, csa_num, osa_num, mean_low, flow_low])
PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index]
# PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index]
if isinstance(datasets, dict):
segment = self.segment_augment(datasets[str(PN)], SP * self.apply_samplerate, EP * self.apply_samplerate)
return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP)
else:
raise Exception(f'dataset read failure!')
def count_SA(self):
# assert isinstance(self.disable_hpy, int)
return sum(self.labels)
def __len__(self):
return len(self.labels_info)
def _getAllLabels(self):
label_path = Path(self.label_path)
if not label_path.exists():
raise Exception(f'{self.label_path} does not exist')
try:
file_list = []
if label_path.is_dir():
if self.dataset_type == "train":
label_list = list(label_path.rglob("*_train_label.csv"))
elif self.dataset_type == "valid":
label_list = list(label_path.rglob("*_valid_label.csv"))
elif self.dataset_type == "test":
label_list = list(label_path.glob("*_sa_test_label.csv"))
# label_list = list(label_path.rglob("*_test_label.npy"))
elif self.dataset_type == "all_test":
label_list = list(label_path.rglob("*_sa_all_label.csv"))
else:
raise ValueError("self.dataset type error")
# label_list = list(label_path.rglob("*_label.npy"))
label_list.sort()
file_list += label_list
elif label_path.is_file():
raise Exception(f'dataset path should be a dir')
else:
raise Exception(f'{self.label_path} does not exist')
except Exception as e:
raise Exception(f'Error loading data from {self.label_path}: {e} \n')
print("loading labels")
for i in tqdm(file_list):
if int(i.name.split("_")[0]) not in self.select_sampNo:
continue
if self.labels_info is None:
self.labels_info = pd.read_csv(i).to_numpy(dtype=int)
else:
labels = pd.read_csv(i).to_numpy(dtype=int)
if len(labels) > 0:
self.labels_info = np.concatenate((self.labels_info, labels))
self.labels = (self.labels_info[:, 3] > self.disable_hpy) * 1
self.labels = torch.from_numpy(self.labels)
gpu = torch.cuda.is_available()
self.labels = self.labels.cuda() if gpu else self.labels
# self.labels_info = self.labels_info[:10000]
print(f"{self.dataset_type} length is {len(self.labels_info)}")
class TestApneaDataset2(ApneaDataset):
def __init__(self, config, dataset_type, select_sampno, segment_augment=None):
super(TestApneaDataset2, self).__init__(
config,
dataset_type=dataset_type,
select_sampno=select_sampno,
segment_augment=segment_augment,
)
def __getitem__(self, index):
PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index]
# PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index]
if isinstance(datasets, dict):
dataset = datasets[str(PN)]
segment = self.segment_augment(dataset, SP * self.apply_samplerate, EP * self.apply_samplerate)
return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP)
else:
raise Exception(f'dataset read failure!')
if __name__ == '__main__':
pass