修复env_partion bug

This commit is contained in:
weixin_46229132 2025-03-29 10:48:47 +08:00
parent 2c88915112
commit ff2b914eb5
4 changed files with 13 additions and 10 deletions

View File

@ -28,7 +28,7 @@ parser.add_argument('--ModelIdex', type=int, default=500,
help='which model to load') help='which model to load')
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--T_horizon', type=int, default=2000, parser.add_argument('--T_horizon', type=int, default=20,
help='lenth of long trajectory') help='lenth of long trajectory')
parser.add_argument('--Distribution', type=str, default='Beta', parser.add_argument('--Distribution', type=str, default='Beta',
help='Should be one of Beta ; GS_ms ; GS_m') help='Should be one of Beta ; GS_ms ; GS_m')
@ -36,7 +36,7 @@ parser.add_argument('--Max_train_steps', type=int,
default=int(5e8), help='Max training steps') default=int(5e8), help='Max training steps')
parser.add_argument('--save_interval', type=int, parser.add_argument('--save_interval', type=int,
default=int(5e5), help='Model saving interval, in steps.') default=int(5e5), help='Model saving interval, in steps.')
parser.add_argument('--eval_interval', type=int, default=int(5e3), parser.add_argument('--eval_interval', type=int, default=int(5e1),
help='Model evaluating interval, in steps.') help='Model evaluating interval, in steps.')
parser.add_argument('--gamma', type=float, default=0.99, parser.add_argument('--gamma', type=float, default=0.99,
@ -138,7 +138,6 @@ def main():
'''Store the current transition''' '''Store the current transition'''
agent.put_data(s, a, r, s_next, logprob_a, agent.put_data(s, a, r, s_next, logprob_a,
done, dw, idx=traj_lenth) done, dw, idx=traj_lenth)
s = s_next
traj_lenth += 1 traj_lenth += 1
total_steps += 1 total_steps += 1

View File

@ -17,11 +17,12 @@ class PartitionEnv(gym.Env):
############################## ##############################
# 可能需要手动修改的超参数 # 可能需要手动修改的超参数
############################## ##############################
self.params = 'params3' self.params = 'params2'
self.CUT_NUM = 2 self.CUT_NUM = 4
self.ROW_CUT_LIMIT = 1 self.ROW_CUT_LIMIT = 3
self.COL_CUT_LIMIT = 1 self.COL_CUT_LIMIT = 1
self.BASE_LINE = 5000 self.BASE_LINE = 12000
self.mTSP_STEPS = 10000
# 车队参数设置 # 车队参数设置
with open(self.params + '.yml', 'r', encoding='utf-8') as file: with open(self.params + '.yml', 'r', encoding='utf-8') as file:
@ -95,7 +96,7 @@ class PartitionEnv(gym.Env):
cols = sorted( cols = sorted(
set(v for v in self.partition_values[self.ROW_CUT_LIMIT:] if v > 0)) set(v for v in self.partition_values[self.ROW_CUT_LIMIT:] if v > 0))
rows = rows if rows else [] rows = rows if rows else []
cols = rows if cols else [] cols = cols if cols else []
# 边界:始终包含 0 和 1 # 边界:始终包含 0 和 1
self.row_cuts = [0.0] + rows + [1.0] self.row_cuts = [0.0] + rows + [1.0]
@ -150,7 +151,9 @@ class PartitionEnv(gym.Env):
tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars, tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
center_idx=center_idx, rectangles=self.rectangles) center_idx=center_idx, rectangles=self.rectangles)
best_time, best_path = tsp.train(10000) best_time, best_path = tsp.train(self.mTSP_STEPS)
print(best_time)
print(best_path)
reward += self.BASE_LINE - best_time reward += self.BASE_LINE - best_time

View File

@ -10,7 +10,7 @@ print('state:', state)
# action_series = [[0.67], [0], [0], [0], [0.7]] # action_series = [[0.67], [0], [0], [0], [0.7]]
# action_series = [0, 0, 3, 0, 10] # action_series = [0, 0, 3, 0, 10]
action_series = [[0.5], [0.5]] action_series = [[0.2], [0.4], [0.7], [0.5]]
for i in range(100): for i in range(100):
action = action_series[i] action = action_series[i]

View File

@ -28,6 +28,7 @@ class mTSP(object):
learning_rate (float) 学习率 learning_rate (float) 学习率
eps (float): 探索率值越大探索性越强但越难收敛 eps (float): 探索率值越大探索性越强但越难收敛
''' '''
np.random.seed(42)
self.num_cities = num_cities self.num_cities = num_cities
self.cities = cities self.cities = cities
self.num_cars = num_cars self.num_cars = num_cars