import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random class DQN(nn.Module): def __init__(self, state_dim, action_dim): super(DQN, self).__init__() self.network = nn.Sequential( nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_dim) ) def forward(self, x): return self.network(x) class Agent: def __init__(self, state_dim, action_dim): self.state_dim = state_dim self.action_dim = action_dim # DQN网络 self.eval_net = DQN(state_dim, action_dim) self.target_net = DQN(state_dim, action_dim) self.target_net.load_state_dict(self.eval_net.state_dict()) # 训练参数 self.learning_rate = 0.001 self.gamma = 0.99 self.epsilon = 1.0 self.epsilon_min = 0.01 self.epsilon_decay = 0.995 self.memory = deque(maxlen=10000) self.batch_size = 64 self.optimizer = optim.Adam(self.eval_net.parameters(), lr=self.learning_rate) def choose_action(self, state): if random.random() < self.epsilon: # 随机选择动作 return random.randint(0, self.action_dim - 1) else: # 根据Q值选择动作 state = torch.FloatTensor(state).unsqueeze(0) q_values = self.eval_net(state) return torch.argmax(q_values).item() def store_transition(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) def learn(self): if len(self.memory) < self.batch_size: return # 随机采样batch batch = random.sample(self.memory, self.batch_size) states = torch.FloatTensor([x[0] for x in batch]) actions = torch.LongTensor([x[1] for x in batch]) rewards = torch.FloatTensor([x[2] for x in batch]) next_states = torch.FloatTensor([x[3] for x in batch]) dones = torch.FloatTensor([x[4] for x in batch]) # 计算当前Q值 current_q_values = self.eval_net(states).gather(1, actions.unsqueeze(1)) # 计算目标Q值 next_q_values = self.target_net(next_states).detach() max_next_q = torch.max(next_q_values, dim=1)[0] target_q_values = rewards + (1 - dones) * self.gamma * max_next_q # 计算损失 loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values) # 更新网络 self.optimizer.zero_grad() loss.backward() self.optimizer.step() # 更新epsilon self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) # 定期更新目标网络 if self.learn.counter % 100 == 0: self.target_net.load_state_dict(self.eval_net.state_dict()) self.learn.counter += 1 # 添加计数器属性 learn.counter = 0