离散情况
This commit is contained in:
parent
3e6887c655
commit
84f69f4293
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,7 +8,7 @@ __pycache__/
|
||||
*.so
|
||||
|
||||
# Pytorch weights
|
||||
weights/
|
||||
model/
|
||||
logs/
|
||||
|
||||
# Distribution / packaging
|
||||
|
@ -1,6 +1,6 @@
|
||||
from DDPG import DDPG_agent
|
||||
from datetime import datetime
|
||||
from utils import str2bool, evaluate_policy
|
||||
from utils import str2bool, evaluate_policy, Action_adapter
|
||||
import gymnasium as gym
|
||||
import shutil
|
||||
import argparse
|
||||
@ -32,7 +32,7 @@ parser.add_argument('--Max_train_steps', type=int,
|
||||
default=5e8, help='Max training steps')
|
||||
parser.add_argument('--save_interval', type=int, default=1e5,
|
||||
help='Model saving interval, in steps.')
|
||||
parser.add_argument('--eval_interval', type=int, default=2e3,
|
||||
parser.add_argument('--eval_interval', type=int, default=2e2,
|
||||
help='Model evaluating interval, in steps.')
|
||||
|
||||
parser.add_argument('--gamma', type=float, default=0.99,
|
||||
@ -45,7 +45,7 @@ parser.add_argument('--c_lr', type=float, default=1e-3,
|
||||
help='Learning rate of critic')
|
||||
parser.add_argument('--batch_size', type=int, default=128,
|
||||
help='batch_size of training')
|
||||
parser.add_argument('--random_steps', type=int, default=5e4,
|
||||
parser.add_argument('--random_steps', type=int, default=5e3,
|
||||
help='random steps before trianing')
|
||||
parser.add_argument('--noise', type=float, default=0.1, help='exploring noise')
|
||||
opt = parser.parse_args()
|
||||
@ -113,6 +113,7 @@ def main():
|
||||
a = env.action_space.sample()
|
||||
else:
|
||||
a = agent.select_action(s, deterministic=False)
|
||||
a = Action_adapter(a, opt.max_action)
|
||||
s_next, r, dw, tr, info = env.step(
|
||||
a) # dw: dead&win; tr: truncated
|
||||
done = (dw or tr)
|
||||
@ -127,7 +128,7 @@ def main():
|
||||
|
||||
'''record & log'''
|
||||
if total_steps % opt.eval_interval == 0:
|
||||
ep_r = evaluate_policy(eval_env, agent, turns=1)
|
||||
ep_r = evaluate_policy(eval_env, agent, opt.max_action, turns=1)
|
||||
if opt.write:
|
||||
writer.add_scalar(
|
||||
'ep_r', ep_r, global_step=total_steps)
|
||||
|
@ -36,7 +36,7 @@ class Q_Critic(nn.Module):
|
||||
q = self.l3(q)
|
||||
return q
|
||||
|
||||
def evaluate_policy(env, agent, turns = 3):
|
||||
def evaluate_policy(env, agent, max_action, turns = 3):
|
||||
total_scores = 0
|
||||
for j in range(turns):
|
||||
s = env.reset()
|
||||
@ -45,11 +45,13 @@ def evaluate_policy(env, agent, turns = 3):
|
||||
while not done:
|
||||
# Take deterministic actions at test time
|
||||
a = agent.select_action(s, deterministic=True)
|
||||
s_next, r, dw, tr, info = env.step(a)
|
||||
act = Action_adapter(a, max_action)
|
||||
s_next, r, dw, tr, info = env.step(act)
|
||||
done = (dw or tr)
|
||||
action_series.append(a[0])
|
||||
action_series.append(act[0])
|
||||
total_scores += r
|
||||
s = s_next
|
||||
print('origin action: ', a)
|
||||
print('action series: ', np.round(action_series, 3))
|
||||
print('state: ', s)
|
||||
return int(total_scores/turns)
|
||||
@ -66,3 +68,7 @@ def str2bool(v):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
def Action_adapter(a, max_action):
|
||||
# from [0,1] to [-max,max]
|
||||
return 2*(a-0.5)*max_action
|
@ -58,23 +58,13 @@ class DQN_agent(object):
|
||||
def select_action(self, state, deterministic):#only used when interact with the env
|
||||
with torch.no_grad():
|
||||
state = torch.FloatTensor(state.reshape(1, -1)).to(self.dvc)
|
||||
# if deterministic:
|
||||
# a = self.q_net(state).argmax().item()
|
||||
# else:
|
||||
if deterministic:
|
||||
a = self.q_net(state).argmax().item()
|
||||
else:
|
||||
if np.random.rand() < self.exp_noise:
|
||||
if state[0][0] == 0:
|
||||
a = np.random.randint(0,10)
|
||||
a = np.random.randint(0,self.action_dim)
|
||||
else:
|
||||
a = np.random.randint(10,15)
|
||||
else:
|
||||
if state[0][0] == 0:
|
||||
q_value = self.q_net(state)
|
||||
q_value[0][10:] = - float('inf')
|
||||
a = q_value.argmax().item()
|
||||
else:
|
||||
q_value = self.q_net(state)
|
||||
q_value[0][:10] = - float('inf')
|
||||
a = q_value.argmax().item()
|
||||
a = self.q_net(state).argmax().item()
|
||||
return a
|
||||
|
||||
|
||||
@ -84,17 +74,10 @@ class DQN_agent(object):
|
||||
'''Compute the target Q value'''
|
||||
with torch.no_grad():
|
||||
if self.Double:
|
||||
# TODO 如果有两个过程,这里记得也要更新
|
||||
argmax_a = self.q_net(s_next).argmax(dim=1).unsqueeze(-1)
|
||||
max_q_next = self.q_target(s_next).gather(1,argmax_a)
|
||||
else:
|
||||
max_q_next = self.q_target(s_next)
|
||||
# 添加动作掩码操作
|
||||
if s_next[0][0] == 0:
|
||||
max_q_next[:, 10:] = -float('inf')
|
||||
else:
|
||||
max_q_next[:, :10] = -float('inf')
|
||||
max_q_next = max_q_next.max(1)[0].unsqueeze(1)
|
||||
max_q_next = self.q_target(s_next).max(1)[0].unsqueeze(1)
|
||||
target_Q = r + (~dw) * self.gamma * max_q_next #dw: die or win
|
||||
|
||||
# Get current Q estimates
|
||||
@ -112,7 +95,7 @@ class DQN_agent(object):
|
||||
|
||||
|
||||
def save(self,algo,EnvName,steps):
|
||||
torch.save(self.q_net.state_dict(), "./weights/{}_{}_{}.pth".format(algo,EnvName,steps))
|
||||
torch.save(self.q_net.state_dict(), "./model/{}_{}_{}.pth".format(algo,EnvName,steps))
|
||||
|
||||
def load(self,algo,EnvName,steps):
|
||||
self.q_net.load_state_dict(torch.load("./model/{}_{}_{}.pth".format(algo,EnvName,steps),map_location=self.dvc))
|
||||
|
@ -1,16 +1,16 @@
|
||||
from DQN import DQN_agent
|
||||
from datetime import datetime
|
||||
from utils import evaluate_policy, str2bool
|
||||
from datetime import datetime
|
||||
from DQN import DQN_agent
|
||||
import gymnasium as gym
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
# fmt: off
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from env_dis import PartitionMazeEnv
|
||||
from env_partion_dist import PartitionEnv
|
||||
# fmt: on
|
||||
|
||||
'''Hyperparameter Setting'''
|
||||
@ -27,17 +27,17 @@ parser.add_argument('--Loadmodel', type=str2bool,
|
||||
parser.add_argument('--ModelIdex', type=int, default=100,
|
||||
help='which model to load')
|
||||
|
||||
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--Max_train_steps', type=int,
|
||||
default=int(1e8), help='Max training steps')
|
||||
parser.add_argument('--save_interval', type=int,
|
||||
default=int(50e3), help='Model saving interval, in steps.')
|
||||
default=int(5e3), help='Model saving interval, in steps.')
|
||||
parser.add_argument('--eval_interval', type=int, default=int(2e3),
|
||||
help='Model evaluating interval, in steps.')
|
||||
parser.add_argument('--random_steps', type=int, default=int(3e3),
|
||||
help='steps for random policy to explore')
|
||||
parser.add_argument('--update_every', type=int,
|
||||
default=50, help='training frequency')
|
||||
default=10, help='training frequency')
|
||||
|
||||
parser.add_argument('--gamma', type=float, default=0.99,
|
||||
help='Discounted Factor')
|
||||
@ -60,14 +60,13 @@ print(opt)
|
||||
|
||||
|
||||
def main():
|
||||
EnvName = ['CartPole-v1', 'LunarLander-v2']
|
||||
BriefEnvName = ['PM_DQN', 'CPV1', 'LLdV2']
|
||||
# env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None)
|
||||
# eval_env = gym.make(EnvName[opt.EnvIdex])
|
||||
env = PartitionMazeEnv()
|
||||
eval_env = PartitionMazeEnv()
|
||||
EnvName = ['PartitionEnv']
|
||||
BriefEnvName = ['PartEnv']
|
||||
env = PartitionEnv()
|
||||
eval_env = PartitionEnv()
|
||||
opt.state_dim = env.observation_space.shape[0]
|
||||
opt.action_dim = env.action_space.n
|
||||
opt.max_e_steps = env.MAX_ADJUST_STEP
|
||||
|
||||
# Algorithm Setting
|
||||
if opt.Duel:
|
||||
@ -88,7 +87,7 @@ def main():
|
||||
print("Random Seed: {}".format(opt.seed))
|
||||
|
||||
print('Algorithm:', algo_name, ' Env:', BriefEnvName[opt.EnvIdex], ' state_dim:', opt.state_dim,
|
||||
' action_dim:', opt.action_dim, ' Random Seed:', opt.seed, '\n')
|
||||
' action_dim:', opt.action_dim, ' Random Seed:', opt.seed, ' max_e_steps:', opt.max_e_steps, '\n')
|
||||
|
||||
if opt.write:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -117,18 +116,14 @@ def main():
|
||||
while total_steps < opt.Max_train_steps:
|
||||
# Do not use opt.seed directly, or it can overfit to opt.seed
|
||||
s = env.reset(seed=env_seed)
|
||||
env_seed += 1
|
||||
done = False
|
||||
|
||||
'''Interact & trian'''
|
||||
while not done:
|
||||
# e-greedy exploration
|
||||
if total_steps < opt.random_steps:
|
||||
# if s[0] == 0:
|
||||
# a = np.random.randint(0, 10)
|
||||
# else:
|
||||
# a = np.random.randint(10, 14)
|
||||
action_series = [0, 0, 3, 0, 10]
|
||||
a = action_series[total_steps % 5]
|
||||
a = env.action_space.sample()
|
||||
else:
|
||||
a = agent.select_action(s, deterministic=False)
|
||||
s_next, r, dw, tr, info = env.step(a)
|
||||
|
@ -1,5 +1,3 @@
|
||||
import numpy as np
|
||||
|
||||
def evaluate_policy(env, agent, turns = 3):
|
||||
total_scores = 0
|
||||
for j in range(turns):
|
||||
@ -14,7 +12,7 @@ def evaluate_policy(env, agent, turns = 3):
|
||||
action_series.append(a)
|
||||
total_scores += r
|
||||
s = s_next
|
||||
print('action series: ', np.round(action_series, 3))
|
||||
print('action series: ', action_series)
|
||||
print('state: ', s)
|
||||
return int(total_scores/turns)
|
||||
|
||||
|
16
GA/ga.py
16
GA/ga.py
@ -3,8 +3,8 @@ import random
|
||||
|
||||
# import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
# np.random.seed(42)
|
||||
|
||||
np.random.seed(42)
|
||||
random.seed(42)
|
||||
|
||||
class GA(object):
|
||||
def __init__(self, num_drones, num_city, num_total, data, to_process_idx, rectangles):
|
||||
@ -16,7 +16,7 @@ class GA(object):
|
||||
self.location = data
|
||||
self.to_process_idx = to_process_idx
|
||||
self.rectangles = rectangles
|
||||
self.epochs = 500
|
||||
self.epochs = 1000
|
||||
self.ga_choose_ratio = 0.2
|
||||
self.mutate_ratio = 0.05
|
||||
# fruits中存每一个个体是下标的list
|
||||
@ -98,7 +98,7 @@ class GA(object):
|
||||
return dis_mat
|
||||
|
||||
# 计算路径长度
|
||||
def compute_pathlen(self, tmp_path, dis_mat):
|
||||
def compute_pathlen(self, tmp_path):
|
||||
path = tmp_path.copy()
|
||||
if path[0] not in self.to_process_idx:
|
||||
path.insert(0, 0)
|
||||
@ -138,7 +138,7 @@ class GA(object):
|
||||
if a in self.to_process_idx and b in self.to_process_idx:
|
||||
car_time += 0
|
||||
else:
|
||||
car_time += dis_mat[a][b] * 100 # TODO 这里要换成对应参数
|
||||
car_time += self.dis_mat[a][b] * 100 # TODO 这里要换成对应参数
|
||||
car_move_info = {'car_path': car_path, 'car_time': car_time}
|
||||
car_infos.append(car_move_info)
|
||||
|
||||
@ -179,7 +179,7 @@ class GA(object):
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
length = self.compute_pathlen(fruit, self.dis_mat)
|
||||
length = self.compute_pathlen(fruit)
|
||||
adp.append(1.0 / length)
|
||||
return np.array(adp)
|
||||
|
||||
@ -287,8 +287,8 @@ class GA(object):
|
||||
gene_x_new = self.ga_mutate(gene_x_new)
|
||||
if np.random.rand() < self.mutate_ratio:
|
||||
gene_y_new = self.ga_mutate(gene_y_new)
|
||||
x_adp = 1.0 / self.compute_pathlen(gene_x_new, self.dis_mat)
|
||||
y_adp = 1.0 / self.compute_pathlen(gene_y_new, self.dis_mat)
|
||||
x_adp = 1.0 / self.compute_pathlen(gene_x_new)
|
||||
y_adp = 1.0 / self.compute_pathlen(gene_y_new)
|
||||
# 将适应度高的放入种群中
|
||||
if x_adp > y_adp and (not gene_x_new in fruits):
|
||||
fruits.append(gene_x_new)
|
||||
|
@ -37,7 +37,7 @@ parser.add_argument('--Max_train_steps', type=int,
|
||||
default=int(5e8), help='Max training steps')
|
||||
parser.add_argument('--save_interval', type=int,
|
||||
default=int(5e5), help='Model saving interval, in steps.')
|
||||
parser.add_argument('--eval_interval', type=int, default=int(5e3),
|
||||
parser.add_argument('--eval_interval', type=int, default=int(5e1),
|
||||
help='Model evaluating interval, in steps.')
|
||||
|
||||
parser.add_argument('--gamma', type=float, default=0.99,
|
||||
|
@ -149,6 +149,7 @@ def evaluate_policy(env, agent, max_action, turns):
|
||||
action_series.append(act[0])
|
||||
total_scores += r
|
||||
s = s_next
|
||||
print('origin action:', a)
|
||||
print('action series: ', np.round(action_series, 3))
|
||||
print('state: ', s)
|
||||
|
||||
|
@ -29,7 +29,7 @@ class PartitionEnv(gym.Env):
|
||||
|
||||
# 定义动作空间:全部动作均为 1 维连续 [0,1]
|
||||
self.action_space = spaces.Box(
|
||||
low=0.0, high=1.0, shape=(1,), dtype=np.float32)
|
||||
low=-0.1, high=0.1, shape=(1,), dtype=np.float32)
|
||||
|
||||
# 定义观察空间为8维向量
|
||||
# 前 4 维表示已决策的切分值(未决策部分为 0)
|
||||
|
232
env_partion_dist.py
Normal file
232
env_partion_dist.py
Normal file
@ -0,0 +1,232 @@
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
import yaml
|
||||
import math
|
||||
from mTSP_solver import mTSP
|
||||
from GA.ga import GA
|
||||
|
||||
|
||||
class PartitionEnv(gym.Env):
|
||||
"""
|
||||
自定义环境,分为两阶段:
|
||||
区域切分,每一次切分都是(0, 1)之间的连续值
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
super(PartitionEnv, self).__init__()
|
||||
##############################
|
||||
# 可能需要手动修改的超参数
|
||||
##############################
|
||||
self.params = 'params2'
|
||||
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.7, 1]
|
||||
self.ORI_COL_CUTS = [0, 0.5, 1]
|
||||
self.CUT_NUM = 4
|
||||
self.BASE_LINE = 9100
|
||||
self.MAX_ADJUST_STEP = 50
|
||||
self.ADJUST_THRESHOLD = 0.1
|
||||
# self.mTSP_STEPS = 10000
|
||||
|
||||
# 切分位置+/-0.01
|
||||
self.action_space = spaces.Discrete(self.CUT_NUM*2 + 1)
|
||||
# 定义观察空间为8维向量
|
||||
self.observation_space = spaces.Box(
|
||||
low=0.0, high=1.0, shape=(self.CUT_NUM + 4,), dtype=np.float32)
|
||||
|
||||
self.row_cuts = self.ORI_ROW_CUTS[:]
|
||||
self.col_cuts = self.ORI_COL_CUTS[:]
|
||||
self.rectangles = []
|
||||
self.adjust_step = 0
|
||||
|
||||
# 车队参数设置
|
||||
with open(self.params + '.yml', 'r', encoding='utf-8') as file:
|
||||
params = yaml.safe_load(file)
|
||||
|
||||
self.H = params['H']
|
||||
self.W = params['W']
|
||||
self.center = (self.H/2, self.W/2)
|
||||
self.num_cars = params['num_cars']
|
||||
|
||||
self.flight_time_factor = params['flight_time_factor']
|
||||
self.comp_time_factor = params['comp_time_factor']
|
||||
self.trans_time_factor = params['trans_time_factor']
|
||||
self.car_time_factor = params['car_time_factor']
|
||||
self.bs_time_factor = params['bs_time_factor']
|
||||
|
||||
self.flight_energy_factor = params['flight_energy_factor']
|
||||
self.comp_energy_factor = params['comp_energy_factor']
|
||||
self.trans_energy_factor = params['trans_energy_factor']
|
||||
self.battery_energy_capacity = params['battery_energy_capacity']
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
# 重置所有变量,回到切分阶段(phase 0)
|
||||
self.row_cuts = self.ORI_ROW_CUTS[:]
|
||||
self.col_cuts = self.ORI_COL_CUTS[:]
|
||||
self.rectangles = []
|
||||
self.adjust_step = 0
|
||||
|
||||
# 状态:前 4 维为 partition_values,其余为区域访问状态(初始全0)
|
||||
state = np.array(self.row_cuts + self.col_cuts)
|
||||
|
||||
return state
|
||||
|
||||
def step(self, action):
|
||||
if action == 1:
|
||||
self.row_cuts[1] += 0.01
|
||||
elif action == 2:
|
||||
self.row_cuts[1] -= 0.01
|
||||
elif action == 3:
|
||||
self.row_cuts[2] += 0.01
|
||||
elif action == 4:
|
||||
self.row_cuts[2] -= 0.01
|
||||
elif action == 5:
|
||||
self.row_cuts[3] += 0.01
|
||||
elif action == 6:
|
||||
self.row_cuts[3] -= 0.01
|
||||
elif action == 7:
|
||||
self.col_cuts[1] += 0.01
|
||||
elif action == 8:
|
||||
self.col_cuts[1] -= 0.01
|
||||
elif action == 9:
|
||||
pass
|
||||
|
||||
self.adjust_step += 1
|
||||
state = np.array(self.row_cuts + self.col_cuts)
|
||||
|
||||
if self.row_cuts[0] < self.row_cuts[1] < self.row_cuts[2] < self.row_cuts[3] < self.row_cuts[4] and self.col_cuts[0] < self.col_cuts[1] < self.col_cuts[2]:
|
||||
# 调整合法,验证分区情况是否满足条件
|
||||
rectangles = self.if_valid_partition()
|
||||
|
||||
if not rectangles:
|
||||
# 不满足条件,结束
|
||||
reward = -10000
|
||||
return state, reward, True, False, {}
|
||||
else:
|
||||
# 满足条件,继续进行路径规划
|
||||
|
||||
# 每隔10步计算一次路径,第一次也需要计算路径,记录最佳路径
|
||||
if self.adjust_step % 10 == 0 or self.adjust_step == 1:
|
||||
best_time, self.best_path = self.ga_solver(rectangles)
|
||||
else:
|
||||
# 根据最佳路径计算当前时间
|
||||
best_time = self.get_best_time(self.best_path, rectangles)
|
||||
|
||||
reward = self.BASE_LINE - best_time
|
||||
|
||||
if self.adjust_step < self.MAX_ADJUST_STEP:
|
||||
done = False
|
||||
else:
|
||||
done = True
|
||||
return state, reward, done, False, self.best_path
|
||||
else:
|
||||
# 调整不合法,结束
|
||||
return state, -10, True, False, {}
|
||||
|
||||
def if_valid_partition(self):
|
||||
rectangles = []
|
||||
for i in range(len(self.row_cuts) - 1):
|
||||
for j in range(len(self.col_cuts) - 1):
|
||||
d = (self.col_cuts[j+1] - self.col_cuts[j]) * self.W * \
|
||||
(self.row_cuts[i+1] -
|
||||
self.row_cuts[i]) * self.H
|
||||
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
|
||||
(self.comp_time_factor - self.trans_time_factor)
|
||||
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
|
||||
(self.comp_energy_factor * d -
|
||||
self.trans_energy_factor * d)
|
||||
if rho_energy_limit < 0:
|
||||
return []
|
||||
rho = min(rho_time_limit, rho_energy_limit)
|
||||
|
||||
flight_time = self.flight_time_factor * d
|
||||
bs_time = self.bs_time_factor * (1 - rho) * d
|
||||
|
||||
rectangles.append({
|
||||
'center': ((self.row_cuts[i] + self.row_cuts[i+1]) * self.H / 2, (self.col_cuts[j+1] + self.col_cuts[j]) * self.W / 2),
|
||||
'flight_time': flight_time,
|
||||
'bs_time': bs_time,
|
||||
})
|
||||
return rectangles
|
||||
|
||||
def check_adjustment_threshold(self, threshold=0.1):
|
||||
"""
|
||||
检查当前切分位置与原始切分位置的差异是否超过阈值
|
||||
Args:
|
||||
threshold (float): 允许的最大调整幅度
|
||||
Returns:
|
||||
bool: 如果任何切分位置的调整超过阈值,返回True
|
||||
"""
|
||||
# 检查行切分位置
|
||||
delta = 0
|
||||
for i in range(len(self.row_cuts)):
|
||||
delta += abs(self.row_cuts[i] - self.ORI_ROW_CUTS[i])
|
||||
|
||||
# 检查列切分位置
|
||||
for i in range(len(self.col_cuts)):
|
||||
delta += abs(self.col_cuts[i] - self.ORI_COL_CUTS[i])
|
||||
|
||||
if delta > threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# def q_learning_solver(self):
|
||||
# 使用q_learning解多旅行商
|
||||
# cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
|
||||
# rec_center_lt = [rec_info['center']
|
||||
# for rec_info in rectangles]
|
||||
# cities = np.column_stack(rec_center_lt)
|
||||
# cities = np.column_stack((self.center, cities))
|
||||
|
||||
# center_idx = []
|
||||
# for i in range(self.num_cars - 1):
|
||||
# cities = np.column_stack((cities, self.center))
|
||||
# center_idx.append(cities.shape[1] - 1)
|
||||
|
||||
# tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
|
||||
# center_idx=center_idx, rectangles=rectangles)
|
||||
|
||||
# best_time, best_path = tsp.train(self.mTSP_STEPS)
|
||||
|
||||
def ga_solver(self, rectangles):
|
||||
cities = [self.center]
|
||||
for rec in rectangles:
|
||||
cities.append(rec['center'])
|
||||
cities = np.array(cities)
|
||||
|
||||
center_idx = [0]
|
||||
for i in range(self.num_cars - 1):
|
||||
cities = np.row_stack((cities, self.center))
|
||||
center_idx.append(cities.shape[0] - 1)
|
||||
|
||||
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
|
||||
data=cities, to_process_idx=center_idx, rectangles=rectangles)
|
||||
best_path, best_time = ga.run()
|
||||
return best_time, best_path
|
||||
|
||||
def get_best_time(self, best_path, rectangles):
|
||||
cities = [self.center]
|
||||
for rec in rectangles:
|
||||
cities.append(rec['center'])
|
||||
cities = np.array(cities)
|
||||
|
||||
center_idx = [0]
|
||||
for i in range(self.num_cars - 1):
|
||||
cities = np.row_stack((cities, self.center))
|
||||
center_idx.append(cities.shape[0] - 1)
|
||||
|
||||
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
|
||||
data=cities, to_process_idx=center_idx, rectangles=rectangles)
|
||||
best_time = ga.compute_pathlen(best_path)
|
||||
return best_time
|
||||
|
||||
def render(self):
|
||||
if self.phase == 1:
|
||||
print("Phase 1: Initialize maze environment.")
|
||||
print(f"Partition values so far: {self.partition_values}")
|
||||
print(f"Motorcade positon: {self.car_pos}")
|
||||
# input('1111')
|
||||
elif self.phase == 2:
|
||||
print("Phase 2: Play maze.")
|
||||
print(f'Motorcade trajectory: {self.car_traj}')
|
||||
# input('2222')
|
@ -1,6 +1,6 @@
|
||||
# from env import PartitionMazeEnv
|
||||
# from env_dis import PartitionMazeEnv
|
||||
from env_partion import PartitionEnv
|
||||
from env_partion_dist import PartitionEnv
|
||||
|
||||
# env = PartitionMazeEnv()
|
||||
env = PartitionEnv()
|
||||
@ -9,9 +9,10 @@ state = env.reset()
|
||||
print('state:', state)
|
||||
|
||||
# action_series = [[0.67], [0], [0], [0], [0.7]]
|
||||
# action_series = [0, 0, 3, 0, 10]
|
||||
action_series = [1, 1, 1, 1, 1, 1]
|
||||
action_series = [1] * 30
|
||||
# action_series = [[0.2], [0.4], [0.7], [0.5]]
|
||||
action_series = [[-0.1], [0], [0], [0]]
|
||||
# action_series = [[-0.08], [-0.08], [0], [0]]
|
||||
|
||||
for i in range(100):
|
||||
action = action_series[i]
|
||||
|
Loading…
Reference in New Issue
Block a user