기계는 거짓말하지 않는다

Pytorch Dataset, DataLoader 본문

AI

Pytorch Dataset, DataLoader

KillinTime 2021. 9. 5. 15:06

PyTorch는 torch.utils.data.Datasettorch.utils.data.DataLoader의 두 가지 데이터셋 라이브러리를 제공하며

미리 준비된(pre-loaded) 데이터셋 뿐 아니라 가지고 있는 데이터를 사용할 수 있다.

Dataset은 data와 label을 저장한다.

DataLoader는 Dataset에 쉽게 접근할 수 있도록 iterable 객체로 만들어 주고

sampler, shuffle, batch_size 등 다양한 매개변수를 설정 할 수 있다.

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class CustomDataset(Dataset):
    def __init__(self, data, transforms=None):
        self.x = [i[0] for i in data]
        self.y = [i[1] for i in data]

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]

        return x, y

data = [[2, 0], [4, 0], [6, 0], [8, 1], [10, 1], [12, 11], [15, 9]]

train_dataset = CustomDataset(data, transforms=None) # tansfroms 이미지 증폭 등

# x, y 각각 tensor로 shuffle=True 이므로 batch_size 만큼 리스트로 섞어서 출력.
# batch_size 마지막에 남는건 그대로 출력(drop_last=False)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, drop_last=False)

for x, y in train_loader:
    print(x, y)

 

'AI' 카테고리의 다른 글

Pytorch Multiclass Classification  (0) 2021.10.11
Pytorch torchvision transforms  (0) 2021.09.05
Pytorch no_grad, eval  (0) 2021.08.29
단층 퍼셉트론(Single-Layer Perceptron) AND, OR, NAND  (0) 2021.08.19
Classification 성능평가  (0) 2021.08.18
Comments