from ..__common__ import * __all__ = ["Field"] class Field(nn.Module): def __init__(self, x_chns: int, shape: list[int], skips: list[int] = [], act: str = 'relu', with_ln: bool = False): super().__init__({"x": x_chns}, {"f": shape[1]}) self.net = nn.FcBlock(x_chns, 0, *shape, skips, act, with_ln=with_ln) # stub method for type hint def __call__(self, x: torch.Tensor) -> torch.Tensor: ... def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x)