Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
import importlib
import os
from pathlib import Path
import sys
from inspect import isclass
from model.base import BaseModel
from .train import train_classes, Train
from .trainer import Trainer, trainer_classes
from .train_with_space import TrainWithSpace
#from .train_multi_scale import TrainMultiScale
__all__ = ["Trainer", "TrainWithSpace"]
# Automatically import any python files this directory
package_dir = os.path.dirname(__file__)
package = os.path.basename(package_dir)
for file in os.listdir(package_dir):
path = os.path.join(package_dir, file)
if file.startswith('_') or file.startswith('.'):
continue
if file.endswith('.py') or os.path.isdir(path):
model_name = file[:-3] if file.endswith('.py') else file
importlib.import_module(f'{package}.{model_name}')
def get_class(class_name: str) -> type:
return train_classes[class_name]
def get_trainer(model: BaseModel, run_dir: Path, states: dict) -> Train:
train_class = get_class(model.TrainerClass)
return train_class(model, run_dir, states)
# Register all trainer classes
for item in __all__:
var = getattr(sys.modules[__name__], item)
if isclass(var) and issubclass(var, Trainer):
trainer_classes[item] = var
\ No newline at end of file
import csv
import json
import logging
import torch
import torch.nn.functional as nn_f
from typing import Any, Dict, Union
from pathlib import Path
import loss
from utils import netio, math
from utils.misc import format_time, print_and_log
from utils.progress_bar import progress_bar
from utils.perf import Perf, enable_perf, perf, get_perf_result
from utils.env import set_env
from utils.type import InputData, ReturnData
from data.loader import DataLoader
from model import serialize
from model.base import BaseModel
train_classes = {}
class BaseTrainMeta(type):
def __new__(cls, name, bases, attrs):
new_cls = type.__new__(cls, name, bases, attrs)
train_classes[name] = new_cls
return new_cls
class Train(object, metaclass=BaseTrainMeta):
@property
def perf_mode(self):
return self.perf_frames > 0
def _arg(self, name: str, default=None):
return self.states.get("train", {}).get(name, default)
def __init__(self, model: BaseModel, run_dir: Path, states: dict) -> None:
super().__init__()
print_and_log(
f"Create trainer {__class__} with args: {json.dumps(states.get('train', {}))}")
self.model = model
self.run_dir = run_dir
self.states = states
self.epoch = states.get("epoch", 0)
self.iters = states.get("iters", 0)
self.max_epochs = self._arg("max_epochs", 50)
self.checkpoint_interval = self._arg("checkpoint_interval", 10)
self.perf_frames = self._arg("perf_frames", 0)
self.model.train()
self.reset_optimizer()
if 'opti' in states:
self.optimizer.load_state_dict(states['opti'])
# For performance measurement
if self.perf_mode:
enable_perf()
self.env = {
"trainer": self
}
def reset_optimizer(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
def train(self, data_loader: DataLoader):
set_env(self.env)
self.data_loader = data_loader
self.iters_per_epoch = self.perf_frames or len(data_loader)
print(f"Begin training... Max epochs: {self.max_epochs}")
while self.epoch < self.max_epochs:
self._train_epoch()
self._save_checkpoint()
print("Train finished")
def _save_checkpoint(self):
(self.run_dir / '_misc').mkdir(exist_ok=True)
# Clean checkpoints
for i in range(1, self.epoch):
if i % self.checkpoint_interval != 0:
checkpoint_path = self.run_dir / f'checkpoint_{i}.tar'
if checkpoint_path.exists():
checkpoint_path.rename(self.run_dir / f'_misc/checkpoint_{i}.tar')
# Save checkpoint
self.states.update({
**serialize(self.model),
"epoch": self.epoch,
"iters": self.iters,
"opti": self.optimizer.state_dict()
})
netio.save_checkpoint(self.states, self.run_dir, self.epoch)
def _show_progress(self, iters_in_epoch: int, avg_loss: float = 0, recent_loss: float = 0):
iters_per_epoch = self.perf_frames or len(self.data_loader)
progress_bar(iters_in_epoch, iters_per_epoch,
f"Loss: {recent_loss:.2e} ({avg_loss:.2e})",
f"Epoch {self.epoch + 1:<3d}",
f" {self.run_dir.absolute()}")
def _show_perf(self):
s = "Performance Report ==>\n"
res = get_perf_result()
if res is None:
s += "No available data.\n"
else:
for key, val in res.items():
path_segs = key.split("/")
s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n"
print(s)
def _forward(self, data: InputData) -> ReturnData:
return self.model(data, 'color', 'energies', 'speculars')
@perf
def _train_iter(self, data: Dict[str, Union[torch.Tensor, Any]]) -> float:
def filtered_data(data, filter):
if filter is not None:
return data[filter]
return data
with perf("Forward"):
if isinstance(data, list):
out_colors = []
out_energies = []
out_speculars = []
gt_colors = []
for datum in data:
partial_out = self._forward(datum)
out_colors.append(partial_out['color'])
out_energies.append(partial_out['energies'].flatten())
if 'speculars' in partial_out:
out_speculars.append(partial_out['speculars'].flatten())
gt_colors.append(filtered_data(datum["color"], partial_out.get("rays_filter")))
out_colors = torch.cat(out_colors)
out_energies = torch.cat(out_energies)
out_speculars = torch.cat(out_speculars) if len(out_speculars) > 0 else None
gt_colors = torch.cat(gt_colors)
else:
out = self._forward(data)
out_colors = out['color']
out_energies = out['energies']
out_speculars = out.get('speculars')
gt_colors = filtered_data(data['color'], out.get("rays_filter"))
with perf("Compute loss"):
loss_val = loss.mse_loss(out_colors, gt_colors)
if self._arg("density_regularization_weight"):
loss_val += loss.cauchy_loss(out_energies, s=self._arg("density_regularization_scale"))\
* self._arg("density_regularization_weight")
if self._arg("specular_regularization_weight") and out_speculars is not None:
loss_val += loss.cauchy_loss(out_speculars, s=self._arg("specular_regularization_scale")) \
* self._arg("specular_regularization_weight")
#return loss_val.item() # TODO remove this line
with perf("Backward"):
self.optimizer.zero_grad(True)
loss_val.backward()
with perf("Update"):
self.optimizer.step()
return loss_val.item()
def _train_epoch(self):
iters_in_epoch = 0
recent_loss = []
tot_loss = 0
train_epoch_node = Perf.Node("Train Epoch")
self._show_progress(iters_in_epoch)
for data in self.data_loader:
loss_val = self._train_iter(data)
self.iters += 1
iters_in_epoch += 1
recent_loss = (recent_loss + [loss_val])[-50:]
recent_avg_loss = sum(recent_loss) / len(recent_loss)
tot_loss += loss_val
avg_loss = tot_loss / iters_in_epoch
#loss_min = min(loss_min, loss_val)
#loss_max = max(loss_max, loss_val)
#loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
self._show_progress(iters_in_epoch, avg_loss=avg_loss, recent_loss=recent_avg_loss)
if self.perf_mode and iters_in_epoch >= self.perf_frames:
self._show_perf()
exit()
train_epoch_node.close()
torch.cuda.synchronize()
self.epoch += 1
epoch_dur = train_epoch_node.duration() / 1000
logging.info(f"Epoch {self.epoch} spent {format_time(epoch_dur)} "
f"(Avg. {format_time(epoch_dur / self.iters_per_epoch)}/iter). "
f"Loss is {avg_loss:.2e}")
#print(list(self.model.model(0).named_parameters())[2])
#print(list(self.model.model(1).named_parameters())[2])
def _train_epoch_debug(self): # TBR
iters_in_epoch = 0
loss_min = math.huge
loss_max = 0
loss_avg = 0
self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0})
indices = []
debug_data = []
for idx, rays_o, rays_d, extra in self.data_loader:
out = self.model(rays_o, rays_d, extra_outputs=['layers', 'weights'])
loss_val = nn_f.mse_loss(out['color'], extra['color']).item()
loss_min = min(loss_min, loss_val)
loss_max = max(loss_max, loss_val)
loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
self.iters += 1
iters_in_epoch += 1
self._show_progress(iters_in_epoch, loss={
'val': loss_val,
'min': loss_min,
'max': loss_max,
'avg': loss_avg
})
indices.append(idx)
debug_data.append(torch.cat([
extra['view_idx'][..., None],
extra['pix_idx'][..., None],
rays_d,
#out['samples'].pts[:, 215:225].reshape(idx.size(0), -1),
#out['samples'].dirs[:, :3].reshape(idx.size(0), -1),
#out['samples'].voxel_indices[:, 215:225],
out['states'].densities[:, 210:230].detach().reshape(idx.size(0), -1),
out['states'].energies[:, 210:230].detach().reshape(idx.size(0), -1)
# out['color'].detach()
], dim=-1))
# states: VolumnRenderer.States = out['states'] # TBR
indices = torch.cat(indices, dim=0)
debug_data = torch.cat(debug_data, dim=0)
indices, sort = indices.sort()
debug_data = debug_data[sort]
name = "rand.csv" if self.data_loader.shuffle else "seq.csv"
with (self.run_dir / name).open("w") as fp:
csv_writer = csv.writer(fp)
csv_writer.writerows(torch.cat([indices[:20, None], debug_data[:20]], dim=-1).tolist())
return
with (self.run_dir / 'states.csv').open("w") as fp:
csv_writer = csv.writer(fp)
for chunk_info in states.chunk_infos:
csv_writer.writerow(
[*chunk_info['range'], chunk_info['hits'], chunk_info['core_i']])
if chunk_info['hits'] > 0:
csv_writer.writerows(torch.cat([
chunk_info['samples'].pts,
chunk_info['samples'].dirs,
chunk_info['samples'].voxel_indices[:, None],
chunk_info['colors'],
chunk_info['energies']
], dim=-1).tolist())
csv_writer.writerow([])
from typing import Union
from .train_with_space import TrainWithSpace
import torch
from pathlib import Path
from model.cnerf import CNeRF
from data.loader import DataLoader, MultiScaleDataLoader
from .train_with_space import TrainWithSpace
from .trainer import Trainer
from model import CNeRF
from data import RaysLoader, MultiScaleDataLoader
from modules import Voxels
from utils.misc import print_and_log
from utils.logging import print_and_log
class TrainMultiScale(TrainWithSpace):
model: CNeRF
data_loader: Union[DataLoader, MultiScaleDataLoader]
data_loader: RaysLoader | MultiScaleDataLoader
def __init__(self, model: CNeRF, run_dir: Path, states: dict) -> None:
super().__init__(model, run_dir, states)
self.freeze_epochs = self._arg("freeze_epochs", [])
self.level_by_level = True#self._arg("level_by_level", False)
def _train_epoch(self):
def _train_epoch(self, profiler: torch.profiler.profile = None):
l = self._check_epoch_matches(self.freeze_epochs, self.epoch)
if l >= 0:
self.model.trigger_stage(l + 1)
......@@ -35,7 +35,7 @@ class TrainMultiScale(TrainWithSpace):
space: Voxels = self.model.model(self.model.stage).space
if self._check_epoch_matches(self.prune_epochs, self.epoch + 1) >= 0:
self.voxel_access = torch.zeros(space.n_voxels, dtype=torch.long, device=space.device)
super(TrainWithSpace, self)._train_epoch()
super(Trainer, self)._train_epoch(profiler)
if self.voxel_access is not None:
before, after = space.prune(self.voxel_access > 0)
print_and_log(f"Prune by weights: {before} -> {after}")
......
from .train import Train
import sys
import torch
from pathlib import Path
from typing import List
from .trainer import Trainer
from modules import Voxels
from model.base import BaseModel
from data.loader import DataLoader
from utils.samples import Samples
from model import Model
from data import RaysLoader
from utils.types import *
from utils.mem_profiler import MemProfiler
from utils.misc import print_and_log
from utils.type import InputData, ReturnData
from utils.logging import print_and_log
class TrainWithSpace(Train):
class TrainWithSpace(Trainer):
def __init__(self, model: BaseModel, run_dir: Path, states: dict) -> None:
def __init__(self, model: Model, run_dir: Path, states: dict) -> None:
super().__init__(model, run_dir, states)
self.prune_epochs = [] if self.perf_mode else self._arg("prune_epochs", [])
self.split_epochs = [] if self.perf_mode else self._arg("split_epochs", [])
self.voxel_access = None
#MemProfiler.enable = True
def _train_epoch(self):
def _train_epoch(self, profiler: torch.profiler.profile = None):
self._split()
space: Voxels = self.model.space
if self._check_epoch_matches(self.prune_epochs, self.epoch + 1) >= 0:
self.voxel_access = torch.zeros(space.n_voxels, dtype=torch.long, device=space.device)
super()._train_epoch()
super()._train_epoch(profiler)
if self.voxel_access is not None:
before, after = space.prune(self.voxel_access > 0)
print_and_log(f"Prune by weights: {before} -> {after}")
self.voxel_access = None
# self._prune()
def _forward(self, data: InputData) -> ReturnData:
def _forward(self, rays: Rays) -> ReturnData:
if self.voxel_access is None:
return super()._forward(data)
out = self.model(data, 'color', 'energies', 'speculars', 'weights', "samples")
return super()._forward(rays)
out = self.model(rays, 'color', 'energies', 'speculars', 'weights', "samples")
with torch.no_grad():
access_voxels = out['samples'].voxel_indices[out['weights'][..., 0] > 0.01]
self.voxel_access.index_add_(0, access_voxels, torch.ones_like(access_voxels))
......@@ -63,7 +59,7 @@ class TrainWithSpace(Train):
except NotImplementedError:
print_and_log("The space does not support pruning operation. Just skip it.")
def _check_epoch_matches(self, key_epochs: List[int], epoch: int = None):
def _check_epoch_matches(self, key_epochs: list[int], epoch: int = None):
epoch = epoch if epoch is not None else self.epoch
if epoch == 0 or len(key_epochs) == 0:
return -1
......@@ -95,7 +91,7 @@ class TrainWithSpace(Train):
], 0) # (M[, ...])
return space.prune(scores > threshold)
def _prune_voxels_by_weights(self, data_loader: DataLoader = None):
def _prune_voxels_by_weights(self, data_loader: RaysLoader = None):
space: Voxels = self.model.space
data_loader = data_loader or self.data_loader
batch_size = data_loader.batch_size
......
from collections import defaultdict
from statistics import mean
from tqdm import tqdm
from tensorboardX import SummaryWriter
from torch.optim import Optimizer, Adam
from torch.optim.lr_scheduler import _LRScheduler, ExponentialLR
from utils import netio, logging, loss, misc
from utils.args import BaseArgs
from utils.profile import Profiler, enable_profile, profile
from utils.types import *
from data import Dataset, RaysLoader
from model import Model
trainer_classes: dict[str, "Trainer"] = {}
class Trainer:
class Args(BaseArgs):
max_iters: int | None
max_epochs: int = 20
checkpoint_interval: int | None
batch_size: int = 4096
loss: list[str] = ["Color_L2", "CoarseColor_L2"]
lr: float = 5e-4
lr_decay: float | None
profile_iters: int | None
trainset: str
args: Args
states: dict[str, Any]
optimizer: Optimizer
scheduler: _LRScheduler | None
loss_defs = {
"Color_L2": {
"fn": lambda out, gt: loss.mse_loss(out["color"], gt["color"]),
"required_outputs": ["color"]
},
"CoarseColor_L2": {
"fn": lambda out, gt: loss.mse_loss(out["coarse_color"], gt["color"]),
"required_outputs": ["coarse_color"]
},
"Density_Reg": {
"fn": lambda out, gt: loss.cauchy_loss(out["densities"], 1e4) * 1e-4,
"required_outputs": ["densities"]
}
}
@property
def profile_mode(self) -> bool:
return self.profile_iters is not None
@staticmethod
def get_class(typename: str) -> Type["Trainer"] | None:
return trainer_classes.get(typename)
def __init__(self, model: Model, run_dir: Path, args: Args = None) -> None:
self.model = model.train()
self.run_dir = run_dir
self.args = args or self.__class__.Args()
self.epoch = 0
self.iters = 0
self.profile_warmup_iters = 10
self.profile_iters = self.args.profile_iters
if self.profile_mode: # profile mode
self.max_iters = self.profile_warmup_iters + self.profile_iters
self.max_epochs = None
self.checkpoint_interval = None
elif self.args.max_iters: # iters mode
self.max_iters = self.args.max_iters
self.max_epochs = None
self.checkpoint_interval = self.args.checkpoint_interval or 10000
else: # epochs mode
self.max_iters = None
self.max_epochs = self.args.max_epochs
self.checkpoint_interval = self.args.checkpoint_interval or 10
self._init_optimizer()
self._init_scheduler()
self.required_outputs = []
for key in self.args.loss:
self.required_outputs += self.loss_defs[key]["required_outputs"]
self.required_outputs = list(set(self.required_outputs))
if self.profile_mode: # Enable performance measurement in profile mode
def handle_profile_result(result: Profiler.ProfileResult):
print(result.get_report())
exit()
enable_profile(self.profile_warmup_iters, self.profile_iters, handle_profile_result)
else: # Enable logging (Tensorboard & txt) in normal mode
tb_log_dir = self.run_dir / "_log"
tb_log_dir.mkdir(exist_ok=True)
self.tb_writer = SummaryWriter(tb_log_dir, purge_step=0)
logging.initialize(self.run_dir / "train.log")
logging.print_and_log(f"Model arguments: {self.model.args}")
logging.print_and_log(f"Trainer arguments: {self.args}")
# Debug: print model structure
print(model)
def state_dict(self) -> dict[str, Any]:
return {
"model": self.model.state_dict(),
"epoch": self.epoch,
"iters": self.iters,
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict() if self.scheduler else None
}
def load_state_dict(self, state_dict: dict[str, Any]):
self.epoch = state_dict.get("epoch", self.epoch)
self.iters = state_dict.get("iters", self.iters)
if "model" in state_dict:
self.model.load_state_dict(state_dict["model"])
if "optimizer" in state_dict:
self.optimizer.load_state_dict(state_dict["optimizer"])
if self.scheduler and "scheduler" in state_dict:
self.scheduler.load_state_dict(state_dict["scheduler"])
def reset_optimizer(self):
self._init_optimizer()
if self.scheduler is not None:
scheduler_state = self.scheduler.state_dict()
self._init_scheduler()
self.scheduler.load_state_dict(scheduler_state)
def train(self, dataset: Dataset):
self.rays_loader = RaysLoader(dataset, self.args.batch_size, shuffle=True,
device=self.model.device)
self.forward_chunk_size = self.args.batch_size
if self.max_iters:
print(f"Begin training... Max iters: {self.max_iters}")
self.progress = tqdm(total=self.max_iters, dynamic_ncols=True)
self.rays_iter = self.rays_loader.__iter__()
while self.iters < self.max_iters:
self._train_iters(min(self.checkpoint_interval, self.max_iters - self.iters))
self._save_checkpoint()
else:
print(f"Begin training... Max epochs: {self.max_epochs}")
while self.epoch < self.max_epochs:
self._train_epoch()
self._save_checkpoint()
print("Train finished")
@staticmethod
def create(model: Model, run_dir: PathLike, typename: str, args: dict[str, Any] = None) -> "Trainer":
if typename not in trainer_classes:
raise ValueError(f"Class {typename} is not found")
return trainer_classes.get(typename)(model, run_dir, args)
def _init_scheduler(self):
self.scheduler = self.args.lr_decay and ExponentialLR(self.optimizer, self.args.lr_decay)
def _init_optimizer(self):
self.optimizer = Adam(self.model.parameters(), lr=self.args.lr)
def _save_checkpoint(self):
if self.checkpoint_interval is None:
return
ckpt = {
"args": {
"model": self.model.__class__.__name__,
"model_args": vars(self.model.args),
"trainer": self.__class__.__name__,
"trainer_args": vars(self.args)
},
"states": self.state_dict()
}
if self.max_iters:
# For iters mode, a checkpoint will be saved every `checkpoint_interval` iterations
netio.save_checkpoint(ckpt, self.run_dir, self.iters)
else:
# For epochs mode, a checkpoint will be saved every epoch.
# Checkpoints which don't match `checkpoint_interval` will be cleaned later
netio.clean_checkpoint(self.run_dir, self.checkpoint_interval)
netio.save_checkpoint(ckpt, self.run_dir, self.epoch)
def _update_progress(self, loss: float = 0):
self.progress.set_postfix_str(f"Loss: {loss:.2e}" if loss > 0 else "")
self.progress.update()
@profile("Forward")
def _forward(self, rays: Rays) -> ReturnData:
return self.model(rays, *self.required_outputs)
@profile("Compute Loss")
def _compute_loss(self, rays: Rays, out: ReturnData) -> dict[str, torch.Tensor]:
torch.isnan
gt = rays.select(out["rays_filter"]) if "rays_filter" in out else rays
loss_terms: dict[str, torch.Tensor] = {}
for key in self.args.loss:
try:
loss_terms[key] = self.loss_defs[key]["fn"](out, gt)
except KeyError:
pass
# Debug: print loss terms
#self.progress.write(",".join([f"{key}: {value.item():.2e}" for key, value in loss_terms.items()]))
return loss_terms
@profile("Train iteration")
def _train_iter(self, rays: Rays) -> float:
try:
self.optimizer.zero_grad(True)
loss_terms = defaultdict(list)
for offset in range(0, rays.shape[0], self.forward_chunk_size):
rays_chunk = rays.select(slice(offset, offset + self.forward_chunk_size))
out_chunk = self._forward(rays_chunk)
loss_chunk = self._compute_loss(rays_chunk, out_chunk)
loss_value = sum(loss_chunk.values())
with profile("Backward"):
loss_value.backward()
loss_terms["Overall_Loss"].append(loss_value.item())
for key, value in loss_chunk.items():
loss_terms[key].append(value.item())
loss_terms = {key: mean(value) for key, value in loss_terms.items()}
with profile("Update"):
self.optimizer.step()
if self.scheduler:
self.scheduler.step()
# Debug: print lr
#self.progress.write(f"Learning rate: {self.optimizer.param_groups[0]['lr']}")
self.iters += 1
if hasattr(self, "tb_writer"):
for key, value in loss_terms.items():
self.tb_writer.add_scalar(f"Loss/{key}", value, self.iters)
return loss_terms["Overall_Loss"]
except RuntimeError as e:
if not e.__str__().startswith("CUDA out of memory"):
raise e
self.progress.write("CUDA out of memory, half forward batch and retry.")
logging.warning("CUDA out of memory, half forward batch and retry.")
self.forward_chunk_size //= 2
torch.cuda.empty_cache()
return self._train_iter(rays)
def _train_iters(self, iters: int):
recent_loss_list = []
tot_loss = 0
train_iters_node = Profiler.Node("Train Iterations")
for _ in range(iters):
try:
rays = self.rays_iter.__next__()
except StopIteration:
self.rays_iter = self.rays_loader.__iter__() # A new epoch
rays = self.rays_iter.__next__()
loss_val = self._train_iter(rays)
recent_loss_list = (recent_loss_list + [loss_val])[-50:] # Keep recent 50 iterations
recent_avg_loss = sum(recent_loss_list) / len(recent_loss_list)
tot_loss += loss_val
self._update_progress(recent_avg_loss)
train_iters_node.close()
torch.cuda.synchronize()
avg_time = train_iters_node.device_duration / 1000 / iters
avg_loss = tot_loss / iters
state_str = f"Iter {self.iters}: Avg. {misc.format_time(avg_time)}/iter; Loss: {avg_loss:.2e}"
self.progress.write(state_str)
logging.info(state_str)
def _train_epoch(self):
iters_per_epoch = len(self.rays_loader)
recent_loss_list = []
tot_loss = 0
self.progress = tqdm(total=iters_per_epoch, desc=f"Epoch {self.epoch + 1:<3d}", leave=False,
dynamic_ncols=True)
train_epoch_node = Profiler.Node("Train Epoch")
for rays in self.rays_loader:
with profile("Train iteration"):
loss_val = self._train_iter(rays)
recent_loss_list = (recent_loss_list + [loss_val])[-50:]
recent_avg_loss = sum(recent_loss_list) / len(recent_loss_list)
tot_loss += loss_val
self._update_progress(recent_avg_loss)
self.progress.close()
train_epoch_node.close()
torch.cuda.synchronize()
self.epoch += 1
epoch_time = train_epoch_node.device_duration / 1000
avg_time = epoch_time / iters_per_epoch
avg_loss = tot_loss / iters_per_epoch
state_str = f"Epoch {self.epoch} spent {misc.format_time(epoch_time)} "\
f"(Avg. {misc.format_time(avg_time)}/iter). Loss is {avg_loss:.2e}."
logging.print_and_log(state_str)
import os
import sys
import argparse
import torch
import torch.optim
import time
from tensorboardX import SummaryWriter
from torch import nn
parser = argparse.ArgumentParser()
# Arguments for train >>>
parser.add_argument('-c', '--config', type=str,
help='Net config files')
parser.add_argument('-i', '--config-id', type=str,
help='Net config id')
parser.add_argument('-e', '--epochs', type=int, default=200,
help='Max epochs for train')
parser.add_argument('-n', '--prev-net', type=str)
# Arguments for test >>>
parser.add_argument('-r', '--output-res', type=str,
help='Output resolution')
parser.add_argument('-o', '--output', nargs='+', type=str, default=['perf', 'color'],
help='Specify what to output (perf, color, depth, all)')
parser.add_argument('--output-type', type=str, default='image',
help='Specify the output type (image, video, debug)')
# Other arguments >>>
parser.add_argument('-t', '--test', action='store_true',
help='Start in test mode')
parser.add_argument('-m', '--model', type=str,
help='The model file to load for continue train or test')
parser.add_argument('-d', '--device', type=int, default=0,
help='Which CUDA device to use.')
parser.add_argument('-l', '--log-redirect', action='store_true',
help='Is log redirected to file?')
parser.add_argument('-p', '--prompt', action='store_true',
help='Interactive prompt mode')
parser.add_argument('dataset', type=str,
help='Dataset description file')
args = parser.parse_args()
torch.cuda.set_device(args.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from utils import netio
from utils import math
from utils import device
from utils import img
from utils import interact
from utils import color
from utils.progress_bar import progress_bar
from utils.perf import Perf
from data.spherical_view_syn import *
from data.loader import FastDataLoader
from configs.spherical_view_syn import SphericalViewSynConfig
from loss.ssim import ssim
data_desc_path = args.dataset if args.dataset.endswith('.json') \
else os.path.join(args.dataset, 'train.json')
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
data_dir = os.path.dirname(data_desc_path) + '/'
config = SphericalViewSynConfig()
BATCH_SIZE = 4096
SAVE_INTERVAL = 10
TEST_BATCH_SIZE = 1
TEST_MAX_RAYS = 32768 // 2
# Toggles
ROT_ONLY = False
EVAL_TIME_PERFORMANCE = False
# ========
#ROT_ONLY = True
#EVAL_TIME_PERFORMANCE = True
def get_model_files(datadir):
model_files = []
for root, _, files in os.walk(datadir):
model_files += [
os.path.join(root, file).replace(datadir, '')
for file in files if file.endswith('.pth')
]
return model_files
def set_outputs(args, outputs_str: str):
args.output = [s.strip() for s in outputs_str.split(',')]
if not args.test:
print('Start in train mode.')
if args.prompt: # 2.1 Prompt max epochs
args.epochs = interact.input_ex('Max epochs:', interact.input_to_int(min=1),
default=200)
epochRange = range(1, args.epochs + 1)
if args.prompt: # 2.2 Prompt continue train
model_files = get_model_files(data_dir)
args.model = interact.input_enum('Continue train on model:', model_files,
err_msg='No such model file', default='')
if args.model:
cont_model = os.path.join(data_dir, args.model)
model_name = os.path.splitext(os.path.basename(cont_model))[0]
epochRange = range(int(model_name[12:]) + 1, epochRange.stop)
run_dir = os.path.dirname(cont_model) + '/'
run_id = os.path.basename(run_dir[:-1])
config.from_id(run_id)
else:
if args.prompt: # 2.3 Prompt config file and additional config items
config_files = [
f[:-3] for f in os.listdir('configs')
if f.endswith('.py') and f != 'spherical_view_syn.py'
]
args.config = interact.input_enum('Specify config file:', config_files,
err_msg='No such config file', default='')
args.config_id = interact.input_ex('Specify custom config items:',
default='')
if args.config:
config.load(os.path.join('configs', args.config + '.py'))
if args.config_id:
config.from_id(args.config_id)
run_id = config.to_id()
run_dir = data_dir + run_id + '/'
log_dir = run_dir + 'log/'
else: # Test mode
print('Start in test mode.')
if args.prompt: # 3. Prompt test model, output resolution, output mode
model_files = get_model_files(data_dir)
args.model = interact.input_enum('Specify test model:', model_files,
err_msg='No such model file')
args.output_res = interact.input_ex('Specify output resolution:',
default='')
set_outputs(args, 'depth')
test_model_path = os.path.join(data_dir, args.model)
test_model_name = os.path.splitext(os.path.basename(test_model_path))[0]
run_dir = os.path.dirname(test_model_path) + '/'
run_id = os.path.basename(run_dir[:-1])
config.from_id(run_id)
config.sa['perturb_sample'] = False
args.output_res = tuple(int(s) for s in args.output_res.split('x')) \
if args.output_res else None
output_dir = f"{run_dir}output_{int(test_model_name.split('_')[-1])}"
output_dataset_id = '%s%s' % (
data_desc_name,
'_%dx%d' % (args.output_res[0], args.output_res[1]) if args.output_res else '')
args.output_flags = {
item: item in args.output or 'all' in args.output
for item in ['perf', 'color', 'depth', 'layers']
}
config.print()
print("run dir: ", run_dir)
# Initialize model
model = config.create_net().to(device.default())
loss_func = nn.MSELoss().to(device.default())
if args.prev_net:
prev_net_config_id = os.path.split(args.prev_net)[-2]
prev_net_config = SphericalViewSynConfig()
prev_net_config.from_id(prev_net_config_id)
prev_net = prev_net_config.create_net().to(device.default())
netio.load(args.prev_net, prev_net)
model.prev_net = prev_net
toggle_show_dir = False
last_toggle_time = 0
def train_loop(data_loader, optimizer, perf, writer, epoch, iters):
global toggle_show_dir
global last_toggle_time
dataset: SphericalViewSynDataset = data_loader.dataset
sub_iters = 0
iters_in_epoch = len(data_loader)
loss_min = 1e5
loss_max = 0
loss_avg = 0
perf1 = Perf(args.log_redirect, True)
for idx, _, rays_o, rays_d in data_loader:
rays_bins = dataset.patched_bins[idx] if dataset.load_bins else None
perf.checkpoint("Load")
out = model(rays_o, rays_d)
perf.checkpoint("Forward")
optimizer.zero_grad()
rays_bins = ((rays_bins[..., 0:1] - 0.5) * 2 * (out.size(-1) - 1)).to(torch.long)
gt = torch.zeros_like(out)
gt.scatter_(-1, rays_bins, 1)
loss_value = loss_func(out, gt)
#loss_value = loss_func(out, rays_bins[..., 0])
perf.checkpoint("Compute loss")
loss_value.backward()
perf.checkpoint("Backward")
optimizer.step()
perf.checkpoint("Update")
loss_value = loss_value.item()
loss_min = min(loss_min, loss_value)
loss_max = max(loss_max, loss_value)
loss_avg = (loss_avg * sub_iters + loss_value) / (sub_iters + 1)
if not args.log_redirect:
progress_bar(sub_iters, iters_in_epoch,
f"Loss: {loss_value:.2e} ({loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e})",
f"Epoch {epoch:<3d}")
current_time = time.time()
if last_toggle_time == 0:
last_toggle_time = current_time
if current_time - last_toggle_time > 3:
toggle_show_dir = not toggle_show_dir
last_toggle_time = current_time
if toggle_show_dir:
sys.stdout.write(f'Epoch {epoch:<3d} [ {run_dir} ]\r')
# Write tensorboard logs.
writer.add_scalar("loss mse", loss_value, iters)
# if patch and iters % 100 == 0:
# output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach()
# writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
# output_vs_gt, nrow=4).cpu().numpy(), iters)
iters += 1
sub_iters += 1
if args.log_redirect:
perf1.checkpoint('Epoch %d (%.2e/%.2e/%.2e)' %
(epoch, loss_min, loss_avg, loss_max), True)
return iters
def save_checkpoint(epoch, iters):
for i in range(1, epoch):
if (i < epoch // 50 * 50 and i % 50 != 0 or i % 10 != 0) and \
os.path.exists(f'{run_dir}model-epoch_{i}.pth'):
os.remove(f'{run_dir}model-epoch_{i}.pth')
netio.save(f'{run_dir}model-epoch_{epoch}.pth', model, iters, print_log=False)
def train():
# 1. Initialize data loader
print("Load dataset: " + data_desc_path)
dataset = SphericalViewSynDataset(data_desc_path, c=config.c, load_images=False,
load_bins=True)
dataset.set_patch_size(1)
data_loader = FastDataLoader(dataset, BATCH_SIZE, shuffle=True, pin_memory=True)
# 2. Initialize components
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
if epochRange.start > 1:
iters = netio.load(f'{run_dir}model-epoch_{epochRange.start - 1}.pth', model)
else:
os.makedirs(run_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
iters = 0
# 3. Train
model.train()
perf = Perf(EVAL_TIME_PERFORMANCE, start=True)
writer = SummaryWriter(log_dir)
print("Begin training...")
for epoch in epochRange:
iters = train_loop(data_loader, optimizer, perf, writer, epoch, iters)
save_checkpoint(epoch, iters)
print("Train finished")
def test():
with torch.no_grad():
# 1. Load dataset
print("Load dataset: " + data_desc_path)
dataset = SphericalViewSynDataset(data_desc_path, res=args.output_res, load_images=False,
load_bins=args.output_flags['perf'])
data_loader = FastDataLoader(dataset, 1, shuffle=False, pin_memory=True)
# 2. Load trained model
netio.load(test_model_path, model)
model.eval()
# 3. Test on dataset
print("Begin test, batch size is %d" % TEST_BATCH_SIZE)
i = 0
global_offset = 0
chns = color.chns(config.c)
n = dataset.n_views
total_pixels = n * dataset.view_res[0] * dataset.view_res[1]
out = {}
if args.output_flags['perf']:
perf_times = torch.empty(n)
perf = Perf(True, start=True)
out['bins'] = torch.zeros(total_pixels, 3, device=device.default())
for vi, _, rays_o, rays_d in data_loader:
rays_o = rays_o.view(-1, 3)
rays_d = rays_d.view(-1, 3)
#rays_bins = dataset.patched_bins[vi].view(-1, 3)
n_rays = rays_o.size(0)
for offset in range(0, n_rays, TEST_MAX_RAYS):
idx = slice(offset, min(offset + TEST_MAX_RAYS, n_rays))
global_idx = slice(idx.start + global_offset, idx.stop + global_offset)
ret = model(rays_o[idx], rays_d[idx])
is_local_max = torch.ones_like(ret, dtype=torch.bool)
for delta in range(-3, 0):
is_local_max[..., -delta:].logical_and_(
ret[..., -delta:] > ret[..., :delta])
for delta in range(1, 4):
is_local_max[..., :-delta].logical_and_(
ret[..., :-delta] > ret[..., delta:])
ret[is_local_max.logical_not()] = 0
vals, idxs = torch.topk(ret, 3) # (B, 3)
vals = vals / vals.sum(-1, keepdim=True)
out['bins'][global_idx] = (idxs.to(torch.float) / (ret.size(-1) - 1) * 0.5 + 0.5) * \
(vals > 0.1)
if args.output_flags['perf']:
perf_times[i] = perf.checkpoint()
progress_bar(i, n, 'Inferring...')
i += 1
global_offset += n_rays
# 4. Save results
print('Saving results...')
os.makedirs(output_dir, exist_ok=True)
for key in out:
shape = [n] + list(dataset.view_res) + list(out[key].size()[1:])
out[key] = out[key].view(shape)
out['bins'] = out['bins'].permute(0, 3, 1, 2)
if args.output_flags['perf']:
perf_errors = torch.ones(n) * math.nan
perf_ssims = torch.ones(n) * math.nan
if dataset.view_images != None:
for i in range(n):
perf_errors[i] = loss_func(dataset.view_images[i], out['color'][i]).item()
perf_ssims[i] = ssim(dataset.view_images[i:i + 1],
out['color'][i:i + 1]).item() * 100
perf_mean_time = torch.mean(perf_times).item()
perf_mean_error = torch.mean(perf_errors).item()
perf_name = 'perf_%s_%.1fms_%.2e.csv' % (
output_dataset_id, perf_mean_time, perf_mean_error)
# Remove old performance reports
for file in os.listdir(output_dir):
if file.startswith(f'perf_{output_dataset_id}'):
os.remove(f"{output_dir}/{file}")
# Save new performance reports
with open(f"{output_dir}/{perf_name}", 'w') as fp:
fp.write('View, Time, PSNR, SSIM\n')
fp.writelines([
f'{dataset.view_idxs[i]}, {perf_times[i].item():.2f}, '
f'{img.mse2psnr(perf_errors[i].item()):.2f}, {perf_ssims[i].item():.2f}\n'
for i in range(n)
])
output_subdir = f"{output_dir}/{output_dataset_id}_bins"
os.makedirs(output_subdir, exist_ok=True)
img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.view_idxs])
if __name__ == "__main__":
if args.test:
test()
else:
train()
from utils import netio
from pathlib import Path
dir = "/home/dengnc/dvs/data/classroom/_nets/ms_train_t0.8/_cnerfadv_ioc/"
for epochs in range(1, 151):
path = f"{dir}checkpoint_{epochs}.tar"
if not Path(path).exists():
continue
print(f"Update epoch {epochs}")
s = netio.load_checkpoint(path)[0]
args0 = s["args"]
args0_for_submodel = {
key: value for key, value in args0.items()
if key != "sub_models" and key != "interp_on_coarse"
}
for i in range(len(args0["sub_models"])):
args0["sub_models"][i] = {**args0_for_submodel, **args0["sub_models"][i]}
if epochs >= 30:
args0["sub_models"][0]["n_samples"] = 64
elif epochs >= 10:
args0["sub_models"][0]["n_samples"] = 32
if epochs >= 70:
args0["sub_models"][1]["n_samples"] = 128
if epochs >= 120:
args0["sub_models"][2]["n_samples"] = 256
netio.save_checkpoint(s, dir, epochs)
\ No newline at end of file
## Fast Super Resolution CNN
![pic](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN/img/framework.png)
### note
this model has high possibility to diverge after 20 epochs.
\ No newline at end of file
import torch
import torch.nn as nn
class Net(torch.nn.Module):
def __init__(self, num_channels, upscale_factor, d=64, s=12, m=4):
super(Net, self).__init__()
self.first_part = nn.Sequential(
nn.Conv2d(in_channels=num_channels, out_channels=d,
kernel_size=5, stride=1, padding=2),
nn.PReLU()
)
self.layers = []
self.layers += [
nn.Conv2d(in_channels=d, out_channels=s,
kernel_size=1, stride=1, padding=0),
nn.PReLU()
]
for _ in range(m):
self.layers += [
nn.Conv2d(in_channels=s, out_channels=s,
kernel_size=3, stride=1, padding=1),
nn.PReLU()
]
self.layers += [
nn.Conv2d(in_channels=s, out_channels=d,
kernel_size=1, stride=1, padding=0),
nn.PReLU()
]
self.mid_part = nn.Sequential(*self.layers)
# Deconvolution
if upscale_factor % 2:
self.last_part = nn.ConvTranspose2d(
in_channels=d, out_channels=num_channels, kernel_size=9,
stride=upscale_factor, padding=5 - (upscale_factor + 1) // 2)
else:
self.last_part = nn.ConvTranspose2d(
in_channels=d, out_channels=num_channels, kernel_size=9,
stride=upscale_factor, padding=5 - upscale_factor // 2,
output_padding=1)
def forward(self, x):
out = self.first_part(x)
out = self.mid_part(out)
out = self.last_part(out)
return out
def weight_init(self, mean=0.0, std=0.02):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0.0, 0.0001)
if m.bias is not None:
m.bias.data.zero_()
from __future__ import print_function
from math import log10
import sys
import torch
import torch.backends.cudnn as cudnn
import torchvision
from .model import Net
from utils.progress_bar import progress_bar
class FSRCNNTrainer(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(FSRCNNTrainer, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
def build_model(self):
self.model = Net(
num_channels=1, upscale_factor=self.upscale_factor).to(self.device)
self.model.weight_init(mean=0.0, std=0.2)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay
def save_model(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader),
'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(
train_loss / len(self.training_loader)))
return iters
def test(self):
self.model.eval()
avg_psnr = 0
with torch.no_grad():
for batch_num, (data, target) in enumerate(self.testing_loader):
data, target = data.to(self.device), target.to(self.device)
prediction = self.model(data)
mse = self.criterion(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader),
'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(
avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
for epoch in range(1, self.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
self.train()
self.test()
self.scheduler.step(epoch)
if epoch == self.nEpochs:
self.save_model()
## Super Resolution CNN
The authors of the SRCNN describe their network, pointing out the equivalence of their method to the sparse-coding method, which is a widely used learning method for image SR. This is an important and educational aspect of their work, because it shows how example-based learning methods can be adapted and generalized to CNN models.
The SRCNN consists of the following operations:
1. **Preprocessing**: Up-scales LR image to desired HR size.
2. **Feature extraction**: Extracts a set of feature maps from the up-scaled LR image.
3. **Non-linear mapping**: Maps the feature maps representing LR to HR patches.
4. **Reconstruction**: Produces the HR image from HR patches.
Operations 2–4 above can be cast as a convolutional layer in a CNN that accepts as input the preprocessed images from step 1 above, and outputs the HR image
import torch
import torch.nn as nn
class Net(torch.nn.Module):
def __init__(self, num_channels, base_filter, upscale_factor=2):
super(Net, self).__init__()
self.layers = torch.nn.Sequential(
nn.Conv2d(in_channels=num_channels, out_channels=base_filter, kernel_size=9, stride=1, padding=4, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter, out_channels=base_filter // 2, kernel_size=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=base_filter // 2, out_channels=num_channels * (upscale_factor ** 2), kernel_size=5, stride=1, padding=2, bias=True),
nn.PixelShuffle(upscale_factor)
)
def forward(self, x):
out = self.layers(x)
return out
def weight_init(self, mean, std):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()
from __future__ import print_function
from math import log10
import sys
import torch
import torch.backends.cudnn as cudnn
import torchvision
from .model import Net
from utils.progress_bar import progress_bar
class SRCNNTrainer(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(SRCNNTrainer, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
def build_model(self, num_channels):
self.model = Net(num_channels=num_channels, base_filter=64, upscale_factor=self.upscale_factor).to(self.device)
self.model.weight_init(mean=0.0, std=0.01)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5)
def save_model(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters, channels = None):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data =data.to(self.device)
target = target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))
return iters
def test(self):
self.model.eval()
avg_psnr = 0
with torch.no_grad():
for batch_num, (data, target) in enumerate(self.testing_loader):
data, target = data.to(self.device), target.to(self.device)
prediction = self.model(data)
mse = self.criterion(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
for epoch in range(1, self.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
self.train()
self.test()
self.scheduler.step(epoch)
if epoch == self.nEpochs:
self.save_model()
# SRGAN: Super-Resolution using GANs
This is a complete Pytorch implementation of [Christian Ledig et al: "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"](https://arxiv.org/abs/1609.04802),
reproducing their results.
This paper's main result is that through using an adversarial and a content loss, a convolutional neural network is able to produce sharp, almost photo-realistic upsamplings of images.
The implementation tries to be as faithful as possible to the original paper.
See [implementation details](#method-and-implementation-details) for a closer look.
## Method and Implementation Details
Architecture diagram of the super-resolution and discriminator networks by Ledig et al:
<p align='center'>
<img src='https://github.com/mseitzer/srgan/blob/master/images/architecture.png' width=580>
</p>
The implementation tries to stay as close as possible to the details given in the paper.
As such, the pretrained SRGAN is also trained with 1e6 and 1e5 update steps.
The high amount of update steps proved to be essential for performance, which pretty much monotonically increases with training time.
Some further implementation choices where the paper does not give any details:
- Initialization: orthogonal for the super-resolution network, randomly from a normal distribution with std=0.02 for the discriminator network
- Padding: reflection padding (instead of the more commonly used zero padding)
## Batch-size
batch size of 2 is recommended if GPU has only 8G RAM.
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
def swish(x):
return x * torch.sigmoid(x)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, kernel, out_channels, stride):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=kernel // 2)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel, stride=stride, padding=kernel // 2)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
y = swish(self.bn1(self.conv1(x)))
return self.bn2(self.conv2(y)) + x
class UpsampleBlock(nn.Module):
# Implements resize-convolution
def __init__(self, in_channels):
super(UpsampleBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * 4, kernel_size=3, stride=1, padding=1)
self.shuffler = nn.PixelShuffle(2)
def forward(self, x):
return swish(self.shuffler(self.conv(x)))
class Generator(nn.Module):
def __init__(self, n_residual_blocks, upsample_factor, num_channel=1, base_filter=64):
super(Generator, self).__init__()
self.n_residual_blocks = n_residual_blocks
self.upsample_factor = upsample_factor
self.conv1 = nn.Conv2d(num_channel, base_filter, kernel_size=9, stride=1, padding=4)
for i in range(self.n_residual_blocks):
self.add_module('residual_block' + str(i + 1), ResidualBlock(in_channels=base_filter, out_channels=base_filter, kernel=3, stride=1))
self.conv2 = nn.Conv2d(base_filter, base_filter, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(base_filter)
for i in range(self.upsample_factor // 2):
self.add_module('upsample' + str(i + 1), UpsampleBlock(base_filter))
self.conv3 = nn.Conv2d(base_filter, num_channel, kernel_size=9, stride=1, padding=4)
def forward(self, x):
x = swish(self.conv1(x))
y = x.clone()
for i in range(self.n_residual_blocks):
y = self.__getattr__('residual_block' + str(i + 1))(y)
x = self.bn2(self.conv2(y)) + x
for i in range(self.upsample_factor // 2):
x = self.__getattr__('upsample' + str(i + 1))(x)
return self.conv3(x)
def weight_init(self, mean=0.0, std=0.02):
for m in self._modules:
normal_init(self._modules[m], mean, std)
class Discriminator(nn.Module):
def __init__(self, num_channel=1, base_filter=64):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(num_channel, base_filter, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(base_filter, base_filter, kernel_size=3, stride=2, padding=1)
self.bn2 = nn.BatchNorm2d(base_filter)
self.conv3 = nn.Conv2d(base_filter, base_filter * 2, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(base_filter * 2)
self.conv4 = nn.Conv2d(base_filter * 2, base_filter * 2, kernel_size=3, stride=2, padding=1)
self.bn4 = nn.BatchNorm2d(base_filter * 2)
self.conv5 = nn.Conv2d(base_filter * 2, base_filter * 4, kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2d(base_filter * 4)
self.conv6 = nn.Conv2d(base_filter * 4, base_filter * 4, kernel_size=3, stride=2, padding=1)
self.bn6 = nn.BatchNorm2d(base_filter * 4)
self.conv7 = nn.Conv2d(base_filter * 4, base_filter * 8, kernel_size=3, stride=1, padding=1)
self.bn7 = nn.BatchNorm2d(base_filter * 8)
self.conv8 = nn.Conv2d(base_filter * 8, base_filter * 8, kernel_size=3, stride=2, padding=1)
self.bn8 = nn.BatchNorm2d(base_filter * 8)
# Replaced original paper FC layers with FCN
self.conv9 = nn.Conv2d(base_filter * 8, num_channel, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = swish(self.conv1(x))
x = swish(self.bn2(self.conv2(x)))
x = swish(self.bn3(self.conv3(x)))
x = swish(self.bn4(self.conv4(x)))
x = swish(self.bn5(self.conv5(x)))
x = swish(self.bn6(self.conv6(x)))
x = swish(self.bn7(self.conv7(x)))
x = swish(self.bn8(self.conv8(x)))
x = self.conv9(x)
return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1)
def weight_init(self, mean=0.0, std=0.02):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()
from __future__ import print_function
from math import log10
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
from torchvision.models.vgg import vgg16
from .model import Generator, Discriminator
from utils.progress_bar import progress_bar
class SRGANTrainer(object):
def __init__(self, config, training_loader, testing_loader, writer):
super(SRGANTrainer, self).__init__()
self.GPU_IN_USE = torch.cuda.is_available()
self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu')
self.netG = None
self.netD = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.epoch_pretrain = 10
self.criterionG = None
self.criterionD = None
self.optimizerG = None
self.optimizerD = None
self.feature_extractor = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.num_residuals = 16
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
def build_model(self):
self.netG = Generator(n_residual_blocks=self.num_residuals, upsample_factor=self.upscale_factor, base_filter=64, num_channel=1).to(self.device)
self.netD = Discriminator(base_filter=64, num_channel=1).to(self.device)
self.feature_extractor = vgg16(pretrained=True)
self.netG.weight_init(mean=0.0, std=0.2)
self.netD.weight_init(mean=0.0, std=0.2)
self.criterionG = nn.MSELoss()
self.criterionD = nn.BCELoss()
torch.manual_seed(self.seed)
if self.GPU_IN_USE:
torch.cuda.manual_seed(self.seed)
self.feature_extractor.cuda()
cudnn.benchmark = True
self.criterionG.cuda()
self.criterionD.cuda()
self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.9, 0.999))
self.optimizerD = optim.SGD(self.netD.parameters(), lr=self.lr / 100, momentum=0.9, nesterov=True)
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizerG, milestones=[50, 75, 100], gamma=0.5) # lr decay
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizerD, milestones=[50, 75, 100], gamma=0.5) # lr decay
@staticmethod
def to_data(x):
if torch.cuda.is_available():
x = x.cpu()
return x.data
def save(self):
g_model_out_path = "SRGAN_Generator_model_path.pth"
d_model_out_path = "SRGAN_Discriminator_model_path.pth"
torch.save(self.netG, g_model_out_path)
torch.save(self.netD, d_model_out_path)
print("Checkpoint saved to {}".format(g_model_out_path))
print("Checkpoint saved to {}".format(d_model_out_path))
def pretrain(self):
self.netG.train()
for batch_num, (_, data, target) in enumerate(self.training_loader):
data, target = data.to(self.device), target.to(self.device)
self.netG.zero_grad()
loss = self.criterionG(self.netG(data), target)
loss.backward()
self.optimizerG.step()
def train(self, epoch, iters):
# models setup
self.netG.train()
self.netD.train()
g_train_loss = 0
d_train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
# setup noise
real_label = torch.ones(data.size(0), data.size(1)).to(self.device)
fake_label = torch.zeros(data.size(0), data.size(1)).to(self.device)
data, target = data.to(self.device), target.to(self.device)
# Train Discriminator
self.optimizerD.zero_grad()
d_real = self.netD(target)
d_real_loss = self.criterionD(d_real, real_label)
d_fake = self.netD(self.netG(data))
d_fake_loss = self.criterionD(d_fake, fake_label)
d_total = d_real_loss + d_fake_loss
d_train_loss += d_total.item()
d_total.backward()
self.optimizerD.step()
# Train generator
self.optimizerG.zero_grad()
g_real = self.netG(data)
g_fake = self.netD(g_real)
gan_loss = self.criterionD(g_fake, real_label)
mse_loss = self.criterionG(g_real, target)
g_total = mse_loss + 1e-3 * gan_loss
g_train_loss += g_total.item()
g_total.backward()
self.optimizerG.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader), 'G_Loss: %.4f | D_Loss: %.4f' % (g_train_loss / (batch_num + 1), d_train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("G_Loss", g_train_loss / (batch_num + 1), iters)
self.writer.add_scalar("D_Loss", d_train_loss / (batch_num + 1), iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([g_real, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average G_Loss: {:.4f}".format(g_train_loss / len(self.training_loader)))
return iters
def test(self):
self.netG.eval()
avg_psnr = 0
with torch.no_grad():
for batch_num, (data, target) in enumerate(self.testing_loader):
data, target = data.to(self.device), target.to(self.device)
prediction = self.netG(data)
mse = self.criterionG(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
for epoch in range(1, self.epoch_pretrain + 1):
self.pretrain()
print("{}/{} pretrained".format(epoch, self.epoch_pretrain))
for epoch in range(1, self.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
self.train()
self.test()
self.scheduler.step(epoch)
if epoch == self.nEpochs:
self.save()
import torch.nn as nn
import torch.nn.init as init
class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, stride=1, padding=1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv4.weight)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.pixel_shuffle(x)
return x
from __future__ import print_function
from math import log10
import sys
import torch
import torch.backends.cudnn as cudnn
import torchvision
from .model import Net
from utils.progress_bar import progress_bar
class SubPixelTrainer(object):
def __init__(self, config, training_loader, testing_loader, writer=None):
super(SubPixelTrainer, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
self.writer = writer
def build_model(self, num_channels):
if num_channels != 1:
raise ValueError('num_channels must be 1')
self.model = Net(upscale_factor=self.upscale_factor).to(self.device)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay
def save(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self, epoch, iters, channels=None):
self.model.train()
train_loss = 0
for batch_num, (_, data, target) in enumerate(self.training_loader):
if channels:
data = data[..., channels, :, :]
target = target[..., channels, :, :]
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
out = self.model(data)
loss = self.criterion(out, target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
sys.stdout.write('Epoch %d: ' % epoch)
progress_bar(batch_num, len(self.training_loader),
'Loss: %.4f' % (train_loss / (batch_num + 1)))
if self.writer:
self.writer.add_scalar("loss", loss, iters)
if iters % 100 == 0:
output_vs_gt = torch.stack([out, target], 1) \
.flatten(0, 1).detach()
self.writer.add_image(
"Output_vs_gt",
torchvision.utils.make_grid(
output_vs_gt, nrow=2).cpu().numpy(),
iters)
iters += 1
print(" Average Loss: {:.4f}".format(
train_loss / len(self.training_loader)))
return iters
def test(self):
self.model.eval()
avg_psnr = 0
with torch.no_grad():
for batch_num, (data, target) in enumerate(self.testing_loader):
data, target = data.to(self.device), target.to(self.device)
prediction = self.model(data)
mse = self.criterion(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader),
'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(
avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
for epoch in range(1, self.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
self.train()
self.test()
self.scheduler.step(epoch)
if epoch == self.nEpochs:
self.save()
from __future__ import print_function
import argparse
import os
import sys
import torch
import torch.nn.functional as nn_f
from tensorboardX.writer import SummaryWriter
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
# ===========================================================
# Training settings
# ===========================================================
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
# hyper-parameters
parser.add_argument('--device', type=int, default=3,
help='Which CUDA device to use.')
parser.add_argument('--batchSize', type=int, default=1,
help='training batch size')
parser.add_argument('--testBatchSize', type=int,
default=1, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=20,
help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01,
help='Learning Rate. Default=0.01')
parser.add_argument('--seed', type=int, default=123,
help='random seed to use. Default=123')
parser.add_argument('--dataset', type=str, required=True,
help='dataset directory')
parser.add_argument('--test', type=str, help='path of model to test')
parser.add_argument('--testOutPatt', type=str, help='test output path pattern')
parser.add_argument('--color', type=str, default='rgb',
help='color')
# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int,
default=2, help="super resolution upscale factor")
#parser.add_argument('--model', '-m', type=str, default='srgan', help='choose which model is going to use')
args = parser.parse_args()
# Select device
torch.cuda.set_device(args.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from utils import misc
from utils import netio
from utils import img
from utils import color
#from .upsampling.SubPixelCNN.solver import SubPixelTrainer as Solver
from upsampling.SRCNN.solver import SRCNNTrainer as Solver
from upsampling.upsampling_dataset import UpsamplingDataset
from data.loader import FastDataLoader
os.chdir(args.dataset)
print('Change working directory to ' + os.getcwd())
run_dir = 'run/'
args.color = color.from_str(args.color)
def train():
os.makedirs(run_dir, exist_ok=True)
train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
'gt/view_%04d.png', color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.batchSize,
shuffle=True,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model(3 if args.color == color.RGB else 1)
iters = 0
for epoch in range(1, args.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
iters = trainer.train(epoch, iters,
channels=slice(2, 3) if args.color == color.YCbCr
else None)
netio.save(run_dir + 'model-epoch_%d.pth' % args.nEpochs, trainer.model)
def test():
os.makedirs(os.path.dirname(args.testOutPatt), exist_ok=True)
train_set = UpsamplingDataset(
'.', 'input/out_view_%04d.png', None, color=args.color)
training_data_loader = FastDataLoader(dataset=train_set,
batch_size=args.testBatchSize,
shuffle=False,
drop_last=False)
trainer = Solver(args, training_data_loader, training_data_loader,
SummaryWriter(run_dir))
trainer.build_model(3 if args.color == color.RGB else 1)
netio.load(args.test, trainer.model)
for idx, input, _ in training_data_loader:
if args.color == color.YCbCr:
output_y = trainer.model(input[:, -1:])
output_cbcr = nn_f.upsample(input[:, 0:2], scale_factor=2)
output = color.ycbcr2rgb(torch.cat([output_cbcr, output_y], -3))
else:
output = trainer.model(input)
img.save(output, args.testOutPatt % idx)
def main():
if (args.test):
test()
else:
train()
if __name__ == '__main__':
main()
import os
import torch
import torchvision.transforms.functional as trans_f
from utils import device
from utils import color
from utils import img
class UpsamplingDataset(torch.utils.data.dataset.Dataset):
"""
Dataset for upsampling task
"""
def __init__(self, data_dir: str, input_patt: str, gt_patt: str,
c: int, load_once: bool = True):
"""
Initialize dataset for upsampling task
:param data_dir: directory of dataset
:param input_patt: file pattern for input (low resolution) images
:param gt_patt: file pattern for ground truth (high resolution) images
:param load_once: load all samples to current device at once to accelerate
training, suitable for small dataset
:param load_gt: whether to load ground truth images
"""
self.input_patt = os.path.join(data_dir, input_patt)
self.gt_patt = os.path.join(data_dir, gt_patt) if gt_patt != None else None
self.n = len(list(filter(
lambda file_name: os.path.exists(file_name),
[self.input_patt % i for i in range(
len(os.listdir(os.path.dirname(self.input_patt))))]
)))
self.load_once = load_once
self.load_gt = self.gt_patt != None
self.color = c
self.input = img.load([self.input_patt % i for i in range(self.n)]) \
.to(device.default()) if self.load_once else None
self.gt = img.load([self.gt_patt % i for i in range(self.n)]) \
.to(device.default()) if self.load_once and self.load_gt else None
if self.color == color.GRAY:
self.input = trans_f.rgb_to_grayscale(self.input)
self.gt = trans_f.rgb_to_grayscale(self.gt) \
if self.gt != None else None
elif self.color == color.YCbCr:
self.input = color.rgb2ycbcr(self.input)
self.gt = color.rgb2ycbcr(self.gt) if self.gt != None else None
def __len__(self):
return self.n
def __getitem__(self, idx):
if self.load_once:
return idx, self.input[idx], self.gt[idx] if self.load_gt else False
if isinstance(idx, torch.Tensor):
input = img.load([self.input_patt % i for i in idx])
gt = img.load([self.gt_patt % i for i in idx]) if self.load_gt else False
else:
input = img.load([self.input_patt % idx])
gt = img.load([self.gt_patt % idx]) if self.load_gt else False
if self.color == color.GRAY:
input = trans_f.rgb_to_grayscale(input)
gt = trans_f.rgb_to_grayscale(gt) if isinstance(gt, torch.Tensor) else False
return idx, input, gt
elif self.color == color.YCbCr:
input = color.rgb2ycbcr(input)
gt = color.rgb2ycbcr(gt) if isinstance(gt, torch.Tensor) else False
return idx, input, gt
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment