给捷龙康康
This commit is contained in:
		
							parent
							
								
									c0abde4ff9
								
							
						
					
					
						commit
						13cc898ec2
					
				
							
								
								
									
										161
									
								
								load_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								load_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,161 @@ | ||||
| #!/usr/bin/python | ||||
| # -*- coding: UTF-8 -*- | ||||
| """ | ||||
| @author:Marques | ||||
| @file:load_dataset.py | ||||
| @email:admin@marques22.com | ||||
| @email:2021022362@m.scnu.edu.cn | ||||
| @time:2021/12/03 | ||||
| """ | ||||
| import sys | ||||
| from pathlib import Path | ||||
| import pandas as pd | ||||
| import numpy as np | ||||
| import torch | ||||
| import yaml | ||||
| from torch.utils.data import Dataset | ||||
| from tqdm import tqdm | ||||
| 
 | ||||
| """ | ||||
| 1. 读取方法 | ||||
| # 无论是否提前切分,均提前转成npy格式 | ||||
| # 1.1 提前预处理,切分好后生成npy,直接载入切分好的片段   内存占用多  读取简单 | ||||
| 使用此方法: 1.2 提前预处理,载入整夜数据,切分好后生成csv或xls,根据片段读取  内存占用少    读取较为复杂 | ||||
| """ | ||||
| 
 | ||||
| datasets = {} | ||||
| 
 | ||||
| 
 | ||||
| # 减少重复读取 | ||||
| def read_dataset(config, augment=None): | ||||
|     data_path = Path(config["Path"]["dataset"]) | ||||
|     try: | ||||
|         file_list = [] | ||||
|         if data_path.is_dir(): | ||||
|             dataset_list = list(data_path.rglob("*.npy_low_zscore.npy")) | ||||
|             dataset_list.sort() | ||||
|             file_list += dataset_list | ||||
|         elif data_path.is_file(): | ||||
|             raise Exception(f'dataset path should be a dir') | ||||
|         else: | ||||
|             raise Exception(f'{data_path} does not exist') | ||||
|     except Exception as e: | ||||
|         raise Exception(f'Error loading data from {data_path}: {e} \n') | ||||
| 
 | ||||
|     print("loading dataset") | ||||
|     for i in tqdm(file_list): | ||||
|         select_dataset = np.load(i, allow_pickle=True)[0] | ||||
|         # select_dataset = preprocessing.Butterworth(select_dataset, "lowpass", low_cut=20, order=3) | ||||
|         if augment is not None: | ||||
|             select_dataset = augment(select_dataset, config) | ||||
|         datasets[i.name.split("samp")[0]] = select_dataset | ||||
| 
 | ||||
| 
 | ||||
| # 用第二种方法读取 | ||||
| class ApneaDataset(Dataset): | ||||
|     def __init__(self, config, dataset_type, select_sampno, segment_augment=None): | ||||
|         self.data_path = Path(config["Path"]["dataset"]) | ||||
|         self.label_path = Path(config["Path"]["label"]) | ||||
|         self.segment_augment = segment_augment | ||||
|         self.labels_info = None | ||||
|         self.labels = None | ||||
|         self.dataset_type = dataset_type | ||||
|         self.select_sampNo = select_sampno | ||||
|         self.disable_hpy = config["disable_hpy"] | ||||
|         self.apply_samplerate = config["apply_samplerate"] | ||||
| 
 | ||||
|         # self._getAllData() | ||||
|         self._getAllLabels() | ||||
| 
 | ||||
|     def __getitem__(self, index): | ||||
|         # PN patience number | ||||
|         # SP/EP start point, end point | ||||
|         # temp_label.append([sampNo, label[-1], i, hpy_num, csa_num, osa_num, mean_low, flow_low]) | ||||
|         PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index] | ||||
|         # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index] | ||||
| 
 | ||||
|         if isinstance(datasets, dict): | ||||
|             segment = self.segment_augment(datasets[str(PN)], SP * self.apply_samplerate, EP * self.apply_samplerate) | ||||
|             return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP) | ||||
|         else: | ||||
|             raise Exception(f'dataset read failure!') | ||||
| 
 | ||||
|     def count_SA(self): | ||||
|         # assert isinstance(self.disable_hpy, int) | ||||
|         return sum(self.labels) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return len(self.labels_info) | ||||
| 
 | ||||
|     def _getAllLabels(self): | ||||
|         label_path = Path(self.label_path) | ||||
|         if not label_path.exists(): | ||||
|             raise Exception(f'{self.label_path} does not exist') | ||||
| 
 | ||||
|         try: | ||||
|             file_list = [] | ||||
|             if label_path.is_dir(): | ||||
|                 if self.dataset_type == "train": | ||||
|                     label_list = list(label_path.rglob("*_train_label.csv")) | ||||
|                 elif self.dataset_type == "valid": | ||||
|                     label_list = list(label_path.rglob("*_valid_label.csv")) | ||||
|                 elif self.dataset_type == "test": | ||||
|                     label_list = list(label_path.glob("*_sa_test_label.csv")) | ||||
|                     # label_list = list(label_path.rglob("*_test_label.npy")) | ||||
|                 elif self.dataset_type == "all_test": | ||||
|                     label_list = list(label_path.rglob("*_sa_all_label.csv")) | ||||
|                 else: | ||||
|                     raise ValueError("self.dataset type error") | ||||
|                 # label_list = list(label_path.rglob("*_label.npy")) | ||||
|                 label_list.sort() | ||||
|                 file_list += label_list | ||||
|             elif label_path.is_file(): | ||||
|                 raise Exception(f'dataset path should be a dir') | ||||
|             else: | ||||
|                 raise Exception(f'{self.label_path} does not exist') | ||||
|         except Exception as e: | ||||
|             raise Exception(f'Error loading data from {self.label_path}: {e} \n') | ||||
|         print("loading labels") | ||||
|         for i in tqdm(file_list): | ||||
|             if int(i.name.split("_")[0]) not in self.select_sampNo: | ||||
|                 continue | ||||
| 
 | ||||
|             if self.labels_info is None: | ||||
|                 self.labels_info = pd.read_csv(i).to_numpy(dtype=int) | ||||
|             else: | ||||
|                 labels = pd.read_csv(i).to_numpy(dtype=int) | ||||
|                 if len(labels) > 0: | ||||
|                     self.labels_info = np.concatenate((self.labels_info, labels)) | ||||
| 
 | ||||
|         self.labels = (self.labels_info[:, 3] > self.disable_hpy) * 1 | ||||
|         self.labels = torch.from_numpy(self.labels) | ||||
|         gpu = torch.cuda.is_available() | ||||
|         self.labels = self.labels.cuda() if gpu else self.labels | ||||
| 
 | ||||
|         # self.labels_info = self.labels_info[:10000] | ||||
|         print(f"{self.dataset_type} length is {len(self.labels_info)}") | ||||
| 
 | ||||
| 
 | ||||
| class TestApneaDataset2(ApneaDataset): | ||||
|     def __init__(self, config, dataset_type, select_sampno, segment_augment=None): | ||||
|         super(TestApneaDataset2, self).__init__( | ||||
|             config, | ||||
|             dataset_type=dataset_type, | ||||
|             select_sampno=select_sampno, | ||||
|             segment_augment=segment_augment, | ||||
|         ) | ||||
| 
 | ||||
|     def __getitem__(self, index): | ||||
|         PN, segmentNo, label_type, new_label, SP, EP = self.labels_info[index] | ||||
|         # PN, label, SP, EP, hpy_num, csa_num, osa_num, mean_low, flow_low = self.labels_info[index] | ||||
| 
 | ||||
|         if isinstance(datasets, dict): | ||||
|             dataset = datasets[str(PN)] | ||||
|             segment = self.segment_augment(dataset, SP * self.apply_samplerate, EP * self.apply_samplerate) | ||||
|             return (*segment, self.labels[index], PN, segmentNo, label_type, new_label, SP, EP) | ||||
|         else: | ||||
|             raise Exception(f'dataset read failure!') | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     pass | ||||
							
								
								
									
										58
									
								
								my_augment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								my_augment.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | ||||
| #!/usr/bin/python | ||||
| # -*- coding: UTF-8 -*- | ||||
| """ | ||||
| @author:Marques | ||||
| @file:my_augment.py | ||||
| @email:admin@marques22.com | ||||
| @email:2021022362@m.scnu.edu.cn | ||||
| @time:2022/07/26 | ||||
| """ | ||||
| import torch.cuda | ||||
| import yaml | ||||
| 
 | ||||
| from utils.Preprocessing import BCG_Operation | ||||
| import numpy as np | ||||
| from scipy.signal import stft | ||||
| from torch import from_numpy | ||||
| with open("./settings.yaml") as f: | ||||
|     hyp = yaml.load(f, Loader=yaml.SafeLoader)  # load hyps | ||||
| 
 | ||||
| apply_samplerate = hyp["apply_samplerate"] | ||||
| dataset_samplerate = hyp["dataset_samplerate"] | ||||
| preprocessing = BCG_Operation() | ||||
| preprocessing.sample_rate = dataset_samplerate | ||||
| 
 | ||||
| 
 | ||||
| def my_augment(dataset, config): | ||||
|     # dataset = preprocessing.Butterworth(dataset, "lowpass", low_cut=20, order=6) | ||||
|     # dataset = (dataset - dataset.mean()) / dataset.std() | ||||
|     # dataset_low = preprocessing.Butterworth(dataset, "lowpass", low_cut=0.7, order=6) | ||||
|     # dataset_high = preprocessing.Butterworth(dataset, "highpass", high_cut=1, order=6) | ||||
|     print(f"dataset sample_rate is {config['dataset_samplerate']}  down_ratio is {config['dataset_samplerate'] // config['apply_samplerate']}") | ||||
|     dataset = dataset[::config['dataset_samplerate'] // config['apply_samplerate']] | ||||
|     gpu = torch.cuda.is_available() | ||||
|     dataset = {"raw": from_numpy(dataset).float().cuda() if gpu else from_numpy(dataset).float(), | ||||
|                # "low": dataset_low, | ||||
|                # "high": dataset_high | ||||
|                } | ||||
|     return dataset | ||||
| 
 | ||||
| 
 | ||||
| def get_stft(x, fs, n): | ||||
|     print(len(x)) | ||||
|     f, t, amp = stft(x, fs, nperseg=n) | ||||
|     z = np.abs(amp.copy()) | ||||
|     return f, t, z | ||||
| 
 | ||||
| 
 | ||||
| def my_segment_augment(dataset, SP, EP): | ||||
|     # dataset_segment1 = dataset["low"][int(SP) * 100:int(EP) * 100].copy() | ||||
|     # dataset_segment2 = dataset["high"][int(SP) * 100:int(EP) * 100].copy() | ||||
| 
 | ||||
|     # dataset_segment = np.concatenate(([dataset_segment1], [dataset_segment2]), axis=0) | ||||
|     dataset_segment = dataset["raw"][int(SP):int(EP)].unsqueeze(dim=0) | ||||
|     return [dataset_segment] | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     pass | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user