import argparse import torch from matplotlib import pyplot as plt from collections import defaultdict from tqdm import tqdm from pathlib import Path from utils import img, device from utils.loss import ssim, mse_loss, LpipsLoss torch.set_grad_enabled(False) lpips_loss = LpipsLoss().to(device.default()) def save_error_image(gt: torch.Tensor, out: torch.Tensor, filename: str): error_image = (mse_loss(out, gt, reduction='none').mean(-3, True) / 1e-2).clamp(0, 1) error_image = img.torch2np(error_image)[..., 0] output_subdir = test_dir.parent / f"{dataset_name}_error" output_subdir.mkdir(exist_ok=True) img.save(plt.get_cmap("jet")(error_image), f'{output_subdir}/{filename}') parser = argparse.ArgumentParser() parser.add_argument('test_dir', type=str, help='Path to the test output dir') args = parser.parse_args() test_dir = Path(args.test_dir) dataset_name = test_dir.parts[-1].split("_")[0] gt_dir = Path(*test_dir.parts[:-5]) / dataset_name test_image_paths = list(test_dir.iterdir()) test_image_paths.sort(key=lambda path: path.stem) gt_image_paths = list(gt_dir.iterdir()) gt_image_paths.sort(key=lambda path: path.stem) perf = defaultdict(list, dummy=[]) n = len(test_image_paths) for test_image_path, gt_image_path in tqdm(zip(test_image_paths, gt_image_paths), total=n): out_image = img.load(test_image_path).to(device.default()) gt_image = img.load(gt_image_path).to(device.default()) perf["mse"].append(mse_loss(out_image, gt_image).item()) perf["ssim"].append(ssim(out_image, gt_image).item() * 100) perf["lpips"].append(lpips_loss(out_image, gt_image).item()) save_error_image(gt_image, out_image, test_image_path.name) perf_mean_error = sum(perf['mse']) / n perf_time = "" # Remove old performance reports for file in test_dir.parent.glob(f'perf_{dataset_name}*'): parts = file.name.split("_") if len(parts) == 4: perf_time = "_" + parts[2] file.unlink() perf_name = f'perf_{dataset_name}{perf_time}_{perf_mean_error:.2e}.csv' # Save new performance reports with (test_dir.parent / perf_name).open('w') as fp: fp.write('PSNR, SSIM, LPIPS\n') fp.writelines([ f'{img.mse2psnr(perf["mse"][i]):.2f}, {perf["ssim"][i]:.2f}, {perf["lpips"][i]:.2e}\n' for i in range(n) ])