Commit 6294701e authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 2824f796
Subproject commit 10f49b1e7df38a58fd78451eac91d7ac1a21df64
import os import os
import sys import sys
import argparse import argparse
import shutil
from typing import Mapping from typing import Mapping
from utils.constants import TINY_FLOAT
import torch import torch
import torch.optim import torch.optim
import math
import time import time
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn from torch import nn
...@@ -49,7 +46,7 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device()) ...@@ -49,7 +46,7 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device())
from utils import netio from utils import netio
from utils import misc from utils import math
from utils import device from utils import device
from utils import img from utils import img
from utils import interact from utils import interact
...@@ -243,10 +240,10 @@ def train_loop(data_loader, optimizer, perf, writer, epoch, iters): ...@@ -243,10 +240,10 @@ def train_loop(data_loader, optimizer, perf, writer, epoch, iters):
for i in range(1, len(out)): for i in range(1, len(out)):
loss_value += loss_mse(out[i]['color'], gt) loss_value += loss_mse(out[i]['color'], gt)
if config.depth_ref: if config.depth_ref:
disp_loss_value = loss_mse(torch.reciprocal(out[0]['depth'] + TINY_FLOAT), gt_disp) disp_loss_value = loss_mse(torch.reciprocal(out[0]['depth'] + math.tiny), gt_disp)
for i in range(1, len(out)): for i in range(1, len(out)):
disp_loss_value += loss_mse(torch.reciprocal( disp_loss_value += loss_mse(torch.reciprocal(
out[i]['depth'] + TINY_FLOAT), gt_disp) out[i]['depth'] + math.tiny), gt_disp)
disp_loss_value = disp_loss_value / math.pow( disp_loss_value = disp_loss_value / math.pow(
1 / dataset.depth_range[0] - 1 / dataset.depth_range[1], 2) 1 / dataset.depth_range[0] - 1 / dataset.depth_range[1], 2)
else: else:
......
...@@ -2,102 +2,109 @@ ...@@ -2,102 +2,109 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from utils.voxels import *\n",
"\n",
"bbox, steps = torch.tensor([[-2, -3.14159, 1], [2, 3.14159, 0]]), torch.tensor([2, 3, 3])\n",
"voxel_size = (bbox[1] - bbox[0]) / steps\n",
"voxels = init_voxels(bbox, steps)\n",
"corners, corner_indices = get_corners(voxels, bbox, steps)\n",
"voxel_indices_in_grid = torch.arange(voxels.shape[0])\n",
"emb = torch.nn.Embedding(corners.shape[0], 3, _weight=corners)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([11, 3]) tensor([ 0, -1, -1, 1, -1, -1, 2, 3, 4, -1, 5, 6, -1, 7, 8, -1, 9, 10])\n"
]
}
],
"source": [
"keeps = torch.tensor([True]*18)\n",
"keeps[torch.tensor([1,2,4,5,9,12,15])] = False\n",
"voxels = voxels[keeps]\n",
"corner_indices = corner_indices[keeps]\n",
"grid_indices, _ = to_grid_indices(voxels, bbox, steps=steps)\n",
"voxel_indices_in_grid = grid_indices.new_full([steps.prod().item()], -1)\n",
"voxel_indices_in_grid[grid_indices] = torch.arange(voxels.shape[0])\n",
"print(voxels.shape, voxel_indices_in_grid)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"torch.Size([88, 3]) torch.Size([185, 3]) torch.Size([88, 8])\n" "Mixin.__init__\n",
"Base.__init__\n",
"Child.__init__\n",
"Mixin.fn\n",
"Mixin.fn\n",
"(<class '__main__.Child'>, <class '__main__.Base'>, <class '__main__.Mixin'>, <class '__main__.Obj'>, <class 'object'>)\n"
] ]
} }
], ],
"source": [ "source": [
"new_voxels = split_voxels(voxels, (bbox[1] - bbox[0]) / steps, 2, align_border=False).reshape(-1, 3)\n", "class Obj:\n",
"new_corners, new_corner_indices = get_corners(new_voxels, bbox, steps * 2)\n", " def fn(self):\n",
"print(new_voxels.shape, new_corners.shape, new_corner_indices.shape)" " print(\"Obj.fn\")\n",
"class Base(Obj):\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" print(\"Base.__init__\")\n",
"\n",
" def fn(self):\n",
" super().fn()\n",
" print(\"Base.fn\")\n",
" \n",
" def fn1(self):\n",
" self.fn()\n",
"\n",
"class Mixin(Obj):\n",
" def __init__(self) -> None:\n",
" print(\"Mixin.__init__\")\n",
" self.fn = self._fn\n",
" \n",
" def _fn(self):\n",
" print(\"Mixin.fn\")\n",
"\n",
"class Child(Base, Mixin):\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" print(\"Child.__init__\")\n",
"\n",
" \n",
"\n",
"a = Child()\n",
"a.fn()\n",
"a.fn1()\n",
"print(Child.__mro__)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"tensor([ 0, 0, -1, 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4,\n", "Base.__init__\n",
" 4, 2, 2, 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, 0, 0, -1,\n", "Child.fn: <__main__.Child object at 0x7f62583e0640>\n"
" 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4, 4, 2, 2,\n",
" 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, -1, -1, 5, 5, 6, 6,\n",
" 6, -1, -1, 5, 5, 6, 6, 6, -1, -1, 7, 7, 8, 8, 8, -1, -1, 7,\n",
" 7, 8, 8, 8, -1, -1, 9, 9, 10, 10, 10, -1, -1, 9, 9, 10, 10, 10,\n",
" -1, -1, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7,\n",
" 7, 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10,\n",
" 10, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7, 7,\n",
" 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10, 10,\n",
" 9, 9, 10, 10, 10])\n",
"tensor(0)\n"
] ]
} }
], ],
"source": [ "source": [
"voxel_indices_of_new_corner = voxel_indices_in_grid[to_flat_indices(to_grid_coords(new_corners, bbox, steps=steps).min(steps - 1), steps)]\n", "class Base:\n",
"print(voxel_indices_of_new_corner)\n", " def __init__(self) -> None:\n",
"p_of_new_corners = (new_corners - voxels[voxel_indices_of_new_corner]) / voxel_size + .5\n", " print(\"Base.__init__\")\n",
"print(((new_corners - trilinear_interp(p_of_new_corners, emb(corner_indices[voxel_indices_of_new_corner]))) > 1e-6).sum())" " \n",
" def fn(self):\n",
" print(\"Base.fn\")\n",
"\n",
" def fn1(self):\n",
" self.fn()\n",
"\n",
"def createChildClass(name):\n",
" def __init__(self):\n",
" super(self.__class__, self).__init__()\n",
" \n",
" def fn(self):\n",
" print(f\"{name}.fn: {self}\")\n",
" \n",
" return type(name, (Base, ), {\n",
" \"__init__\": __init__,\n",
" \"fn\": fn\n",
" })\n",
"\n",
"Child = createChildClass(\"Child\")\n",
"\n",
"a = Child()\n",
"a.fn()"
] ]
} }
], ],
"metadata": { "metadata": {
"interpreter": { "interpreter": {
"hash": "08b118544df3cb8970a671e5837a88fd458f4d4c799ef1fb2709465a22a45b92" "hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3.9.5 64-bit ('base': conda)", "display_name": "Python 3.8.5 64-bit ('base': conda)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
...@@ -111,7 +118,7 @@ ...@@ -111,7 +118,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.5" "version": "3.8.5"
}, },
"orig_nbformat": 4 "orig_nbformat": 4
}, },
......
...@@ -2,7 +2,8 @@ import os ...@@ -2,7 +2,8 @@ import os
import argparse import argparse
import torch import torch
import torch.nn.functional as nn_f import torch.nn.functional as nn_f
from math import nan, ceil, prod import cv2
import numpy as np
from pathlib import Path from pathlib import Path
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -10,12 +11,13 @@ parser.add_argument('-m', '--model', type=str, ...@@ -10,12 +11,13 @@ parser.add_argument('-m', '--model', type=str,
help='The model file to load for testing') help='The model file to load for testing')
parser.add_argument('-r', '--output-res', type=str, parser.add_argument('-r', '--output-res', type=str,
help='Output resolution') help='Output resolution')
parser.add_argument('-o', '--output', nargs='+', type=str, default=['perf', 'color'], parser.add_argument('-o', '--output', nargs='*', type=str, default=['perf', 'color'],
help='Specify what to output (perf, color, depth, all)') help='Specify what to output (perf, color, depth, all)')
parser.add_argument('--output-type', type=str, default='image', parser.add_argument('--output-type', type=str, default='image',
help='Specify the output type (image, video, debug)') help='Specify the output type (image, video, debug)')
parser.add_argument('--views', type=str, parser.add_argument('--views', type=str,
help='Specify the range of views to test') help='Specify the range of views to test')
parser.add_argument('-s', '--samples', type=int)
parser.add_argument('-p', '--prompt', action='store_true', parser.add_argument('-p', '--prompt', action='store_true',
help='Interactive prompt mode') help='Interactive prompt mode')
parser.add_argument('--time', action='store_true', parser.add_argument('--time', action='store_true',
...@@ -31,18 +33,18 @@ from utils import color ...@@ -31,18 +33,18 @@ from utils import color
from utils import interact from utils import interact
from utils import device from utils import device
from utils import img from utils import img
from utils import netio
from utils import math
from utils.perf import Perf, enable_perf, get_perf_result from utils.perf import Perf, enable_perf, get_perf_result
from utils.progress_bar import progress_bar from utils.progress_bar import progress_bar
from data.dataset_factory import * from data import *
from data.loader import DataLoader
from utils.constants import HUGE_FLOAT
RAYS_PER_BATCH = 2 ** 12
DATA_LOADER_CHUNK_SIZE = 1e8 DATA_LOADER_CHUNK_SIZE = 1e8
torch.set_grad_enabled(False)
data_desc_path = DatasetFactory.get_dataset_desc_path(args.dataset) data_desc_path = get_dataset_desc_path(args.dataset)
os.chdir(data_desc_path.parent) os.chdir(data_desc_path.parent)
nets_dir = Path("_nets") nets_dir = Path("_nets")
data_desc_path = data_desc_path.name data_desc_path = data_desc_path.name
...@@ -81,18 +83,20 @@ dataset = DatasetFactory.load(data_desc_path, res=args.output_res, ...@@ -81,18 +83,20 @@ dataset = DatasetFactory.load(data_desc_path, res=args.output_res,
views_to_load=args.views) views_to_load=args.views)
print(f"Dataset loaded: {dataset.root}/{dataset.name}") print(f"Dataset loaded: {dataset.root}/{dataset.name}")
RAYS_PER_BATCH = dataset.res[0] * dataset.res[1] // 4
model_path: Path = nets_dir / args.model model_path: Path = nets_dir / args.model
model_name = model_path.parent.name model_name = model_path.parent.name
model = mdl.load(model_path, { states, _ = netio.load_checkpoint(model_path)
"raymarching_early_stop_tolerance": 0.01, if args.samples:
# "raymarching_chunk_size_or_sections": [8], states['args']['n_samples'] = args.samples
"perturb_sample": False model = mdl.deserialize(states,
})[0].to(device.default()).eval() raymarching_early_stop_tolerance=0.01,
model_class = model.__class__.__name__ raymarching_chunk_size_or_sections=None,
model_args = model.args perturb_sample=False).to(device.default()).eval()
print(f"model: {model_name} ({model_class})") print(f"model: {model_name} ({model._get_name()})")
print("args:", json.dumps(model.args0)) print("args:", json.dumps(model.args))
run_dir = model_path.parent run_dir = model_path.parent
output_dir = run_dir / f"output_{int(model_path.stem.split('_')[-1])}" output_dir = run_dir / f"output_{int(model_path.stem.split('_')[-1])}"
...@@ -102,125 +106,137 @@ output_dataset_id = '%s%s' % ( ...@@ -102,125 +106,137 @@ output_dataset_id = '%s%s' % (
) )
if __name__ == "__main__": # 1. Initialize data loader
with torch.no_grad(): data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
# 1. Initialize data loader shuffle=False, enable_preload=not args.time,
data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE, color=color.from_str(model.args['color']))
shuffle=False, enable_preload=True,
color=color.from_str(model.args['color'])) # 3. Test on dataset
print("Begin test, batch size is %d" % RAYS_PER_BATCH)
# 3. Test on dataset
print("Begin test, batch size is %d" % RAYS_PER_BATCH) i = 0
offset = 0
i = 0 chns = model.chns('color')
offset = 0 n = dataset.n_views
chns = model.chns('color') total_pixels = math.prod([n, *dataset.res])
n = dataset.n_views
total_pixels = prod([n, *dataset.res]) out = {}
if args.output_flags['perf'] or args.output_flags['color']:
out = {} out['color'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['perf'] or args.output_flags['color']: if args.output_flags['diffuse']:
out['color'] = torch.zeros(total_pixels, chns, device=device.default()) out['diffuse'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['diffuse']: if args.output_flags['specular']:
out['diffuse'] = torch.zeros(total_pixels, chns, device=device.default()) out['specular'] = torch.zeros(total_pixels, chns, device=device.default())
if args.output_flags['specular']: if args.output_flags['depth']:
out['specular'] = torch.zeros(total_pixels, chns, device=device.default()) out['depth'] = torch.full([total_pixels, 1], math.huge, device=device.default())
if args.output_flags['depth']: gt_images = torch.empty_like(out['color']) if dataset.image_path else None
out['depth'] = torch.full([total_pixels, 1], HUGE_FLOAT, device=device.default())
gt_images = torch.empty_like(out['color']) if dataset.image_path else None tot_time = 0
tot_iters = len(data_loader)
tot_time = 0 progress_bar(i, tot_iters, 'Inferring...')
tot_iters = len(data_loader) for data in data_loader:
progress_bar(i, tot_iters, 'Inferring...') if args.output_flags['perf']:
for _, rays_o, rays_d, extra in data_loader: test_perf = Perf.Node("Test")
if args.output_flags['perf']: n_rays = data['rays_o'].size(0)
test_perf = Perf.Node("Test") idx = slice(offset, offset + n_rays)
n_rays = rays_o.size(0) ret = model(data, *out.keys())
idx = slice(offset, offset + n_rays)
ret = model(rays_o, rays_d, extra_outputs=[key for key in out.keys() if key != 'color']) if ret is not None:
if ret is not None:
for key in out:
out[key][idx][ret['rays_mask']] = ret[key]
if args.output_flags['perf']:
test_perf.close()
torch.cuda.synchronize()
tot_time += test_perf.duration()
if gt_images is not None:
gt_images[idx] = extra['color']
i += 1
progress_bar(i, tot_iters, 'Inferring...')
offset += n_rays
# 4. Save results
print('Saving results...')
output_dir.mkdir(parents=True, exist_ok=True)
for key in out: for key in out:
out[key] = out[key].reshape([n, *dataset.res, *out[key].shape[1:]]) if key not in ret:
if 'color' in out: out[key] = None
out['color'] = out['color'].permute(0, 3, 1, 2)
if 'diffuse' in out:
out['diffuse'] = out['diffuse'].permute(0, 3, 1, 2)
if 'specular' in out:
out['specular'] = out['specular'].permute(0, 3, 1, 2)
if args.output_flags['perf']:
perf_errors = torch.full([n], nan)
perf_ssims = torch.full([n], nan)
if gt_images is not None:
gt_images = gt_images.reshape(n, *dataset.res, chns).permute(0, 3, 1, 2)
for i in range(n):
perf_errors[i] = nn_f.mse_loss(gt_images[i], out['color'][i]).item()
perf_ssims[i] = ssim(gt_images[i:i + 1], out['color'][i:i + 1]).item() * 100
perf_mean_time = tot_time / n
perf_mean_error = torch.mean(perf_errors).item()
perf_name = f'perf_{output_dataset_id}_{perf_mean_time:.1f}ms_{perf_mean_error:.2e}.csv'
# Remove old performance reports
for file in output_dir.glob(f'perf_{output_dataset_id}*'):
file.unlink()
# Save new performance reports
with (output_dir / perf_name).open('w') as fp:
fp.write('View, PSNR, SSIM\n')
fp.writelines([
f'{dataset.indices[i]}, '
f'{img.mse2psnr(perf_errors[i].item()):.2f}, {perf_ssims[i].item():.2f}\n'
for i in range(n)
])
for output_type in ['color', 'diffuse', 'specular']:
if not args.output_flags[output_type]:
continue
if args.output_type == 'video':
output_file = output_dir / f"{output_dataset_id}_{output_type}.mp4"
img.save_video(out[output_type], output_file, 30)
else:
output_subdir = output_dir / f"{output_dataset_id}_{output_type}"
output_subdir.mkdir(exist_ok=True)
img.save(out[output_type],
[f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.output_flags['depth']:
colored_depths = img.colorize_depthmap(out['depth'][..., 0], model_args['sample_range'])
if args.output_type == 'video':
output_file = output_dir / f"{output_dataset_id}_depth.mp4"
img.save_video(colored_depths, output_file, 30)
else: else:
output_subdir = output_dir / f"{output_dataset_id}_depth" if 'rays_filter' in ret:
output_subdir.mkdir(exist_ok=True) out[key][idx][ret['rays_filter']] = ret[key]
img.save(colored_depths, [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) else:
#output_subdir = output_dir / f"{output_dataset_id}_bins" out[key][idx] = ret[key]
# output_dir.mkdir(exist_ok=True) if args.output_flags['perf']:
#img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) test_perf.close()
torch.cuda.synchronize()
if args.time: tot_time += test_perf.duration()
s = "Performance Report ==>\n" if gt_images is not None:
res = get_perf_result() gt_images[idx] = data['color']
if res is None: i += 1
s += "No available data.\n" progress_bar(i, tot_iters, 'Inferring...')
else: offset += n_rays
for key, val in res.items():
path_segs = key.split("/") # 4. Save results
s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n" print('Saving results...')
print(s) output_dir.mkdir(parents=True, exist_ok=True)
out = {key: value for key, value in out.items() if value is not None}
for key in out:
out[key] = out[key].reshape([n, *dataset.res, *out[key].shape[1:]])
if key == 'color' or key == 'diffuse' or key == 'specular':
out[key] = out[key].permute(0, 3, 1, 2)
if args.output_flags['perf']:
perf_errors = torch.full([n], math.nan)
perf_ssims = torch.full([n], math.nan)
if gt_images is not None:
gt_images = gt_images.reshape(n, *dataset.res, chns).permute(0, 3, 1, 2)
for i in range(n):
perf_errors[i] = nn_f.mse_loss(gt_images[i], out['color'][i]).item()
perf_ssims[i] = ssim(gt_images[i:i + 1], out['color'][i:i + 1]).item() * 100
perf_mean_time = tot_time / n
perf_mean_error = torch.mean(perf_errors).item()
perf_name = f'perf_{output_dataset_id}_{perf_mean_time:.1f}ms_{perf_mean_error:.2e}.csv'
# Remove old performance reports
for file in output_dir.glob(f'perf_{output_dataset_id}*'):
file.unlink()
# Save new performance reports
with (output_dir / perf_name).open('w') as fp:
fp.write('View, PSNR, SSIM\n')
fp.writelines([
f'{dataset.indices[i]}, '
f'{img.mse2psnr(perf_errors[i].item()):.2f}, {perf_ssims[i].item():.2f}\n'
for i in range(n)
])
error_images = ((gt_images - out['color'])**2).sum(1, True) / chns
error_images = (error_images / 1e-2).clamp(0, 1) * 255
error_images = img.torch2np(error_images)
error_images = np.asarray(error_images, dtype=np.uint8)
output_subdir = output_dir / f"{output_dataset_id}_error"
output_subdir.mkdir(exist_ok=True)
for i in range(n):
heat_img = cv2.applyColorMap(error_images[i], cv2.COLORMAP_JET) # 注意此处的三通道热力图是cv2专有的GBR排列
cv2.imwrite(f'{output_subdir}/{dataset.indices[i]:0>4d}.png', heat_img)
for output_type in ['color', 'diffuse', 'specular']:
if output_type not in out:
continue
if args.output_type == 'video':
output_file = output_dir / f"{output_dataset_id}_{output_type}.mp4"
img.save_video(out[output_type], output_file, 30)
else:
output_subdir = output_dir / f"{output_dataset_id}_{output_type}"
output_subdir.mkdir(exist_ok=True)
img.save(out[output_type],
[f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if 'depth' in out:
colored_depths = img.colorize_depthmap(out['depth'][..., 0], model.args['sample_range'])
if args.output_type == 'video':
output_file = output_dir / f"{output_dataset_id}_depth.mp4"
img.save_video(colored_depths, output_file, 30)
else:
output_subdir = output_dir / f"{output_dataset_id}_depth"
output_subdir.mkdir(exist_ok=True)
img.save(colored_depths, [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
#output_subdir = output_dir / f"{output_dataset_id}_bins"
# output_dir.mkdir(exist_ok=True)
#img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
if args.time:
s = "Performance Report ==>\n"
res = get_perf_result()
if res is None:
s += "No available data.\n"
else:
for key, val in res.items():
path_segs = key.split("/")
s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n"
print(s)
This source diff could not be displayed because it is too large. You can view the blob instead.
import bpy
import json
import os
import math
import numpy as np
from itertools import product
scene = bpy.context.scene
cam_obj = scene.camera
cam = cam_obj.data
scene.cycles.device = 'GPU'
dataset_name = 'train'
tbox = [0.6, 0.6, 0.6]
rbox = [320, 40]
dataset_desc = {
'view_file_pattern': '%s/view_%%04d.png' % dataset_name,
"gl_coord": True,
'view_res': {
'x': 512,
'y': 512
},
'cam_params': {
'fov': 40.0,
'cx': 0.5,
'cy': 0.5,
'normalized': True
},
'range': {
'min': [-tbox[0] / 2, -tbox[1] / 2, -tbox[2] / 2, -rbox[0] / 2, -rbox[1] / 2],
'max': [tbox[0] / 2, tbox[1] / 2, tbox[2] / 2, rbox[0] / 2, rbox[1] / 2]
},
'samples': [5, 5, 5, 9, 2],
#'samples': [2000],
'view_centers': [],
'view_rots': []
}
data_desc_file = f'output/{dataset_name}.json'
if not os.path.exists('output'):
os.mkdir('output')
if os.path.exists(data_desc_file):
with open(data_desc_file, 'r') as fp:
dataset_desc.update(json.load(fp))
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
# Output resolution
scene.render.resolution_x = dataset_desc['view_res']['x']
scene.render.resolution_y = dataset_desc['view_res']['y']
# Field of view
cam.lens_unit = 'FOV'
cam.angle = math.radians(dataset_desc['cam_params']['fov'])
cam.dof.use_dof = False
def add_sample(i, x, y, z, rx, ry, render_only=False):
cam_obj.location = [x, y, z]
cam_obj.rotation_euler = [math.radians(ry), math.radians(rx), 0]
scene.render.filepath = 'output/' + dataset_desc['view_file_pattern'] % i
bpy.ops.render.render(write_still=True)
if not render_only:
dataset_desc['view_centers'].append(list(cam_obj.location))
dataset_desc['view_rots'].append([rx, ry])
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
for i in range(len(dataset_desc['view_centers'])):
if not os.path.exists('output/' + dataset_desc['view_file_pattern'] % i):
add_sample(i, *dataset_desc['view_centers'][i], *dataset_desc['view_rots'][i], render_only=True)
start_view = len(dataset_desc['view_centers'])
if len(dataset_desc['samples']) == 1:
range_min = np.array(dataset_desc['range']['min'])
range_max = np.array(dataset_desc['range']['max'])
samples = (range_max - range_min) * np.random.rand(dataset_desc['samples'][0], 5) + range_min
for i in range(start_view, dataset_desc['samples'][0]):
add_sample(i, *list(samples[i]))
else:
ranges = [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(0, 3)
] + [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(3, 5)
]
i = 0
for x, y, z, rx, ry in product(*ranges):
if i >= start_view:
add_sample(i, x, y, z, rx, ry)
i += 1
import bpy
import json
import os
import math
import numpy as np
from itertools import product
scene = bpy.context.scene
cam_obj = scene.camera
cam = cam_obj.data
scene.cycles.device = 'GPU'
dataset_name = 'train'
tbox = [0.7, 0.7, 0.7]
rbox = [300, 120]
dataset_desc = {
'view_file_pattern': '%s/view_%%04d.png' % dataset_name,
"gl_coord": True,
'view_res': {
'x': 512,
'y': 512
},
'cam_params': {
'fov': 60.0,
'cx': 0.5,
'cy': 0.5,
'normalized': True
},
'range': {
'min': [-tbox[0] / 2, -tbox[1] / 2, -tbox[2] / 2, -rbox[0] / 2, -rbox[1] / 2],
'max': [tbox[0] / 2, tbox[1] / 2, tbox[2] / 2, rbox[0] / 2, rbox[1] / 2]
},
'samples': [5, 5, 5, 6, 3],
#'samples': [2000],
'view_centers': [],
'view_rots': []
}
data_desc_file = f'output/{dataset_name}.json'
if not os.path.exists('output'):
os.mkdir('output')
if os.path.exists(data_desc_file):
with open(data_desc_file, 'r') as fp:
dataset_desc.update(json.load(fp))
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
# Output resolution
scene.render.resolution_x = dataset_desc['view_res']['x']
scene.render.resolution_y = dataset_desc['view_res']['y']
# Field of view
cam.lens_unit = 'FOV'
cam.angle = math.radians(dataset_desc['cam_params']['fov'])
cam.dof.use_dof = False
def add_sample(i, x, y, z, rx, ry, render_only=False):
cam_obj.location = [x, y, z]
cam_obj.rotation_euler = [math.radians(ry), math.radians(rx), 0]
scene.render.filepath = 'output/' + dataset_desc['view_file_pattern'] % i
bpy.ops.render.render(write_still=True)
if not render_only:
dataset_desc['view_centers'].append(list(cam_obj.location))
dataset_desc['view_rots'].append([rx, ry])
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
for i in range(len(dataset_desc['view_centers'])):
if not os.path.exists('output/' + dataset_desc['view_file_pattern'] % i):
add_sample(i, *dataset_desc['view_centers'][i], *dataset_desc['view_rots'][i], render_only=True)
start_view = len(dataset_desc['view_centers'])
if len(dataset_desc['samples']) == 1:
range_min = np.array(dataset_desc['range']['min'])
range_max = np.array(dataset_desc['range']['max'])
samples = (range_max - range_min) * np.random.rand(dataset_desc['samples'][0], 5) + range_min
for i in range(start_view, dataset_desc['samples'][0]):
add_sample(i, *list(samples[i]))
else:
ranges = [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(0, 3)
] + [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(3, 5)
]
i = 0
for x, y, z, rx, ry in product(*ranges):
if i >= start_view:
add_sample(i, x, y, z, rx, ry)
i += 1
import json
import sys
import os
import argparse
import numpy as np
import torch
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
from utils import misc
parser = argparse.ArgumentParser()
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_desc_path = args.dataset
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
data_dir = os.path.dirname(data_desc_path) + '/'
with open(data_desc_path, 'r') as fp:
dataset_desc = json.load(fp)
centers = np.array(dataset_desc['view_centers'])
t_max = np.max(centers, axis=0)
t_min = np.min(centers, axis=0)
dataset_desc['range'] = {
'min': [t_min[0], t_min[1], t_min[2], 0, 0],
'max': [t_max[0], t_max[1], t_max[2], 0, 0]
}
with open(data_desc_path, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
\ No newline at end of file
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from itertools import product, repeat from itertools import product, repeat
from pathlib import Path from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str, default='train1') parser.add_argument('-o', '--output', type=str, default='train1')
......
import sys
import os
import argparse
import json
import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
from utils import misc
from utils.colmap_read_model import read_model
parser = argparse.ArgumentParser()
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_dir = args.dataset
os.makedirs(data_dir, exist_ok=True)
out_desc_path = os.path.join(data_dir, "train.json")
cameras, images, points3D = read_model(os.path.join(data_dir, 'sparse/0'), '.bin')
print("Model loaded.")
print("num_cameras:", len(cameras))
print("num_images:", len(images))
print("num_points3D:", len(points3D))
cam = cameras[list(cameras.keys())[0]]
views = np.array([int(images[img_id].name[5:9]) for img_id in images])
view_centers = np.array([images[img_id].tvec for img_id in images])
view_rots = []
for img_id in images:
im = images[img_id]
R = im.qvec2rotmat()
view_rots.append(R.reshape([9]).tolist())
view_rots = np.array(view_rots)
indices = np.argsort(views)
views = views[indices]
view_centers = view_centers[indices]
view_rots = view_rots[indices]
pts = np.array([points3D[pt_id].xyz for pt_id in points3D])
zvals = np.sqrt(np.sum(pts * pts, 1))
dataset_desc = {
'view_file_pattern': f"images/image%04d.jpg",
'gl_coord': True,
'view_res': {
'x': cam.width,
'y': cam.height
},
'cam_params': {
'fx': cam.params[0],
'fy': cam.params[0],
'cx': cam.params[1],
'cy': cam.params[2]
},
'range': {
'min': np.min(view_centers, 0).tolist() + [0, 0],
'max': np.max(view_centers, 0).tolist() + [0, 0]
},
'depth_range': {
'min': np.min(zvals),
'max': np.max(zvals)
},
'samples': [len(view_centers)],
'view_centers': view_centers.tolist(),
'view_rots': view_rots.tolist(),
'views': views.tolist()
}
with open(out_desc_path, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
import json
import numpy as np
from itertools import product
tbox = [0.1, 0.1, 0.1]
dataset_desc = {
"gl_coord": True,
'view_res': {
'x': 1440,
'y': 1600
},
'cam_params': {
'fov': 110.0,
'cx': 0.5,
'cy': 0.5,
'normalized': True
},
'range': {
'min': [-tbox[0] / 2, -tbox[1] / 2, -tbox[2] / 2, 0, 0],
'max': [tbox[0] / 2, tbox[1] / 2, tbox[2] / 2, 0, 0]
},
'samples': [3, 3, 3, 1, 1],
'view_centers': [],
'view_rots': []
}
panorama_pose_content = []
data_desc_file = f'for_panorama.json'
panorama_pose_file = "for_lipis.json"
ranges = [
np.linspace(dataset_desc['range']['min'][i],
dataset_desc['range']['max'][i],
dataset_desc['samples'][i])
for i in range(0, 3)
]
for x, y, z in product(*ranges):
dataset_desc['view_centers'].append([x, y, z])
dataset_desc['view_rots'].append([0, 0])
panorama_pose_content.append({
"rotation_euler": [0.0, 0.0, 0.0],
"location": [x, y, z]
})
with open(data_desc_file, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
with open(panorama_pose_file, 'w') as fp:
json.dump(panorama_pose_content, fp, indent=4)
import json
import sys
import os
import argparse
import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from utils import seqs
from utils import math
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--rot-range', nargs='+', type=int)
parser.add_argument('-t', '--trans-range', nargs='+', type=float)
parser.add_argument('--fov', type=float)
parser.add_argument('--res', type=str)
parser.add_argument('--gl', action='store_true')
parser.add_argument('-s', '--seq', type=str, required=True)
parser.add_argument('-n', '--views', type=int, required=True)
parser.add_argument('-o', '--out-desc', type=str)
parser.add_argument('--ref', type=str)
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_dir = args.dataset
os.makedirs(data_dir, exist_ok=True)
out_desc_path = os.path.join(data_dir, (args.out_desc if args.out_desc else f"{args.seq}.json"))
if args.ref:
with open(os.path.join(data_dir, args.ref), 'r') as fp:
ref_desc = json.load(fp)
else:
if not args.trans_range or not args.rot_range or not args.fov or not args.res:
print('-r, -t, --fov, --res options are required if --ref is not specified')
exit(-1)
ref_desc = None
if args.trans_range:
trans_range = np.array(list(args.trans_range) * 3 if len(args.trans_range) == 1
else args.trans_range)
else:
trans_range = np.array(ref_desc['range']['max'][0:3]) - \
np.array(ref_desc['range']['min'][0:3])
if args.rot_range:
rot_range = np.array(list(args.rot_range) * 2 if len(args.rot_range) == 1
else args.rot_range)
else:
rot_range = np.array(ref_desc['range']['max'][3:5]) - \
np.array(ref_desc['range']['min'][3:5])
filter_range = np.concatenate([trans_range, rot_range])
if args.fov:
cam_params = {
'fov': args.fov,
'cx': 0.5,
'cy': 0.5,
'normalized': True
}
else:
cam_params = ref_desc['cam_params']
if args.res:
res = tuple(int(s) for s in args.res.split('x'))
res = {'x': res[0], 'y': res[1]}
else:
res = ref_desc['view_res']
if args.seq == 'helix':
centers, rots = seqs.helix(trans_range, 4, args.views)
elif args.seq == 'scan_around':
centers, rots = seqs.scan_around(trans_range, 1, args.views)
elif args.seq == 'look_around':
centers, rots = seqs.look_around(trans_range, args.views)
rots *= 180 / math.pi
gl = args.gl or ref_desc and ref_desc.get('gl_coord')
if gl:
centers[:, 2] *= -1
rots[:, 0] *= -1
dataset_desc = {
'gl_coord': gl,
'view_res': res,
'cam_params': cam_params,
'range': {
'min': (-0.5 * filter_range).tolist(),
'max': (0.5 * filter_range).tolist()
},
'samples': [args.views],
'view_centers': centers.tolist(),
'view_rots': rots.tolist()
}
with open(out_desc_path, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
import json
import sys
import os
import argparse
import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
from utils import misc
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--rot-range', nargs='+', type=int)
parser.add_argument('-t', '--trans-range', nargs='+', type=float)
parser.add_argument('-k', '--trainset-ratio', type=float, default=0.7)
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_desc_path = args.dataset
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
data_dir = os.path.dirname(data_desc_path) + '/'
with open(data_desc_path, 'r') as fp:
dataset_desc = json.load(fp)
if args.trans_range:
trans_range = np.array(args.trans_range)
else:
trans_range = np.array(dataset_desc['range']['max'][0:3]) - \
np.array(dataset_desc['range']['min'][0:3])
if args.rot_range:
rot_range = np.array(args.rot_range)
else:
rot_range = np.array(dataset_desc['range']['max'][3:5]) - \
np.array(dataset_desc['range']['min'][3:5])
filter_range = np.concatenate([trans_range, rot_range])
out_data_dir = data_dir + 'r%dx%d_t%.1fx%.1fx%.1f/' % (
int(rot_range[0]), int(rot_range[1]),
trans_range[0], trans_range[1], trans_range[2]
)
dataset_version = 0
while True:
out_trainset_name = f'train_{dataset_version}'
out_testset_name = f'test_{dataset_version}'
if not os.path.exists(out_data_dir + out_trainset_name):
break
dataset_version += 1
def in_range(val, range): return val >= -range / 2 and val <= range / 2
views = []
for i in range(len(dataset_desc['view_centers'])):
if in_range(dataset_desc['view_rots'][i][0], rot_range[0]) and \
in_range(dataset_desc['view_rots'][i][1], rot_range[1]) and \
in_range(dataset_desc['view_centers'][i][0], trans_range[0]) and \
in_range(dataset_desc['view_centers'][i][1], trans_range[1]) and \
in_range(dataset_desc['view_centers'][i][2], trans_range[2]):
views.append(i)
if len(views) < 100:
print(f'Number of views in range is too small ({len(views)})')
exit()
views = np.random.permutation(views)
n_train_views = int(len(views) * args.trainset_ratio)
train_views = np.sort(views[:n_train_views])
test_views = np.sort(views[n_train_views:])
print('Train set views: ', len(train_views))
print('Test set views: ', len(test_views))
def create_subset(views, out_desc_name):
views = views.tolist()
subset_desc = dataset_desc.copy()
subset_desc['view_file_pattern'] = \
f"{out_desc_name}/{dataset_desc['view_file_pattern'].split('/')[-1]}"
subset_desc['range'] = {
'min': list(-filter_range / 2),
'max': list(filter_range / 2)
}
subset_desc['samples'] = [int(len(views))]
subset_desc['views'] = views
subset_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
subset_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
with open(os.path.join(out_data_dir, f'{out_desc_name}.json'), 'w') as fp:
json.dump(subset_desc, fp, indent=4)
os.makedirs(os.path.join(out_data_dir, out_desc_name), exist_ok=True)
for i in range(len(views)):
os.symlink(os.path.join('../../', dataset_desc['view_file_pattern'] % views[i]),
os.path.join(out_data_dir, subset_desc['view_file_pattern'] % views[i]))
os.makedirs(out_data_dir, exist_ok=True)
create_subset(train_views, out_trainset_name)
create_subset(train_views, out_testset_name)
...@@ -4,10 +4,18 @@ import os ...@@ -4,10 +4,18 @@ import os
import json import json
import argparse import argparse
from typing import Mapping from typing import Mapping
from data import get_dataset_desc_path, get_data_path
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from utils import misc parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, nargs='+')
parser.add_argument('output', type=str)
args = parser.parse_args()
input = [get_dataset_desc_path(path) for path in args.input]
output = get_dataset_desc_path(args.output)
def copy_images(src_path, dst_path, n, offset=0): def copy_images(src_path, dst_path, n, offset=0):
...@@ -16,38 +24,24 @@ def copy_images(src_path, dst_path, n, offset=0): ...@@ -16,38 +24,24 @@ def copy_images(src_path, dst_path, n, offset=0):
copy(src_path % i, dst_path % (i + offset)) copy(src_path % i, dst_path % (i + offset))
input_data_desc_paths = [ with open(input[0], 'r') as fp:
'/home/dengnc/dvs/data/__new/barbershop_all/nerf_cvt.json',
'/home/dengnc/dvs/data/__new/__demo/fvvdp/0816_3_barbershop_cvt.json',
'/home/dengnc/dvs/data/__new/__demo/fvvdp/0816_1_barbershop_cvt.json',
]
output_data_desc_path = '/home/dengnc/dvs/data/__new/__demo/fvvdp/barbershop_3&1_nerf.json'
output_data_name = os.path.splitext(os.path.basename(output_data_desc_path))[0]
output_dir = os.path.dirname(output_data_desc_path)
with open(input_data_desc_paths[0], 'r') as fp:
dataset_desc: Mapping = json.load(fp) dataset_desc: Mapping = json.load(fp)
n_views = 0 n_views = 0
# Copy images of the first dataset for i in range(len(input)):
for i in range(len(input_data_desc_paths)):
if i == 0: if i == 0:
input_desc = dataset_desc input_desc = dataset_desc
else: else:
with open(input_data_desc_paths[i], 'r') as fp: with open(input[i], 'r') as fp:
input_desc: Mapping = json.load(fp) input_desc: Mapping = json.load(fp)
dataset_desc['view_centers'] += input_desc['view_centers'] dataset_desc['view_centers'] += input_desc['view_centers']
dataset_desc['view_rots'] += input_desc['view_rots'] dataset_desc['view_rots'] += input_desc['view_rots']
copy_images( copy_images(get_data_path(input[i], input_desc['view_file_pattern']),
os.path.join(os.path.dirname(input_data_desc_paths[i]), input_desc['view_file_pattern']), get_data_path(output, dataset_desc['view_file_pattern']),
os.path.join(output_dir, output_data_name, 'view_%04d.png'), len(input_desc['view_centers']), n_views)
len(input_desc['view_centers']), n_views
)
n_views += len(input_desc['view_centers']) n_views += len(input_desc['view_centers'])
dataset_desc['samples'] = [n_views] dataset_desc['samples'] = [n_views]
dataset_desc['view_file_pattern'] = os.path.join(output_data_name, 'view_%04d.png')
with open(output_data_desc_path, 'w') as fp: with open(output, 'w') as fp:
json.dump(dataset_desc, fp, indent=4) json.dump(dataset_desc, fp, indent=4)
import json
import sys
import os
import argparse
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from data import get_dataset_desc_path
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str, nargs="+", required=True)
parser.add_argument("-v", "--views", type=int, nargs="+", required=True)
parser.add_argument('dataset', type=str)
args = parser.parse_args()
input = get_dataset_desc_path(args.dataset)
outputs = [
get_dataset_desc_path(input.with_name(f"{input.stem}_{appendix}"))
for appendix in args.output
]
with open(input, 'r') as fp:
input_desc: dict = json.load(fp)
n_views = len(input_desc['view_centers'])
assert(len(args.views) == len(outputs))
sum_views = sum(args.views)
for i in range(len(args.views)):
if args.views[i] == -1:
args.views[i] = n_views - sum_views - 1
sum_views = n_views
break
assert(sum_views <= n_views)
for i in range(len(args.views)):
assert(args.views[i] > 0)
offset = 0
for i in range(len(outputs)):
n = args.views[i]
end = offset + n
output_desc = input_desc.copy()
output_desc['samples'] = args.views[i]
if 'views' in output_desc:
output_desc['views'] = output_desc['views'][offset:end]
else:
output_desc['views'] = list(range(offset, end))
output_desc['view_centers'] = output_desc['view_centers'][offset:end]
if 'view_rots' in output_desc:
output_desc['view_rots'] = output_desc['view_rots'][offset:end]
with open(outputs[i], 'w') as fp:
json.dump(output_desc, fp, indent=4)
# Create symbol links of images
out_dir = outputs[i].with_suffix('')
out_dir.mkdir(exist_ok=True)
for k in range(n):
os.symlink(Path("..") / input.stem / (output_desc['view_file_pattern'] % output_desc['views'][k]),
out_dir / (input_desc['view_file_pattern'] % output_desc['views'][k]))
offset += args.views[i]
from pathlib import Path from pathlib import Path
import sys import sys
import argparse import argparse
import math
import torch import torch
import torchvision.transforms.functional as trans_F import torchvision.transforms.functional as trans_F
sys.path.append(str(Path(sys.path[0]).parent.absolute())) sys.path.append(str(Path(sys.path[0]).parent.absolute()))
from utils import img from utils import img
from utils import math
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str) parser.add_argument('-o', '--output', type=str)
......
import cv2 import cv2
import numpy as np import numpy as np
import os import os
from utils.constants import * import sys
from utils import math
def genGaussiankernel(width, sigma): def genGaussiankernel(width, sigma):
...@@ -87,7 +88,7 @@ def foveat_img(im, fixs): ...@@ -87,7 +88,7 @@ def foveat_img(im, fixs):
# B # B
Bs = [] Bs = []
for i in range(1, prNum): for i in range(1, prNum):
Bs.append((0.5 - Ts[i]) / (Ts[i-1] - Ts[i] + TINY_FLOAT)) Bs.append((0.5 - Ts[i]) / (Ts[i-1] - Ts[i] + math.tiny))
# M # M
Ms = np.zeros((prNum, R.shape[0], R.shape[1])) Ms = np.zeros((prNum, R.shape[0], R.shape[1]))
......
import argparse import argparse
import logging import logging
import os import os
from pathlib import Path
import sys import sys
from pathlib import Path
from typing import List
import model as mdl import model as mdl
import train import train
from utils import color
from utils import device from utils import device
from data.dataset_factory import * from utils import netio
from data.loader import DataLoader from data import *
from utils.misc import list_epochs, print_and_log from utils.misc import print_and_log
RAYS_PER_BATCH = 2 ** 12 RAYS_PER_BATCH = 2 ** 12
DATA_LOADER_CHUNK_SIZE = 1e8 DATA_LOADER_CHUNK_SIZE = 1e8
root_dir = Path.cwd() root_dir = Path(__file__).absolute().parent
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, parser.add_argument('-c', '--config', type=str,
help='Net config files') help='Net config files')
parser.add_argument('-e', '--epochs', type=int, default=50, parser.add_argument('-e', '--epochs', type=int,
help='Max epochs for train') help='Max epochs for train')
parser.add_argument('--perf', type=int, default=0, parser.add_argument('--perf', type=int,
help='Performance measurement frames (0 for disabling performance measurement)') help='Performance measurement frames (0 for disabling performance measurement)')
parser.add_argument('--prune', type=int, default=5, parser.add_argument('--prune', type=int, nargs='+',
help='Prune voxels on every # epochs') help='Prune voxels on every # epochs')
parser.add_argument('--split', type=int, default=10, parser.add_argument('--split', type=int, nargs='+',
help='Split voxels on every # epochs') help='Split voxels on every # epochs')
parser.add_argument('--freeze', type=int, nargs='+',
help='freeze levels on epochs')
parser.add_argument('--checkpoint-interval', type=int)
parser.add_argument('--views', type=str, parser.add_argument('--views', type=str,
help='Specify the range of views to train') help='Specify the range of views to train')
parser.add_argument('path', type=str, parser.add_argument('path', type=str,
help='Dataset description file') help='Dataset description file')
args = parser.parse_args() args = parser.parse_args()
views_to_load = range(*[int(val) for val in args.views.split('-')]) if args.views else None
argpath = Path(args.path) argpath = Path(args.path)
# argpath: May be model path or data path # argpath: May be model path or data path
# 1) model path: continue training on the specified model # 1) model path: continue training on the specified model
# 2) data path: train a new model using specified dataset # 2) data path: train a new model using specified dataset
if argpath.suffix == ".tar":
args.mdl_path = argpath
else:
existed_epochs = list_epochs(argpath, "checkpoint_*.tar")
args.mdl_path = argpath / f"checkpoint_{existed_epochs[-1]}.tar" if existed_epochs else None
if args.mdl_path: def load_dataset(data_path: Path):
print(f"Loading dataset {data_path}")
try:
dataset = DatasetFactory.load(data_path, views_to_load=views_to_load)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
os.chdir(dataset.root)
return dataset, dataset.name
except FileNotFoundError:
return load_multiscale_dataset(data_path)
def load_multiscale_dataset(data_path: Path):
if not data_path.is_dir():
raise ValueError(
f"Path {data_path} is not a directory")
dataset: List[Union[PanoDataset, ViewDataset]] = []
for sub_data_desc_path in data_path.glob("*.json"):
sub_dataset = DatasetFactory.load(sub_data_desc_path, views_to_load=views_to_load)
print(f"Sub-dataset loaded: {sub_dataset.root}/{sub_dataset.name}")
dataset.append(sub_dataset)
if len(dataset) == 0:
raise ValueError(f"Path {data_path} does not contain sub-datasets")
os.chdir(data_path.parent)
return dataset, data_path.name
try:
states, checkpoint_path = netio.load_checkpoint(argpath)
# Infer dataset path from model path # Infer dataset path from model path
# The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar # The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar
dataset_name = args.mdl_path.parent.parent.name model_name = checkpoint_path.parts[-2]
dataset_dir = args.mdl_path.parent.parent.parent.parent dataset, dataset_name = load_dataset(
args.data_path = dataset_dir / dataset_name Path(*checkpoint_path.parts[:-4]) / checkpoint_path.parts[-3])
args.mdl_path = args.mdl_path.relative_to(dataset_dir) except Exception:
else:
args.data_path = argpath
args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None
dataset = DatasetFactory.load(args.data_path, views_to_load=args.views)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
os.chdir(dataset.root)
if args.mdl_path:
# Load model to continue training
model, states = mdl.load(args.mdl_path)
model_name = args.mdl_path.parent.name
model_class = model.__class__.__name__
model_args = model.args
else:
# Create model from specified configuration
with Path(f'{root_dir}/configs/{args.config}.json').open() as fp:
config = json.load(fp)
model_name = args.config model_name = args.config
model_class = config['model'] dataset, dataset_name = load_dataset(argpath)
model_args = config['args']
model_args['bbox'] = dataset.bbox # Load state 0 from specified configuration
model_args['depth_range'] = dataset.depth_range with Path(f'{root_dir}/configs/{args.config}.json').open() as fp:
model, states = mdl.create(model_class, model_args), None states = json.load(fp)
model.to(device.default()) states['args']['bbox'] = dataset[0].bbox if isinstance(dataset, list) else dataset.bbox
states['args']['depth_range'] = dataset[0].depth_range if isinstance(dataset, list)\
run_dir = Path(f"_nets/{dataset.name}/{model_name}") else dataset.depth_range
if 'train' not in states:
states['train'] = {}
if args.prune is not None:
states['train']['prune_epochs'] = args.prune
if args.split is not None:
states['train']['split_epochs'] = args.split
if args.freeze is not None:
states['train']['freeze_epochs'] = args.freeze
if args.perf is not None:
states['train']['perf_frames'] = args.perf
if args.checkpoint_interval is not None:
states['train']['checkpoint_interval'] = args.checkpoint_interval
if args.epochs is not None:
states['train']['max_epochs'] = args.epochs
model = mdl.deserialize(states).to(device.default())
# Initialize run directory
run_dir = Path(f"_nets/{dataset_name}/{model_name}")
run_dir.mkdir(parents=True, exist_ok=True) run_dir.mkdir(parents=True, exist_ok=True)
# Initialize logging
log_file = run_dir / "train.log" log_file = run_dir / "train.log"
logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO, logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO,
filename=log_file, filemode='a' if log_file.exists() else 'w') filename=log_file, filemode='a' if log_file.exists() else 'w')
print_and_log(f"model: {model_name} ({model_class})")
print_and_log(f"args: {json.dumps(model.args0)}") def log_exception(exc_type, exc_value, exc_traceback):
if not issubclass(exc_type, KeyboardInterrupt):
logging.exception(exc_value, exc_info=(exc_type, exc_value, exc_traceback))
sys.__excepthook__(exc_type, exc_value, exc_traceback)
sys.excepthook = log_exception
print_and_log(f"model: {model_name} ({model.cls})")
print_and_log(f"args:")
model.print_config()
print(model)
if __name__ == "__main__": if __name__ == "__main__":
# 1. Initialize data loader # 1. Initialize data loader
data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE, data_loader = get_loader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
shuffle=True, enable_preload=True, shuffle=True, enable_preload=False, color=model.color)
color=color.from_str(model.args['color']))
# 2. Initialize model and trainer # 2. Initialize model and trainer
trainer = train.get_trainer(model, run_dir=run_dir, states=states, perf_frames=args.perf, trainer = train.get_trainer(model, run_dir, states)
pruning_loop=args.prune, splitting_loop=args.split)
# 3. Train # 3. Train
trainer.train(data_loader, args.epochs) trainer.train(data_loader)
\ No newline at end of file
import importlib import importlib
import os import os
from pathlib import Path
from model.base import BaseModel from model.base import BaseModel
from . import base from .train import train_classes, Train
# Automatically import any python files this directory # Automatically import any python files this directory
...@@ -18,9 +19,9 @@ for file in os.listdir(package_dir): ...@@ -18,9 +19,9 @@ for file in os.listdir(package_dir):
def get_class(class_name: str) -> type: def get_class(class_name: str) -> type:
return base.train_classes[class_name] return train_classes[class_name]
def get_trainer(model: BaseModel, **kwargs) -> base.Train: def get_trainer(model: BaseModel, run_dir: Path, states: dict) -> Train:
train_class = get_class(model.TrainerClass) train_class = get_class(model.TrainerClass)
return train_class(model, **kwargs) return train_class(model, run_dir, states)
import csv import csv
import json
import logging import logging
import sys
import time
import torch import torch
import torch.nn.functional as nn_f import torch.nn.functional as nn_f
from typing import Dict from typing import Any, Dict, Union
from pathlib import Path from pathlib import Path
import loss import loss
from utils.constants import HUGE_FLOAT from utils import netio, math
from utils.misc import format_time from utils.misc import format_time, print_and_log
from utils.progress_bar import progress_bar from utils.progress_bar import progress_bar
from utils.perf import Perf, checkpoint, enable_perf, perf, get_perf_result from utils.perf import Perf, enable_perf, perf, get_perf_result
from utils.env import set_env
from utils.type import InputData, ReturnData
from data.loader import DataLoader from data.loader import DataLoader
from model import serialize
from model.base import BaseModel from model.base import BaseModel
from model import save
train_classes = {} train_classes = {}
...@@ -34,62 +35,74 @@ class Train(object, metaclass=BaseTrainMeta): ...@@ -34,62 +35,74 @@ class Train(object, metaclass=BaseTrainMeta):
def perf_mode(self): def perf_mode(self):
return self.perf_frames > 0 return self.perf_frames > 0
def __init__(self, model: BaseModel, *, def _arg(self, name: str, default=None):
run_dir: Path, states: dict = None, perf_frames: int = 0) -> None: return self.states.get("train", {}).get(name, default)
def __init__(self, model: BaseModel, run_dir: Path, states: dict) -> None:
super().__init__() super().__init__()
print_and_log(
f"Create trainer {__class__} with args: {json.dumps(states.get('train', {}))}")
self.model = model self.model = model
self.epoch = 0
self.iters = 0
self.run_dir = run_dir self.run_dir = run_dir
self.states = states
self.epoch = states.get("epoch", 0)
self.iters = states.get("iters", 0)
self.max_epochs = self._arg("max_epochs", 50)
self.checkpoint_interval = self._arg("checkpoint_interval", 10)
self.perf_frames = self._arg("perf_frames", 0)
self.model.trainer = self
self.model.train() self.model.train()
self.reset_optimizer()
if states: self.reset_optimizer()
if 'epoch' in states: if 'opti' in states:
self.epoch = states['epoch'] self.optimizer.load_state_dict(states['opti'])
if 'iters' in states:
self.iters = states['iters']
if 'opti' in states:
self.optimizer.load_state_dict(states['opti'])
# For performance measurement # For performance measurement
self.perf_frames = perf_frames
if self.perf_mode: if self.perf_mode:
enable_perf() enable_perf()
self.env = {
"trainer": self
}
def reset_optimizer(self): def reset_optimizer(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
def train(self, data_loader: DataLoader, max_epochs: int): def train(self, data_loader: DataLoader):
set_env(self.env)
self.data_loader = data_loader self.data_loader = data_loader
self.iters_per_epoch = self.perf_frames or len(data_loader) self.iters_per_epoch = self.perf_frames or len(data_loader)
print("Begin training...") print(f"Begin training... Max epochs: {self.max_epochs}")
while self.epoch < max_epochs: while self.epoch < self.max_epochs:
self.epoch += 1
self._train_epoch() self._train_epoch()
self._save_checkpoint() self._save_checkpoint()
print("Train finished") print("Train finished")
def _save_checkpoint(self): def _save_checkpoint(self):
save(self.run_dir / f'checkpoint_{self.epoch}.tar', self.model, epoch=self.epoch, (self.run_dir / '_misc').mkdir(exist_ok=True)
iters=self.iters, opti=self.optimizer.state_dict()) # Clean checkpoints
for i in range(1, self.epoch): for i in range(1, self.epoch):
if i % 10 != 0: if i % self.checkpoint_interval != 0:
(self.run_dir / f'checkpoint_{i}.tar').unlink(missing_ok=True) checkpoint_path = self.run_dir / f'checkpoint_{i}.tar'
if checkpoint_path.exists():
def _show_progress(self, iters_in_epoch: int, loss: Dict[str, float] = {}): checkpoint_path.rename(self.run_dir / f'_misc/checkpoint_{i}.tar')
loss_val = loss.get('val', 0)
loss_min = loss.get('min', 0) # Save checkpoint
loss_max = loss.get('max', 0) self.states.update({
loss_avg = loss.get('avg', 0) **serialize(self.model),
"epoch": self.epoch,
"iters": self.iters,
"opti": self.optimizer.state_dict()
})
netio.save_checkpoint(self.states, self.run_dir, self.epoch)
def _show_progress(self, iters_in_epoch: int, avg_loss: float = 0, recent_loss: float = 0):
iters_per_epoch = self.perf_frames or len(self.data_loader) iters_per_epoch = self.perf_frames or len(self.data_loader)
progress_bar(iters_in_epoch, iters_per_epoch, progress_bar(iters_in_epoch, iters_per_epoch,
f"Loss: {loss_val:.2e} ({loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e})", f"Loss: {recent_loss:.2e} ({avg_loss:.2e})",
f"Epoch {self.epoch:<3d}", f"Epoch {self.epoch + 1:<3d}",
f" {self.run_dir}") f" {self.run_dir.absolute()}")
def _show_perf(self): def _show_perf(self):
s = "Performance Report ==>\n" s = "Performance Report ==>\n"
...@@ -102,71 +115,100 @@ class Train(object, metaclass=BaseTrainMeta): ...@@ -102,71 +115,100 @@ class Train(object, metaclass=BaseTrainMeta):
s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n" s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n"
print(s) print(s)
def _forward(self, data: InputData) -> ReturnData:
return self.model(data, 'color', 'energies', 'speculars')
@perf @perf
def _train_iter(self, rays_o: torch.Tensor, rays_d: torch.Tensor, def _train_iter(self, data: Dict[str, Union[torch.Tensor, Any]]) -> float:
extra: Dict[str, torch.Tensor]) -> float: def filtered_data(data, filter):
out = self.model(rays_o, rays_d, extra_outputs=['energies', 'speculars']) if filter is not None:
if 'rays_mask' in out: return data[filter]
extra = {key: value[out['rays_mask']] for key, value in extra.items()} return data
checkpoint("Forward")
with perf("Forward"):
self.optimizer.zero_grad() if isinstance(data, list):
loss_val = loss.mse_loss(out['color'], extra['color']) out_colors = []
if self.model.args.get('density_regularization_weight'): out_energies = []
loss_val += loss.cauchy_loss(out['energies'], out_speculars = []
s=self.model.args['density_regularization_scale']) \ gt_colors = []
* self.model.args['density_regularization_weight'] for datum in data:
if self.model.args.get('specular_regularization_weight'): partial_out = self._forward(datum)
loss_val += loss.cauchy_loss(out['speculars'], out_colors.append(partial_out['color'])
s=self.model.args['specular_regularization_scale']) \ out_energies.append(partial_out['energies'].flatten())
* self.model.args['specular_regularization_weight'] if 'speculars' in partial_out:
checkpoint("Compute loss") out_speculars.append(partial_out['speculars'].flatten())
gt_colors.append(filtered_data(datum["color"], partial_out.get("rays_filter")))
loss_val.backward() out_colors = torch.cat(out_colors)
checkpoint("Backward") out_energies = torch.cat(out_energies)
out_speculars = torch.cat(out_speculars) if len(out_speculars) > 0 else None
self.optimizer.step() gt_colors = torch.cat(gt_colors)
checkpoint("Update") else:
out = self._forward(data)
out_colors = out['color']
out_energies = out['energies']
out_speculars = out.get('speculars')
gt_colors = filtered_data(data['color'], out.get("rays_filter"))
with perf("Compute loss"):
loss_val = loss.mse_loss(out_colors, gt_colors)
if self._arg("density_regularization_weight"):
loss_val += loss.cauchy_loss(out_energies, s=self._arg("density_regularization_scale"))\
* self._arg("density_regularization_weight")
if self._arg("specular_regularization_weight") and out_speculars is not None:
loss_val += loss.cauchy_loss(out_speculars, s=self._arg("specular_regularization_scale")) \
* self._arg("specular_regularization_weight")
#return loss_val.item() # TODO remove this line
with perf("Backward"):
self.optimizer.zero_grad(True)
loss_val.backward()
with perf("Update"):
self.optimizer.step()
return loss_val.item() return loss_val.item()
def _train_epoch(self): def _train_epoch(self):
iters_in_epoch = 0 iters_in_epoch = 0
loss_min = HUGE_FLOAT recent_loss = []
loss_max = 0 tot_loss = 0
loss_avg = 0
train_epoch_node = Perf.Node("Train Epoch")
self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0})
for idx, rays_o, rays_d, extra in self.data_loader:
loss_val = self._train_iter(rays_o, rays_d, extra)
loss_min = min(loss_min, loss_val) train_epoch_node = Perf.Node("Train Epoch")
loss_max = max(loss_max, loss_val)
loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
self._show_progress(iters_in_epoch)
for data in self.data_loader:
loss_val = self._train_iter(data)
self.iters += 1 self.iters += 1
iters_in_epoch += 1 iters_in_epoch += 1
self._show_progress(iters_in_epoch, loss={
'val': loss_val, recent_loss = (recent_loss + [loss_val])[-50:]
'min': loss_min, recent_avg_loss = sum(recent_loss) / len(recent_loss)
'max': loss_max, tot_loss += loss_val
'avg': loss_avg avg_loss = tot_loss / iters_in_epoch
})
#loss_min = min(loss_min, loss_val)
#loss_max = max(loss_max, loss_val)
#loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
self._show_progress(iters_in_epoch, avg_loss=avg_loss, recent_loss=recent_avg_loss)
if self.perf_mode and iters_in_epoch >= self.perf_frames: if self.perf_mode and iters_in_epoch >= self.perf_frames:
self._show_perf() self._show_perf()
exit() exit()
train_epoch_node.close() train_epoch_node.close()
torch.cuda.synchronize() torch.cuda.synchronize()
self.epoch += 1
epoch_dur = train_epoch_node.duration() / 1000 epoch_dur = train_epoch_node.duration() / 1000
logging.info(f"Epoch {self.epoch} spent {format_time(epoch_dur)} " logging.info(f"Epoch {self.epoch} spent {format_time(epoch_dur)} "
f"(Avg. {format_time(epoch_dur / self.iters_per_epoch)}/iter). " f"(Avg. {format_time(epoch_dur / self.iters_per_epoch)}/iter). "
f"Loss is {loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e}") f"Loss is {avg_loss:.2e}")
#print(list(self.model.model(0).named_parameters())[2])
#print(list(self.model.model(1).named_parameters())[2])
def _train_epoch_debug(self): # TBR def _train_epoch_debug(self): # TBR
iters_in_epoch = 0 iters_in_epoch = 0
loss_min = HUGE_FLOAT loss_min = math.huge
loss_max = 0 loss_max = 0
loss_avg = 0 loss_avg = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment