给捷龙康康
This commit is contained in:
parent
c0abde4ff9
commit
13cc898ec2
161
load_dataset.py
Normal file
161
load_dataset.py
Normal file
@ -0,0 +1,161 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@author:Marques
|
||||
@file:load_dataset.py
|
||||
@email:admin@marques22.com
|
||||
@email:2021022362@m.scnu.edu.cn
|
||||
@time:2021/12/03
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
"""
|
||||
1. 读取方法
|
||||
# 无论是否提前切分,均提前转成npy格式
|
||||
# 1.1 提前预处理,切分好后生成npy,直接载入切分好的片段 内存占用多 读取简单
|
||||
使用此方法: 1.2 提前预处理,载入整夜数据,切分好后生成csv或xls,根据片段读取 内存占用少 读取较为复杂
|
||||
"""
|
||||
|
||||
datasets = {}
|
||||
|
||||
|
||||
# 减少重复读取
|
||||
def read_dataset(config, augment=None):
|
||||
data_path = Path(config["Path"]["dataset"])
|
||||
try:
|
||||
file_list = []
|
||||
if data_path.is_dir():
|
||||
dataset_list = list(data_path.rglob("*.npy_low_zscore.npy"))
|
||||
dataset_list.sort()
|
||||
file_list += dataset_list
|
||||
elif data_path.is_file():
|
||||
raise Exception(f'dataset path should be a dir')
|
||||
else:
|
||||
raise Exception(f'{data_path} does not exist')
|
||||
except Exception as e:
|
||||
raise Exception(f'Error loading data from {data_path}: {e} \n')
|
||||
|
||||
print("loading dataset")
|
||||
for i in tqdm(file_list):
|
||||
select_dataset = np.load(i, allow_pickle=True)[0]
|
||||
# select_dataset = preprocessing.Butterworth(select_dataset, "lowpass", low_cut=20, order=3)
|
||||
if augment is not None:
|
||||
select_dataset = augment(select_dataset, config)
|
||||
datasets[i.name.split("samp")[0]] = select_dataset
|
||||
|
||||
|
||||
# 用第二种方法读取
|
||||
class ApneaDataset(Dataset):
|
||||
def __init__(self, config, dataset_type, select_sampno, segment_augment=None):
|
||||
self.data_path = Path(config["Path"]["dataset"])
|
||||
self.label_path = Path(config["Path"]["label"])
|
||||
self.segment_augment = segment_augment
|
||||
self.labels_info = None
|
||||
self.labels = None
|
||||
self.dataset_type = dataset_type
|
||||
self.select_sampNo = select_sampno
|
||||
self.disable_hpy = config["disable_hpy"]
|
||||
self.apply_samplerate = config["apply_samplerate"]
|
||||
|
||||
# self._getAllData()
|
||||
self._getAllLabels()
|
||||
|
||||
def __getitem__(self, index):
|
||||
# PN patience number
|
||||
# SP/EP start point, end point
|
||||
# temp_label.append([sampNo, label[-1], i, hpy_num, csa_num, osa_num, mean_low, flow_low])
|
||||
PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index]
|
||||
# PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index]
|
||||
|
||||
if isinstance(datasets, dict):
|
||||
segment = self.segment_augment(datasets[str(PN)], SP * self.apply_samplerate, EP * self.apply_samplerate)
|
||||
return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP)
|
||||
else:
|
||||
raise Exception(f'dataset read failure!')
|
||||
|
||||
def count_SA(self):
|
||||
# assert isinstance(self.disable_hpy, int)
|
||||
return sum(self.labels)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels_info)
|
||||
|
||||
def _getAllLabels(self):
|
||||
label_path = Path(self.label_path)
|
||||
if not label_path.exists():
|
||||
raise Exception(f'{self.label_path} does not exist')
|
||||
|
||||
try:
|
||||
file_list = []
|
||||
if label_path.is_dir():
|
||||
if self.dataset_type == "train":
|
||||
label_list = list(label_path.rglob("*_train_label.csv"))
|
||||
elif self.dataset_type == "valid":
|
||||
label_list = list(label_path.rglob("*_valid_label.csv"))
|
||||
elif self.dataset_type == "test":
|
||||
label_list = list(label_path.glob("*_sa_test_label.csv"))
|
||||
# label_list = list(label_path.rglob("*_test_label.npy"))
|
||||
elif self.dataset_type == "all_test":
|
||||
label_list = list(label_path.rglob("*_sa_all_label.csv"))
|
||||
else:
|
||||
raise ValueError("self.dataset type error")
|
||||
# label_list = list(label_path.rglob("*_label.npy"))
|
||||
label_list.sort()
|
||||
file_list += label_list
|
||||
elif label_path.is_file():
|
||||
raise Exception(f'dataset path should be a dir')
|
||||
else:
|
||||
raise Exception(f'{self.label_path} does not exist')
|
||||
except Exception as e:
|
||||
raise Exception(f'Error loading data from {self.label_path}: {e} \n')
|
||||
print("loading labels")
|
||||
for i in tqdm(file_list):
|
||||
if int(i.name.split("_")[0]) not in self.select_sampNo:
|
||||
continue
|
||||
|
||||
if self.labels_info is None:
|
||||
self.labels_info = pd.read_csv(i).to_numpy(dtype=int)
|
||||
else:
|
||||
labels = pd.read_csv(i).to_numpy(dtype=int)
|
||||
if len(labels) > 0:
|
||||
self.labels_info = np.concatenate((self.labels_info, labels))
|
||||
|
||||
self.labels = (self.labels_info[:, 3] > self.disable_hpy) * 1
|
||||
self.labels = torch.from_numpy(self.labels)
|
||||
gpu = torch.cuda.is_available()
|
||||
self.labels = self.labels.cuda() if gpu else self.labels
|
||||
|
||||
# self.labels_info = self.labels_info[:10000]
|
||||
print(f"{self.dataset_type} length is {len(self.labels_info)}")
|
||||
|
||||
|
||||
class TestApneaDataset2(ApneaDataset):
|
||||
def __init__(self, config, dataset_type, select_sampno, segment_augment=None):
|
||||
super(TestApneaDataset2, self).__init__(
|
||||
config,
|
||||
dataset_type=dataset_type,
|
||||
select_sampno=select_sampno,
|
||||
segment_augment=segment_augment,
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index]
|
||||
# PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index]
|
||||
|
||||
if isinstance(datasets, dict):
|
||||
dataset = datasets[str(PN)]
|
||||
segment = self.segment_augment(dataset, SP * self.apply_samplerate, EP * self.apply_samplerate)
|
||||
return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP)
|
||||
else:
|
||||
raise Exception(f'dataset read failure!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
58
my_augment.py
Normal file
58
my_augment.py
Normal file
@ -0,0 +1,58 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@author:Marques
|
||||
@file:my_augment.py
|
||||
@email:admin@marques22.com
|
||||
@email:2021022362@m.scnu.edu.cn
|
||||
@time:2022/07/26
|
||||
"""
|
||||
import torch.cuda
|
||||
import yaml
|
||||
|
||||
from utils.Preprocessing import BCG_Operation
|
||||
import numpy as np
|
||||
from scipy.signal import stft
|
||||
from torch import from_numpy
|
||||
with open("./settings.yaml") as f:
|
||||
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
|
||||
|
||||
apply_samplerate = hyp["apply_samplerate"]
|
||||
dataset_samplerate = hyp["dataset_samplerate"]
|
||||
preprocessing = BCG_Operation()
|
||||
preprocessing.sample_rate = dataset_samplerate
|
||||
|
||||
|
||||
def my_augment(dataset, config):
|
||||
# dataset = preprocessing.Butterworth(dataset, "lowpass", low_cut=20, order=6)
|
||||
# dataset = (dataset - dataset.mean()) / dataset.std()
|
||||
# dataset_low = preprocessing.Butterworth(dataset, "lowpass", low_cut=0.7, order=6)
|
||||
# dataset_high = preprocessing.Butterworth(dataset, "highpass", high_cut=1, order=6)
|
||||
print(f"dataset sample_rate is {config['dataset_samplerate']} down_ratio is {config['dataset_samplerate'] // config['apply_samplerate']}")
|
||||
dataset = dataset[::config['dataset_samplerate'] // config['apply_samplerate']]
|
||||
gpu = torch.cuda.is_available()
|
||||
dataset = {"raw": from_numpy(dataset).float().cuda() if gpu else from_numpy(dataset).float(),
|
||||
# "low": dataset_low,
|
||||
# "high": dataset_high
|
||||
}
|
||||
return dataset
|
||||
|
||||
|
||||
def get_stft(x, fs, n):
|
||||
print(len(x))
|
||||
f, t, amp = stft(x, fs, nperseg=n)
|
||||
z = np.abs(amp.copy())
|
||||
return f, t, z
|
||||
|
||||
|
||||
def my_segment_augment(dataset, SP, EP):
|
||||
# dataset_segment1 = dataset["low"][int(SP) * 100:int(EP) * 100].copy()
|
||||
# dataset_segment2 = dataset["high"][int(SP) * 100:int(EP) * 100].copy()
|
||||
|
||||
# dataset_segment = np.concatenate(([dataset_segment1], [dataset_segment2]), axis=0)
|
||||
dataset_segment = dataset["raw"][int(SP):int(EP)].unsqueeze(dim=0)
|
||||
return [dataset_segment]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
Loading…
Reference in New Issue
Block a user