Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
...@@ -140,7 +140,7 @@ ...@@ -140,7 +140,7 @@
"\n", "\n",
"# Load Dataset\n", "# Load Dataset\n",
"views = load_views('views.json')\n", "views = load_views('views.json')\n",
"#ref_dataset = SphericalViewSynDataset('ref.json', load_images=False, calculate_rays=False)\n", "#ref_dataset = SphericalViewSynDataset('ref.json', load_colors=False, calculate_rays=False)\n",
"print('Dataset loaded.')\n", "print('Dataset loaded.')\n",
"\n", "\n",
"print('views:', views.size())\n", "print('views:', views.size())\n",
...@@ -226,4 +226,4 @@ ...@@ -226,4 +226,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2
} }
\ No newline at end of file
...@@ -44,7 +44,7 @@ ...@@ -44,7 +44,7 @@
" return None\n", " return None\n",
"\n", "\n",
"\n", "\n",
"def load_views(data_desc_file) -> Tuple[view.Trans, torch.Tensor]:\n", "def load_views(data_desc_file) -> tuple[view.Trans, torch.Tensor]:\n",
" with open(data_desc_file, 'r', encoding='utf-8') as file:\n", " with open(data_desc_file, 'r', encoding='utf-8') as file:\n",
" lines = file.readlines()\n", " lines = file.readlines()\n",
" n = len(lines) // 7\n", " n = len(lines) // 7\n",
......
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"source": [
"import sys\n",
"import os\n",
"import torch\n",
"import torch.nn.functional as nn_f\n",
"import matplotlib.pyplot as plt\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n",
"torch.cuda.set_device(0)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n",
"from utils import img\n",
"from utils.view import *\n",
"\n",
"datadir = f\"{rootdir}/data/__new/__demo/for_crop\"\n",
"figs = ['our', 'gt', 'nerf', 'fgt']\n",
"crops = {\n",
" 'classroom_0': [[720, 790, 100], [370, 1160, 200]],\n",
" 'lobby_1': [[570, 1000, 100], [1300, 1000, 200]],\n",
" 'stones_2': [[720, 800, 100], [680, 1317, 200]],\n",
" 'barbershop_3': [[745, 810, 100], [950, 900, 200]]\n",
"}\n",
"colors = torch.tensor([[0, 1, 0, 1], [1, 1, 0, 1]], dtype=torch.float)\n",
"border = 10\n",
"\n",
"for scene in crops:\n",
" images = img.load([f\"{datadir}/origin/{scene}_{fig}.png\" for fig in figs])\n",
" halfw = images.size(-1) // 2\n",
" halfh = images.size(-2) // 2\n",
" crop = crops[scene]\n",
" fovea_patches = images[...,\n",
" crop[0][1] - crop[0][2] // 2: crop[0][1] + crop[0][2] // 2,\n",
" crop[0][0] - crop[0][2] // 2: crop[0][0] + crop[0][2] // 2]\n",
" periph_patches = images[...,\n",
" crop[1][1] - crop[1][2] // 2: crop[1][1] + crop[1][2] // 2,\n",
" crop[1][0] - crop[1][2] // 2: crop[1][0] + crop[1][2] // 2]\n",
" fovea_patches = nn_f.interpolate(fovea_patches, (128, 128))\n",
" periph_patches = nn_f.interpolate(periph_patches, (128, 128))\n",
" overlay = torch.zeros(1, 4, 1600, 1440)\n",
" mask = torch.zeros(2, 1600, 1440, dtype=torch.bool)\n",
" for i in range(2):\n",
" mask[i,\n",
" crop[i][1] - crop[i][2] // 2 - border: crop[i][1] + crop[i][2] // 2 + border,\n",
" crop[i][0] - crop[i][2] // 2 - border: crop[i][0] + crop[i][2] // 2 + border] = True\n",
" mask[i,\n",
" crop[i][1] - crop[i][2] // 2: crop[i][1] + crop[i][2] // 2,\n",
" crop[i][0] - crop[i][2] // 2: crop[i][0] + crop[i][2] // 2] = False\n",
" overlay[:, :, mask[0]] = colors[0][..., None]\n",
" overlay[:, :, mask[1]] = colors[1][..., None]\n",
" plt.figure(figsize=(12, 6))\n",
" plt.subplot(1, 2, 1)\n",
" img.plot(images[0])\n",
" plt.subplot(1, 2, 2)\n",
" img.plot(overlay)\n",
" plt.figure(figsize=(12, 6))\n",
" for i in range(4):\n",
" plt.subplot(2, 4, i + 1)\n",
" img.plot(fovea_patches[i])\n",
" for i in range(4):\n",
" plt.subplot(2, 4, i + 5)\n",
" img.plot(periph_patches[i])\n",
" img.save(fovea_patches, [f\"{datadir}/fovea/{scene}_{fig}.png\" for fig in figs])\n",
" img.save(periph_patches, [f\"{datadir}/periph/{scene}_{fig}.png\" for fig in figs])\n",
" img.save(torch.cat([fovea_patches, periph_patches], dim=-1),\n",
" [f\"{datadir}/patch/{scene}_{fig}.png\" for fig in figs])\n",
" img.save(overlay, f\"{datadir}/overlay/{scene}.png\")\n"
],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"interpreter": {
"hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.8.5 64-bit ('base': conda)"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Performance of Randperm on CPU/GPU"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Node Random perm on host: host duration 6384.1ms, device duration 6384.4ms\n",
"Node Random perm on device: host duration 2525.0ms, device duration 2525.0ms\n"
]
}
],
"source": [
"from common import *\n",
"from utils.profile import debug_profile\n",
"from utils.mem_profiler import MemProfiler\n",
"\n",
"with debug_profile(\"Random perm on host\"):\n",
" torch.randperm(1024 * 1024 * 100)\n",
"\n",
"with debug_profile(\"Random perm on device\"),\\\n",
" MemProfiler(\"Random perm on host\", device=\"cuda:3\"):\n",
" torch.randperm(1024 * 1024 * 100, device=\"cuda:3\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# \\_\\_getattribute\\_\\_ Method"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a.a: 1\n",
"a.b: 2\n"
]
},
{
"ename": "AttributeError",
"evalue": "'A' object has no attribute 'c'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-3-bdb259af4410>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"a.a:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"a.b:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"a.c:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-3-bdb259af4410>\u001b[0m in \u001b[0;36m__getattribute__\u001b[0;34m(self, _A__name)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0merr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-3-bdb259af4410>\u001b[0m in \u001b[0;36m__getattribute__\u001b[0;34m(self, _A__name)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m__name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m__name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mAttributeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'A' object has no attribute 'c'"
]
}
],
"source": [
"class A(object):\n",
"\n",
"\n",
" def __init__(self, a, **extra) -> None:\n",
" super().__init__()\n",
" self.a = a\n",
" self.extra = extra\n",
"\n",
" def __getattribute__(self, __name: str):\n",
" try:\n",
" return super().__getattribute__(__name)\n",
" except AttributeError as e:\n",
" try:\n",
" return self.extra[__name]\n",
" except KeyError:\n",
" pass\n",
" err = e\n",
" raise err\n",
"\n",
"\n",
"a = A(a=1, b=2)\n",
"\n",
"\n",
"print(\"a.a:\", a.a)\n",
"print(\"a.b:\", a.b)\n",
"print(\"a.c:\", a.c)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Performance of Various Select/Scatter Methods"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.010159730911254883 tensor(100., device='cuda:0')\n",
"0.011237859725952148 tensor(1400., device='cuda:0')\n",
"0.032263755798339844 tensor(2700., device='cuda:0')\n",
"0.02148723602294922 tensor(4.1723e-07, device='cuda:0')\n",
"0.009927511215209961 tensor(4.1723e-07, device='cuda:0')\n",
"Mask set 0.02173590660095215\n",
"Inplace mask scatter 0.0041882991790771484\n",
"Mask scatter 0.00580906867980957\n",
"Index set 0.03358888626098633\n",
"Index put 0.01044917106628418\n"
]
}
],
"source": [
"from common import *\n",
"from time import time\n",
"\n",
"a = torch.zeros(200000, device=\"cuda\")\n",
"i = torch.randint(0, a.shape[0], [500000], device=\"cuda\")\n",
"start = time()\n",
"for _ in range(100):\n",
" a[i] += 1\n",
"end = time()\n",
"print(end - start, a.max())\n",
"start = time()\n",
"for _ in range(100):\n",
" a.index_add_(0, i, torch.ones_like(i, dtype=torch.float))\n",
"end = time()\n",
"print(end - start, a.max())\n",
"\n",
"start = time()\n",
"for _ in range(100):\n",
" ui, n = i.unique(return_counts=True)\n",
" a[ui] += n\n",
"end = time()\n",
"print(end - start, a.max())\n",
"\n",
"\n",
"a = torch.rand(2000, 2000, device=\"cuda\") - .5\n",
"m = a > 0\n",
"\n",
"start = time()\n",
"for _ in range(100):\n",
" b = a[m]\n",
"end = time()\n",
"print(end - start, b.min())\n",
"\n",
"start = time()\n",
"for _1 in range(20):\n",
" m1 = m.nonzero(as_tuple=True)\n",
" for _ in range(5):\n",
" b = a[m1]\n",
"end = time()\n",
"print(end - start, b.min())\n",
"\n",
"\n",
"c = torch.rand_like(b)\n",
"\n",
"start = time()\n",
"for _ in range(100):\n",
" a[m] = c\n",
"end = time()\n",
"print(\"Mask set\", end - start)\n",
"\n",
"start = time()\n",
"for _ in range(100):\n",
" a.masked_scatter_(m, c)\n",
"end = time()\n",
"print(\"Inplace mask scatter\", end - start)\n",
"\n",
"\n",
"start = time()\n",
"for _ in range(100):\n",
" a = a.masked_scatter(m, c)\n",
"end = time()\n",
"print(\"Mask scatter\", end - start)\n",
"\n",
"start = time()\n",
"for _1 in range(20):\n",
" m1 = m.nonzero(as_tuple=True)\n",
" for _ in range(5):\n",
" a[m1] = b\n",
"end = time()\n",
"print(\"Index set\", end - start)\n",
"\n",
"\n",
"start = time()\n",
"for _1 in range(20):\n",
" m1 = m.nonzero(as_tuple=True)\n",
" for _ in range(5):\n",
" a.index_put_(m1, b)\n",
"end = time()\n",
"print(\"Index put\", end - start)"
]
}
],
"metadata": {
"interpreter": {
"hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
},
"kernelspec": {
"display_name": "Python 3.8.12 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
rootdir = Path(sys.path[0]).absolute().parents[1]
sys.path.append(str(rootdir))
torch.cuda.set_device(0)
\ No newline at end of file
...@@ -6,21 +6,11 @@ ...@@ -6,21 +6,11 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import sys\n", "from common import *\n",
"import os\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import torchvision.transforms.functional as trans_f\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n",
"torch.cuda.set_device(2)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"\n",
"from components import refine\n", "from components import refine\n",
"from utils import img\n", "from utils import img\n",
"\n", "\n",
"img = img.load(os.path.join(rootdir, \"data/gas_2021.01.04_all_in_one/output/mid_0536.png\"))\n", "img = img.load(rootdir / \"data/gas_2021.01.04_all_in_one/output/mid_0536.png\")\n",
"\n", "\n",
"fe = 0.2\n", "fe = 0.2\n",
"leng_sigma = [0,3,5]\n", "leng_sigma = [0,3,5]\n",
...@@ -37,8 +27,7 @@ ...@@ -37,8 +27,7 @@
"img.plot(enhanced)\n", "img.plot(enhanced)\n",
"plt.title('Enhanced')\n", "plt.title('Enhanced')\n",
"plt.axis('off')\n", "plt.axis('off')\n",
"plt.show()\n", "plt.show()"
"\n"
] ]
} }
], ],
...@@ -63,4 +52,4 @@ ...@@ -63,4 +52,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2
} }
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -6,28 +6,17 @@ ...@@ -6,28 +6,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# 图像去噪平滑滤波\n",
"# 使用opencv的自带函数实现,与自编写作比较\n",
"# 产生椒盐噪声,高斯噪声等\n",
"# 使用中值滤波,平均滤波,高斯滤波,方框滤波\n",
"import sys\n",
"import os\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n",
"\n",
"import numpy as np\n",
"import math\n",
"import cv2\n", "import cv2\n",
"import matplotlib.pyplot as plt\n", "\n",
"import torch.nn.functional as nn_f\n", "from common import *\n",
"import torch\n", "from utils import math\n",
"from loss.perc_loss import *\n", "from loss import *\n",
"\n", "\n",
"loss = VGGPerceptualLoss().to('cuda')\n", "loss = VGGPerceptualLoss().to('cuda')\n",
"\n", "\n",
"def psnr(input, gt):\n", "def psnr(input, gt):\n",
" input, gt = torch.from_numpy(input / 255).permute(2, 0, 1)[None, :].to('cuda', torch.float32), torch.from_numpy(gt / 255).permute(2, 0, 1)[None, :].to('cuda', torch.float32)\n", " input, gt = torch.from_numpy(input / 255).permute(2, 0, 1)[None, :].to('cuda', torch.float32), torch.from_numpy(gt / 255).permute(2, 0, 1)[None, :].to('cuda', torch.float32)\n",
" rmse = math.sqrt(nn_f.mse_loss(input, gt))\n", " rmse = math.sqrt(mse_loss(input, gt))\n",
" #diff = target / 255 - ref / 255\n", " #diff = target / 255 - ref / 255\n",
" #rmse = math.sqrt(np.mean(diff ** 2.))\n", " #rmse = math.sqrt(np.mean(diff ** 2.))\n",
" #return rmse\n", " #return rmse\n",
...@@ -35,8 +24,8 @@ ...@@ -35,8 +24,8 @@
"\n", "\n",
"\n", "\n",
"for i in range(3):\n", "for i in range(3):\n",
" image_gt = cv2.imread(os.path.join (rootdir, 'data/gas_fovea_2020.12.31/train/view_%04d.png' % i))\n", " image_gt = cv2.imread(rootdir / 'data/gas_fovea_2020.12.31/train/view_%04d.png' % i)\n",
" image = cv2.imread(os.path.join (rootdir, 'data/gas_fovea_2020.12.31/new_fovea_rgb@nmsl-rgb_e10_fc128x4_d1-50_s32/output/model-epoch_300/train/out_view_%04d.png' % i))\n", " image = cv2.imread(rootdir / 'data/gas_fovea_2020.12.31/new_fovea_rgb@nmsl-rgb_e10_fc128x4_d1-50_s32/output/model-epoch_300/train/out_view_%04d.png' % i)\n",
" plt.figure(facecolor='white', figsize=(10,4))\n", " plt.figure(facecolor='white', figsize=(10,4))\n",
" plt.subplot(2, 3, 1)\n", " plt.subplot(2, 3, 1)\n",
" plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n", " plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n",
...@@ -112,4 +101,4 @@ ...@@ -112,4 +101,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2
} }
\ No newline at end of file
...@@ -6,29 +6,18 @@ ...@@ -6,29 +6,18 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import sys\n", "%matplotlib inline\n",
"import os\n",
"import torch\n",
"import time\n", "import time\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"rootdir = os.path.abspath('../')\n",
"sys.path.append(rootdir)\n",
"\n", "\n",
"from common import *\n",
"from utils import img\n", "from utils import img\n",
"from utils import sphere\n", "from utils import sphere\n",
"from utils import device\n", "from utils import device\n",
"from utils import misc\n", "from utils import misc\n",
"from utils.mem_profiler import *\n", "from utils.mem_profiler import *\n",
"from data.dataset_factory import DatasetFactory\n", "from data import Dataset, RaysLoader\n",
"from data.loader import DataLoader\n",
"\n",
"# Select device\n",
"torch.cuda.set_device(0)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"\n", "\n",
"MemProfiler.enable = True\n" "MemProfiler.enable = False"
] ]
}, },
{ {
...@@ -57,27 +46,28 @@ ...@@ -57,27 +46,28 @@
"#DATA_DESC_FILE = f'{rootdir}/data/__nerf/horns/images_4.json'\n", "#DATA_DESC_FILE = f'{rootdir}/data/__nerf/horns/images_4.json'\n",
"#DATA_DESC_FILE = f'{rootdir}/data/__new/city_fovea_r360x80_t5.0/train1.json'\n", "#DATA_DESC_FILE = f'{rootdir}/data/__new/city_fovea_r360x80_t5.0/train1.json'\n",
"#DATA_DESC_FILE = f'{rootdir}/data/__captured/room/train.json'\n", "#DATA_DESC_FILE = f'{rootdir}/data/__captured/room/train.json'\n",
"DATA_DESC_FILE = f'{rootdir}/data/__pano/stones_fovea_t1.0/train.json'\n", "DATA_DESC_FILE = f'{rootdir}/data/__captured/bedroom/images4_train.json'\n",
"\n", "\n",
"MemProfiler.print_memory_stats('Start')\n", "MemProfiler.print_memory_stats('Start')\n",
"\n", "\n",
"dataset = DatasetFactory.load(DATA_DESC_FILE)\n", "dataset = Dataset(DATA_DESC_FILE)\n",
"res = dataset.res\n", "res = dataset.res\n",
"data_loader = DataLoader(dataset, res[0] * res[1], chunk_max_items=6e8)\n", "data_loader = RaysLoader(dataset, res[0] * res[1], device=torch.device(\"cuda\"))\n",
"\n", "\n",
"MemProfiler.print_memory_stats('After dataset loaded')\n", "MemProfiler.print_memory_stats('After dataset loaded')\n",
"\n", "\n",
"fig = plt.figure(figsize=(12, 6))\n", "fig = plt.figure(figsize=(12, 6))\n",
"i = 0\n", "i = 0\n",
"for indices, rays_o, rays_d, extras in data_loader:\n", "for data in data_loader:\n",
" if i >= 4:\n", " if i >= 4:\n",
" break\n", " break\n",
" plt.subplot(2, 2, i + 1)\n", " plt.subplot(2, 2, i + 1)\n",
" img.plot(extras['colors'].view(1, res[0], res[1], 3))\n", " img.plot(data['color'].view(1, res[0], res[1], 3))\n",
" MemProfiler.print_memory_stats(f'After view {i} is plotted')\n", " MemProfiler.print_memory_stats(f'After view {i} is plotted')\n",
" i += 1\n", " i += 1\n",
" time.sleep(1)\n", " time.sleep(1)\n",
"\n", "\n",
"#plt.show()\n",
"MemProfiler.print_memory_stats(f'After all views are plotted')" "MemProfiler.print_memory_stats(f'After all views are plotted')"
] ]
}, },
...@@ -88,23 +78,23 @@ ...@@ -88,23 +78,23 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"selector = torch.arange(res[0] * res[1]).reshape(res[0], res[1])\n", "selector = torch.arange(res[0] * res[1]).reshape(res[0], res[1])\n",
"selector = selector[1024-512:1024+512:10, 2048-512:2048+512:5].flatten()\n", "selector = selector[::3, ::3].flatten()\n",
"idx_range = [0, 4, 20, 24, 62, 100, 104, 120, 124]\n", "idx_range = [0, 4, 20, 24, 62, 100, 104, 120, 124]\n",
"for r in torch.arange(3, 3.5, 0.1):\n", "for r in torch.arange(11, 50, 5):\n",
" p = None\n", " p = None\n",
" centers = None\n", " centers = None\n",
" pixels = None\n", " pixels = None\n",
" idx = 0\n", " idx = 0\n",
" MemProfiler.print_memory_stats(f'Before iter')\n", " MemProfiler.print_memory_stats(f'Before iter')\n",
" for indices, rays_o, rays_d, extras in data_loader:\n", " for data in data_loader:\n",
" if idx > max(idx_range):\n", " if idx > max(idx_range):\n",
" break\n", " break\n",
" if idx not in idx_range:\n", " if idx not in idx_range:\n",
" idx += 1\n", " idx += 1\n",
" continue\n", " continue\n",
" colors = extras['colors'][selector]\n", " colors = data['color'][selector]\n",
" rays_o = rays_o[selector]\n", " rays_o = data['rays_o'][selector]\n",
" rays_d = rays_d[selector]\n", " rays_d = data['rays_d'][selector]\n",
" r = torch.tensor([[r]], device=device.default())\n", " r = torch.tensor([[r]], device=device.default())\n",
" p_ = misc.torch2np(sphere.ray_sphere_intersect(rays_o, rays_d, r)[0].view(-1, 3))\n", " p_ = misc.torch2np(sphere.ray_sphere_intersect(rays_o, rays_d, r)[0].view(-1, 3))\n",
" p = p_ if p is None else np.concatenate((p, p_), axis=0)\n", " p = p_ if p is None else np.concatenate((p, p_), axis=0)\n",
...@@ -118,7 +108,6 @@ ...@@ -118,7 +108,6 @@
" plt.ylabel('z')\n", " plt.ylabel('z')\n",
" plt.title('r = %f' % r)\n", " plt.title('r = %f' % r)\n",
" ax.scatter([0], [0], [0], color=\"k\", s=10)\n", " ax.scatter([0], [0], [0], color=\"k\", s=10)\n",
" print(p.shape, pixels.shape)\n",
" ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)\n", " ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)\n",
" ax.view_init(elev=0, azim=-90)\n" " ax.view_init(elev=0, azim=-90)\n"
] ]
...@@ -126,10 +115,10 @@ ...@@ -126,10 +115,10 @@
], ],
"metadata": { "metadata": {
"interpreter": { "interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5" "hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3.8.12 ('base')",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
...@@ -143,7 +132,7 @@ ...@@ -143,7 +132,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.5" "version": "3.8.12"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -80,7 +80,7 @@ ...@@ -80,7 +80,7 @@
"\n", "\n",
"images = img.load_seq(f'{datadir}/img%02d.png', 16, permute=False)\n", "images = img.load_seq(f'{datadir}/img%02d.png', 16, permute=False)\n",
"res=(int(p[0, 0]), int(p[0, 1]))\n", "res=(int(p[0, 0]), int(p[0, 1]))\n",
"cam = CameraParam({\"fy\":-p[0, 2], \"fx\": p[0, 2], \"cx\":res[1]//2, \"cy\":res[0]//2}, res)\n", "cam = Camera({\"fy\":-p[0, 2], \"fx\": p[0, 2], \"cx\":res[1]//2, \"cy\":res[0]//2}, res)\n",
"views = Trans(torch.tensor(t, dtype=torch.float), torch.tensor(r, dtype=torch.float))\n", "views = Trans(torch.tensor(t, dtype=torch.float), torch.tensor(r, dtype=torch.float))\n",
"_rays_o, _rays_d = cam.get_global_rays(views, flatten=True)\n", "_rays_o, _rays_d = cam.get_global_rays(views, flatten=True)\n",
"_patches = images.flatten(1, 2)\n", "_patches = images.flatten(1, 2)\n",
......
...@@ -6,29 +6,28 @@ ...@@ -6,29 +6,28 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%matplotlib inline\n",
"import sys\n", "import sys\n",
"import os\n", "import os\n",
"import struct\n", "import struct\n",
"import torch\n", "import torch\n",
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n", "rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n", "sys.path.append(rootdir)\n",
"torch.cuda.set_device(0)\n", "torch.set_grad_enabled(False)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n",
"\n", "\n",
"from utils import img\n", "from utils import img\n",
"from utils import device\n",
"from utils.view import *\n", "from utils.view import *\n",
"from components.foveation import Foveation\n", "from components.foveation import Foveation\n",
"\n", "\n",
"\n",
"def gen_mask_file(layer_mask, filename):\n", "def gen_mask_file(layer_mask, filename):\n",
" indices = torch.arange(layer_mask.size(0) * layer_mask.size(1),\n", " indices = torch.arange(layer_mask.size(0) * layer_mask.size(1),\n",
" device=layer_mask.device).view_as(layer_mask)\n", " device=layer_mask.device).view_as(layer_mask)\n",
" indices = indices[layer_mask>=0]\n", " indices = indices[layer_mask >= 0]\n",
" inverseIndices = torch.ones(layer_mask.size(0) * layer_mask.size(1), device=layer_mask.device, dtype=torch.long) * -1\n", " inverseIndices = torch.ones(layer_mask.size(0) * layer_mask.size(1),\n",
" device=layer_mask.device, dtype=torch.long) * -1\n",
" inverseIndices[indices] = torch.arange(indices.size(0), device=layer_mask.device)\n", " inverseIndices[indices] = torch.arange(indices.size(0), device=layer_mask.device)\n",
" with open(filename, 'wb') as fp:\n", " with open(filename, 'wb') as fp:\n",
" fp.write(indices.size(0).to_bytes(4, 'little'))\n", " fp.write(indices.size(0).to_bytes(4, 'little'))\n",
...@@ -36,32 +35,46 @@ ...@@ -36,32 +35,46 @@
" fp.write(inverseIndices.size(0).to_bytes(4, 'little'))\n", " fp.write(inverseIndices.size(0).to_bytes(4, 'little'))\n",
" fp.write(struct.pack(f\"<{inverseIndices.size(0)}i\", *inverseIndices))\n", " fp.write(struct.pack(f\"<{inverseIndices.size(0)}i\", *inverseIndices))\n",
"\n", "\n",
"\n",
"foveation = Foveation([20, 45, 110], [(256, 256), (256, 256), (256, 230)], (1600, 1440))\n", "foveation = Foveation([20, 45, 110], [(256, 256), (256, 256), (256, 230)], (1600, 1440))\n",
"layers_mask = foveation.get_layers_mask()\n", "layers_mask = foveation.get_layers_mask(gaze=(0, 0))\n",
"\n",
"plt.figure(figsize=(12, 4))\n", "plt.figure(figsize=(12, 4))\n",
"for i, mask in enumerate(layers_mask):\n", "for i, mask in enumerate(layers_mask):\n",
" colored_mask = torch.zeros(mask.size(0), mask.size(1), 3, device=mask.device)\n", " colored_mask = torch.zeros(mask.size(0), mask.size(1), 3, device=mask.device)\n",
" c = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 1]], device=mask.device)\n", " c = torch.tensor([\n",
" [1, 1, 1, 1, 1, 1],\n",
" [1, 1, 1, 1, 1, 1],\n",
" [1, 1, 1, 0, 0, 0]\n",
" ], device=mask.device)\n",
" for bi in range(3):\n", " for bi in range(3):\n",
" region = torch.logical_and(mask >= bi, mask < bi + 1)\n", " region = torch.logical_and(mask >= bi, mask < bi + 1)\n",
" colored_mask[region] = c[bi] + (c[-1] - c[bi]) * (mask[region][..., None] - bi)\n", " colored_mask[region] = c[bi, :3] + (c[bi, 3:] - c[bi, :3]) * (mask[region][..., None] - bi)\n",
" plt.subplot(1, len(layers_mask), i + 1)\n", " plt.subplot(1, len(layers_mask), i + 1)\n",
" img.plot(colored_mask)\n", " img.plot(colored_mask)\n",
" img.save(colored_mask, f\"blend_{i}.png\")\n",
" n_skipped = torch.sum(mask < 0)\n", " n_skipped = torch.sum(mask < 0)\n",
" n_tot = len(mask.flatten())\n", " n_tot = len(mask.flatten())\n",
" print (f\"Layer {i}: {n_skipped}({n_skipped / n_tot * 100:.2f}%) pixels are masked as skipped, {n_tot - n_skipped} pixels need to be inferred\")\n", " print(f\"Layer {i}: {n_skipped}({n_skipped / n_tot * 100:.2f}%) pixels are masked as skipped, \"\n",
" f\"{n_tot - n_skipped}({(n_tot - n_skipped) / n_tot * 100:.2f}%) pixels need to be inferred\")\n",
"\n", "\n",
"gen_mask_file(layers_mask[0], 'fovea.mask')\n", "plt.figure(figsize=(12, 4))\n",
"gen_mask_file(layers_mask[1], 'mid.mask')" "for i, mask in enumerate(layers_mask):\n",
" binary_mask = (mask >= 0).float().expand(3, -1, -1)\n",
" plt.subplot(1, len(layers_mask), i + 1)\n",
" img.plot(binary_mask)\n",
" img.save(binary_mask, f\"mask_{i}.png\")\n",
"#gen_mask_file(layers_mask[0], 'fovea.mask')\n",
"#gen_mask_file(layers_mask[1], 'mid.mask')\n"
] ]
} }
], ],
"metadata": { "metadata": {
"interpreter": { "interpreter": {
"hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5" "hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3.8.12 ('base')",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
...@@ -75,7 +88,7 @@ ...@@ -75,7 +88,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.5" "version": "3.8.12"
}, },
"metadata": { "metadata": {
"interpreter": { "interpreter": {
......
...@@ -25,18 +25,11 @@ ...@@ -25,18 +25,11 @@
} }
], ],
"source": [ "source": [
"import sys\n", "from common import *\n",
"import os\n",
"import torch\n",
"import torch.nn as nn\n", "import torch.nn as nn\n",
"\n", "\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n",
"torch.cuda.set_device(0)\n",
"print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
"torch.autograd.set_grad_enabled(False)\n", "torch.autograd.set_grad_enabled(False)\n",
"\n", "\n",
"from configs.spherical_view_syn import SphericalViewSynConfig\n",
"from utils import netio\n", "from utils import netio\n",
"from utils import img\n", "from utils import img\n",
"from utils import device\n", "from utils import device\n",
......
...@@ -137,7 +137,7 @@ ...@@ -137,7 +137,7 @@
"test_view = views.get(*view_coord)\n", "test_view = views.get(*view_coord)\n",
"\n", "\n",
"cams = [\n", "cams = [\n",
" view.CameraParam({\n", " view.Camera({\n",
" \"fov\": fov_list[i],\n", " \"fov\": fov_list[i],\n",
" \"cx\": 0.5,\n", " \"cx\": 0.5,\n",
" \"cy\": 0.5,\n", " \"cy\": 0.5,\n",
...@@ -147,7 +147,7 @@ ...@@ -147,7 +147,7 @@
"]\n", "]\n",
"fovea_cam, mid_cam, periph_cam = cams[0], cams[1], cams[2]\n", "fovea_cam, mid_cam, periph_cam = cams[0], cams[1], cams[2]\n",
"#guide_cam = ref_dataset.cam_params\n", "#guide_cam = ref_dataset.cam_params\n",
"vr_cam = view.CameraParam({\n", "vr_cam = view.Camera({\n",
" 'fov': fov_list[-1],\n", " 'fov': fov_list[-1],\n",
" 'cx': 0.5,\n", " 'cx': 0.5,\n",
" 'cy': 0.5,\n", " 'cy': 0.5,\n",
......
...@@ -63,7 +63,7 @@ ...@@ -63,7 +63,7 @@
" view_dataset.samples)\n", " view_dataset.samples)\n",
"ref_indices = torch.arange(ref_dataset.n_views, device=device.default()).view(\n", "ref_indices = torch.arange(ref_dataset.n_views, device=device.default()).view(\n",
" ref_dataset.samples)\n", " ref_dataset.samples)\n",
"cam_params = view.CameraParam({\n", "cam_params = view.Camera({\n",
" \"fov\": 20,\n", " \"fov\": 20,\n",
" \"cx\": 0.5,\n", " \"cx\": 0.5,\n",
" \"cy\": 0.5,\n", " \"cy\": 0.5,\n",
...@@ -76,7 +76,7 @@ ...@@ -76,7 +76,7 @@
"netio.load(model_path, net)\n", "netio.load(model_path, net)\n",
"print('Net loaded.')\n", "print('Net loaded.')\n",
"\n", "\n",
"vr_cam = view.CameraParam({\n", "vr_cam = view.Camera({\n",
" 'fov': 110,\n", " 'fov': 110,\n",
" 'cx': 0.5,\n", " 'cx': 0.5,\n",
" 'cy': 0.5,\n", " 'cy': 0.5,\n",
......
...@@ -6,13 +6,7 @@ ...@@ -6,13 +6,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import torch\n", "from common import *\n",
"import sys\n",
"import os\n",
"\n",
"rootdir = os.path.abspath(sys.path[0] + '/../')\n",
"sys.path.append(rootdir)\n",
"\n",
"from utils.voxels import *\n", "from utils.voxels import *\n",
"\n", "\n",
"bbox, steps = torch.tensor([[-2, -3.14159, 1], [2, 3.14159, 0]]), torch.tensor([2, 3, 3])\n", "bbox, steps = torch.tensor([[-2, -3.14159, 1], [2, 3.14159, 0]]), torch.tensor([2, 3, 3])\n",
...@@ -94,7 +88,7 @@ ...@@ -94,7 +88,7 @@
"voxel_indices_of_new_corner = voxel_indices_in_grid[to_flat_indices(to_grid_coords(new_corners, bbox, steps).min(steps - 1), steps) + 1]\n", "voxel_indices_of_new_corner = voxel_indices_in_grid[to_flat_indices(to_grid_coords(new_corners, bbox, steps).min(steps - 1), steps) + 1]\n",
"print(voxel_indices_of_new_corner)\n", "print(voxel_indices_of_new_corner)\n",
"p_of_new_corners = (new_corners - voxels[voxel_indices_of_new_corner]) / voxel_size + .5\n", "p_of_new_corners = (new_corners - voxels[voxel_indices_of_new_corner]) / voxel_size + .5\n",
"print(((new_corners - trilinear_interp(p_of_new_corners, emb(corner_indices[voxel_indices_of_new_corner]))) > 1e-6).sum())" "print(((new_corners - linear_interp(p_of_new_corners, emb(corner_indices[voxel_indices_of_new_corner]))) > 1e-6).sum())"
] ]
} }
], ],
......
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import matplotlib.pyplot as plt\n",
"from data.lf_syn import LightFieldSynDataset\n",
"from utils import img\n",
"from utils import math\n",
"from nets.trans_unet import LatentSpaceTransformer\n",
"\n",
"device = torch.device(\"cuda:2\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test data loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DATA_DIR = '../data/lf_syn_2020.12.23'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"\n",
"train_dataset = LightFieldSynDataset(TRAIN_DATA_DESC_FILE)\n",
"train_data_loader = torch.utils.data.DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=3,\n",
" num_workers=8,\n",
" pin_memory=True,\n",
" shuffle=True,\n",
" drop_last=False)\n",
"print(len(train_data_loader))\n",
"\n",
"print(train_dataset.cam_params)\n",
"print(train_dataset.sparse_view_positions)\n",
"print(train_dataset.diopter_of_layers)\n",
"plt.figure()\n",
"img.plot(train_dataset.sparse_view_images[0])\n",
"plt.figure()\n",
"img.plot(train_dataset.sparse_view_depths[0] / 255 * 10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test disparity wrapper"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"transformer = LatentSpaceTransformer(train_dataset.sparse_view_images.size()[2],\n",
" train_dataset.cam_params,\n",
" train_dataset.diopter_of_layers,\n",
" train_dataset.sparse_view_positions)\n",
"novel_views = torch.stack([\n",
" train_dataset.view_positions[13],\n",
" train_dataset.view_positions[30],\n",
" train_dataset.view_positions[57],\n",
"], dim=0)\n",
"trans_images = transformer(train_dataset.sparse_view_images.to(device),\n",
" train_dataset.sparse_view_depths.to(device),\n",
" novel_views)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"mask = (torch.sum(trans_images[0], 1) > math.tiny).to(dtype=torch.float)\n",
"blended = torch.sum(trans_images[0], 0)\n",
"weight = torch.sum(mask, 0)\n",
"blended = blended / weight.unsqueeze(0)\n",
"\n",
"plt.figure(figsize=(6, 6))\n",
"img.plot(train_dataset.view_images[13])\n",
"plt.figure(figsize=(6, 6))\n",
"img.plot(blended)\n",
"plt.figure(figsize=(12, 6))\n",
"plt.subplot(2, 4, 1)\n",
"img.plot(train_dataset.sparse_view_images[0])\n",
"plt.subplot(2, 4, 2)\n",
"img.plot(train_dataset.sparse_view_images[1])\n",
"plt.subplot(2, 4, 3)\n",
"img.plot(train_dataset.sparse_view_images[2])\n",
"plt.subplot(2, 4, 4)\n",
"img.plot(train_dataset.sparse_view_images[3])\n",
"\n",
"plt.subplot(2, 4, 5)\n",
"img.plot(trans_images[0, 0])\n",
"plt.subplot(2, 4, 6)\n",
"img.plot(trans_images[0, 1])\n",
"plt.subplot(2, 4, 7)\n",
"img.plot(trans_images[0, 2])\n",
"plt.subplot(2, 4, 8)\n",
"img.plot(trans_images[0, 3])\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
"name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
},
"language_info": {
"name": "python",
"version": ""
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"\n",
"os.chdir('../')\n",
"sys.path.append(os.getcwd())\n",
"\n",
"from utils import img\n",
"from utils import color\n",
"\n",
"input_img = img.load('data/gas_fovea_2020.12.31/upsampling_test/input/out_view_0000.png')\n",
"ycbcr = color.rgb2ycbcr(input_img)\n",
"rgb = color.ycbcr2rgb(ycbcr)\n",
"\n",
"plt.figure()\n",
"img.plot(input_img)\n",
"plt.figure()\n",
"plt.subplot(1, 4, 1)\n",
"img.plot(ycbcr)\n",
"plt.subplot(1, 4, 2)\n",
"img.plot(ycbcr[:, 0])\n",
"plt.subplot(1, 4, 3)\n",
"img.plot(ycbcr[:, 1])\n",
"plt.subplot(1, 4, 4)\n",
"img.plot(ycbcr[:, 2])\n",
"plt.figure()\n",
"img.plot(rgb)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
"name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment