添加价值评估的mask

This commit is contained in:
weixin_46229132 2025-03-19 21:52:33 +08:00
parent 3dba6e4a53
commit c5023fb360
4 changed files with 13 additions and 6 deletions

View File

@ -65,7 +65,7 @@ class DQN_agent(object):
if state[0][0] == 0: if state[0][0] == 0:
a = np.random.randint(0,10) a = np.random.randint(0,10)
else: else:
a = np.random.randint(10,14) a = np.random.randint(10,15)
else: else:
if state[0][0] == 0: if state[0][0] == 0:
q_value = self.q_net(state) q_value = self.q_net(state)
@ -84,10 +84,17 @@ class DQN_agent(object):
'''Compute the target Q value''' '''Compute the target Q value'''
with torch.no_grad(): with torch.no_grad():
if self.Double: if self.Double:
# TODO 如果有两个过程,这里记得也要更新
argmax_a = self.q_net(s_next).argmax(dim=1).unsqueeze(-1) argmax_a = self.q_net(s_next).argmax(dim=1).unsqueeze(-1)
max_q_next = self.q_target(s_next).gather(1,argmax_a) max_q_next = self.q_target(s_next).gather(1,argmax_a)
else: else:
max_q_next = self.q_target(s_next).max(1)[0].unsqueeze(1) max_q_next = self.q_target(s_next)
# 添加动作掩码操作
if s_next[0][0] == 0:
max_q_next[:, 10:] = -float('inf')
else:
max_q_next[:, :10] = -float('inf')
max_q_next = max_q_next.max(1)[0].unsqueeze(1)
target_Q = r + (~dw) * self.gamma * max_q_next #dw: die or win target_Q = r + (~dw) * self.gamma * max_q_next #dw: die or win
# Get current Q estimates # Get current Q estimates

View File

@ -50,9 +50,9 @@ parser.add_argument('--exp_noise', type=float,
default=0.2, help='explore noise') default=0.2, help='explore noise')
parser.add_argument('--noise_decay', type=float, default=0.99, parser.add_argument('--noise_decay', type=float, default=0.99,
help='decay rate of explore noise') help='decay rate of explore noise')
parser.add_argument('--Double', type=str2bool, default=True, parser.add_argument('--Double', type=str2bool, default=False,
help='Whether to use Double Q-learning') help='Whether to use Double Q-learning')
parser.add_argument('--Duel', type=str2bool, default=True, parser.add_argument('--Duel', type=str2bool, default=False,
help='Whether to use Duel networks') help='Whether to use Duel networks')
opt = parser.parse_args() opt = parser.parse_args()
opt.dvc = torch.device(opt.dvc) # from str to torch.device opt.dvc = torch.device(opt.dvc) # from str to torch.device

2
env.py
View File

@ -57,7 +57,7 @@ class PartitionMazeEnv(gym.Env):
# 阶段 1 状态:区域访问状态向量(长度为(CUT_NUM/2+1)^2 # 阶段 1 状态:区域访问状态向量(长度为(CUT_NUM/2+1)^2
max_regions = (self.CUT_NUM // 2 + 1) ** 2 max_regions = (self.CUT_NUM // 2 + 1) ** 2
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(self.CUT_NUM + max_regions,), dtype=np.float32) low=0.0, high=100.0, shape=(self.CUT_NUM + max_regions,), dtype=np.float32)
# 切分阶段相关变量 # 切分阶段相关变量
self.col_cuts = [] # 存储竖切位置c₁, c₂当值为0时表示不切 self.col_cuts = [] # 存储竖切位置c₁, c₂当值为0时表示不切

View File

@ -57,7 +57,7 @@ class PartitionMazeEnv(gym.Env):
# 阶段 1 状态:车辆位置 (2D) # 阶段 1 状态:车辆位置 (2D)
max_regions = (self.CUT_NUM // 2 + 1) ** 2 max_regions = (self.CUT_NUM // 2 + 1) ** 2
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(1 + self.CUT_NUM + max_regions,), dtype=np.float32) low=0.0, high=100.0, shape=(1 + self.CUT_NUM + max_regions,), dtype=np.float32)
# 切分阶段相关变量 # 切分阶段相关变量
self.col_cuts = [] # 存储竖切位置c₁, c₂当值为0时表示不切 self.col_cuts = [] # 存储竖切位置c₁, c₂当值为0时表示不切