添加价值评估的mask
This commit is contained in:
parent
3dba6e4a53
commit
c5023fb360
@ -65,7 +65,7 @@ class DQN_agent(object):
|
||||
if state[0][0] == 0:
|
||||
a = np.random.randint(0,10)
|
||||
else:
|
||||
a = np.random.randint(10,14)
|
||||
a = np.random.randint(10,15)
|
||||
else:
|
||||
if state[0][0] == 0:
|
||||
q_value = self.q_net(state)
|
||||
@ -84,10 +84,17 @@ class DQN_agent(object):
|
||||
'''Compute the target Q value'''
|
||||
with torch.no_grad():
|
||||
if self.Double:
|
||||
# TODO 如果有两个过程,这里记得也要更新
|
||||
argmax_a = self.q_net(s_next).argmax(dim=1).unsqueeze(-1)
|
||||
max_q_next = self.q_target(s_next).gather(1,argmax_a)
|
||||
else:
|
||||
max_q_next = self.q_target(s_next).max(1)[0].unsqueeze(1)
|
||||
max_q_next = self.q_target(s_next)
|
||||
# 添加动作掩码操作
|
||||
if s_next[0][0] == 0:
|
||||
max_q_next[:, 10:] = -float('inf')
|
||||
else:
|
||||
max_q_next[:, :10] = -float('inf')
|
||||
max_q_next = max_q_next.max(1)[0].unsqueeze(1)
|
||||
target_Q = r + (~dw) * self.gamma * max_q_next #dw: die or win
|
||||
|
||||
# Get current Q estimates
|
||||
|
@ -50,9 +50,9 @@ parser.add_argument('--exp_noise', type=float,
|
||||
default=0.2, help='explore noise')
|
||||
parser.add_argument('--noise_decay', type=float, default=0.99,
|
||||
help='decay rate of explore noise')
|
||||
parser.add_argument('--Double', type=str2bool, default=True,
|
||||
parser.add_argument('--Double', type=str2bool, default=False,
|
||||
help='Whether to use Double Q-learning')
|
||||
parser.add_argument('--Duel', type=str2bool, default=True,
|
||||
parser.add_argument('--Duel', type=str2bool, default=False,
|
||||
help='Whether to use Duel networks')
|
||||
opt = parser.parse_args()
|
||||
opt.dvc = torch.device(opt.dvc) # from str to torch.device
|
||||
|
2
env.py
2
env.py
@ -57,7 +57,7 @@ class PartitionMazeEnv(gym.Env):
|
||||
# 阶段 1 状态:区域访问状态向量(长度为(CUT_NUM/2+1)^2)
|
||||
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||
self.observation_space = spaces.Box(
|
||||
low=0.0, high=1.0, shape=(self.CUT_NUM + max_regions,), dtype=np.float32)
|
||||
low=0.0, high=100.0, shape=(self.CUT_NUM + max_regions,), dtype=np.float32)
|
||||
|
||||
# 切分阶段相关变量
|
||||
self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切
|
||||
|
@ -57,7 +57,7 @@ class PartitionMazeEnv(gym.Env):
|
||||
# 阶段 1 状态:车辆位置 (2D)
|
||||
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||
self.observation_space = spaces.Box(
|
||||
low=0.0, high=1.0, shape=(1 + self.CUT_NUM + max_regions,), dtype=np.float32)
|
||||
low=0.0, high=100.0, shape=(1 + self.CUT_NUM + max_regions,), dtype=np.float32)
|
||||
|
||||
# 切分阶段相关变量
|
||||
self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切
|
||||
|
Loading…
Reference in New Issue
Block a user