HPCC2025/Duel_Double_DQN/utils.py
2025-04-12 22:55:01 +08:00

89 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from datetime import datetime
import copy
def evaluate_policy(env, agent, turns=3):
"""
评估策略
Args:
env: 环境对象
agent: 智能体对象
turns: 评估轮数
Returns:
int: 平均得分
"""
total_scores = 0
# for j in range(turns):
s = env.reset()
done = False
eval_info = {'action_series': [],
# 'state_series': [],
'reward_series': []}
info_lt = []
while not done:
a = agent.select_action(s, deterministic=True)
s_next, r, dw, tr, info = env.step(a)
done = (dw or tr)
eval_info['action_series'].append(a)
eval_info['reward_series'].append(r)
info_lt.append(copy.deepcopy(info))
total_scores += r
s = s_next
print(eval_info)
save_best_solution(info_lt)
return int(total_scores/turns)
def save_best_solution(info_lt):
# 找出这一轮中最优的解
best_info = min(info_lt, key=lambda x: x['best_time'])
# 读取已有的最优解
try:
with open('solutions/dqn_params_50_50_3.json', 'r') as f:
saved_solution = json.load(f)
saved_time = saved_solution['best_time']
except FileNotFoundError:
saved_time = float('inf')
# 如果新的解更好则更新json文件
if best_info['best_time'] < saved_time:
best_solution = {
'best_time': best_info['best_time'],
'row_cuts': best_info['row_cuts'],
'col_cuts': best_info['col_cuts'],
'best_path': best_info['best_path'],
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
with open('solutions/dqn_params_100_100_6.json', 'w') as f:
json.dump(best_solution, f, indent=4)
print(f"发现新的最优解!时间: {best_info['best_time']}")
def compare_lists(list1, list2):
return len(list1) == len(list2) and all(a == b for a, b in zip(list1, list2))
# You can just ignore this funciton. Is not related to the RL.
def str2bool(v):
'''transfer str to bool for argparse'''
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'True', 'true', 'TRUE', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'False', 'false', 'FALSE', 'f', 'n', '0'):
return False
else:
print('Wrong Input.')
raise