#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:Marques
@time:2021/10/15
"""
import os

import yaml
import logging
from pathlib import Path

import time
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torch.cuda
from tqdm import tqdm
from torchinfo import summary
from load_dataset import ApneaDataset, read_dataset

from torch import nn
from utils.calc_metrics import CALC_METRICS
from sklearn.model_selection import KFold
from model.Hybrid_Net001 import HYBRIDNET001
# from utils.LossFunction import Foca1lLoss
from my_augment import my_augment, my_segment_augment
# 加载配置
with open("./settings.yaml") as f:
    hyp = yaml.load(f, Loader=yaml.SafeLoader)

os.environ["CUDA_VISIBLE_DEVICES"] = hyp["GPU"]
os.environ["WANDB_MODE"] = "dryrun"

realtime = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))

# 读取地址参数
data_path = hyp["Path"]["dataset"]
label_path = hyp["Path"]["label"]

save_dir = Path(hyp["Path"]["save"]) / (Path(hyp["Path"]["save"]).name + "_" + realtime)
save_dir.mkdir(parents=True, exist_ok=True)

# 设置日志
logger = logging.getLogger()
logger.setLevel(logging.NOTSET)
fh = logging.FileHandler(save_dir / (realtime + ".log"), mode='a')
fh.setLevel(logging.NOTSET)
fh.setFormatter(logging.Formatter("%(asctime)s: %(message)s"))
logger.addHandler(fh)

ch = logging.StreamHandler()
ch.setLevel(logging.NOTSET)
ch.setFormatter(logging.Formatter("%(asctime)s: %(message)s"))
logger.addHandler(ch)
logging.getLogger('matplotlib.font_manager').disabled = True
logger.info("------------------------------------")
logger.info('hyper_parameters: ' + ', '.join(f'{k}={v}\n' for k, v in hyp.items()))

# 备份配置
with open(save_dir / 'settings.yaml', 'w') as f:
    yaml.dump(hyp, f, sort_keys=False)

# Hyper-parameters
gpu = torch.cuda.is_available()
epochs = hyp["epoch"]
lr = hyp["lr"]
nc = hyp["nc"]
bs = hyp["batch_size"]
worker = hyp["number_worker"]
select_sampno = hyp["select_sampno"]

read_dataset(data_path, augment=my_augment)
calc_metrics = CALC_METRICS(nc)


# 开始训练
# 训练
def model_train(model, train_loader, optimizer, scheduler, loss_func, training_state):
    model.train()
    train_loss = 0.0
    optimizer.zero_grad()

    pbar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=80)
    pbar.set_description(training_state)
    for i, (resp, stft, labels) in pbar:
        resp = resp.float().cuda() if gpu else resp.float()
        stft = stft.float().cuda() if gpu else stft.float()
        labels = labels.cuda() if gpu else labels

        # 所以输入为【batch_size, 1, 3000】 3000 = 30秒 * 100Hz
        # segments = segments.view(len(segments), 1, -1)

        out = model(resp, stft)
        out = out.squeeze(-1)
        loss = loss_func(out, labels.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 余弦退火传入变量
        # scheduler.step(epoch + i / len(train_loader.dataset))
        # 自适应调整传入变量
        scheduler.step(loss)

        loss_value = loss.item()
        train_loss += loss_value
        # cur_lr = optimizer.param_groups[-1]['lr']
        # labels = torch.unsqueeze(labels, dim=1)
        # out = F.softmax(out, dim=1)
        # out = torch.unsqueeze(out[:, 1], dim=1)
        calc_metrics.update(out.cpu(), labels.cpu())
        # if i % 20 == 0:
        #     pbar.write(calc_metrics.get_matrix(loss=loss_value, cur_lr=cur_lr, epoch=epoch))

    cur_lr = optimizer.param_groups[-1]['lr']
    train_loss /= len(train_loader)

    calc_metrics.compute()
    logger.info("")
    logger.info("--------------------------------------")
    logger.info(training_state)
    logger.info(calc_metrics.get_matrix(loss=train_loss, epoch=epoch, epoch_type="train", cur_lr=cur_lr))
    calc_metrics.reset()


def model_valid(model, valid_loader, wdir, loss_func):
    model.eval()
    valid_loss = 0.0
    for resp, stft, labels in valid_loader:
        resp = resp.float().cuda() if gpu else resp.float()
        stft = stft.float().cuda() if gpu else stft.float()
        labels = labels.cuda() if gpu else labels
        with torch.no_grad():
            # segments = F.normalize(segments)
            # segments = segments - torch.mean(segments, dim=1).view(-1, 1)
            # segments = F.normalize(segments - torch.mean(segments, dim=1).view(-1, 1))
            # segments = segments.view(len(segments), 1, -1)

            out = model(resp, stft)
            out = out.squeeze(-1)
            # out = F.softmax(out, dim=1)
            loss = loss_func(out, labels.float())

        valid_loss += loss.item()
        # labels = torch.unsqueeze(labels, dim=1)
        # out = torch.unsqueeze(out[:, 1], dim=1)
        calc_metrics.update(out.cpu(), labels.cpu())

    valid_loss /= len(valid_loader)
    calc_metrics.compute()
    logger.info(calc_metrics.get_matrix(loss=valid_loss, epoch=epoch, epoch_type="valid"))
    global best_f1
    valid_f1 = calc_metrics.metrics[-1].compute()
    if valid_f1 > best_f1:
        best_f1 = valid_f1
        torch.save(model.state_dict(), wdir / f"best_{epoch}_{str(round(float(valid_f1), 3))}.pt")
        torch.save(model.state_dict(), wdir / f"best.pt")
        if wandb is not None:
            wandb.run.summary["best_f1"] = valid_f1
    calc_metrics.reset()


def model_test(model, test_loader, loss_func):
    model.eval()
    test_loss = 0.0
    for resp, stft, labels in test_loader:
        resp = resp.float().cuda() if gpu else resp.float()
        stft = stft.float().cuda() if gpu else stft.float()
        labels = labels.cuda() if gpu else labels
        with torch.no_grad():
            # segments = F.normalize(segments)
            # segments = segments - torch.mean(segments, dim=1).view(-1, 1)
            # segments = F.normalize(segments - torch.mean(segments, dim=1).view(-1, 1))
            # segments = segments.view(len(segments), 1, -1)

            out = model(resp, stft)
            out = out.squeeze(-1)
            # out = F.softmax(out, dim=1)
            loss = loss_func(out, labels.float())

        test_loss += loss.item()
        # labels = torch.unsqueeze(labels, dim=1)
        # out = torch.unsqueeze(out[:, 1], dim=1)
        calc_metrics.update(out.cpu(), labels.cpu())

    test_loss /= len(test_loader)
    calc_metrics.compute()
    logger.info(calc_metrics.get_matrix(loss=test_loss, epoch=epoch, epoch_type="test"))
    calc_metrics.reset()


if __name__ == '__main__':

    try:
        import wandb
    except ImportError:
        wandb = None
        prefix = 'wandb: '
        logger.info(f"{prefix}Install Weights & Biases logger with 'pip install wandb'")

    if wandb is not None and wandb.run is None:
        wandb_run = wandb.init(
            config=hyp,
            name=save_dir.stem,
            project=hyp["project"],
            notes=hyp["Note"],
            tags=hyp["tags"],
            entity=hyp["entity"],
        )
    exam_name = Path("./").absolute().name

    model_net = eval(hyp["model_name"])()
    model_net.initialize_weights()
    summary(model_net, [(1, 120, 1), (1, 1, 121, 26)])

    time.sleep(3)
    if gpu:
        model_net.cuda()

    k_folds = 5
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)

    logger.info('--------------------------------')
    for fold, (train_ids, test_ids) in enumerate(kfold.split(select_sampno)):
        logger.info(f'Start FOLD {fold} / {k_folds}----------------------')
        train_set = [select_sampno[i] for i in train_ids]
        test_set = [select_sampno[i] for i in test_ids]
        logger.info(f'Train_Set:{train_set}')
        logger.info(f'Independent_Test_Set:{test_set}')

        sub_save_dir = save_dir / f"KFold_{fold}"
        sub_save_dir.mkdir(exist_ok=True, parents=True)
        wdir = sub_save_dir / "weights"
        wdir.mkdir(exist_ok=True, parents=True)

        hyp["train_set"] = train_set
        hyp["test_set"] = test_set
        with open(sub_save_dir / 'settings.yaml', 'w') as f:
            yaml.dump(hyp, f, sort_keys=False)

        train_dataset = ApneaDataset(data_path, label_path, train_set, "train", my_segment_augment)
        valid_dataset = ApneaDataset(data_path, label_path, train_set, "valid", my_segment_augment)
        test_dataset = ApneaDataset(data_path, label_path, train_set, "test", my_segment_augment)

        train_loader = DataLoader(train_dataset, batch_size=bs, pin_memory=True, num_workers=worker, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=bs, pin_memory=True, num_workers=worker)
        test_loader = DataLoader(test_dataset, batch_size=bs, pin_memory=True, num_workers=worker)

        # 重新初始化模型
        del model_net
        model_net = eval(hyp["model_name"])()
        model_net.initialize_weights()
        if gpu:
            model_net.cuda()

        logger.info(f"Weight is {[(len(train_dataset) - train_dataset.count_SA()) / train_dataset.count_SA()]}")
        # 损失函数与优化器
        loss_function = nn.BCELoss(
            weight=torch.Tensor([(len(train_dataset) - train_dataset.count_SA()) / train_dataset.count_SA()]).cuda())

        # loss_func = nn.BCEWithLogitsLoss()
        # loss_func = FocalLoss(class_num=nc, alpha=0.75, size_average="sum")

        # momentum
        # nesterov 牛顿动量
        # weight_decay L2正则
        # optimizer = torch.optim.SGD(model_net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=1e-6)

        optimizer = torch.optim.Adam(model_net.parameters(), lr=lr)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(hyp["T_max"]),

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
                                                               patience=2836, min_lr=1e-8,
                                                               verbose=True)

        # 参数记录
        best_f1 = 0

        for epoch in range(epochs):
            model_train(model_net, train_loader, optimizer, scheduler, loss_function,
                        f"EXAM:{exam_name} FOLD:{fold}/{k_folds} EPOCH:{epoch}/{epochs}")
            model_valid(model_net, valid_loader, wdir, loss_function)
            model_test(model_net, test_loader, loss_function)
            if wandb is not None:
                calc_metrics.wandb_log(wandb=wandb, cur_lr=optimizer.param_groups[-1]['lr'])