ddpg求解env_part
This commit is contained in:
parent
0cf336c96d
commit
f05f8400fb
@ -9,7 +9,7 @@ import torch
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
from env import PartitionMazeEnv
|
from env_partion import PartitionEnv
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
'''Hyperparameter Setting'''
|
'''Hyperparameter Setting'''
|
||||||
@ -54,16 +54,16 @@ print(opt)
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
EnvName = ['PartitionMaze_DDPG', 'Pendulum-v1', 'LunarLanderContinuous-v2', 'Humanoid-v4',
|
EnvName = ['Partition_DDPG', 'Pendulum-v1', 'LunarLanderContinuous-v2', 'Humanoid-v4',
|
||||||
'HalfCheetah-v4', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3']
|
'HalfCheetah-v4', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3']
|
||||||
BrifEnvName = ['PM_DDPG', 'PV1', 'LLdV2',
|
BrifEnvName = ['Part_DDPG', 'PV1', 'LLdV2',
|
||||||
'Humanv4', 'HCv4', 'BWv3', 'BWHv3']
|
'Humanv4', 'HCv4', 'BWv3', 'BWHv3']
|
||||||
|
|
||||||
# Build Env
|
# Build Env
|
||||||
# env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None)
|
# env = gym.make(EnvName[opt.EnvIdex], render_mode = "human" if opt.render else None)
|
||||||
env = PartitionMazeEnv()
|
env = PartitionEnv()
|
||||||
# eval_env = gym.make(EnvName[opt.EnvIdex])
|
# eval_env = gym.make(EnvName[opt.EnvIdex])
|
||||||
eval_env = PartitionMazeEnv()
|
eval_env = PartitionEnv()
|
||||||
opt.state_dim = env.observation_space.shape[0]
|
opt.state_dim = env.observation_space.shape[0]
|
||||||
opt.action_dim = env.action_space.shape[0]
|
opt.action_dim = env.action_space.shape[0]
|
||||||
# remark: action space【-max,max】
|
# remark: action space【-max,max】
|
||||||
|
@ -18,11 +18,11 @@ class PartitionEnv(gym.Env):
|
|||||||
##############################
|
##############################
|
||||||
# 可能需要手动修改的超参数
|
# 可能需要手动修改的超参数
|
||||||
##############################
|
##############################
|
||||||
self.params = 'params3'
|
self.params = 'params2'
|
||||||
self.CUT_NUM = 2
|
self.CUT_NUM = 4
|
||||||
self.ROW_CUT_LIMIT = 1
|
self.ROW_CUT_LIMIT = 3
|
||||||
self.COL_CUT_LIMIT = 1
|
self.COL_CUT_LIMIT = 1
|
||||||
self.BASE_LINE = 5000
|
self.BASE_LINE = 10000
|
||||||
self.mTSP_STEPS = 10000
|
self.mTSP_STEPS = 10000
|
||||||
|
|
||||||
# 车队参数设置
|
# 车队参数设置
|
||||||
@ -176,6 +176,7 @@ class PartitionEnv(gym.Env):
|
|||||||
# print(best_path)
|
# print(best_path)
|
||||||
|
|
||||||
reward += self.BASE_LINE - best_time
|
reward += self.BASE_LINE - best_time
|
||||||
|
print(reward)
|
||||||
|
|
||||||
return state, reward, True, False, best_path
|
return state, reward, True, False, best_path
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user