HPCC2025/Duel_Double_DQN/DQN.py

152 lines
5.0 KiB
Python
Raw Normal View History

2025-03-18 21:16:48 +08:00
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torch
import copy
def build_net(layer_shape, activation, output_activation):
'''Build networks with For loop'''
layers = []
for j in range(len(layer_shape)-1):
act = activation if j < len(layer_shape)-2 else output_activation
layers += [nn.Linear(layer_shape[j], layer_shape[j+1]), act()]
return nn.Sequential(*layers)
class Q_Net(nn.Module):
def __init__(self, state_dim, action_dim, hid_shape):
super(Q_Net, self).__init__()
layers = [state_dim] + list(hid_shape) + [action_dim]
self.Q = build_net(layers, nn.ReLU, nn.Identity)
def forward(self, s):
q = self.Q(s)
return q
class Duel_Q_Net(nn.Module):
def __init__(self, state_dim, action_dim, hid_shape):
super(Duel_Q_Net, self).__init__()
layers = [state_dim] + list(hid_shape)
self.hidden = build_net(layers, nn.ReLU, nn.ReLU)
self.V = nn.Linear(hid_shape[-1], 1)
self.A = nn.Linear(hid_shape[-1], action_dim)
def forward(self, s):
s = self.hidden(s)
Adv = self.A(s)
V = self.V(s)
Q = V + (Adv - torch.mean(Adv, dim=-1, keepdim=True)) # Q(s,a)=V(s)+A(s,a)-mean(A(s,a))
return Q
class DQN_agent(object):
def __init__(self, **kwargs):
# Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..."
self.__dict__.update(kwargs)
self.tau = 0.005
self.replay_buffer = ReplayBuffer(self.state_dim, self.dvc, max_size=int(1e6))
if self.Duel:
self.q_net = Duel_Q_Net(self.state_dim, self.action_dim, (self.net_width,self.net_width)).to(self.dvc)
else:
self.q_net = Q_Net(self.state_dim, self.action_dim, (self.net_width, self.net_width)).to(self.dvc)
self.q_net_optimizer = torch.optim.Adam(self.q_net.parameters(), lr=self.lr)
self.q_target = copy.deepcopy(self.q_net)
# Freeze target networks with respect to optimizers (only update via polyak averaging)
for p in self.q_target.parameters(): p.requires_grad = False
def select_action(self, state, deterministic):#only used when interact with the env
with torch.no_grad():
state = torch.FloatTensor(state.reshape(1, -1)).to(self.dvc)
# if deterministic:
# a = self.q_net(state).argmax().item()
# else:
if np.random.rand() < self.exp_noise:
if state[0][0] == 0:
a = np.random.randint(0,10)
else:
2025-03-19 21:52:33 +08:00
a = np.random.randint(10,15)
2025-03-18 21:16:48 +08:00
else:
if state[0][0] == 0:
q_value = self.q_net(state)
2025-03-19 20:40:35 +08:00
q_value[0][10:] = - float('inf')
2025-03-18 21:16:48 +08:00
a = q_value.argmax().item()
else:
q_value = self.q_net(state)
2025-03-19 20:40:35 +08:00
q_value[0][:10] = - float('inf')
2025-03-18 21:16:48 +08:00
a = q_value.argmax().item()
return a
def train(self):
s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size)
'''Compute the target Q value'''
with torch.no_grad():
if self.Double:
2025-03-19 21:52:33 +08:00
# TODO 如果有两个过程,这里记得也要更新
2025-03-18 21:16:48 +08:00
argmax_a = self.q_net(s_next).argmax(dim=1).unsqueeze(-1)
max_q_next = self.q_target(s_next).gather(1,argmax_a)
else:
2025-03-19 21:52:33 +08:00
max_q_next = self.q_target(s_next)
# 添加动作掩码操作
if s_next[0][0] == 0:
max_q_next[:, 10:] = -float('inf')
else:
max_q_next[:, :10] = -float('inf')
max_q_next = max_q_next.max(1)[0].unsqueeze(1)
2025-03-18 21:16:48 +08:00
target_Q = r + (~dw) * self.gamma * max_q_next #dw: die or win
# Get current Q estimates
current_q = self.q_net(s)
current_q_a = current_q.gather(1,a)
q_loss = F.mse_loss(current_q_a, target_Q)
self.q_net_optimizer.zero_grad()
q_loss.backward()
self.q_net_optimizer.step()
# Update the frozen target models
for param, target_param in zip(self.q_net.parameters(), self.q_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def save(self,algo,EnvName,steps):
torch.save(self.q_net.state_dict(), "./weights/{}_{}_{}.pth".format(algo,EnvName,steps))
def load(self,algo,EnvName,steps):
self.q_net.load_state_dict(torch.load("./model/{}_{}_{}.pth".format(algo,EnvName,steps),map_location=self.dvc))
self.q_target.load_state_dict(torch.load("./model/{}_{}_{}.pth".format(algo,EnvName,steps),map_location=self.dvc))
class ReplayBuffer(object):
def __init__(self, state_dim, dvc, max_size=int(1e6)):
self.max_size = max_size
self.dvc = dvc
self.ptr = 0
self.size = 0
self.s = torch.zeros((max_size, state_dim),dtype=torch.float,device=self.dvc)
self.a = torch.zeros((max_size, 1),dtype=torch.long,device=self.dvc)
self.r = torch.zeros((max_size, 1),dtype=torch.float,device=self.dvc)
self.s_next = torch.zeros((max_size, state_dim),dtype=torch.float,device=self.dvc)
self.dw = torch.zeros((max_size, 1),dtype=torch.bool,device=self.dvc)
def add(self, s, a, r, s_next, dw):
self.s[self.ptr] = torch.from_numpy(s).to(self.dvc)
self.a[self.ptr] = a
self.r[self.ptr] = r
self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc)
self.dw[self.ptr] = dw
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,))
return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind]