test_perf.py 2.26 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import torch
from matplotlib import pyplot as plt
from collections import defaultdict
from tqdm import tqdm
from pathlib import Path

from utils import img, device
from utils.loss import ssim, mse_loss, LpipsLoss


torch.set_grad_enabled(False)
lpips_loss = LpipsLoss().to(device.default())


def save_error_image(gt: torch.Tensor, out: torch.Tensor, filename: str):
    error_image = (mse_loss(out, gt, reduction='none').mean(-3, True) / 1e-2).clamp(0, 1)
    error_image = img.torch2np(error_image)[..., 0]
    output_subdir = test_dir.parent / f"{dataset_name}_error"
    output_subdir.mkdir(exist_ok=True)
    img.save(plt.get_cmap("jet")(error_image), f'{output_subdir}/{filename}')


parser = argparse.ArgumentParser()
parser.add_argument('test_dir', type=str,
                    help='Path to the test output dir')
args = parser.parse_args()


test_dir = Path(args.test_dir)
dataset_name = test_dir.parts[-1].split("_")[0]
gt_dir = Path(*test_dir.parts[:-5]) / dataset_name

test_image_paths = list(test_dir.iterdir())
test_image_paths.sort(key=lambda path: path.stem)

gt_image_paths = list(gt_dir.iterdir())
gt_image_paths.sort(key=lambda path: path.stem)


perf = defaultdict(list, dummy=[])
n = len(test_image_paths)
for test_image_path, gt_image_path in tqdm(zip(test_image_paths, gt_image_paths), total=n):
    out_image = img.load(test_image_path).to(device.default())
    gt_image = img.load(gt_image_path).to(device.default())
    perf["mse"].append(mse_loss(out_image, gt_image).item())
    perf["ssim"].append(ssim(out_image, gt_image).item() * 100)
    perf["lpips"].append(lpips_loss(out_image, gt_image).item())
    save_error_image(gt_image, out_image, test_image_path.name)

perf_mean_error = sum(perf['mse']) / n
perf_time = ""

# Remove old performance reports
for file in test_dir.parent.glob(f'perf_{dataset_name}*'):
    parts = file.name.split("_")
    if len(parts) == 4:
        perf_time = "_" + parts[2]
    file.unlink()

perf_name = f'perf_{dataset_name}{perf_time}_{perf_mean_error:.2e}.csv'


# Save new performance reports
with (test_dir.parent / perf_name).open('w') as fp:
    fp.write('PSNR, SSIM, LPIPS\n')
    fp.writelines([
        f'{img.mse2psnr(perf["mse"][i]):.2f}, {perf["ssim"][i]:.2f}, {perf["lpips"][i]:.2e}\n'
        for i in range(n)
    ])