Commit 1bc644a1 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 6294701e
from operator import countOf
from types import UnionType
from typing_extensions import Self
from configargparse import ArgumentParser, Namespace
from .types import *
class BaseArgs(Namespace):
@property
def defaults(self) -> dict[str, Any]:
return {
key: getattr(self.__class__, key)
for key in self.__annotations__ if hasattr(self.__class__, key)
}
def __init__(self, **kwargs) -> None:
super().__init__(**self.defaults | kwargs)
def merge_with(self, dict: dict[str, Any]) -> Self:
return self.__class__(**vars(self) | dict)
def parse(self, config_path: PathLike = None, debug: bool = False) -> Self:
parser = ArgumentParser(default_config_files=[f"{config_path}"] if config_path else [])
self.setup_parser(parser, debug)
return parser.parse_known_args(namespace=self)[0]
def setup_parser(self, parser: ArgumentParser, debug: bool = False):
def build_debug_str(key: str, params_for_parser: dict[str, Any], prefix="parser") -> str:
def to_str(value): return value.__name__ if isinstance(value, Type) else (
f"\"{value}\"" if isinstance(value, str) else value.__str__())
params_str = ", ".join([
f"{name}={to_str(value)}" for name, value in params_for_parser.items()
])
return f"{prefix}.add_argument(\"--{key}\", {params_str})"
def add_argument(parser: ArgumentParser, key: str, type: Type, required: bool, **kwargs):
params = {}
if type == bool:
bool_group = parser.add_mutually_exclusive_group()
bool_group.add_argument(f"--{key}", action="store_true")
bool_group.add_argument(f"--no-{key}", action="store_false", dest=key)
if debug:
print("bool_group = parser.add_mutually_exclusive_group()")
print(build_debug_str(key, {"action": "store_true"}, "bool_group"))
print(build_debug_str(f"no-{key}", {"action": "store_false", "dest": key},
"bool_group"))
else:
params["type"] = type
if "nargs" in kwargs:
params["nargs"] = kwargs["nargs"]
if "default" in kwargs:
params["default"] = kwargs["default"]
elif required:
params["required"] = True
parser.add_argument(f"--{key}", **params)
if debug:
print(build_debug_str(key, params))
for key, arg_type in self.__annotations__.items():
required = True
kwargs = {}
if isinstance(arg_type, UnionType):
if len(arg_type.__args__) != 2 or countOf(arg_type.__args__, type(None)) != 1:
raise ValueError(f"{key} cannot be union of two or more different types")
arg_type = arg_type.__args__[0] if arg_type.__args__[1] == type(None) \
else arg_type.__args__[1]
required = False
if getattr(arg_type, "__origin__", None) == list:
arg_type = arg_type.__args__[0]
kwargs["nargs"] = "*"
elif getattr(arg_type, "__origin__", None) == tuple:
arg_type = arg_type.__args__[0]
if any([arg != arg_type for arg in arg_type.__args__]):
raise ValueError(f"{key} cannot be tuple of different types")
kwargs["nargs"] = len(arg_type.__args__)
if hasattr(self, key):
kwargs["default"] = getattr(self, key)
add_argument(parser, key, arg_type, required, **kwargs)
......@@ -257,7 +257,7 @@ def read_points3d_binary(path_to_model_file):
return points3D
def read_model(path, ext):
def read_model(path, ext) -> tuple[dict[int, Camera], dict[int, Image], dict[int, Point3D]]:
if ext == ".txt":
cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
images = read_images_text(os.path.join(path, "images" + ext))
......
import json5
from .types import *
__all__ = ["load_from_json", "get_type_and_args"]
def load_from_json(json_path: PathLike) -> dict[str, Any]:
try:
with Path(json_path).open() as fp:
config: dict[str, Any] = json5.load(fp)
except Exception:
raise ValueError(f"{json_path} is not a valid json file")
if "parent" in config:
parent_config = load_from_json((json_path.parent / config.pop("parent")).with_suffix(".json"))
config["model"][1] = parent_config["model"][1] | config["model"][1]
config["train"][1] = parent_config["train"][1] | config["train"][1]
return config
def get_type_and_args(config_item: dict[str, Any] | list | str, default_type=None, default_args={}) -> tuple[str, dict[str, Any]]:
match config_item:
case None:
return default_type, default_args
case str() as type_name:
return type_name, default_args
case dict() as args:
return default_type, default_args | args
case str() as type_name, dict() as args:
return type_name, default_args | args
case _:
raise ValueError("\"config_item\" is invalid")
env = None
env = {}
def get_env():
return env
def get(key: str):
return env.get(key)
def get_all() -> dict:
return env
def set_env(new_env: dict):
def set(**kwargs):
global env
env = new_env
env |= kwargs
import torch
import itertools
from os import PathLike
from typing import List, Dict
from .nn import Module
class ModelExporter(object):
inputs: dict[str, list[int]]
output_names: list[str]
module: Module
@property
def input_names(self) -> list[str]:
return list(self.inputs.keys())
@property
def module(self) -> Module:
return self.fn.__self__
def __init__(self, fn, *outputs: str, **inputs: list[int]) -> None:
super().__init__()
self.inputs = inputs
self.output_names = list(outputs)
self.fn = fn
def prepare_inputs(self, batch_size: int = None):
return tuple(
torch.rand(batch_size or 1, *size, device=self.module.device)
for size in self.inputs.values()
)
def export_onnx(self, path: PathLike, batch_size: int = None, **kwargs):
dynamic_axes = {
name: {0: "batch_size"}
for name in itertools.chain(self.input_names, self.output_names)
} if not batch_size else None
# Replace module's forward method with target method and recover later
self.module.forward = self.fn
kwargs = {
"export_params": True, # store the trained parameter weights inside the model file
"opset_version": 10, # the ONNX version to export the model to
"do_constant_folding": True, # whether to execute constant folding for optimization
**kwargs
}
torch.onnx.export(self.module, self.prepare_inputs(batch_size), path,
input_names=self.input_names, # the model's input names
output_names=self.output_names, # the model's output names
dynamic_axes=dynamic_axes, # variable length axes
**kwargs)
self.module.forward = self.module._forward
......@@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import numpy as np
import torch
import torch.nn.functional as F
......@@ -215,7 +214,7 @@ def pruning_points(feats, points, scores, depth=0, th=0.5):
return feats, points
def offset_points(point_xyz: torch.Tensor, half_voxel: Union[torch.Tensor, int, float] = 1,
def offset_points(point_xyz: torch.Tensor, half_voxel: torch.Tensor | int | float = 1,
offset_only: bool = False, bits: int = 2) -> torch.Tensor:
"""
[summary]
......
import os
from pathlib import Path
import shutil
import torch
import uuid
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as nn_f
from typing import List, Tuple, Union
from . import misc
from . import math
from . import misc, math
from .types import *
def is_image_file(filename):
"""
......@@ -19,7 +19,7 @@ def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def np2torch(img, permute=True):
def np2torch(img: np.ndarray, permute: bool = True) -> torch.Tensor:
"""
Convert numpy-images(s) to torch-image(s), permute channels dim if `permute=True`
......@@ -28,8 +28,7 @@ def np2torch(img, permute=True):
"""
batch_input = len(img.shape) == 4
if permute:
t = torch.from_numpy(np.transpose(
img, [0, 3, 1, 2] if batch_input else [2, 0, 1]))
t = torch.from_numpy(np.transpose(img, [0, 3, 1, 2] if batch_input else [2, 0, 1]))
else:
t = torch.from_numpy(img)
if not batch_input:
......@@ -76,49 +75,51 @@ def load_seq(path: str, n: int, permute=True, with_alpha=False) -> torch.Tensor:
return load([path % i for i in range(n)], permute=permute, with_alpha=with_alpha)
def save(input: torch.Tensor, *paths: Union[str, Path, List[Union[str, Path]]]):
def save(input: torch.Tensor | np.ndarray, *paths: PathLike | list[PathLike]):
"""
Save one or multiple torch-image(s) to `paths`
:param input `torch.Tensor`: torch-image(s) to save
:param input `Tensor|ndarray`: torch-image(s) to save
:param *paths `str...`: paths to save torch-image(s) to
:raises `ValueError`: if number of paths does not match batches of input image(s)
"""
new_paths = []
for path in paths:
new_paths += [path] if isinstance(path, (str, Path)) else list(path)
if len(input.size()) < 4:
if len(input.shape) < 4:
input = input[None]
if input.size(0) != len(new_paths):
if input.shape[0] != len(new_paths):
raise ValueError
np_img = torch2np(input)
np_img = torch2np(input) if isinstance(input, torch.Tensor) else input
if np_img.dtype.kind == 'f':
np_img = np.clip(np_img, 0, 1)
if np_img.shape[-1] == 1:
np_img = np.repeat(np_img, 3, axis=-1)
if not np_img.flags['C_CONTIGUOUS']:
np_img = np.ascontiguousarray(np_img)
for i, path in enumerate(new_paths):
plt.imsave(path, np_img[i])
def save_seq(input: torch.Tensor, path: Union[str, Path]):
def save_seq(input: torch.Tensor, path: str | Path):
n = 1 if len(input.size()) <= 3 else input.size(0)
return save(input, [str(path) % i for i in range(n)])
def plot(input: torch.Tensor, *, ax: plt.Axes = None):
def plot(input: torch.Tensor | np.ndarray, *, ax: plt.Axes = None):
"""
Plot a torch-image using matplotlib
:param input `Tensor(HW|[B]CHW|[B]HWC)`: 2D, 3D or 4D torch-image(s)
:param ax `plt.Axes`: (optional) specify the axes to plot image
"""
im = torch2np(input)
im = torch2np(input) if isinstance(input, torch.Tensor) else input
if len(im.shape) == 4:
im = im[0]
return plt.imshow(im) if ax is None else ax.imshow(im)
def save_video(frames: torch.Tensor, path: Union[str, Path], fps: int,
def save_video(frames: torch.Tensor, path: str | Path, fps: int,
repeat: int = 1, pingpong: bool = False):
"""
Encode and save a sequence of frames as video file
......@@ -136,14 +137,14 @@ def save_video(frames: torch.Tensor, path: Union[str, Path], fps: int,
frames = frames.expand(repeat, -1, -1, -1, -1).flatten(0, 1)
path = Path(path)
tempdir = Path('/dev/shm/dvs_tmp/video')
inferout = tempdir / path.stem / f"%04d.bmp"
os.makedirs(inferout.parent, exist_ok=True)
os.makedirs(path.parent, exist_ok=True)
tempdir = Path(f'/dev/shm/dvs_tmp/video/{uuid.uuid4().hex}')
temp_frame_files = tempdir / f"%04d.bmp"
path.parent.mkdir(parents=True, exist_ok=True)
tempdir.mkdir(parents=True, exist_ok=True)
save_seq(frames, inferout)
os.system(f'ffmpeg -y -r {fps:d} -i {inferout} -c:v libx264 {path}')
shutil.rmtree(inferout.parent)
save_seq(frames, temp_frame_files)
os.system(f'ffmpeg -y -r {fps:d} -i {temp_frame_files} -c:v libx264 {path}')
shutil.rmtree(tempdir)
def horizontal_shift(input: torch.Tensor, offset: int, dim=-1) -> torch.Tensor:
......@@ -165,7 +166,7 @@ def horizontal_shift(input: torch.Tensor, offset: int, dim=-1) -> torch.Tensor:
return shifted
def translate(input: torch.Tensor, offset: Tuple[float, float]) -> torch.Tensor:
def translate(input: torch.Tensor, offset: tuple[float, float]) -> torch.Tensor:
theta = torch.tensor([
[1, 0, -offset[0] / input.size(-1) * 2],
[0, 1, -offset[1] / input.size(-2) * 2]
......
......@@ -62,7 +62,7 @@ def input_ex(prompt, *actions, default=None):
return s
def input_enum(prompt, complete_list: List[str], *, err_msg: str, default=None):
def input_enum(prompt, complete_list: list[str], *, err_msg: str, default=None):
readline.set_completer(make_completer(complete_list))
prompt_default = '(Default: %s) ' % default if default != None else ''
while True:
......
import sys
from logging import *
from pathlib import Path
enable_logging = False
def _log_exception(exc_type, exc_value, exc_traceback):
if not issubclass(exc_type, KeyboardInterrupt):
exception(exc_value, exc_info=(exc_type, exc_value, exc_traceback))
sys.__excepthook__(exc_type, exc_value, exc_traceback)
def initialize(path: Path):
global enable_logging
basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=INFO,
filename=path, filemode='a' if path.exists() else 'w')
sys.excepthook = _log_exception
enable_logging = True
def print_and_log(msg: str):
print(msg)
if enable_logging:
info(msg)
......@@ -4,7 +4,7 @@ from torch import nn
class CombinedLoss(nn.Module):
def __init__(self, loss_modules: List[nn.Module], weights: List[float]):
def __init__(self, loss_modules: list[nn.Module], weights: list[float]):
super().__init__()
self.loss_modules = nn.ModuleList(loss_modules)
self.weights = weights
......
from torch.nn import L1Loss, MSELoss
from torch.nn.functional import l1_loss, mse_loss
from .ssim import SSIM
from .ssim import ssim, SSIM
from .perc_loss import VGGPerceptualLoss
from .cauchy import cauchy_loss, CauchyLoss
\ No newline at end of file
from .cauchy import cauchy_loss, CauchyLoss
from .lpips import lpips_loss, LpipsLoss
\ No newline at end of file
import torch
def cauchy_loss(input: torch.Tensor, target: torch.Tensor = None, *, s = 1.0):
def cauchy_loss(input: torch.Tensor, target: torch.Tensor = None, *, s=1.0, sum=False):
x = input - target if target is not None else input
return (s * x * x * 0.5 + 1).log().mean()
y = (s * x * x * 0.5 + 1).log()
return y.sum() if sum else y.mean()
class CauchyLoss(torch.nn.Module):
def __init__(self, s = 1.0):
def __init__(self, s=1.0, sum=False):
super().__init__()
self.s = s
self.sum = sum
def forward(self, input: torch.Tensor, target: torch.Tensor = None):
return cauchy_loss(input, target, s=self.s)
return cauchy_loss(input, target, s=self.s, sum=self.sum)
import torch
from lpips import LPIPS
class LpipsLoss(torch.nn.Module):
def __init__(self, net: str = "alex") -> None:
super().__init__()
self.fn = LPIPS(net)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return self.fn(input * 2. - 1., target * 2. - 1.)
default_loss_fn = None
def lpips_loss(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
global default_loss_fn
if default_loss_fn is None:
default_loss_fn = LpipsLoss().to(input.device)
return default_loss_fn(input, target)
import numpy as np
import torch
from typing import TypeVar
from math import *
TensorType = TypeVar('TensorType', torch.Tensor, np.ndarray)
huge = 1e10
tiny = 1e-6
tiny = 1e-5
def expected_sin(x: torch.Tensor, x_var: torch.Tensor):
......@@ -93,3 +97,12 @@ def cylinder_to_gaussian(d: torch.Tensor, t0: float, t1: float, radius: float, d
r_var = radius**2 / 4
t_var = (t1 - t0)**2 / 12
return lift_gaussian(d, t_mean, t_var, r_var, diag)
def lerp(t, range):
return t * (range[1] - range[0]) + range[0]
def normalize(value: TensorType) -> TensorType:
if isinstance(value, torch.Tensor):
return value / torch.norm(value, dim=-1, keepdim=True)
else:
return value / np.linalg.norm(value, axis=-1, keepdims=True)
\ No newline at end of file
from cgitb import enable
import torch
from .device import *
def simple_memory_state(device: torch.device = None) -> str:
return f"PyTorch allocates {torch.cuda.memory_allocated(device)/1024/1024:.2f}MB "\
f"(peak is {torch.cuda.max_memory_allocated(device)/1024/1024:.2f}MB) and "\
f"reserves {torch.cuda.memory_reserved(device)/1024/1024:.2f}MB memory"
class MemProfiler:
enable = False
......@@ -20,8 +25,7 @@ class MemProfiler:
else:
delta_str = ''
print(f'{prefix}: {delta_str}currently PyTorch allocates {torch.cuda.memory_allocated(device)/1024/1024:.2f}MB and '
f'reserves {torch.cuda.memory_reserved(device)/1024/1024:.2f}MB memory')
f'reserves {torch.cuda.memory_reserved(device)/1024/1024:.2f}MB memory')
def __init__(self, name, device=None) -> None:
self.name = name
......@@ -33,4 +37,4 @@ class MemProfiler:
return self
def __exit__(self, exc_type, exc_val, exc_traceback):
MemProfiler.print_memory_stats(self.name, self.alloc0, self.device)
\ No newline at end of file
MemProfiler.print_memory_stats(self.name, self.alloc0, self.device)
from itertools import repeat
import logging
from pathlib import Path
import re
import shutil
import torch
import glm
import csv
import numpy as np
from typing import List, Tuple, Union
from torch.types import Number
from typing import SupportsFloat
from itertools import repeat
from . import math
from .types import *
from .device import *
......@@ -46,23 +43,26 @@ def glm2torch(val) -> torch.Tensor:
return torch.from_numpy(np.array(val))
def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False, device: torch.device = None) -> torch.Tensor:
def grid2d(rows: int, cols: int = None, normalize: bool = False, indexing: str = "xy",
device: torch.device = None) -> torch.Tensor:
"""
Generate a mesh grid
:param *size: grid size (rows, columns)
:param normalize: return coords in normalized space? defaults to False
:param swap_dim: if True, return coords in (y, x) order, defaults to False
:return: rows x columns x 2 tensor
Generate a 2D grid
:param rows `int`: number of rows
:param cols `int`: number of columns
:param normalize `bool`: whether return coords in normalized space, defaults to False
:param indexing `str`: specify the order of returned coordinates. Optional values are "xy" and "ij",
defaults to "xy"
:return `Tensor(R, C, 2)`: the coordinates of the grid
"""
if len(size) == 1:
size = (size[0], size[0])
y, x = torch.meshgrid(torch.arange(size[0], device=device),
torch.arange(size[1], device=device))
if cols is None:
cols = rows
i, j = torch.meshgrid(torch.arange(rows, device=device),
torch.arange(cols, device=device), indexing="ij") # (R, C)
if normalize:
x.div_(size[1] - 1.)
y.div_(size[0] - 1.)
return torch.stack([y, x], 2) if swap_dim else torch.stack([x, y], 2)
i.div_(rows - 1)
j.div_(cols - 1)
return torch.stack([j, i] if indexing == "xy" else [i, j], 2) # (R, C, 2)
def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
......@@ -70,52 +70,6 @@ def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return angle
def broadcast_cat(input: torch.Tensor,
s: Union[Number, List[Number], torch.Tensor],
dim=-1,
append: bool = True) -> torch.Tensor:
"""
Concatenate a tensor with a scalar along last dimension
:param input `Tensor(..., N)`: input tensor
:param s: scalar
:param append: append or prepend the scalar to input tensor
:return: `Tensor(..., N+1)`
"""
if dim != -1:
raise NotImplementedError('currently only support the last dimension')
if isinstance(s, torch.Tensor):
x = s
elif isinstance(s, list):
x = torch.tensor(s, dtype=input.dtype, device=input.device)
else:
x = torch.tensor([s], dtype=input.dtype, device=input.device)
expand_shape = list(input.size())
expand_shape[dim] = -1
x = x.expand(expand_shape)
return torch.cat([input, x] if append else [x, input], dim)
def save_2d_tensor(path, x):
with open(path, 'w', encoding='utf-8', newline='') as f:
csv_writer = csv.writer(f)
for i in range(x.shape[0]):
csv_writer.writerow(x[i])
def view_like(input: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
"""
Reshape input to be the same size as ref except the last dimension
:param input `Tensor(..., C)`: input tensor
:param ref `Tensor(B.., *): reference tensor
:return `Tensor(B.., C)`: reshaped tensor
"""
out_shape = list(ref.size())
out_shape[-1] = -1
return input.view(out_shape)
def format_time(seconds):
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
......@@ -142,24 +96,17 @@ def format_time(seconds):
return output
def print_and_log(s):
print(s)
logging.info(s)
def masked_scatter(mask: torch.Tensor, value: torch.Tensor, initial: Union[torch.Tensor, Number] = 0):
def masked_scatter(mask: torch.Tensor, value: torch.Tensor, initial: torch.Tensor | SupportsFloat = 0):
"""
Extend PyTorch's built-in `masked_scatter` function
:param mask `Tensor(M...)`: the boolean mask
:param value `Tensor(N, D...)`: the value to fill in with, should have at least as many elements
as the number of ones in `mask`
:param destination `Tensor(M..., D...)`: (optional) the destination tensor to fill,
if not specified, a new tensor filled with
`empty_value` will be created and used as destination
:param empty_value `Number`: the initial elements in the newly created destination tensor,
defaults to 0
:return `Tensor(M..., D...)`: the destination tensor after filled
as the number of ones in `mask`
:param initial `Tensor(M..., D...)|Number`: the initial values. Could be a tensor or a number.
If specified by a number, a new tensor filled with the number will be created as the initial values.
Defaults to 0
:return `Tensor(M..., D...)`: the result tensor
"""
M_ = mask.size()
D_ = value.size()[1:]
......@@ -181,28 +128,40 @@ def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int):
(dir / (file_pattern % i)).rename(dir / (file_pattern % (i + offset)))
def merge(*args: dict, **kwargs) -> dict:
ret_args = {}
for arg in args:
ret_args.update(arg)
ret_args.update(kwargs)
return ret_args
def union(*tensors: torch.Tensor) -> torch.Tensor:
return torch.cat(tensors, dim=-1)
def split(tensor: torch.Tensor, *sizes: int) -> Tuple[torch.Tensor, ...]:
def calculate_autosize(max_size: int, *sizes: int) -> tuple[list[int], int]:
sizes = list(sizes)
tot_size = sum(sizes)
sum_size = sum(sizes)
for i in range(len(sizes)):
if sizes[i] == -1:
sizes[i] = tensor.shape[-1] - tot_size - 1
tot_size = tensor.shape[-1]
sizes[i] = max_size - sum_size - 1
sum_size = max_size
break
if tot_size > tensor.shape[-1]:
raise ValueError("The total number of sizes is larger than the last dim of input tensor")
if sum_size > max_size:
raise ValueError("The sum of 'sizes' exceeds 'max_size'")
if any([size < 0 for size in sizes]):
raise ValueError("Only one element in sizes could be -1")
raise ValueError(
"Only one of the 'sizes' could be -1 and all others must be positive or zero")
return sizes, sum_size
def union(*tensors: torch.Tensor | SupportsFloat) -> torch.Tensor:
try:
first_tensor = next((item for item in tensors if isinstance(item, torch.Tensor)))
except StopIteration:
raise ValueError("Arguments should contain at least one tensor")
tensors = [
item if isinstance(item, torch.Tensor) else first_tensor.new_tensor([item])
for item in tensors
]
if any(item.device != first_tensor.device or item.dtype != first_tensor.dtype
for item in tensors):
raise ValueError("All tensors should have same dtype and locate on same device")
shape = torch.broadcast_shapes(*(item.shape[:-1] for item in tensors))
return torch.cat([item.expand(*shape, -1) for item in tensors], dim=-1)
def split(tensor: torch.Tensor, *sizes: int) -> tuple[torch.Tensor, ...]:
sizes, tot_size = calculate_autosize(tensor.shape[-1], *sizes)
if tot_size < tensor.shape[-1]:
sizes = [*sizes, tensor.shape[-1] - tot_size]
return torch.split(tensor, sizes, -1)[:-1]
......@@ -210,12 +169,28 @@ def split(tensor: torch.Tensor, *sizes: int) -> Tuple[torch.Tensor, ...]:
return torch.split(tensor, sizes, -1)
def dump_tensors_to_csv(path, *tensors: torch.Tensor, open_mode = "w"):
for i in range(len(tensors)):
if len(tensors[i].shape) == 1:
tensors[i] = tensors[i][:, None]
elif len(tensors[i].shape) > 2:
tensors[i] = tensors[i].flatten(1, -1)
def dump_tensors_to_csv(path, *tensors: torch.Tensor, open_mode="w"):
data = []
for tensor in tensors:
if len(tensor.shape) == 1:
tensor = tensor[:, None]
elif len(tensor.shape) > 2:
tensor = tensor.flatten(1, -1)
tensor_data = tensor.tolist()
if not data:
data = [
[
f"{value:.6e}" if isinstance(value, float) else f"{value}"
for value in tensor_data[i]
] for i in range(len(tensor_data))
]
else:
for i, row in enumerate(data):
row += [
f"{value:.6e}" if isinstance(value, float) else f"{value}"
for value in tensor_data[i]
]
with open(path, open_mode) as fp:
csv_writer = csv.writer(fp)
csv_writer.writerows(torch.cat(tensors, -1).tolist())
\ No newline at end of file
csv_writer.writerows(data)
from typing import List, Tuple, Union
import torch
import numpy as np
from pathlib import Path
from utils.types import *
checkpoint_file_prefix = "checkpoint_"
......@@ -12,8 +11,7 @@ def get_checkpoint_filename(epoch):
return f"{checkpoint_file_prefix}{epoch}{checkpoint_file_suffix}"
def list_epochs(directory: Union[str, Path]) -> List[int]:
directory = Path(directory)
def list_epochs(directory: Path) -> list[int]:
epoch_list = [
int(file_path.stem[len(checkpoint_file_prefix):])
for file_path in directory.glob(get_checkpoint_filename("*"))
......@@ -22,20 +20,32 @@ def list_epochs(directory: Union[str, Path]) -> List[int]:
return epoch_list
def load_checkpoint(path: Union[str, Path]) -> Tuple[dict, Path]:
path = Path(path)
def find_checkpoint(path: Path) -> Path | None:
if path.suffix != checkpoint_file_suffix:
existed_epochs = list_epochs(path)
if len(existed_epochs) == 0:
raise FileNotFoundError(f"{path} does not contain checkpoint files")
path = path / get_checkpoint_filename(existed_epochs[-1])
return path / get_checkpoint_filename(existed_epochs[-1]) if existed_epochs else None
return path if path.exists() else None
def load_checkpoint(path: Path) -> tuple[dict, Path]:
path = find_checkpoint(path)
if path is None:
raise FileNotFoundError(f"{path} does not contain checkpoint files")
return torch.load(path), path
def save_checkpoint(states_dict: dict, directory: Union[str, Path], epoch: int):
def save_checkpoint(states_dict: dict, directory: Path, epoch: int):
torch.save(states_dict, Path(directory) / get_checkpoint_filename(epoch))
def clean_checkpoint(directory: Path, keep_interval: int):
(directory / '_misc').mkdir(exist_ok=True)
for file in directory.glob(f"{checkpoint_file_prefix}*{checkpoint_file_suffix}"):
i = int(file.name[len(checkpoint_file_prefix):-len(checkpoint_file_suffix)])
if i % keep_interval != 0:
file.rename(directory / "_misc" / file.name)
def log(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment