type.py 970 Bytes
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import Any, Dict, Union
import torch

InputData = Dict[str, Union[torch.Tensor, Any]]
ReturnData = Dict[str, Union[torch.Tensor, Any]]
NetOutput = Dict[str, torch.Tensor]


class NetInput:
    def __init__(self, x: torch.Tensor = None, d: torch.Tensor = None, f: torch.Tensor = None) -> None:
        self.x = x
        self.d = d
        self.f = f
        if x is not None:
            self.shape = x.shape[:-1]
        elif d is not None:
            self.shape = d.shape[:-1]
        else:
            self.shape = [0]

    def __getitem__(self, index: Union[int, slice, list, tuple, torch.Tensor, None]) -> 'NetInput':
        if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
            index = index.nonzero(as_tuple=True)
        return NetInput(
            self.x[index] if self.x is not None else None,
            self.d[index] if self.d is not None else None,
            self.f[index] if self.f is not None else None
        )