59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
![]() |
from RL_brain import DeepQNetwork
|
|||
|
import os
|
|||
|
import sys
|
|||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
from env import PartitionMazeEnv
|
|||
|
|
|||
|
def run_maze():
|
|||
|
step = 0 # 为了记录走到第几步,记忆录中积累经验(也就是积累一些transition)之后再开始学习
|
|||
|
for episode in range(200):
|
|||
|
# initial observation
|
|||
|
observation = env.reset()
|
|||
|
|
|||
|
while True:
|
|||
|
# refresh env
|
|||
|
env.render()
|
|||
|
|
|||
|
# RL choose action based on observation
|
|||
|
action = RL.choose_action(observation)
|
|||
|
|
|||
|
# RL take action and get next observation and reward
|
|||
|
observation_, reward, done = env.step(action)
|
|||
|
|
|||
|
# !! restore transition
|
|||
|
RL.store_transition(observation, action, reward, observation_)
|
|||
|
|
|||
|
# 超过200条transition之后每隔5步学习一次
|
|||
|
if (step > 200) and (step % 5 == 0):
|
|||
|
RL.learn()
|
|||
|
|
|||
|
# swap observation
|
|||
|
observation = observation_
|
|||
|
|
|||
|
# break while loop when end of this episode
|
|||
|
if done:
|
|||
|
break
|
|||
|
step += 1
|
|||
|
|
|||
|
# end of game
|
|||
|
print("game over")
|
|||
|
env.destroy()
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# maze game
|
|||
|
env = PartitionMazeEnv()
|
|||
|
|
|||
|
# TODO 代码还没有写完,跑不了!!!
|
|||
|
RL = DeepQNetwork(env.n_actions, env.n_features,
|
|||
|
learning_rate=0.01,
|
|||
|
reward_decay=0.9,
|
|||
|
e_greedy=0.9,
|
|||
|
replace_target_iter=200,
|
|||
|
memory_size=2000)
|
|||
|
env.after(100, run_maze)
|
|||
|
env.mainloop()
|
|||
|
RL.plot_cost()
|
|||
|
|
|||
|
|