본문 바로가기
DeepLearning Framework & Coding/Pytorch

[pytorch 따라하기-7] pix2pix 구현

by 노마드공학자 2021. 7. 27.

[pytorch 따라하기-1] 구글 Colab에 pytorch 세팅하기 https://limitsinx.tistory.com/136

[pytorch 따라하기-2] Tensor생성 및 Backward https://limitsinx.tistory.com/137 

[pytorch 따라하기-3] 경사하강법을 통한 선형회귀 구현 https://limitsinx.tistory.com/138

[pytorch 따라하기-4] 인공신경망(ANN) 구현 https://limitsinx.tistory.com/139

[pytorch 따라하기-5] 합성곱신경망(CNN) 구현 https://limitsinx.tistory.com/140

[pytorch 따라하기-6] Neural Style Transfer 구현 https://limitsinx.tistory.com/141

 

※이 전글에서 정리한 코드/문법은 재설명하지 않으므로, 참고부탁드립니다

※해당 글은 PC에서 보기에 최적화 되어있습니다.

 


 

pix2pix란?

 

pix2pix란, GAN과 동일한 개념에서 출발하는 것으로, Neural Style Transfer는 In/Out 이미지의 합성이였다면,

pix2pix는 Target이미지에 좀더 Bias를 두어, 타게팅하는 이미지의 성질을 가지도록 이미지를 변환 하는것을 의미합니다.

pix2pix의 단점은, 변화시키고자하는 특성 외의 다른 모든성질은 동일한 "pair Image"가 필요하다는 점입니다.

즉, 흑백 -> 컬러 이미지로 변환하고 싶다면 그림은 동일하고 채색의 유/무만 다른 이미지가 필요하다는 점입니다.

(Pair Image가 없어도 이미지 변환을 가능하도록 한것이 cycleGAN)

 

저는 고흐의 "Starry Night" 흑백/컬러 이미지를 학습시키고 얼룩말 흑백사진을 넣어

마치 고흐의 Starry Night 컬러감을 가지는 얼룩말 사진을 한번 생성해보았습니다.

 

학습에 사용한 이미지

 

 

코드

 

from os import listdir

from os.path import join

import random

import matplotlib.pyplot as plt

%matplotlib inline

 

import numpy as np

import os

import time

from PIL import Image

import torch

import torch.nn as nn

from torch.utils.data import Dataset

from torch.utils.data import DataLoader

import torchvision.transforms as transforms

from torchvision.transforms.functional import to_pil_image

 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Image Data Download

#!git clone https://github.com/mrzhu-cool/pix2pix-pytorch 

#!mkdir 'data' 

#!unzip /content/pix2pix-pytorch/dataset/facades.zip  -d /content/data; 

 

class FacadeDataset(Dataset):

    def __init__(selfpath2imgdirection='b2a'transform=False):

        super().__init__()

        self.direction = direction

        self.path2a = join(path2img, 'a')

        self.path2b = join(path2img, 'b')

        self.img_filenames = [for x in listdir(self.path2a)]

        self.transform = transform

 

    def __getitem__(selfindex):

        a = Image.open(join(self.path2a, self.img_filenames[index])).convert('RGB')

        b = Image.open(join(self.path2b, self.img_filenames[index])).convert('RGB')

        

        if self.transform:

            a = self.transform(a)

            b = self.transform(b)

 

        if self.direction == 'b2a':

            return b,a

        else:

            return a,b

 

    def __len__(self):

        return len(self.img_filenames)

 

# Image Transform

transform = transforms.Compose([

                    transforms.ToTensor(),

                    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),

                    transforms.Resize((256,256))

])

 

# Dataset

path2img = 'Training을 위한 본인 PC의 이미지가 저장된 디렉토리'

path2img_test = 'Test를 위한 본인 PC의 이미지가 저장된 디렉토리'

train_ds = FacadeDataset(path2img, transform = transform)

test_ds = FacadeDataset(path2img_test, transform = transform)

 

a,b = train_ds[0]

plt.figure(figsize=(10,10))

plt.subplot(1,2,1)

plt.imshow(to_pil_image(0.5*a+0.5))

plt.axis('off')

plt.subplot(1,2,2)

plt.imshow(to_pil_image(0.5*b+0.5))

plt.axis('off')

 

train_dl = DataLoader(train_ds, batch_size = 1, shuffle=False)

test_dl = DataLoader(test_ds, batch_size = 1, shuffle = False)

 

class UNetDown(nn.Module):

    def __init__(selfin_channelsout_channelsnormalize=Truedropout=0.0):

        super().__init__()

 

        layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]

 

        if normalize:

            layers.append(nn.InstanceNorm2d(out_channels)),

 

        layers.append(nn.LeakyReLU(0.2))

 

        if dropout:

            layers.append(nn.Dropout(dropout))

 

        self.down = nn.Sequential(*layers)

 

    def forward(selfx):

        x = self.down(x)

        return x

 

# check

x = torch.randn(16, 3, 256,256, device=device)

model = UNetDown(3,64).to(device)

down_out = model(x)

print(down_out.shape)

 

class UNetUp(nn.Module):

    def __init__(selfin_channelsout_channelsdropout=0.0):

        super().__init__()

 

        layers = [

            nn.ConvTranspose2d(in_channels, out_channels,4,2,1,bias=False),

            nn.InstanceNorm2d(out_channels),

            nn.LeakyReLU()

        ]

 

        if dropout:

            layers.append(nn.Dropout(dropout))

 

        self.up = nn.Sequential(*layers)

 

    def forward(self,x,skip):

        x = self.up(x)

        x = torch.cat((x,skip),1)

        return x

 

# Model Check

x = torch.randn(16, 128, 64, 64, device=device)

model = UNetUp(128,64).to(device)

out = model(x,down_out)

print(out.shape)

 

# Generator

class GeneratorUNet(nn.Module):

    def __init__(selfin_channels=3out_channels=3):

        super().__init__()

 

        self.down1 = UNetDown(in_channels, 64, normalize=False)

        self.down2 = UNetDown(64,128)                 

        self.down3 = UNetDown(128,256)               

        self.down4 = UNetDown(256,512,dropout=0.5) 

        self.down5 = UNetDown(512,512,dropout=0.5)      

        self.down6 = UNetDown(512,512,dropout=0.5)             

        self.down7 = UNetDown(512,512,dropout=0.5)              

        self.down8 = UNetDown(512,512,normalize=False,dropout=0.5)

 

        self.up1 = UNetUp(512,512,dropout=0.5)

        self.up2 = UNetUp(1024,512,dropout=0.5)

        self.up3 = UNetUp(1024,512,dropout=0.5)

        self.up4 = UNetUp(1024,512,dropout=0.5)

        self.up5 = UNetUp(1024,256)

        self.up6 = UNetUp(512,128)

        self.up7 = UNetUp(256,64)

        self.up8 = nn.Sequential(

            nn.ConvTranspose2d(128,3,4,stride=2,padding=1),

            nn.Tanh()

        )

 

    def forward(selfx):

        d1 = self.down1(x)

        d2 = self.down2(d1)

        d3 = self.down3(d2)

        d4 = self.down4(d3)

        d5 = self.down5(d4)

        d6 = self.down6(d5)

        d7 = self.down7(d6)

        d8 = self.down8(d7)

 

        u1 = self.up1(d8,d7)

        u2 = self.up2(u1,d6)

        u3 = self.up3(u2,d5)

        u4 = self.up4(u3,d4)

        u5 = self.up5(u4,d3)

        u6 = self.up6(u5,d2)

        u7 = self.up7(u6,d1)

        u8 = self.up8(u7)

 

        return u8

 

# check

x = torch.randn(16,3,256,256,device=device)

model = GeneratorUNet().to(device)

out = model(x)

print(out.shape)

 

class Dis_block(nn.Module):

    def __init__(selfin_channelsout_channelsnormalize=True):

        super().__init__()

 

        layers = [nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)]

        if normalize:

            layers.append(nn.InstanceNorm2d(out_channels))

        layers.append(nn.LeakyReLU(0.2))

    

        self.block = nn.Sequential(*layers)

 

    def forward(selfx):

        x = self.block(x)

        return x

 

# check

x = torch.randn(16,64,128,128,device=device)

model = Dis_block(64,128).to(device)

out = model(x)

print(out.shape)

 

class Discriminator(nn.Module):

    def __init__(selfin_channels=3):

        super().__init__()

 

        self.stage_1 = Dis_block(in_channels*2,64,normalize=False)

        self.stage_2 = Dis_block(64,128)

        self.stage_3 = Dis_block(128,256)

        self.stage_4 = Dis_block(256,512)

 

        self.patch = nn.Conv2d(512,1,3,padding=1) # 16x16 patch

 

    def forward(self,a,b):

        x = torch.cat((a,b),1)

        x = self.stage_1(x)

        x = self.stage_2(x)

        x = self.stage_3(x)

        x = self.stage_4(x)

        x = self.patch(x)

        x = torch.sigmoid(x)

        return x

# check

x = torch.randn(16,3,256,256,device=device)

model = Discriminator().to(device)

out = model(x,x)

print(out.shape)

 

model_gen = GeneratorUNet().to(device)

model_dis = Discriminator().to(device)

 

# 가중치 초기화

def initialize_weights(model):

    class_name = model.__class__.__name__

    if class_name.find('Conv') != -1:

        nn.init.normal_(model.weight.data, 0.0, 0.02)



# weighting initialization

model_gen.apply(initialize_weights);

model_dis.apply(initialize_weights);

 

# loss function

loss_func_gan = nn.BCELoss()

loss_func_pix = nn.L1Loss()

 

# loss_func_pix weighting

lambda_pixel = 100

 

# # of patch 

patch = (1,256//2**4,256//2**4)

 

# Optimized Parameter

from torch import optim

lr = 2e-4

beta1 = 0.5

beta2 = 0.999

 

opt_dis = optim.Adam(model_dis.parameters(),lr=lr,betas=(beta1,beta2))

opt_gen = optim.Adam(model_gen.parameters(),lr=lr,betas=(beta1,beta2))

 

# model training

model_gen.train()

model_dis.train()

 

batch_count = 0

num_epochs = 100

start_time = time.time()

 

loss_hist = {'gen':[],

             'dis':[]}

 

for epoch in range(num_epochs):

    for a, b in train_dl:

        ba_si = a.size(0)

 

        # real image

        real_a = a.to(device)

        real_b = b.to(device)

 

        # patch label

        real_label = torch.ones(ba_si, *patch, requires_grad=False).to(device)

        fake_label = torch.zeros(ba_si, *patch, requires_grad=False).to(device)

 

        # generator

        model_gen.zero_grad()

 

        fake_b = model_gen(real_a) 

        out_dis = model_dis(fake_b, real_b) 

 

        gen_loss = loss_func_gan(out_dis, real_label)

        pixel_loss = loss_func_pix(fake_b, real_b)

 

        g_loss = gen_loss + lambda_pixel * pixel_loss

        g_loss.backward()

        opt_gen.step()

 

        # discriminator

        model_dis.zero_grad()

 

        out_dis = model_dis(real_b, real_a) 

        real_loss = loss_func_gan(out_dis,real_label)

        

        out_dis = model_dis(fake_b.detach(), real_a) 

        fake_loss = loss_func_gan(out_dis,fake_label)

 

        d_loss = (real_loss + fake_loss) / 2.

        d_loss.backward()

        opt_dis.step()

 

        loss_hist['gen'].append(g_loss.item())

        loss_hist['dis'].append(d_loss.item())

 

        batch_count += 1

        if batch_count % 100 == 0:

            print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, g_loss.item(), d_loss.item(), (time.time()-start_time)/60))

 

# loss history

plt.figure(figsize=(10,5))

plt.title('Loss Progress')

plt.plot(loss_hist['gen'], label='Gen. Loss')

plt.plot(loss_hist['dis'], label='Dis. Loss')

plt.xlabel('batch count')

plt.ylabel('Loss')

plt.legend()

plt.show()

 

# Saving weight

path2models = './models/'

os.makedirs(path2models, exist_ok=True)

path2weights_gen = os.path.join(path2models, 'weights_gen.pt')

path2weights_dis = os.path.join(path2models, 'weights_dis.pt')

 

torch.save(model_gen.state_dict(), path2weights_gen)

torch.save(model_dis.state_dict(), path2weights_dis)

 

# Weight control

weights = torch.load(path2weights_gen)

model_gen.load_state_dict(weights)

 

# evaluation model

model_gen.eval()

 

# Image Generating

with torch.no_grad():

    for a,in test_dl:

        fake_imgs = model_gen(a.to(device)).detach().cpu()

        #plt.savefig('fake_to_real.jpg')

        #real_imgs = b

        real_imgs = a

        break

 

# Generating Fake to real image

plt.figure(figsize=(15,15))

 

for ii in range(0,2,2):

    plt.subplot(1,2,ii+1)

    plt.imshow(to_pil_image(0.5*real_imgs[ii]+0.5))

    plt.axis('off')

    plt.subplot(1,2,ii+2)

    plt.imshow(to_pil_image(0.5*fake_imgs[ii]+0.5))

    plt.axis('off')

 

 

결과물

학습한 이미지

흑백의 그림을 컬러 그림의 채색을 가지도록 학습 진행

 

결과물 이미지

이미지 1장을 학습하여 결과물을 도출했지만, 꽤나 유의미한 결과물이 나왔습니다.

고흐의 이미지 모두를 흑백/컬러로 나누어 몇십장을 학습시켰다면 더욱 정교한 이미지가 나오지않았을까하네요 :)

 

댓글