From f7038e26d9a54c79a37a1fc16585349ef43b36e6 Mon Sep 17 00:00:00 2001 From: BobYeah <635596704@qq.com> Date: Sat, 9 Jan 2021 09:19:03 +0800 Subject: [PATCH] sync --- data/loader.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 data/loader.py diff --git a/data/loader.py b/data/loader.py new file mode 100644 index 0000000..cd4d146 --- /dev/null +++ b/data/loader.py @@ -0,0 +1,39 @@ +import torch +import math +from ..my import device + + +class FastDataLoader(object): + + class Iter(object): + + def __init__(self, dataset, batch_size, shuffle, drop_last) -> None: + super().__init__() + self.indices = torch.randperm(len(dataset), device=device.GetDevice()) \ + if shuffle else torch.arange(len(dataset), device=device.GetDevice()) + self.offset = 0 + self.batch_size = batch_size + self.dataset = dataset + self.drop_last = drop_last + + def __next__(self): + if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset): + raise StopIteration() + indices = self.indices[self.offset:self.offset + self.batch_size] + self.offset += self.batch_size + return self.dataset[indices] + + def __init__(self, dataset, batch_size, shuffle, drop_last, **kwargs) -> None: + super().__init__() + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + + def __iter__(self): + return FastDataLoader.Iter(self.dataset, self.batch_size, + self.shuffle, self.drop_last) + + def __len__(self): + return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \ + else math.ceil(len(self.dataset) / self.batch_size) -- GitLab