기계는 거짓말하지 않는다

PyTorch Windows 환경 DataLoader num_workers 본문

AI

PyTorch Windows 환경 DataLoader num_workers

KillinTime 2025. 5. 22. 19:19

Windows 환경에서 multiprocessing(DataLoader의 num_workers>0)이 내부적으로 프로세스를

spawn(push)할 때 발생 할 수 있는 오류의 예이다.

EOFError: Ran out of input
...
ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function <lambda> at 0x0000029A93FCF820>: attribute lookup <lambda> on __main__ failed

익명 함수(lambda)나 로컬 함수 등 피클로 직렬화할 수 없는 객체를 읽어 들이려 하기 때문이다

num_workers=0으로 설정하면 해결은 되지만 여러 workers를 사용하여 데이터를 읽으려 할 때는 이렇게 할 수 없다.

아래는 Windows 환경에서 DataLoader의 num_workers를 2 이상으로 사용하는 방법이다.

if __name__ == "__main__" 적용

Windows는 자식 프로세스를 spawn 방식으로 실행하기 때문에, 모듈 최상단의 실행문이 그대로 다시 실행된다.

이로 인해 무한 재귀나 피클링 오류가 발생한다.

따라서 모든 초기화 코드는 if __name__ == "__main__" 블록 안에서 수행한다.

이 구조는 자식 프로세스가 main()을 재실행하지 않는다.

# import, 다른 코드 등 ...
def main():
    dataset = CustomDataset(...)
    loader  = DataLoader(
        dataset,
        batch_size=32,
        num_workers=4,
        collate_fn=my_collate
    )
    for data in loader:
        # 학습 또는 추론
        pass

if __name__ == "__main__":
    main()

람다(lambda) 제거 및 전역 함수/클래스 사용

collate_fn이나 worker_init_fn 등에 람다나 로컬 함수를 넘기면 spawn 모드에서 피클링이 불가능하다.

반드시 모듈 최상단에 정의된 전역 함수나 callable 클래스를 사용해야 한다.

lambda를 사용한 함수, 변수 등을 변경해야한다.

# import, 다른 코드 등 ...
class CustomCollator:
    def __init__(self, dataset):
        self.dataset = dataset

    def __call__(self, batch):
        # custom_collate_fn는 임의로 만든 collate_fn
        return custom_collate_fn(
            batch,
            self.dataset,
        )

def main(loader):
    for data in loader:
        # 학습 또는 추론
        pass

if __name__ == "__main__":
    dataset = CustomDataset(...)
    collator = CustomCollator(dataset)
    loader = DataLoader(
        train_dataset,
        batch_size=32,
        num_workers=4,
        collate_fn=collator,
    )

    main(loader)
Comments