인공지능/구현

PyTorch Dataset (2) - ImageFolder, DataFolder

전공생 2024. 1. 22. 16:09

데이터셋은 학습에 필요한 데이터셋, 검증에 필요한 데이터셋, 평가에 필요한 데이터셋이 필요하다. train.py와 test.py(혹은 eval.py)로 나누어 코드를 짜는 경우엔 train.py에는 학습 데이터셋, 검증(평가) 데이터셋을 구현하여 사용하면 되고 test.py(혹은 eval.py)에는 평가 데이터셋을 구현해야 한다.

train_set = CustomDataset(...)
validation_set = CustomDataset(...)
test_set = CustomDataset(...)

 

보통 데이터는 다음과 같은 트리구조로 되어있다.

data
  |-- train
    |       |-- cat 
    |       |       |-- cat1.jpg
    |       |       |-- cat2.jpg
    |       |       |-- cat3.jpg
    |         |        |         ...
    |       |       |-- cat4000.jpg
    |       |-- dog 
    |       |       |-- dog1.jpg
    |       |       |-- dog2.jpg
    |       |       |-- dog3.jpg
    |         |        |         ...
    |       |       |-- dog4000.jpg
  |-- test
    |       |-- 1.jpg 
    |       |-- 2.jpg 
    |       |-- 3.jpg 
    |       |    ...
    |       |-- 100.jpg

 

이 파일 트리구조 정보를 이용해서 데이터셋을 불러온다. - ImageFolder

ImageFolder

torchvision.datasets.ImageFolder

DatasetFolder 상속받는데, DataFolder의 똑같은 methods를 원하는 데이터셋에 커스터마이징한다고 볼 수 있다.

DataFolder의 self.imgs = self.samples, DataFolder에서 요구하는 IMG_EXTENSIONS 인자를 받지 않는다.

DataFolder

VisionDataset을 상속받는다.(VisionDatasetDataset을 상속받음 → __getitem__, __len__ 구현됨)

# find_classes(self.root) method를 통해 class, class_to_idx 생성
classes, class_to_idx = self.find_classes(self.root)

# 데이터 파일 경로, class당 label 정보를 받아 dataset을 만든다.
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

 

 

find_classes

이 함수는 데이터셋 구조가 다음과 같을때 클래스 폴더를 찾아낸다.

directory/
    ├── class_x
    │   ├── xxx.ext
    │   ├── xxy.ext
    │   └── ...
    │       └── xxz.ext
    └── class_y
        ├── 123.ext
        ├── nsdf3.ext
        └── ...
        └── asd932_.ext

디렉토리명 리스트를 얻기 위해

os.listdir()을 사용할 수도 있고, os.scandir(), 혹은 pathlib.Path()를 사용할 수도 있다.

os.listdir()는 해당 디렉토리 내의 파일, 폴더 명까지 모두 리스트형태로 반환한다.

os.scandir()는 해당 디렉토리의 모든 항목을 Iterator 형태로 포인팅한다. 해당 디렉토리의 각 항목의 name 속성을 통해 파일명에 접근할 수 있다. for entry.name in os.scandir(dir):

pathlib.Path() 역시 os.scandir()과 마찬가지로 iterator 형태로 해당 디렉토리에 대한 정보를 포함하여 반환한다.

파일 유형이나 파일 속성 정보까지 필요할 땐 os.scandir() 이나 pathlib.Path()를 사용하는 것이 좋고, 디렉토리 목록만 필요할 경우 os.listdir()를 사용해도 괜찮을것 같다.

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
        # directory 내의 폴더 명(=>클래스 명) 추출.
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

        # 클래스명과 라벨을 이어서 딕셔너리 형태로 만들어줌 {"클래스명" : 라벨, ..} 
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

 

make_dataset

def make_dataset(
        directory: str,
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Union[str, Tuple[str, ...]]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    See :class:`DatasetFolder` for details.

    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
        ## dir : "~/data"
        ## dir = os.path.expanduser(dir) 
        ## dir : "User/{user}/data" # ~ 기호를 홈 디렉토리 경로로 확장해준다.
        ## 즉 절대 경로를 얻으려고 사용. `os.getcwd()`를 통해 절대경로를 얻을수도 있다.
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:  # None이 아닌 경우를 말하는 건가?
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:

        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        # target_class -> class_index
                # ex) 'cat' -> 2
                class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = path, class_index
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances

'인공지능 > 구현' 카테고리의 다른 글

PyTorch Dataset (1) - Dataset  (0) 2024.01.22