import re import glm import csv import numpy as np from typing import SupportsFloat from itertools import repeat from . import math from .types import * from .device import * gvec_type = [glm.dvec1, glm.dvec2, glm.dvec3, glm.dvec4] gmat_type = [[glm.dmat2, glm.dmat2x3, glm.dmat2x4], [glm.dmat3x2, glm.dmat3, glm.dmat3x4], [glm.dmat4x2, glm.dmat4x3, glm.dmat4]] def smooth_step(x0, x1, x): y = torch.clamp((x - x0) / (x1 - x0), 0, 1) return y * y * (3 - 2 * y) def torch2np(input: torch.Tensor) -> np.ndarray: return input.cpu().detach().numpy() def torch2glm(input): input = input.squeeze() size = input.size() if len(size) == 1: if size[0] <= 0 or size[0] > 4: raise ValueError return gvec_type[size[0] - 1](torch2np(input)) if len(size) == 2: if size[0] <= 1 or size[0] > 4 or size[1] <= 1 or size[1] > 4: raise ValueError return gmat_type[size[1] - 2][size[0] - 2](torch2np(input)) raise ValueError def glm2torch(val) -> torch.Tensor: return torch.from_numpy(np.array(val)) def grid2d(rows: int, cols: int = None, normalize: bool = False, indexing: str = "xy", device: torch.device = None) -> torch.Tensor: """ Generate a 2D grid :param rows `int`: number of rows :param cols `int`: number of columns :param normalize `bool`: whether return coords in normalized space, defaults to False :param indexing `str`: specify the order of returned coordinates. Optional values are "xy" and "ij", defaults to "xy" :return `Tensor(R, C, 2)`: the coordinates of the grid """ if cols is None: cols = rows i, j = torch.meshgrid(torch.arange(rows, device=device), torch.arange(cols, device=device), indexing="ij") # (R, C) if normalize: i.div_(rows - 1) j.div_(cols - 1) return torch.stack([j, i] if indexing == "xy" else [i, j], 2) # (R, C, 2) def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: angle = -torch.atan(x / y) - (y < 0) * math.pi + 0.5 * math.pi return angle def format_time(seconds): days = int(seconds / 3600 / 24) seconds = seconds - days * 3600 * 24 hours = int(seconds / 3600) seconds = seconds - hours * 3600 minutes = int(seconds / 60) seconds = seconds - minutes * 60 seconds_final = int(seconds) seconds = seconds - seconds_final millis = int(seconds * 1000) if days > 0: output = f"{days}D{hours:0>2d}h{minutes:0>2d}m" elif hours > 0: output = f"{hours:0>2d}h{minutes:0>2d}m{seconds_final:0>2d}s" elif minutes > 0: output = f"{minutes:0>2d}m{seconds_final:0>2d}s" elif seconds_final > 0: output = f"{seconds_final:0>2d}s{millis:0>3d}ms" elif millis > 0: output = f"{millis:0>3d}ms" else: output = '0ms' return output def masked_scatter(mask: torch.Tensor, value: torch.Tensor, initial: torch.Tensor | SupportsFloat = 0): """ Extend PyTorch's built-in `masked_scatter` function :param mask `Tensor(M...)`: the boolean mask :param value `Tensor(N, D...)`: the value to fill in with, should have at least as many elements as the number of ones in `mask` :param initial `Tensor(M..., D...)|Number`: the initial values. Could be a tensor or a number. If specified by a number, a new tensor filled with the number will be created as the initial values. Defaults to 0 :return `Tensor(M..., D...)`: the result tensor """ M_ = mask.size() D_ = value.size()[1:] if not isinstance(initial, torch.Tensor): initial = value.new_full([*M_, *D_], initial) return initial.masked_scatter(mask.reshape(*M_, *repeat(1, len(D_))), value) def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int): start, end = re.search(r'%0\dd', file_pattern).span() prefix, suffix = start, len(file_pattern) - end seqs = [ int(path.name[prefix:-suffix]) for path in dir.glob(re.sub(r'%0\dd', "*", file_pattern)) ] seqs.sort(reverse=offset > 0) for i in seqs: (dir / (file_pattern % i)).rename(dir / (file_pattern % (i + offset))) def calculate_autosize(max_size: int, *sizes: int) -> tuple[list[int], int]: sizes = list(sizes) sum_size = sum(sizes) for i in range(len(sizes)): if sizes[i] == -1: sizes[i] = max_size - sum_size - 1 sum_size = max_size break if sum_size > max_size: raise ValueError("The sum of 'sizes' exceeds 'max_size'") if any([size < 0 for size in sizes]): raise ValueError( "Only one of the 'sizes' could be -1 and all others must be positive or zero") return sizes, sum_size def union(*tensors: torch.Tensor | SupportsFloat) -> torch.Tensor: try: first_tensor = next((item for item in tensors if isinstance(item, torch.Tensor))) except StopIteration: raise ValueError("Arguments should contain at least one tensor") tensors = [ item if isinstance(item, torch.Tensor) else first_tensor.new_tensor([item]) for item in tensors ] if any(item.device != first_tensor.device or item.dtype != first_tensor.dtype for item in tensors): raise ValueError("All tensors should have same dtype and locate on same device") shape = torch.broadcast_shapes(*(item.shape[:-1] for item in tensors)) return torch.cat([item.expand(*shape, -1) for item in tensors], dim=-1) def split(tensor: torch.Tensor, *sizes: int) -> tuple[torch.Tensor, ...]: sizes, tot_size = calculate_autosize(tensor.shape[-1], *sizes) if tot_size < tensor.shape[-1]: sizes = [*sizes, tensor.shape[-1] - tot_size] return torch.split(tensor, sizes, -1)[:-1] else: return torch.split(tensor, sizes, -1) def dump_tensors_to_csv(path, *tensors: torch.Tensor, open_mode="w"): data = [] for tensor in tensors: if len(tensor.shape) == 1: tensor = tensor[:, None] elif len(tensor.shape) > 2: tensor = tensor.flatten(1, -1) tensor_data = tensor.tolist() if not data: data = [ [ f"{value:.6e}" if isinstance(value, float) else f"{value}" for value in tensor_data[i] ] for i in range(len(tensor_data)) ] else: for i, row in enumerate(data): row += [ f"{value:.6e}" if isinstance(value, float) else f"{value}" for value in tensor_data[i] ] with open(path, open_mode) as fp: csv_writer = csv.writer(fp) csv_writer.writerows(data)