修改dqn

This commit is contained in:
weixin_46229132 2025-03-19 01:04:03 +08:00
parent f19e8fbdbf
commit 2362de4c54
3 changed files with 8 additions and 6 deletions

View File

@ -69,11 +69,11 @@ class DQN_agent(object):
else: else:
if state[0][0] == 0: if state[0][0] == 0:
q_value = self.q_net(state) q_value = self.q_net(state)
q_value[:10] = - float('inf') q_value[10:] = - float('inf')
a = q_value.argmax().item() a = q_value.argmax().item()
else: else:
q_value = self.q_net(state) q_value = self.q_net(state)
q_value[10:] = - float('inf') q_value[:10] = - float('inf')
a = q_value.argmax().item() a = q_value.argmax().item()
return a return a

View File

@ -111,7 +111,7 @@ def main():
print('EnvName:', BriefEnvName[opt.EnvIdex], print('EnvName:', BriefEnvName[opt.EnvIdex],
'seed:', opt.seed, 'score:', score) 'seed:', opt.seed, 'score:', score)
else: else:
total_steps = 0 total_steps = 1
while total_steps < opt.Max_train_steps: while total_steps < opt.Max_train_steps:
# Do not use opt.seed directly, or it can overfit to opt.seed # Do not use opt.seed directly, or it can overfit to opt.seed
s = env.reset(seed=env_seed) s = env.reset(seed=env_seed)
@ -122,6 +122,7 @@ def main():
while not done: while not done:
# e-greedy exploration # e-greedy exploration
if total_steps < opt.random_steps: if total_steps < opt.random_steps:
# TODO sample取值有问题
a = env.action_space.sample() a = env.action_space.sample()
else: else:
a = agent.select_action(s, deterministic=False) a = agent.select_action(s, deterministic=False)

View File

@ -1,11 +1,12 @@
from env import PartitionMazeEnv # from env import PartitionMazeEnv
from env_dis import PartitionMazeEnv
env = PartitionMazeEnv() env = PartitionMazeEnv()
state = env.reset() state = env.reset()
print(state) print(state)
action_series = [[0], [0.5], [0], [0.2], [0.4], [0.7], [0.3], [0.8], [0.5], [0.1], [0.7], [0.7], [0.9], [0.9], [0.1], [0.9], [0.9], [0.1]] action_series = [0, 0, 3, 0, 0, 10]
for i in range(100): for i in range(100):
action = action_series[i] action = action_series[i]