from typing import List import math import torch import torch.nn as nn from utils.constants import * class BatchLinear(nn.Linear): ''' A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a hypernetwork. ''' __doc__ = nn.Linear.__doc__ def forward(self, input, params=None): # if params is None: # params = OrderedDict(self.named_parameters()) bias = params.get('bias', None) weight = params['weight'] output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2)) output += bias.unsqueeze(-2) return output class Sine(nn.Module): def __init__(self): super().__init__() def forward(self, input): return torch.sin(30 * input) class FcLayer(nn.Module): def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0): super().__init__() nls_and_inits = { 'sine': (Sine(), sine_init), 'relu': (nn.ReLU(), None), 'sigmoid': (nn.Sigmoid(), None), 'tanh': (nn.Tanh(), None), 'selu': (nn.SELU(), init_weights_selu), 'softplus': (nn.Softplus(), init_weights_normal), 'elu': (nn.ELU(), init_weights_elu), 'softmax': (nn.Softmax(dim=-1), softmax_init), 'logsoftmax': (nn.LogSoftmax(dim=-1), softmax_init), 'linear': (None, None) } nl, nl_weight_init = nls_and_inits[act] self.net = nn.Sequential( nn.Linear(in_chns + skip_chns, out_chns), nl ) if nl else nn.Linear(in_chns + skip_chns, out_chns) self.skip = skip_chns != 0 if nl_weight_init is not None: nl_weight_init(self.net if isinstance(self.net, nn.Linear) else self.net[0]) else: self.init_params(act) def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor: return self.net(torch.cat([x0, x], dim=-1) if self.skip else x) def get_params(self): linear_net = self.net if isinstance(self.net, nn.Linear) else self.net[0] return linear_net.weight, linear_net.bias def init_params(self, act): weight, bias = self.get_params() nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(act)) nn.init.zeros_(bias) def copy_to(self, layer): weight, bias = self.get_params() dst_weight, dst_bias = layer.get_params() dst_weight.copy_(weight) dst_bias.copy_(bias) class FcNet(nn.Module): def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int, skips: List[int] = [], act: str = 'relu', out_act = 'linear'): """ Initialize a full-connection net :kwarg in_chns: channels of input :kwarg out_chns: channels of output :kwarg nf: number of features in each hidden layer :kwarg n_layers: number of layers :kwarg skips: create skip connections from input to layers in this list """ super().__init__() self.layers = [FcLayer(in_chns, nf, act)] + [ FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0) for i in range(n_layers - 1) ] if out_chns: self.layers.append(FcLayer(nf, out_chns, out_act)) for i, layer in enumerate(self.layers): self.add_module(f"layer{i}", layer) def forward(self, x: torch.Tensor) -> torch.Tensor: x0 = x for layer in self.layers: x = layer(x, x0) return x ######################## # Initialization methods def _no_grad_trunc_normal_(tensor, mean, std, a, b): # For PINNet, Raissi et al. 2019 # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # grab from upstream pytorch branch and paste here for now def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def init_weights_trunc_normal(m): # For PINNet, Raissi et al. 2019 # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf if type(m) == BatchLinear or type(m) == nn.Linear: if hasattr(m, 'weight'): fan_in = m.weight.size(1) fan_out = m.weight.size(0) std = math.sqrt(2.0 / float(fan_in + fan_out)) mean = 0. # initialize with the same behavior as tf.truncated_normal # "The generated values follow a normal distribution with specified mean and # standard deviation, except that values whose magnitude is more than 2 # standard deviations from the mean are dropped and re-picked." _no_grad_trunc_normal_(m.weight, mean, std, -2 * std, 2 * std) def init_weights_normal(m): if type(m) == BatchLinear or type(m) == nn.Linear: if hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') def init_weights_selu(m): if type(m) == BatchLinear or type(m) == nn.Linear: if hasattr(m, 'weight'): num_input = m.weight.size(-1) nn.init.normal_(m.weight, std=1 / math.sqrt(num_input)) def init_weights_elu(m): if type(m) == BatchLinear or type(m) == nn.Linear: if hasattr(m, 'weight'): num_input = m.weight.size(-1) nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input)) def init_weights_xavier(m): if type(m) == BatchLinear or type(m) == nn.Linear: if hasattr(m, 'weight'): nn.init.xavier_normal_(m.weight) def sine_init(m): with torch.no_grad(): if hasattr(m, 'weight'): num_input = m.weight.size(-1) # See supplement Sec. 1.5 for discussion of factor 30 m.weight.uniform_(-math.sqrt(6 / num_input) / 30, math.sqrt(6 / num_input) / 30) def first_layer_sine_init(m): with torch.no_grad(): if hasattr(m, 'weight'): num_input = m.weight.size(-1) # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 m.weight.uniform_(-1 / num_input, 1 / num_input) def softmax_init(m): with torch.no_grad(): nn.init.normal_(m.weight, mean=0, std=0.01) nn.init.constant_(m.bias, val=0)