[pytorch 따라하기-1] 구글 Colab에 pytorch 세팅하기 https://limitsinx.tistory.com/136
[pytorch 따라하기-2] Tensor생성 및 Backward https://limitsinx.tistory.com/137
[pytorch 따라하기-3] 경사하강법을 통한 선형회귀 구현 https://limitsinx.tistory.com/138
[pytorch 따라하기-4] 인공신경망(ANN) 구현 https://limitsinx.tistory.com/139
[pytorch 따라하기-5] 합성곱신경망(CNN) 구현 https://limitsinx.tistory.com/140
※이 전글에서 정리한 코드/문법은 재설명하지 않으므로, 참고부탁드립니다
※해당 글은 PC에서 보기에 최적화 되어있습니다.
제가 pytorch를 공부하는이유는.. GAN계열의 딥러닝을 하기 위해서인데요
그 시초?라고 할 수 있는 Neural Style Transfer를 구현해보겠습니다.
내용이 워낙방대하여.. ResNet-50이라는 CNN계열의 논문설명과 함께 차후 한개씩 정리해나가도록 하겠습니다.
이번글은 코드와 결과만 정리해보겠습니다.
코드
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
import torch.utils.data as data
import torchvision.models as models
import torchvision.utils as v_utils
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
%matplotlib inline
content_layer_num = 1
image_size = 512
epoch = 5000
content_dir = "본인 PC에 저장된 이미지 디렉토리 입력"
style_dir = "본인 PC에 저장된 이미지 디렉토리 입력"
def image_preprocess(img_dir):
img = Image.open(img_dir)
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961],
std=[1,1,1]),
#transforms.Normalize([0.5], [0.5])
])
img = transform(img).view((-1,3,image_size,image_size))
return img
def image_postprocess(tensor):
transform = transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961],
std=[1,1,1])
#transform = transforms.Normalize([0.5], [0.5])
img = transform(tensor.clone())
img = img.clamp(0,1)
img = torch.transpose(img,0,1)
img = torch.transpose(img,1,2)
return img
resnet = models.resnet50(pretrained=True)
for name,module in resnet.named_children():
print(name)
class Resnet(nn.Module):
def __init__(self):
super(Resnet,self).__init__()
self.layer0 = nn.Sequential(*list(resnet.children())[0:1])
self.layer1 = nn.Sequential(*list(resnet.children())[1:4])
self.layer2 = nn.Sequential(*list(resnet.children())[4:5])
self.layer3 = nn.Sequential(*list(resnet.children())[5:6])
self.layer4 = nn.Sequential(*list(resnet.children())[6:7])
self.layer5 = nn.Sequential(*list(resnet.children())[7:8])
def forward(self,x):
out_0 = self.layer0(x)
out_1 = self.layer1(out_0)
out_2 = self.layer2(out_1)
out_3 = self.layer3(out_2)
out_4 = self.layer4(out_3)
out_5 = self.layer5(out_4)
return out_0, out_1, out_2, out_3, out_4, out_5
class GramMatrix(nn.Module):
def forward(self, input):
b,c,h,w = input.size()
F = input.view(b, c, h*w)
G = torch.bmm(F, F.transpose(1,2))
return G
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
resnet = Resnet().to(device)
for param in resnet.parameters():
param.requires_grad = False
class GramMSELoss(nn.Module):
def forward(self, input, target):
out = nn.MSELoss()(GramMatrix()(input), target)
return out
content = image_preprocess(content_dir).to(device)
style = image_preprocess(style_dir).to(device)
generated = content.clone().requires_grad_().to(device)
print(content.requires_grad,style.requires_grad,generated.requires_grad)
plt.imshow(image_postprocess(content[0].cpu()))
plt.show()
plt.imshow(image_postprocess(style[0].cpu()))
plt.show()
gen_img = image_postprocess(generated[0].cpu()).data.numpy()
plt.imshow(gen_img)
plt.show()
style_target = list(GramMatrix().to(device)(i) for i in resnet(style))
content_target = resnet(content)[content_layer_num]
style_weight = [1/n**2 for n in [64,64,256,512,1024,2048]]
optimizer = optim.LBFGS([generated])
iteration = [0]
while iteration[0] < epoch:
def closure():
optimizer.zero_grad()
out = resnet(generated)
style_loss = [GramMSELoss().to(device)(out[i],style_target[i])*style_weight[i] for i in range(len(style_target))]
content_loss = nn.MSELoss().to(device)(out[content_layer_num],content_target)
total_loss = 1000 * sum(style_loss) + torch.sum(content_loss)
total_loss.backward()
if iteration[0] % 100 == 0:
print(total_loss)
iteration[0] += 1
return total_loss
optimizer.step(closure)
gen_img = image_postprocess(generated[0].cpu()).data.numpy()
plt.figure(figsize=(10,10))
plt.imshow(gen_img)
plt.show()
plt.savefig('nvh_gen.png')
#gen_img.save("drive/MyDrive/ST.jpg")
결과
고흐의 "별이빛나는밤에"를 Style image로 하고, 제 얼굴사진을 적용하여
마치 고흐가 그린듯한 제 얼굴 이미지를 만들어보았습니다.
'DeepLearning Framework & Coding > Pytorch' 카테고리의 다른 글
[pytorch 따라하기-8] DC-GAN(Deep Convolutional Generative Adversarial Network) 구현 (0) | 2021.07.27 |
---|---|
[pytorch 따라하기-7] pix2pix 구현 (0) | 2021.07.27 |
[pytorch 따라하기-5] 합성곱신경망(CNN) 구현 (0) | 2021.07.25 |
[pytorch 따라하기-4] 인공신경망(ANN) 구현 (0) | 2021.07.25 |
[pytorch 따라하기-3] 경사하강법을 통한 선형회귀 구현 (0) | 2021.07.25 |
댓글