diff --git a/data/loader.py b/data/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4d14663e27e06106172b8ce9cd62632911bc4b --- /dev/null +++ b/data/loader.py @@ -0,0 +1,39 @@ +import torch +import math +from ..my import device + + +class FastDataLoader(object): + + class Iter(object): + + def __init__(self, dataset, batch_size, shuffle, drop_last) -> None: + super().__init__() + self.indices = torch.randperm(len(dataset), device=device.GetDevice()) \ + if shuffle else torch.arange(len(dataset), device=device.GetDevice()) + self.offset = 0 + self.batch_size = batch_size + self.dataset = dataset + self.drop_last = drop_last + + def __next__(self): + if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset): + raise StopIteration() + indices = self.indices[self.offset:self.offset + self.batch_size] + self.offset += self.batch_size + return self.dataset[indices] + + def __init__(self, dataset, batch_size, shuffle, drop_last, **kwargs) -> None: + super().__init__() + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + + def __iter__(self): + return FastDataLoader.Iter(self.dataset, self.batch_size, + self.shuffle, self.drop_last) + + def __len__(self): + return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \ + else math.ceil(len(self.dataset) / self.batch_size)