import csv import logging import sys import time import torch import torch.nn.functional as nn_f from typing import Dict from pathlib import Path import loss from utils.constants import HUGE_FLOAT from utils.misc import format_time from utils.progress_bar import progress_bar from utils.perf import Perf, checkpoint, enable_perf, perf, get_perf_result from data.loader import DataLoader from model.base import BaseModel from model import save train_classes = {} class BaseTrainMeta(type): def __new__(cls, name, bases, attrs): new_cls = type.__new__(cls, name, bases, attrs) train_classes[name] = new_cls return new_cls class Train(object, metaclass=BaseTrainMeta): @property def perf_mode(self): return self.perf_frames > 0 def __init__(self, model: BaseModel, *, run_dir: Path, states: dict = None, perf_frames: int = 0) -> None: super().__init__() self.model = model self.epoch = 0 self.iters = 0 self.run_dir = run_dir self.model.trainer = self self.model.train() self.reset_optimizer() if states: if 'epoch' in states: self.epoch = states['epoch'] if 'iters' in states: self.iters = states['iters'] if 'opti' in states: self.optimizer.load_state_dict(states['opti']) # For performance measurement self.perf_frames = perf_frames if self.perf_mode: enable_perf() def reset_optimizer(self): self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4) def train(self, data_loader: DataLoader, max_epochs: int): self.data_loader = data_loader self.iters_per_epoch = self.perf_frames or len(data_loader) print("Begin training...") while self.epoch < max_epochs: self.epoch += 1 self._train_epoch() self._save_checkpoint() print("Train finished") def _save_checkpoint(self): save(self.run_dir / f'checkpoint_{self.epoch}.tar', self.model, epoch=self.epoch, iters=self.iters, opti=self.optimizer.state_dict()) for i in range(1, self.epoch): if i % 10 != 0: (self.run_dir / f'checkpoint_{i}.tar').unlink(missing_ok=True) def _show_progress(self, iters_in_epoch: int, loss: Dict[str, float] = {}): loss_val = loss.get('val', 0) loss_min = loss.get('min', 0) loss_max = loss.get('max', 0) loss_avg = loss.get('avg', 0) iters_per_epoch = self.perf_frames or len(self.data_loader) progress_bar(iters_in_epoch, iters_per_epoch, f"Loss: {loss_val:.2e} ({loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e})", f"Epoch {self.epoch:<3d}", f" {self.run_dir}") def _show_perf(self): s = "Performance Report ==>\n" res = get_perf_result() if res is None: s += "No available data.\n" else: for key, val in res.items(): path_segs = key.split("/") s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n" print(s) @perf def _train_iter(self, rays_o: torch.Tensor, rays_d: torch.Tensor, extra: Dict[str, torch.Tensor]) -> float: out = self.model(rays_o, rays_d, extra_outputs=['energies', 'speculars']) if 'rays_mask' in out: extra = {key: value[out['rays_mask']] for key, value in extra.items()} checkpoint("Forward") self.optimizer.zero_grad() loss_val = loss.mse_loss(out['color'], extra['color']) if self.model.args.get('density_regularization_weight'): loss_val += loss.cauchy_loss(out['energies'], s=self.model.args['density_regularization_scale']) \ * self.model.args['density_regularization_weight'] if self.model.args.get('specular_regularization_weight'): loss_val += loss.cauchy_loss(out['speculars'], s=self.model.args['specular_regularization_scale']) \ * self.model.args['specular_regularization_weight'] checkpoint("Compute loss") loss_val.backward() checkpoint("Backward") self.optimizer.step() checkpoint("Update") return loss_val.item() def _train_epoch(self): iters_in_epoch = 0 loss_min = HUGE_FLOAT loss_max = 0 loss_avg = 0 train_epoch_node = Perf.Node("Train Epoch") self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0}) for idx, rays_o, rays_d, extra in self.data_loader: loss_val = self._train_iter(rays_o, rays_d, extra) loss_min = min(loss_min, loss_val) loss_max = max(loss_max, loss_val) loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1) self.iters += 1 iters_in_epoch += 1 self._show_progress(iters_in_epoch, loss={ 'val': loss_val, 'min': loss_min, 'max': loss_max, 'avg': loss_avg }) if self.perf_mode and iters_in_epoch >= self.perf_frames: self._show_perf() exit() train_epoch_node.close() torch.cuda.synchronize() epoch_dur = train_epoch_node.duration() / 1000 logging.info(f"Epoch {self.epoch} spent {format_time(epoch_dur)} " f"(Avg. {format_time(epoch_dur / self.iters_per_epoch)}/iter). " f"Loss is {loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e}") def _train_epoch_debug(self): # TBR iters_in_epoch = 0 loss_min = HUGE_FLOAT loss_max = 0 loss_avg = 0 self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0}) indices = [] debug_data = [] for idx, rays_o, rays_d, extra in self.data_loader: out = self.model(rays_o, rays_d, extra_outputs=['layers', 'weights']) loss_val = nn_f.mse_loss(out['color'], extra['color']).item() loss_min = min(loss_min, loss_val) loss_max = max(loss_max, loss_val) loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1) self.iters += 1 iters_in_epoch += 1 self._show_progress(iters_in_epoch, loss={ 'val': loss_val, 'min': loss_min, 'max': loss_max, 'avg': loss_avg }) indices.append(idx) debug_data.append(torch.cat([ extra['view_idx'][..., None], extra['pix_idx'][..., None], rays_d, #out['samples'].pts[:, 215:225].reshape(idx.size(0), -1), #out['samples'].dirs[:, :3].reshape(idx.size(0), -1), #out['samples'].voxel_indices[:, 215:225], out['states'].densities[:, 210:230].detach().reshape(idx.size(0), -1), out['states'].energies[:, 210:230].detach().reshape(idx.size(0), -1) # out['color'].detach() ], dim=-1)) # states: VolumnRenderer.States = out['states'] # TBR indices = torch.cat(indices, dim=0) debug_data = torch.cat(debug_data, dim=0) indices, sort = indices.sort() debug_data = debug_data[sort] name = "rand.csv" if self.data_loader.shuffle else "seq.csv" with (self.run_dir / name).open("w") as fp: csv_writer = csv.writer(fp) csv_writer.writerows(torch.cat([indices[:20, None], debug_data[:20]], dim=-1).tolist()) return with (self.run_dir / 'states.csv').open("w") as fp: csv_writer = csv.writer(fp) for chunk_info in states.chunk_infos: csv_writer.writerow( [*chunk_info['range'], chunk_info['hits'], chunk_info['core_i']]) if chunk_info['hits'] > 0: csv_writer.writerows(torch.cat([ chunk_info['samples'].pts, chunk_info['samples'].dirs, chunk_info['samples'].voxel_indices[:, None], chunk_info['colors'], chunk_info['energies'] ], dim=-1).tolist()) csv_writer.writerow([])