100 lines
2.5 KiB
Python
100 lines
2.5 KiB
Python
"""
|
|
blaze run -c opt //experimental/users/jietan/ARS:eval_ars -- \
|
|
--logdir=/cns/ij-d/home/jietan/experiment/ARS/ars_react_nr01.191950338.191950550/ \
|
|
--checkpoint=lin_policy_plus_990.npz \
|
|
--num_rollouts=10
|
|
"""
|
|
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os, inspect
|
|
import time
|
|
|
|
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
|
os.sys.path.insert(0,currentdir)
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
|
|
import pdb
|
|
import os
|
|
import numpy as np
|
|
import gym
|
|
import config_ars
|
|
import utility
|
|
import policies
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
flags.DEFINE_string('logdir', None, 'The path of the checkpoint.')
|
|
flags.DEFINE_string('checkpoint', None, 'The file name of the checkpoint.')
|
|
flags.DEFINE_integer('num_rollouts', 1, 'The number of rollouts.')
|
|
|
|
|
|
def main(argv):
|
|
del argv # Unused.
|
|
|
|
print('loading and building expert policy')
|
|
checkpoint_file = os.path.join(FLAGS.logdir, FLAGS.checkpoint)
|
|
lin_policy = np.load(checkpoint_file, encoding='bytes')
|
|
lin_policy = lin_policy.items()[0][1]
|
|
|
|
M = lin_policy[0]
|
|
# mean and std of state vectors estimated online by ARS.
|
|
mean = lin_policy[1]
|
|
std = lin_policy[2]
|
|
|
|
config = utility.load_config(FLAGS.logdir)
|
|
print("config=",config)
|
|
env = config['env'](hard_reset=True, render=True)
|
|
ob_dim = env.observation_space.shape[0]
|
|
ac_dim = env.action_space.shape[0]
|
|
|
|
# set policy parameters. Possible filters: 'MeanStdFilter' for v2, 'NoFilter' for v1.
|
|
policy_params = {
|
|
'type': 'linear',
|
|
'ob_filter': config['filter'],
|
|
'ob_dim': ob_dim,
|
|
'ac_dim': ac_dim,
|
|
"weights": M,
|
|
"mean": mean,
|
|
"std": std,
|
|
}
|
|
policy = policies.LinearPolicy(policy_params, update_filter=False)
|
|
returns = []
|
|
observations = []
|
|
actions = []
|
|
for i in range(FLAGS.num_rollouts):
|
|
print('iter', i)
|
|
obs = env.reset()
|
|
done = False
|
|
totalr = 0.
|
|
steps = 0
|
|
while not done:
|
|
action = policy.act(obs)
|
|
observations.append(obs)
|
|
actions.append(action)
|
|
|
|
obs, r, done, _ = env.step(action)
|
|
time.sleep(1./100.)
|
|
totalr += r
|
|
steps += 1
|
|
if steps % 100 == 0:
|
|
print('%i/%i' % (steps, config['rollout_length']))
|
|
if steps >= config['rollout_length']:
|
|
break
|
|
returns.append(totalr)
|
|
|
|
print('returns', returns)
|
|
print('mean return', np.mean(returns))
|
|
print('std of return', np.std(returns))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
flags.mark_flag_as_required('logdir')
|
|
flags.mark_flag_as_required('checkpoint')
|
|
app.run(main)
|