Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
#include <device_launch_parameters.h>
#include <glm/glm.hpp>
#define T_IDX threadIdx.x
#define T_IDX2 glm::uvec2(threadIdx.x, threadIdx.y)
#define T_IDX3 glm::uvec3(threadIdx.x, threadIdx.y, threadIdx.z)
#define B_IDX blockIdx.x
#define B_IDX2 glm::uvec2(blockIdx.x, blockIdx.y)
#define B_IDX3 glm::uvec3(blockIdx.x, blockIdx.y, blockIdx.z)
#define IDX blockIdx.x *blockDim.x + threadIdx.x
#define IDX2 glm::uvec2(blockIdx.x *blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y)
#define IDX3 \
glm::uvec3(blockIdx.x *blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, \
blockIdx.z * blockDim.z + threadIdx.z)
#define FLAT_INDEX utils::cuda::flattenIdx(IDX3)
#define DEFINE_IDX(__var1__) uint __var1__ = blockIdx.x * blockDim.x + threadIdx.x;
#define DEFINE_IDX2(__var1__, __var2__) \
uint __var1__ = blockIdx.x * blockDim.x + threadIdx.x; \
uint __var2__ = blockIdx.y * blockDim.y + threadIdx.y;
#define DEFINE_IDX3(__var1__, __var2__, __var3__) \
uint __var1__ = blockIdx.x * blockDim.x + threadIdx.x; \
uint __var2__ = blockIdx.y * blockDim.y + threadIdx.y; \
uint __var3__ = blockIdx.z * blockDim.z + threadIdx.z;
#define DEFINE_FLAT_INDEX(__var1__) uint __var1__ = FLAT_INDEX;
namespace utils::cuda {
__device__ __forceinline__ uint flattenIdx(glm::uvec3 idx3) {
return idx3.x + idx3.y * blockDim.x * gridDim.x +
idx3.z * blockDim.x * gridDim.x * blockDim.y * gridDim.y;
}
__device__ __forceinline__ uint flattenIdx(glm::uvec2 idx2) {
return idx2.x + idx2.y * blockDim.x * gridDim.x;
}
} // namespace utils::cuda
#include <cuda_runtime.h>
#include <vector>
#include "error.h"
namespace utils::cuda {
class MapResourcesScope {
public:
MapResourcesScope(const std::vector<cudaGraphicsResource_t> &resources,
cudaStream_t stream = nullptr)
: _resources(resources), _stream(stream) {
if (!_resources.empty())
THROW_IF_FAILED(
cudaGraphicsMapResources((int)_resources.size(), _resources.data(), _stream));
}
~MapResourcesScope() {
if (!_resources.empty())
cudaGraphicsUnmapResources((int)_resources.size(), _resources.data(), _stream);
}
private:
std::vector<cudaGraphicsResource_t> _resources;
cudaStream_t _stream;
};
} // namespace utils::cuda
#pragma once
#include <cuda_runtime.h>
#include "../common.h"
#include "error.h"
namespace utils::cuda {
class Resource {
public:
virtual ~Resource() {}
virtual void *getBuffer() const = 0;
virtual size_t size() const = 0;
};
class BufferResource : public Resource {
public:
BufferResource(void *buffer = nullptr, size_t size = 0)
: _p_buffer(buffer), _ownBuffer(false), _size(size) {}
BufferResource(size_t size) : _ownBuffer(true), _size(size) {
void *p_buffer;
THROW_IF_FAILED(cudaMalloc(&p_buffer, size));
_p_buffer = std::shared_ptr<void>(p_buffer);
}
virtual ~BufferResource() {
if (!_ownBuffer || _p_buffer.use_count() > 1)
return;
try {
THROW_IF_FAILED(cudaFree(_p_buffer.get()));
} catch (std::exception &ex) {
common::Logger::instance.warning("Exception raised in destructor: %s", ex.what());
}
}
virtual void *getBuffer() const { return _p_buffer.get(); }
virtual size_t size() const { return _size; }
private:
std::shared_ptr<void> _p_buffer;
bool _ownBuffer;
size_t _size;
};
class GraphicsResource : public Resource {
public:
cudaGraphicsResource_t getHandler() { return _res; }
virtual ~GraphicsResource() {
if (_res == nullptr)
return;
try {
THROW_IF_FAILED(cudaGraphicsUnregisterResource(_res));
} catch (std::exception &ex) {
common::Logger::instance.warning("Exception raised in destructor: %s", ex.what());
}
_res = nullptr;
}
virtual size_t size() const { return _size; }
protected:
cudaGraphicsResource_t _res;
size_t _size;
GraphicsResource() : _res(nullptr), _size(0) {}
};
} // namespace utils::cuda
#include "resource.h"
#include <map>
#include <vector>
namespace utils::cuda {
class Resources : public std::map<std::string, Resource &> {
public:
std::vector<cudaGraphicsResource_t> getGraphicsResourceHandlers() {
std::vector<cudaGraphicsResource_t> handlers;
for (auto &&item : *this) {
auto gres = dynamic_cast<GraphicsResource *>(&item.second);
if (gres != nullptr)
handlers.push_back(gres->getHandler());
}
return handlers;
}
};
} // namespace utils::cuda
#include <cuda_runtime.h>
#include <memory>
namespace utils::cuda {
class Stream {
public:
Stream() : _p_stream(std::make_shared<cudaStream_t>()) {
cudaStreamCreate(_p_stream.get());
}
virtual ~Stream() {
if (_p_stream.use_count() == 1)
cudaStreamDestroy(*_p_stream);
}
operator cudaStream_t() { return *_p_stream; }
private:
std::shared_ptr<cudaStream_t> _p_stream;
};
} // namespace utils::cuda
template <typename T, typename T2 = T>
void dumpArray(std::ostream &so, CudaArray<T> &arr, size_t maxDumpRows = 0,
size_t elemsPerRow = 1) {
int chns = sizeof(T) / sizeof(T2);
T2 *hostArr = new T2[arr.n() * chns];
cudaMemcpy(hostArr, arr.getBuffer(), arr.n() * sizeof(T), cudaMemcpyDeviceToHost);
dumpHostBuffer<T2>(so, hostArr, arr.n() * sizeof(T), chns * elemsPerRow, maxDumpRows);
delete[] hostArr;
}
\ No newline at end of file
#pragma once
#include "Msl.h"
class Nmsl2 : public Msl
{
public:
sptr<Resource> resRaw1;
sptr<Resource> resRaw2;
Net *fcNet1;
Net *fcNet2;
Net *catNet;
unsigned int batchSize;
unsigned int samples;
Nmsl2(int batchSize, int samples);
virtual bool load(const std::string &netDir);
virtual void bindResources(Resource *resEncoded, Resource *resDepths, Resource *resColors);
virtual bool infer();
virtual void dispose();
};
#pragma once
#include "../utils/common.h"
class Sampler {
public:
Sampler(glm::vec2 depthRange, unsigned int samples, bool outputRadius)
: _dispRange(1.0f / depthRange, samples), _outputRadius(outputRadius) {}
void sampleOnRays(sptr<CudaArray<float>> o_coords, sptr<CudaArray<float>> o_depths,
sptr<CudaArray<glm::vec3>> rays, glm::vec3 rayCenter);
private:
Range _dispRange;
bool _outputRadius;
};
\ No newline at end of file
#include "Msl.h"
#include <time.h>
Msl::Msl() : net(nullptr) {}
bool Msl::load(const std::string &netPath) {
net = new Net();
if (net->load(netPath))
return true;
dispose();
return false;
}
void Msl::bindResources(Resource *resEncoded, Resource *resDepths, Resource *resColors) {
net->bindResource("Encoded", resEncoded);
net->bindResource("Depths", resDepths);
net->bindResource("Colors", resColors);
}
bool Msl::infer() { return net->infer(); }
void Msl::dispose() {
if (net != nullptr) {
net->dispose();
delete net;
net = nullptr;
}
}
#pragma once
#include "../utils/common.h"
#include "Net.h"
class Msl {
public:
Net *net;
Msl();
virtual bool load(const std::string &netDir);
virtual void bindResources(Resource *resEncoded, Resource *resDepths, Resource *resColors);
virtual bool infer();
virtual void dispose();
};
#include "Nmsl2.h"
#include <time.h>
Nmsl2::Nmsl2(int batchSize, int samples)
: batchSize(batchSize),
samples(samples),
resRaw1(nullptr),
resRaw2(nullptr),
fcNet1(nullptr),
fcNet2(nullptr),
catNet(nullptr) {}
bool Nmsl2::load(const std::string &netDir) {
fcNet1 = new Net();
fcNet2 = new Net();
catNet = new Net();
if (!fcNet1->load(netDir + "fc1.trt") || !fcNet2->load(netDir + "fc2.trt") ||
!catNet->load(netDir + "cat.trt"))
return false;
resRaw1 = sptr<Resource>(new CudaBuffer(batchSize * samples / 2 * sizeof(float4)));
resRaw2 = sptr<Resource>(new CudaBuffer(batchSize * samples / 2 * sizeof(float4)));
return true;
}
void Nmsl2::bindResources(Resource *resEncoded, Resource *resDepths, Resource *resColors) {
fcNet1->bindResource("Encoded", resEncoded);
fcNet1->bindResource("Raw", resRaw1.get());
fcNet2->bindResource("Encoded", resEncoded);
fcNet2->bindResource("Raw", resRaw2.get());
catNet->bindResource("Raw1", resRaw1.get());
catNet->bindResource("Raw2", resRaw2.get());
catNet->bindResource("Depths", resDepths);
catNet->bindResource("Colors", resColors);
}
bool Nmsl2::infer() {
// CudaStream stream1, stream2;
if (!fcNet1->infer())
return false;
if (!fcNet2->infer())
return false;
if (!catNet->infer())
return false;
return true;
}
void Nmsl2::dispose() {
if (fcNet1 != nullptr) {
fcNet1->dispose();
delete fcNet1;
fcNet1 = nullptr;
}
if (fcNet2 != nullptr) {
fcNet2->dispose();
delete fcNet2;
fcNet2 = nullptr;
}
if (catNet != nullptr) {
catNet->dispose();
delete catNet;
catNet = nullptr;
}
resRaw1 = nullptr;
resRaw2 = nullptr;
}
#include "Renderer.h"
#include "../utils/cuda.h"
/// Dispatch (n_rays, -)
__global__ void cu_render(glm::vec4 *o_colors, glm::vec4 *layeredColors, uint samples, uint nRays) {
glm::uvec3 idx3 = IDX3;
uint rayIdx = idx3.x;
if (rayIdx >= nRays)
return;
glm::vec4 outColor;
for (int si = samples - 1; si >= 0; --si) {
glm::vec4 c = layeredColors[rayIdx * samples + si];
outColor = outColor * (1 - c.a) + c * c.a;
}
outColor.a = 1.0f;
o_colors[idx3.x] = outColor;
}
Renderer::Renderer() {}
void Renderer::render(sptr<CudaArray<glm::vec4>> o_colors,
sptr<CudaArray<glm::vec4>> layeredColors) {
dim3 blkSize(1024);
dim3 grdSize(ceilDiv(o_colors->n(), blkSize.x));
CU_INVOKE(cu_render)
(*o_colors, *layeredColors, layeredColors->n() / o_colors->n(), o_colors->n());
CHECK_EX(cudaGetLastError());
}
\ No newline at end of file
#pragma once
#include "../utils/common.h"
class Renderer {
public:
Renderer();
/**
* @brief
*
* @param o_colors
* @param layeredColors
*/
void render(sptr<CudaArray<glm::vec4>> o_colors, sptr<CudaArray<glm::vec4>> layeredColors);
};
\ No newline at end of file
#include "Logger.h"
Logger Logger::instance;
#include <device_launch_parameters.h>
#include <glm/glm.hpp>
#define IDX2 glm::uvec2 { blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y }
#define IDX3 glm::uvec3 { blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, blockIdx.z * blockDim.z + threadIdx.z }
__device__ __forceinline__ unsigned int flattenIdx(glm::uvec3 idx3)
{
return idx3.x + idx3.y * blockDim.x * gridDim.x + idx3.z * blockDim.x * gridDim.x * blockDim.y * gridDim.y;
}
__device__ __forceinline__ unsigned int flattenIdx()
{
return flattenIdx(IDX3);
}
\ No newline at end of file
from .utils import *
from .dataset_factory import *
from .loader import *
\ No newline at end of file
from .dataset import DataDesc, Dataset
from .loader import RaysLoader, MultiScaleDataLoader
import json
import torch
import torch.utils.data
import torch.nn.functional as nn_f
from typing import Union
from operator import itemgetter
from typing import Tuple, Union
from pathlib import Path
from utils import view
from .utils import get_data_path
try:
from ..utils import view, img, math
from ..utils.types import *
from ..utils.misc import calculate_autosize
except ImportError:
from utils import view, img, math
from utils.types import *
from utils.misc import calculate_autosize
class Dataset(object):
desc: dict
desc_path: Path
device: torch.device
class DataDesc(dict[str, Any]):
path: Path
@property
def name(self):
return self.desc_path.stem
def name(self) -> str:
return self.path.stem
@property
def root(self):
return self.desc_path.parent
def root(self) -> Path:
return self.path.parent
@property
def n_views(self):
return self.centers.size(0)
def coord_sys(self) -> str:
return "gl" if self.get("gl_coord") else "dx"
def __init__(self, path: PathLike):
path = DataDesc.get_json_path(path)
with open(path, 'r', encoding='utf-8') as file:
data = json.loads(file.read())
super().__init__(data)
self.path = path
@staticmethod
def get_json_path(path: PathLike) -> Path:
path = Path(path)
if path.suffix != ".json":
path = Path(f"{path}.json")
return path.absolute()
def get(self, key: str, fn=lambda x: x, default=None) -> Any | None:
if key in self:
return fn(self[key])
return default
def get_as_tensor(self, key: str, fn=lambda x: x, default=None, dtype=torch.float, device=None,
shape=None) -> torch.Tensor | None:
raw_value = self.get(key, fn, default)
if raw_value is None:
return raw_value
tensor_value = torch.tensor(raw_value, dtype=dtype, device=device)
if shape is not None:
tensor_value = tensor_value.reshape(shape)
return tensor_value
def get_path(self, name: str) -> str | None:
path_pattern = self.get(f"{name}_file")
if not path_pattern:
return None
if "/" not in path_pattern:
path_pattern = f"{self.name}/{path_pattern}"
return str(self.root / path_pattern)
class Dataset(torch.utils.data.Dataset):
root: Path
"""`Path` Root directory of the dataset"""
name: str
"""`str` Name of the dataset"""
color_mode: Color
"""`Color` Color mode of images in the dataset"""
white_bg: bool
"""`bool` Images in the dataset should have white background"""
level: int
"""`int` Level of this dataset"""
res: Resolution
"""`Resolution` Resolution of each view as (rows, columns)"""
coord_sys: str
"""`str` Coordinate system, must be 'dx' or 'gl'"""
device: torch.device
"""`device` Device of tensors"""
cam: view.Camera
"""`Camera?` Camera object"""
depth_range: tuple[float, float] | None
"""`(float, float)?` Depth range of the scene as a guide to sampling"""
bbox: tuple[tuple[float, ...], tuple[float, ...]] | None
"""`((float,...), (float,...))?` Bounding box of the scene as a guide to sampling"""
trans_range: tuple[tuple[float, ...], tuple[float, ...]] | None
"""`((float,...), (float,...))?` Acceptable Translation (and optional rotation) range"""
color_path: str | None
"""`str?` Path of image data"""
depth_path: str | None
"""`str?` Path of depth data"""
indices: torch.Tensor
"""`Tensor(N)` Indices for loading specific subset of views in the dataset"""
centers: torch.Tensor
"""`Tensor(N, 3)` Center positions of views"""
rots: torch.Tensor | None
"""`Tensor(N, 3, 3)?` Rotation matrices of views"""
@property
def n_pixels_per_view(self):
return self.res[0] * self.res[1]
def disparity_range(self) -> tuple[float, float] | None:
return self.depth_range and (1 / self.depth_range[0], 1 / self.depth_range[1])
@property
def n_pixels(self):
return self.n_views * self.n_pixels_per_view
def pixels_per_view(self) -> int:
return self.cam.local_rays.shape[0]
@property
def tot_pixels(self) -> int:
return len(self) * self.pixels_per_view
@overload
def __init__(self, desc: DataDesc, *,
res: Resolution | tuple[int, int] | None = None,
views_to_load: IndexSelector | None = None,
color_mode: Color = Color.rgb,
coord_sys: str = "gl",
device: torch.device = None) -> None:
...
def __init__(self, desc: dict, desc_path: Path, *,
res: Tuple[int, int] = None,
views_to_load: Union[range, torch.Tensor] = None,
device: torch.device = None, **kwargs) -> None:
@overload
def __init__(self, dataset: "Dataset", *,
views_to_load: IndexSelector | None = None) -> None:
...
def __init__(self, dataset_or_desc: Union["Dataset", DataDesc, PathLike], *,
res: tuple[int, int] = None,
views_to_load: IndexSelector = None,
color_mode: Color = Color.rgb,
coord_sys: str = "gl",
device: torch.device = None) -> None:
super().__init__()
self.desc = desc
self.desc_path = desc_path.absolute()
self.device = device
self._load_desc(res, views_to_load, **kwargs)
if isinstance(dataset_or_desc, Dataset):
self._init_from_dataset(dataset_or_desc, views_to_load=views_to_load)
else:
self._init_from_desc(dataset_or_desc, res=res, views_to_load=views_to_load,
color_mode=color_mode, coord_sys=coord_sys, device=device)
def get_data(self):
def __getitem__(self, index: int | torch.Tensor | slice) -> dict[str, torch.Tensor]:
if isinstance(index, torch.Tensor) and len(index.shape) == 0:
index = index.item()
view_index = self.indices[index]
data = {
'indices': self.indices,
'centers': self.centers
"t": self.centers[index]
}
if self.rots is not None:
data['rots'] = self.rots
data["r"] = self.rots[index]
for image_type in ["color", "depth"]:
image = self.load_images(image_type, view_index)
if image is not None:
data[image_type] = self.cam.get_pixels(image)
if isinstance(index, int):
data[image_type].squeeze_(0)
return data
def _get_data_path(self, name: str) -> str:
path_pattern = self.desc.get(f"{name}_file_pattern", None)
return path_pattern and get_data_path(self.desc_path, path_pattern)
def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor],
**kwargs):
self.level = self.desc.get('level', 0)
self.res = res or itemgetter("y", "x")(self.desc['view_res'])
self.cam = view.CameraParam(self.desc['cam_params'], self.res, device=self.device)\
if 'cam_params' in self.desc else None
self.depth_range = itemgetter("min", "max")(self.desc['depth_range']) \
if 'depth_range' in self.desc else None
self.range = itemgetter("min", "max")(self.desc['range']) if 'range' in self.desc else None
self.bbox = self.desc.get('bbox')
self.samples = self.desc.get('samples')
self.centers = torch.tensor(self.desc['view_centers'], device=self.device) # (N, 3)
self.rots = torch.tensor(
[
view.euler_to_matrix([rot[1] if self.desc.get('gl_coord') else -rot[1], rot[0], 0])
for rot in self.desc['view_rots']
]
if len(self.desc['view_rots'][0]) == 2 else self.desc['view_rots'],
device=self.device).view(-1, 3, 3) if 'view_rots' in self.desc else None # (N, 3, 3)
self.indices = torch.tensor(self.desc.get('views') or [*range(self.centers.size(0))],
device=self.device)
def __len__(self):
return self.indices.shape[0]
def load_images(self, type: str, indices: int | torch.Tensor | list[int]) -> torch.Tensor:
if not getattr(self, f"{type}_path"):
return None
if isinstance(indices, int):
raw_images = img.load(getattr(self, f"{type}_path") % indices)
elif isinstance(indices, torch.Tensor) and len(indices.shape) == 0:
raw_images = img.load(getattr(self, f"{type}_path") % indices.item())
else:
raw_images = img.load(*[getattr(self, f"{type}_path") % i for i in indices])
raw_images = raw_images.to(device=self.device)
if self.res != list(raw_images.shape[-2:]):
raw_images = nn_f.interpolate(raw_images, self.res)
if type == "image":
return Color.cvt(raw_images, Color.rgb, self.color_mode)
elif type == "depth":
return math.lerp(1 - raw_images, self.disparity_range).reciprocal()
return raw_images
def split(self, *views: int) -> list["Dataset"]:
views, _ = calculate_autosize(len(self), *views)
sub_datasets: list["Dataset"] = []
offset = 0
for i in range(len(views)):
end = offset + views[i]
sub_datasets.append(Dataset(self, views_to_load=slice(offset, end)))
offset = end
return sub_datasets
def _init_from_desc(self, desc_or_path: DataDesc | PathLike, res: tuple[int, int] | None,
views_to_load: IndexSelector | None, color_mode: Color,
coord_sys: str, device: torch.device) -> None:
desc = desc_or_path if isinstance(desc_or_path, DataDesc) else DataDesc(desc_or_path)
self.root = desc.root
self.name = desc.name
self.color_mode = color_mode
self.white_bg = desc.get("white_bg", default=False)
self.level = desc.get('level', default=0)
self.color_path = desc.get_path("color")
self.depth_path = desc.get_path("depth")
self.res = Resolution(*res) if res else Resolution.from_str(desc["res"])
self.coord_sys = coord_sys
self.device = device
self.cam = view.Camera.create(desc["cam"], self.res, coord_sys=self.coord_sys, device=device)
self.depth_range = desc.get("depth_range")
self.bbox = desc.get("bbox", lambda val: (tuple(val[:len(val) // 2]),
tuple(val[len(val) // 2:])))
self.trs_range = desc.get("trs_range")
self.rot_range = desc.get("rot_range")
self.centers = desc.get_as_tensor("centers", device=device)
self.rots = desc.get_as_tensor("rots", lambda rots: [
view.euler_to_matrix(rot[1] if desc.coord_sys == "gl" else -rot[1], rot[0], 0)
for rot in rots
] if len(rots[0]) == 2 else rots, shape=(-1, 3, 3), device=device)
self.indices = desc.get_as_tensor("views", default=list(range(self.centers.shape[0])),
dtype=torch.long, device=device)
if views_to_load is not None:
if isinstance(views_to_load, list):
views_to_load = torch.tensor(views_to_load, device=device)
self.indices = self.indices[views_to_load]
self.centers = self.centers[views_to_load]
self.rots = self.rots[views_to_load] if self.rots is not None else None
self.indices = self.indices[views_to_load]
if self.desc.get('gl_coord'):
print('Convert from OGL coordinate to DX coordinate (i.e. flip z axis)')
if desc.coord_sys != self.coord_sys:
self.centers[:, 2] *= -1
if self.cam is not None:
if not self.desc['cam_params'].get('fov'):
self.cam.f[1] *= -1
if self.rots is not None:
self.rots[:, 2] *= -1
self.rots[..., 2] *= -1
def _init_from_dataset(self, dataset: "Dataset", views_to_load: IndexSelector | None) -> None:
"""
Clone or get subset of an existed dataset
:param dataset `Dataset`: _description_
:param views_to_load `IndexSelector?`: _description_, defaults to None
"""
self.root = dataset.root
self.name = dataset.name
self.color_mode = dataset.color_mode
self.level = dataset.level
self.res = dataset.res
self.coord_sys = dataset.coord_sys
self.device = dataset.device
self.cam = dataset.cam
self.depth_range = dataset.depth_range
self.bbox = dataset.bbox
self.trs_range = dataset.trs_range
self.rot_range = dataset.rot_range
self.color_path = dataset.color_path
self.depth_path = dataset.depth_path
if views_to_load is not None:
if isinstance(views_to_load, list):
views_to_load = torch.tensor(views_to_load, device=dataset.device)
self.indices = dataset.indices[views_to_load].clone()
self.centers = dataset.centers[views_to_load].clone()
self.rots = None if dataset.rots is None else dataset.rots[views_to_load].clone()
else:
self.indices = dataset.indices.clone()
self.centers = dataset.centers.clone()
self.rots = None if dataset.rots is None else dataset.rots.clone()
import json
from pathlib import Path
import utils.device
from .utils import get_dataset_desc_path
from .pano_dataset import PanoDataset
from .view_dataset import ViewDataset
class DatasetFactory(object):
@staticmethod
def load(path: Path, device=None, **kwargs):
device = device or utils.device.default()
path = get_dataset_desc_path(path)
with open(path, 'r', encoding='utf-8') as file:
data_desc: dict = json.loads(file.read())
if data_desc.get('type') == 'pano':
dataset_class = PanoDataset
else:
dataset_class = ViewDataset
dataset = dataset_class(data_desc, path.absolute(), device=device, **kwargs)
return dataset
import threading
import torch
import math
from logging import *
from typing import Dict, List
import torch.utils.data
from tqdm import tqdm
from collections import defaultdict
from .dataset import Dataset
try:
from ..utils import math
from ..utils.types import *
except ImportError:
from utils import math
from utils.types import *
class Preloader(object):
def __init__(self, device=None) -> None:
super().__init__()
self.stream = torch.cuda.Stream(device)
self.event_chunk_loaded = None
def preload_chunk(self, chunk):
if self.event_chunk_loaded is not None:
self.event_chunk_loaded.wait()
if chunk.loaded:
return
# print(f'Preloader: preload chunk #{chunk.id}')
self.event_chunk_loaded = threading.Event()
threading.Thread(target=Preloader._load_chunk, args=(self, chunk)).start()
def _load_chunk(self, chunk):
with torch.cuda.stream(self.stream):
chunk.load()
self.event_chunk_loaded.set()
# print(f'Preloader: chunk #{chunk.id} is loaded')
class DataLoader(object):
class RaysLoader(object):
class Iter(object):
class Iterator(object):
def __init__(self, chunks, batch_size, shuffle, device: torch.device, preloader: Preloader):
def __init__(self, loader: "RaysLoader"):
super().__init__()
self.batch_size = batch_size
self.chunks = chunks
self.offset = -1
self.chunk_idx = -1
self.current_chunk = None
self.shuffle = shuffle
self.device = device
self.preloader = preloader
self.loader = loader
self.offset = 0
def __del__(self):
#print('DataLoader.Iter: clean chunks')
if self.preloader is not None and self.preloader.event_chunk_loaded is not None:
self.preloader.event_chunk_loaded.wait()
chunks_to_reserve = 1 if self.preloader is None else 2
for i in range(chunks_to_reserve, len(self.chunks)):
if self.chunks[i].loaded:
self.chunks[i].release()
# Initialize ray indices
#self.ray_indices = torch.randperm(self.loader.tot_pixels, device=self.loader.device)\
# if loader.shuffle else torch.arange(self.loader.tot_pixels, device=self.loader.device)
self.ray_indices = torch.randperm(self.loader.tot_pixels, device="cpu")\
if loader.shuffle else torch.arange(self.loader.tot_pixels, device="cpu")
def __next__(self):
if self.offset == -1:
self._next_chunk()
stop = min(self.offset + self.batch_size, len(self.current_chunk))
if self.indices is not None:
indices = self.indices[self.offset:stop]
else:
indices = torch.arange(self.offset, stop, device=self.device)
self.offset = stop
if self.offset >= len(self.current_chunk):
self.offset = -1
return self.current_chunk[indices]
def _next_chunk(self):
if self.current_chunk is not None:
chunks_to_reserve = 1 if self.preloader is None else 2
if len(self.chunks) > chunks_to_reserve:
self.current_chunk.release()
if self.chunk_idx >= len(self.chunks) - 1:
def __next__(self) -> Rays:
if self.offset >= self.ray_indices.shape[0]:
raise StopIteration()
self.chunk_idx += 1
self.current_chunk = self.chunks[self.chunk_idx]
self.offset = 0
self.indices = torch.randperm(len(self.current_chunk)).to(device=self.device) \
if self.shuffle else None
if self.preloader is not None:
self.preloader.preload_chunk(self.chunks[(self.chunk_idx + 1) % len(self.chunks)])
def __init__(self, dataset, batch_size, *,
chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args):
stop = min(self.offset + self.loader.batch_size, self.ray_indices.shape[0])
rays = self._get_rays(self.ray_indices[self.offset:stop])
self.offset = stop
return rays
def _get_rays(self, indices: torch.Tensor) -> Rays:
indices_on_device = indices.to(self.loader.device) # (B)
view_idx = torch.div(indices_on_device, self.loader.pixels_per_view, rounding_mode="trunc")
pix_idx = indices_on_device % self.loader.pixels_per_view
rays_o = self.loader.centers[view_idx] # (B, 3)
rays_d = self.loader.local_rays[pix_idx] # (B, 3)
if self.loader.rots is not None:
rays_d = (self.loader.rots[view_idx] @ rays_d[..., None])[..., 0]
rays = Rays({
'level': self.loader.level,
'idx': indices_on_device,
'rays_o': rays_o,
'rays_d': rays_d
})
# "colors" and "depths" are on host memory. Move part of them to device memory
indices = indices.to("cpu")
for image_type in ["color", "depth"]:
if image_type in self.loader.data:
rays[image_type] = self.loader.data[image_type][indices].to(
self.loader.device, non_blocking=True)
return rays
def __init__(self, dataset: Dataset, batch_size: int, *,
shuffle: bool = False, num_workers: int = 8, device: torch.device = None):
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.device = device
self.shuffle = shuffle
self.chunk_args = chunk_args
self.preloader = Preloader(self.dataset.device) if enable_preload else None
self._init_chunks(chunk_max_items)
self.level = dataset.level
self.n_views = len(dataset)
self.pixels_per_view = dataset.pixels_per_view
self.tot_pixels = self.n_views * self.pixels_per_view
self.indices = dataset.indices.to(self.device)
self.centers = dataset.centers.to(self.device)
self.rots = dataset.rots.to(self.device) if dataset.rots is not None else None
self.local_rays = dataset.cam.local_rays.to(self.device)
# Load views from dataset
self.data = defaultdict(list)
views_loader = torch.utils.data.DataLoader(dataset, num_workers=num_workers,
pin_memory=True)
for view_data in tqdm(views_loader, "Loading views", leave=False, dynamic_ncols=True):
for key, val in view_data.items():
self.data[key].append(val)
print(f"{len(dataset)} views loaded.")
self.data = {
key: torch.cat(val).flatten(0, 1)
for key, val in self.data.items()
if key == "color" or key == "depth"
}
def __iter__(self):
return DataLoader.Iter(self.chunks, self.batch_size, self.shuffle, self.dataset.device,
self.preloader)
return RaysLoader.Iterator(self)
def __len__(self):
return sum(math.ceil(len(chunk) / self.batch_size) for chunk in self.chunks)
def _init_chunks(self, chunk_max_items):
data: Dict[str, torch.Tensor] = self.dataset.get_data()
if self.shuffle:
rand_seq = torch.randperm(self.dataset.n_views).to(device=self.dataset.device)
data = {key: val[rand_seq] for key, val in data.items()}
self.chunks = []
n_chunks = 1 if chunk_max_items is None else \
math.ceil(self.dataset.n_pixels / chunk_max_items)
views_per_chunk = math.ceil(self.dataset.n_views / n_chunks)
for offset in range(0, self.dataset.n_views, views_per_chunk):
sel = slice(offset, offset + views_per_chunk)
chunk_data = {key: val[sel] for key, val in data.items()}
self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset,
chunk_data=chunk_data, **self.chunk_args))
if self.preloader is not None:
self.preloader.preload_chunk(self.chunks[0])
return math.ceil(self.tot_pixels / self.batch_size)
class MultiScaleDataLoader(object):
class Iter(object):
def __init__(self, sub_loaders: List[DataLoader]):
def __init__(self, sub_loaders: list[RaysLoader]):
super().__init__()
self.sub_loaders = sub_loaders
self.end_flags = [False] * len(sub_loaders)
......@@ -150,15 +131,13 @@ class MultiScaleDataLoader(object):
return data_frags
def __init__(self, dataset, batch_size, *,
chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args):
def __init__(self, dataset: Dataset, batch_size, *,
views_per_chunk=8, shuffle=False, num_workers=4, device: torch.device = None):
super().__init__()
self.batch_size = batch_size
self.sub_loaders = [
DataLoader(sub_dataset, batch_size // len(dataset),
chunk_max_items=chunk_max_items // len(dataset)
if chunk_max_items is not None else None,
shuffle=shuffle, enable_preload=enable_preload, **chunk_args)
RaysLoader(sub_dataset, batch_size // len(dataset), views_per_chunk=views_per_chunk,
shuffle=shuffle, num_workers=num_workers, device=device)
for sub_dataset in dataset
]
# Sort by datasets' levels
......@@ -178,10 +157,12 @@ class MultiScaleDataLoader(object):
if not isinstance(self.active_sub_loaders, list):
self.active_sub_loaders = [self.active_sub_loaders]
def get_loader(dataset, batch_size, *,
chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args):
views_per_chunk=8, shuffle=False, num_workers=4, device: torch.device = None):
if isinstance(dataset, list):
return MultiScaleDataLoader(dataset, batch_size, chunk_max_items=chunk_max_items,
shuffle=shuffle, enable_preload=enable_preload, **chunk_args)
return DataLoader(dataset, batch_size, chunk_max_items=chunk_max_items,
shuffle=shuffle, enable_preload=enable_preload, **chunk_args)
raise NotImplementedError()
return MultiScaleDataLoader(dataset, batch_size, views_per_chunk=views_per_chunk,
shuffle=shuffle, num_workers=num_workers, device=device)
return RaysLoader(dataset, batch_size, views_per_chunk=views_per_chunk,
shuffle=shuffle, num_workers=num_workers, device=device)
import os
import torch
import torch.nn.functional as nn_f
from typing import Dict, Tuple, Union
from operator import itemgetter
from pathlib import Path
from utils import img
from utils import color
from utils import sphere
from utils import math
from utils.mem_profiler import *
from .dataset import Dataset
class PanoDataset(Dataset):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset\n
view_file_pattern ```str```: the filename pattern of view images\n
cam_params ```object```: camera intrinsic parameters\n
centers ```Tensor(N, 3)```: centers of views\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
images ```Tensor(N, 3, H, W)```: images of views\n
view_depths ```Tensor(N, H, W)```: depths of views\n
"""
class Chunk(object):
@property
def n_views(self):
return self.indices.size(0)
@property
def n_pixels_per_view(self):
return self.dataset.n_pixels_per_view
def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
color: int, **kwargs):
"""
[summary]
:param dataset `PanoDataset`: dataset object
:param indices `Tensor(N)`: indices of views
:param centers `Tensor(N, 3)`: centers of views
"""
self.id = id
self.dataset = dataset
self.indices = chunk_data['indices']
self.centers = chunk_data['centers']
self.color = color
self.colors_cpu = None
self.colors = None
self.loaded = False
def release(self):
self.colors = None
self.loaded = False
MemProfiler.print_memory_stats(f'Chunk #{self.id} released')
def load(self):
if self.dataset.image_path and self.colors_cpu is None:
images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
color.RGB, self.color)
if self.dataset.res != tuple(images.shape[-2:]):
images = nn_f.interpolate(images, self.dataset.res)
self.colors_cpu = images.permute(
0, 2, 3, 1)[:, self.dataset.pixels[:, 0], self.dataset.pixels[:, 1]].flatten(0, 1)
if self.colors_cpu is not None:
self.colors = self.colors_cpu.to(self.dataset.device)
self.loaded = True
MemProfiler.print_memory_stats(
f'Chunk #{self.id} ({self.n_views} views, '
f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded')
def __len__(self):
return self.n_views * self.n_pixels_per_view
def __getitem__(self, idx):
if not self.loaded:
self.load()
view_idx = idx // self.n_pixels_per_view
pix_idx = idx % self.n_pixels_per_view
global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
rays_o = self.centers[view_idx]
rays_d = self.dataset.rays[pix_idx]
data = {
'idx': global_idx,
'rays_o': rays_o,
'rays_d': rays_d,
'level': self.dataset.level
}
if self.colors is not None:
data['color'] = self.colors[idx]
return data
@property
def n_pixels_per_view(self):
return self.pixels.size(0)
def __init__(self, desc: dict, desc_path: Path, *,
load_images: bool = True,
res: Tuple[int, int] = None,
views_to_load: Union[range, torch.Tensor] = None,
device: torch.device = None,
**kwargs):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view
- depth_range: { "min", "max" }, the depth range
- range: { "min": [...], "max": [...] }, the range of translation and rotation
- centers: [ [ x, y, z ], ... ], centers of views
:param desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
super().__init__(desc, desc_path, res=res, views_to_load=views_to_load, device=device,
load_images=load_images)
def get_data(self):
return {
'indices': self.indices,
'centers': self.centers
}
def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor],
load_images: bool):
super()._load_desc(res, views_to_load)
self.image_path = load_images and self._get_data_path("view")
self.pixels, self.rays = self._get_pano_rays()
def _get_pano_rays(self):
"""
Get unprojected rays of pixels on a panorama
:return `Tensor(N, 2)`: rays' pixel coordinates in pano image
:return `Tensor(N, 3)`: rays' directions with one unit length
"""
phi = (torch.arange(self.res[0], device=self.device) + 0.5) / self.res[0] * math.pi # (H)
length = (phi.sin() * self.res[1] * 0.5).ceil() * 2
cols = torch.arange(self.res[1], device=self.device)[None, :].expand(*self.res) # (H, W)
mask = torch.logical_and(cols >= (self.res[1] - length[:, None]) / 2,
cols < (self.res[1] + length[:, None]) / 2) # (H, W)
pixs = mask.nonzero() # (N, 2)
pixs_phi = (0.5 - (pixs[:, 0] + 0.5) / self.res[0]) * math.pi
pixs_theta = (pixs[:, 1] * 2 + 1 - self.res[1]) / length[pixs[:, 0]] * math.pi
spher_coords = torch.stack([torch.ones_like(pixs_phi), pixs_theta, pixs_phi], dim=-1)
return pixs, sphere.spherical2cartesian(spher_coords) # (N, 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment