refine.py 5.63 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import numpy as np
import torch.nn.functional as nn_f
from . import view


class GuideRefinement(object):

    def __init__(self, guides_image, guides_view: view.Trans,
                 guides_cam: view.CameraParam, net) -> None:
        rays_o, rays_d = guides_cam.get_global_rays(guides_view, flatten=True)
        guides_inferred = torch.stack([
            net(rays_o[i], rays_d[i]).view(
                guides_cam.res[0], guides_cam.res[1], -1).permute(2, 0, 1)
            for i in range(guides_image.size(0))
        ], 0)
        self.guides_diff = (guides_image - guides_inferred) / \
            (guides_inferred + 1e-5)
        self.guides_view = guides_view
        self.guides_cam = guides_cam

    def get_warp(self, rays_o, rays_d, depthmap, tgt_trans: view.Trans, tgt_cam):
        rays_size = list(depthmap.size()) + [3]
        rays_o = rays_o.view(rays_size)
        rays_d = rays_d.view(rays_size)
        #print(rays_o.size(), rays_d.size(), depthmap.size())
        pcloud = rays_o + rays_d * depthmap[..., None]
        #print('pcloud', pcloud.size())
        pcloud_in_tgt = tgt_trans.trans_point(pcloud, inverse=True)
        #print(pcloud_in_tgt.size())
        pixel_positions = tgt_cam.proj(pcloud_in_tgt)
        pixel_positions[..., 0] /= tgt_cam.res[1] * 0.5
        pixel_positions[..., 1] /= tgt_cam.res[0] * 0.5
        pixel_positions -= 1
        return pixel_positions

    def refine_by_guide(self, image, depthmap, rays_o, rays_d, is_lr):
        if is_lr:
            image = nn_f.upsample(
                image[None, ...], scale_factor=2, mode='bicubic')[0]
            depthmap = nn_f.upsample(
                depthmap[None, None, ...], scale_factor=2, mode='bicubic')[0, 0]
        warp = self.get_warp(rays_o, rays_d, depthmap,
                             self.guides_view, self.guides_cam)
        warped_diff = nn_f.grid_sample(self.guides_diff, warp)
        print(warp.size(), warped_diff.size())
        avg_diff = torch.mean(warped_diff, 0)
        return image * (1 + avg_diff)


def constrast_enhance(image, sigma, fe):
    kernel = torch.ones(1, 1, 3, 3, device=image.device) / 9
    mean = torch.cat([
        nn_f.conv2d(image[:, 0:1], kernel, padding=1),
        nn_f.conv2d(image[:, 1:2], kernel, padding=1),
        nn_f.conv2d(image[:, 2:3], kernel, padding=1)
    ], 1)
    cScale = 1.0 + sigma * fe
    return torch.clamp(mean + (image - mean) * cScale, 0, 1)

def get_grad(image):
    kernel = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=image.device, dtype=torch.float32).view(1, 1, 3, 3)
    x_grad =  torch.cat([
        nn_f.conv2d(image[:, 0:1], kernel, padding=1),
        nn_f.conv2d(image[:, 1:2], kernel, padding=1),
        nn_f.conv2d(image[:, 2:3], kernel, padding=1)
    ], 1)
    kernel = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=image.device, dtype=torch.float32).view(1, 1, 3, 3)
    y_grad = torch.cat([
        nn_f.conv2d(image[:, 0:1], kernel, padding=1),
        nn_f.conv2d(image[:, 1:2], kernel, padding=1),
        nn_f.conv2d(image[:, 2:3], kernel, padding=1)
    ], 1)
    return (x_grad ** 2 + y_grad ** 2).sqrt().clamp(0, 1)


def getGaussianKernel(ksize, sigma=0):
    if sigma <= 0:
        # 根据 kernelsize 计算默认的 sigma,和 opencv 保持一致
        sigma = 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8 
    center = ksize // 2
    xs = (np.arange(ksize, dtype=np.float32) - center) # 元素与矩阵中心的横向距离
    kernel1d = np.exp(-(xs ** 2) / (2 * sigma ** 2)) # 计算一维卷积核
    # 根据指数函数性质,利用矩阵乘法快速计算二维卷积核
    kernel = kernel1d[..., None] @ kernel1d[None, ...] 
    kernel = torch.from_numpy(kernel)
    kernel = kernel / kernel.sum() # 归一化
    return kernel.view(1, 1, 3, 3)


def grad_aware_gaussian(image, ksize, sigma=0):
    kernel = getGaussianKernel(ksize, sigma).to(image.device)
    print(kernel.size())
    blur = torch.cat([
        nn_f.conv2d(image[:, 0:1], kernel, padding=1),
        nn_f.conv2d(image[:, 1:2], kernel, padding=1),
        nn_f.conv2d(image[:, 2:3], kernel, padding=1)
    ], 1)
    grad = get_grad(image)
    return image * grad + blur * (1 - grad)


def bilateral_filter(batch_img, ksize, sigmaColor=None, sigmaSpace=None):
    device = batch_img.device
    if sigmaSpace is None:
        sigmaSpace = 0.15 * ksize + 0.35  # 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8
    if sigmaColor is None:
        sigmaColor = sigmaSpace
    
    pad = (ksize - 1) // 2
    batch_img_pad = nn_f.pad(batch_img, pad=[pad, pad, pad, pad], mode='reflect')
    
    # batch_img 的维度为 BxcxHxW, 因此要沿着第 二、三维度 unfold
    # patches.shape:  B x C x H x W x ksize x ksize
    patches = batch_img_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
    patch_dim = patches.dim() # 6 
    # 求出像素亮度差
    diff_color = patches - batch_img.unsqueeze(-1).unsqueeze(-1)
    # 根据像素亮度差,计算权重矩阵
    weights_color = torch.exp(-(diff_color ** 2) / (2 * sigmaColor ** 2))
    # 归一化权重矩阵
    weights_color = weights_color / weights_color.sum(dim=(-1, -2), keepdim=True)
    
    # 获取 gaussian kernel 并将其复制成和 weight_color 形状相同的 tensor
    weights_space = getGaussianKernel(ksize, sigmaSpace).to(device)
    weights_space_dim = (patch_dim - 2) * (1,) + (ksize, ksize)
    weights_space = weights_space.view(*weights_space_dim).expand_as(weights_color)
    
    # 两个权重矩阵相乘得到总的权重矩阵
    weights = weights_space * weights_color
    # 总权重矩阵的归一化参数
    weights_sum = weights.sum(dim=(-1, -2))
    # 加权平均
    weighted_pix = (weights * patches).sum(dim=(-1, -2)) / weights_sum
    return weighted_pix