import threading import torch import math from logging import * from typing import Dict, List class Preloader(object): def __init__(self, device=None) -> None: super().__init__() self.stream = torch.cuda.Stream(device) self.event_chunk_loaded = None def preload_chunk(self, chunk): if self.event_chunk_loaded is not None: self.event_chunk_loaded.wait() if chunk.loaded: return # print(f'Preloader: preload chunk #{chunk.id}') self.event_chunk_loaded = threading.Event() threading.Thread(target=Preloader._load_chunk, args=(self, chunk)).start() def _load_chunk(self, chunk): with torch.cuda.stream(self.stream): chunk.load() self.event_chunk_loaded.set() # print(f'Preloader: chunk #{chunk.id} is loaded') class DataLoader(object): class Iter(object): def __init__(self, chunks, batch_size, shuffle, device: torch.device, preloader: Preloader): super().__init__() self.batch_size = batch_size self.chunks = chunks self.offset = -1 self.chunk_idx = -1 self.current_chunk = None self.shuffle = shuffle self.device = device self.preloader = preloader def __del__(self): #print('DataLoader.Iter: clean chunks') if self.preloader is not None and self.preloader.event_chunk_loaded is not None: self.preloader.event_chunk_loaded.wait() chunks_to_reserve = 1 if self.preloader is None else 2 for i in range(chunks_to_reserve, len(self.chunks)): if self.chunks[i].loaded: self.chunks[i].release() def __next__(self): if self.offset == -1: self._next_chunk() stop = min(self.offset + self.batch_size, len(self.current_chunk)) if self.indices is not None: indices = self.indices[self.offset:stop] else: indices = torch.arange(self.offset, stop, device=self.device) self.offset = stop if self.offset >= len(self.current_chunk): self.offset = -1 return self.current_chunk[indices] def _next_chunk(self): if self.current_chunk is not None: chunks_to_reserve = 1 if self.preloader is None else 2 if len(self.chunks) > chunks_to_reserve: self.current_chunk.release() if self.chunk_idx >= len(self.chunks) - 1: raise StopIteration() self.chunk_idx += 1 self.current_chunk = self.chunks[self.chunk_idx] self.offset = 0 self.indices = torch.randperm(len(self.current_chunk)).to(device=self.device) \ if self.shuffle else None if self.preloader is not None: self.preloader.preload_chunk(self.chunks[(self.chunk_idx + 1) % len(self.chunks)]) def __init__(self, dataset, batch_size, *, chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args): super().__init__() self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.chunk_args = chunk_args self.preloader = Preloader(self.dataset.device) if enable_preload else None self._init_chunks(chunk_max_items) def __iter__(self): return DataLoader.Iter(self.chunks, self.batch_size, self.shuffle, self.dataset.device, self.preloader) def __len__(self): return sum(math.ceil(len(chunk) / self.batch_size) for chunk in self.chunks) def _init_chunks(self, chunk_max_items): data: Dict[str, torch.Tensor] = self.dataset.get_data() if self.shuffle: rand_seq = torch.randperm(self.dataset.n_views).to(device=self.dataset.device) data = {key: val[rand_seq] for key, val in data.items()} self.chunks = [] n_chunks = 1 if chunk_max_items is None else \ math.ceil(self.dataset.n_pixels / chunk_max_items) views_per_chunk = math.ceil(self.dataset.n_views / n_chunks) for offset in range(0, self.dataset.n_views, views_per_chunk): sel = slice(offset, offset + views_per_chunk) chunk_data = {key: val[sel] for key, val in data.items()} self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset, chunk_data=chunk_data, **self.chunk_args)) if self.preloader is not None: self.preloader.preload_chunk(self.chunks[0]) class MultiScaleDataLoader(object): class Iter(object): def __init__(self, sub_loaders: List[DataLoader]): super().__init__() self.sub_loaders = sub_loaders self.end_flags = [False] * len(sub_loaders) self.sub_iters = [sub_loader.__iter__() for sub_loader in sub_loaders] def __del__(self): self.sub_iters.clear() def __next__(self): # Try move all iterators forward and collect data data_frags = [] for i in range(len(self.sub_iters)): try: data_frags.append(self.sub_iters[i].__next__()) except StopIteration: self.end_flags[i] = True data_frags.append(None) # Stop iteration when all iterators have reached the end at least once if all(self.end_flags): raise StopIteration() # Cycle short iterators for i in range(len(self.sub_iters)): if data_frags[i] is None: self.sub_iters[i] = self.sub_loaders[i].__iter__() data_frags[i] = self.sub_iters[i].__next__() return data_frags def __init__(self, dataset, batch_size, *, chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args): super().__init__() self.batch_size = batch_size self.sub_loaders = [ DataLoader(sub_dataset, batch_size // len(dataset), chunk_max_items=chunk_max_items // len(dataset) if chunk_max_items is not None else None, shuffle=shuffle, enable_preload=enable_preload, **chunk_args) for sub_dataset in dataset ] # Sort by datasets' levels self.sub_loaders.sort(key=lambda loader: loader.dataset.level) self.active_sub_loaders = self.sub_loaders def __iter__(self): for loader in self.active_sub_loaders: loader.batch_size = self.batch_size // len(self.active_sub_loaders) return MultiScaleDataLoader.Iter(self.active_sub_loaders) def __len__(self): return max([len(loader) for loader in self.active_sub_loaders]) def set_active_sub_loaders(self, idx): self.active_sub_loaders = self.sub_loaders[idx] if not isinstance(self.active_sub_loaders, list): self.active_sub_loaders = [self.active_sub_loaders] def get_loader(dataset, batch_size, *, chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args): if isinstance(dataset, list): return MultiScaleDataLoader(dataset, batch_size, chunk_max_items=chunk_max_items, shuffle=shuffle, enable_preload=enable_preload, **chunk_args) return DataLoader(dataset, batch_size, chunk_max_items=chunk_max_items, shuffle=shuffle, enable_preload=enable_preload, **chunk_args)