sphere.py 3.7 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import List, Union
import torch
import math
from . import misc


def cartesian2spherical(cart: torch.Tensor, inverse_r: bool = False) -> torch.Tensor:
    """
    Convert coordinates from Cartesian to Spherical

    :param cart `Tensor(..., 3)`: coordinates in Cartesian
    :param inverse_r: whether to inverse r
    :return `Tensor(..., 3)`: coordinates in Spherical (r, theta, phi)
    """
    rho = torch.sqrt(torch.sum(cart * cart, dim=-1))
    theta = misc.get_angle(cart[..., 0], cart[..., 2])
    if inverse_r:
        rho = rho.reciprocal()
        phi = torch.acos(cart[..., 1] * rho)
    else:
        phi = torch.acos(cart[..., 1] / rho)
    return torch.stack([rho, theta, phi], dim=-1)


Nianchen Deng's avatar
Nianchen Deng committed
25
def spherical2cartesian(spher: torch.Tensor, inverse_r: bool = False) -> torch.Tensor:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
27
28
29
30
31
32
    """
    Convert coordinates from Spherical to Cartesian

    :param spher: ... x 3, coordinates in Spherical
    :return: ... x 3, coordinates in Cartesian (r, theta, phi)
    """
    rho = spher[..., 0]
Nianchen Deng's avatar
Nianchen Deng committed
33
34
    if inverse_r:
        rho = rho.reciprocal()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    sin_theta_phi = torch.sin(spher[..., 1:3])
    cos_theta_phi = torch.cos(spher[..., 1:3])
    x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1]
    y = rho * cos_theta_phi[..., 1]
    z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1]
    return torch.stack([x, y, z], dim=-1)


def ray_sphere_intersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    """
    Calculate intersections of each rays and each spheres

    :param p `Tensor(B, 3)`: positions of rays
    :param v `Tensor(B, 3)`: directions of rays
    :param r `Tensor(N)`: , radius of spheres
    :return `Tensor(B, N, 3)`: points of intersection
    :return `Tensor(B, N)`: depths of intersection along ray
    """
    # p, v: Expand to (B, 1, 3)
    p = p.unsqueeze(1)
    v = v.unsqueeze(1)
    # pp, vv, pv: (B, 1)
    pp = (p * p).sum(dim=2)
    vv = (v * v).sum(dim=2)
    pv = (p * v).sum(dim=2)
    depths = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv)
    return p + depths[..., None] * v, depths


def get_rot_matrix(theta: Union[float, torch.Tensor], phi: Union[float, torch.Tensor]) -> torch.Tensor:
    """
    Get rotation matrix from angles in spherical space

    :param theta `Tensor(..., 1) | float`: rotation angles around y axis
    :param phi  `Tensor(..., 1) | float`: rotation angles around x axis
    :return: `Tensor(..., 3, 3)` rotation matrices
    """
    if not isinstance(theta, torch.Tensor):
        theta = torch.tensor([theta])
    if not isinstance(phi, torch.Tensor):
        phi = torch.tensor([phi])
    spher = torch.cat([torch.ones_like(theta), theta, phi], dim=-1)
    print(spher)
    forward = spherical2cartesian(spher)  # (..., 3)
    up = torch.tensor([0.0, 1.0, 0.0])
    forward, up = torch.broadcast_tensors(forward, up)
    print(forward, up)
    right = torch.cross(forward, up, dim=-1)  # (..., 3)
    up = torch.cross(right, forward, dim=-1)  # (..., 3)
    print(right, up, forward)
    return torch.stack([right, up, forward], dim=-2)  # (..., 3, 3)


def calc_local_dir(dirs, spherical_coords, pts):
    """
    [summary]

    :param dirs `Tensor(B, 3)`: 
    :param spherical_coords `Tensor(B, N, 3)`: 
    :param pts `Tensor(B, N, 3)`: 
    :return `Tensor(B, N, 2)`
    """
    local_z = pts / pts.norm(dim=-1, keepdim=True)
    local_x = spherical2cartesian(
        spherical_coords + torch.tensor([0, math.radians(0.1), 0], device=spherical_coords.device)) - pts
    local_x = local_x / local_x.norm(dim=-1, keepdim=True)
    local_y = torch.cross(local_x, local_z, -1)
    local_rot = torch.stack([local_x, local_y, local_z], dim=-2)  # (B, N, 3, 3)
    return cartesian2spherical(torch.matmul(dirs[:, None, None, :], local_rot)) \
Nianchen Deng's avatar
Nianchen Deng committed
104
        .squeeze(-2)[..., 1:3]