sphere.py 4.06 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
from . import math, misc
from .types import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
4
5
6
7
8


def cartesian2spherical(cart: torch.Tensor, inverse_r: bool = False) -> torch.Tensor:
    """
    Convert coordinates from Cartesian to Spherical

Nianchen Deng's avatar
sync    
Nianchen Deng committed
9
10
11
    :param cart `Tensor([N...,] 3)`: coordinates in Cartesian
    :param inverse_r: whether to convert r to reciprocal form, defaults to `False`
    :return `Tensor([N...,] 3)`: coordinates in Spherical ([r | 1/r], theta, phi)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
    """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
13
14
15
16
17
18
19
20
21
    #rho = torch.sqrt(torch.sum(cart * cart, dim=-1))
    #theta = -torch.atan(cart[..., 0] / cart[..., 2]) + (cart[..., 2] < 0) * math.pi + 0.5 * math.pi
    #if inverse_r:
    #    rho = rho.reciprocal()
    #    phi = torch.acos(cart[..., 1] * rho)
    #else:
    #    phi = torch.acos(cart[..., 1] / rho)
    #return torch.stack([rho, theta, phi], dim=-1)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
22
    rho = torch.sqrt(torch.sum(cart * cart, dim=-1))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
23
    theta = misc.get_angle(cart[..., 2], cart[..., 0])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
24
25
    if inverse_r:
        rho = rho.reciprocal()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
        phi = torch.asin(cart[..., 1] * rho)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
27
    else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
28
        phi = torch.asin(cart[..., 1] / rho)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
30
31
    return torch.stack([rho, theta, phi], dim=-1)


Nianchen Deng's avatar
Nianchen Deng committed
32
def spherical2cartesian(spher: torch.Tensor, inverse_r: bool = False) -> torch.Tensor:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
33
34
35
    """
    Convert coordinates from Spherical to Cartesian

Nianchen Deng's avatar
sync    
Nianchen Deng committed
36
37
38
    :param spher `Tensor([N...,] 3)`: coordinates in Spherical  ([r | 1/r], theta, phi)
    :param inverse_r `bool`: whether r is in reciprocal form, defaults to `False`
    :return `Tensor([N...,] 3)`:, coordinates in Cartesian
Nianchen Deng's avatar
sync    
Nianchen Deng committed
39
40
    """
    rho = spher[..., 0]
Nianchen Deng's avatar
Nianchen Deng committed
41
42
    if inverse_r:
        rho = rho.reciprocal()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
44
    sin_theta_phi = torch.sin(spher[..., 1:3])
    cos_theta_phi = torch.cos(spher[..., 1:3])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
46
47
    x = rho * sin_theta_phi[..., 0] * cos_theta_phi[..., 1]
    y = rho * sin_theta_phi[..., 1]
    z = rho * cos_theta_phi[..., 0] * cos_theta_phi[..., 1]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
48
49
50
    return torch.stack([x, y, z], dim=-1)


Nianchen Deng's avatar
sync    
Nianchen Deng committed
51
def ray_sphere_intersect(rays: Rays, r: torch.Tensor) -> torch.Tensor:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
53
54
    """
    Calculate intersections of each rays and each spheres

Nianchen Deng's avatar
sync    
Nianchen Deng committed
55
56
57
    :param rays `Rays(B)`: rays
    :param r `Tensor(P)`: , radius of spheres
    :return `Tensor(B, P)`: depths of intersections along rays
Nianchen Deng's avatar
sync    
Nianchen Deng committed
58
59
    """
    # p, v: Expand to (B, 1, 3)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
60
61
    p = rays.rays_o.unsqueeze(1)
    v = rays.rays_d.unsqueeze(1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
62
63
64
65
    # pp, vv, pv: (B, 1)
    pp = (p * p).sum(dim=2)
    vv = (v * v).sum(dim=2)
    pv = (p * v).sum(dim=2)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
66
67
    z = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv) # (B, P)
    return z
Nianchen Deng's avatar
sync    
Nianchen Deng committed
68
69


Nianchen Deng's avatar
sync    
Nianchen Deng committed
70
def get_rot_matrix(theta: float | torch.Tensor, phi: float | torch.Tensor) -> torch.Tensor:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
71
72
73
    """
    Get rotation matrix from angles in spherical space

Nianchen Deng's avatar
sync    
Nianchen Deng committed
74
75
76
    :param theta `Tensor([N...,] 1) | float`: rotation angles around y axis
    :param phi  `Tensor([N...,] 1) | float`: rotation angles around x axis
    :return: `Tensor([N...,] 3, 3)` rotation matrices
Nianchen Deng's avatar
sync    
Nianchen Deng committed
77
78
79
80
81
82
83
    """
    if not isinstance(theta, torch.Tensor):
        theta = torch.tensor([theta])
    if not isinstance(phi, torch.Tensor):
        phi = torch.tensor([phi])
    spher = torch.cat([torch.ones_like(theta), theta, phi], dim=-1)
    print(spher)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
84
    forward = spherical2cartesian(spher)  # ([N...,] 3)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
85
86
87
    up = torch.tensor([0.0, 1.0, 0.0])
    forward, up = torch.broadcast_tensors(forward, up)
    print(forward, up)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
88
89
    right = torch.cross(forward, up, dim=-1)  # ([N...,] 3)
    up = torch.cross(right, forward, dim=-1)  # ([N...,] 3)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
90
    print(right, up, forward)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
91
    return torch.stack([right, up, forward], dim=-2)  # ([N...,] 3, 3)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


def calc_local_dir(dirs, spherical_coords, pts):
    """
    [summary]

    :param dirs `Tensor(B, 3)`: 
    :param spherical_coords `Tensor(B, N, 3)`: 
    :param pts `Tensor(B, N, 3)`: 
    :return `Tensor(B, N, 2)`
    """
    local_z = pts / pts.norm(dim=-1, keepdim=True)
    local_x = spherical2cartesian(
        spherical_coords + torch.tensor([0, math.radians(0.1), 0], device=spherical_coords.device)) - pts
    local_x = local_x / local_x.norm(dim=-1, keepdim=True)
    local_y = torch.cross(local_x, local_z, -1)
    local_rot = torch.stack([local_x, local_y, local_z], dim=-2)  # (B, N, 3, 3)
    return cartesian2spherical(torch.matmul(dirs[:, None, None, :], local_rot)) \
Nianchen Deng's avatar
sync    
Nianchen Deng committed
110
        .squeeze(-2)[..., 1:3]