修改dqn
This commit is contained in:
parent
f19e8fbdbf
commit
2362de4c54
@ -69,11 +69,11 @@ class DQN_agent(object):
|
||||
else:
|
||||
if state[0][0] == 0:
|
||||
q_value = self.q_net(state)
|
||||
q_value[:10] = - float('inf')
|
||||
q_value[10:] = - float('inf')
|
||||
a = q_value.argmax().item()
|
||||
else:
|
||||
q_value = self.q_net(state)
|
||||
q_value[10:] = - float('inf')
|
||||
q_value[:10] = - float('inf')
|
||||
a = q_value.argmax().item()
|
||||
return a
|
||||
|
||||
|
@ -111,7 +111,7 @@ def main():
|
||||
print('EnvName:', BriefEnvName[opt.EnvIdex],
|
||||
'seed:', opt.seed, 'score:', score)
|
||||
else:
|
||||
total_steps = 0
|
||||
total_steps = 1
|
||||
while total_steps < opt.Max_train_steps:
|
||||
# Do not use opt.seed directly, or it can overfit to opt.seed
|
||||
s = env.reset(seed=env_seed)
|
||||
@ -122,6 +122,7 @@ def main():
|
||||
while not done:
|
||||
# e-greedy exploration
|
||||
if total_steps < opt.random_steps:
|
||||
# TODO sample取值有问题
|
||||
a = env.action_space.sample()
|
||||
else:
|
||||
a = agent.select_action(s, deterministic=False)
|
||||
|
@ -1,15 +1,16 @@
|
||||
from env import PartitionMazeEnv
|
||||
# from env import PartitionMazeEnv
|
||||
from env_dis import PartitionMazeEnv
|
||||
|
||||
env = PartitionMazeEnv()
|
||||
|
||||
state = env.reset()
|
||||
print(state)
|
||||
|
||||
action_series = [[0], [0.5], [0], [0.2], [0.4], [0.7], [0.3], [0.8], [0.5], [0.1], [0.7], [0.7], [0.9], [0.9], [0.1], [0.9], [0.9], [0.1]]
|
||||
action_series = [0, 0, 3, 0, 0, 10]
|
||||
|
||||
for i in range(100):
|
||||
action = action_series[i]
|
||||
state, reward, done, info, _ = env.step(action)
|
||||
print(state, reward, done, info)
|
||||
if done:
|
||||
break
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user