2025-04-01 17:46:23 +08:00
|
|
|
|
import json
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
import copy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_policy(env, agent, turns=3):
|
|
|
|
|
"""
|
|
|
|
|
评估策略
|
|
|
|
|
Args:
|
|
|
|
|
env: 环境对象
|
|
|
|
|
agent: 智能体对象
|
|
|
|
|
turns: 评估轮数
|
|
|
|
|
Returns:
|
|
|
|
|
int: 平均得分
|
|
|
|
|
"""
|
2025-03-18 21:16:48 +08:00
|
|
|
|
total_scores = 0
|
2025-04-01 17:46:23 +08:00
|
|
|
|
|
|
|
|
|
# for j in range(turns):
|
|
|
|
|
s = env.reset()
|
|
|
|
|
done = False
|
|
|
|
|
eval_info = {'action_series': [],
|
|
|
|
|
# 'state_series': [],
|
|
|
|
|
'reward_series': []}
|
|
|
|
|
info_lt = []
|
|
|
|
|
|
|
|
|
|
while not done:
|
|
|
|
|
a = agent.select_action(s, deterministic=True)
|
|
|
|
|
s_next, r, dw, tr, info = env.step(a)
|
|
|
|
|
done = (dw or tr)
|
|
|
|
|
|
|
|
|
|
eval_info['action_series'].append(a)
|
|
|
|
|
eval_info['reward_series'].append(r)
|
|
|
|
|
info_lt.append(copy.deepcopy(info))
|
|
|
|
|
|
|
|
|
|
total_scores += r
|
|
|
|
|
s = s_next
|
|
|
|
|
|
|
|
|
|
print(eval_info)
|
|
|
|
|
save_best_solution(info_lt)
|
|
|
|
|
|
2025-03-18 21:16:48 +08:00
|
|
|
|
return int(total_scores/turns)
|
|
|
|
|
|
|
|
|
|
|
2025-04-01 17:46:23 +08:00
|
|
|
|
def save_best_solution(info_lt):
|
|
|
|
|
# 找出这一轮中最优的解
|
|
|
|
|
best_info = min(info_lt, key=lambda x: x['best_time'])
|
|
|
|
|
|
|
|
|
|
# 读取已有的最优解
|
|
|
|
|
try:
|
2025-04-04 10:59:31 +08:00
|
|
|
|
with open('solutions/dqn_params_100_100_6.json', 'r') as f:
|
2025-04-01 17:46:23 +08:00
|
|
|
|
saved_solution = json.load(f)
|
|
|
|
|
saved_time = saved_solution['best_time']
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
saved_time = float('inf')
|
|
|
|
|
|
|
|
|
|
# 如果新的解更好,则更新json文件
|
|
|
|
|
if best_info['best_time'] < saved_time:
|
|
|
|
|
best_solution = {
|
|
|
|
|
'best_time': best_info['best_time'],
|
|
|
|
|
'row_cuts': best_info['row_cuts'],
|
|
|
|
|
'col_cuts': best_info['col_cuts'],
|
|
|
|
|
'best_path': best_info['best_path'],
|
|
|
|
|
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
|
}
|
|
|
|
|
|
2025-04-04 10:59:31 +08:00
|
|
|
|
with open('solutions/dqn_params_100_100_6.json', 'w') as f:
|
2025-04-01 17:46:23 +08:00
|
|
|
|
json.dump(best_solution, f, indent=4)
|
|
|
|
|
|
|
|
|
|
print(f"发现新的最优解!时间: {best_info['best_time']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compare_lists(list1, list2):
|
|
|
|
|
return len(list1) == len(list2) and all(a == b for a, b in zip(list1, list2))
|
|
|
|
|
|
|
|
|
|
# You can just ignore this funciton. Is not related to the RL.
|
|
|
|
|
|
|
|
|
|
|
2025-03-18 21:16:48 +08:00
|
|
|
|
def str2bool(v):
|
|
|
|
|
'''transfer str to bool for argparse'''
|
|
|
|
|
if isinstance(v, bool):
|
|
|
|
|
return v
|
2025-04-01 17:46:23 +08:00
|
|
|
|
if v.lower() in ('yes', 'True', 'true', 'TRUE', 't', 'y', '1'):
|
2025-03-18 21:16:48 +08:00
|
|
|
|
return True
|
2025-04-01 17:46:23 +08:00
|
|
|
|
elif v.lower() in ('no', 'False', 'false', 'FALSE', 'f', 'n', '0'):
|
2025-03-18 21:16:48 +08:00
|
|
|
|
return False
|
|
|
|
|
else:
|
|
|
|
|
print('Wrong Input.')
|
2025-04-01 17:46:23 +08:00
|
|
|
|
raise
|