diff --git a/load_dataset.py b/load_dataset.py new file mode 100644 index 0000000..a88b6a2 --- /dev/null +++ b/load_dataset.py @@ -0,0 +1,161 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- +""" +@author:Marques +@file:load_dataset.py +@email:admin@marques22.com +@email:2021022362@m.scnu.edu.cn +@time:2021/12/03 +""" +import sys +from pathlib import Path +import pandas as pd +import numpy as np +import torch +import yaml +from torch.utils.data import Dataset +from tqdm import tqdm + +""" +1. 读取方法 +# 无论是否提前切分,均提前转成npy格式 +# 1.1 提前预处理,切分好后生成npy,直接载入切分好的片段 内存占用多 读取简单 +使用此方法: 1.2 提前预处理,载入整夜数据,切分好后生成csv或xls,根据片段读取 内存占用少 读取较为复杂 +""" + +datasets = {} + + +# 减少重复读取 +def read_dataset(config, augment=None): + data_path = Path(config["Path"]["dataset"]) + try: + file_list = [] + if data_path.is_dir(): + dataset_list = list(data_path.rglob("*.npy_low_zscore.npy")) + dataset_list.sort() + file_list += dataset_list + elif data_path.is_file(): + raise Exception(f'dataset path should be a dir') + else: + raise Exception(f'{data_path} does not exist') + except Exception as e: + raise Exception(f'Error loading data from {data_path}: {e} \n') + + print("loading dataset") + for i in tqdm(file_list): + select_dataset = np.load(i, allow_pickle=True)[0] + # select_dataset = preprocessing.Butterworth(select_dataset, "lowpass", low_cut=20, order=3) + if augment is not None: + select_dataset = augment(select_dataset, config) + datasets[i.name.split("samp")[0]] = select_dataset + + +# 用第二种方法读取 +class ApneaDataset(Dataset): + def __init__(self, config, dataset_type, select_sampno, segment_augment=None): + self.data_path = Path(config["Path"]["dataset"]) + self.label_path = Path(config["Path"]["label"]) + self.segment_augment = segment_augment + self.labels_info = None + self.labels = None + self.dataset_type = dataset_type + self.select_sampNo = select_sampno + self.disable_hpy = config["disable_hpy"] + self.apply_samplerate = config["apply_samplerate"] + + # self._getAllData() + self._getAllLabels() + + def __getitem__(self, index): + # PN patience number + # SP/EP start point, end point + # temp_label.append([sampNo, label[-1], i, hpy_num, csa_num, osa_num, mean_low, flow_low]) + PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index] + # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index] + + if isinstance(datasets, dict): + segment = self.segment_augment(datasets[str(PN)], SP * self.apply_samplerate, EP * self.apply_samplerate) + return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP) + else: + raise Exception(f'dataset read failure!') + + def count_SA(self): + # assert isinstance(self.disable_hpy, int) + return sum(self.labels) + + def __len__(self): + return len(self.labels_info) + + def _getAllLabels(self): + label_path = Path(self.label_path) + if not label_path.exists(): + raise Exception(f'{self.label_path} does not exist') + + try: + file_list = [] + if label_path.is_dir(): + if self.dataset_type == "train": + label_list = list(label_path.rglob("*_train_label.csv")) + elif self.dataset_type == "valid": + label_list = list(label_path.rglob("*_valid_label.csv")) + elif self.dataset_type == "test": + label_list = list(label_path.glob("*_sa_test_label.csv")) + # label_list = list(label_path.rglob("*_test_label.npy")) + elif self.dataset_type == "all_test": + label_list = list(label_path.rglob("*_sa_all_label.csv")) + else: + raise ValueError("self.dataset type error") + # label_list = list(label_path.rglob("*_label.npy")) + label_list.sort() + file_list += label_list + elif label_path.is_file(): + raise Exception(f'dataset path should be a dir') + else: + raise Exception(f'{self.label_path} does not exist') + except Exception as e: + raise Exception(f'Error loading data from {self.label_path}: {e} \n') + print("loading labels") + for i in tqdm(file_list): + if int(i.name.split("_")[0]) not in self.select_sampNo: + continue + + if self.labels_info is None: + self.labels_info = pd.read_csv(i).to_numpy(dtype=int) + else: + labels = pd.read_csv(i).to_numpy(dtype=int) + if len(labels) > 0: + self.labels_info = np.concatenate((self.labels_info, labels)) + + self.labels = (self.labels_info[:, 3] > self.disable_hpy) * 1 + self.labels = torch.from_numpy(self.labels) + gpu = torch.cuda.is_available() + self.labels = self.labels.cuda() if gpu else self.labels + + # self.labels_info = self.labels_info[:10000] + print(f"{self.dataset_type} length is {len(self.labels_info)}") + + +class TestApneaDataset2(ApneaDataset): + def __init__(self, config, dataset_type, select_sampno, segment_augment=None): + super(TestApneaDataset2, self).__init__( + config, + dataset_type=dataset_type, + select_sampno=select_sampno, + segment_augment=segment_augment, + ) + + def __getitem__(self, index): + PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index] + # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index] + + if isinstance(datasets, dict): + dataset = datasets[str(PN)] + segment = self.segment_augment(dataset, SP * self.apply_samplerate, EP * self.apply_samplerate) + return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP) + else: + raise Exception(f'dataset read failure!') + + +if __name__ == '__main__': + pass diff --git a/my_augment.py b/my_augment.py new file mode 100644 index 0000000..8b78209 --- /dev/null +++ b/my_augment.py @@ -0,0 +1,58 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- +""" +@author:Marques +@file:my_augment.py +@email:admin@marques22.com +@email:2021022362@m.scnu.edu.cn +@time:2022/07/26 +""" +import torch.cuda +import yaml + +from utils.Preprocessing import BCG_Operation +import numpy as np +from scipy.signal import stft +from torch import from_numpy +with open("./settings.yaml") as f: + hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps + +apply_samplerate = hyp["apply_samplerate"] +dataset_samplerate = hyp["dataset_samplerate"] +preprocessing = BCG_Operation() +preprocessing.sample_rate = dataset_samplerate + + +def my_augment(dataset, config): + # dataset = preprocessing.Butterworth(dataset, "lowpass", low_cut=20, order=6) + # dataset = (dataset - dataset.mean()) / dataset.std() + # dataset_low = preprocessing.Butterworth(dataset, "lowpass", low_cut=0.7, order=6) + # dataset_high = preprocessing.Butterworth(dataset, "highpass", high_cut=1, order=6) + print(f"dataset sample_rate is {config['dataset_samplerate']} down_ratio is {config['dataset_samplerate'] // config['apply_samplerate']}") + dataset = dataset[::config['dataset_samplerate'] // config['apply_samplerate']] + gpu = torch.cuda.is_available() + dataset = {"raw": from_numpy(dataset).float().cuda() if gpu else from_numpy(dataset).float(), + # "low": dataset_low, + # "high": dataset_high + } + return dataset + + +def get_stft(x, fs, n): + print(len(x)) + f, t, amp = stft(x, fs, nperseg=n) + z = np.abs(amp.copy()) + return f, t, z + + +def my_segment_augment(dataset, SP, EP): + # dataset_segment1 = dataset["low"][int(SP) * 100:int(EP) * 100].copy() + # dataset_segment2 = dataset["high"][int(SP) * 100:int(EP) * 100].copy() + + # dataset_segment = np.concatenate(([dataset_segment1], [dataset_segment2]), axis=0) + dataset_segment = dataset["raw"][int(SP):int(EP)].unsqueeze(dim=0) + return [dataset_segment] + + +if __name__ == '__main__': + pass