from collections import defaultdict from statistics import mean from tqdm import tqdm from tensorboardX import SummaryWriter from torch.optim import Optimizer, Adam from torch.optim.lr_scheduler import _LRScheduler, ExponentialLR from utils import netio, logging, loss, misc from utils.args import BaseArgs from utils.profile import Profiler, enable_profile, profile from utils.types import * from data import Dataset, RaysLoader from model import Model trainer_classes: dict[str, "Trainer"] = {} class Trainer: class Args(BaseArgs): max_iters: int | None max_epochs: int = 20 checkpoint_interval: int | None batch_size: int = 4096 loss: list[str] = ["Color_L2", "CoarseColor_L2"] lr: float = 5e-4 lr_decay: float | None profile_iters: int | None trainset: str args: Args states: dict[str, Any] optimizer: Optimizer scheduler: _LRScheduler | None loss_defs = { "Color_L2": { "fn": lambda out, gt: loss.mse_loss(out["color"], gt["color"]), "required_outputs": ["color"] }, "CoarseColor_L2": { "fn": lambda out, gt: loss.mse_loss(out["coarse_color"], gt["color"]), "required_outputs": ["coarse_color"] }, "Density_Reg": { "fn": lambda out, gt: loss.cauchy_loss(out["densities"], 1e4) * 1e-4, "required_outputs": ["densities"] } } @property def profile_mode(self) -> bool: return self.profile_iters is not None @staticmethod def get_class(typename: str) -> Type["Trainer"] | None: return trainer_classes.get(typename) def __init__(self, model: Model, run_dir: Path, args: Args = None) -> None: self.model = model.train() self.run_dir = run_dir self.args = args or self.__class__.Args() self.epoch = 0 self.iters = 0 self.profile_warmup_iters = 10 self.profile_iters = self.args.profile_iters if self.profile_mode: # profile mode self.max_iters = self.profile_warmup_iters + self.profile_iters self.max_epochs = None self.checkpoint_interval = None elif self.args.max_iters: # iters mode self.max_iters = self.args.max_iters self.max_epochs = None self.checkpoint_interval = self.args.checkpoint_interval or 10000 else: # epochs mode self.max_iters = None self.max_epochs = self.args.max_epochs self.checkpoint_interval = self.args.checkpoint_interval or 10 self._init_optimizer() self._init_scheduler() self.required_outputs = [] for key in self.args.loss: self.required_outputs += self.loss_defs[key]["required_outputs"] self.required_outputs = list(set(self.required_outputs)) if self.profile_mode: # Enable performance measurement in profile mode def handle_profile_result(result: Profiler.ProfileResult): print(result.get_report()) exit() enable_profile(self.profile_warmup_iters, self.profile_iters, handle_profile_result) else: # Enable logging (Tensorboard & txt) in normal mode tb_log_dir = self.run_dir / "_log" tb_log_dir.mkdir(exist_ok=True) self.tb_writer = SummaryWriter(tb_log_dir, purge_step=0) logging.initialize(self.run_dir / "train.log") logging.print_and_log(f"Model arguments: {self.model.args}") logging.print_and_log(f"Trainer arguments: {self.args}") # Debug: print model structure print(model) def state_dict(self) -> dict[str, Any]: return { "model": self.model.state_dict(), "epoch": self.epoch, "iters": self.iters, "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict() if self.scheduler else None } def load_state_dict(self, state_dict: dict[str, Any]): self.epoch = state_dict.get("epoch", self.epoch) self.iters = state_dict.get("iters", self.iters) if "model" in state_dict: self.model.load_state_dict(state_dict["model"]) if "optimizer" in state_dict: self.optimizer.load_state_dict(state_dict["optimizer"]) if self.scheduler and "scheduler" in state_dict: self.scheduler.load_state_dict(state_dict["scheduler"]) def reset_optimizer(self): self._init_optimizer() if self.scheduler is not None: scheduler_state = self.scheduler.state_dict() self._init_scheduler() self.scheduler.load_state_dict(scheduler_state) def train(self, dataset: Dataset): self.rays_loader = RaysLoader(dataset, self.args.batch_size, shuffle=True, device=self.model.device) self.forward_chunk_size = self.args.batch_size if self.max_iters: print(f"Begin training... Max iters: {self.max_iters}") self.progress = tqdm(total=self.max_iters, dynamic_ncols=True) self.rays_iter = self.rays_loader.__iter__() while self.iters < self.max_iters: self._train_iters(min(self.checkpoint_interval, self.max_iters - self.iters)) self._save_checkpoint() else: print(f"Begin training... Max epochs: {self.max_epochs}") while self.epoch < self.max_epochs: self._train_epoch() self._save_checkpoint() print("Train finished") @staticmethod def create(model: Model, run_dir: PathLike, typename: str, args: dict[str, Any] = None) -> "Trainer": if typename not in trainer_classes: raise ValueError(f"Class {typename} is not found") return trainer_classes.get(typename)(model, run_dir, args) def _init_scheduler(self): self.scheduler = self.args.lr_decay and ExponentialLR(self.optimizer, self.args.lr_decay) def _init_optimizer(self): self.optimizer = Adam(self.model.parameters(), lr=self.args.lr) def _save_checkpoint(self): if self.checkpoint_interval is None: return ckpt = { "args": { "model": self.model.__class__.__name__, "model_args": vars(self.model.args), "trainer": self.__class__.__name__, "trainer_args": vars(self.args) }, "states": self.state_dict() } if self.max_iters: # For iters mode, a checkpoint will be saved every `checkpoint_interval` iterations netio.save_checkpoint(ckpt, self.run_dir, self.iters) else: # For epochs mode, a checkpoint will be saved every epoch. # Checkpoints which don't match `checkpoint_interval` will be cleaned later netio.clean_checkpoint(self.run_dir, self.checkpoint_interval) netio.save_checkpoint(ckpt, self.run_dir, self.epoch) def _update_progress(self, loss: float = 0): self.progress.set_postfix_str(f"Loss: {loss:.2e}" if loss > 0 else "") self.progress.update() @profile("Forward") def _forward(self, rays: Rays) -> ReturnData: return self.model(rays, *self.required_outputs) @profile("Compute Loss") def _compute_loss(self, rays: Rays, out: ReturnData) -> dict[str, torch.Tensor]: torch.isnan gt = rays.select(out["rays_filter"]) if "rays_filter" in out else rays loss_terms: dict[str, torch.Tensor] = {} for key in self.args.loss: try: loss_terms[key] = self.loss_defs[key]["fn"](out, gt) except KeyError: pass # Debug: print loss terms #self.progress.write(",".join([f"{key}: {value.item():.2e}" for key, value in loss_terms.items()])) return loss_terms @profile("Train iteration") def _train_iter(self, rays: Rays) -> float: try: self.optimizer.zero_grad(True) loss_terms = defaultdict(list) for offset in range(0, rays.shape[0], self.forward_chunk_size): rays_chunk = rays.select(slice(offset, offset + self.forward_chunk_size)) out_chunk = self._forward(rays_chunk) loss_chunk = self._compute_loss(rays_chunk, out_chunk) loss_value = sum(loss_chunk.values()) with profile("Backward"): loss_value.backward() loss_terms["Overall_Loss"].append(loss_value.item()) for key, value in loss_chunk.items(): loss_terms[key].append(value.item()) loss_terms = {key: mean(value) for key, value in loss_terms.items()} with profile("Update"): self.optimizer.step() if self.scheduler: self.scheduler.step() # Debug: print lr #self.progress.write(f"Learning rate: {self.optimizer.param_groups[0]['lr']}") self.iters += 1 if hasattr(self, "tb_writer"): for key, value in loss_terms.items(): self.tb_writer.add_scalar(f"Loss/{key}", value, self.iters) return loss_terms["Overall_Loss"] except RuntimeError as e: if not e.__str__().startswith("CUDA out of memory"): raise e self.progress.write("CUDA out of memory, half forward batch and retry.") logging.warning("CUDA out of memory, half forward batch and retry.") self.forward_chunk_size //= 2 torch.cuda.empty_cache() return self._train_iter(rays) def _train_iters(self, iters: int): recent_loss_list = [] tot_loss = 0 train_iters_node = Profiler.Node("Train Iterations") for _ in range(iters): try: rays = self.rays_iter.__next__() except StopIteration: self.rays_iter = self.rays_loader.__iter__() # A new epoch rays = self.rays_iter.__next__() loss_val = self._train_iter(rays) recent_loss_list = (recent_loss_list + [loss_val])[-50:] # Keep recent 50 iterations recent_avg_loss = sum(recent_loss_list) / len(recent_loss_list) tot_loss += loss_val self._update_progress(recent_avg_loss) train_iters_node.close() torch.cuda.synchronize() avg_time = train_iters_node.device_duration / 1000 / iters avg_loss = tot_loss / iters state_str = f"Iter {self.iters}: Avg. {misc.format_time(avg_time)}/iter; Loss: {avg_loss:.2e}" self.progress.write(state_str) logging.info(state_str) def _train_epoch(self): iters_per_epoch = len(self.rays_loader) recent_loss_list = [] tot_loss = 0 self.progress = tqdm(total=iters_per_epoch, desc=f"Epoch {self.epoch + 1:<3d}", leave=False, dynamic_ncols=True) train_epoch_node = Profiler.Node("Train Epoch") for rays in self.rays_loader: with profile("Train iteration"): loss_val = self._train_iter(rays) recent_loss_list = (recent_loss_list + [loss_val])[-50:] recent_avg_loss = sum(recent_loss_list) / len(recent_loss_list) tot_loss += loss_val self._update_progress(recent_avg_loss) self.progress.close() train_epoch_node.close() torch.cuda.synchronize() self.epoch += 1 epoch_time = train_epoch_node.device_duration / 1000 avg_time = epoch_time / iters_per_epoch avg_loss = tot_loss / iters_per_epoch state_str = f"Epoch {self.epoch} spent {misc.format_time(epoch_time)} "\ f"(Avg. {misc.format_time(avg_time)}/iter). Loss is {avg_loss:.2e}." logging.print_and_log(state_str)