import csv import json import logging import torch import torch.nn.functional as nn_f from typing import Any, Dict, Union from pathlib import Path import loss from utils import netio, math from utils.misc import format_time, print_and_log from utils.progress_bar import progress_bar from utils.perf import Perf, enable_perf, perf, get_perf_result from utils.env import set_env from utils.type import InputData, ReturnData from data.loader import DataLoader from model import serialize from model.base import BaseModel train_classes = {} class BaseTrainMeta(type): def __new__(cls, name, bases, attrs): new_cls = type.__new__(cls, name, bases, attrs) train_classes[name] = new_cls return new_cls class Train(object, metaclass=BaseTrainMeta): @property def perf_mode(self): return self.perf_frames > 0 def _arg(self, name: str, default=None): return self.states.get("train", {}).get(name, default) def __init__(self, model: BaseModel, run_dir: Path, states: dict) -> None: super().__init__() print_and_log( f"Create trainer {__class__} with args: {json.dumps(states.get('train', {}))}") self.model = model self.run_dir = run_dir self.states = states self.epoch = states.get("epoch", 0) self.iters = states.get("iters", 0) self.max_epochs = self._arg("max_epochs", 50) self.checkpoint_interval = self._arg("checkpoint_interval", 10) self.perf_frames = self._arg("perf_frames", 0) self.model.train() self.reset_optimizer() if 'opti' in states: self.optimizer.load_state_dict(states['opti']) # For performance measurement if self.perf_mode: enable_perf() self.env = { "trainer": self } def reset_optimizer(self): self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4) def train(self, data_loader: DataLoader): set_env(self.env) self.data_loader = data_loader self.iters_per_epoch = self.perf_frames or len(data_loader) print(f"Begin training... Max epochs: {self.max_epochs}") while self.epoch < self.max_epochs: self._train_epoch() self._save_checkpoint() print("Train finished") def _save_checkpoint(self): (self.run_dir / '_misc').mkdir(exist_ok=True) # Clean checkpoints for i in range(1, self.epoch): if i % self.checkpoint_interval != 0: checkpoint_path = self.run_dir / f'checkpoint_{i}.tar' if checkpoint_path.exists(): checkpoint_path.rename(self.run_dir / f'_misc/checkpoint_{i}.tar') # Save checkpoint self.states.update({ **serialize(self.model), "epoch": self.epoch, "iters": self.iters, "opti": self.optimizer.state_dict() }) netio.save_checkpoint(self.states, self.run_dir, self.epoch) def _show_progress(self, iters_in_epoch: int, avg_loss: float = 0, recent_loss: float = 0): iters_per_epoch = self.perf_frames or len(self.data_loader) progress_bar(iters_in_epoch, iters_per_epoch, f"Loss: {recent_loss:.2e} ({avg_loss:.2e})", f"Epoch {self.epoch + 1:<3d}", f" {self.run_dir.absolute()}") def _show_perf(self): s = "Performance Report ==>\n" res = get_perf_result() if res is None: s += "No available data.\n" else: for key, val in res.items(): path_segs = key.split("/") s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n" print(s) def _forward(self, data: InputData) -> ReturnData: return self.model(data, 'color', 'energies', 'speculars') @perf def _train_iter(self, data: Dict[str, Union[torch.Tensor, Any]]) -> float: def filtered_data(data, filter): if filter is not None: return data[filter] return data with perf("Forward"): if isinstance(data, list): out_colors = [] out_energies = [] out_speculars = [] gt_colors = [] for datum in data: partial_out = self._forward(datum) out_colors.append(partial_out['color']) out_energies.append(partial_out['energies'].flatten()) if 'speculars' in partial_out: out_speculars.append(partial_out['speculars'].flatten()) gt_colors.append(filtered_data(datum["color"], partial_out.get("rays_filter"))) out_colors = torch.cat(out_colors) out_energies = torch.cat(out_energies) out_speculars = torch.cat(out_speculars) if len(out_speculars) > 0 else None gt_colors = torch.cat(gt_colors) else: out = self._forward(data) out_colors = out['color'] out_energies = out['energies'] out_speculars = out.get('speculars') gt_colors = filtered_data(data['color'], out.get("rays_filter")) with perf("Compute loss"): loss_val = loss.mse_loss(out_colors, gt_colors) if self._arg("density_regularization_weight"): loss_val += loss.cauchy_loss(out_energies, s=self._arg("density_regularization_scale"))\ * self._arg("density_regularization_weight") if self._arg("specular_regularization_weight") and out_speculars is not None: loss_val += loss.cauchy_loss(out_speculars, s=self._arg("specular_regularization_scale")) \ * self._arg("specular_regularization_weight") #return loss_val.item() # TODO remove this line with perf("Backward"): self.optimizer.zero_grad(True) loss_val.backward() with perf("Update"): self.optimizer.step() return loss_val.item() def _train_epoch(self): iters_in_epoch = 0 recent_loss = [] tot_loss = 0 train_epoch_node = Perf.Node("Train Epoch") self._show_progress(iters_in_epoch) for data in self.data_loader: loss_val = self._train_iter(data) self.iters += 1 iters_in_epoch += 1 recent_loss = (recent_loss + [loss_val])[-50:] recent_avg_loss = sum(recent_loss) / len(recent_loss) tot_loss += loss_val avg_loss = tot_loss / iters_in_epoch #loss_min = min(loss_min, loss_val) #loss_max = max(loss_max, loss_val) #loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1) self._show_progress(iters_in_epoch, avg_loss=avg_loss, recent_loss=recent_avg_loss) if self.perf_mode and iters_in_epoch >= self.perf_frames: self._show_perf() exit() train_epoch_node.close() torch.cuda.synchronize() self.epoch += 1 epoch_dur = train_epoch_node.duration() / 1000 logging.info(f"Epoch {self.epoch} spent {format_time(epoch_dur)} " f"(Avg. {format_time(epoch_dur / self.iters_per_epoch)}/iter). " f"Loss is {avg_loss:.2e}") #print(list(self.model.model(0).named_parameters())[2]) #print(list(self.model.model(1).named_parameters())[2]) def _train_epoch_debug(self): # TBR iters_in_epoch = 0 loss_min = math.huge loss_max = 0 loss_avg = 0 self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0}) indices = [] debug_data = [] for idx, rays_o, rays_d, extra in self.data_loader: out = self.model(rays_o, rays_d, extra_outputs=['layers', 'weights']) loss_val = nn_f.mse_loss(out['color'], extra['color']).item() loss_min = min(loss_min, loss_val) loss_max = max(loss_max, loss_val) loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1) self.iters += 1 iters_in_epoch += 1 self._show_progress(iters_in_epoch, loss={ 'val': loss_val, 'min': loss_min, 'max': loss_max, 'avg': loss_avg }) indices.append(idx) debug_data.append(torch.cat([ extra['view_idx'][..., None], extra['pix_idx'][..., None], rays_d, #out['samples'].pts[:, 215:225].reshape(idx.size(0), -1), #out['samples'].dirs[:, :3].reshape(idx.size(0), -1), #out['samples'].voxel_indices[:, 215:225], out['states'].densities[:, 210:230].detach().reshape(idx.size(0), -1), out['states'].energies[:, 210:230].detach().reshape(idx.size(0), -1) # out['color'].detach() ], dim=-1)) # states: VolumnRenderer.States = out['states'] # TBR indices = torch.cat(indices, dim=0) debug_data = torch.cat(debug_data, dim=0) indices, sort = indices.sort() debug_data = debug_data[sort] name = "rand.csv" if self.data_loader.shuffle else "seq.csv" with (self.run_dir / name).open("w") as fp: csv_writer = csv.writer(fp) csv_writer.writerows(torch.cat([indices[:20, None], debug_data[:20]], dim=-1).tolist()) return with (self.run_dir / 'states.csv').open("w") as fp: csv_writer = csv.writer(fp) for chunk_info in states.chunk_infos: csv_writer.writerow( [*chunk_info['range'], chunk_info['hits'], chunk_info['core_i']]) if chunk_info['hits'] > 0: csv_writer.writerows(torch.cat([ chunk_info['samples'].pts, chunk_info['samples'].dirs, chunk_info['samples'].voxel_indices[:, None], chunk_info['colors'], chunk_info['energies'] ], dim=-1).tolist()) csv_writer.writerow([])