保存当前状态
This commit is contained in:
parent
01c6a71b4f
commit
e7a4395340
43
MDP/test_mdpsovler.py
Normal file
43
MDP/test_mdpsovler.py
Normal file
@ -0,0 +1,43 @@
|
||||
import mdpsolver
|
||||
import random
|
||||
import sys
|
||||
import numpy as np
|
||||
from random import randint
|
||||
|
||||
#TEST 1
|
||||
#Simple MDP with 3 states and 2 actions in each state.
|
||||
|
||||
#---------------------------------------
|
||||
# CONFIGURATION 1
|
||||
#---------------------------------------
|
||||
|
||||
#rewards
|
||||
#1st index: from (current) states
|
||||
#2nd index: actions
|
||||
rewards = [[5,-1],
|
||||
[1,-2],
|
||||
[50,0]]
|
||||
|
||||
#transition probabilities
|
||||
#1st index: from (current) states
|
||||
#2nd index: actions
|
||||
#3rd index: to (next) states
|
||||
tranMatWithZeros = [[[0.9,0.1,0.0],[0.1,0.9,0.0]],
|
||||
[[0.4,0.5,0.1],[0.3,0.5,0.2]],
|
||||
[[0.2,0.2,0.6],[0.5,0.5,0.0]]]
|
||||
|
||||
#initial policy
|
||||
random.seed(10)
|
||||
initPolicy = [randint(0, 1) for p in range(0, 3)]
|
||||
|
||||
#Model 1a (discounted reward, parallel)
|
||||
mdl1a = mdpsolver.model()
|
||||
mdl1a.mdp(discount=0.95,
|
||||
rewards=rewards,
|
||||
tranMatWithZeros=tranMatWithZeros)
|
||||
mdl1a.solve(algorithm="mpi",
|
||||
update="standard",
|
||||
parallel=True,
|
||||
initPolicy=initPolicy)
|
||||
|
||||
print(mdl1a.getPolicy())
|
Loading…
Reference in New Issue
Block a user