Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
......@@ -37,20 +37,20 @@ for i in range(3):
for j in range(2):
out_desc_name = f'part{idx:d}'
out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = f'{out_desc_name}/view_%04d.png'
out_desc['color_file'] = f'{out_desc_name}/view_%04d.png'
n_x = out_desc['samples'][3] // 3
n_y = out_desc['samples'][4] // 2
views = indices[..., i * n_x:(i + 1) * n_x, j * n_y:(j + 1) * n_y].flatten().tolist()
out_desc['samples'] = [len(views)]
out_desc['views'] = views
out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
out_desc['centers'] = np.array(dataset_desc['centers'])[views].tolist()
out_desc['rots'] = np.array(dataset_desc['rots'])[views].tolist()
with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
json.dump(out_desc, fp, indent=4)
os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
for k in range(len(views)):
os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))
os.symlink(os.path.join('..', dataset_desc['color_file'] % views[k]),
os.path.join(data_dir, out_desc['color_file'] % views[k]))
idx += 1
'''
......@@ -60,24 +60,24 @@ for xi in range(0, 4, 2):
for zi in range(0, 4, 2):
out_desc_name = f'part{idx:d}'
out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = f'{out_desc_name}/view_%04d.png'
out_desc['color_file'] = f'{out_desc_name}/view_%04d.png'
views = indices[xi:xi + 2, yi:yi + 2, zi:zi + 2].flatten().tolist()
out_desc['samples'] = [len(views)]
out_desc['views'] = views
out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
out_desc['centers'] = np.array(dataset_desc['centers'])[views].tolist()
out_desc['rots'] = np.array(dataset_desc['rots'])[views].tolist()
with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
json.dump(out_desc, fp, indent=4)
os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
for k in range(len(views)):
os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))
os.symlink(os.path.join('..', dataset_desc['color_file'] % views[k]),
os.path.join(data_dir, out_desc['color_file'] % views[k]))
idx += 1
'''
def extract_by_grid(*grid_indices):
indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples'])
indices = torch.arange(len(dataset_desc['centers'])).view(dataset_desc['samples'])
views = []
for idx in product(*grid_indices):
views += indices[idx].flatten().tolist()
......@@ -86,11 +86,11 @@ def extract_by_grid(*grid_indices):
def extract_by_trans(max_trans, max_views):
if max_trans is not None:
centers = np.array(dataset_desc['view_centers'])
centers = np.array(dataset_desc['centers'])
trans = np.linalg.norm(centers, axis=-1)
indices = np.nonzero(trans <= max_trans)[0]
else:
indices = np.arange(len(dataset_desc['view_centers']))
indices = np.arange(len(dataset_desc['centers']))
if max_views is not None:
indices = np.sort(indices[np.random.permutation(indices.shape[0])[:max_views]])
return indices.tolist()
......@@ -101,18 +101,18 @@ if args.grids:
else:
views = extract_by_trans(args.trans, args.views)
image_path = dataset_desc['view_file_pattern']
image_path = dataset_desc['color_file']
if "/" not in image_path:
image_path = in_name + "/" + image_path
# Save new dataset
out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = image_path.split('/')[-1]
out_desc['color_file'] = image_path.split('/')[-1]
out_desc['samples'] = [len(views)]
out_desc['views'] = views
out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
if 'view_rots' in dataset_desc:
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
out_desc['centers'] = np.array(dataset_desc['centers'])[views].tolist()
if 'rots' in dataset_desc:
out_desc['rots'] = np.array(dataset_desc['rots'])[views].tolist()
# Write new data desc
with open(out_desc_path, 'w') as fp:
......@@ -123,7 +123,7 @@ out_dir.mkdir()
for k in range(len(views)):
if out_dir.parent.absolute() == root_dir.absolute():
os.symlink(Path("..") / (image_path % views[k]),
out_dir / (out_desc['view_file_pattern'] % views[k]))
out_dir / (out_desc['color_file'] % views[k]))
else:
os.symlink(root_dir.absolute() / (image_path % views[k]),
out_dir / (out_desc['view_file_pattern'] % views[k]))
out_dir / (out_desc['color_file'] % views[k]))
import sys
import os
import argparse
import numpy as np
import cv2
from tqdm import tqdm
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--start', type=int)
parser.add_argument('-t', '--duration', type=int)
parser.add_argument('--fps', type=str, required=True)
parser.add_argument('datadir', type=str)
args = parser.parse_args()
os.chdir(args.datadir)
rawK = np.array([
[1369.757446, 0., 1838.643555, 0., 1369.757446, 1524.068604, 0., 0., 1.],
[1367.517944, 0., 1840.157837, 0., 1367.517944, 1536.036133, 0., 0., 1.],
[1369.830322, 0., 1827.990723, 0., 1369.830322, 1514.463135, 0., 0., 1.],
[1368.966187, 0., 1829.976196, 0., 1368.966187, 1512.734375, 0., 0., 1.],
[1373.654297, 0., 1838.130859, 0., 1373.654297, 1534.985840, 0., 0., 1.],
[1365.853027, 0., 1835.100830, 0., 1365.853027, 1533.032959, 0., 0., 1.]
]).reshape(-1, 3, 3)
D = np.array([[-0.044752], [-0.006285], [0.000000], [0.000000]])
mean_focal = np.mean(rawK[:, 0, 0])
K = None # 1369.2632038333334, 1500.0, 1900.0
for i in range(6):
# Extract frames from video
os.makedirs(f"raw_images/{i + 1}", exist_ok=True)
extra_args = []
if args.start is not None:
extra_args.append(f"-ss {args.start}")
if args.duration is not None:
extra_args.append(f"-t {args.duration}")
extra_args = ' '.join(extra_args)
os.system(f"ffmpeg -i raw_video/{i + 1:02d}.mov {extra_args} -f image2 -q:v 2 -vf fps={args.fps} "
f"raw_images/{i + 1}/image%03d.png")
# Undistort frames and collect
os.makedirs(f"images", exist_ok=True)
raw_image_files = os.listdir(f"raw_images/{i + 1}")
map1, map2 = None, None
for raw_file in tqdm(raw_image_files):
raw_im = cv2.imread(f"raw_images/{i + 1}/{raw_file}")
if K is None:
K = np.array([[mean_focal, 0., raw_im.shape[1] / 2],
[0., mean_focal, raw_im.shape[0] / 2],
[0., 0., 1.]])
tqdm.write(
f"Intrinsic parameters: {mean_focal}, {raw_im.shape[0] / 2}, {raw_im.shape[1] / 2}")
if map1 is None:
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
rawK[i], D, None, K, (raw_im.shape[1], raw_im.shape[0]), cv2.CV_16SC2)
im = cv2.remap(raw_im, map1, map2, interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT)
im = cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
cv2.imwrite(f"images/image{i}{raw_file[5:]}", im)
import sys
import os
import argparse
import json
import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
from utils import misc
from utils.colmap_read_model import read_model
parser = argparse.ArgumentParser()
parser.add_argument('dataset', type=str)
args = parser.parse_args()
data_dir = args.dataset
os.makedirs(data_dir, exist_ok=True)
out_desc_path = os.path.join(data_dir, "train.json")
cameras, images, points3D = read_model(os.path.join(data_dir, 'sparse/0'), '.bin')
print("Model loaded.")
print("num_cameras:", len(cameras))
print("num_images:", len(images))
print("num_points3D:", len(points3D))
cam = cameras[list(cameras.keys())[0]]
views = np.array([int(images[img_id].name[5:9]) for img_id in images])
view_centers = np.array([images[img_id].tvec for img_id in images])
view_rots = []
for img_id in images:
im = images[img_id]
R = im.qvec2rotmat()
view_rots.append(R.reshape([9]).tolist())
view_rots = np.array(view_rots)
indices = np.argsort(views)
views = views[indices]
view_centers = view_centers[indices]
view_rots = view_rots[indices]
pts = np.array([points3D[pt_id].xyz for pt_id in points3D])
zvals = np.sqrt(np.sum(pts * pts, 1))
dataset_desc = {
'view_file_pattern': f"images/image%04d.jpg",
'gl_coord': True,
'view_res': {
'x': cam.width,
'y': cam.height
},
'cam_params': {
'fx': cam.params[0],
'fy': cam.params[0],
'cx': cam.params[1],
'cy': cam.params[2]
},
'range': {
'min': np.min(view_centers, 0).tolist() + [0, 0],
'max': np.max(view_centers, 0).tolist() + [0, 0]
},
'depth_range': {
'min': np.min(zvals),
'max': np.max(zvals)
},
'samples': [len(view_centers)],
'view_centers': view_centers.tolist(),
'view_rots': view_rots.tolist(),
'views': views.tolist()
}
with open(out_desc_path, 'w') as fp:
json.dump(dataset_desc, fp, indent=4)
......@@ -6,8 +6,9 @@ import numpy as np
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from utils import seqs
from utils import math
from utils import seqs, math
from utils.types import Resolution
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--rot-range', nargs='+', type=int)
......@@ -25,7 +26,7 @@ args = parser.parse_args()
data_dir = args.dataset
os.makedirs(data_dir, exist_ok=True)
out_desc_path = os.path.join(data_dir, (args.out_desc if args.out_desc else f"{args.seq}.json"))
out_desc_path = os.path.join(data_dir, args.out_desc or f"{args.seq}.json")
if args.ref:
with open(os.path.join(data_dir, args.ref), 'r') as fp:
......@@ -37,34 +38,17 @@ else:
ref_desc = None
if args.trans_range:
trans_range = np.array(list(args.trans_range) * 3 if len(args.trans_range) == 1
else args.trans_range)
trans_range = np.array(args.trans_range * 3 if len(args.trans_range) == 1 else args.trans_range)
else:
trans_range = np.array(ref_desc['range']['max'][0:3]) - \
np.array(ref_desc['range']['min'][0:3])
trans_range = np.array(ref_desc["trs_range"])
if args.rot_range:
rot_range = np.array(list(args.rot_range) * 2 if len(args.rot_range) == 1
else args.rot_range)
rot_range = np.array(args.rot_range * 2 if len(args.rot_range) == 1 else args.rot_range)
else:
rot_range = np.array(ref_desc['range']['max'][3:5]) - \
np.array(ref_desc['range']['min'][3:5])
rot_range = np.array(ref_desc["rot_range"])
filter_range = np.concatenate([trans_range, rot_range])
if args.fov:
cam_params = {
'fov': args.fov,
'cx': 0.5,
'cy': 0.5,
'normalized': True
}
else:
cam_params = ref_desc['cam_params']
if args.res:
res = tuple(int(s) for s in args.res.split('x'))
res = {'x': res[0], 'y': res[1]}
else:
res = ref_desc['view_res']
cam_params = { "fov": args.fov } if args.fov else ref_desc["cam"]
res = Resolution.from_str(args.res or ref_desc["res"])
if args.seq == 'helix':
centers, rots = seqs.helix(trans_range, 4, args.views)
......@@ -73,7 +57,7 @@ elif args.seq == 'scan_around':
elif args.seq == 'look_around':
centers, rots = seqs.look_around(trans_range, args.views)
rots *= 180 / math.pi
rots = np.degrees(rots)
gl = args.gl or ref_desc and ref_desc.get('gl_coord')
if gl:
centers[:, 2] *= -1
......@@ -81,15 +65,13 @@ if gl:
dataset_desc = {
'gl_coord': gl,
'view_res': res,
'cam_params': cam_params,
'range': {
'min': (-0.5 * filter_range).tolist(),
'max': (0.5 * filter_range).tolist()
},
'res': f"{res.w}x{res.h}",
'cam': cam_params,
"trs_range": trans_range.tolist(),
"rot_range": rot_range.tolist(),
'samples': [args.views],
'view_centers': centers.tolist(),
'view_rots': rots.tolist()
'centers': centers.tolist(),
'rots': rots.tolist()
}
with open(out_desc_path, 'w') as fp:
......
......@@ -76,8 +76,8 @@ print('Test set views: ', len(test_views))
def create_subset(views, out_desc_name):
views = views.tolist()
subset_desc = dataset_desc.copy()
subset_desc['view_file_pattern'] = \
f"{out_desc_name}/{dataset_desc['view_file_pattern'].split('/')[-1]}"
subset_desc['color_file'] = \
f"{out_desc_name}/{dataset_desc['color_file'].split('/')[-1]}"
subset_desc['range'] = {
'min': list(-filter_range / 2),
'max': list(filter_range / 2)
......@@ -91,8 +91,8 @@ def create_subset(views, out_desc_name):
json.dump(subset_desc, fp, indent=4)
os.makedirs(os.path.join(out_data_dir, out_desc_name), exist_ok=True)
for i in range(len(views)):
os.symlink(os.path.join('../../', dataset_desc['view_file_pattern'] % views[i]),
os.path.join(out_data_dir, subset_desc['view_file_pattern'] % views[i]))
os.symlink(os.path.join('../../', dataset_desc['color_file'] % views[i]),
os.path.join(out_data_dir, subset_desc['color_file'] % views[i]))
os.makedirs(out_data_dir, exist_ok=True)
......
......@@ -36,8 +36,8 @@ for i in range(len(input)):
input_desc: Mapping = json.load(fp)
dataset_desc['view_centers'] += input_desc['view_centers']
dataset_desc['view_rots'] += input_desc['view_rots']
copy_images(get_data_path(input[i], input_desc['view_file_pattern']),
get_data_path(output, dataset_desc['view_file_pattern']),
copy_images(get_data_path(input[i], input_desc['color_file']),
get_data_path(output, dataset_desc['color_file']),
len(input_desc['view_centers']), n_views)
n_views += len(input_desc['view_centers'])
......
......@@ -2,59 +2,95 @@ import json
import sys
import os
import argparse
import torch
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from data import get_dataset_desc_path
from data import DataDesc
from utils.misc import calculate_autosize
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str, nargs="+", required=True)
parser.add_argument("-v", "--views", type=int, nargs="+", required=True)
parser.add_argument('-o', '--outputs', type=str, nargs="+", required=True,
help="names of output datasets, leading with ~ to prepend the name of input dataset")
parser.add_argument("-v", "--views", type=str, nargs="+", required=True,
help="views of output datasets, could be -1, a positive number or a colon splited slice")
parser.add_argument("--random", action="store_true")
parser.add_argument('dataset', type=str)
parser.usage = """
Split a dataset into one or more datasets.
Examples:
> python split.py path_to_dataset.json -o train test -v 20 -1
This will create two datasets "train" and "test" in the folder where input dataset locates,
with the first 20 views in "train" and other views in "test".
> python split.py path_to_dataset.json -o ~_train ~_test -v -1 ::8
This will create two datasets "train" and "test" in the folder where input dataset locates,
with every 8 views in "test" and other views in "train".
"""
args = parser.parse_args()
input = get_dataset_desc_path(args.dataset)
input = DataDesc.get_json_path(args.dataset)
outputs = [
get_dataset_desc_path(input.with_name(f"{input.stem}_{appendix}"))
for appendix in args.output
DataDesc.get_json_path(input.with_name(
f"{input.stem}{appendix[1:]}" if appendix.startswith("~") else appendix))
for appendix in args.outputs
]
with open(input, 'r') as fp:
input_desc: dict = json.load(fp)
n_views = len(input_desc['view_centers'])
n_views = len(input_desc['centers'])
assert(len(args.views) == len(outputs))
sum_views = sum(args.views)
for i in range(len(args.views)):
if args.views[i] == -1:
args.views[i] = n_views - sum_views - 1
sum_views = n_views
break
assert(sum_views <= n_views)
for i in range(len(args.views)):
assert(args.views[i] > 0)
indices = torch.arange(n_views)
indices_assigned = torch.zeros(n_views, dtype=torch.bool)
output_dataset_indices: list[torch.Tensor] = [None] * len(outputs)
output_dataset_views = {}
for i, output_views in enumerate(args.views):
arr = output_views.split(":")
if len(arr) > 1:
view_slice = slice(*[int(value) if value != "" else None for value in arr])
output_dataset_indices[i] = indices[view_slice]
indices_assigned[view_slice] = True
else:
output_dataset_views[i] = int(arr[0])
indices_remain = indices[indices_assigned.logical_not()]
n_views_remain = len(indices_remain)
output_dataset_views = {
key: value
for key, value in zip(output_dataset_views,
calculate_autosize(n_views_remain, *output_dataset_views.values())[0])
}
if args.random:
indices_remain = indices_remain[torch.randperm(n_views_remain)]
offset = 0
for key, value in output_dataset_views.items():
output_dataset_indices[key] = indices_remain[offset:offset + value]
offset += value
in_views = torch.tensor(input_desc["views"]) if "views" in input_desc else torch.arange(n_views)
in_centers = torch.tensor(input_desc["centers"])
in_rots = torch.tensor(input_desc["rots"]) if "rots" in input_desc else None
for i in range(len(outputs)):
n = args.views[i]
end = offset + n
sub_indices = output_dataset_indices[i].sort()[0]
output_desc = input_desc.copy()
output_desc['samples'] = args.views[i]
if 'views' in output_desc:
output_desc['views'] = output_desc['views'][offset:end]
else:
output_desc['views'] = list(range(offset, end))
output_desc['view_centers'] = output_desc['view_centers'][offset:end]
if 'view_rots' in output_desc:
output_desc['view_rots'] = output_desc['view_rots'][offset:end]
output_desc['samples'] = [len(sub_indices)]
output_desc['views'] = in_views[sub_indices].tolist()
output_desc['centers'] = in_centers[sub_indices].tolist()
if in_rots is not None:
output_desc['rots'] = in_rots[sub_indices].tolist()
with open(outputs[i], 'w') as fp:
json.dump(output_desc, fp, indent=4)
# Create symbol links of images
out_dir = outputs[i].with_suffix('')
out_dir.mkdir(exist_ok=True)
for k in range(n):
os.symlink(Path("..") / input.stem / (output_desc['view_file_pattern'] % output_desc['views'][k]),
out_dir / (input_desc['view_file_pattern'] % output_desc['views'][k]))
offset += args.views[i]
for k in range(len(sub_indices)):
os.symlink(Path("..") / input.stem / (output_desc['color_file'] % output_desc['views'][k]),
out_dir / (input_desc['color_file'] % output_desc['views'][k]))
import json
import sys
import os
import argparse
import numpy as np
import shutil
from typing import List
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
from data import DataDesc
from utils.misc import calculate_autosize
def run(dataset: str, outputs: list[str], views: list[int], random: bool = False):
if len(views) != len(outputs):
raise ValueError("")
input = DataDesc.get_json_path(dataset)
outputs = [
DataDesc.get_json_path(input.with_name(f"{input.stem}_{appendix}"))
for appendix in outputs
]
with open(input, 'r') as fp:
input_desc: dict = json.load(fp)
n_views = len(input_desc['view_centers']) // 6
assert(len(views) == len(outputs))
views, sum_views = calculate_autosize(n_views, *views)
if random:
indices = np.random.permutation(n_views)
else:
indices = np.arange(n_views)
in_views = np.array(input_desc["views"]) if "views" in input_desc else np.arange(n_views)
in_centers = np.array(input_desc["view_centers"])
in_rots = np.array(input_desc["view_rots"]) if "view_rots" in input_desc else None
offset = 0
for i in range(len(outputs)):
n = views[i]
end = offset + n
sub_indices = np.sort(indices[offset:end])
sub_indices = np.concatenate([sub_indices + j * n_views for j in range(6)], axis=0)
output_desc = input_desc.copy()
output_desc['samples'] = [views[i] * 6]
output_desc['views'] = in_views[sub_indices].tolist()
output_desc['view_centers'] = in_centers[sub_indices].tolist()
if in_rots is not None:
output_desc['view_rots'] = in_rots[sub_indices].tolist()
with open(outputs[i], 'w') as fp:
json.dump(output_desc, fp, indent=4)
# Create symbol links of images
out_dir: Path = outputs[i].with_suffix('')
if out_dir.exists():
shutil.rmtree(out_dir)
out_dir.mkdir()
for view in output_desc['views']:
filename = output_desc['color_file'] % view
os.symlink(Path("..") / input.stem / filename, out_dir / filename)
offset += views[i]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--outputs', type=str, nargs="+", required=True)
parser.add_argument("-v", "--views", type=int, nargs="+", required=True)
parser.add_argument("--random", action="store_true")
parser.add_argument('dataset', type=str)
args = parser.parse_args()
run(args.dataset, args.outputs, args.views, args.random)
import sys
import os
import argparse
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--start', type=int)
parser.add_argument('-t', '--duration', type=int)
parser.add_argument('--fps', type=str)
parser.add_argument('--subset', type=str)
parser.add_argument('datadir', type=str)
args = parser.parse_args()
os.chdir(args.datadir)
if args.subset is not None:
video_dir = Path(f"videos/{args.subset}")
else:
video_dir = Path(f"video")
for video_path in video_dir.glob("*.*"):
# Extract frames from video
image_dir = "images" if args.subset is None else f"images/{args.subset}"
os.makedirs(f"{image_dir}/{video_path.stem}", exist_ok=True)
extra_args = []
if args.start is not None:
extra_args.append(f"-ss {args.start}")
if args.duration is not None:
extra_args.append(f"-t {args.duration}")
if args.fps is not None:
extra_args.append(f"-vf fps={args.fps}")
extra_args = ' '.join(extra_args)
os.system(f"ffmpeg -i {video_path} {extra_args} -f image2 -q:v 2 "
f"{image_dir}/{video_path.stem}/image%03d.png")
\ No newline at end of file
import torch
import argparse
from operator import itemgetter
parser = argparse.ArgumentParser()
parser.add_argument("ckpt_path", type=str)
cli_args = parser.parse_args()
args, states = itemgetter("args", "states")(torch.load(cli_args.ckpt_path))
print(f"Model: {args['model']} >>>>")
for key, value in args["model_args"].items():
print(f"{key}: {value}")
print("\n")
if args["trainer"]:
print(f"Trainer: {args['trainer']} >>>>")
for key, value in args["trainer_args"].items():
print(f"{key}={value}")
print("\n")
print("Model states >>>>")
for key, value in states["model"].items():
print(f"{key}: Tensor{list(value.shape)}")
......@@ -49,7 +49,7 @@ def load_net(path):
def export_net(net: torch.nn.Module, name: str,
input: Mapping[str, List[int]], output_names: List[str]):
input: Mapping[str, list[int]], output_names: list[str]):
outpath = os.path.join(opt.outdir, config.to_id(), name + ".onnx")
input_tensors = tuple([
torch.empty(size, device=device.default())
......
......@@ -49,7 +49,7 @@ def load_net(path):
def export_net(net: torch.nn.Module, name: str,
input: Mapping[str, List[int]], output_names: List[str]):
input: Mapping[str, list[int]], output_names: list[str]):
outpath = os.path.join(opt.outdir, config.to_id(), name + ".onnx")
input_tensors = tuple([
torch.empty(size, device=device.default())
......
import sys
import os
import argparse
import torch
import torch.optim
from torch import onnx
from pathlib import Path
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
sys.path.append(str(Path(__file__).absolute().parent.parent))
from utils import netio
import model
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
help='Which CUDA device to use.')
parser.add_argument('--batch-size', type=str,
help='Resolution')
parser.add_argument('--outdir', type=str, default='./',
parser.add_argument('--outdir', type=str, default='onnx',
help='Output directory')
parser.add_argument('model', type=str,
help='Path of model to export')
opt = parser.parse_args()
# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
from configs.spherical_view_syn import SphericalViewSynConfig
from utils import device
from utils import netio
from utils import misc
dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)
config = SphericalViewSynConfig()
def load_net(path):
name = os.path.splitext(os.path.basename(path))[0]
config.from_id(name)
config.sa['spherical'] = True
config.sa['perturb_sample'] = False
config.sa['n_samples'] = 4
config.print()
net = config.create_net().to(device.default())
netio.load(path, net)
return net, name
if __name__ == "__main__":
with torch.no_grad():
# Load model
net, name = load_net(model_file)
# Input to the model
rays_o = torch.empty(batch_size, 3, device=device.default())
rays_d = torch.empty(batch_size, 3, device=device.default())
with torch.inference_mode():
states, model_path = netio.load_checkpoint(opt.model)
batch_size = opt.batch_size and eval(opt.batch_size)
out_dir = model_path.parent / opt.outdir
os.makedirs(opt.outdir, exist_ok=True)
model.deserialize(states).eval().export_onnx(out_dir, batch_size)
# Export the model
outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
onnx.export(
net, # model being run
(rays_o, rays_d), # model input (or a tuple for multiple inputs)
outpath,
export_params=True, # store the trained parameter weights inside the model file
verbose=True,
opset_version=9, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding
input_names=['Rays_o', 'Rays_d'], # the model's input names
output_names=['Colors'] # the model's output names
)
print ('Model exported to ' + outpath)
print(f'Model exported to {out_dir}')
from pathlib import Path
import sys
import torch
import torch.optim
sys.path.append(str(Path(__file__).absolute().parent.parent))
import model
torch.set_grad_enabled(False)
m = model.load(
"/home/dengnc/dvs/data/classroom/_nets/train_hr_pano_t0.8/_hr_snerf/checkpoint_50.tar").eval().to("cuda")
print(m.cores[0])
inputs = (
torch.rand(10, 63, device="cuda"),
torch.rand(10, 24, device="cuda")
)
def fn(*args, **kwargs):
return m.cores[0].infer(*args, **kwargs)
sm = torch.jit.trace(fn, inputs)
torch.nn.Module.__call__
print(sm.infer(torch.rand(5, 63, device="cuda"), torch.rand(5, 24, device="cuda")))
sm.save("test.pt")
torch.onnx.export(sm.infer, # model being run
inputs, # model input (or a tuple for multiple inputs)
"core_0.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["x", "d"], # the model's input names
output_names=["densities", "colors"], # the model's output names
dynamic_axes={
"x": [0],
"d": [0],
"densities": [0],
"colors": [0]
}) # variable length axes
......@@ -61,8 +61,8 @@ def load_net():
return net
def export_net(net: torch.nn.Module, path: str, input: Mapping[str, List[int]],
output_names: List[str]):
def export_net(net: torch.nn.Module, path: str, input: Mapping[str, list[int]],
output_names: list[str]):
input_tensors = tuple([
torch.empty(size, device=device.default())
for size in input.values()
......
import json
import sys
import os
import csv
import argparse
import torch
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as nn_f
from tqdm import trange
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--stereo', action='store_true')
parser.add_argument('-R', '--replace', action='store_true')
parser.add_argument('--noCE', action='store_true')
parser.add_argument('-s', '--stereo', action='store_true',
help="Render stereo video")
parser.add_argument('-d', '--disparity', type=float, default=0.06,
help="The stereo disparity")
parser.add_argument('-R', '--replace', action='store_true',
help="Replace the existed frames in the intermediate output directory")
parser.add_argument('--noCE', action='store_true',
help="Disable constrast enhancement")
parser.add_argument('-i', '--input', type=str)
parser.add_argument('-r', '--range', type=str)
parser.add_argument('-f', '--fps', type=int)
parser.add_argument('--device', type=int, default=0,
help='Which CUDA device to use.')
parser.add_argument('scene', type=str)
parser.add_argument('view_file', type=str)
parser.add_argument('-r', '--range', type=str,
help="The range of frames to render, specified as format: start,end")
parser.add_argument('-f', '--fps', type=int,
help="The FPS of output video. if not specified, a sequence of images will be saved instead")
parser.add_argument('-m', '--model', type=str,
help="The directory containing fovea* and periph* model file")
parser.add_argument('view_file', type=str,
help="The path to .csv or .json file which contains a sequence of poses and gazes")
opt = parser.parse_args()
# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
torch.autograd.set_grad_enabled(False)
from configs.spherical_view_syn import SphericalViewSynConfig
from utils import netio
from utils import misc
from utils import img
from utils import device
from utils import netio, img, device
from utils.view import *
from utils import sphere
from utils.types import *
from components.fnr import FoveatedNeuralRenderer
from utils.progress_bar import progress_bar
def load_net(path):
config = SphericalViewSynConfig()
config.from_id(path[:-4])
config.sa['perturb_sample'] = False
# config.print()
net = config.create_net().to(device.default())
netio.load(path, net)
return net
from model import Model
def find_file(prefix):
for path in os.listdir():
if path.startswith(prefix):
return path
return None
def load_model(path: Path) -> Model:
checkpoint, _ = netio.load_checkpoint(path)
checkpoint["model"][1]["sampler"]["perturb"] = False
model = Model.create(*checkpoint["model"])
model.load_state_dict(checkpoint["states"]["model"])
model.to(device.default()).eval()
return model
def clamp_gaze(gaze):
return gaze
scoord = sphere.cartesian2spherical(gaze)
def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]:
def load_csv(data_desc_file: Path) -> tuple[Trans, torch.Tensor]:
def to_tensor(line_content):
return torch.tensor([float(str) for str in line_content.split(',')])
......@@ -84,9 +72,9 @@ def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]:
if lines[i + 1].startswith('0,0,0') or lines[i + 2].startswith('0,0,0'):
continue
j = i + 1
gaze_dirs[view_idx, 0] = clamp_gaze(to_tensor(lines[j]))
gaze_dirs[view_idx, 0] = to_tensor(lines[j])
j += 1
gaze_dirs[view_idx, 1] = clamp_gaze(to_tensor(lines[j]))
gaze_dirs[view_idx, 1] = to_tensor(lines[j])
j += 1
if not old_fmt:
gazes[view_idx, 0] = to_tensor(lines[j])
......@@ -113,11 +101,16 @@ def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]:
return Trans(view_t, view_mats[:, 0, :3, :3]), gazes
def load_json(data_desc_file) -> Tuple[Trans, torch.Tensor]:
def load_json(data_desc_file: Path) -> tuple[Trans, torch.Tensor]:
with open(data_desc_file, 'r', encoding='utf-8') as file:
data = json.load(file)
view_t = torch.tensor(data['view_centers'])
view_r = torch.tensor(data['view_rots']).view(-1, 3, 3)
if data.get("gl_coord"):
view_t[:, 2] *= -1
view_r[:, 2] *= -1
view_r[..., 2] *= -1
if data.get('gazes'):
if len(data['gazes'][0]) == 2:
gazes = torch.tensor(data['gazes']).view(-1, 1, 2).expand(-1, 2, -1)
......@@ -127,68 +120,61 @@ def load_json(data_desc_file) -> Tuple[Trans, torch.Tensor]:
gazes = torch.zeros(view_t.size(0), 2, 2)
return Trans(view_t, view_r), gazes
def load_views(data_desc_file: str) -> Tuple[Trans, torch.Tensor]:
if data_desc_file.endswith('.csv'):
return load_csv(data_desc_file)
return load_json(data_desc_file)
rot_range = {
'classroom': [120, 80],
'barbershop': [360, 80],
'lobby': [360, 80],
'stones': [360, 80]
}
trans_range = {
'classroom': 0.6,
'barbershop': 0.3,
'lobby': 1.0,
'stones': 1.0
}
fov_list = [20, 45, 110]
res_list = [(256, 256), (256, 256), (400, 360)]
def load_views_and_gazes(data_desc_file: Path) -> tuple[Trans, torch.Tensor]:
if data_desc_file.suffix == '.csv':
views, gazes = load_csv(data_desc_file)
else:
views, gazes = load_json(data_desc_file)
gazes[:, :, 1] = (gazes[:, :1, 1] + gazes[:, 1:, 1]) * 0.5
return views, gazes
torch.set_grad_enabled(False)
view_file = Path(opt.view_file)
stereo_disparity = opt.disparity
res_full = (1600, 1440)
stereo_disparity = 0.06
cwd = os.getcwd()
os.chdir(f"{sys.path[0]}/../data/__new/{opt.scene}_all")
fovea_net = load_net(find_file('fovea'))
periph_net = load_net(find_file('periph'))
renderer = FoveatedNeuralRenderer(fov_list, res_list,
nn.ModuleList([fovea_net, periph_net, periph_net]),
res_full, device=device.default())
os.chdir(cwd)
if opt.model:
model_dir = Path(opt.model)
fov_list = [20.0, 45.0, 110.0]
res_list = [(256, 256), (256, 256), (256, 230)]
fovea_net = load_model(next(model_dir.glob("fovea*.tar")))
periph_net = load_model(next(model_dir.glob("periph*.tar")))
renderer = FoveatedNeuralRenderer(fov_list, res_list,
nn.ModuleList([fovea_net, periph_net, periph_net]),
res_full, device=device.default())
else:
renderer = None
# Load Dataset
views, gazes = load_views(opt.view_file)
views, gazes = load_views_and_gazes(Path(opt.view_file))
if opt.range:
opt.range = [int(val) for val in opt.range.split(",")]
if len(opt.range) == 1:
opt.range = [0, opt.range[0]]
views = views.get(range(opt.range[0], opt.range[1]))
views = views[opt.range[0]:opt.range[1]]
gazes = gazes[opt.range[0]:opt.range[1]]
views = views.to(device.default())
n_views = views.size()[0]
n_views = views.shape[0]
print('Dataset loaded. Views:', n_views)
videodir = os.path.dirname(os.path.abspath(opt.view_file))
tempdir = '/dev/shm/dvs_tmp/video'
videoname = f"{os.path.splitext(os.path.split(opt.view_file)[-1])[0]}_{'stereo' if opt.stereo else 'mono'}"
gazeout = f"{videodir}/{videoname}_gaze.csv"
if opt.noCE:
videoname += "_noCE"
videodir = view_file.absolute().parent
tempdir = Path('/dev/shm/dvs_tmp/video')
if opt.input:
videoname = Path(opt.input).parent.stem
else:
videoname = f"{view_file.stem}_{('stereo' if opt.stereo else 'mono')}"
if opt.noCE:
videoname += "_noCE"
gazeout = videodir / f"{videoname}_gaze.csv"
if opt.fps:
if opt.input:
videoname = os.path.split(opt.input)[0]
inferout = f"{videodir}/{opt.input}" if opt.input else f"{tempdir}/{videoname}/%04d.bmp"
hintout = f"{tempdir}/{videoname}_hint/%04d.bmp"
else:
inferout = f"{videodir}/{opt.input}" if opt.input else f"{videodir}/{videoname}/%04d.png"
hintout = f"{videodir}/{videoname}_hint/%04d.png"
if opt.input:
scale = img.load(inferout % 0).shape[-1] / res_full[1]
else:
scale = 1
scale = img.load(inferout % 0).shape[-1] / res_full[1] if opt.input else 1
hint = img.load(f"{sys.path[0]}/fovea_hint.png", with_alpha=True).to(device=device.default())
hint = nn_f.interpolate(hint, mode='bilinear', scale_factor=scale, align_corners=False)
......@@ -233,66 +219,44 @@ if not opt.replace:
hint_offset = max(0, hint_offset - 1)
infer_offset = n_views if opt.input else max(0, infer_offset - 1)
if opt.stereo:
gazes_out = torch.empty(n_views, 4)
for view_idx in range(n_views):
shift = gazes[view_idx, 0, 0] - gazes[view_idx, 1, 0]
# print(shift.item())
gazel = ((gazes[view_idx, 1, 0] + 0.4 * shift).item(),
0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item())
gazer = ((gazes[view_idx, 0, 0] - 0.4 * shift).item(), gazel[1])
# gazel = ((gazes[view_idx, 0, 0]).item(),
# 0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item())
#gazer = ((gazes[view_idx, 1, 0]).item(), gazel[1])
gazes_out[view_idx] = torch.tensor([gazel[0], gazel[1], gazer[0], gazer[1]])
if view_idx < hint_offset:
continue
if view_idx < infer_offset:
frame = img.load(inferout % view_idx).to(device=device.default())
else:
view_trans = views.get(view_idx)
left_images, right_images = renderer(view_trans, gazel, gazer,
stereo_disparity=stereo_disparity,
mono_periph_mode=3, ret_raw=True)
frame = torch.cat([
left_images['blended_raw'] if opt.noCE else left_images['blended'],
right_images['blended_raw'] if opt.noCE else right_images['blended']], -1)
frame = img.translate(frame, (0.5, 0.5))
img.save(frame, inferout % view_idx)
add_hint(frame, gazel, gazer)
img.save(frame, hintout % view_idx)
progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx)
else:
gazes_out = torch.empty(n_views, 2)
for view_idx in range(n_views):
gaze = 0.5 * (gazes[view_idx, 0] + gazes[view_idx, 1])
gaze = (gaze[0].item(), gaze[1].item())
gazes_out[view_idx] = torch.tensor([gaze[0], gaze[1]])
if view_idx < hint_offset:
continue
if view_idx < infer_offset:
frame = img.load(inferout % view_idx).to(device=device.default())
for view_idx in trange(n_views):
if view_idx < hint_offset:
continue
gaze = gazes[view_idx]
if not opt.stereo:
gaze = gaze.sum(0, True) * 0.5
gaze = gaze.tolist()
if view_idx < infer_offset:
frame = img.load(inferout % view_idx).to(device=device.default())
else:
if renderer is None:
raise Exception
key = 'blended_raw' if opt.noCE else 'blended'
view_trans = views[view_idx]
if opt.stereo:
left_images, right_images = renderer(view_trans, *gaze,
stereo_disparity=stereo_disparity,
mono_periph_mode=3, ret_raw=True)
frame = torch.cat([left_images[key], right_images[key]], -1)
else:
view_trans = views.get(view_idx)
frame = renderer(view_trans, gaze,
ret_raw=True)['blended_raw' if opt.noCE else 'blended']
frame = img.translate(frame, (0.5, 0.5))
img.save(frame, inferout % view_idx)
add_hint(frame, gaze)
img.save(frame, hintout % view_idx)
progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx)
frame = renderer(view_trans, *gaze, ret_raw=True)[key]
img.save(frame, inferout % view_idx)
add_hint(frame, *gaze)
img.save(frame, hintout % view_idx)
gazes_out = gazes.reshape(-1, 4) if opt.stereo else gazes.sum(1) * 0.5
with open(gazeout, 'w') as fp:
for i in range(n_views):
fp.write(','.join([f'{val.item()}' for val in gazes_out[i]]))
fp.write('\n')
csv_writer = csv.writer(fp)
csv_writer.writerows(gazes_out.tolist())
if opt.fps:
# Generate video without hint
os.system(f'ffmpeg -y -r {opt.fps:d} -i {inferout} -c:v libx264 {videodir}/{videoname}.mp4')
if not opt.input:
shutil.rmtree(os.path.dirname(inferout))
# Generate video with hint
os.system(f'ffmpeg -y -r {opt.fps:d} -i {hintout} -c:v libx264 {videodir}/{videoname}_hint.mp4')
# Clean temp images
if not opt.input:
shutil.rmtree(os.path.dirname(inferout))
shutil.rmtree(os.path.dirname(hintout))
import sys
import os
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
import argparse
from PIL import Image
from utils import misc
from tqdm import tqdm
from pathlib import Path
def batch_scale(src, target, size):
os.makedirs(target, exist_ok=True)
for file_name in os.listdir(src):
postfix = os.path.splitext(file_name)[1]
if postfix == '.jpg' or postfix == '.png':
im = Image.open(os.path.join(src, file_name))
im = im.resize(size)
im.save(os.path.join(target, file_name))
def run(src: Path, target: Path, root: Path, scale_factor: float = 1., width: int = -1, height: int = -1):
target.mkdir(exist_ok=True)
for file_name in tqdm(os.listdir(src), leave=False, desc=src.relative_to(root).__str__()):
if (src / file_name).is_dir():
run(src / file_name, target / file_name, root, scale_factor, width, height)
elif not (target / file_name).exists():
postfix = os.path.splitext(file_name)[1]
if postfix == '.jpg' or postfix == '.png':
im = Image.open(src / file_name)
if width == -1 and height == -1:
width = round(im.width * scale_factor)
height = round(im.height * scale_factor)
elif width == -1:
width = round(im.width / im.height * height)
elif height == -1:
height = round(im.height / im.width * width)
im = im.resize((width, height))
im.save(target / file_name)
if __name__ == '__main__':
......@@ -23,9 +31,11 @@ if __name__ == '__main__':
help='Source directory.')
parser.add_argument('target', type=str,
help='Target directory.')
parser.add_argument('--width', type=int,
parser.add_argument('-x', '--scale-factor', type=float, default=1,
help='Target directory.')
parser.add_argument('--width', type=int, default=-1,
help='Width of output images (pixel)')
parser.add_argument('--height', type=int,
parser.add_argument('--height', type=int, default=-1,
help='Height of output images (pixel)')
opt = parser.parse_args()
batch_scale(opt.src, opt.target, (opt.width, opt.height))
run(Path(opt.src), Path(opt.target), Path(opt.src), opt.scale_factor, opt.width, opt.height)
......@@ -40,7 +40,7 @@ for subdir in ['images', 'images_4', 'images_8']:
print(f"Rename {src} to {tgt}")
os.rename(src, tgt)
out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = f"{subdir}/view_%04d.jpg"
out_desc['color_file'] = f"{subdir}/view_%04d.jpg"
k = res[0] / dataset_desc['view_res']['y']
out_desc['view_res'] = {
'x': res[1],
......
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import List
from operator import itemgetter
from configargparse import ArgumentParser, SUPPRESS
import model as mdl
import train
from utils import device
from utils import netio
from data import *
from utils.misc import print_and_log
from model import Model
from train import Trainer
from utils import device, netio
from utils.types import *
from data import Dataset
RAYS_PER_BATCH = 2 ** 12
DATA_LOADER_CHUNK_SIZE = 1e8
root_dir = Path(__file__).absolute().parent
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
help='Net config files')
parser.add_argument('-e', '--epochs', type=int,
help='Max epochs for train')
parser.add_argument('--perf', type=int,
help='Performance measurement frames (0 for disabling performance measurement)')
parser.add_argument('--prune', type=int, nargs='+',
help='Prune voxels on every # epochs')
parser.add_argument('--split', type=int, nargs='+',
help='Split voxels on every # epochs')
parser.add_argument('--freeze', type=int, nargs='+',
help='freeze levels on epochs')
parser.add_argument('--checkpoint-interval', type=int)
parser.add_argument('--views', type=str,
help='Specify the range of views to train')
parser.add_argument('path', type=str,
help='Dataset description file')
args = parser.parse_args()
views_to_load = range(*[int(val) for val in args.views.split('-')]) if args.views else None
argpath = Path(args.path)
# argpath: May be model path or data path
# 1) model path: continue training on the specified model
# 2) data path: train a new model using specified dataset
def load_dataset(data_path: Path):
print(f"Loading dataset {data_path}")
try:
dataset = DatasetFactory.load(data_path, views_to_load=views_to_load)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
os.chdir(dataset.root)
return dataset, dataset.name
except FileNotFoundError:
return load_multiscale_dataset(data_path)
def load_multiscale_dataset(data_path: Path):
if not data_path.is_dir():
raise ValueError(
f"Path {data_path} is not a directory")
dataset: List[Union[PanoDataset, ViewDataset]] = []
for sub_data_desc_path in data_path.glob("*.json"):
sub_dataset = DatasetFactory.load(sub_data_desc_path, views_to_load=views_to_load)
print(f"Sub-dataset loaded: {sub_dataset.root}/{sub_dataset.name}")
dataset.append(sub_dataset)
if len(dataset) == 0:
raise ValueError(f"Path {data_path} does not contain sub-datasets")
os.chdir(data_path.parent)
return dataset, data_path.name
try:
states, checkpoint_path = netio.load_checkpoint(argpath)
# Infer dataset path from model path
# The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar
model_name = checkpoint_path.parts[-2]
dataset, dataset_name = load_dataset(
Path(*checkpoint_path.parts[:-4]) / checkpoint_path.parts[-3])
except Exception:
model_name = args.config
dataset, dataset_name = load_dataset(argpath)
# Load state 0 from specified configuration
with Path(f'{root_dir}/configs/{args.config}.json').open() as fp:
states = json.load(fp)
states['args']['bbox'] = dataset[0].bbox if isinstance(dataset, list) else dataset.bbox
states['args']['depth_range'] = dataset[0].depth_range if isinstance(dataset, list)\
else dataset.depth_range
def load_dataset(data_path: Path, color: str, coord: str):
dataset = Dataset(data_path, color_mode=Color[color], coord_sys=coord)
print(f"Load dataset: {dataset.root}/{dataset.name}")
return dataset
if 'train' not in states:
states['train'] = {}
if args.prune is not None:
states['train']['prune_epochs'] = args.prune
if args.split is not None:
states['train']['split_epochs'] = args.split
if args.freeze is not None:
states['train']['freeze_epochs'] = args.freeze
if args.perf is not None:
states['train']['perf_frames'] = args.perf
if args.checkpoint_interval is not None:
states['train']['checkpoint_interval'] = args.checkpoint_interval
if args.epochs is not None:
states['train']['max_epochs'] = args.epochs
model = mdl.deserialize(states).to(device.default())
initial_parser = ArgumentParser()
initial_parser.add_argument('-c', '--config', type=str, default=SUPPRESS,
help='Config name, ignored if path is a checkpoint path')
initial_parser.add_argument('--expname', type=str, default=SUPPRESS,
help='Experiment name, defaults to config name, ignored if path is a checkpoint path')
initial_parser.add_argument('path', type=str,
help='Path to dataset description file or checkpoint file')
initial_args = vars(initial_parser.parse_known_args()[0])
# Initialize run directory
run_dir = Path(f"_nets/{dataset_name}/{model_name}")
run_dir.mkdir(parents=True, exist_ok=True)
# Initialize logging
log_file = run_dir / "train.log"
logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO,
filename=log_file, filemode='a' if log_file.exists() else 'w')
def log_exception(exc_type, exc_value, exc_traceback):
if not issubclass(exc_type, KeyboardInterrupt):
logging.exception(exc_value, exc_info=(exc_type, exc_value, exc_traceback))
sys.__excepthook__(exc_type, exc_value, exc_traceback)
sys.excepthook = log_exception
print_and_log(f"model: {model_name} ({model.cls})")
print_and_log(f"args:")
model.print_config()
print(model)
if __name__ == "__main__":
# 1. Initialize data loader
data_loader = get_loader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
shuffle=True, enable_preload=False, color=model.color)
# 2. Initialize model and trainer
trainer = train.get_trainer(model, run_dir, states)
# 3. Train
trainer.train(data_loader)
root_dir = Path(__file__).absolute().parent
argpath = Path(initial_args["path"]) # May be checkpoint path or dataset path
# 1) checkpoint path: continue training a model
# 2) dataset path: train a new model using specified dataset
ckpt_path = netio.find_checkpoint(argpath)
if ckpt_path:
# Continue training from a checkpoint
print(f"Load checkpoint {ckpt_path}")
args, states = itemgetter("args", "states")(torch.load(ckpt_path))
# args: "model", "model_args", "trainer", "trainer_args"
ModelCls = Model.get_class(args["model"])
TrainerCls = Trainer.get_class(args["trainer"])
model_args = ModelCls.Args(**args["model_args"])
trainer_args = TrainerCls.Args(**args["trainer_args"]).parse()
trainset = load_dataset(trainer_args.trainset, model_args.color, model_args.coord)
run_dir = ckpt_path.parent
else:
# Start a new train
expname = initial_args.get("expname", initial_args.get("config", "unnamed"))
if "config" in initial_args:
config_path = root_dir / "configs" / f"{initial_args['config']}.ini"
if not config_path.exists():
raise ValueError(f"Config {initial_args['config']} is not found in "
f"{root_dir / 'configs'}.")
print(f"Load config {config_path}")
else:
config_path = None
# First parse model class and trainer class from config file or command-line arguments
parser = ArgumentParser(default_config_files=[f"{config_path}"] if config_path else [])
parser.add_argument('--color', type=str, default="rgb",
help='The color mode')
parser.add_argument('--model', type=str, required=True,
help='The model to train')
parser.add_argument('--trainer', type=str, default="Trainer",
help='The trainer to use for training')
args = parser.parse_known_args()[0]
ModelCls = Model.get_class(args.model)
TrainerCls = Trainer.get_class(args.trainer)
trainset_path = argpath
trainset = load_dataset(trainset_path, args.color, "gl")
# Then parse model's and trainer's args
if trainset.depth_range:
model_args = ModelCls.Args( # Some model's args are inferred from training dataset
color=trainset.color_mode.name,
near=trainset.depth_range[0],
far=trainset.depth_range[1],
white_bg=trainset.white_bg,
coord=trainset.coord_sys
)
else:
model_args = ModelCls.Args(white_bg=trainset.white_bg)
model_args.parse(config_path)
trainer_args = TrainerCls.Args(trainset=f"{trainset_path}").parse(config_path)
states = None
run_dir = trainset.root / "_nets" / trainset.name / expname
run_dir.mkdir(parents=True, exist_ok=True)
m = ModelCls(model_args).to(device.default())
trainer = TrainerCls(m, run_dir, trainer_args)
if states:
trainer.load_state_dict(states)
# Start train
trainer.train(trainset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment