HPCC2025/DQN/run_dqn.py
2025-03-09 16:53:01 +08:00

96 lines
2.5 KiB
Python

from env import Env
from dqn import Agent
import numpy as np
import matplotlib.pyplot as plt
def train():
# 创建环境和智能体
env = Env()
state_dim = env.observation_space.shape[0]
action_dim = 10 # len(垂直切割数)+len(水平切割数)
agent = Agent(state_dim, action_dim)
# 训练参数
episodes = 1000
max_steps = 1000
# 记录训练过程
rewards_history = []
best_reward = float('-inf')
best_solution = None
# 开始训练
for episode in range(episodes):
state = env.reset()
episode_reward = 0
for step in range(max_steps):
# 选择动作
action = agent.choose_action(state)
# 执行动作
next_state, reward, done, _ = env.step(action)
# 存储经验
agent.store_transition(state, action, reward, next_state, done)
# 学习
agent.learn()
episode_reward += reward
state = next_state
if done:
break
# 记录每个episode的总奖励
rewards_history.append(episode_reward)
# 更新最佳解
if episode_reward > best_reward:
best_reward = episode_reward
best_solution = {
'vertical_cuts': int(action[0]),
'horizontal_cuts': int(action[1]),
# 'offload_ratio': action[2],
'total_time': -reward if reward != -1000 else float('inf'),
'episode': episode
}
# 打印训练进度
if (episode + 1) % 10 == 0:
avg_reward = np.mean(rewards_history[-10:])
print(f"Episode {episode + 1}, Average Reward: {avg_reward:.2f}")
return best_solution, rewards_history
def plot_training_results(rewards_history):
plt.figure(figsize=(10, 5))
plt.plot(rewards_history)
plt.title('Training Progress')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.show()
def print_solution(solution):
print("\n最佳解决方案:")
print(f"在第 {solution['episode']} 轮找到")
print(f"垂直切割数: {solution['vertical_cuts']}")
print(f"水平切割数: {solution['horizontal_cuts']}")
print(f"任务卸载率: {solution['offload_ratio']:.2f}")
print(f"总完成时间: {solution['total_time']:.2f}")
if __name__ == "__main__":
# 训练模型
best_solution, rewards_history = train()
# 显示结果
plot_training_results(rewards_history)
print_solution(best_solution)