diff --git a/Duel_Double_DQN/DQN.py b/Duel_Double_DQN/DQN.py index 92773a8..df3313a 100644 --- a/Duel_Double_DQN/DQN.py +++ b/Duel_Double_DQN/DQN.py @@ -65,7 +65,7 @@ class DQN_agent(object): if state[0][0] == 0: a = np.random.randint(0,10) else: - a = np.random.randint(10,14) + a = np.random.randint(10,15) else: if state[0][0] == 0: q_value = self.q_net(state) @@ -84,10 +84,17 @@ class DQN_agent(object): '''Compute the target Q value''' with torch.no_grad(): if self.Double: + # TODO 如果有两个过程,这里记得也要更新 argmax_a = self.q_net(s_next).argmax(dim=1).unsqueeze(-1) max_q_next = self.q_target(s_next).gather(1,argmax_a) else: - max_q_next = self.q_target(s_next).max(1)[0].unsqueeze(1) + max_q_next = self.q_target(s_next) + # 添加动作掩码操作 + if s_next[0][0] == 0: + max_q_next[:, 10:] = -float('inf') + else: + max_q_next[:, :10] = -float('inf') + max_q_next = max_q_next.max(1)[0].unsqueeze(1) target_Q = r + (~dw) * self.gamma * max_q_next #dw: die or win # Get current Q estimates diff --git a/Duel_Double_DQN/main.py b/Duel_Double_DQN/main.py index e470bdf..693c9ba 100644 --- a/Duel_Double_DQN/main.py +++ b/Duel_Double_DQN/main.py @@ -50,9 +50,9 @@ parser.add_argument('--exp_noise', type=float, default=0.2, help='explore noise') parser.add_argument('--noise_decay', type=float, default=0.99, help='decay rate of explore noise') -parser.add_argument('--Double', type=str2bool, default=True, +parser.add_argument('--Double', type=str2bool, default=False, help='Whether to use Double Q-learning') -parser.add_argument('--Duel', type=str2bool, default=True, +parser.add_argument('--Duel', type=str2bool, default=False, help='Whether to use Duel networks') opt = parser.parse_args() opt.dvc = torch.device(opt.dvc) # from str to torch.device diff --git a/env.py b/env.py index 789c4bb..b1b84cc 100644 --- a/env.py +++ b/env.py @@ -57,7 +57,7 @@ class PartitionMazeEnv(gym.Env): # 阶段 1 状态:区域访问状态向量(长度为(CUT_NUM/2+1)^2) max_regions = (self.CUT_NUM // 2 + 1) ** 2 self.observation_space = spaces.Box( - low=0.0, high=1.0, shape=(self.CUT_NUM + max_regions,), dtype=np.float32) + low=0.0, high=100.0, shape=(self.CUT_NUM + max_regions,), dtype=np.float32) # 切分阶段相关变量 self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切 diff --git a/env_dis.py b/env_dis.py index dd06411..91ac312 100644 --- a/env_dis.py +++ b/env_dis.py @@ -57,7 +57,7 @@ class PartitionMazeEnv(gym.Env): # 阶段 1 状态:车辆位置 (2D) max_regions = (self.CUT_NUM // 2 + 1) ** 2 self.observation_space = spaces.Box( - low=0.0, high=1.0, shape=(1 + self.CUT_NUM + max_regions,), dtype=np.float32) + low=0.0, high=100.0, shape=(1 + self.CUT_NUM + max_regions,), dtype=np.float32) # 切分阶段相关变量 self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切