본문 바로가기
DeepLearning Framework & Coding/Pytorch

[pytorch 따라하기-6] Neural Style Transfer 구현(이미지 합성)

by 노마드공학자 2021. 7. 26.

[pytorch 따라하기-1] 구글 Colab에 pytorch 세팅하기 https://limitsinx.tistory.com/136

[pytorch 따라하기-2] Tensor생성 및 Backward https://limitsinx.tistory.com/137 

[pytorch 따라하기-3] 경사하강법을 통한 선형회귀 구현 https://limitsinx.tistory.com/138

[pytorch 따라하기-4] 인공신경망(ANN) 구현 https://limitsinx.tistory.com/139

[pytorch 따라하기-5] 합성곱신경망(CNN) 구현 https://limitsinx.tistory.com/140

 

※이 전글에서 정리한 코드/문법은 재설명하지 않으므로, 참고부탁드립니다

※해당 글은 PC에서 보기에 최적화 되어있습니다.

 


제가 pytorch를 공부하는이유는.. GAN계열의 딥러닝을 하기 위해서인데요

그 시초?라고 할 수 있는 Neural Style Transfer를 구현해보겠습니다.

내용이 워낙방대하여.. ResNet-50이라는 CNN계열의 논문설명과 함께 차후 한개씩 정리해나가도록 하겠습니다.

이번글은 코드와 결과만 정리해보겠습니다.

 

코드

import torch

import torch.nn as nn

import torch.optim as optim

import torch.utils as utils

import torch.utils.data as data

import torchvision.models as models

import torchvision.utils as v_utils

import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import numpy as np

from PIL import Image

%matplotlib inline

 

content_layer_num = 1

image_size = 512

epoch = 5000

 

content_dir = "본인 PC에 저장된 이미지 디렉토리 입력"

style_dir = "본인 PC에 저장된 이미지 디렉토리 입력"



def image_preprocess(img_dir):

    img = Image.open(img_dir)

    transform = transforms.Compose([

                    transforms.Resize(image_size),

                    transforms.CenterCrop(image_size),

                    transforms.ToTensor(),

                    transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], 

                                        std=[1,1,1]),

                    #transforms.Normalize([0.5], [0.5])

                ])

    img = transform(img).view((-1,3,image_size,image_size))

    return img



def image_postprocess(tensor):

    transform = transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], 

                                     std=[1,1,1])

    #transform = transforms.Normalize([0.5], [0.5])

    img = transform(tensor.clone())

    img = img.clamp(0,1)

    img = torch.transpose(img,0,1)

    img = torch.transpose(img,1,2)

    return img

 

resnet = models.resnet50(pretrained=True)

for name,module in resnet.named_children():

    print(name)



class Resnet(nn.Module):

    def __init__(self):

        super(Resnet,self).__init__()

        self.layer0 = nn.Sequential(*list(resnet.children())[0:1])

        self.layer1 = nn.Sequential(*list(resnet.children())[1:4])

        self.layer2 = nn.Sequential(*list(resnet.children())[4:5])

        self.layer3 = nn.Sequential(*list(resnet.children())[5:6])

        self.layer4 = nn.Sequential(*list(resnet.children())[6:7])

        self.layer5 = nn.Sequential(*list(resnet.children())[7:8])

 

    def forward(self,x):

        out_0 = self.layer0(x)

        out_1 = self.layer1(out_0)

        out_2 = self.layer2(out_1)

        out_3 = self.layer3(out_2)

        out_4 = self.layer4(out_3)

        out_5 = self.layer5(out_4)

        return out_0, out_1, out_2, out_3, out_4, out_5



class GramMatrix(nn.Module):

    def forward(selfinput):

        b,c,h,w = input.size()

        F = input.view(b, c, h*w)

        G = torch.bmm(F, F.transpose(1,2)) 

        return G

 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

 

resnet = Resnet().to(device)

for param in resnet.parameters():

    param.requires_grad = False    




class GramMSELoss(nn.Module):

    def forward(selfinputtarget):

        out = nn.MSELoss()(GramMatrix()(input), target)

        return out



content = image_preprocess(content_dir).to(device)

style = image_preprocess(style_dir).to(device)

generated = content.clone().requires_grad_().to(device)

 

print(content.requires_grad,style.requires_grad,generated.requires_grad)

 

plt.imshow(image_postprocess(content[0].cpu()))

plt.show()

 

plt.imshow(image_postprocess(style[0].cpu()))

plt.show()

 

gen_img = image_postprocess(generated[0].cpu()).data.numpy()

plt.imshow(gen_img)

plt.show()



style_target = list(GramMatrix().to(device)(i) for i in resnet(style))

content_target = resnet(content)[content_layer_num]

style_weight = [1/n**2 for n in [64,64,256,512,1024,2048]]

 

optimizer = optim.LBFGS([generated])

 

iteration = [0]

while iteration[0] < epoch:

    def closure():

        optimizer.zero_grad()

        out = resnet(generated)

 

        style_loss = [GramMSELoss().to(device)(out[i],style_target[i])*style_weight[i] for i in range(len(style_target))]

 

        content_loss = nn.MSELoss().to(device)(out[content_layer_num],content_target)

 

        total_loss = 1000 * sum(style_loss) + torch.sum(content_loss)

        total_loss.backward()

 

        if iteration[0] % 100 == 0:

            print(total_loss)

        iteration[0] += 1

        return total_loss

 

    optimizer.step(closure)

 

gen_img = image_postprocess(generated[0].cpu()).data.numpy()

plt.figure(figsize=(10,10))

plt.imshow(gen_img)

plt.show()

plt.savefig('nvh_gen.png')

#gen_img.save("drive/MyDrive/ST.jpg")

 

결과

 

고흐의 "별이빛나는밤에"를 Style image로 하고, 제 얼굴사진을 적용하여

마치 고흐가 그린듯한 제 얼굴 이미지를 만들어보았습니다.

 

Style Image
Content Image
Generating Image

댓글