Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
...@@ -37,20 +37,20 @@ for i in range(3): ...@@ -37,20 +37,20 @@ for i in range(3):
for j in range(2): for j in range(2):
out_desc_name = f'part{idx:d}' out_desc_name = f'part{idx:d}'
out_desc = dataset_desc.copy() out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = f'{out_desc_name}/view_%04d.png' out_desc['color_file'] = f'{out_desc_name}/view_%04d.png'
n_x = out_desc['samples'][3] // 3 n_x = out_desc['samples'][3] // 3
n_y = out_desc['samples'][4] // 2 n_y = out_desc['samples'][4] // 2
views = indices[..., i * n_x:(i + 1) * n_x, j * n_y:(j + 1) * n_y].flatten().tolist() views = indices[..., i * n_x:(i + 1) * n_x, j * n_y:(j + 1) * n_y].flatten().tolist()
out_desc['samples'] = [len(views)] out_desc['samples'] = [len(views)]
out_desc['views'] = views out_desc['views'] = views
out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist() out_desc['centers'] = np.array(dataset_desc['centers'])[views].tolist()
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() out_desc['rots'] = np.array(dataset_desc['rots'])[views].tolist()
with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
json.dump(out_desc, fp, indent=4) json.dump(out_desc, fp, indent=4)
os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True) os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
for k in range(len(views)): for k in range(len(views)):
os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), os.symlink(os.path.join('..', dataset_desc['color_file'] % views[k]),
os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) os.path.join(data_dir, out_desc['color_file'] % views[k]))
idx += 1 idx += 1
''' '''
...@@ -60,24 +60,24 @@ for xi in range(0, 4, 2): ...@@ -60,24 +60,24 @@ for xi in range(0, 4, 2):
for zi in range(0, 4, 2): for zi in range(0, 4, 2):
out_desc_name = f'part{idx:d}' out_desc_name = f'part{idx:d}'
out_desc = dataset_desc.copy() out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = f'{out_desc_name}/view_%04d.png' out_desc['color_file'] = f'{out_desc_name}/view_%04d.png'
views = indices[xi:xi + 2, yi:yi + 2, zi:zi + 2].flatten().tolist() views = indices[xi:xi + 2, yi:yi + 2, zi:zi + 2].flatten().tolist()
out_desc['samples'] = [len(views)] out_desc['samples'] = [len(views)]
out_desc['views'] = views out_desc['views'] = views
out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist() out_desc['centers'] = np.array(dataset_desc['centers'])[views].tolist()
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() out_desc['rots'] = np.array(dataset_desc['rots'])[views].tolist()
with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
json.dump(out_desc, fp, indent=4) json.dump(out_desc, fp, indent=4)
os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True) os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
for k in range(len(views)): for k in range(len(views)):
os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), os.symlink(os.path.join('..', dataset_desc['color_file'] % views[k]),
os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) os.path.join(data_dir, out_desc['color_file'] % views[k]))
idx += 1 idx += 1
''' '''
def extract_by_grid(*grid_indices): def extract_by_grid(*grid_indices):
indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples']) indices = torch.arange(len(dataset_desc['centers'])).view(dataset_desc['samples'])
views = [] views = []
for idx in product(*grid_indices): for idx in product(*grid_indices):
views += indices[idx].flatten().tolist() views += indices[idx].flatten().tolist()
...@@ -86,11 +86,11 @@ def extract_by_grid(*grid_indices): ...@@ -86,11 +86,11 @@ def extract_by_grid(*grid_indices):
def extract_by_trans(max_trans, max_views): def extract_by_trans(max_trans, max_views):
if max_trans is not None: if max_trans is not None:
centers = np.array(dataset_desc['view_centers']) centers = np.array(dataset_desc['centers'])
trans = np.linalg.norm(centers, axis=-1) trans = np.linalg.norm(centers, axis=-1)
indices = np.nonzero(trans <= max_trans)[0] indices = np.nonzero(trans <= max_trans)[0]
else: else:
indices = np.arange(len(dataset_desc['view_centers'])) indices = np.arange(len(dataset_desc['centers']))
if max_views is not None: if max_views is not None:
indices = np.sort(indices[np.random.permutation(indices.shape[0])[:max_views]]) indices = np.sort(indices[np.random.permutation(indices.shape[0])[:max_views]])
return indices.tolist() return indices.tolist()
...@@ -101,18 +101,18 @@ if args.grids: ...@@ -101,18 +101,18 @@ if args.grids:
else: else:
views = extract_by_trans(args.trans, args.views) views = extract_by_trans(args.trans, args.views)
image_path = dataset_desc['view_file_pattern'] image_path = dataset_desc['color_file']
if "/" not in image_path: if "/" not in image_path:
image_path = in_name + "/" + image_path image_path = in_name + "/" + image_path
# Save new dataset # Save new dataset
out_desc = dataset_desc.copy() out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = image_path.split('/')[-1] out_desc['color_file'] = image_path.split('/')[-1]
out_desc['samples'] = [len(views)] out_desc['samples'] = [len(views)]
out_desc['views'] = views out_desc['views'] = views
out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist() out_desc['centers'] = np.array(dataset_desc['centers'])[views].tolist()
if 'view_rots' in dataset_desc: if 'rots' in dataset_desc:
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() out_desc['rots'] = np.array(dataset_desc['rots'])[views].tolist()
# Write new data desc # Write new data desc
with open(out_desc_path, 'w') as fp: with open(out_desc_path, 'w') as fp:
...@@ -123,7 +123,7 @@ out_dir.mkdir() ...@@ -123,7 +123,7 @@ out_dir.mkdir()
for k in range(len(views)): for k in range(len(views)):
if out_dir.parent.absolute() == root_dir.absolute(): if out_dir.parent.absolute() == root_dir.absolute():
os.symlink(Path("..") / (image_path % views[k]), os.symlink(Path("..") / (image_path % views[k]),
out_dir / (out_desc['view_file_pattern'] % views[k])) out_dir / (out_desc['color_file'] % views[k]))
else: else:
os.symlink(root_dir.absolute() / (image_path % views[k]), os.symlink(root_dir.absolute() / (image_path % views[k]),
out_dir / (out_desc['view_file_pattern'] % views[k])) out_dir / (out_desc['color_file'] % views[k]))
import sys
import os
import argparse
import numpy as np
import cv2
from tqdm import tqdm
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--start', type=int)
parser.add_argument('-t', '--duration', type=int)
parser.add_argument('--fps', type=str, required=True)
parser.add_argument('datadir', type=str)
args = parser.parse_args()
os.chdir(args.datadir)
rawK = np.array([
[1369.757446, 0., 1838.643555, 0., 1369.757446, 1524.068604, 0., 0., 1.],
[1367.517944, 0., 1840.157837, 0., 1367.517944, 1536.036133, 0., 0., 1.],
[1369.830322, 0., 1827.990723, 0., 1369.830322, 1514.463135, 0., 0., 1.],
[1368.966187, 0., 1829.976196, 0., 1368.966187, 1512.734375, 0., 0., 1.],
[1373.654297, 0., 1838.130859, 0., 1373.654297, 1534.985840, 0., 0., 1.],
[1365.853027, 0., 1835.100830, 0., 1365.853027, 1533.032959, 0., 0., 1.]
]).reshape(-1, 3, 3)
D = np.array([[-0.044752], [-0.006285], [0.000000], [0.000000]])
mean_focal = np.mean(rawK[:, 0, 0])
K = None # 1369.2632038333334, 1500.0, 1900.0
for i in range(6):
# Extract frames from video
os.makedirs(f"raw_images/{i + 1}", exist_ok=True)
extra_args = []
if args.start is not None:
extra_args.append(f"-ss {args.start}")
if args.duration is not None:
extra_args.append(f"-t {args.duration}")
extra_args = ' '.join(extra_args)
os.system(f"ffmpeg -i raw_video/{i + 1:02d}.mov {extra_args} -f image2 -q:v 2 -vf fps={args.fps} "
f"raw_images/{i + 1}/image%03d.png")
# Undistort frames and collect
os.makedirs(f"images", exist_ok=True)
raw_image_files = os.listdir(f"raw_images/{i + 1}")
map1, map2 = None, None
for raw_file in tqdm(raw_image_files):
raw_im = cv2.imread(f"raw_images/{i + 1}/{raw_file}")
if K is None:
K = np.array([[mean_focal, 0., raw_im.shape[1] / 2],
[0., mean_focal, raw_im.shape[0] / 2],
[0., 0., 1.]])
tqdm.write(
f"Intrinsic parameters: {mean_focal}, {raw_im.shape[0] / 2}, {raw_im.shape[1] / 2}")
if map1 is None:
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
rawK[i], D, None, K, (raw_im.shape[1], raw_im.shape[0]), cv2.CV_16SC2)
im = cv2.remap(raw_im, map1, map2, interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT)
im = cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
cv2.imwrite(f"images/image{i}{raw_file[5:]}", im)
import sys
import os
import argparse
import json
import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
from utils import misc
from utils.colmap_read_model import read_model
parser = argparse.ArgumentParser()
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_dir = args.dataset
os.makedirs(data_dir, exist_ok=True)
out_desc_path = os.path.join(data_dir, "train.json")
cameras, images, points3D = read_model(os.path.join(data_dir, 'sparse/0'), '.bin')
print("Model loaded.")
print("num_cameras:", len(cameras))
print("num_images:", len(images))
print("num_points3D:", len(points3D))
cam = cameras[list(cameras.keys())[0]]
views = np.array([int(images[img_id].name[5:9]) for img_id in images])
view_centers = np.array([images[img_id].tvec for img_id in images])
view_rots = []
for img_id in images:
im = images[img_id]
R = im.qvec2rotmat()
view_rots.append(R.reshape([9]).tolist())
view_rots = np.array(view_rots)
indices = np.argsort(views)
views = views[indices]
view_centers = view_centers[indices]
view_rots = view_rots[indices]
pts = np.array([points3D[pt_id].xyz for pt_id in points3D])
zvals = np.sqrt(np.sum(pts * pts, 1))
dataset_desc = {
'view_file_pattern': f"images/image%04d.jpg",
'gl_coord': True,
'view_res': {
'x': cam.width,
'y': cam.height
},
'cam_params': {
'fx': cam.params[0],
'fy': cam.params[0],
'cx': cam.params[1],
'cy': cam.params[2]
},
'range': {
'min': np.min(view_centers, 0).tolist() + [0, 0],
'max': np.max(view_centers, 0).tolist() + [0, 0]
},
'depth_range': {
'min': np.min(zvals),
'max': np.max(zvals)
},
'samples': [len(view_centers)],
'view_centers': view_centers.tolist(),
'view_rots': view_rots.tolist(),
'views': views.tolist()
}
with open(out_desc_path, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
...@@ -6,8 +6,9 @@ import numpy as np ...@@ -6,8 +6,9 @@ import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../../')) sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from utils import seqs from utils import seqs, math
from utils import math from utils.types import Resolution
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-r', '--rot-range', nargs='+', type=int) parser.add_argument('-r', '--rot-range', nargs='+', type=int)
...@@ -25,7 +26,7 @@ args = parser.parse_args() ...@@ -25,7 +26,7 @@ args = parser.parse_args()
data_dir = args.dataset data_dir = args.dataset
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
out_desc_path = os.path.join(data_dir, (args.out_desc if args.out_desc else f"{args.seq}.json")) out_desc_path = os.path.join(data_dir, args.out_desc or f"{args.seq}.json")
if args.ref: if args.ref:
with open(os.path.join(data_dir, args.ref), 'r') as fp: with open(os.path.join(data_dir, args.ref), 'r') as fp:
...@@ -37,34 +38,17 @@ else: ...@@ -37,34 +38,17 @@ else:
ref_desc = None ref_desc = None
if args.trans_range: if args.trans_range:
trans_range = np.array(list(args.trans_range) * 3 if len(args.trans_range) == 1 trans_range = np.array(args.trans_range * 3 if len(args.trans_range) == 1 else args.trans_range)
else args.trans_range)
else: else:
trans_range = np.array(ref_desc['range']['max'][0:3]) - \ trans_range = np.array(ref_desc["trs_range"])
np.array(ref_desc['range']['min'][0:3])
if args.rot_range: if args.rot_range:
rot_range = np.array(list(args.rot_range) * 2 if len(args.rot_range) == 1 rot_range = np.array(args.rot_range * 2 if len(args.rot_range) == 1 else args.rot_range)
else args.rot_range)
else: else:
rot_range = np.array(ref_desc['range']['max'][3:5]) - \ rot_range = np.array(ref_desc["rot_range"])
np.array(ref_desc['range']['min'][3:5])
filter_range = np.concatenate([trans_range, rot_range]) filter_range = np.concatenate([trans_range, rot_range])
if args.fov: cam_params = { "fov": args.fov } if args.fov else ref_desc["cam"]
cam_params = { res = Resolution.from_str(args.res or ref_desc["res"])
'fov': args.fov,
'cx': 0.5,
'cy': 0.5,
'normalized': True
}
else:
cam_params = ref_desc['cam_params']
if args.res:
res = tuple(int(s) for s in args.res.split('x'))
res = {'x': res[0], 'y': res[1]}
else:
res = ref_desc['view_res']
if args.seq == 'helix': if args.seq == 'helix':
centers, rots = seqs.helix(trans_range, 4, args.views) centers, rots = seqs.helix(trans_range, 4, args.views)
...@@ -73,7 +57,7 @@ elif args.seq == 'scan_around': ...@@ -73,7 +57,7 @@ elif args.seq == 'scan_around':
elif args.seq == 'look_around': elif args.seq == 'look_around':
centers, rots = seqs.look_around(trans_range, args.views) centers, rots = seqs.look_around(trans_range, args.views)
rots *= 180 / math.pi rots = np.degrees(rots)
gl = args.gl or ref_desc and ref_desc.get('gl_coord') gl = args.gl or ref_desc and ref_desc.get('gl_coord')
if gl: if gl:
centers[:, 2] *= -1 centers[:, 2] *= -1
...@@ -81,15 +65,13 @@ if gl: ...@@ -81,15 +65,13 @@ if gl:
dataset_desc = { dataset_desc = {
'gl_coord': gl, 'gl_coord': gl,
'view_res': res, 'res': f"{res.w}x{res.h}",
'cam_params': cam_params, 'cam': cam_params,
'range': { "trs_range": trans_range.tolist(),
'min': (-0.5 * filter_range).tolist(), "rot_range": rot_range.tolist(),
'max': (0.5 * filter_range).tolist()
},
'samples': [args.views], 'samples': [args.views],
'view_centers': centers.tolist(), 'centers': centers.tolist(),
'view_rots': rots.tolist() 'rots': rots.tolist()
} }
with open(out_desc_path, 'w') as fp: with open(out_desc_path, 'w') as fp:
......
...@@ -76,8 +76,8 @@ print('Test set views: ', len(test_views)) ...@@ -76,8 +76,8 @@ print('Test set views: ', len(test_views))
def create_subset(views, out_desc_name): def create_subset(views, out_desc_name):
views = views.tolist() views = views.tolist()
subset_desc = dataset_desc.copy() subset_desc = dataset_desc.copy()
subset_desc['view_file_pattern'] = \ subset_desc['color_file'] = \
f"{out_desc_name}/{dataset_desc['view_file_pattern'].split('/')[-1]}" f"{out_desc_name}/{dataset_desc['color_file'].split('/')[-1]}"
subset_desc['range'] = { subset_desc['range'] = {
'min': list(-filter_range / 2), 'min': list(-filter_range / 2),
'max': list(filter_range / 2) 'max': list(filter_range / 2)
...@@ -91,8 +91,8 @@ def create_subset(views, out_desc_name): ...@@ -91,8 +91,8 @@ def create_subset(views, out_desc_name):
json.dump(subset_desc, fp, indent=4) json.dump(subset_desc, fp, indent=4)
os.makedirs(os.path.join(out_data_dir, out_desc_name), exist_ok=True) os.makedirs(os.path.join(out_data_dir, out_desc_name), exist_ok=True)
for i in range(len(views)): for i in range(len(views)):
os.symlink(os.path.join('../../', dataset_desc['view_file_pattern'] % views[i]), os.symlink(os.path.join('../../', dataset_desc['color_file'] % views[i]),
os.path.join(out_data_dir, subset_desc['view_file_pattern'] % views[i])) os.path.join(out_data_dir, subset_desc['color_file'] % views[i]))
os.makedirs(out_data_dir, exist_ok=True) os.makedirs(out_data_dir, exist_ok=True)
......
...@@ -36,8 +36,8 @@ for i in range(len(input)): ...@@ -36,8 +36,8 @@ for i in range(len(input)):
input_desc: Mapping = json.load(fp) input_desc: Mapping = json.load(fp)
dataset_desc['view_centers'] += input_desc['view_centers'] dataset_desc['view_centers'] += input_desc['view_centers']
dataset_desc['view_rots'] += input_desc['view_rots'] dataset_desc['view_rots'] += input_desc['view_rots']
copy_images(get_data_path(input[i], input_desc['view_file_pattern']), copy_images(get_data_path(input[i], input_desc['color_file']),
get_data_path(output, dataset_desc['view_file_pattern']), get_data_path(output, dataset_desc['color_file']),
len(input_desc['view_centers']), n_views) len(input_desc['view_centers']), n_views)
n_views += len(input_desc['view_centers']) n_views += len(input_desc['view_centers'])
......
...@@ -2,59 +2,95 @@ import json ...@@ -2,59 +2,95 @@ import json
import sys import sys
import os import os
import argparse import argparse
import torch
from pathlib import Path from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../../')) sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from data import get_dataset_desc_path from data import DataDesc
from utils.misc import calculate_autosize
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str, nargs="+", required=True) parser.add_argument('-o', '--outputs', type=str, nargs="+", required=True,
parser.add_argument("-v", "--views", type=int, nargs="+", required=True) help="names of output datasets, leading with ~ to prepend the name of input dataset")
parser.add_argument("-v", "--views", type=str, nargs="+", required=True,
help="views of output datasets, could be -1, a positive number or a colon splited slice")
parser.add_argument("--random", action="store_true")
parser.add_argument('dataset', type=str) parser.add_argument('dataset', type=str)
parser.usage = """
Split a dataset into one or more datasets.
Examples:
> python split.py path_to_dataset.json -o train test -v 20 -1
This will create two datasets "train" and "test" in the folder where input dataset locates,
with the first 20 views in "train" and other views in "test".
> python split.py path_to_dataset.json -o ~_train ~_test -v -1 ::8
This will create two datasets "train" and "test" in the folder where input dataset locates,
with every 8 views in "test" and other views in "train".
"""
args = parser.parse_args() args = parser.parse_args()
input = get_dataset_desc_path(args.dataset) input = DataDesc.get_json_path(args.dataset)
outputs = [ outputs = [
get_dataset_desc_path(input.with_name(f"{input.stem}_{appendix}")) DataDesc.get_json_path(input.with_name(
for appendix in args.output f"{input.stem}{appendix[1:]}" if appendix.startswith("~") else appendix))
for appendix in args.outputs
] ]
with open(input, 'r') as fp: with open(input, 'r') as fp:
input_desc: dict = json.load(fp) input_desc: dict = json.load(fp)
n_views = len(input_desc['view_centers']) n_views = len(input_desc['centers'])
assert(len(args.views) == len(outputs)) assert(len(args.views) == len(outputs))
sum_views = sum(args.views)
for i in range(len(args.views)):
if args.views[i] == -1:
args.views[i] = n_views - sum_views - 1
sum_views = n_views
break
assert(sum_views <= n_views)
for i in range(len(args.views)):
assert(args.views[i] > 0)
indices = torch.arange(n_views)
indices_assigned = torch.zeros(n_views, dtype=torch.bool)
output_dataset_indices: list[torch.Tensor] = [None] * len(outputs)
output_dataset_views = {}
for i, output_views in enumerate(args.views):
arr = output_views.split(":")
if len(arr) > 1:
view_slice = slice(*[int(value) if value != "" else None for value in arr])
output_dataset_indices[i] = indices[view_slice]
indices_assigned[view_slice] = True
else:
output_dataset_views[i] = int(arr[0])
indices_remain = indices[indices_assigned.logical_not()]
n_views_remain = len(indices_remain)
output_dataset_views = {
key: value
for key, value in zip(output_dataset_views,
calculate_autosize(n_views_remain, *output_dataset_views.values())[0])
}
if args.random:
indices_remain = indices_remain[torch.randperm(n_views_remain)]
offset = 0 offset = 0
for key, value in output_dataset_views.items():
output_dataset_indices[key] = indices_remain[offset:offset + value]
offset += value
in_views = torch.tensor(input_desc["views"]) if "views" in input_desc else torch.arange(n_views)
in_centers = torch.tensor(input_desc["centers"])
in_rots = torch.tensor(input_desc["rots"]) if "rots" in input_desc else None
for i in range(len(outputs)): for i in range(len(outputs)):
n = args.views[i] sub_indices = output_dataset_indices[i].sort()[0]
end = offset + n
output_desc = input_desc.copy() output_desc = input_desc.copy()
output_desc['samples'] = args.views[i] output_desc['samples'] = [len(sub_indices)]
if 'views' in output_desc: output_desc['views'] = in_views[sub_indices].tolist()
output_desc['views'] = output_desc['views'][offset:end] output_desc['centers'] = in_centers[sub_indices].tolist()
else: if in_rots is not None:
output_desc['views'] = list(range(offset, end)) output_desc['rots'] = in_rots[sub_indices].tolist()
output_desc['view_centers'] = output_desc['view_centers'][offset:end]
if 'view_rots' in output_desc:
output_desc['view_rots'] = output_desc['view_rots'][offset:end]
with open(outputs[i], 'w') as fp: with open(outputs[i], 'w') as fp:
json.dump(output_desc, fp, indent=4) json.dump(output_desc, fp, indent=4)
# Create symbol links of images # Create symbol links of images
out_dir = outputs[i].with_suffix('') out_dir = outputs[i].with_suffix('')
out_dir.mkdir(exist_ok=True) out_dir.mkdir(exist_ok=True)
for k in range(n): for k in range(len(sub_indices)):
os.symlink(Path("..") / input.stem / (output_desc['view_file_pattern'] % output_desc['views'][k]), os.symlink(Path("..") / input.stem / (output_desc['color_file'] % output_desc['views'][k]),
out_dir / (input_desc['view_file_pattern'] % output_desc['views'][k])) out_dir / (input_desc['color_file'] % output_desc['views'][k]))
offset += args.views[i]
import json
import sys
import os
import argparse
import numpy as np
import shutil
from typing import List
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from data import DataDesc
from utils.misc import calculate_autosize
def run(dataset: str, outputs: list[str], views: list[int], random: bool = False):
if len(views) != len(outputs):
raise ValueError("")
input = DataDesc.get_json_path(dataset)
outputs = [
DataDesc.get_json_path(input.with_name(f"{input.stem}_{appendix}"))
for appendix in outputs
]
with open(input, 'r') as fp:
input_desc: dict = json.load(fp)
n_views = len(input_desc['view_centers']) // 6
assert(len(views) == len(outputs))
views, sum_views = calculate_autosize(n_views, *views)
if random:
indices = np.random.permutation(n_views)
else:
indices = np.arange(n_views)
in_views = np.array(input_desc["views"]) if "views" in input_desc else np.arange(n_views)
in_centers = np.array(input_desc["view_centers"])
in_rots = np.array(input_desc["view_rots"]) if "view_rots" in input_desc else None
offset = 0
for i in range(len(outputs)):
n = views[i]
end = offset + n
sub_indices = np.sort(indices[offset:end])
sub_indices = np.concatenate([sub_indices + j * n_views for j in range(6)], axis=0)
output_desc = input_desc.copy()
output_desc['samples'] = [views[i] * 6]
output_desc['views'] = in_views[sub_indices].tolist()
output_desc['view_centers'] = in_centers[sub_indices].tolist()
if in_rots is not None:
output_desc['view_rots'] = in_rots[sub_indices].tolist()
with open(outputs[i], 'w') as fp:
json.dump(output_desc, fp, indent=4)
# Create symbol links of images
out_dir: Path = outputs[i].with_suffix('')
if out_dir.exists():
shutil.rmtree(out_dir)
out_dir.mkdir()
for view in output_desc['views']:
filename = output_desc['color_file'] % view
os.symlink(Path("..") / input.stem / filename, out_dir / filename)
offset += views[i]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--outputs', type=str, nargs="+", required=True)
parser.add_argument("-v", "--views", type=int, nargs="+", required=True)
parser.add_argument("--random", action="store_true")
parser.add_argument('dataset', type=str)
args = parser.parse_args()
run(args.dataset, args.outputs, args.views, args.random)
import sys
import os
import argparse
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--start', type=int)
parser.add_argument('-t', '--duration', type=int)
parser.add_argument('--fps', type=str)
parser.add_argument('--subset', type=str)
parser.add_argument('datadir', type=str)
args = parser.parse_args()
os.chdir(args.datadir)
if args.subset is not None:
video_dir = Path(f"videos/{args.subset}")
else:
video_dir = Path(f"video")
for video_path in video_dir.glob("*.*"):
# Extract frames from video
image_dir = "images" if args.subset is None else f"images/{args.subset}"
os.makedirs(f"{image_dir}/{video_path.stem}", exist_ok=True)
extra_args = []
if args.start is not None:
extra_args.append(f"-ss {args.start}")
if args.duration is not None:
extra_args.append(f"-t {args.duration}")
if args.fps is not None:
extra_args.append(f"-vf fps={args.fps}")
extra_args = ' '.join(extra_args)
os.system(f"ffmpeg -i {video_path} {extra_args} -f image2 -q:v 2 "
f"{image_dir}/{video_path.stem}/image%03d.png")
\ No newline at end of file
import torch
import argparse
from operator import itemgetter
parser = argparse.ArgumentParser()
parser.add_argument("ckpt_path", type=str)
cli_args = parser.parse_args()
args, states = itemgetter("args", "states")(torch.load(cli_args.ckpt_path))
print(f"Model: {args['model']} >>>>")
for key, value in args["model_args"].items():
print(f"{key}: {value}")
print("\n")
if args["trainer"]:
print(f"Trainer: {args['trainer']} >>>>")
for key, value in args["trainer_args"].items():
print(f"{key}={value}")
print("\n")
print("Model states >>>>")
for key, value in states["model"].items():
print(f"{key}: Tensor{list(value.shape)}")
...@@ -49,7 +49,7 @@ def load_net(path): ...@@ -49,7 +49,7 @@ def load_net(path):
def export_net(net: torch.nn.Module, name: str, def export_net(net: torch.nn.Module, name: str,
input: Mapping[str, List[int]], output_names: List[str]): input: Mapping[str, list[int]], output_names: list[str]):
outpath = os.path.join(opt.outdir, config.to_id(), name + ".onnx") outpath = os.path.join(opt.outdir, config.to_id(), name + ".onnx")
input_tensors = tuple([ input_tensors = tuple([
torch.empty(size, device=device.default()) torch.empty(size, device=device.default())
......
...@@ -49,7 +49,7 @@ def load_net(path): ...@@ -49,7 +49,7 @@ def load_net(path):
def export_net(net: torch.nn.Module, name: str, def export_net(net: torch.nn.Module, name: str,
input: Mapping[str, List[int]], output_names: List[str]): input: Mapping[str, list[int]], output_names: list[str]):
outpath = os.path.join(opt.outdir, config.to_id(), name + ".onnx") outpath = os.path.join(opt.outdir, config.to_id(), name + ".onnx")
input_tensors = tuple([ input_tensors = tuple([
torch.empty(size, device=device.default()) torch.empty(size, device=device.default())
......
import sys import sys
import os
import argparse import argparse
import torch import torch
import torch.optim import torch.optim
from torch import onnx from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(str(Path(__file__).absolute().parent.parent))
from utils import netio
import model
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
help='Which CUDA device to use.')
parser.add_argument('--batch-size', type=str, parser.add_argument('--batch-size', type=str,
help='Resolution') help='Resolution')
parser.add_argument('--outdir', type=str, default='./', parser.add_argument('--outdir', type=str, default='onnx',
help='Output directory') help='Output directory')
parser.add_argument('model', type=str, parser.add_argument('model', type=str,
help='Path of model to export') help='Path of model to export')
opt = parser.parse_args() opt = parser.parse_args()
# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from configs.spherical_view_syn import SphericalViewSynConfig
from utils import device
from utils import netio
from utils import misc
dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)
config = SphericalViewSynConfig()
def load_net(path):
name = os.path.splitext(os.path.basename(path))[0]
config.from_id(name)
config.sa['spherical'] = True
config.sa['perturb_sample'] = False
config.sa['n_samples'] = 4
config.print()
net = config.create_net().to(device.default())
netio.load(path, net)
return net, name
if __name__ == "__main__":
with torch.no_grad():
# Load model
net, name = load_net(model_file)
# Input to the model with torch.inference_mode():
rays_o = torch.empty(batch_size, 3, device=device.default()) states, model_path = netio.load_checkpoint(opt.model)
rays_d = torch.empty(batch_size, 3, device=device.default()) batch_size = opt.batch_size and eval(opt.batch_size)
out_dir = model_path.parent / opt.outdir
os.makedirs(opt.outdir, exist_ok=True) model.deserialize(states).eval().export_onnx(out_dir, batch_size)
# Export the model print(f'Model exported to {out_dir}')
outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
onnx.export(
net, # model being run
(rays_o, rays_d), # model input (or a tuple for multiple inputs)
outpath,
export_params=True, # store the trained parameter weights inside the model file
verbose=True,
opset_version=9, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding
input_names=['Rays_o', 'Rays_d'], # the model's input names
output_names=['Colors'] # the model's output names
)
print ('Model exported to ' + outpath)
from pathlib import Path
import sys
import torch
import torch.optim
sys.path.append(str(Path(__file__).absolute().parent.parent))
import model
torch.set_grad_enabled(False)
m = model.load(
"/home/dengnc/dvs/data/classroom/_nets/train_hr_pano_t0.8/_hr_snerf/checkpoint_50.tar").eval().to("cuda")
print(m.cores[0])
inputs = (
torch.rand(10, 63, device="cuda"),
torch.rand(10, 24, device="cuda")
)
def fn(*args, **kwargs):
return m.cores[0].infer(*args, **kwargs)
sm = torch.jit.trace(fn, inputs)
torch.nn.Module.__call__
print(sm.infer(torch.rand(5, 63, device="cuda"), torch.rand(5, 24, device="cuda")))
sm.save("test.pt")
torch.onnx.export(sm.infer, # model being run
inputs, # model input (or a tuple for multiple inputs)
"core_0.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["x", "d"], # the model's input names
output_names=["densities", "colors"], # the model's output names
dynamic_axes={
"x": [0],
"d": [0],
"densities": [0],
"colors": [0]
}) # variable length axes
...@@ -61,8 +61,8 @@ def load_net(): ...@@ -61,8 +61,8 @@ def load_net():
return net return net
def export_net(net: torch.nn.Module, path: str, input: Mapping[str, List[int]], def export_net(net: torch.nn.Module, path: str, input: Mapping[str, list[int]],
output_names: List[str]): output_names: list[str]):
input_tensors = tuple([ input_tensors = tuple([
torch.empty(size, device=device.default()) torch.empty(size, device=device.default())
for size in input.values() for size in input.values()
......
import json import json
import sys import sys
import os import os
import csv
import argparse import argparse
import torch
import shutil import shutil
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as nn_f import torch.nn.functional as nn_f
from tqdm import trange
sys.path.append(os.path.abspath(sys.path[0] + '/../')) sys.path.append(os.path.abspath(sys.path[0] + '/../'))
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-s', '--stereo', action='store_true') parser.add_argument('-s', '--stereo', action='store_true',
parser.add_argument('-R', '--replace', action='store_true') help="Render stereo video")
parser.add_argument('--noCE', action='store_true') parser.add_argument('-d', '--disparity', type=float, default=0.06,
help="The stereo disparity")
parser.add_argument('-R', '--replace', action='store_true',
help="Replace the existed frames in the intermediate output directory")
parser.add_argument('--noCE', action='store_true',
help="Disable constrast enhancement")
parser.add_argument('-i', '--input', type=str) parser.add_argument('-i', '--input', type=str)
parser.add_argument('-r', '--range', type=str) parser.add_argument('-r', '--range', type=str,
parser.add_argument('-f', '--fps', type=int) help="The range of frames to render, specified as format: start,end")
parser.add_argument('--device', type=int, default=0, parser.add_argument('-f', '--fps', type=int,
help='Which CUDA device to use.') help="The FPS of output video. if not specified, a sequence of images will be saved instead")
parser.add_argument('scene', type=str) parser.add_argument('-m', '--model', type=str,
parser.add_argument('view_file', type=str) help="The directory containing fovea* and periph* model file")
parser.add_argument('view_file', type=str,
help="The path to .csv or .json file which contains a sequence of poses and gazes")
opt = parser.parse_args() opt = parser.parse_args()
# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
torch.autograd.set_grad_enabled(False)
from configs.spherical_view_syn import SphericalViewSynConfig from utils import netio, img, device
from utils import netio
from utils import misc
from utils import img
from utils import device
from utils.view import * from utils.view import *
from utils import sphere from utils.types import *
from components.fnr import FoveatedNeuralRenderer from components.fnr import FoveatedNeuralRenderer
from utils.progress_bar import progress_bar from model import Model
def load_net(path):
config = SphericalViewSynConfig()
config.from_id(path[:-4])
config.sa['perturb_sample'] = False
# config.print()
net = config.create_net().to(device.default())
netio.load(path, net)
return net
def find_file(prefix): def load_model(path: Path) -> Model:
for path in os.listdir(): checkpoint, _ = netio.load_checkpoint(path)
if path.startswith(prefix): checkpoint["model"][1]["sampler"]["perturb"] = False
return path model = Model.create(*checkpoint["model"])
return None model.load_state_dict(checkpoint["states"]["model"])
model.to(device.default()).eval()
return model
def clamp_gaze(gaze): def load_csv(data_desc_file: Path) -> tuple[Trans, torch.Tensor]:
return gaze
scoord = sphere.cartesian2spherical(gaze)
def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]:
def to_tensor(line_content): def to_tensor(line_content):
return torch.tensor([float(str) for str in line_content.split(',')]) return torch.tensor([float(str) for str in line_content.split(',')])
...@@ -84,9 +72,9 @@ def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]: ...@@ -84,9 +72,9 @@ def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]:
if lines[i + 1].startswith('0,0,0') or lines[i + 2].startswith('0,0,0'): if lines[i + 1].startswith('0,0,0') or lines[i + 2].startswith('0,0,0'):
continue continue
j = i + 1 j = i + 1
gaze_dirs[view_idx, 0] = clamp_gaze(to_tensor(lines[j])) gaze_dirs[view_idx, 0] = to_tensor(lines[j])
j += 1 j += 1
gaze_dirs[view_idx, 1] = clamp_gaze(to_tensor(lines[j])) gaze_dirs[view_idx, 1] = to_tensor(lines[j])
j += 1 j += 1
if not old_fmt: if not old_fmt:
gazes[view_idx, 0] = to_tensor(lines[j]) gazes[view_idx, 0] = to_tensor(lines[j])
...@@ -113,11 +101,16 @@ def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]: ...@@ -113,11 +101,16 @@ def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]:
return Trans(view_t, view_mats[:, 0, :3, :3]), gazes return Trans(view_t, view_mats[:, 0, :3, :3]), gazes
def load_json(data_desc_file) -> Tuple[Trans, torch.Tensor]:
def load_json(data_desc_file: Path) -> tuple[Trans, torch.Tensor]:
with open(data_desc_file, 'r', encoding='utf-8') as file: with open(data_desc_file, 'r', encoding='utf-8') as file:
data = json.load(file) data = json.load(file)
view_t = torch.tensor(data['view_centers']) view_t = torch.tensor(data['view_centers'])
view_r = torch.tensor(data['view_rots']).view(-1, 3, 3) view_r = torch.tensor(data['view_rots']).view(-1, 3, 3)
if data.get("gl_coord"):
view_t[:, 2] *= -1
view_r[:, 2] *= -1
view_r[..., 2] *= -1
if data.get('gazes'): if data.get('gazes'):
if len(data['gazes'][0]) == 2: if len(data['gazes'][0]) == 2:
gazes = torch.tensor(data['gazes']).view(-1, 1, 2).expand(-1, 2, -1) gazes = torch.tensor(data['gazes']).view(-1, 1, 2).expand(-1, 2, -1)
...@@ -127,68 +120,61 @@ def load_json(data_desc_file) -> Tuple[Trans, torch.Tensor]: ...@@ -127,68 +120,61 @@ def load_json(data_desc_file) -> Tuple[Trans, torch.Tensor]:
gazes = torch.zeros(view_t.size(0), 2, 2) gazes = torch.zeros(view_t.size(0), 2, 2)
return Trans(view_t, view_r), gazes return Trans(view_t, view_r), gazes
def load_views(data_desc_file: str) -> Tuple[Trans, torch.Tensor]:
if data_desc_file.endswith('.csv'): def load_views_and_gazes(data_desc_file: Path) -> tuple[Trans, torch.Tensor]:
return load_csv(data_desc_file) if data_desc_file.suffix == '.csv':
return load_json(data_desc_file) views, gazes = load_csv(data_desc_file)
else:
rot_range = { views, gazes = load_json(data_desc_file)
'classroom': [120, 80], gazes[:, :, 1] = (gazes[:, :1, 1] + gazes[:, 1:, 1]) * 0.5
'barbershop': [360, 80], return views, gazes
'lobby': [360, 80],
'stones': [360, 80]
} torch.set_grad_enabled(False)
trans_range = { view_file = Path(opt.view_file)
'classroom': 0.6, stereo_disparity = opt.disparity
'barbershop': 0.3,
'lobby': 1.0,
'stones': 1.0
}
fov_list = [20, 45, 110]
res_list = [(256, 256), (256, 256), (400, 360)]
res_full = (1600, 1440) res_full = (1600, 1440)
stereo_disparity = 0.06
cwd = os.getcwd() if opt.model:
os.chdir(f"{sys.path[0]}/../data/__new/{opt.scene}_all") model_dir = Path(opt.model)
fovea_net = load_net(find_file('fovea')) fov_list = [20.0, 45.0, 110.0]
periph_net = load_net(find_file('periph')) res_list = [(256, 256), (256, 256), (256, 230)]
renderer = FoveatedNeuralRenderer(fov_list, res_list, fovea_net = load_model(next(model_dir.glob("fovea*.tar")))
nn.ModuleList([fovea_net, periph_net, periph_net]), periph_net = load_model(next(model_dir.glob("periph*.tar")))
res_full, device=device.default()) renderer = FoveatedNeuralRenderer(fov_list, res_list,
os.chdir(cwd) nn.ModuleList([fovea_net, periph_net, periph_net]),
res_full, device=device.default())
else:
renderer = None
# Load Dataset # Load Dataset
views, gazes = load_views(opt.view_file) views, gazes = load_views_and_gazes(Path(opt.view_file))
if opt.range: if opt.range:
opt.range = [int(val) for val in opt.range.split(",")] opt.range = [int(val) for val in opt.range.split(",")]
if len(opt.range) == 1: if len(opt.range) == 1:
opt.range = [0, opt.range[0]] opt.range = [0, opt.range[0]]
views = views.get(range(opt.range[0], opt.range[1])) views = views[opt.range[0]:opt.range[1]]
gazes = gazes[opt.range[0]:opt.range[1]] gazes = gazes[opt.range[0]:opt.range[1]]
views = views.to(device.default()) views = views.to(device.default())
n_views = views.size()[0] n_views = views.shape[0]
print('Dataset loaded. Views:', n_views) print('Dataset loaded. Views:', n_views)
videodir = view_file.absolute().parent
videodir = os.path.dirname(os.path.abspath(opt.view_file)) tempdir = Path('/dev/shm/dvs_tmp/video')
tempdir = '/dev/shm/dvs_tmp/video' if opt.input:
videoname = f"{os.path.splitext(os.path.split(opt.view_file)[-1])[0]}_{'stereo' if opt.stereo else 'mono'}" videoname = Path(opt.input).parent.stem
gazeout = f"{videodir}/{videoname}_gaze.csv" else:
if opt.noCE: videoname = f"{view_file.stem}_{('stereo' if opt.stereo else 'mono')}"
videoname += "_noCE" if opt.noCE:
videoname += "_noCE"
gazeout = videodir / f"{videoname}_gaze.csv"
if opt.fps: if opt.fps:
if opt.input:
videoname = os.path.split(opt.input)[0]
inferout = f"{videodir}/{opt.input}" if opt.input else f"{tempdir}/{videoname}/%04d.bmp" inferout = f"{videodir}/{opt.input}" if opt.input else f"{tempdir}/{videoname}/%04d.bmp"
hintout = f"{tempdir}/{videoname}_hint/%04d.bmp" hintout = f"{tempdir}/{videoname}_hint/%04d.bmp"
else: else:
inferout = f"{videodir}/{opt.input}" if opt.input else f"{videodir}/{videoname}/%04d.png" inferout = f"{videodir}/{opt.input}" if opt.input else f"{videodir}/{videoname}/%04d.png"
hintout = f"{videodir}/{videoname}_hint/%04d.png" hintout = f"{videodir}/{videoname}_hint/%04d.png"
if opt.input: scale = img.load(inferout % 0).shape[-1] / res_full[1] if opt.input else 1
scale = img.load(inferout % 0).shape[-1] / res_full[1]
else:
scale = 1
hint = img.load(f"{sys.path[0]}/fovea_hint.png", with_alpha=True).to(device=device.default()) hint = img.load(f"{sys.path[0]}/fovea_hint.png", with_alpha=True).to(device=device.default())
hint = nn_f.interpolate(hint, mode='bilinear', scale_factor=scale, align_corners=False) hint = nn_f.interpolate(hint, mode='bilinear', scale_factor=scale, align_corners=False)
...@@ -233,66 +219,44 @@ if not opt.replace: ...@@ -233,66 +219,44 @@ if not opt.replace:
hint_offset = max(0, hint_offset - 1) hint_offset = max(0, hint_offset - 1)
infer_offset = n_views if opt.input else max(0, infer_offset - 1) infer_offset = n_views if opt.input else max(0, infer_offset - 1)
if opt.stereo: for view_idx in trange(n_views):
gazes_out = torch.empty(n_views, 4) if view_idx < hint_offset:
for view_idx in range(n_views): continue
shift = gazes[view_idx, 0, 0] - gazes[view_idx, 1, 0] gaze = gazes[view_idx]
# print(shift.item()) if not opt.stereo:
gazel = ((gazes[view_idx, 1, 0] + 0.4 * shift).item(), gaze = gaze.sum(0, True) * 0.5
0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item()) gaze = gaze.tolist()
gazer = ((gazes[view_idx, 0, 0] - 0.4 * shift).item(), gazel[1]) if view_idx < infer_offset:
# gazel = ((gazes[view_idx, 0, 0]).item(), frame = img.load(inferout % view_idx).to(device=device.default())
# 0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item()) else:
#gazer = ((gazes[view_idx, 1, 0]).item(), gazel[1]) if renderer is None:
gazes_out[view_idx] = torch.tensor([gazel[0], gazel[1], gazer[0], gazer[1]]) raise Exception
if view_idx < hint_offset: key = 'blended_raw' if opt.noCE else 'blended'
continue view_trans = views[view_idx]
if view_idx < infer_offset: if opt.stereo:
frame = img.load(inferout % view_idx).to(device=device.default()) left_images, right_images = renderer(view_trans, *gaze,
else: stereo_disparity=stereo_disparity,
view_trans = views.get(view_idx) mono_periph_mode=3, ret_raw=True)
left_images, right_images = renderer(view_trans, gazel, gazer, frame = torch.cat([left_images[key], right_images[key]], -1)
stereo_disparity=stereo_disparity,
mono_periph_mode=3, ret_raw=True)
frame = torch.cat([
left_images['blended_raw'] if opt.noCE else left_images['blended'],
right_images['blended_raw'] if opt.noCE else right_images['blended']], -1)
frame = img.translate(frame, (0.5, 0.5))
img.save(frame, inferout % view_idx)
add_hint(frame, gazel, gazer)
img.save(frame, hintout % view_idx)
progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx)
else:
gazes_out = torch.empty(n_views, 2)
for view_idx in range(n_views):
gaze = 0.5 * (gazes[view_idx, 0] + gazes[view_idx, 1])
gaze = (gaze[0].item(), gaze[1].item())
gazes_out[view_idx] = torch.tensor([gaze[0], gaze[1]])
if view_idx < hint_offset:
continue
if view_idx < infer_offset:
frame = img.load(inferout % view_idx).to(device=device.default())
else: else:
view_trans = views.get(view_idx) frame = renderer(view_trans, *gaze, ret_raw=True)[key]
frame = renderer(view_trans, gaze, img.save(frame, inferout % view_idx)
ret_raw=True)['blended_raw' if opt.noCE else 'blended'] add_hint(frame, *gaze)
frame = img.translate(frame, (0.5, 0.5)) img.save(frame, hintout % view_idx)
img.save(frame, inferout % view_idx)
add_hint(frame, gaze)
img.save(frame, hintout % view_idx)
progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx)
gazes_out = gazes.reshape(-1, 4) if opt.stereo else gazes.sum(1) * 0.5
with open(gazeout, 'w') as fp: with open(gazeout, 'w') as fp:
for i in range(n_views): csv_writer = csv.writer(fp)
fp.write(','.join([f'{val.item()}' for val in gazes_out[i]])) csv_writer.writerows(gazes_out.tolist())
fp.write('\n')
if opt.fps: if opt.fps:
# Generate video without hint # Generate video without hint
os.system(f'ffmpeg -y -r {opt.fps:d} -i {inferout} -c:v libx264 {videodir}/{videoname}.mp4') os.system(f'ffmpeg -y -r {opt.fps:d} -i {inferout} -c:v libx264 {videodir}/{videoname}.mp4')
if not opt.input:
shutil.rmtree(os.path.dirname(inferout))
# Generate video with hint # Generate video with hint
os.system(f'ffmpeg -y -r {opt.fps:d} -i {hintout} -c:v libx264 {videodir}/{videoname}_hint.mp4') os.system(f'ffmpeg -y -r {opt.fps:d} -i {hintout} -c:v libx264 {videodir}/{videoname}_hint.mp4')
# Clean temp images
if not opt.input:
shutil.rmtree(os.path.dirname(inferout))
shutil.rmtree(os.path.dirname(hintout)) shutil.rmtree(os.path.dirname(hintout))
import sys
import os import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
import argparse import argparse
from PIL import Image from PIL import Image
from utils import misc from tqdm import tqdm
from pathlib import Path
def batch_scale(src, target, size): def run(src: Path, target: Path, root: Path, scale_factor: float = 1., width: int = -1, height: int = -1):
os.makedirs(target, exist_ok=True) target.mkdir(exist_ok=True)
for file_name in os.listdir(src): for file_name in tqdm(os.listdir(src), leave=False, desc=src.relative_to(root).__str__()):
postfix = os.path.splitext(file_name)[1] if (src / file_name).is_dir():
if postfix == '.jpg' or postfix == '.png': run(src / file_name, target / file_name, root, scale_factor, width, height)
im = Image.open(os.path.join(src, file_name)) elif not (target / file_name).exists():
im = im.resize(size) postfix = os.path.splitext(file_name)[1]
im.save(os.path.join(target, file_name)) if postfix == '.jpg' or postfix == '.png':
im = Image.open(src / file_name)
if width == -1 and height == -1:
width = round(im.width * scale_factor)
height = round(im.height * scale_factor)
elif width == -1:
width = round(im.width / im.height * height)
elif height == -1:
height = round(im.height / im.width * width)
im = im.resize((width, height))
im.save(target / file_name)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -23,9 +31,11 @@ if __name__ == '__main__': ...@@ -23,9 +31,11 @@ if __name__ == '__main__':
help='Source directory.') help='Source directory.')
parser.add_argument('target', type=str, parser.add_argument('target', type=str,
help='Target directory.') help='Target directory.')
parser.add_argument('--width', type=int, parser.add_argument('-x', '--scale-factor', type=float, default=1,
help='Target directory.')
parser.add_argument('--width', type=int, default=-1,
help='Width of output images (pixel)') help='Width of output images (pixel)')
parser.add_argument('--height', type=int, parser.add_argument('--height', type=int, default=-1,
help='Height of output images (pixel)') help='Height of output images (pixel)')
opt = parser.parse_args() opt = parser.parse_args()
batch_scale(opt.src, opt.target, (opt.width, opt.height)) run(Path(opt.src), Path(opt.target), Path(opt.src), opt.scale_factor, opt.width, opt.height)
...@@ -40,7 +40,7 @@ for subdir in ['images', 'images_4', 'images_8']: ...@@ -40,7 +40,7 @@ for subdir in ['images', 'images_4', 'images_8']:
print(f"Rename {src} to {tgt}") print(f"Rename {src} to {tgt}")
os.rename(src, tgt) os.rename(src, tgt)
out_desc = dataset_desc.copy() out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = f"{subdir}/view_%04d.jpg" out_desc['color_file'] = f"{subdir}/view_%04d.jpg"
k = res[0] / dataset_desc['view_res']['y'] k = res[0] / dataset_desc['view_res']['y']
out_desc['view_res'] = { out_desc['view_res'] = {
'x': res[1], 'x': res[1],
......
import argparse from operator import itemgetter
import logging from configargparse import ArgumentParser, SUPPRESS
import os
import sys
from pathlib import Path
from typing import List
import model as mdl from model import Model
import train from train import Trainer
from utils import device from utils import device, netio
from utils import netio from utils.types import *
from data import * from data import Dataset
from utils.misc import print_and_log
RAYS_PER_BATCH = 2 ** 12 def load_dataset(data_path: Path, color: str, coord: str):
DATA_LOADER_CHUNK_SIZE = 1e8 dataset = Dataset(data_path, color_mode=Color[color], coord_sys=coord)
root_dir = Path(__file__).absolute().parent print(f"Load dataset: {dataset.root}/{dataset.name}")
return dataset
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
help='Net config files')
parser.add_argument('-e', '--epochs', type=int,
help='Max epochs for train')
parser.add_argument('--perf', type=int,
help='Performance measurement frames (0 for disabling performance measurement)')
parser.add_argument('--prune', type=int, nargs='+',
help='Prune voxels on every # epochs')
parser.add_argument('--split', type=int, nargs='+',
help='Split voxels on every # epochs')
parser.add_argument('--freeze', type=int, nargs='+',
help='freeze levels on epochs')
parser.add_argument('--checkpoint-interval', type=int)
parser.add_argument('--views', type=str,
help='Specify the range of views to train')
parser.add_argument('path', type=str,
help='Dataset description file')
args = parser.parse_args()
views_to_load = range(*[int(val) for val in args.views.split('-')]) if args.views else None
argpath = Path(args.path)
# argpath: May be model path or data path
# 1) model path: continue training on the specified model
# 2) data path: train a new model using specified dataset
def load_dataset(data_path: Path):
print(f"Loading dataset {data_path}")
try:
dataset = DatasetFactory.load(data_path, views_to_load=views_to_load)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
os.chdir(dataset.root)
return dataset, dataset.name
except FileNotFoundError:
return load_multiscale_dataset(data_path)
def load_multiscale_dataset(data_path: Path):
if not data_path.is_dir():
raise ValueError(
f"Path {data_path} is not a directory")
dataset: List[Union[PanoDataset, ViewDataset]] = []
for sub_data_desc_path in data_path.glob("*.json"):
sub_dataset = DatasetFactory.load(sub_data_desc_path, views_to_load=views_to_load)
print(f"Sub-dataset loaded: {sub_dataset.root}/{sub_dataset.name}")
dataset.append(sub_dataset)
if len(dataset) == 0:
raise ValueError(f"Path {data_path} does not contain sub-datasets")
os.chdir(data_path.parent)
return dataset, data_path.name
try:
states, checkpoint_path = netio.load_checkpoint(argpath)
# Infer dataset path from model path
# The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar
model_name = checkpoint_path.parts[-2]
dataset, dataset_name = load_dataset(
Path(*checkpoint_path.parts[:-4]) / checkpoint_path.parts[-3])
except Exception:
model_name = args.config
dataset, dataset_name = load_dataset(argpath)
# Load state 0 from specified configuration
with Path(f'{root_dir}/configs/{args.config}.json').open() as fp:
states = json.load(fp)
states['args']['bbox'] = dataset[0].bbox if isinstance(dataset, list) else dataset.bbox
states['args']['depth_range'] = dataset[0].depth_range if isinstance(dataset, list)\
else dataset.depth_range
if 'train' not in states:
states['train'] = {}
if args.prune is not None:
states['train']['prune_epochs'] = args.prune
if args.split is not None:
states['train']['split_epochs'] = args.split
if args.freeze is not None:
states['train']['freeze_epochs'] = args.freeze
if args.perf is not None:
states['train']['perf_frames'] = args.perf
if args.checkpoint_interval is not None:
states['train']['checkpoint_interval'] = args.checkpoint_interval
if args.epochs is not None:
states['train']['max_epochs'] = args.epochs
model = mdl.deserialize(states).to(device.default()) initial_parser = ArgumentParser()
initial_parser.add_argument('-c', '--config', type=str, default=SUPPRESS,
help='Config name, ignored if path is a checkpoint path')
initial_parser.add_argument('--expname', type=str, default=SUPPRESS,
help='Experiment name, defaults to config name, ignored if path is a checkpoint path')
initial_parser.add_argument('path', type=str,
help='Path to dataset description file or checkpoint file')
initial_args = vars(initial_parser.parse_known_args()[0])
# Initialize run directory root_dir = Path(__file__).absolute().parent
run_dir = Path(f"_nets/{dataset_name}/{model_name}") argpath = Path(initial_args["path"]) # May be checkpoint path or dataset path
run_dir.mkdir(parents=True, exist_ok=True) # 1) checkpoint path: continue training a model
# 2) dataset path: train a new model using specified dataset
# Initialize logging
log_file = run_dir / "train.log" ckpt_path = netio.find_checkpoint(argpath)
logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO, if ckpt_path:
filename=log_file, filemode='a' if log_file.exists() else 'w') # Continue training from a checkpoint
print(f"Load checkpoint {ckpt_path}")
args, states = itemgetter("args", "states")(torch.load(ckpt_path))
def log_exception(exc_type, exc_value, exc_traceback): # args: "model", "model_args", "trainer", "trainer_args"
if not issubclass(exc_type, KeyboardInterrupt): ModelCls = Model.get_class(args["model"])
logging.exception(exc_value, exc_info=(exc_type, exc_value, exc_traceback)) TrainerCls = Trainer.get_class(args["trainer"])
sys.__excepthook__(exc_type, exc_value, exc_traceback) model_args = ModelCls.Args(**args["model_args"])
trainer_args = TrainerCls.Args(**args["trainer_args"]).parse()
trainset = load_dataset(trainer_args.trainset, model_args.color, model_args.coord)
sys.excepthook = log_exception run_dir = ckpt_path.parent
else:
print_and_log(f"model: {model_name} ({model.cls})") # Start a new train
print_and_log(f"args:") expname = initial_args.get("expname", initial_args.get("config", "unnamed"))
model.print_config() if "config" in initial_args:
print(model) config_path = root_dir / "configs" / f"{initial_args['config']}.ini"
if not config_path.exists():
raise ValueError(f"Config {initial_args['config']} is not found in "
if __name__ == "__main__": f"{root_dir / 'configs'}.")
# 1. Initialize data loader print(f"Load config {config_path}")
data_loader = get_loader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE, else:
shuffle=True, enable_preload=False, color=model.color) config_path = None
# 2. Initialize model and trainer # First parse model class and trainer class from config file or command-line arguments
trainer = train.get_trainer(model, run_dir, states) parser = ArgumentParser(default_config_files=[f"{config_path}"] if config_path else [])
parser.add_argument('--color', type=str, default="rgb",
# 3. Train help='The color mode')
trainer.train(data_loader) parser.add_argument('--model', type=str, required=True,
help='The model to train')
parser.add_argument('--trainer', type=str, default="Trainer",
help='The trainer to use for training')
args = parser.parse_known_args()[0]
ModelCls = Model.get_class(args.model)
TrainerCls = Trainer.get_class(args.trainer)
trainset_path = argpath
trainset = load_dataset(trainset_path, args.color, "gl")
# Then parse model's and trainer's args
if trainset.depth_range:
model_args = ModelCls.Args( # Some model's args are inferred from training dataset
color=trainset.color_mode.name,
near=trainset.depth_range[0],
far=trainset.depth_range[1],
white_bg=trainset.white_bg,
coord=trainset.coord_sys
)
else:
model_args = ModelCls.Args(white_bg=trainset.white_bg)
model_args.parse(config_path)
trainer_args = TrainerCls.Args(trainset=f"{trainset_path}").parse(config_path)
states = None
run_dir = trainset.root / "_nets" / trainset.name / expname
run_dir.mkdir(parents=True, exist_ok=True)
m = ModelCls(model_args).to(device.default())
trainer = TrainerCls(m, run_dir, trainer_args)
if states:
trainer.load_state_dict(states)
# Start train
trainer.train(trainset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment