HPCC2025/DQN/dqn.py
2025-03-09 16:53:01 +08:00

95 lines
3.1 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
class DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super(DQN, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, action_dim)
)
def forward(self, x):
return self.network(x)
class Agent:
def __init__(self, state_dim, action_dim):
self.state_dim = state_dim
self.action_dim = action_dim
# DQN网络
self.eval_net = DQN(state_dim, action_dim)
self.target_net = DQN(state_dim, action_dim)
self.target_net.load_state_dict(self.eval_net.state_dict())
# 训练参数
self.learning_rate = 0.001
self.gamma = 0.99
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.memory = deque(maxlen=10000)
self.batch_size = 64
self.optimizer = optim.Adam(self.eval_net.parameters(), lr=self.learning_rate)
def choose_action(self, state):
if random.random() < self.epsilon:
# 随机选择动作
return random.randint(0, self.action_dim - 1)
else:
# 根据Q值选择动作
state = torch.FloatTensor(state).unsqueeze(0)
q_values = self.eval_net(state)
return torch.argmax(q_values).item()
def store_transition(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def learn(self):
if len(self.memory) < self.batch_size:
return
# 随机采样batch
batch = random.sample(self.memory, self.batch_size)
states = torch.FloatTensor([x[0] for x in batch])
actions = torch.LongTensor([x[1] for x in batch])
rewards = torch.FloatTensor([x[2] for x in batch])
next_states = torch.FloatTensor([x[3] for x in batch])
dones = torch.FloatTensor([x[4] for x in batch])
# 计算当前Q值
current_q_values = self.eval_net(states).gather(1, actions.unsqueeze(1))
# 计算目标Q值
next_q_values = self.target_net(next_states).detach()
max_next_q = torch.max(next_q_values, dim=1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * max_next_q
# 计算损失
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# 更新网络
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 更新epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
# 定期更新目标网络
if self.learn.counter % 100 == 0:
self.target_net.load_state_dict(self.eval_net.state_dict())
self.learn.counter += 1
# 添加计数器属性
learn.counter = 0