upsampling.py 3.01 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
import os
from numpy.core.fromnumeric import trace
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
from numpy.lib.arraysetops import isin
BobYeah's avatar
sync    
BobYeah committed
4
5
6
7
import torch
import torchvision.transforms.functional as trans_f
from ..my import util
from ..my import device
Nianchen Deng's avatar
sync    
Nianchen Deng committed
8
from ..my import color_mode
BobYeah's avatar
sync    
BobYeah committed
9
10
11
12
13
14
15
16
17


class UpsamplingDataset(torch.utils.data.dataset.Dataset):
    """
    Dataset for upsampling task

    """

    def __init__(self, data_dir: str, input_patt: str, gt_patt: str,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
                 color: int, load_once: bool = True):
BobYeah's avatar
sync    
BobYeah committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        """
        Initialize dataset for upsampling task

        :param data_dir: directory of dataset
        :param input_patt: file pattern for input (low resolution) images
        :param gt_patt: file pattern for ground truth (high resolution) images
        :param load_once: load all samples to current device at once to accelerate 
            training, suitable for small dataset
        :param load_gt: whether to load ground truth images
        """
        self.input_patt = os.path.join(data_dir, input_patt)
        self.gt_patt = os.path.join(data_dir, gt_patt) if gt_patt != None else None
        self.n = len(list(filter(
            lambda file_name: os.path.exists(file_name),
            [self.input_patt % i for i in range(
                len(os.listdir(os.path.dirname(self.input_patt))))]
        )))
        self.load_once = load_once
        self.load_gt = self.gt_patt != None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
38
        self.color = color
BobYeah's avatar
sync    
BobYeah committed
39
40
41
42
        self.input = util.ReadImageTensor([self.input_patt % i for i in range(self.n)]) \
            .to(device.GetDevice()) if self.load_once else None
        self.gt = util.ReadImageTensor([self.gt_patt % i for i in range(self.n)]) \
            .to(device.GetDevice()) if self.load_once and self.load_gt else None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
        if self.color == color_mode.GRAY:
BobYeah's avatar
sync    
BobYeah committed
44
45
46
            self.input = trans_f.rgb_to_grayscale(self.input)
            self.gt = trans_f.rgb_to_grayscale(self.gt) \
                if self.gt != None else None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
48
49
        elif self.color == color_mode.YCbCr:
            self.input = util.rgb2ycbcr(self.input)
            self.gt = util.rgb2ycbcr(self.gt) if self.gt != None else None
BobYeah's avatar
sync    
BobYeah committed
50
51
52
53
54
55
56
57

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        if self.load_once:
            return idx, self.input[idx], self.gt[idx] if self.load_gt else False
        if isinstance(idx, torch.Tensor):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
58
59
60
61
62
63
64
65
66
67
68
69
70
            input = util.ReadImageTensor([self.input_patt % i for i in idx])
            gt = util.ReadImageTensor([self.gt_patt % i for i in idx]) if self.load_gt else False
        else:
            input = util.ReadImageTensor([self.input_patt % idx])
            gt = util.ReadImageTensor([self.gt_patt % idx]) if self.load_gt else False
        if self.color == color_mode.GRAY:
            input = trans_f.rgb_to_grayscale(input)
            gt = trans_f.rgb_to_grayscale(gt) if isinstance(gt, torch.Tensor) else False
            return idx, input, gt
        elif self.color == color_mode.YCbCr:
            input = util.rgb2ycbcr(input)
            gt = util.rgb2ycbcr(gt) if isinstance(gt, torch.Tensor) else False
            return idx, input, gt