96 lines
2.5 KiB
Python
96 lines
2.5 KiB
Python
from env import Env
|
|
from dqn import Agent
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def train():
|
|
# 创建环境和智能体
|
|
env = Env()
|
|
state_dim = env.observation_space.shape[0]
|
|
action_dim = 10 # len(垂直切割数)+len(水平切割数)
|
|
|
|
agent = Agent(state_dim, action_dim)
|
|
|
|
# 训练参数
|
|
episodes = 1000
|
|
max_steps = 1000
|
|
|
|
# 记录训练过程
|
|
rewards_history = []
|
|
best_reward = float('-inf')
|
|
best_solution = None
|
|
|
|
# 开始训练
|
|
for episode in range(episodes):
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
|
|
for step in range(max_steps):
|
|
# 选择动作
|
|
action = agent.choose_action(state)
|
|
|
|
# 执行动作
|
|
next_state, reward, done, _ = env.step(action)
|
|
|
|
# 存储经验
|
|
agent.store_transition(state, action, reward, next_state, done)
|
|
|
|
# 学习
|
|
agent.learn()
|
|
|
|
episode_reward += reward
|
|
state = next_state
|
|
|
|
if done:
|
|
break
|
|
|
|
# 记录每个episode的总奖励
|
|
rewards_history.append(episode_reward)
|
|
|
|
# 更新最佳解
|
|
if episode_reward > best_reward:
|
|
best_reward = episode_reward
|
|
best_solution = {
|
|
'vertical_cuts': int(action[0]),
|
|
'horizontal_cuts': int(action[1]),
|
|
# 'offload_ratio': action[2],
|
|
'total_time': -reward if reward != -1000 else float('inf'),
|
|
'episode': episode
|
|
}
|
|
|
|
# 打印训练进度
|
|
if (episode + 1) % 10 == 0:
|
|
avg_reward = np.mean(rewards_history[-10:])
|
|
print(f"Episode {episode + 1}, Average Reward: {avg_reward:.2f}")
|
|
|
|
return best_solution, rewards_history
|
|
|
|
|
|
def plot_training_results(rewards_history):
|
|
plt.figure(figsize=(10, 5))
|
|
plt.plot(rewards_history)
|
|
plt.title('Training Progress')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Total Reward')
|
|
plt.grid(True)
|
|
plt.show()
|
|
|
|
|
|
def print_solution(solution):
|
|
print("\n最佳解决方案:")
|
|
print(f"在第 {solution['episode']} 轮找到")
|
|
print(f"垂直切割数: {solution['vertical_cuts']}")
|
|
print(f"水平切割数: {solution['horizontal_cuts']}")
|
|
print(f"任务卸载率: {solution['offload_ratio']:.2f}")
|
|
print(f"总完成时间: {solution['total_time']:.2f} 秒")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 训练模型
|
|
best_solution, rewards_history = train()
|
|
|
|
# 显示结果
|
|
plot_training_results(rewards_history)
|
|
print_solution(best_solution)
|