sphere.py 3.63 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import List, Union
import torch
import math
from . import misc


def cartesian2spherical(cart: torch.Tensor, inverse_r: bool = False) -> torch.Tensor:
    """
    Convert coordinates from Cartesian to Spherical

    :param cart `Tensor(..., 3)`: coordinates in Cartesian
    :param inverse_r: whether to inverse r
    :return `Tensor(..., 3)`: coordinates in Spherical (r, theta, phi)
    """
    rho = torch.sqrt(torch.sum(cart * cart, dim=-1))
    theta = misc.get_angle(cart[..., 0], cart[..., 2])
    if inverse_r:
        rho = rho.reciprocal()
        phi = torch.acos(cart[..., 1] * rho)
    else:
        phi = torch.acos(cart[..., 1] / rho)
    return torch.stack([rho, theta, phi], dim=-1)


def spherical2cartesian(spher: torch.Tensor) -> torch.Tensor:
    """
    Convert coordinates from Spherical to Cartesian

    :param spher: ... x 3, coordinates in Spherical
    :return: ... x 3, coordinates in Cartesian (r, theta, phi)
    """
    rho = spher[..., 0]
    sin_theta_phi = torch.sin(spher[..., 1:3])
    cos_theta_phi = torch.cos(spher[..., 1:3])
    x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1]
    y = rho * cos_theta_phi[..., 1]
    z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1]
    return torch.stack([x, y, z], dim=-1)


def ray_sphere_intersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    """
    Calculate intersections of each rays and each spheres

    :param p `Tensor(B, 3)`: positions of rays
    :param v `Tensor(B, 3)`: directions of rays
    :param r `Tensor(N)`: , radius of spheres
    :return `Tensor(B, N, 3)`: points of intersection
    :return `Tensor(B, N)`: depths of intersection along ray
    """
    # p, v: Expand to (B, 1, 3)
    p = p.unsqueeze(1)
    v = v.unsqueeze(1)
    # pp, vv, pv: (B, 1)
    pp = (p * p).sum(dim=2)
    vv = (v * v).sum(dim=2)
    pv = (p * v).sum(dim=2)
    depths = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv)
    return p + depths[..., None] * v, depths


def get_rot_matrix(theta: Union[float, torch.Tensor], phi: Union[float, torch.Tensor]) -> torch.Tensor:
    """
    Get rotation matrix from angles in spherical space

    :param theta `Tensor(..., 1) | float`: rotation angles around y axis
    :param phi  `Tensor(..., 1) | float`: rotation angles around x axis
    :return: `Tensor(..., 3, 3)` rotation matrices
    """
    if not isinstance(theta, torch.Tensor):
        theta = torch.tensor([theta])
    if not isinstance(phi, torch.Tensor):
        phi = torch.tensor([phi])
    spher = torch.cat([torch.ones_like(theta), theta, phi], dim=-1)
    print(spher)
    forward = spherical2cartesian(spher)  # (..., 3)
    up = torch.tensor([0.0, 1.0, 0.0])
    forward, up = torch.broadcast_tensors(forward, up)
    print(forward, up)
    right = torch.cross(forward, up, dim=-1)  # (..., 3)
    up = torch.cross(right, forward, dim=-1)  # (..., 3)
    print(right, up, forward)
    return torch.stack([right, up, forward], dim=-2)  # (..., 3, 3)


def calc_local_dir(dirs, spherical_coords, pts):
    """
    [summary]

    :param dirs `Tensor(B, 3)`: 
    :param spherical_coords `Tensor(B, N, 3)`: 
    :param pts `Tensor(B, N, 3)`: 
    :return `Tensor(B, N, 2)`
    """
    local_z = pts / pts.norm(dim=-1, keepdim=True)
    local_x = spherical2cartesian(
        spherical_coords + torch.tensor([0, math.radians(0.1), 0], device=spherical_coords.device)) - pts
    local_x = local_x / local_x.norm(dim=-1, keepdim=True)
    local_y = torch.cross(local_x, local_z, -1)
    local_rot = torch.stack([local_x, local_y, local_z], dim=-2)  # (B, N, 3, 3)
    return cartesian2spherical(torch.matmul(dirs[:, None, None, :], local_rot)) \
        .squeeze(-2)[..., 1:3]