修改dqn
This commit is contained in:
parent
f19e8fbdbf
commit
2362de4c54
@ -69,11 +69,11 @@ class DQN_agent(object):
|
|||||||
else:
|
else:
|
||||||
if state[0][0] == 0:
|
if state[0][0] == 0:
|
||||||
q_value = self.q_net(state)
|
q_value = self.q_net(state)
|
||||||
q_value[:10] = - float('inf')
|
q_value[10:] = - float('inf')
|
||||||
a = q_value.argmax().item()
|
a = q_value.argmax().item()
|
||||||
else:
|
else:
|
||||||
q_value = self.q_net(state)
|
q_value = self.q_net(state)
|
||||||
q_value[10:] = - float('inf')
|
q_value[:10] = - float('inf')
|
||||||
a = q_value.argmax().item()
|
a = q_value.argmax().item()
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ def main():
|
|||||||
print('EnvName:', BriefEnvName[opt.EnvIdex],
|
print('EnvName:', BriefEnvName[opt.EnvIdex],
|
||||||
'seed:', opt.seed, 'score:', score)
|
'seed:', opt.seed, 'score:', score)
|
||||||
else:
|
else:
|
||||||
total_steps = 0
|
total_steps = 1
|
||||||
while total_steps < opt.Max_train_steps:
|
while total_steps < opt.Max_train_steps:
|
||||||
# Do not use opt.seed directly, or it can overfit to opt.seed
|
# Do not use opt.seed directly, or it can overfit to opt.seed
|
||||||
s = env.reset(seed=env_seed)
|
s = env.reset(seed=env_seed)
|
||||||
@ -122,6 +122,7 @@ def main():
|
|||||||
while not done:
|
while not done:
|
||||||
# e-greedy exploration
|
# e-greedy exploration
|
||||||
if total_steps < opt.random_steps:
|
if total_steps < opt.random_steps:
|
||||||
|
# TODO sample取值有问题
|
||||||
a = env.action_space.sample()
|
a = env.action_space.sample()
|
||||||
else:
|
else:
|
||||||
a = agent.select_action(s, deterministic=False)
|
a = agent.select_action(s, deterministic=False)
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from env import PartitionMazeEnv
|
# from env import PartitionMazeEnv
|
||||||
|
from env_dis import PartitionMazeEnv
|
||||||
|
|
||||||
env = PartitionMazeEnv()
|
env = PartitionMazeEnv()
|
||||||
|
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
print(state)
|
print(state)
|
||||||
|
|
||||||
action_series = [[0], [0.5], [0], [0.2], [0.4], [0.7], [0.3], [0.8], [0.5], [0.1], [0.7], [0.7], [0.9], [0.9], [0.1], [0.9], [0.9], [0.1]]
|
action_series = [0, 0, 3, 0, 0, 10]
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
action = action_series[i]
|
action = action_series[i]
|
||||||
|
Loading…
Reference in New Issue
Block a user