每一个加一个奖励

This commit is contained in:
weixin_46229132 2025-03-29 16:53:03 +08:00
parent f347ca8276
commit 3e6887c655
3 changed files with 70 additions and 63 deletions

View File

@ -3,6 +3,7 @@ import random
# import matplotlib.pyplot as plt # import matplotlib.pyplot as plt
import numpy as np import numpy as np
# np.random.seed(42)
class GA(object): class GA(object):

View File

@ -108,12 +108,31 @@ class PartitionEnv(gym.Env):
# 出现无效调整,直接结束 # 出现无效调整,直接结束
if not valid_adjust: if not valid_adjust:
return state, reward, True, False, {} return state, reward, True, False, {}
# 调整合理,计算当前时间
else: else:
rectangles = self.if_valid_partition()
if not rectangles:
reward = -10
return state, reward, True, False, {}
else:
# 继续进行路径规划
# 使用遗传算法解多旅行商
best_time, best_path = self.ga_solver(rectangles)
# print(best_time)
# print(best_path)
reward = self.BASE_LINE / best_time
if self.partition_step < self.CUT_NUM: if self.partition_step < self.CUT_NUM:
return state, 0.0, False, False, {} done = False
else: else:
# 完成 4 步后,判断分区是否合理,并计算各个分区的任务卸载率ρ done = True
valid_partition = True reward = reward * 3
return state, reward, done, False, best_path
def if_valid_partition(self):
rectangles = []
for i in range(len(self.ori_row_cuts) - 1): for i in range(len(self.ori_row_cuts) - 1):
for j in range(len(self.ori_col_cuts) - 1): for j in range(len(self.ori_col_cuts) - 1):
d = (self.ori_col_cuts[j+1] - self.ori_col_cuts[j]) * self.W * \ d = (self.ori_col_cuts[j+1] - self.ori_col_cuts[j]) * self.W * \
@ -125,30 +144,24 @@ class PartitionEnv(gym.Env):
(self.comp_energy_factor * d - (self.comp_energy_factor * d -
self.trans_energy_factor * d) self.trans_energy_factor * d)
if rho_energy_limit < 0: if rho_energy_limit < 0:
valid_partition = False return []
break
rho = min(rho_time_limit, rho_energy_limit) rho = min(rho_time_limit, rho_energy_limit)
flight_time = self.flight_time_factor * d flight_time = self.flight_time_factor * d
bs_time = self.bs_time_factor * (1 - rho) * d bs_time = self.bs_time_factor * (1 - rho) * d
self.rectangles.append({ rectangles.append({
'center': ((self.ori_row_cuts[i] + self.ori_row_cuts[i+1]) * self.H / 2, (self.ori_col_cuts[j+1] + self.ori_col_cuts[j]) * self.W / 2), 'center': ((self.ori_row_cuts[i] + self.ori_row_cuts[i+1]) * self.H / 2, (self.ori_col_cuts[j+1] + self.ori_col_cuts[j]) * self.W / 2),
'flight_time': flight_time, 'flight_time': flight_time,
'bs_time': bs_time, 'bs_time': bs_time,
}) })
if not valid_partition: return rectangles
break
if not valid_partition: # def q_learning_solver(self):
reward = -10
return state, reward, True, False, {}
else:
# 继续进行路径规划
# 使用q_learning解多旅行商 # 使用q_learning解多旅行商
# cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标 # cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
# rec_center_lt = [rec_info['center'] # rec_center_lt = [rec_info['center']
# for rec_info in self.rectangles] # for rec_info in rectangles]
# cities = np.column_stack(rec_center_lt) # cities = np.column_stack(rec_center_lt)
# cities = np.column_stack((self.center, cities)) # cities = np.column_stack((self.center, cities))
@ -158,13 +171,13 @@ class PartitionEnv(gym.Env):
# center_idx.append(cities.shape[1] - 1) # center_idx.append(cities.shape[1] - 1)
# tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars, # tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
# center_idx=center_idx, rectangles=self.rectangles) # center_idx=center_idx, rectangles=rectangles)
# best_time, best_path = tsp.train(self.mTSP_STEPS) # best_time, best_path = tsp.train(self.mTSP_STEPS)
# 使用遗传算法解多旅行商 def ga_solver(self, rectangles):
cities = [self.center] cities = [self.center]
for rec in self.rectangles: for rec in rectangles:
cities.append(rec['center']) cities.append(rec['center'])
cities = np.array(cities) cities = np.array(cities)
@ -174,16 +187,9 @@ class PartitionEnv(gym.Env):
center_idx.append(cities.shape[0] - 1) center_idx.append(cities.shape[0] - 1)
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20, ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
data=cities, to_process_idx=center_idx, rectangles=self.rectangles) data=cities, to_process_idx=center_idx, rectangles=rectangles)
best_path, best_time = ga.run() best_path, best_time = ga.run()
return best_time, best_path
# print(best_time)
# print(best_path)
reward = self.BASE_LINE / best_time
return state, reward, True, False, best_path
def render(self): def render(self):
if self.phase == 1: if self.phase == 1:

View File

@ -10,7 +10,7 @@ print('state:', state)
# action_series = [[0.67], [0], [0], [0], [0.7]] # action_series = [[0.67], [0], [0], [0], [0.7]]
# action_series = [0, 0, 3, 0, 10] # action_series = [0, 0, 3, 0, 10]
action_series = [[0.2], [0.4], [0.7], [0.5]] # action_series = [[0.2], [0.4], [0.7], [0.5]]
action_series = [[-0.1], [0], [0], [0]] action_series = [[-0.1], [0], [0], [0]]
for i in range(100): for i in range(100):