loss.py 3.63 KB
Newer Older
BobYeah's avatar
BobYeah committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from ssim import *
from perc_loss import * 

l1loss = torch.nn.L1Loss()
perc_loss = VGGPerceptualLoss() 
perc_loss = perc_loss.to("cuda:1")

##### LOSS #####
def calImageGradients(images):
    # x is a 4-D tensor
    dx = images[:, :, 1:, :] - images[:, :, :-1, :]
    dy = images[:, :, :, 1:] - images[:, :, :, :-1]
    return dx, dy

def loss_new(generated, gt):
    mse_loss = torch.nn.MSELoss()
    rmse_intensity = mse_loss(generated, gt)
    psnr_intensity = torch.log10(rmse_intensity)
    # print("psnr:",psnr_intensity)
    # ssim_intensity = ssim(generated, gt)
    labels_dx, labels_dy = calImageGradients(gt)
    # print("generated:",generated.shape)
    preds_dx, preds_dy = calImageGradients(generated)
    rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
    psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
    # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
    p_loss = perc_loss(generated,gt)
    # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
    total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
    # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
    return total_loss

def loss_without_perc(generated, gt): 
    mse_loss = torch.nn.MSELoss()
    rmse_intensity = mse_loss(generated, gt)
    psnr_intensity = torch.log10(rmse_intensity)
    # print("psnr:",psnr_intensity)
    # ssim_intensity = ssim(generated, gt)
    labels_dx, labels_dy = calImageGradients(gt)
    # print("generated:",generated.shape)
    preds_dx, preds_dy = calImageGradients(generated)
    rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
    psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
    # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
    # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
    total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
    # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
    return total_loss
##### LOSS #####


class ReconstructionLoss(torch.nn.Module):
    def __init__(self):
        super(ReconstructionLoss, self).__init__()

    def forward(self, generated, gt):
        rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
        psnr_intensity = torch.log10(rmse_intensity)
        labels_dx, labels_dy = calImageGradients(gt)
        preds_dx, preds_dy = calImageGradients(generated)
        rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
        psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
        total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
        return total_loss

class PerceptionReconstructionLoss(torch.nn.Module):
    def __init__(self):
        super(PerceptionReconstructionLoss, self).__init__()

    def forward(self, generated, gt):
        rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
        psnr_intensity = torch.log10(rmse_intensity)
        labels_dx, labels_dy = calImageGradients(gt)
        preds_dx, preds_dy = calImageGradients(generated)
        rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
        psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
        p_loss = perc_loss(generated,gt)
        total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
        return total_loss