Skip to content
Snippets Groups Projects
Commit f7038e26 authored by BobYeah's avatar BobYeah
Browse files

sync

parent 72636b3e
No related merge requests found
import torch
import math
from ..my import device
class FastDataLoader(object):
class Iter(object):
def __init__(self, dataset, batch_size, shuffle, drop_last) -> None:
super().__init__()
self.indices = torch.randperm(len(dataset), device=device.GetDevice()) \
if shuffle else torch.arange(len(dataset), device=device.GetDevice())
self.offset = 0
self.batch_size = batch_size
self.dataset = dataset
self.drop_last = drop_last
def __next__(self):
if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset):
raise StopIteration()
indices = self.indices[self.offset:self.offset + self.batch_size]
self.offset += self.batch_size
return self.dataset[indices]
def __init__(self, dataset, batch_size, shuffle, drop_last, **kwargs) -> None:
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
def __iter__(self):
return FastDataLoader.Iter(self.dataset, self.batch_size,
self.shuffle, self.drop_last)
def __len__(self):
return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \
else math.ceil(len(self.dataset) / self.batch_size)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment