Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
#include <device_launch_parameters.h>
#include <glm/glm.hpp>
#define T_IDX threadIdx.x
#define T_IDX2 glm::uvec2(threadIdx.x, threadIdx.y)
#define T_IDX3 glm::uvec3(threadIdx.x, threadIdx.y, threadIdx.z)
#define B_IDX blockIdx.x
#define B_IDX2 glm::uvec2(blockIdx.x, blockIdx.y)
#define B_IDX3 glm::uvec3(blockIdx.x, blockIdx.y, blockIdx.z)
#define IDX blockIdx.x *blockDim.x + threadIdx.x
#define IDX2 glm::uvec2(blockIdx.x *blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y)
#define IDX3 \
glm::uvec3(blockIdx.x *blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, \
blockIdx.z * blockDim.z + threadIdx.z)
#define FLAT_INDEX utils::cuda::flattenIdx(IDX3)
#define DEFINE_IDX(__var1__) uint __var1__ = blockIdx.x * blockDim.x + threadIdx.x;
#define DEFINE_IDX2(__var1__, __var2__) \
uint __var1__ = blockIdx.x * blockDim.x + threadIdx.x; \
uint __var2__ = blockIdx.y * blockDim.y + threadIdx.y;
#define DEFINE_IDX3(__var1__, __var2__, __var3__) \
uint __var1__ = blockIdx.x * blockDim.x + threadIdx.x; \
uint __var2__ = blockIdx.y * blockDim.y + threadIdx.y; \
uint __var3__ = blockIdx.z * blockDim.z + threadIdx.z;
#define DEFINE_FLAT_INDEX(__var1__) uint __var1__ = FLAT_INDEX;
namespace utils::cuda {
__device__ __forceinline__ uint flattenIdx(glm::uvec3 idx3) {
return idx3.x + idx3.y * blockDim.x * gridDim.x +
idx3.z * blockDim.x * gridDim.x * blockDim.y * gridDim.y;
}
__device__ __forceinline__ uint flattenIdx(glm::uvec2 idx2) {
return idx2.x + idx2.y * blockDim.x * gridDim.x;
}
} // namespace utils::cuda
#include <cuda_runtime.h>
#include <vector>
#include "error.h"
namespace utils::cuda {
class MapResourcesScope {
public:
MapResourcesScope(const std::vector<cudaGraphicsResource_t> &resources,
cudaStream_t stream = nullptr)
: _resources(resources), _stream(stream) {
if (!_resources.empty())
THROW_IF_FAILED(
cudaGraphicsMapResources((int)_resources.size(), _resources.data(), _stream));
}
~MapResourcesScope() {
if (!_resources.empty())
cudaGraphicsUnmapResources((int)_resources.size(), _resources.data(), _stream);
}
private:
std::vector<cudaGraphicsResource_t> _resources;
cudaStream_t _stream;
};
} // namespace utils::cuda
#pragma once
#include <cuda_runtime.h>
#include "../common.h"
#include "error.h"
namespace utils::cuda {
class Resource {
public:
virtual ~Resource() {}
virtual void *getBuffer() const = 0;
virtual size_t size() const = 0;
};
class BufferResource : public Resource {
public:
BufferResource(void *buffer = nullptr, size_t size = 0)
: _p_buffer(buffer), _ownBuffer(false), _size(size) {}
BufferResource(size_t size) : _ownBuffer(true), _size(size) {
void *p_buffer;
THROW_IF_FAILED(cudaMalloc(&p_buffer, size));
_p_buffer = std::shared_ptr<void>(p_buffer);
}
virtual ~BufferResource() {
if (!_ownBuffer || _p_buffer.use_count() > 1)
return;
try {
THROW_IF_FAILED(cudaFree(_p_buffer.get()));
} catch (std::exception &ex) {
common::Logger::instance.warning("Exception raised in destructor: %s", ex.what());
}
}
virtual void *getBuffer() const { return _p_buffer.get(); }
virtual size_t size() const { return _size; }
private:
std::shared_ptr<void> _p_buffer;
bool _ownBuffer;
size_t _size;
};
class GraphicsResource : public Resource {
public:
cudaGraphicsResource_t getHandler() { return _res; }
virtual ~GraphicsResource() {
if (_res == nullptr)
return;
try {
THROW_IF_FAILED(cudaGraphicsUnregisterResource(_res));
} catch (std::exception &ex) {
common::Logger::instance.warning("Exception raised in destructor: %s", ex.what());
}
_res = nullptr;
}
virtual size_t size() const { return _size; }
protected:
cudaGraphicsResource_t _res;
size_t _size;
GraphicsResource() : _res(nullptr), _size(0) {}
};
} // namespace utils::cuda
#include "resource.h"
#include <map>
#include <vector>
namespace utils::cuda {
class Resources : public std::map<std::string, Resource &> {
public:
std::vector<cudaGraphicsResource_t> getGraphicsResourceHandlers() {
std::vector<cudaGraphicsResource_t> handlers;
for (auto &&item : *this) {
auto gres = dynamic_cast<GraphicsResource *>(&item.second);
if (gres != nullptr)
handlers.push_back(gres->getHandler());
}
return handlers;
}
};
} // namespace utils::cuda
#include <cuda_runtime.h>
#include <memory>
namespace utils::cuda {
class Stream {
public:
Stream() : _p_stream(std::make_shared<cudaStream_t>()) {
cudaStreamCreate(_p_stream.get());
}
virtual ~Stream() {
if (_p_stream.use_count() == 1)
cudaStreamDestroy(*_p_stream);
}
operator cudaStream_t() { return *_p_stream; }
private:
std::shared_ptr<cudaStream_t> _p_stream;
};
} // namespace utils::cuda
template <typename T, typename T2 = T>
void dumpArray(std::ostream &so, CudaArray<T> &arr, size_t maxDumpRows = 0,
size_t elemsPerRow = 1) {
int chns = sizeof(T) / sizeof(T2);
T2 *hostArr = new T2[arr.n() * chns];
cudaMemcpy(hostArr, arr.getBuffer(), arr.n() * sizeof(T), cudaMemcpyDeviceToHost);
dumpHostBuffer<T2>(so, hostArr, arr.n() * sizeof(T), chns * elemsPerRow, maxDumpRows);
delete[] hostArr;
}
\ No newline at end of file
#pragma once
#include "Msl.h"
class Nmsl2 : public Msl
{
public:
sptr<Resource> resRaw1;
sptr<Resource> resRaw2;
Net *fcNet1;
Net *fcNet2;
Net *catNet;
unsigned int batchSize;
unsigned int samples;
Nmsl2(int batchSize, int samples);
virtual bool load(const std::string &netDir);
virtual void bindResources(Resource *resEncoded, Resource *resDepths, Resource *resColors);
virtual bool infer();
virtual void dispose();
};
#pragma once
#include "../utils/common.h"
class Sampler {
public:
Sampler(glm::vec2 depthRange, unsigned int samples, bool outputRadius)
: _dispRange(1.0f / depthRange, samples), _outputRadius(outputRadius) {}
void sampleOnRays(sptr<CudaArray<float>> o_coords, sptr<CudaArray<float>> o_depths,
sptr<CudaArray<glm::vec3>> rays, glm::vec3 rayCenter);
private:
Range _dispRange;
bool _outputRadius;
};
\ No newline at end of file
#include "Msl.h"
#include <time.h>
Msl::Msl() : net(nullptr) {}
bool Msl::load(const std::string &netPath) {
net = new Net();
if (net->load(netPath))
return true;
dispose();
return false;
}
void Msl::bindResources(Resource *resEncoded, Resource *resDepths, Resource *resColors) {
net->bindResource("Encoded", resEncoded);
net->bindResource("Depths", resDepths);
net->bindResource("Colors", resColors);
}
bool Msl::infer() { return net->infer(); }
void Msl::dispose() {
if (net != nullptr) {
net->dispose();
delete net;
net = nullptr;
}
}
#pragma once
#include "../utils/common.h"
#include "Net.h"
class Msl {
public:
Net *net;
Msl();
virtual bool load(const std::string &netDir);
virtual void bindResources(Resource *resEncoded, Resource *resDepths, Resource *resColors);
virtual bool infer();
virtual void dispose();
};
#include "Nmsl2.h"
#include <time.h>
Nmsl2::Nmsl2(int batchSize, int samples)
: batchSize(batchSize),
samples(samples),
resRaw1(nullptr),
resRaw2(nullptr),
fcNet1(nullptr),
fcNet2(nullptr),
catNet(nullptr) {}
bool Nmsl2::load(const std::string &netDir) {
fcNet1 = new Net();
fcNet2 = new Net();
catNet = new Net();
if (!fcNet1->load(netDir + "fc1.trt") || !fcNet2->load(netDir + "fc2.trt") ||
!catNet->load(netDir + "cat.trt"))
return false;
resRaw1 = sptr<Resource>(new CudaBuffer(batchSize * samples / 2 * sizeof(float4)));
resRaw2 = sptr<Resource>(new CudaBuffer(batchSize * samples / 2 * sizeof(float4)));
return true;
}
void Nmsl2::bindResources(Resource *resEncoded, Resource *resDepths, Resource *resColors) {
fcNet1->bindResource("Encoded", resEncoded);
fcNet1->bindResource("Raw", resRaw1.get());
fcNet2->bindResource("Encoded", resEncoded);
fcNet2->bindResource("Raw", resRaw2.get());
catNet->bindResource("Raw1", resRaw1.get());
catNet->bindResource("Raw2", resRaw2.get());
catNet->bindResource("Depths", resDepths);
catNet->bindResource("Colors", resColors);
}
bool Nmsl2::infer() {
// CudaStream stream1, stream2;
if (!fcNet1->infer())
return false;
if (!fcNet2->infer())
return false;
if (!catNet->infer())
return false;
return true;
}
void Nmsl2::dispose() {
if (fcNet1 != nullptr) {
fcNet1->dispose();
delete fcNet1;
fcNet1 = nullptr;
}
if (fcNet2 != nullptr) {
fcNet2->dispose();
delete fcNet2;
fcNet2 = nullptr;
}
if (catNet != nullptr) {
catNet->dispose();
delete catNet;
catNet = nullptr;
}
resRaw1 = nullptr;
resRaw2 = nullptr;
}
#include "Renderer.h"
#include "../utils/cuda.h"
/// Dispatch (n_rays, -)
__global__ void cu_render(glm::vec4 *o_colors, glm::vec4 *layeredColors, uint samples, uint nRays) {
glm::uvec3 idx3 = IDX3;
uint rayIdx = idx3.x;
if (rayIdx >= nRays)
return;
glm::vec4 outColor;
for (int si = samples - 1; si >= 0; --si) {
glm::vec4 c = layeredColors[rayIdx * samples + si];
outColor = outColor * (1 - c.a) + c * c.a;
}
outColor.a = 1.0f;
o_colors[idx3.x] = outColor;
}
Renderer::Renderer() {}
void Renderer::render(sptr<CudaArray<glm::vec4>> o_colors,
sptr<CudaArray<glm::vec4>> layeredColors) {
dim3 blkSize(1024);
dim3 grdSize(ceilDiv(o_colors->n(), blkSize.x));
CU_INVOKE(cu_render)
(*o_colors, *layeredColors, layeredColors->n() / o_colors->n(), o_colors->n());
CHECK_EX(cudaGetLastError());
}
\ No newline at end of file
#pragma once
#include "../utils/common.h"
class Renderer {
public:
Renderer();
/**
* @brief
*
* @param o_colors
* @param layeredColors
*/
void render(sptr<CudaArray<glm::vec4>> o_colors, sptr<CudaArray<glm::vec4>> layeredColors);
};
\ No newline at end of file
#include "Logger.h"
Logger Logger::instance;
#include <device_launch_parameters.h>
#include <glm/glm.hpp>
#define IDX2 glm::uvec2 { blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y }
#define IDX3 glm::uvec3 { blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, blockIdx.z * blockDim.z + threadIdx.z }
__device__ __forceinline__ unsigned int flattenIdx(glm::uvec3 idx3)
{
return idx3.x + idx3.y * blockDim.x * gridDim.x + idx3.z * blockDim.x * gridDim.x * blockDim.y * gridDim.y;
}
__device__ __forceinline__ unsigned int flattenIdx()
{
return flattenIdx(IDX3);
}
\ No newline at end of file
from .utils import * from .dataset import DataDesc, Dataset
from .dataset_factory import * from .loader import RaysLoader, MultiScaleDataLoader
from .loader import *
\ No newline at end of file
import json
import torch import torch
import torch.utils.data
import torch.nn.functional as nn_f
from typing import Union
from operator import itemgetter from operator import itemgetter
from typing import Tuple, Union
from pathlib import Path
from utils import view try:
from .utils import get_data_path from ..utils import view, img, math
from ..utils.types import *
from ..utils.misc import calculate_autosize
except ImportError:
from utils import view, img, math
from utils.types import *
from utils.misc import calculate_autosize
class Dataset(object): class DataDesc(dict[str, Any]):
desc: dict path: Path
desc_path: Path
device: torch.device
@property @property
def name(self): def name(self) -> str:
return self.desc_path.stem return self.path.stem
@property @property
def root(self): def root(self) -> Path:
return self.desc_path.parent return self.path.parent
@property @property
def n_views(self): def coord_sys(self) -> str:
return self.centers.size(0) return "gl" if self.get("gl_coord") else "dx"
def __init__(self, path: PathLike):
path = DataDesc.get_json_path(path)
with open(path, 'r', encoding='utf-8') as file:
data = json.loads(file.read())
super().__init__(data)
self.path = path
@staticmethod
def get_json_path(path: PathLike) -> Path:
path = Path(path)
if path.suffix != ".json":
path = Path(f"{path}.json")
return path.absolute()
def get(self, key: str, fn=lambda x: x, default=None) -> Any | None:
if key in self:
return fn(self[key])
return default
def get_as_tensor(self, key: str, fn=lambda x: x, default=None, dtype=torch.float, device=None,
shape=None) -> torch.Tensor | None:
raw_value = self.get(key, fn, default)
if raw_value is None:
return raw_value
tensor_value = torch.tensor(raw_value, dtype=dtype, device=device)
if shape is not None:
tensor_value = tensor_value.reshape(shape)
return tensor_value
def get_path(self, name: str) -> str | None:
path_pattern = self.get(f"{name}_file")
if not path_pattern:
return None
if "/" not in path_pattern:
path_pattern = f"{self.name}/{path_pattern}"
return str(self.root / path_pattern)
class Dataset(torch.utils.data.Dataset):
root: Path
"""`Path` Root directory of the dataset"""
name: str
"""`str` Name of the dataset"""
color_mode: Color
"""`Color` Color mode of images in the dataset"""
white_bg: bool
"""`bool` Images in the dataset should have white background"""
level: int
"""`int` Level of this dataset"""
res: Resolution
"""`Resolution` Resolution of each view as (rows, columns)"""
coord_sys: str
"""`str` Coordinate system, must be 'dx' or 'gl'"""
device: torch.device
"""`device` Device of tensors"""
cam: view.Camera
"""`Camera?` Camera object"""
depth_range: tuple[float, float] | None
"""`(float, float)?` Depth range of the scene as a guide to sampling"""
bbox: tuple[tuple[float, ...], tuple[float, ...]] | None
"""`((float,...), (float,...))?` Bounding box of the scene as a guide to sampling"""
trans_range: tuple[tuple[float, ...], tuple[float, ...]] | None
"""`((float,...), (float,...))?` Acceptable Translation (and optional rotation) range"""
color_path: str | None
"""`str?` Path of image data"""
depth_path: str | None
"""`str?` Path of depth data"""
indices: torch.Tensor
"""`Tensor(N)` Indices for loading specific subset of views in the dataset"""
centers: torch.Tensor
"""`Tensor(N, 3)` Center positions of views"""
rots: torch.Tensor | None
"""`Tensor(N, 3, 3)?` Rotation matrices of views"""
@property @property
def n_pixels_per_view(self): def disparity_range(self) -> tuple[float, float] | None:
return self.res[0] * self.res[1] return self.depth_range and (1 / self.depth_range[0], 1 / self.depth_range[1])
@property @property
def n_pixels(self): def pixels_per_view(self) -> int:
return self.n_views * self.n_pixels_per_view return self.cam.local_rays.shape[0]
@property
def tot_pixels(self) -> int:
return len(self) * self.pixels_per_view
@overload
def __init__(self, desc: DataDesc, *,
res: Resolution | tuple[int, int] | None = None,
views_to_load: IndexSelector | None = None,
color_mode: Color = Color.rgb,
coord_sys: str = "gl",
device: torch.device = None) -> None:
...
def __init__(self, desc: dict, desc_path: Path, *, @overload
res: Tuple[int, int] = None, def __init__(self, dataset: "Dataset", *,
views_to_load: Union[range, torch.Tensor] = None, views_to_load: IndexSelector | None = None) -> None:
device: torch.device = None, **kwargs) -> None: ...
def __init__(self, dataset_or_desc: Union["Dataset", DataDesc, PathLike], *,
res: tuple[int, int] = None,
views_to_load: IndexSelector = None,
color_mode: Color = Color.rgb,
coord_sys: str = "gl",
device: torch.device = None) -> None:
super().__init__() super().__init__()
self.desc = desc if isinstance(dataset_or_desc, Dataset):
self.desc_path = desc_path.absolute() self._init_from_dataset(dataset_or_desc, views_to_load=views_to_load)
self.device = device else:
self._load_desc(res, views_to_load, **kwargs) self._init_from_desc(dataset_or_desc, res=res, views_to_load=views_to_load,
color_mode=color_mode, coord_sys=coord_sys, device=device)
def get_data(self): def __getitem__(self, index: int | torch.Tensor | slice) -> dict[str, torch.Tensor]:
if isinstance(index, torch.Tensor) and len(index.shape) == 0:
index = index.item()
view_index = self.indices[index]
data = { data = {
'indices': self.indices, "t": self.centers[index]
'centers': self.centers
} }
if self.rots is not None: if self.rots is not None:
data['rots'] = self.rots data["r"] = self.rots[index]
for image_type in ["color", "depth"]:
image = self.load_images(image_type, view_index)
if image is not None:
data[image_type] = self.cam.get_pixels(image)
if isinstance(index, int):
data[image_type].squeeze_(0)
return data return data
def _get_data_path(self, name: str) -> str: def __len__(self):
path_pattern = self.desc.get(f"{name}_file_pattern", None) return self.indices.shape[0]
return path_pattern and get_data_path(self.desc_path, path_pattern)
def load_images(self, type: str, indices: int | torch.Tensor | list[int]) -> torch.Tensor:
def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor], if not getattr(self, f"{type}_path"):
**kwargs): return None
self.level = self.desc.get('level', 0) if isinstance(indices, int):
self.res = res or itemgetter("y", "x")(self.desc['view_res']) raw_images = img.load(getattr(self, f"{type}_path") % indices)
self.cam = view.CameraParam(self.desc['cam_params'], self.res, device=self.device)\ elif isinstance(indices, torch.Tensor) and len(indices.shape) == 0:
if 'cam_params' in self.desc else None raw_images = img.load(getattr(self, f"{type}_path") % indices.item())
self.depth_range = itemgetter("min", "max")(self.desc['depth_range']) \ else:
if 'depth_range' in self.desc else None raw_images = img.load(*[getattr(self, f"{type}_path") % i for i in indices])
self.range = itemgetter("min", "max")(self.desc['range']) if 'range' in self.desc else None raw_images = raw_images.to(device=self.device)
self.bbox = self.desc.get('bbox') if self.res != list(raw_images.shape[-2:]):
self.samples = self.desc.get('samples') raw_images = nn_f.interpolate(raw_images, self.res)
self.centers = torch.tensor(self.desc['view_centers'], device=self.device) # (N, 3) if type == "image":
self.rots = torch.tensor( return Color.cvt(raw_images, Color.rgb, self.color_mode)
[ elif type == "depth":
view.euler_to_matrix([rot[1] if self.desc.get('gl_coord') else -rot[1], rot[0], 0]) return math.lerp(1 - raw_images, self.disparity_range).reciprocal()
for rot in self.desc['view_rots'] return raw_images
]
if len(self.desc['view_rots'][0]) == 2 else self.desc['view_rots'], def split(self, *views: int) -> list["Dataset"]:
device=self.device).view(-1, 3, 3) if 'view_rots' in self.desc else None # (N, 3, 3) views, _ = calculate_autosize(len(self), *views)
self.indices = torch.tensor(self.desc.get('views') or [*range(self.centers.size(0))], sub_datasets: list["Dataset"] = []
device=self.device) offset = 0
for i in range(len(views)):
end = offset + views[i]
sub_datasets.append(Dataset(self, views_to_load=slice(offset, end)))
offset = end
return sub_datasets
def _init_from_desc(self, desc_or_path: DataDesc | PathLike, res: tuple[int, int] | None,
views_to_load: IndexSelector | None, color_mode: Color,
coord_sys: str, device: torch.device) -> None:
desc = desc_or_path if isinstance(desc_or_path, DataDesc) else DataDesc(desc_or_path)
self.root = desc.root
self.name = desc.name
self.color_mode = color_mode
self.white_bg = desc.get("white_bg", default=False)
self.level = desc.get('level', default=0)
self.color_path = desc.get_path("color")
self.depth_path = desc.get_path("depth")
self.res = Resolution(*res) if res else Resolution.from_str(desc["res"])
self.coord_sys = coord_sys
self.device = device
self.cam = view.Camera.create(desc["cam"], self.res, coord_sys=self.coord_sys, device=device)
self.depth_range = desc.get("depth_range")
self.bbox = desc.get("bbox", lambda val: (tuple(val[:len(val) // 2]),
tuple(val[len(val) // 2:])))
self.trs_range = desc.get("trs_range")
self.rot_range = desc.get("rot_range")
self.centers = desc.get_as_tensor("centers", device=device)
self.rots = desc.get_as_tensor("rots", lambda rots: [
view.euler_to_matrix(rot[1] if desc.coord_sys == "gl" else -rot[1], rot[0], 0)
for rot in rots
] if len(rots[0]) == 2 else rots, shape=(-1, 3, 3), device=device)
self.indices = desc.get_as_tensor("views", default=list(range(self.centers.shape[0])),
dtype=torch.long, device=device)
if views_to_load is not None: if views_to_load is not None:
if isinstance(views_to_load, list):
views_to_load = torch.tensor(views_to_load, device=device)
self.indices = self.indices[views_to_load]
self.centers = self.centers[views_to_load] self.centers = self.centers[views_to_load]
self.rots = self.rots[views_to_load] if self.rots is not None else None self.rots = self.rots[views_to_load] if self.rots is not None else None
self.indices = self.indices[views_to_load]
if self.desc.get('gl_coord'): if desc.coord_sys != self.coord_sys:
print('Convert from OGL coordinate to DX coordinate (i.e. flip z axis)')
self.centers[:, 2] *= -1 self.centers[:, 2] *= -1
if self.cam is not None:
if not self.desc['cam_params'].get('fov'):
self.cam.f[1] *= -1
if self.rots is not None: if self.rots is not None:
self.rots[:, 2] *= -1 self.rots[:, 2] *= -1
self.rots[..., 2] *= -1 self.rots[..., 2] *= -1
def _init_from_dataset(self, dataset: "Dataset", views_to_load: IndexSelector | None) -> None:
"""
Clone or get subset of an existed dataset
:param dataset `Dataset`: _description_
:param views_to_load `IndexSelector?`: _description_, defaults to None
"""
self.root = dataset.root
self.name = dataset.name
self.color_mode = dataset.color_mode
self.level = dataset.level
self.res = dataset.res
self.coord_sys = dataset.coord_sys
self.device = dataset.device
self.cam = dataset.cam
self.depth_range = dataset.depth_range
self.bbox = dataset.bbox
self.trs_range = dataset.trs_range
self.rot_range = dataset.rot_range
self.color_path = dataset.color_path
self.depth_path = dataset.depth_path
if views_to_load is not None:
if isinstance(views_to_load, list):
views_to_load = torch.tensor(views_to_load, device=dataset.device)
self.indices = dataset.indices[views_to_load].clone()
self.centers = dataset.centers[views_to_load].clone()
self.rots = None if dataset.rots is None else dataset.rots[views_to_load].clone()
else:
self.indices = dataset.indices.clone()
self.centers = dataset.centers.clone()
self.rots = None if dataset.rots is None else dataset.rots.clone()
import json
from pathlib import Path
import utils.device
from .utils import get_dataset_desc_path
from .pano_dataset import PanoDataset
from .view_dataset import ViewDataset
class DatasetFactory(object):
@staticmethod
def load(path: Path, device=None, **kwargs):
device = device or utils.device.default()
path = get_dataset_desc_path(path)
with open(path, 'r', encoding='utf-8') as file:
data_desc: dict = json.loads(file.read())
if data_desc.get('type') == 'pano':
dataset_class = PanoDataset
else:
dataset_class = ViewDataset
dataset = dataset_class(data_desc, path.absolute(), device=device, **kwargs)
return dataset
import threading import torch.utils.data
import torch from tqdm import tqdm
import math from collections import defaultdict
from logging import *
from typing import Dict, List
from .dataset import Dataset
try:
from ..utils import math
from ..utils.types import *
except ImportError:
from utils import math
from utils.types import *
class Preloader(object):
def __init__(self, device=None) -> None:
super().__init__()
self.stream = torch.cuda.Stream(device)
self.event_chunk_loaded = None
def preload_chunk(self, chunk):
if self.event_chunk_loaded is not None:
self.event_chunk_loaded.wait()
if chunk.loaded:
return
# print(f'Preloader: preload chunk #{chunk.id}')
self.event_chunk_loaded = threading.Event()
threading.Thread(target=Preloader._load_chunk, args=(self, chunk)).start()
def _load_chunk(self, chunk):
with torch.cuda.stream(self.stream):
chunk.load()
self.event_chunk_loaded.set()
# print(f'Preloader: chunk #{chunk.id} is loaded')
class RaysLoader(object):
class DataLoader(object): class Iterator(object):
class Iter(object): def __init__(self, loader: "RaysLoader"):
def __init__(self, chunks, batch_size, shuffle, device: torch.device, preloader: Preloader):
super().__init__() super().__init__()
self.batch_size = batch_size self.loader = loader
self.chunks = chunks self.offset = 0
self.offset = -1
self.chunk_idx = -1
self.current_chunk = None
self.shuffle = shuffle
self.device = device
self.preloader = preloader
def __del__(self): # Initialize ray indices
#print('DataLoader.Iter: clean chunks') #self.ray_indices = torch.randperm(self.loader.tot_pixels, device=self.loader.device)\
if self.preloader is not None and self.preloader.event_chunk_loaded is not None: # if loader.shuffle else torch.arange(self.loader.tot_pixels, device=self.loader.device)
self.preloader.event_chunk_loaded.wait() self.ray_indices = torch.randperm(self.loader.tot_pixels, device="cpu")\
chunks_to_reserve = 1 if self.preloader is None else 2 if loader.shuffle else torch.arange(self.loader.tot_pixels, device="cpu")
for i in range(chunks_to_reserve, len(self.chunks)):
if self.chunks[i].loaded:
self.chunks[i].release()
def __next__(self): def __next__(self) -> Rays:
if self.offset == -1: if self.offset >= self.ray_indices.shape[0]:
self._next_chunk()
stop = min(self.offset + self.batch_size, len(self.current_chunk))
if self.indices is not None:
indices = self.indices[self.offset:stop]
else:
indices = torch.arange(self.offset, stop, device=self.device)
self.offset = stop
if self.offset >= len(self.current_chunk):
self.offset = -1
return self.current_chunk[indices]
def _next_chunk(self):
if self.current_chunk is not None:
chunks_to_reserve = 1 if self.preloader is None else 2
if len(self.chunks) > chunks_to_reserve:
self.current_chunk.release()
if self.chunk_idx >= len(self.chunks) - 1:
raise StopIteration() raise StopIteration()
self.chunk_idx += 1 stop = min(self.offset + self.loader.batch_size, self.ray_indices.shape[0])
self.current_chunk = self.chunks[self.chunk_idx] rays = self._get_rays(self.ray_indices[self.offset:stop])
self.offset = 0 self.offset = stop
self.indices = torch.randperm(len(self.current_chunk)).to(device=self.device) \ return rays
if self.shuffle else None
if self.preloader is not None: def _get_rays(self, indices: torch.Tensor) -> Rays:
self.preloader.preload_chunk(self.chunks[(self.chunk_idx + 1) % len(self.chunks)]) indices_on_device = indices.to(self.loader.device) # (B)
view_idx = torch.div(indices_on_device, self.loader.pixels_per_view, rounding_mode="trunc")
def __init__(self, dataset, batch_size, *, pix_idx = indices_on_device % self.loader.pixels_per_view
chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args): rays_o = self.loader.centers[view_idx] # (B, 3)
rays_d = self.loader.local_rays[pix_idx] # (B, 3)
if self.loader.rots is not None:
rays_d = (self.loader.rots[view_idx] @ rays_d[..., None])[..., 0]
rays = Rays({
'level': self.loader.level,
'idx': indices_on_device,
'rays_o': rays_o,
'rays_d': rays_d
})
# "colors" and "depths" are on host memory. Move part of them to device memory
indices = indices.to("cpu")
for image_type in ["color", "depth"]:
if image_type in self.loader.data:
rays[image_type] = self.loader.data[image_type][indices].to(
self.loader.device, non_blocking=True)
return rays
def __init__(self, dataset: Dataset, batch_size: int, *,
shuffle: bool = False, num_workers: int = 8, device: torch.device = None):
super().__init__() super().__init__()
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.device = device
self.shuffle = shuffle self.shuffle = shuffle
self.chunk_args = chunk_args
self.preloader = Preloader(self.dataset.device) if enable_preload else None self.level = dataset.level
self._init_chunks(chunk_max_items) self.n_views = len(dataset)
self.pixels_per_view = dataset.pixels_per_view
self.tot_pixels = self.n_views * self.pixels_per_view
self.indices = dataset.indices.to(self.device)
self.centers = dataset.centers.to(self.device)
self.rots = dataset.rots.to(self.device) if dataset.rots is not None else None
self.local_rays = dataset.cam.local_rays.to(self.device)
# Load views from dataset
self.data = defaultdict(list)
views_loader = torch.utils.data.DataLoader(dataset, num_workers=num_workers,
pin_memory=True)
for view_data in tqdm(views_loader, "Loading views", leave=False, dynamic_ncols=True):
for key, val in view_data.items():
self.data[key].append(val)
print(f"{len(dataset)} views loaded.")
self.data = {
key: torch.cat(val).flatten(0, 1)
for key, val in self.data.items()
if key == "color" or key == "depth"
}
def __iter__(self): def __iter__(self):
return DataLoader.Iter(self.chunks, self.batch_size, self.shuffle, self.dataset.device, return RaysLoader.Iterator(self)
self.preloader)
def __len__(self): def __len__(self):
return sum(math.ceil(len(chunk) / self.batch_size) for chunk in self.chunks) return math.ceil(self.tot_pixels / self.batch_size)
def _init_chunks(self, chunk_max_items):
data: Dict[str, torch.Tensor] = self.dataset.get_data()
if self.shuffle:
rand_seq = torch.randperm(self.dataset.n_views).to(device=self.dataset.device)
data = {key: val[rand_seq] for key, val in data.items()}
self.chunks = []
n_chunks = 1 if chunk_max_items is None else \
math.ceil(self.dataset.n_pixels / chunk_max_items)
views_per_chunk = math.ceil(self.dataset.n_views / n_chunks)
for offset in range(0, self.dataset.n_views, views_per_chunk):
sel = slice(offset, offset + views_per_chunk)
chunk_data = {key: val[sel] for key, val in data.items()}
self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset,
chunk_data=chunk_data, **self.chunk_args))
if self.preloader is not None:
self.preloader.preload_chunk(self.chunks[0])
class MultiScaleDataLoader(object): class MultiScaleDataLoader(object):
class Iter(object): class Iter(object):
def __init__(self, sub_loaders: List[DataLoader]): def __init__(self, sub_loaders: list[RaysLoader]):
super().__init__() super().__init__()
self.sub_loaders = sub_loaders self.sub_loaders = sub_loaders
self.end_flags = [False] * len(sub_loaders) self.end_flags = [False] * len(sub_loaders)
...@@ -150,15 +131,13 @@ class MultiScaleDataLoader(object): ...@@ -150,15 +131,13 @@ class MultiScaleDataLoader(object):
return data_frags return data_frags
def __init__(self, dataset, batch_size, *, def __init__(self, dataset: Dataset, batch_size, *,
chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args): views_per_chunk=8, shuffle=False, num_workers=4, device: torch.device = None):
super().__init__() super().__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.sub_loaders = [ self.sub_loaders = [
DataLoader(sub_dataset, batch_size // len(dataset), RaysLoader(sub_dataset, batch_size // len(dataset), views_per_chunk=views_per_chunk,
chunk_max_items=chunk_max_items // len(dataset) shuffle=shuffle, num_workers=num_workers, device=device)
if chunk_max_items is not None else None,
shuffle=shuffle, enable_preload=enable_preload, **chunk_args)
for sub_dataset in dataset for sub_dataset in dataset
] ]
# Sort by datasets' levels # Sort by datasets' levels
...@@ -178,10 +157,12 @@ class MultiScaleDataLoader(object): ...@@ -178,10 +157,12 @@ class MultiScaleDataLoader(object):
if not isinstance(self.active_sub_loaders, list): if not isinstance(self.active_sub_loaders, list):
self.active_sub_loaders = [self.active_sub_loaders] self.active_sub_loaders = [self.active_sub_loaders]
def get_loader(dataset, batch_size, *, def get_loader(dataset, batch_size, *,
chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args): views_per_chunk=8, shuffle=False, num_workers=4, device: torch.device = None):
if isinstance(dataset, list): if isinstance(dataset, list):
return MultiScaleDataLoader(dataset, batch_size, chunk_max_items=chunk_max_items, raise NotImplementedError()
shuffle=shuffle, enable_preload=enable_preload, **chunk_args) return MultiScaleDataLoader(dataset, batch_size, views_per_chunk=views_per_chunk,
return DataLoader(dataset, batch_size, chunk_max_items=chunk_max_items, shuffle=shuffle, num_workers=num_workers, device=device)
shuffle=shuffle, enable_preload=enable_preload, **chunk_args) return RaysLoader(dataset, batch_size, views_per_chunk=views_per_chunk,
shuffle=shuffle, num_workers=num_workers, device=device)
import os
import torch
import torch.nn.functional as nn_f
from typing import Dict, Tuple, Union
from operator import itemgetter
from pathlib import Path
from utils import img
from utils import color
from utils import sphere
from utils import math
from utils.mem_profiler import *
from .dataset import Dataset
class PanoDataset(Dataset):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset\n
view_file_pattern ```str```: the filename pattern of view images\n
cam_params ```object```: camera intrinsic parameters\n
centers ```Tensor(N, 3)```: centers of views\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
images ```Tensor(N, 3, H, W)```: images of views\n
view_depths ```Tensor(N, H, W)```: depths of views\n
"""
class Chunk(object):
@property
def n_views(self):
return self.indices.size(0)
@property
def n_pixels_per_view(self):
return self.dataset.n_pixels_per_view
def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
color: int, **kwargs):
"""
[summary]
:param dataset `PanoDataset`: dataset object
:param indices `Tensor(N)`: indices of views
:param centers `Tensor(N, 3)`: centers of views
"""
self.id = id
self.dataset = dataset
self.indices = chunk_data['indices']
self.centers = chunk_data['centers']
self.color = color
self.colors_cpu = None
self.colors = None
self.loaded = False
def release(self):
self.colors = None
self.loaded = False
MemProfiler.print_memory_stats(f'Chunk #{self.id} released')
def load(self):
if self.dataset.image_path and self.colors_cpu is None:
images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
color.RGB, self.color)
if self.dataset.res != tuple(images.shape[-2:]):
images = nn_f.interpolate(images, self.dataset.res)
self.colors_cpu = images.permute(
0, 2, 3, 1)[:, self.dataset.pixels[:, 0], self.dataset.pixels[:, 1]].flatten(0, 1)
if self.colors_cpu is not None:
self.colors = self.colors_cpu.to(self.dataset.device)
self.loaded = True
MemProfiler.print_memory_stats(
f'Chunk #{self.id} ({self.n_views} views, '
f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded')
def __len__(self):
return self.n_views * self.n_pixels_per_view
def __getitem__(self, idx):
if not self.loaded:
self.load()
view_idx = idx // self.n_pixels_per_view
pix_idx = idx % self.n_pixels_per_view
global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
rays_o = self.centers[view_idx]
rays_d = self.dataset.rays[pix_idx]
data = {
'idx': global_idx,
'rays_o': rays_o,
'rays_d': rays_d,
'level': self.dataset.level
}
if self.colors is not None:
data['color'] = self.colors[idx]
return data
@property
def n_pixels_per_view(self):
return self.pixels.size(0)
def __init__(self, desc: dict, desc_path: Path, *,
load_images: bool = True,
res: Tuple[int, int] = None,
views_to_load: Union[range, torch.Tensor] = None,
device: torch.device = None,
**kwargs):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view
- depth_range: { "min", "max" }, the depth range
- range: { "min": [...], "max": [...] }, the range of translation and rotation
- centers: [ [ x, y, z ], ... ], centers of views
:param desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
super().__init__(desc, desc_path, res=res, views_to_load=views_to_load, device=device,
load_images=load_images)
def get_data(self):
return {
'indices': self.indices,
'centers': self.centers
}
def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor],
load_images: bool):
super()._load_desc(res, views_to_load)
self.image_path = load_images and self._get_data_path("view")
self.pixels, self.rays = self._get_pano_rays()
def _get_pano_rays(self):
"""
Get unprojected rays of pixels on a panorama
:return `Tensor(N, 2)`: rays' pixel coordinates in pano image
:return `Tensor(N, 3)`: rays' directions with one unit length
"""
phi = (torch.arange(self.res[0], device=self.device) + 0.5) / self.res[0] * math.pi # (H)
length = (phi.sin() * self.res[1] * 0.5).ceil() * 2
cols = torch.arange(self.res[1], device=self.device)[None, :].expand(*self.res) # (H, W)
mask = torch.logical_and(cols >= (self.res[1] - length[:, None]) / 2,
cols < (self.res[1] + length[:, None]) / 2) # (H, W)
pixs = mask.nonzero() # (N, 2)
pixs_phi = (0.5 - (pixs[:, 0] + 0.5) / self.res[0]) * math.pi
pixs_theta = (pixs[:, 1] * 2 + 1 - self.res[1]) / length[pixs[:, 0]] * math.pi
spher_coords = torch.stack([torch.ones_like(pixs_phi), pixs_theta, pixs_phi], dim=-1)
return pixs, sphere.spherical2cartesian(spher_coords) # (N, 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment