Commit 6294701e authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 2824f796
Subproject commit 10f49b1e7df38a58fd78451eac91d7ac1a21df64
import os
import sys
import argparse
import shutil
from typing import Mapping
from utils.constants import TINY_FLOAT
import torch
import torch.optim
import math
import time
from tensorboardX import SummaryWriter
from torch import nn
......@@ -49,7 +46,7 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device())
from utils import netio
from utils import misc
from utils import math
from utils import device
from utils import img
from utils import interact
......@@ -243,10 +240,10 @@ def train_loop(data_loader, optimizer, perf, writer, epoch, iters):
for i in range(1, len(out)):
loss_value += loss_mse(out[i]['color'], gt)
if config.depth_ref:
disp_loss_value = loss_mse(torch.reciprocal(out[0]['depth'] + TINY_FLOAT), gt_disp)
disp_loss_value = loss_mse(torch.reciprocal(out[0]['depth'] + math.tiny), gt_disp)
for i in range(1, len(out)):
disp_loss_value += loss_mse(torch.reciprocal(
out[i]['depth'] + TINY_FLOAT), gt_disp)
out[i]['depth'] + math.tiny), gt_disp)
disp_loss_value = disp_loss_value / math.pow(
1 / dataset.depth_range[0] - 1 / dataset.depth_range[1], 2)
else:
......
......@@ -2,102 +2,109 @@
"cells": [
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from utils.voxels import *\n",
"\n",
"bbox, steps = torch.tensor([[-2, -3.14159, 1], [2, 3.14159, 0]]), torch.tensor([2, 3, 3])\n",
"voxel_size = (bbox[1] - bbox[0]) / steps\n",
"voxels = init_voxels(bbox, steps)\n",
"corners, corner_indices = get_corners(voxels, bbox, steps)\n",
"voxel_indices_in_grid = torch.arange(voxels.shape[0])\n",
"emb = torch.nn.Embedding(corners.shape[0], 3, _weight=corners)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([11, 3]) tensor([ 0, -1, -1, 1, -1, -1, 2, 3, 4, -1, 5, 6, -1, 7, 8, -1, 9, 10])\n"
]
}
],
"source": [
"keeps = torch.tensor([True]*18)\n",
"keeps[torch.tensor([1,2,4,5,9,12,15])] = False\n",
"voxels = voxels[keeps]\n",
"corner_indices = corner_indices[keeps]\n",
"grid_indices, _ = to_grid_indices(voxels, bbox, steps=steps)\n",
"voxel_indices_in_grid = grid_indices.new_full([steps.prod().item()], -1)\n",
"voxel_indices_in_grid[grid_indices] = torch.arange(voxels.shape[0])\n",
"print(voxels.shape, voxel_indices_in_grid)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([88, 3]) torch.Size([185, 3]) torch.Size([88, 8])\n"
"Mixin.__init__\n",
"Base.__init__\n",
"Child.__init__\n",
"Mixin.fn\n",
"Mixin.fn\n",
"(<class '__main__.Child'>, <class '__main__.Base'>, <class '__main__.Mixin'>, <class '__main__.Obj'>, <class 'object'>)\n"
]
}
],
"source": [
"new_voxels = split_voxels(voxels, (bbox[1] - bbox[0]) / steps, 2, align_border=False).reshape(-1, 3)\n",
"new_corners, new_corner_indices = get_corners(new_voxels, bbox, steps * 2)\n",
"print(new_voxels.shape, new_corners.shape, new_corner_indices.shape)"
"class Obj:\n",
" def fn(self):\n",
" print(\"Obj.fn\")\n",
"class Base(Obj):\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" print(\"Base.__init__\")\n",
"\n",
" def fn(self):\n",
" super().fn()\n",
" print(\"Base.fn\")\n",
" \n",
" def fn1(self):\n",
" self.fn()\n",
"\n",
"class Mixin(Obj):\n",
" def __init__(self) -> None:\n",
" print(\"Mixin.__init__\")\n",
" self.fn = self._fn\n",
" \n",
" def _fn(self):\n",
" print(\"Mixin.fn\")\n",
"\n",
"class Child(Base, Mixin):\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" print(\"Child.__init__\")\n",
"\n",
" \n",
"\n",
"a = Child()\n",
"a.fn()\n",
"a.fn1()\n",
"print(Child.__mro__)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([ 0, 0, -1, 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4,\n",
" 4, 2, 2, 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, 0, 0, -1,\n",
" 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4, 4, 2, 2,\n",
" 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, -1, -1, 5, 5, 6, 6,\n",
" 6, -1, -1, 5, 5, 6, 6, 6, -1, -1, 7, 7, 8, 8, 8, -1, -1, 7,\n",
" 7, 8, 8, 8, -1, -1, 9, 9, 10, 10, 10, -1, -1, 9, 9, 10, 10, 10,\n",
" -1, -1, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7,\n",
" 7, 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10,\n",
" 10, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7, 7,\n",
" 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10, 10,\n",
" 9, 9, 10, 10, 10])\n",
"tensor(0)\n"
"Base.__init__\n",
"Child.fn: <__main__.Child object at 0x7f62583e0640>\n"
]
}
],
"source": [
"voxel_indices_of_new_corner = voxel_indices_in_grid[to_flat_indices(to_grid_coords(new_corners, bbox, steps=steps).min(steps - 1), steps)]\n",
"print(voxel_indices_of_new_corner)\n",
"p_of_new_corners = (new_corners - voxels[voxel_indices_of_new_corner]) / voxel_size + .5\n",
"print(((new_corners - trilinear_interp(p_of_new_corners, emb(corner_indices[voxel_indices_of_new_corner]))) > 1e-6).sum())"
"class Base:\n",
" def __init__(self) -> None:\n",
" print(\"Base.__init__\")\n",
" \n",
" def fn(self):\n",
" print(\"Base.fn\")\n",
"\n",
" def fn1(self):\n",
" self.fn()\n",
"\n",
"def createChildClass(name):\n",
" def __init__(self):\n",
" super(self.__class__, self).__init__()\n",
" \n",
" def fn(self):\n",
" print(f\"{name}.fn: {self}\")\n",
" \n",
" return type(name, (Base, ), {\n",
" \"__init__\": __init__,\n",
" \"fn\": fn\n",
" })\n",
"\n",
"Child = createChildClass(\"Child\")\n",
"\n",
"a = Child()\n",
"a.fn()"
]
}
],
"metadata": {
"interpreter": {
"hash": "08b118544df3cb8970a671e5837a88fd458f4d4c799ef1fb2709465a22a45b92"
"hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
},
"kernelspec": {
"display_name": "Python 3.9.5 64-bit ('base': conda)",
"display_name": "Python 3.8.5 64-bit ('base': conda)",
"language": "python",
"name": "python3"
},
......@@ -111,7 +118,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
"version": "3.8.5"
},
"orig_nbformat": 4
},
......
......@@ -2,7 +2,8 @@ import os
import argparse
import torch
import torch.nn.functional as nn_f
from math import nan, ceil, prod
import cv2
import numpy as np
from pathlib import Path
parser = argparse.ArgumentParser()
......@@ -10,12 +11,13 @@ parser.add_argument('-m', '--model', type=str,
help='The model file to load for testing')
parser.add_argument('-r', '--output-res', type=str,
help='Output resolution')
parser.add_argument('-o', '--output', nargs='+', type=str, default=['perf', 'color'],
parser.add_argument('-o', '--output', nargs='*', type=str, default=['perf', 'color'],
help='Specify what to output (perf, color, depth, all)')
parser.add_argument('--output-type', type=str, default='image',
help='Specify the output type (image, video, debug)')
parser.add_argument('--views', type=str,
help='Specify the range of views to test')
parser.add_argument('-s', '--samples', type=int)
parser.add_argument('-p', '--prompt', action='store_true',
help='Interactive prompt mode')
parser.add_argument('--time', action='store_true',
......@@ -31,18 +33,18 @@ from utils import color
from utils import interact
from utils import device
from utils import img
from utils import netio
from utils import math
from utils.perf import Perf, enable_perf, get_perf_result
from utils.progress_bar import progress_bar
from data.dataset_factory import *
from data.loader import DataLoader
from utils.constants import HUGE_FLOAT
from data import *
RAYS_PER_BATCH = 2 ** 12
DATA_LOADER_CHUNK_SIZE = 1e8
torch.set_grad_enabled(False)
data_desc_path = DatasetFactory.get_dataset_desc_path(args.dataset)
data_desc_path = get_dataset_desc_path(args.dataset)
os.chdir(data_desc_path.parent)
nets_dir = Path("_nets")
data_desc_path = data_desc_path.name
......@@ -81,18 +83,20 @@ dataset = DatasetFactory.load(data_desc_path, res=args.output_res,
views_to_load=args.views)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
RAYS_PER_BATCH = dataset.res[0] * dataset.res[1] // 4
model_path: Path = nets_dir / args.model
model_name = model_path.parent.name
model = mdl.load(model_path, {
"raymarching_early_stop_tolerance": 0.01,
# "raymarching_chunk_size_or_sections": [8],
"perturb_sample": False
})[0].to(device.default()).eval()
model_class = model.__class__.__name__
model_args = model.args
print(f"model: {model_name} ({model_class})")
print("args:", json.dumps(model.args0))
states, _ = netio.load_checkpoint(model_path)
if args.samples:
states['args']['n_samples'] = args.samples
model = mdl.deserialize(states,
raymarching_early_stop_tolerance=0.01,
raymarching_chunk_size_or_sections=None,
perturb_sample=False).to(device.default()).eval()
print(f"model: {model_name} ({model._get_name()})")
print("args:", json.dumps(model.args))
run_dir = model_path.parent
output_dir = run_dir / f"output_{int(model_path.stem.split('_')[-1])}"
......@@ -102,71 +106,73 @@ output_dataset_id = '%s%s' % (
)
if __name__ == "__main__":
with torch.no_grad():
# 1. Initialize data loader
data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
shuffle=False, enable_preload=True,
# 1. Initialize data loader
data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
shuffle=False, enable_preload=not args.time,
color=color.from_str(model.args['color']))
# 3. Test on dataset
print("Begin test, batch size is %d" % RAYS_PER_BATCH)
# 3. Test on dataset
print("Begin test, batch size is %d" % RAYS_PER_BATCH)
i = 0
offset = 0
chns = model.chns('color')
n = dataset.n_views
total_pixels = prod([n, *dataset.res])
i = 0
offset = 0
chns = model.chns('color')
n = dataset.n_views
total_pixels = math.prod([n, *dataset.res])
out = {}
if args.output_flags['perf'] or args.output_flags['color']:
out = {}
if args.output_flags['perf'] or args.output_flags['color']:
out['color'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['diffuse']:
if args.output_flags['diffuse']:
out['diffuse'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['specular']:
if args.output_flags['specular']:
out['specular'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['depth']:
out['depth'] = torch.full([total_pixels, 1], HUGE_FLOAT, device=device.default())
gt_images = torch.empty_like(out['color']) if dataset.image_path else None
tot_time = 0
tot_iters = len(data_loader)
progress_bar(i, tot_iters, 'Inferring...')
for _, rays_o, rays_d, extra in data_loader:
if args.output_flags['depth']:
out['depth'] = torch.full([total_pixels, 1], math.huge, device=device.default())
gt_images = torch.empty_like(out['color']) if dataset.image_path else None
tot_time = 0
tot_iters = len(data_loader)
progress_bar(i, tot_iters, 'Inferring...')
for data in data_loader:
if args.output_flags['perf']:
test_perf = Perf.Node("Test")
n_rays = rays_o.size(0)
n_rays = data['rays_o'].size(0)
idx = slice(offset, offset + n_rays)
ret = model(rays_o, rays_d, extra_outputs=[key for key in out.keys() if key != 'color'])
ret = model(data, *out.keys())
if ret is not None:
for key in out:
out[key][idx][ret['rays_mask']] = ret[key]
if key not in ret:
out[key] = None
else:
if 'rays_filter' in ret:
out[key][idx][ret['rays_filter']] = ret[key]
else:
out[key][idx] = ret[key]
if args.output_flags['perf']:
test_perf.close()
torch.cuda.synchronize()
tot_time += test_perf.duration()
if gt_images is not None:
gt_images[idx] = extra['color']
gt_images[idx] = data['color']
i += 1
progress_bar(i, tot_iters, 'Inferring...')
offset += n_rays
# 4. Save results
print('Saving results...')
output_dir.mkdir(parents=True, exist_ok=True)
# 4. Save results
print('Saving results...')
output_dir.mkdir(parents=True, exist_ok=True)
for key in out:
out = {key: value for key, value in out.items() if value is not None}
for key in out:
out[key] = out[key].reshape([n, *dataset.res, *out[key].shape[1:]])
if 'color' in out:
out['color'] = out['color'].permute(0, 3, 1, 2)
if 'diffuse' in out:
out['diffuse'] = out['diffuse'].permute(0, 3, 1, 2)
if 'specular' in out:
out['specular'] = out['specular'].permute(0, 3, 1, 2)
if key == 'color' or key == 'diffuse' or key == 'specular':
out[key] = out[key].permute(0, 3, 1, 2)
if args.output_flags['perf']:
perf_errors = torch.full([n], nan)
perf_ssims = torch.full([n], nan)
if args.output_flags['perf']:
perf_errors = torch.full([n], math.nan)
perf_ssims = torch.full([n], math.nan)
if gt_images is not None:
gt_images = gt_images.reshape(n, *dataset.res, chns).permute(0, 3, 1, 2)
for i in range(n):
......@@ -189,8 +195,18 @@ if __name__ == "__main__":
for i in range(n)
])
for output_type in ['color', 'diffuse', 'specular']:
if not args.output_flags[output_type]:
error_images = ((gt_images - out['color'])**2).sum(1, True) / chns
error_images = (error_images / 1e-2).clamp(0, 1) * 255
error_images = img.torch2np(error_images)
error_images = np.asarray(error_images, dtype=np.uint8)
output_subdir = output_dir / f"{output_dataset_id}_error"
output_subdir.mkdir(exist_ok=True)
for i in range(n):
heat_img = cv2.applyColorMap(error_images[i], cv2.COLORMAP_JET) # 注意此处的三通道热力图是cv2专有的GBR排列
cv2.imwrite(f'{output_subdir}/{dataset.indices[i]:0>4d}.png', heat_img)
for output_type in ['color', 'diffuse', 'specular']:
if output_type not in out:
continue
if args.output_type == 'video':
output_file = output_dir / f"{output_dataset_id}_{output_type}.mp4"
......@@ -201,8 +217,8 @@ if __name__ == "__main__":
img.save(out[output_type],
[f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.output_flags['depth']:
colored_depths = img.colorize_depthmap(out['depth'][..., 0], model_args['sample_range'])
if 'depth' in out:
colored_depths = img.colorize_depthmap(out['depth'][..., 0], model.args['sample_range'])
if args.output_type == 'video':
output_file = output_dir / f"{output_dataset_id}_depth.mp4"
img.save_video(colored_depths, output_file, 30)
......@@ -214,7 +230,7 @@ if __name__ == "__main__":
# output_dir.mkdir(exist_ok=True)
#img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.time:
if args.time:
s = "Performance Report ==>\n"
res = get_perf_result()
if res is None:
......
This source diff could not be displayed because it is too large. You can view the blob instead.
import bpy
import json
import os
import math
import numpy as np
from itertools import product
scene = bpy.context.scene
cam_obj = scene.camera
cam = cam_obj.data
scene.cycles.device = 'GPU'
dataset_name = 'train'
tbox = [0.6, 0.6, 0.6]
rbox = [320, 40]
dataset_desc = {
'view_file_pattern': '%s/view_%%04d.png' % dataset_name,
"gl_coord": True,
'view_res': {
'x': 512,
'y': 512
},
'cam_params': {
'fov': 40.0,
'cx': 0.5,
'cy': 0.5,
'normalized': True
},
'range': {
'min': [-tbox[0] / 2, -tbox[1] / 2, -tbox[2] / 2, -rbox[0] / 2, -rbox[1] / 2],
'max': [tbox[0] / 2, tbox[1] / 2, tbox[2] / 2, rbox[0] / 2, rbox[1] / 2]
},
'samples': [5, 5, 5, 9, 2],
#'samples': [2000],
'view_centers': [],
'view_rots': []
}
data_desc_file = f'output/{dataset_name}.json'
if not os.path.exists('output'):
os.mkdir('output')
if os.path.exists(data_desc_file):
with open(data_desc_file, 'r') as fp:
dataset_desc.update(json.load(fp))
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
# Output resolution
scene.render.resolution_x = dataset_desc['view_res']['x']
scene.render.resolution_y = dataset_desc['view_res']['y']
# Field of view
cam.lens_unit = 'FOV'
cam.angle = math.radians(dataset_desc['cam_params']['fov'])
cam.dof.use_dof = False
def add_sample(i, x, y, z, rx, ry, render_only=False):
cam_obj.location = [x, y, z]
cam_obj.rotation_euler = [math.radians(ry), math.radians(rx), 0]
scene.render.filepath = 'output/' + dataset_desc['view_file_pattern'] % i
bpy.ops.render.render(write_still=True)
if not render_only:
dataset_desc['view_centers'].append(list(cam_obj.location))
dataset_desc['view_rots'].append([rx, ry])
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
for i in range(len(dataset_desc['view_centers'])):
if not os.path.exists('output/' + dataset_desc['view_file_pattern'] % i):
add_sample(i, *dataset_desc['view_centers'][i], *dataset_desc['view_rots'][i], render_only=True)
start_view = len(dataset_desc['view_centers'])
if len(dataset_desc['samples']) == 1:
range_min = np.array(dataset_desc['range']['min'])
range_max = np.array(dataset_desc['range']['max'])
samples = (range_max - range_min) * np.random.rand(dataset_desc['samples'][0], 5) + range_min
for i in range(start_view, dataset_desc['samples'][0]):
add_sample(i, *list(samples[i]))
else:
ranges = [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(0, 3)
] + [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(3, 5)
]
i = 0
for x, y, z, rx, ry in product(*ranges):
if i >= start_view:
add_sample(i, x, y, z, rx, ry)
i += 1
import bpy
import json
import os
import math
import numpy as np
from itertools import product
scene = bpy.context.scene
cam_obj = scene.camera
cam = cam_obj.data
scene.cycles.device = 'GPU'
dataset_name = 'train'
tbox = [0.7, 0.7, 0.7]
rbox = [300, 120]
dataset_desc = {
'view_file_pattern': '%s/view_%%04d.png' % dataset_name,
"gl_coord": True,
'view_res': {
'x': 512,
'y': 512
},
'cam_params': {
'fov': 60.0,
'cx': 0.5,
'cy': 0.5,
'normalized': True
},
'range': {
'min': [-tbox[0] / 2, -tbox[1] / 2, -tbox[2] / 2, -rbox[0] / 2, -rbox[1] / 2],
'max': [tbox[0] / 2, tbox[1] / 2, tbox[2] / 2, rbox[0] / 2, rbox[1] / 2]
},
'samples': [5, 5, 5, 6, 3],
#'samples': [2000],
'view_centers': [],
'view_rots': []
}
data_desc_file = f'output/{dataset_name}.json'
if not os.path.exists('output'):
os.mkdir('output')
if os.path.exists(data_desc_file):
with open(data_desc_file, 'r') as fp:
dataset_desc.update(json.load(fp))
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
# Output resolution
scene.render.resolution_x = dataset_desc['view_res']['x']
scene.render.resolution_y = dataset_desc['view_res']['y']
# Field of view
cam.lens_unit = 'FOV'
cam.angle = math.radians(dataset_desc['cam_params']['fov'])
cam.dof.use_dof = False
def add_sample(i, x, y, z, rx, ry, render_only=False):
cam_obj.location = [x, y, z]
cam_obj.rotation_euler = [math.radians(ry), math.radians(rx), 0]
scene.render.filepath = 'output/' + dataset_desc['view_file_pattern'] % i
bpy.ops.render.render(write_still=True)
if not render_only:
dataset_desc['view_centers'].append(list(cam_obj.location))
dataset_desc['view_rots'].append([rx, ry])
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
for i in range(len(dataset_desc['view_centers'])):
if not os.path.exists('output/' + dataset_desc['view_file_pattern'] % i):
add_sample(i, *dataset_desc['view_centers'][i], *dataset_desc['view_rots'][i], render_only=True)
start_view = len(dataset_desc['view_centers'])
if len(dataset_desc['samples']) == 1:
range_min = np.array(dataset_desc['range']['min'])
range_max = np.array(dataset_desc['range']['max'])
samples = (range_max - range_min) * np.random.rand(dataset_desc['samples'][0], 5) + range_min
for i in range(start_view, dataset_desc['samples'][0]):
add_sample(i, *list(samples[i]))
else:
ranges = [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(0, 3)
] + [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(3, 5)
]
i = 0
for x, y, z, rx, ry in product(*ranges):
if i >= start_view:
add_sample(i, x, y, z, rx, ry)
i += 1
import json
import sys
import os
import argparse
import numpy as np
import torch
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
from utils import misc
parser = argparse.ArgumentParser()
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_desc_path = args.dataset
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
data_dir = os.path.dirname(data_desc_path) + '/'
with open(data_desc_path, 'r') as fp:
dataset_desc = json.load(fp)
centers = np.array(dataset_desc['view_centers'])
t_max = np.max(centers, axis=0)
t_min = np.min(centers, axis=0)
dataset_desc['range'] = {
'min': [t_min[0], t_min[1], t_min[2], 0, 0],
'max': [t_max[0], t_max[1], t_max[2], 0, 0]
}
with open(data_desc_path, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
\ No newline at end of file
......@@ -7,7 +7,7 @@ import torch
from itertools import product, repeat
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str, default='train1')
......
import sys
import os
import argparse
import json
import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
from utils import misc
from utils.colmap_read_model import read_model
parser = argparse.ArgumentParser()
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_dir = args.dataset
os.makedirs(data_dir, exist_ok=True)
out_desc_path = os.path.join(data_dir, "train.json")
cameras, images, points3D = read_model(os.path.join(data_dir, 'sparse/0'), '.bin')
print("Model loaded.")
print("num_cameras:", len(cameras))
print("num_images:", len(images))
print("num_points3D:", len(points3D))
cam = cameras[list(cameras.keys())[0]]
views = np.array([int(images[img_id].name[5:9]) for img_id in images])
view_centers = np.array([images[img_id].tvec for img_id in images])
view_rots = []
for img_id in images:
im = images[img_id]
R = im.qvec2rotmat()
view_rots.append(R.reshape([9]).tolist())
view_rots = np.array(view_rots)
indices = np.argsort(views)
views = views[indices]
view_centers = view_centers[indices]
view_rots = view_rots[indices]
pts = np.array([points3D[pt_id].xyz for pt_id in points3D])
zvals = np.sqrt(np.sum(pts * pts, 1))
dataset_desc = {
'view_file_pattern': f"images/image%04d.jpg",
'gl_coord': True,
'view_res': {
'x': cam.width,
'y': cam.height
},
'cam_params': {
'fx': cam.params[0],
'fy': cam.params[0],
'cx': cam.params[1],
'cy': cam.params[2]
},
'range': {
'min': np.min(view_centers, 0).tolist() + [0, 0],
'max': np.max(view_centers, 0).tolist() + [0, 0]
},
'depth_range': {
'min': np.min(zvals),
'max': np.max(zvals)
},
'samples': [len(view_centers)],
'view_centers': view_centers.tolist(),
'view_rots': view_rots.tolist(),
'views': views.tolist()
}
with open(out_desc_path, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
import json
import numpy as np
from itertools import product
tbox = [0.1, 0.1, 0.1]
dataset_desc = {
"gl_coord": True,
'view_res': {
'x': 1440,
'y': 1600
},
'cam_params': {
'fov': 110.0,
'cx': 0.5,
'cy': 0.5,
'normalized': True
},
'range': {
'min': [-tbox[0] / 2, -tbox[1] / 2, -tbox[2] / 2, 0, 0],
'max': [tbox[0] / 2, tbox[1] / 2, tbox[2] / 2, 0, 0]
},
'samples': [3, 3, 3, 1, 1],
'view_centers': [],
'view_rots': []
}
panorama_pose_content = []
data_desc_file = f'for_panorama.json'
panorama_pose_file = "for_lipis.json"
ranges = [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(0, 3)
]
for x, y, z in product(*ranges):
dataset_desc['view_centers'].append([x, y, z])
dataset_desc['view_rots'].append([0, 0])
panorama_pose_content.append({
"rotation_euler": [0.0, 0.0, 0.0],
"location": [x, y, z]
})
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
with open(panorama_pose_file, 'w') as fp:
json.dump(panorama_pose_content, fp, indent=4)
import json
import sys
import os
import argparse
import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from utils import seqs
from utils import math
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--rot-range', nargs='+', type=int)
parser.add_argument('-t', '--trans-range', nargs='+', type=float)
parser.add_argument('--fov', type=float)
parser.add_argument('--res', type=str)
parser.add_argument('--gl', action='store_true')
parser.add_argument('-s', '--seq', type=str, required=True)
parser.add_argument('-n', '--views', type=int, required=True)
parser.add_argument('-o', '--out-desc', type=str)
parser.add_argument('--ref', type=str)
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_dir = args.dataset
os.makedirs(data_dir, exist_ok=True)
out_desc_path = os.path.join(data_dir, (args.out_desc if args.out_desc else f"{args.seq}.json"))
if args.ref:
with open(os.path.join(data_dir, args.ref), 'r') as fp:
ref_desc = json.load(fp)
else:
if not args.trans_range or not args.rot_range or not args.fov or not args.res:
print('-r, -t, --fov, --res options are required if --ref is not specified')
exit(-1)
ref_desc = None
if args.trans_range:
trans_range = np.array(list(args.trans_range) * 3 if len(args.trans_range) == 1
else args.trans_range)
else:
trans_range = np.array(ref_desc['range']['max'][0:3]) - \
np.array(ref_desc['range']['min'][0:3])
if args.rot_range:
rot_range = np.array(list(args.rot_range) * 2 if len(args.rot_range) == 1
else args.rot_range)
else:
rot_range = np.array(ref_desc['range']['max'][3:5]) - \
np.array(ref_desc['range']['min'][3:5])
filter_range = np.concatenate([trans_range, rot_range])
if args.fov:
cam_params = {
'fov': args.fov,
'cx': 0.5,
'cy': 0.5,
'normalized': True
}
else:
cam_params = ref_desc['cam_params']
if args.res:
res = tuple(int(s) for s in args.res.split('x'))
res = {'x': res[0], 'y': res[1]}
else:
res = ref_desc['view_res']
if args.seq == 'helix':
centers, rots = seqs.helix(trans_range, 4, args.views)
elif args.seq == 'scan_around':
centers, rots = seqs.scan_around(trans_range, 1, args.views)
elif args.seq == 'look_around':
centers, rots = seqs.look_around(trans_range, args.views)
rots *= 180 / math.pi
gl = args.gl or ref_desc and ref_desc.get('gl_coord')
if gl:
centers[:, 2] *= -1
rots[:, 0] *= -1
dataset_desc = {
'gl_coord': gl,
'view_res': res,
'cam_params': cam_params,
'range': {
'min': (-0.5 * filter_range).tolist(),
'max': (0.5 * filter_range).tolist()
},
'samples': [args.views],
'view_centers': centers.tolist(),
'view_rots': rots.tolist()
}
with open(out_desc_path, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
import json
import sys
import os
import argparse
import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
from utils import misc
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--rot-range', nargs='+', type=int)
parser.add_argument('-t', '--trans-range', nargs='+', type=float)
parser.add_argument('-k', '--trainset-ratio', type=float, default=0.7)
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_desc_path = args.dataset
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
data_dir = os.path.dirname(data_desc_path) + '/'
with open(data_desc_path, 'r') as fp:
dataset_desc = json.load(fp)
if args.trans_range:
trans_range = np.array(args.trans_range)
else:
trans_range = np.array(dataset_desc['range']['max'][0:3]) - \
np.array(dataset_desc['range']['min'][0:3])
if args.rot_range:
rot_range = np.array(args.rot_range)
else:
rot_range = np.array(dataset_desc['range']['max'][3:5]) - \
np.array(dataset_desc['range']['min'][3:5])
filter_range = np.concatenate([trans_range, rot_range])
out_data_dir = data_dir + 'r%dx%d_t%.1fx%.1fx%.1f/' % (
int(rot_range[0]), int(rot_range[1]),
trans_range[0], trans_range[1], trans_range[2]
)
dataset_version = 0
while True:
out_trainset_name = f'train_{dataset_version}'
out_testset_name = f'test_{dataset_version}'
if not os.path.exists(out_data_dir + out_trainset_name):
break
dataset_version += 1
def in_range(val, range): return val >= -range / 2 and val <= range / 2
views = []
for i in range(len(dataset_desc['view_centers'])):
if in_range(dataset_desc['view_rots'][i][0], rot_range[0]) and \
in_range(dataset_desc['view_rots'][i][1], rot_range[1]) and \
in_range(dataset_desc['view_centers'][i][0], trans_range[0]) and \
in_range(dataset_desc['view_centers'][i][1], trans_range[1]) and \
in_range(dataset_desc['view_centers'][i][2], trans_range[2]):
views.append(i)
if len(views) < 100:
print(f'Number of views in range is too small ({len(views)})')
exit()
views = np.random.permutation(views)
n_train_views = int(len(views) * args.trainset_ratio)
train_views = np.sort(views[:n_train_views])
test_views = np.sort(views[n_train_views:])
print('Train set views: ', len(train_views))
print('Test set views: ', len(test_views))
def create_subset(views, out_desc_name):
views = views.tolist()
subset_desc = dataset_desc.copy()
subset_desc['view_file_pattern'] = \
f"{out_desc_name}/{dataset_desc['view_file_pattern'].split('/')[-1]}"
subset_desc['range'] = {
'min': list(-filter_range / 2),
'max': list(filter_range / 2)
}
subset_desc['samples'] = [int(len(views))]
subset_desc['views'] = views
subset_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
subset_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
with open(os.path.join(out_data_dir, f'{out_desc_name}.json'), 'w') as fp:
json.dump(subset_desc, fp, indent=4)
os.makedirs(os.path.join(out_data_dir, out_desc_name), exist_ok=True)
for i in range(len(views)):
os.symlink(os.path.join('../../', dataset_desc['view_file_pattern'] % views[i]),
os.path.join(out_data_dir, subset_desc['view_file_pattern'] % views[i]))
os.makedirs(out_data_dir, exist_ok=True)
create_subset(train_views, out_trainset_name)
create_subset(train_views, out_testset_name)
......@@ -4,10 +4,18 @@ import os
import json
import argparse
from typing import Mapping
from data import get_dataset_desc_path, get_data_path
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from utils import misc
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, nargs='+')
parser.add_argument('output', type=str)
args = parser.parse_args()
input = [get_dataset_desc_path(path) for path in args.input]
output = get_dataset_desc_path(args.output)
def copy_images(src_path, dst_path, n, offset=0):
......@@ -16,38 +24,24 @@ def copy_images(src_path, dst_path, n, offset=0):
copy(src_path % i, dst_path % (i + offset))
input_data_desc_paths = [
'/home/dengnc/dvs/data/__new/barbershop_all/nerf_cvt.json',
'/home/dengnc/dvs/data/__new/__demo/fvvdp/0816_3_barbershop_cvt.json',
'/home/dengnc/dvs/data/__new/__demo/fvvdp/0816_1_barbershop_cvt.json',
]
output_data_desc_path = '/home/dengnc/dvs/data/__new/__demo/fvvdp/barbershop_3&1_nerf.json'
output_data_name = os.path.splitext(os.path.basename(output_data_desc_path))[0]
output_dir = os.path.dirname(output_data_desc_path)
with open(input_data_desc_paths[0], 'r') as fp:
with open(input[0], 'r') as fp:
dataset_desc: Mapping = json.load(fp)
n_views = 0
# Copy images of the first dataset
for i in range(len(input_data_desc_paths)):
for i in range(len(input)):
if i == 0:
input_desc = dataset_desc
else:
with open(input_data_desc_paths[i], 'r') as fp:
with open(input[i], 'r') as fp:
input_desc: Mapping = json.load(fp)
dataset_desc['view_centers'] += input_desc['view_centers']
dataset_desc['view_rots'] += input_desc['view_rots']
copy_images(
os.path.join(os.path.dirname(input_data_desc_paths[i]), input_desc['view_file_pattern']),
os.path.join(output_dir, output_data_name, 'view_%04d.png'),
len(input_desc['view_centers']), n_views
)
copy_images(get_data_path(input[i], input_desc['view_file_pattern']),
get_data_path(output, dataset_desc['view_file_pattern']),
len(input_desc['view_centers']), n_views)
n_views += len(input_desc['view_centers'])
dataset_desc['samples'] = [n_views]
dataset_desc['view_file_pattern'] = os.path.join(output_data_name, 'view_%04d.png')
with open(output_data_desc_path, 'w') as fp:
with open(output, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
import json
import sys
import os
import argparse
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from data import get_dataset_desc_path
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str, nargs="+", required=True)
parser.add_argument("-v", "--views", type=int, nargs="+", required=True)
parser.add_argument('dataset', type=str)
args = parser.parse_args()
input = get_dataset_desc_path(args.dataset)
outputs = [
get_dataset_desc_path(input.with_name(f"{input.stem}_{appendix}"))
for appendix in args.output
]
with open(input, 'r') as fp:
input_desc: dict = json.load(fp)
n_views = len(input_desc['view_centers'])
assert(len(args.views) == len(outputs))
sum_views = sum(args.views)
for i in range(len(args.views)):
if args.views[i] == -1:
args.views[i] = n_views - sum_views - 1
sum_views = n_views
break
assert(sum_views <= n_views)
for i in range(len(args.views)):
assert(args.views[i] > 0)
offset = 0
for i in range(len(outputs)):
n = args.views[i]
end = offset + n
output_desc = input_desc.copy()
output_desc['samples'] = args.views[i]
if 'views' in output_desc:
output_desc['views'] = output_desc['views'][offset:end]
else:
output_desc['views'] = list(range(offset, end))
output_desc['view_centers'] = output_desc['view_centers'][offset:end]
if 'view_rots' in output_desc:
output_desc['view_rots'] = output_desc['view_rots'][offset:end]
with open(outputs[i], 'w') as fp:
json.dump(output_desc, fp, indent=4)
# Create symbol links of images
out_dir = outputs[i].with_suffix('')
out_dir.mkdir(exist_ok=True)
for k in range(n):
os.symlink(Path("..") / input.stem / (output_desc['view_file_pattern'] % output_desc['views'][k]),
out_dir / (input_desc['view_file_pattern'] % output_desc['views'][k]))
offset += args.views[i]
from pathlib import Path
import sys
import argparse
import math
import torch
import torchvision.transforms.functional as trans_F
sys.path.append(str(Path(sys.path[0]).parent.absolute()))
from utils import img
from utils import math
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str)
......
import cv2
import numpy as np
import os
from utils.constants import *
import sys
from utils import math
def genGaussiankernel(width, sigma):
......@@ -87,7 +88,7 @@ def foveat_img(im, fixs):
# B
Bs = []
for i in range(1, prNum):
Bs.append((0.5 - Ts[i]) / (Ts[i-1] - Ts[i] + TINY_FLOAT))
Bs.append((0.5 - Ts[i]) / (Ts[i-1] - Ts[i] + math.tiny))
# M
Ms = np.zeros((prNum, R.shape[0], R.shape[1]))
......
import argparse
import logging
import os
from pathlib import Path
import sys
from pathlib import Path
from typing import List
import model as mdl
import train
from utils import color
from utils import device
from data.dataset_factory import *
from data.loader import DataLoader
from utils.misc import list_epochs, print_and_log
from utils import netio
from data import *
from utils.misc import print_and_log
RAYS_PER_BATCH = 2 ** 12
DATA_LOADER_CHUNK_SIZE = 1e8
root_dir = Path.cwd()
root_dir = Path(__file__).absolute().parent
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
help='Net config files')
parser.add_argument('-e', '--epochs', type=int, default=50,
parser.add_argument('-e', '--epochs', type=int,
help='Max epochs for train')
parser.add_argument('--perf', type=int, default=0,
parser.add_argument('--perf', type=int,
help='Performance measurement frames (0 for disabling performance measurement)')
parser.add_argument('--prune', type=int, default=5,
parser.add_argument('--prune', type=int, nargs='+',
help='Prune voxels on every # epochs')
parser.add_argument('--split', type=int, default=10,
parser.add_argument('--split', type=int, nargs='+',
help='Split voxels on every # epochs')
parser.add_argument('--freeze', type=int, nargs='+',
help='freeze levels on epochs')
parser.add_argument('--checkpoint-interval', type=int)
parser.add_argument('--views', type=str,
help='Specify the range of views to train')
parser.add_argument('path', type=str,
help='Dataset description file')
args = parser.parse_args()
views_to_load = range(*[int(val) for val in args.views.split('-')]) if args.views else None
argpath = Path(args.path)
# argpath: May be model path or data path
# 1) model path: continue training on the specified model
# 2) data path: train a new model using specified dataset
if argpath.suffix == ".tar":
args.mdl_path = argpath
else:
existed_epochs = list_epochs(argpath, "checkpoint_*.tar")
args.mdl_path = argpath / f"checkpoint_{existed_epochs[-1]}.tar" if existed_epochs else None
if args.mdl_path:
def load_dataset(data_path: Path):
print(f"Loading dataset {data_path}")
try:
dataset = DatasetFactory.load(data_path, views_to_load=views_to_load)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
os.chdir(dataset.root)
return dataset, dataset.name
except FileNotFoundError:
return load_multiscale_dataset(data_path)
def load_multiscale_dataset(data_path: Path):
if not data_path.is_dir():
raise ValueError(
f"Path {data_path} is not a directory")
dataset: List[Union[PanoDataset, ViewDataset]] = []
for sub_data_desc_path in data_path.glob("*.json"):
sub_dataset = DatasetFactory.load(sub_data_desc_path, views_to_load=views_to_load)
print(f"Sub-dataset loaded: {sub_dataset.root}/{sub_dataset.name}")
dataset.append(sub_dataset)
if len(dataset) == 0:
raise ValueError(f"Path {data_path} does not contain sub-datasets")
os.chdir(data_path.parent)
return dataset, data_path.name
try:
states, checkpoint_path = netio.load_checkpoint(argpath)
# Infer dataset path from model path
# The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar
dataset_name = args.mdl_path.parent.parent.name
dataset_dir = args.mdl_path.parent.parent.parent.parent
args.data_path = dataset_dir / dataset_name
args.mdl_path = args.mdl_path.relative_to(dataset_dir)
else:
args.data_path = argpath
args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None
dataset = DatasetFactory.load(args.data_path, views_to_load=args.views)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
os.chdir(dataset.root)
if args.mdl_path:
# Load model to continue training
model, states = mdl.load(args.mdl_path)
model_name = args.mdl_path.parent.name
model_class = model.__class__.__name__
model_args = model.args
else:
# Create model from specified configuration
with Path(f'{root_dir}/configs/{args.config}.json').open() as fp:
config = json.load(fp)
model_name = checkpoint_path.parts[-2]
dataset, dataset_name = load_dataset(
Path(*checkpoint_path.parts[:-4]) / checkpoint_path.parts[-3])
except Exception:
model_name = args.config
model_class = config['model']
model_args = config['args']
model_args['bbox'] = dataset.bbox
model_args['depth_range'] = dataset.depth_range
model, states = mdl.create(model_class, model_args), None
model.to(device.default())
run_dir = Path(f"_nets/{dataset.name}/{model_name}")
dataset, dataset_name = load_dataset(argpath)
# Load state 0 from specified configuration
with Path(f'{root_dir}/configs/{args.config}.json').open() as fp:
states = json.load(fp)
states['args']['bbox'] = dataset[0].bbox if isinstance(dataset, list) else dataset.bbox
states['args']['depth_range'] = dataset[0].depth_range if isinstance(dataset, list)\
else dataset.depth_range
if 'train' not in states:
states['train'] = {}
if args.prune is not None:
states['train']['prune_epochs'] = args.prune
if args.split is not None:
states['train']['split_epochs'] = args.split
if args.freeze is not None:
states['train']['freeze_epochs'] = args.freeze
if args.perf is not None:
states['train']['perf_frames'] = args.perf
if args.checkpoint_interval is not None:
states['train']['checkpoint_interval'] = args.checkpoint_interval
if args.epochs is not None:
states['train']['max_epochs'] = args.epochs
model = mdl.deserialize(states).to(device.default())
# Initialize run directory
run_dir = Path(f"_nets/{dataset_name}/{model_name}")
run_dir.mkdir(parents=True, exist_ok=True)
# Initialize logging
log_file = run_dir / "train.log"
logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO,
filename=log_file, filemode='a' if log_file.exists() else 'w')
print_and_log(f"model: {model_name} ({model_class})")
print_and_log(f"args: {json.dumps(model.args0)}")
def log_exception(exc_type, exc_value, exc_traceback):
if not issubclass(exc_type, KeyboardInterrupt):
logging.exception(exc_value, exc_info=(exc_type, exc_value, exc_traceback))
sys.__excepthook__(exc_type, exc_value, exc_traceback)
sys.excepthook = log_exception
print_and_log(f"model: {model_name} ({model.cls})")
print_and_log(f"args:")
model.print_config()
print(model)
if __name__ == "__main__":
# 1. Initialize data loader
data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
shuffle=True, enable_preload=True,
color=color.from_str(model.args['color']))
data_loader = get_loader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
shuffle=True, enable_preload=False, color=model.color)
# 2. Initialize model and trainer
trainer = train.get_trainer(model, run_dir=run_dir, states=states, perf_frames=args.perf,
pruning_loop=args.prune, splitting_loop=args.split)
trainer = train.get_trainer(model, run_dir, states)
# 3. Train
trainer.train(data_loader, args.epochs)
\ No newline at end of file
trainer.train(data_loader)
import importlib
import os
from pathlib import Path
from model.base import BaseModel
from . import base
from .train import train_classes, Train
# Automatically import any python files this directory
......@@ -18,9 +19,9 @@ for file in os.listdir(package_dir):
def get_class(class_name: str) -> type:
return base.train_classes[class_name]
return train_classes[class_name]
def get_trainer(model: BaseModel, **kwargs) -> base.Train:
def get_trainer(model: BaseModel, run_dir: Path, states: dict) -> Train:
train_class = get_class(model.TrainerClass)
return train_class(model, **kwargs)
return train_class(model, run_dir, states)
import csv
import json
import logging
import sys
import time
import torch
import torch.nn.functional as nn_f
from typing import Dict
from typing import Any, Dict, Union
from pathlib import Path
import loss
from utils.constants import HUGE_FLOAT
from utils.misc import format_time
from utils import netio, math
from utils.misc import format_time, print_and_log
from utils.progress_bar import progress_bar
from utils.perf import Perf, checkpoint, enable_perf, perf, get_perf_result
from utils.perf import Perf, enable_perf, perf, get_perf_result
from utils.env import set_env
from utils.type import InputData, ReturnData
from data.loader import DataLoader
from model import serialize
from model.base import BaseModel
from model import save
train_classes = {}
......@@ -34,62 +35,74 @@ class Train(object, metaclass=BaseTrainMeta):
def perf_mode(self):
return self.perf_frames > 0
def __init__(self, model: BaseModel, *,
run_dir: Path, states: dict = None, perf_frames: int = 0) -> None:
def _arg(self, name: str, default=None):
return self.states.get("train", {}).get(name, default)
def __init__(self, model: BaseModel, run_dir: Path, states: dict) -> None:
super().__init__()
print_and_log(
f"Create trainer {__class__} with args: {json.dumps(states.get('train', {}))}")
self.model = model
self.epoch = 0
self.iters = 0
self.run_dir = run_dir
self.states = states
self.epoch = states.get("epoch", 0)
self.iters = states.get("iters", 0)
self.max_epochs = self._arg("max_epochs", 50)
self.checkpoint_interval = self._arg("checkpoint_interval", 10)
self.perf_frames = self._arg("perf_frames", 0)
self.model.trainer = self
self.model.train()
self.reset_optimizer()
if states:
if 'epoch' in states:
self.epoch = states['epoch']
if 'iters' in states:
self.iters = states['iters']
self.reset_optimizer()
if 'opti' in states:
self.optimizer.load_state_dict(states['opti'])
# For performance measurement
self.perf_frames = perf_frames
if self.perf_mode:
enable_perf()
self.env = {
"trainer": self
}
def reset_optimizer(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
def train(self, data_loader: DataLoader, max_epochs: int):
def train(self, data_loader: DataLoader):
set_env(self.env)
self.data_loader = data_loader
self.iters_per_epoch = self.perf_frames or len(data_loader)
print("Begin training...")
while self.epoch < max_epochs:
self.epoch += 1
print(f"Begin training... Max epochs: {self.max_epochs}")
while self.epoch < self.max_epochs:
self._train_epoch()
self._save_checkpoint()
print("Train finished")
def _save_checkpoint(self):
save(self.run_dir / f'checkpoint_{self.epoch}.tar', self.model, epoch=self.epoch,
iters=self.iters, opti=self.optimizer.state_dict())
(self.run_dir / '_misc').mkdir(exist_ok=True)
# Clean checkpoints
for i in range(1, self.epoch):
if i % 10 != 0:
(self.run_dir / f'checkpoint_{i}.tar').unlink(missing_ok=True)
def _show_progress(self, iters_in_epoch: int, loss: Dict[str, float] = {}):
loss_val = loss.get('val', 0)
loss_min = loss.get('min', 0)
loss_max = loss.get('max', 0)
loss_avg = loss.get('avg', 0)
if i % self.checkpoint_interval != 0:
checkpoint_path = self.run_dir / f'checkpoint_{i}.tar'
if checkpoint_path.exists():
checkpoint_path.rename(self.run_dir / f'_misc/checkpoint_{i}.tar')
# Save checkpoint
self.states.update({
**serialize(self.model),
"epoch": self.epoch,
"iters": self.iters,
"opti": self.optimizer.state_dict()
})
netio.save_checkpoint(self.states, self.run_dir, self.epoch)
def _show_progress(self, iters_in_epoch: int, avg_loss: float = 0, recent_loss: float = 0):
iters_per_epoch = self.perf_frames or len(self.data_loader)
progress_bar(iters_in_epoch, iters_per_epoch,
f"Loss: {loss_val:.2e} ({loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e})",
f"Epoch {self.epoch:<3d}",
f" {self.run_dir}")
f"Loss: {recent_loss:.2e} ({avg_loss:.2e})",
f"Epoch {self.epoch + 1:<3d}",
f" {self.run_dir.absolute()}")
def _show_perf(self):
s = "Performance Report ==>\n"
......@@ -102,71 +115,100 @@ class Train(object, metaclass=BaseTrainMeta):
s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n"
print(s)
@perf
def _train_iter(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
extra: Dict[str, torch.Tensor]) -> float:
out = self.model(rays_o, rays_d, extra_outputs=['energies', 'speculars'])
if 'rays_mask' in out:
extra = {key: value[out['rays_mask']] for key, value in extra.items()}
checkpoint("Forward")
self.optimizer.zero_grad()
loss_val = loss.mse_loss(out['color'], extra['color'])
if self.model.args.get('density_regularization_weight'):
loss_val += loss.cauchy_loss(out['energies'],
s=self.model.args['density_regularization_scale']) \
* self.model.args['density_regularization_weight']
if self.model.args.get('specular_regularization_weight'):
loss_val += loss.cauchy_loss(out['speculars'],
s=self.model.args['specular_regularization_scale']) \
* self.model.args['specular_regularization_weight']
checkpoint("Compute loss")
def _forward(self, data: InputData) -> ReturnData:
return self.model(data, 'color', 'energies', 'speculars')
@perf
def _train_iter(self, data: Dict[str, Union[torch.Tensor, Any]]) -> float:
def filtered_data(data, filter):
if filter is not None:
return data[filter]
return data
with perf("Forward"):
if isinstance(data, list):
out_colors = []
out_energies = []
out_speculars = []
gt_colors = []
for datum in data:
partial_out = self._forward(datum)
out_colors.append(partial_out['color'])
out_energies.append(partial_out['energies'].flatten())
if 'speculars' in partial_out:
out_speculars.append(partial_out['speculars'].flatten())
gt_colors.append(filtered_data(datum["color"], partial_out.get("rays_filter")))
out_colors = torch.cat(out_colors)
out_energies = torch.cat(out_energies)
out_speculars = torch.cat(out_speculars) if len(out_speculars) > 0 else None
gt_colors = torch.cat(gt_colors)
else:
out = self._forward(data)
out_colors = out['color']
out_energies = out['energies']
out_speculars = out.get('speculars')
gt_colors = filtered_data(data['color'], out.get("rays_filter"))
with perf("Compute loss"):
loss_val = loss.mse_loss(out_colors, gt_colors)
if self._arg("density_regularization_weight"):
loss_val += loss.cauchy_loss(out_energies, s=self._arg("density_regularization_scale"))\
* self._arg("density_regularization_weight")
if self._arg("specular_regularization_weight") and out_speculars is not None:
loss_val += loss.cauchy_loss(out_speculars, s=self._arg("specular_regularization_scale")) \
* self._arg("specular_regularization_weight")
#return loss_val.item() # TODO remove this line
with perf("Backward"):
self.optimizer.zero_grad(True)
loss_val.backward()
checkpoint("Backward")
with perf("Update"):
self.optimizer.step()
checkpoint("Update")
return loss_val.item()
def _train_epoch(self):
iters_in_epoch = 0
loss_min = HUGE_FLOAT
loss_max = 0
loss_avg = 0
train_epoch_node = Perf.Node("Train Epoch")
self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0})
for idx, rays_o, rays_d, extra in self.data_loader:
loss_val = self._train_iter(rays_o, rays_d, extra)
recent_loss = []
tot_loss = 0
loss_min = min(loss_min, loss_val)
loss_max = max(loss_max, loss_val)
loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
train_epoch_node = Perf.Node("Train Epoch")
self._show_progress(iters_in_epoch)
for data in self.data_loader:
loss_val = self._train_iter(data)
self.iters += 1
iters_in_epoch += 1
self._show_progress(iters_in_epoch, loss={
'val': loss_val,
'min': loss_min,
'max': loss_max,
'avg': loss_avg
})
recent_loss = (recent_loss + [loss_val])[-50:]
recent_avg_loss = sum(recent_loss) / len(recent_loss)
tot_loss += loss_val
avg_loss = tot_loss / iters_in_epoch
#loss_min = min(loss_min, loss_val)
#loss_max = max(loss_max, loss_val)
#loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
self._show_progress(iters_in_epoch, avg_loss=avg_loss, recent_loss=recent_avg_loss)
if self.perf_mode and iters_in_epoch >= self.perf_frames:
self._show_perf()
exit()
train_epoch_node.close()
torch.cuda.synchronize()
self.epoch += 1
epoch_dur = train_epoch_node.duration() / 1000
logging.info(f"Epoch {self.epoch} spent {format_time(epoch_dur)} "
f"(Avg. {format_time(epoch_dur / self.iters_per_epoch)}/iter). "
f"Loss is {loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e}")
f"Loss is {avg_loss:.2e}")
#print(list(self.model.model(0).named_parameters())[2])
#print(list(self.model.model(1).named_parameters())[2])
def _train_epoch_debug(self): # TBR
iters_in_epoch = 0
loss_min = HUGE_FLOAT
loss_min = math.huge
loss_max = 0
loss_avg = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment