기계는 거짓말하지 않는다

Pytorch torchvision transforms 본문

AI

Pytorch torchvision transforms

KillinTime 2021. 9. 5. 18:55

이미지 Augmentation을 할 때 사용할 수 있다.

Augmentation은 데이터의 양을 늘리기 위해 원본에 각종 변환을 적용하고 데이터를 늘릴 수 있다.

동일한 이미지들을 조금씩 변형시켜가며 학습하면서 Overfitting을 방지하는 데 도움이 된다.

학습 용도에 맞는 augmentation을 선택해서 사용하여야 한다.

보통 Training 단계에서 많이 사용되지만 Test 단계에서도 사용이 가능하며,

이를 Test Time-Augmentation(TTA)이라고 함

Dataloader를 이용하여 받으면 channel, height, width(c, h, w) shape가 되며

img.cpu().numpy().transpose(1, 2, 0).copy()를 이용하면 OpenCV에서 사용할 수 있는 (h, w, c) shape가 된다.

import os
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchvision.utils import save_image
from PIL import Image

import cv2
import numpy as np

# image type
image_format = {".bmp", ".jpg", ".jpeg", ".png", ".tif", ".tiff", ".dom"}

# transforms show function
def transformsShow(name = "img",):
    def transforms_show(img):
        cv2.imshow(name, np.array(img))
        if cv2.waitKey(0) & 0xff == ord("q"):
            exit()
        return img
    return transforms_show

# 폴더에서 파일 가져옴
def file_search(folder_path):
    img_root = []

    for (path, dir, file) in os.walk(folder_path):
        # print(path, dir, file)
        for file_name in file:
            # image.jpeg -> .jpeg
            ext = os.path.splitext(file_name)[-1].lower()
            if ext in image_format:
                root = os.path.join(path, file_name)
                img_root.append(root)

    return img_root

img_path = "./data/"
data_path = file_search(img_path)

for path in data_path:
    print(path)

# custom dataset
class CustomDataset(Dataset):
    def __init__(self, path, transforms=None):
        self.path = path
        self.transfroms = transforms

    def __getitem__(self, item):
        path = self.path[item]
        img = cv2.imread(path)
        
        if self.transfroms is not None:
            img = self.transfroms(img)
        
        return img

    def __len__(self):
        return len(self.path)

# augmentation
my_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ColorJitter(brightness=0.5),
    transforms.Resize((255, 255)),
    transforms.CenterCrop((224, 224)),
    transforms.RandomHorizontalFlip(p=0.8),
    transforms.RandomGrayscale(p=0.4),
    transforms.RandomRotation(degrees=0.5),
    transforms.ToTensor()
])

dataset = CustomDataset(data_path, transforms=my_transforms)

img_num = 1
for _ in range(5):
    for img in dataset:
        save_image(img, "./data/img"+str(img_num) + ".png")
        img_num += 1

'AI' 카테고리의 다른 글

다층 퍼셉트론(Multi-Layer Perceptron) XOR  (0) 2021.10.18
Pytorch Multiclass Classification  (0) 2021.10.11
Pytorch Dataset, DataLoader  (0) 2021.09.05
Pytorch no_grad, eval  (0) 2021.08.29
단층 퍼셉트론(Single-Layer Perceptron) AND, OR, NAND  (0) 2021.08.19
Comments