跑通PPO partition
This commit is contained in:
parent
8d79e8cc66
commit
2c88915112
@ -9,7 +9,7 @@ from PPO import PPO_agent
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from env import PartitionMazeEnv
|
||||
from env_partion import PartitionEnv
|
||||
# fmt: on
|
||||
|
||||
'''Hyperparameter Setting'''
|
||||
@ -68,16 +68,16 @@ print(opt)
|
||||
|
||||
|
||||
def main():
|
||||
EnvName = ['PartitionMaze_PPO_Continuous', 'Pendulum-v1', 'LunarLanderContinuous-v2',
|
||||
EnvName = ['Partition_PPO_Continuous', 'Pendulum-v1', 'LunarLanderContinuous-v2',
|
||||
'Humanoid-v4', 'HalfCheetah-v4', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3']
|
||||
BrifEnvName = ['PM_PPO_Con', 'PV1', 'LLdV2',
|
||||
BrifEnvName = ['Part_PPO_Con', 'PV1', 'LLdV2',
|
||||
'Humanv4', 'HCv4', 'BWv3', 'BWHv3']
|
||||
|
||||
# Build Env
|
||||
# env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None)
|
||||
env = PartitionMazeEnv()
|
||||
env = PartitionEnv()
|
||||
# eval_env = gym.make(EnvName[opt.EnvIdex])
|
||||
eval_env = PartitionMazeEnv()
|
||||
eval_env = PartitionEnv()
|
||||
opt.state_dim = env.observation_space.shape[0]
|
||||
opt.action_dim = env.action_space.shape[0]
|
||||
opt.max_action = float(env.action_space.high[0])
|
||||
|
168
env_partion.py
Normal file
168
env_partion.py
Normal file
@ -0,0 +1,168 @@
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
import yaml
|
||||
import math
|
||||
from mTSP_solver import mTSP
|
||||
|
||||
|
||||
class PartitionEnv(gym.Env):
|
||||
"""
|
||||
自定义环境,分为两阶段:
|
||||
区域切分,每一次切分都是(0, 1)之间的连续值
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
super(PartitionEnv, self).__init__()
|
||||
##############################
|
||||
# 可能需要手动修改的超参数
|
||||
##############################
|
||||
self.params = 'params3'
|
||||
self.CUT_NUM = 2
|
||||
self.ROW_CUT_LIMIT = 1
|
||||
self.COL_CUT_LIMIT = 1
|
||||
self.BASE_LINE = 5000
|
||||
|
||||
# 车队参数设置
|
||||
with open(self.params + '.yml', 'r', encoding='utf-8') as file:
|
||||
params = yaml.safe_load(file)
|
||||
|
||||
self.H = params['H']
|
||||
self.W = params['W']
|
||||
self.center = (self.H/2, self.W/2)
|
||||
self.num_cars = params['num_cars']
|
||||
|
||||
self.flight_time_factor = params['flight_time_factor']
|
||||
self.comp_time_factor = params['comp_time_factor']
|
||||
self.trans_time_factor = params['trans_time_factor']
|
||||
self.car_time_factor = params['car_time_factor']
|
||||
self.bs_time_factor = params['bs_time_factor']
|
||||
|
||||
self.flight_energy_factor = params['flight_energy_factor']
|
||||
self.comp_energy_factor = params['comp_energy_factor']
|
||||
self.trans_energy_factor = params['trans_energy_factor']
|
||||
self.battery_energy_capacity = params['battery_energy_capacity']
|
||||
|
||||
self.partition_step = 0 # 区域划分阶段步数,范围 0~4
|
||||
self.partition_values = np.zeros(
|
||||
self.CUT_NUM, dtype=np.float32) # 存储 c₁, c₂, r₁, r₂
|
||||
|
||||
# 定义动作空间:全部动作均为 1 维连续 [0,1]
|
||||
self.action_space = spaces.Box(
|
||||
low=0.0, high=1.0, shape=(1,), dtype=np.float32)
|
||||
|
||||
# 定义观察空间为8维向量
|
||||
# 前 4 维表示已决策的切分值(未决策部分为 0)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0.0, high=1.0, shape=(self.CUT_NUM,), dtype=np.float32)
|
||||
|
||||
# 切分阶段相关变量
|
||||
self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切
|
||||
self.row_cuts = [] # 存储横切位置(r₁, r₂)
|
||||
self.rectangles = []
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
# 重置所有变量,回到切分阶段(phase 0)
|
||||
self.phase = 0
|
||||
self.partition_step = 0
|
||||
self.partition_values = np.zeros(self.CUT_NUM, dtype=np.float32)
|
||||
self.col_cuts = []
|
||||
self.row_cuts = []
|
||||
self.rectangles = []
|
||||
|
||||
# 状态:前 4 维为 partition_values,其余为区域访问状态(初始全0)
|
||||
state = self.partition_values
|
||||
|
||||
return state
|
||||
|
||||
def step(self, action):
|
||||
# 在所有阶段动作均为 1 维连续动作,取 action[0]
|
||||
a = float(action[0])
|
||||
self.partition_values[self.partition_step] = a
|
||||
self.partition_step += 1
|
||||
|
||||
# 构造当前状态:前 partition_step 个为已决策值,其余为 0,再补 7 个 0
|
||||
state = self.partition_values
|
||||
|
||||
# 如果未完成 4 步,则仍处于切分阶段,不发奖励,done 为 False
|
||||
if self.partition_step < self.CUT_NUM:
|
||||
return state, 0.0, False, False, {}
|
||||
else:
|
||||
# 完成 4 步后,计算切分边界
|
||||
# 过滤掉 0,并去重后排序
|
||||
rows = sorted(
|
||||
set(v for v in self.partition_values[:self.ROW_CUT_LIMIT] if v > 0))
|
||||
cols = sorted(
|
||||
set(v for v in self.partition_values[self.ROW_CUT_LIMIT:] if v > 0))
|
||||
rows = rows if rows else []
|
||||
cols = rows if cols else []
|
||||
|
||||
# 边界:始终包含 0 和 1
|
||||
self.row_cuts = [0.0] + rows + [1.0]
|
||||
self.col_cuts = [0.0] + cols + [1.0]
|
||||
|
||||
# 判断分区是否合理,并计算各个分区的任务卸载率ρ
|
||||
valid_partition = True
|
||||
for i in range(len(self.row_cuts) - 1):
|
||||
for j in range(len(self.col_cuts) - 1):
|
||||
d = (self.col_cuts[j+1] - self.col_cuts[j]) * self.W * \
|
||||
(self.row_cuts[i+1] - self.row_cuts[i]) * self.H
|
||||
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
|
||||
(self.comp_time_factor - self.trans_time_factor)
|
||||
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
|
||||
(self.comp_energy_factor * d -
|
||||
self.trans_energy_factor * d)
|
||||
if rho_energy_limit < 0:
|
||||
valid_partition = False
|
||||
break
|
||||
rho = min(rho_time_limit, rho_energy_limit)
|
||||
|
||||
flight_time = self.flight_time_factor * d
|
||||
bs_time = self.bs_time_factor * (1 - rho) * d
|
||||
|
||||
self.rectangles.append({
|
||||
'center': ((self.row_cuts[i] + self.row_cuts[i+1]) * self.H / 2, (self.col_cuts[j+1] + self.col_cuts[j]) * self.W / 2),
|
||||
'flight_time': flight_time,
|
||||
'bs_time': bs_time,
|
||||
})
|
||||
if not valid_partition:
|
||||
break
|
||||
|
||||
if not valid_partition:
|
||||
reward = -100
|
||||
state = self.partition_values
|
||||
return state, reward, True, False, {}
|
||||
else:
|
||||
reward = 0
|
||||
state = self.partition_values
|
||||
|
||||
# 继续进行路径规划
|
||||
# cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
|
||||
rec_center_lt = [rec_info['center']
|
||||
for rec_info in self.rectangles]
|
||||
cities = np.column_stack(rec_center_lt)
|
||||
cities = np.column_stack((self.center, cities))
|
||||
|
||||
center_idx = []
|
||||
for i in range(self.num_cars - 1):
|
||||
cities = np.column_stack((cities, self.center))
|
||||
center_idx.append(cities.shape[1] - 1)
|
||||
|
||||
tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
|
||||
center_idx=center_idx, rectangles=self.rectangles)
|
||||
best_time, best_path = tsp.train(10000)
|
||||
|
||||
reward += self.BASE_LINE - best_time
|
||||
|
||||
return state, reward, True, False, best_path
|
||||
|
||||
def render(self):
|
||||
if self.phase == 1:
|
||||
print("Phase 1: Initialize maze environment.")
|
||||
print(f"Partition values so far: {self.partition_values}")
|
||||
print(f"Motorcade positon: {self.car_pos}")
|
||||
# input('1111')
|
||||
elif self.phase == 2:
|
||||
print("Phase 2: Play maze.")
|
||||
print(f'Motorcade trajectory: {self.car_traj}')
|
||||
# input('2222')
|
@ -1,18 +1,21 @@
|
||||
from env import PartitionMazeEnv
|
||||
# from env import PartitionMazeEnv
|
||||
# from env_dis import PartitionMazeEnv
|
||||
from env_partion import PartitionEnv
|
||||
|
||||
env = PartitionMazeEnv()
|
||||
# env = PartitionMazeEnv()
|
||||
env = PartitionEnv()
|
||||
|
||||
state = env.reset()
|
||||
print(state)
|
||||
print('state:', state)
|
||||
|
||||
action_series = [[0.67], [0], [0], [0], [0.7]]
|
||||
# action_series = [[0.67], [0], [0], [0], [0.7]]
|
||||
# action_series = [0, 0, 3, 0, 10]
|
||||
action_series = [[0.5], [0.5]]
|
||||
|
||||
for i in range(100):
|
||||
action = action_series[i]
|
||||
state, reward, done, info, _ = env.step(action)
|
||||
print(state)
|
||||
print(reward)
|
||||
print('state:', state)
|
||||
print('reward:', reward)
|
||||
if done:
|
||||
break
|
||||
|
@ -1,17 +1,20 @@
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
|
||||
class TSP(object):
|
||||
class mTSP(object):
|
||||
'''
|
||||
用 Q-Learning 求解 TSP 问题
|
||||
作者 Surfer Zen @ https://www.zhihu.com/people/surfer-zen
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
params='params',
|
||||
num_cities=15,
|
||||
cities=None,
|
||||
num_cars=2,
|
||||
center_idx=[0],
|
||||
rectangles=None,
|
||||
alpha=2,
|
||||
beta=1,
|
||||
learning_rate=0.001,
|
||||
@ -29,6 +32,7 @@ class TSP(object):
|
||||
self.cities = cities
|
||||
self.num_cars = num_cars
|
||||
self.center_idx = center_idx
|
||||
self.rectangles = rectangles
|
||||
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
@ -41,6 +45,24 @@ class TSP(object):
|
||||
self.best_path = None
|
||||
self.best_path_length = np.inf
|
||||
|
||||
with open(params+'.yml', 'r', encoding='utf-8') as file:
|
||||
params = yaml.safe_load(file)
|
||||
|
||||
self.H = params['H']
|
||||
self.W = params['W']
|
||||
self.num_cars = params['num_cars']
|
||||
|
||||
self.flight_time_factor = params['flight_time_factor']
|
||||
self.comp_time_factor = params['comp_time_factor']
|
||||
self.trans_time_factor = params['trans_time_factor']
|
||||
self.car_time_factor = params['car_time_factor']
|
||||
self.bs_time_factor = params['bs_time_factor']
|
||||
|
||||
self.flight_energy_factor = params['flight_energy_factor']
|
||||
self.comp_energy_factor = params['comp_energy_factor']
|
||||
self.trans_energy_factor = params['trans_energy_factor']
|
||||
self.battery_energy_capacity = params['battery_energy_capacity']
|
||||
|
||||
def get_dist_matrix(self):
|
||||
'''
|
||||
根据城市坐标,计算距离矩阵
|
||||
@ -137,13 +159,21 @@ class TSP(object):
|
||||
'''
|
||||
split_result = self.split_path(path)
|
||||
|
||||
length_lt = []
|
||||
time_lt = []
|
||||
for car_path in split_result:
|
||||
path_length = 0
|
||||
flight_time = 0
|
||||
bs_time = 0
|
||||
for fr, to in zip(car_path[:-1], car_path[1:]):
|
||||
path_length += self.distances[fr, to]
|
||||
length_lt.append(path_length)
|
||||
return max(length_lt)
|
||||
car_time = path_length * self.car_time_factor
|
||||
for offset_rec_idx in car_path[1:-1]:
|
||||
flight_time += self.rectangles[offset_rec_idx -
|
||||
1]['flight_time']
|
||||
bs_time += self.rectangles[offset_rec_idx - 1]['bs_time']
|
||||
system_time = max(flight_time + car_time, bs_time)
|
||||
time_lt.append(system_time)
|
||||
return max(time_lt)
|
||||
|
||||
def split_path(self, path):
|
||||
# 分割路径
|
||||
@ -197,23 +227,24 @@ class TSP(object):
|
||||
'''
|
||||
for epoch in range(num_epochs):
|
||||
self.train_for_one_rollout(start_city_id=0)
|
||||
return self.best_path_length, self.best_path
|
||||
|
||||
|
||||
def main():
|
||||
np.random.seed(42)
|
||||
center = np.array([0, 0])
|
||||
# cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
|
||||
cites = np.random.random([2, 15]) * np.array([800, 600]).reshape(2, -1)
|
||||
# cites = np.array([[10, -10], [0, 0]])
|
||||
cites = np.column_stack((center, cites))
|
||||
cities = np.random.random([2, 15]) * np.array([800, 600]).reshape(2, -1)
|
||||
# cities = np.array([[10, -10], [0, 0]])
|
||||
cities = np.column_stack((center, cities))
|
||||
|
||||
num_cars = 2
|
||||
center_idx = []
|
||||
for i in range(num_cars - 1):
|
||||
cites = np.column_stack((cites, center))
|
||||
center_idx.append(cites.shape[1] - 1)
|
||||
cities = np.column_stack((cities, center))
|
||||
center_idx.append(cities.shape[1] - 1)
|
||||
|
||||
tsp = TSP(num_cities=cites.shape[1], cities=cites,
|
||||
tsp = mTSP(num_cities=cities.shape[1], cities=cities,
|
||||
num_cars=num_cars, center_idx=center_idx)
|
||||
|
||||
# 训练模型
|
@ -11,13 +11,13 @@ random.seed(42)
|
||||
# ---------------------------
|
||||
# 需要修改的超参数
|
||||
# ---------------------------
|
||||
num_iterations = 1000000
|
||||
num_iterations = 10000
|
||||
# 随机生成分区的行分段数与列分段数
|
||||
# R = random.randint(0, 3) # 行分段数
|
||||
# C = random.randint(0, 3) # 列分段数
|
||||
R = 3
|
||||
C = 3
|
||||
params_file = 'params2'
|
||||
R = 1
|
||||
C = 1
|
||||
params_file = 'params3'
|
||||
|
||||
|
||||
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
||||
|
Loading…
Reference in New Issue
Block a user