import torch.nn.functional as F import torch.nn as nn import argparse import torch import numpy as np class Actor(nn.Module): def __init__(self, state_dim, action_dim, net_width, maxaction): super(Actor, self).__init__() self.l1 = nn.Linear(state_dim, net_width) self.l2 = nn.Linear(net_width, 300) self.l3 = nn.Linear(300, action_dim) self.maxaction = maxaction def forward(self, state): a = torch.relu(self.l1(state)) a = torch.relu(self.l2(a)) a = torch.tanh(self.l3(a)) * self.maxaction return a class Q_Critic(nn.Module): def __init__(self, state_dim, action_dim, net_width): super(Q_Critic, self).__init__() self.l1 = nn.Linear(state_dim + action_dim, net_width) self.l2 = nn.Linear(net_width, 300) self.l3 = nn.Linear(300, 1) def forward(self, state, action): sa = torch.cat([state, action], 1) q = F.relu(self.l1(sa)) q = F.relu(self.l2(q)) q = self.l3(q) return q def evaluate_policy(env, agent, turns = 3): total_scores = 0 for j in range(turns): s = env.reset() done = False action_series = [] while not done: # Take deterministic actions at test time a = agent.select_action(s, deterministic=True) s_next, r, dw, tr, info = env.step(a) done = (dw or tr) action_series.append(a[0]) total_scores += r s = s_next print('action series: ', np.round(action_series, 3)) print('state: ', s) return int(total_scores/turns) #Just ignore this function~ def str2bool(v): '''transfer str to bool for argparse''' if isinstance(v, bool): return v if v.lower() in ('yes', 'True','true','TRUE', 't', 'y', '1'): return True elif v.lower() in ('no', 'False','false','FALSE', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.')