添加价值评估的mask
This commit is contained in:
parent
3dba6e4a53
commit
c5023fb360
@ -65,7 +65,7 @@ class DQN_agent(object):
|
|||||||
if state[0][0] == 0:
|
if state[0][0] == 0:
|
||||||
a = np.random.randint(0,10)
|
a = np.random.randint(0,10)
|
||||||
else:
|
else:
|
||||||
a = np.random.randint(10,14)
|
a = np.random.randint(10,15)
|
||||||
else:
|
else:
|
||||||
if state[0][0] == 0:
|
if state[0][0] == 0:
|
||||||
q_value = self.q_net(state)
|
q_value = self.q_net(state)
|
||||||
@ -84,10 +84,17 @@ class DQN_agent(object):
|
|||||||
'''Compute the target Q value'''
|
'''Compute the target Q value'''
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.Double:
|
if self.Double:
|
||||||
|
# TODO 如果有两个过程,这里记得也要更新
|
||||||
argmax_a = self.q_net(s_next).argmax(dim=1).unsqueeze(-1)
|
argmax_a = self.q_net(s_next).argmax(dim=1).unsqueeze(-1)
|
||||||
max_q_next = self.q_target(s_next).gather(1,argmax_a)
|
max_q_next = self.q_target(s_next).gather(1,argmax_a)
|
||||||
else:
|
else:
|
||||||
max_q_next = self.q_target(s_next).max(1)[0].unsqueeze(1)
|
max_q_next = self.q_target(s_next)
|
||||||
|
# 添加动作掩码操作
|
||||||
|
if s_next[0][0] == 0:
|
||||||
|
max_q_next[:, 10:] = -float('inf')
|
||||||
|
else:
|
||||||
|
max_q_next[:, :10] = -float('inf')
|
||||||
|
max_q_next = max_q_next.max(1)[0].unsqueeze(1)
|
||||||
target_Q = r + (~dw) * self.gamma * max_q_next #dw: die or win
|
target_Q = r + (~dw) * self.gamma * max_q_next #dw: die or win
|
||||||
|
|
||||||
# Get current Q estimates
|
# Get current Q estimates
|
||||||
|
@ -50,9 +50,9 @@ parser.add_argument('--exp_noise', type=float,
|
|||||||
default=0.2, help='explore noise')
|
default=0.2, help='explore noise')
|
||||||
parser.add_argument('--noise_decay', type=float, default=0.99,
|
parser.add_argument('--noise_decay', type=float, default=0.99,
|
||||||
help='decay rate of explore noise')
|
help='decay rate of explore noise')
|
||||||
parser.add_argument('--Double', type=str2bool, default=True,
|
parser.add_argument('--Double', type=str2bool, default=False,
|
||||||
help='Whether to use Double Q-learning')
|
help='Whether to use Double Q-learning')
|
||||||
parser.add_argument('--Duel', type=str2bool, default=True,
|
parser.add_argument('--Duel', type=str2bool, default=False,
|
||||||
help='Whether to use Duel networks')
|
help='Whether to use Duel networks')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
opt.dvc = torch.device(opt.dvc) # from str to torch.device
|
opt.dvc = torch.device(opt.dvc) # from str to torch.device
|
||||||
|
2
env.py
2
env.py
@ -57,7 +57,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# 阶段 1 状态:区域访问状态向量(长度为(CUT_NUM/2+1)^2)
|
# 阶段 1 状态:区域访问状态向量(长度为(CUT_NUM/2+1)^2)
|
||||||
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=0.0, high=1.0, shape=(self.CUT_NUM + max_regions,), dtype=np.float32)
|
low=0.0, high=100.0, shape=(self.CUT_NUM + max_regions,), dtype=np.float32)
|
||||||
|
|
||||||
# 切分阶段相关变量
|
# 切分阶段相关变量
|
||||||
self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切
|
self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切
|
||||||
|
@ -57,7 +57,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# 阶段 1 状态:车辆位置 (2D)
|
# 阶段 1 状态:车辆位置 (2D)
|
||||||
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=0.0, high=1.0, shape=(1 + self.CUT_NUM + max_regions,), dtype=np.float32)
|
low=0.0, high=100.0, shape=(1 + self.CUT_NUM + max_regions,), dtype=np.float32)
|
||||||
|
|
||||||
# 切分阶段相关变量
|
# 切分阶段相关变量
|
||||||
self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切
|
self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切
|
||||||
|
Loading…
Reference in New Issue
Block a user