fovea_params_optim.py 1.97 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import itertools
import torch
from utils import math
from tqdm import tqdm

mar0 = 1. / 48.
mar_slope = 0.0275
weights = torch.tensor([1., .25, .25], device="cuda")  # (L) Also define levels here

# VR configuration
res = (1440, 1600)  # (hor, ver)
fov = 110  # degrees
distance = .5 * res[1] / math.tan(.5 * math.radians(fov))
ratio = res[0] / res[1]  # hor / ver

K = 360. / math.pi / distance
L = len(weights)

min_sum = math.inf
x_of_min_sum = None
e_of_min_sum = None
s_of_min_sum = None
D_of_min_sum = None

for x1 in tqdm(itertools.product(*([range(1, res[0] - 2)] * (L - 3))),
               total=int(math.pow(res[0] - 1, L - 3))):
    if any([x1[i] <= x1[i - 1] for i in range(1, len(x1))]):
        continue
    if not x1:
        x2 = torch.stack(torch.meshgrid(
            [torch.arange(1, res[0], device="cuda")] * 2), -1).flatten(0, 1)
        x = x2[(x2[:, 1:] > x2[:, :-1]).any(-1)]
    else:
        x2 = torch.stack(torch.meshgrid(
            [torch.arange(x1[-1] + 1, res[0], device="cuda")] * 2), -1).flatten(0, 1)
        x = torch.cat([
            torch.tensor([x1], device="cuda").expand(x2.shape[0], -1),
            x2[(x2[:, 1:] <= x2[:, :-1]).any(-1)]
        ], -1)
    tan_e = x / distance  # (N, L - 1)
    e = tan_e.arctan().rad2deg()  # (N, L - 1)
    mar = mar0 + mar_slope * e  # (N, L - 1)
    s = torch.cat([e.new_ones(e.shape[0], 1), mar * (1. + tan_e.pow(2.)) / K], -1)  # (N, L)
    D = torch.cat([x * 2. / s[:, :-1], res[1] / s[:, -1:]], -1)  # (N, L)
    P = D * D
    P[:, -1] *= ratio
    weighted_sum = (P * weights).sum(-1)
    min_value, min_indice = weighted_sum.min(0)
    min_value = min_value.item()
    min_indice = min_indice.item()
    if min_value < min_sum:
        min_sum = min_value
        x_of_min_sum = x[min_indice]
        e_of_min_sum = e[min_indice]
        s_of_min_sum = s[min_indice]
        D_of_min_sum = D[min_indice]

print(min_sum)
print("x:", x_of_min_sum)
print("e:", e_of_min_sum)
print("s:", s_of_min_sum)
print("D:", D_of_min_sum)