2025-03-14 15:27:05 +08:00
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn as nn
|
|
|
|
import argparse
|
|
|
|
import torch
|
2025-03-19 10:58:43 +08:00
|
|
|
import numpy as np
|
2025-03-14 15:27:05 +08:00
|
|
|
|
|
|
|
class Actor(nn.Module):
|
|
|
|
def __init__(self, state_dim, action_dim, net_width, maxaction):
|
|
|
|
super(Actor, self).__init__()
|
|
|
|
|
|
|
|
self.l1 = nn.Linear(state_dim, net_width)
|
|
|
|
self.l2 = nn.Linear(net_width, 300)
|
|
|
|
self.l3 = nn.Linear(300, action_dim)
|
|
|
|
|
|
|
|
self.maxaction = maxaction
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
a = torch.relu(self.l1(state))
|
|
|
|
a = torch.relu(self.l2(a))
|
|
|
|
a = torch.tanh(self.l3(a)) * self.maxaction
|
|
|
|
return a
|
|
|
|
|
|
|
|
|
|
|
|
class Q_Critic(nn.Module):
|
|
|
|
def __init__(self, state_dim, action_dim, net_width):
|
|
|
|
super(Q_Critic, self).__init__()
|
|
|
|
|
|
|
|
self.l1 = nn.Linear(state_dim + action_dim, net_width)
|
|
|
|
self.l2 = nn.Linear(net_width, 300)
|
|
|
|
self.l3 = nn.Linear(300, 1)
|
|
|
|
|
|
|
|
def forward(self, state, action):
|
|
|
|
sa = torch.cat([state, action], 1)
|
|
|
|
q = F.relu(self.l1(sa))
|
|
|
|
q = F.relu(self.l2(q))
|
|
|
|
q = self.l3(q)
|
|
|
|
return q
|
|
|
|
|
|
|
|
def evaluate_policy(env, agent, turns = 3):
|
|
|
|
total_scores = 0
|
|
|
|
for j in range(turns):
|
|
|
|
s = env.reset()
|
|
|
|
done = False
|
2025-03-19 10:58:43 +08:00
|
|
|
action_series = []
|
2025-03-14 15:27:05 +08:00
|
|
|
while not done:
|
|
|
|
# Take deterministic actions at test time
|
|
|
|
a = agent.select_action(s, deterministic=True)
|
|
|
|
s_next, r, dw, tr, info = env.step(a)
|
|
|
|
done = (dw or tr)
|
2025-03-19 10:58:43 +08:00
|
|
|
action_series.append(a[0])
|
2025-03-14 15:27:05 +08:00
|
|
|
total_scores += r
|
|
|
|
s = s_next
|
2025-03-19 14:22:24 +08:00
|
|
|
print('action series: ', np.round(action_series, 3))
|
|
|
|
print('state: {s_next}')
|
2025-03-14 15:27:05 +08:00
|
|
|
return int(total_scores/turns)
|
|
|
|
|
|
|
|
|
|
|
|
#Just ignore this function~
|
|
|
|
def str2bool(v):
|
|
|
|
'''transfer str to bool for argparse'''
|
|
|
|
if isinstance(v, bool):
|
|
|
|
return v
|
|
|
|
if v.lower() in ('yes', 'True','true','TRUE', 't', 'y', '1'):
|
|
|
|
return True
|
|
|
|
elif v.lower() in ('no', 'False','false','FALSE', 'f', 'n', '0'):
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|