인공지능/구현

PyTorch Dataset (1) - Dataset

전공생 2024. 1. 22. 16:08

이 글은 pytorch에서 데이터셋을 불러오는데 아주 중요한 역할을 하는 Dataset 클래스에 대해 이전에 공부했던 내용을 정리한 글이다.

참고 : https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset

Dataset 클래스

모든 데이터셋은 key로 데이터 샘플을 매핑해서 표현되고 이 Dataset 클래스를 상속받아야 한다(subclass). subclass들은 모두 주어진 key에 대해 데이터 샘플에 접근할 수 있는 __getitem__을 덮어써야(overwrtie)한다. subclass들은 많은 ~torch.utils.data.Sampler클래스에 의해 구현된 데이터셋의 크기를 반환하는 __len__을 덮어쓸 수 있고(이 말인 즉슨 전체 데이터로부터 Sampler를 통해 뽑아낸 부분적인 데이터셋을 사용하게 되어 그의 크기를 __len__로 얻을 수 있다는 뜻인것 같다.) ~torch.utils.data.DataLoader클래스의 default option이 된다. subclass들은 배치의 샘플들의 로딩 속도를 빠르게 하기 위해 선택적으로 __getitems__를 구현할 수 있다.

 

[note] ~torch.utils.data.DataLoader 클래스는 integral(필수적인) index(index값을 알아야만 데이터에 접근할 수 있기 때문에 필수적인(integral) index라고 하는 것 같다.)들을 생성하는 index sampler를 디폴트로 생성한다. 이것을 non-integral indices/keys로 된 map-style의 데이터셋으로 만들기 위해서는 custom sampler를 만들어야 한다.

 

*list, map 자료구조

더보기
  • map: 키(key)랑 값(value)으로 나눠서 데이터를 관리하는 자료구조. 순서가 없으며 키에 대한 중복이 없다. 인덱스 대신 키값으로 값을 찾는다. iterator 클래스를 이용해서 키 값을 순서대로 iterator에 저장해두면 순서대로 데이터 추출이 가능하다.
  • list: 순서가 있고 중복을 허용하는 자료구조. 각 값에 0부터 시작하는 숫자값을 인덱스로 지정한 후 인덱스를 통해 값을 찾는다.

__add__ 함수: ConcatDataset(self, other: Dataset) -> return ConcatDataset([self, other])


Dataset Types: Map-style과 iterable-style로 나뉜다.

Map-style

  • _getitem__()__len__() 프로토콜을 구현하고 non-integral indices/keys로부터 데이터 샘플로의 map으로 표현되는 dataset이다. 예를 들면, dataset[idx]로 데이터셋에 접근할 때 idx번째의 이미지와 디스크의 폴더로부터 대응되는 라벨을 읽을 수 있다.
  • Dataset 클래스는 map-style의 데이터셋에 대한 클래스다. Iterable-style의 클래스는 IterableDataset으로 사용할 수 있다.

데이터 로딩 순서

  • torch.utils.data.Sampler 클래스는 데이터 로딩에 사용되는 인덱스 혹은 key의 순서를 구체적으로 명시하는데 사용된다.
  • sequential 혹은 shuffled sampler는 자동으로 DataLoader로 들어가는 shuffle인자 기반으로 만들어진다. 혹은 대안적으로 사용자들은 fetch할 다음 인덱스 혹은 key를 매시간 만들도록 하는 커스텀 Sampler 오브젝트를 생성하여 sampler 인자에 넣을수도 있다.

Sampler

  • 데이터셋에서 데이터를 뽑는 녀셕이다.
  • 데이터를 뽑는 방법에 따라 SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler, DistributedSampler 처럼 다양하게 Sampler가 구성된다.
  • 모든 Sampler는 데이터셋의 인덱스들 혹은 인덱스들의 리스트를 순회하는 방법을 제공하는 __iter__()와 리턴된 iterator의 길이를 리턴하는 __len__()를 상속받는다.
  • 샘플링 개념 자체가 통계에서 전체 집단을 표현할 좀 더 작은 집단을 뽑아서 볼때 사용되기 때문에 파이토치의 sampler역시 batch와 같은 더 작은 데이터셋을 전체 데이터셋으로부터 뽑을때 사용되는 것 같다.

Iterable-style

  • iterable-style 데이터셋은 __iter__() 프로토콜을 구현하고 데이터 샘플을 반복할 수 있도록 표현하는 IterableDataset의 subclass의 인스턴스이다. 이 타입의 데이터셋은 특히 random 읽기가 계산이 많거나 불가능할때와 배치 크기가 fetched data에 의존할 경우 사용하기 적합하다. 예를 들면, iter(dataset)로 호출된 데이터셋은 데이터셋으로부터 읽어들인 데이터의 스트림을 반환하고, 서버에 원격접속하고, 실시간으로 로그를 생성할 수 있다.
  • 데이터 샘플들이 iterable하게 표현되는 모든 데이터셋은 이 torch.utils.data.IterableDataset을 상속해야 한다. 데이터가 stream으로부터 올때 유용한 데이터셋이다. 모든 subclass는 데이터셋의 샘플에 대한 iterator를 반환하는 __iter__()를 덮어써야 한다.

데이터 로딩 순서:

  • iterable-style의 데이터셋의 경우, 데이터 로딩 순서는 전체적으로 사용자 정의의 iterable에 의해 통제된다. 이는 chunk-reading이나 배치로 나뉘어진 샘플을 매시간 생산함으로써 동적 배치 사이즈의 구현을 더 쉽게 해준다.
  • subclass가 DataLoader와 사용될 때 데이터셋의 각 item은 DataLoader iterator로부터 생산되게 된다. num_workers>0일때, 각 worker 프로세스는 데이터셋 object의 다른 copy를 가지게 되고, 그래서 종종 worker들로부터 반환되는 중복데이터를 피하기 위해 각 copy를 독립적으로 환경을 설정한다. 이 말의 뜻은, worker 별로 __iter__() 함수를 독립적으로 불러오기 때문에 동일한 데이터셋을 불러와 다른 worker가 처리하는 데이터까지 중복으로 불러오는 문제가 생긴다는 것이다. 따라서 worker당 할당된 데이터셋만 불러오도록 추가적인 데이터 설정(재분배)이 필요하다. get_worker_info() 는 worker에 대한 정보를 반환하는데, dataset의 __iter__() 메써드나 DataLoader의 각 copy의 behavior를 수정하기 위한 worker_init_fn 옵션에 사용된다.

example1) dataset의 __iter__()를 수정하여 모든 worker에 대한 workload를 나누기

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None: # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else: # in a worker process
            #split worload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) # workload per worker
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))


>>> ds = MyIterableDataset(start=3, end=7)

 

example2) worker_init_fn을 사용하여 모든 worker에 대한 workload 나누기

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end 

    def __iter__(self):
        return iter(range(self.start, self.end))

>>> ds = MyIterableDataset(start=3, end=7)

# Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

# multi-process loading을 할 경우 중복 데이터가 생성된다.
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

# 각각의 데이터셋을 다르게 copy하게끔 고려해서 worker_init_fn을 정의한다.
>>> def worker_init_fn(worker_id):
            worker_info = torch.utils.data.get_worker_info()
            dataset = worker_info.dataset
            overall_start = dataset.start
            overall_end = dataset.end
            per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            dataset.start = overall_start + per_worker * worker_id
            dataset.end = min(dataset.start + per_worker, overall_end)

>>> # custom한 `worker_init_fn`으로 multi-process loading을 동작
>>> # Worker 0 fetched [3, 4], Worker 1 fetched [5, 6]
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

example2 -> worker 개수가 여러개 일때 데이터가 중복되어 복사되는 오류를 파하기 위해 위와 같은 코드를 추가하여 조심하라는 뜻인 것 같다.

Memory Pinning

pinned(page-locked) 메모리를 사용하여 데이터를 로딩하면 데이터를 더 빠르게 CUDA를 사용할 수 있는 GPU로 옮길 수 있다.

이 부분은 나중에 GPU를 효율적으로 다룰때 공부해볼 것.

Question

worker가 여러개인 경우는 어떤 경우인가?

더보기

example 2) worker_init_fn을 사용하여 모든 worker에 대한 workload 나누기 

참고

dataloader는 dataset을 어떻게 사용해서 매 에폭마다 데이터를 불러오는가?

더보기

torch.utils.data.Dataset → map-style dataset

  • __init__()
  • __getitem__()
  • __len__()
  • __getitems__() : option. batched sample 로딩 속도를 높이기 위해 구현할수도 있다. 이는 배치 내의 샘플들의 인덱스 리스트를 사용하거나 샘플 리스트를 반환한다.

 

'인공지능 > 구현' 카테고리의 다른 글

PyTorch Dataset (2) - ImageFolder, DataFolder  (0) 2024.01.22