Cartpole(역진자운동)
Deep Q Network의 원논문 "Playing Atari with Reinforcement Learning" 중 Atari는 이것입니다.
어렸을적 TV에 연결해서 하던 게임셋을 Atari라고 하는데요, 여기에 존재하던 많은 게임들을 Reinforcement Learning을 적용하여 해결해보자라는 것입니다!
(게임이 보상이 확실하기에 강화학습을 적용하기에 아주좋죠!)
이중, Cartpole이란 게임은 좌/우 방향키만 움직여서 얼마나 오랫동안 버틸수있을까?를 경쟁하는 게임입니다 :)
이 게임에 DQN을 적용하여 확인해보고자합니다.
Cartpole의 현재 state는 4개요소로 정의합니다.
DQN 하이퍼파라미터 세팅
DQN의 하이퍼파라미터로 EPISODE는 한게임을 처음부터 끝날때까지 진행하는것을 의미합니다.
즉, 100번 Cartpole게임을 반복하는것입니다.
EPS는 Epsilon으로 E-Greedy 탐색법에서 0.9부터 0.05까지 EPISODE를 반복할수록 낮추어가며 진행한다는 의미입니다.
Learning Rate는 0.001이며, mini batch를 통해 진행하고자 batch_size를 32로 둡니다.
DQN class 설정
DQN-Agent 클래스를 선언하고,
현재의 Cartpole State를 의미하는 4가지 요소를 Input으로 받고 2개의 Output(좌/우)을 내는 신경망을 세팅합니다.
Mini batch를 적용하기위한 memory를 설정하는 부분이 self.memory=deque입니다.
Deque는 Double Ended Queue로, 데이터 저장공간이 들어오는문, 나가는문이 모두 열려있는것을 나타냅니다.
(즉, memory size를 10000으로 두고 9999개에서 4개가 더들어오는순간, deque는 10000개가되고 가장 처음들어왔던 3개는 사라집니다.)
DQN-Agent의 행동 함수부분으로,
E-greedy 탐색법을 위한 epsilon을 바꾸어가고 스텝카운트를 하며,
epsilon보다 확률이 낮은 경우, random하게 action을 선택하고,
epsilon보다 확률이 높은 경우, Q-value를 Max로하는 action을 선택
batch Size(32)보다 작은경우에는, 학습 미진행(bias 생길수도 있으므로)
-DQN 개념설명글 참고
https://limitsinx.tistory.com/154
학습은expected_q와 current_q의 MSE(Mean Square Error)로 진행하며, 벨만방정식에 의거하여 Q-Learning기반 해석에서 나온 수식입니다.
Cartpole게임 환경설정
OpenAI에서 제공하는 gym 라이브러리에서 Cartpole을 불러와서 학습을 위한 환경을 세팅합니다.
DQN-Cartpole 메인
Episode 100회를 반복하며, 균형을 맞추었을 경우 보상(reward)를 계속 +1씩주며, 균형을 맞추지 못한순간 보상을 -1을 주는것으로 전체 함수가 돌아갑니다.
결과
EPISODE를 100번 진행하며 점점 Cartpole이 잘되는것을 확인할 수 있습니다.
끝내는 200번의 Cartpole까지 다가가는데요
아주잘 동작하는 것을 확인할 수 있습니다.
마침 Auto Encoder를 공부하고있어서, 오토인코더처럼 좀 복잡하게 신경망을 구성하여 DQN을 돌려보았습니다.
요렇게 신경망을 조금 복잡하게 짜서 돌려보았더니.. 간단한 신경망대비 200(최대Score)까지는 빨리 도달하지만, 시스템이 안정화가 되지못해서 등락폭이 크게 유지되는것을 확인할 수 있었습니다.
즉, 신경망을 마냥 복잡하게 짠다고 성능이 향상되는것은 아니다라는것을 확인할 수 있었습니다 :)
Code
import gym
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import matplotlib.pyplot as plt
#Hyper Parameter
EPISODES = 100
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
GAMMA = 0.8
LR = 0.001
BATCH_SIZE = 32
class DQNAgent:
def __init__(self):
self.model = nn.Sequential(
nn.Linear(4, 128),
nn.ReLU(),
nn.Linear(128,256),
nn.ReLU(),
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128, 2)
)
self.optimizer = optim.Adam(self.model.parameters(), LR)
self.steps_done = 0
self.memory = deque(maxlen=10000)
def memorize(self, state, action, reward, next_state):
self.memory.append((state,
action,
torch.FloatTensor([reward]),
torch.FloatTensor([next_state])))
def act(self, state):
eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * self.steps_done / EPS_DECAY)
self.steps_done += 1
if random.random() > eps_threshold:
return self.model(state).data.max(1)[1].view(1, 1)
else:
return torch.LongTensor([[random.randrange(2)]])
def learn(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
states, actions, rewards, next_states = zip(*batch)
states = torch.cat(states)
actions = torch.cat(actions)
rewards = torch.cat(rewards)
next_states = torch.cat(next_states)
current_q = self.model(states).gather(1, actions)
max_next_q = self.model(next_states).detach().max(1)[0]
expected_q = rewards + (GAMMA * max_next_q)
loss = F.mse_loss(current_q.squeeze(), expected_q)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
env = gym.make('CartPole-v0')
agent = DQNAgent()
score_history = []
for e in range(1, EPISODES+1):
state = env.reset()
steps = 0
while True:
state = torch.FloatTensor([state])
action = agent.act(state)
next_state, reward, done, _ = env.step(action.item())
if done:
reward = -1
agent.memorize(state, action, reward, next_state)
agent.learn()
state = next_state
steps += 1
if done:
print("Eposide:{0} Score: {1}".format(e, steps))
score_history.append(steps)
break
plt.plot(score_history)
plt.ylabel('score')
plt.show()
[참고]
3분딥러닝 파이토치맛
https://github.com/yellowjs0304/3-min-pytorch_study
https://yjs-program.tistory.com/173
'AI > Reinforcement Learning' 카테고리의 다른 글
[Reinforcement Learning-5] Deep Q-Network으로 최적경로 찾기 (2) | 2021.09.05 |
---|---|
[Reinforcement Learning-4] Deep Q-Network(DQN)에 대한 간단한 이해 (0) | 2021.08.29 |
[Reinforcement Learning-3] Q-Learning으로 최적경로 찾기 (0) | 2021.08.28 |
[Reinforcement Learning-2] Q-Learning에 대한 간단한 이해 (0) | 2021.08.26 |
[Reinforcement Learning-1] Thompson sampling model (2) | 2021.08.26 |
댓글