修复env_partion bug
This commit is contained in:
parent
2c88915112
commit
ff2b914eb5
@ -28,7 +28,7 @@ parser.add_argument('--ModelIdex', type=int, default=500,
|
|||||||
help='which model to load')
|
help='which model to load')
|
||||||
|
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
parser.add_argument('--T_horizon', type=int, default=2000,
|
parser.add_argument('--T_horizon', type=int, default=20,
|
||||||
help='lenth of long trajectory')
|
help='lenth of long trajectory')
|
||||||
parser.add_argument('--Distribution', type=str, default='Beta',
|
parser.add_argument('--Distribution', type=str, default='Beta',
|
||||||
help='Should be one of Beta ; GS_ms ; GS_m')
|
help='Should be one of Beta ; GS_ms ; GS_m')
|
||||||
@ -36,7 +36,7 @@ parser.add_argument('--Max_train_steps', type=int,
|
|||||||
default=int(5e8), help='Max training steps')
|
default=int(5e8), help='Max training steps')
|
||||||
parser.add_argument('--save_interval', type=int,
|
parser.add_argument('--save_interval', type=int,
|
||||||
default=int(5e5), help='Model saving interval, in steps.')
|
default=int(5e5), help='Model saving interval, in steps.')
|
||||||
parser.add_argument('--eval_interval', type=int, default=int(5e3),
|
parser.add_argument('--eval_interval', type=int, default=int(5e1),
|
||||||
help='Model evaluating interval, in steps.')
|
help='Model evaluating interval, in steps.')
|
||||||
|
|
||||||
parser.add_argument('--gamma', type=float, default=0.99,
|
parser.add_argument('--gamma', type=float, default=0.99,
|
||||||
@ -138,7 +138,6 @@ def main():
|
|||||||
'''Store the current transition'''
|
'''Store the current transition'''
|
||||||
agent.put_data(s, a, r, s_next, logprob_a,
|
agent.put_data(s, a, r, s_next, logprob_a,
|
||||||
done, dw, idx=traj_lenth)
|
done, dw, idx=traj_lenth)
|
||||||
s = s_next
|
|
||||||
|
|
||||||
traj_lenth += 1
|
traj_lenth += 1
|
||||||
total_steps += 1
|
total_steps += 1
|
||||||
|
@ -17,11 +17,12 @@ class PartitionEnv(gym.Env):
|
|||||||
##############################
|
##############################
|
||||||
# 可能需要手动修改的超参数
|
# 可能需要手动修改的超参数
|
||||||
##############################
|
##############################
|
||||||
self.params = 'params3'
|
self.params = 'params2'
|
||||||
self.CUT_NUM = 2
|
self.CUT_NUM = 4
|
||||||
self.ROW_CUT_LIMIT = 1
|
self.ROW_CUT_LIMIT = 3
|
||||||
self.COL_CUT_LIMIT = 1
|
self.COL_CUT_LIMIT = 1
|
||||||
self.BASE_LINE = 5000
|
self.BASE_LINE = 12000
|
||||||
|
self.mTSP_STEPS = 10000
|
||||||
|
|
||||||
# 车队参数设置
|
# 车队参数设置
|
||||||
with open(self.params + '.yml', 'r', encoding='utf-8') as file:
|
with open(self.params + '.yml', 'r', encoding='utf-8') as file:
|
||||||
@ -95,7 +96,7 @@ class PartitionEnv(gym.Env):
|
|||||||
cols = sorted(
|
cols = sorted(
|
||||||
set(v for v in self.partition_values[self.ROW_CUT_LIMIT:] if v > 0))
|
set(v for v in self.partition_values[self.ROW_CUT_LIMIT:] if v > 0))
|
||||||
rows = rows if rows else []
|
rows = rows if rows else []
|
||||||
cols = rows if cols else []
|
cols = cols if cols else []
|
||||||
|
|
||||||
# 边界:始终包含 0 和 1
|
# 边界:始终包含 0 和 1
|
||||||
self.row_cuts = [0.0] + rows + [1.0]
|
self.row_cuts = [0.0] + rows + [1.0]
|
||||||
@ -150,7 +151,9 @@ class PartitionEnv(gym.Env):
|
|||||||
|
|
||||||
tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
|
tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
|
||||||
center_idx=center_idx, rectangles=self.rectangles)
|
center_idx=center_idx, rectangles=self.rectangles)
|
||||||
best_time, best_path = tsp.train(10000)
|
best_time, best_path = tsp.train(self.mTSP_STEPS)
|
||||||
|
print(best_time)
|
||||||
|
print(best_path)
|
||||||
|
|
||||||
reward += self.BASE_LINE - best_time
|
reward += self.BASE_LINE - best_time
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ print('state:', state)
|
|||||||
|
|
||||||
# action_series = [[0.67], [0], [0], [0], [0.7]]
|
# action_series = [[0.67], [0], [0], [0], [0.7]]
|
||||||
# action_series = [0, 0, 3, 0, 10]
|
# action_series = [0, 0, 3, 0, 10]
|
||||||
action_series = [[0.5], [0.5]]
|
action_series = [[0.2], [0.4], [0.7], [0.5]]
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
action = action_series[i]
|
action = action_series[i]
|
||||||
|
@ -28,6 +28,7 @@ class mTSP(object):
|
|||||||
learning_rate (float): 学习率
|
learning_rate (float): 学习率
|
||||||
eps (float): 探索率,值越大,探索性越强,但越难收敛
|
eps (float): 探索率,值越大,探索性越强,但越难收敛
|
||||||
'''
|
'''
|
||||||
|
np.random.seed(42)
|
||||||
self.num_cities = num_cities
|
self.num_cities = num_cities
|
||||||
self.cities = cities
|
self.cities = cities
|
||||||
self.num_cars = num_cars
|
self.num_cars = num_cars
|
||||||
|
Loading…
Reference in New Issue
Block a user