保存当前状态
This commit is contained in:
parent
01c6a71b4f
commit
e7a4395340
43
MDP/test_mdpsovler.py
Normal file
43
MDP/test_mdpsovler.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import mdpsolver
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
#TEST 1
|
||||||
|
#Simple MDP with 3 states and 2 actions in each state.
|
||||||
|
|
||||||
|
#---------------------------------------
|
||||||
|
# CONFIGURATION 1
|
||||||
|
#---------------------------------------
|
||||||
|
|
||||||
|
#rewards
|
||||||
|
#1st index: from (current) states
|
||||||
|
#2nd index: actions
|
||||||
|
rewards = [[5,-1],
|
||||||
|
[1,-2],
|
||||||
|
[50,0]]
|
||||||
|
|
||||||
|
#transition probabilities
|
||||||
|
#1st index: from (current) states
|
||||||
|
#2nd index: actions
|
||||||
|
#3rd index: to (next) states
|
||||||
|
tranMatWithZeros = [[[0.9,0.1,0.0],[0.1,0.9,0.0]],
|
||||||
|
[[0.4,0.5,0.1],[0.3,0.5,0.2]],
|
||||||
|
[[0.2,0.2,0.6],[0.5,0.5,0.0]]]
|
||||||
|
|
||||||
|
#initial policy
|
||||||
|
random.seed(10)
|
||||||
|
initPolicy = [randint(0, 1) for p in range(0, 3)]
|
||||||
|
|
||||||
|
#Model 1a (discounted reward, parallel)
|
||||||
|
mdl1a = mdpsolver.model()
|
||||||
|
mdl1a.mdp(discount=0.95,
|
||||||
|
rewards=rewards,
|
||||||
|
tranMatWithZeros=tranMatWithZeros)
|
||||||
|
mdl1a.solve(algorithm="mpi",
|
||||||
|
update="standard",
|
||||||
|
parallel=True,
|
||||||
|
initPolicy=initPolicy)
|
||||||
|
|
||||||
|
print(mdl1a.getPolicy())
|
Loading…
Reference in New Issue
Block a user