import torch
import math
from ..my import device


class FastDataLoader(object):

    class Iter(object):

        def __init__(self, dataset, batch_size, shuffle, drop_last) -> None:
            super().__init__()
            self.indices = torch.randperm(len(dataset), device=device.GetDevice()) \
                if shuffle else torch.arange(len(dataset), device=device.GetDevice())
            self.offset = 0
            self.batch_size = batch_size
            self.dataset = dataset
            self.drop_last = drop_last

        def __next__(self):
            if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset):
                raise StopIteration()
            indices = self.indices[self.offset:self.offset + self.batch_size]
            self.offset += self.batch_size
            return self.dataset[indices]

    def __init__(self, dataset, batch_size, shuffle, drop_last, **kwargs) -> None:
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

    def __iter__(self):
        return FastDataLoader.Iter(self.dataset, self.batch_size,
                                   self.shuffle, self.drop_last)

    def __len__(self):
        return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \
            else math.ceil(len(self.dataset) / self.batch_size)