Files
bullet3/examples/pybullet/gym/pybullet_envs/ARS/eval_ars.py
2018-04-11 18:14:56 -07:00

100 lines
2.5 KiB
Python

"""
blaze run -c opt //experimental/users/jietan/ARS:eval_ars -- \
--logdir=/cns/ij-d/home/jietan/experiment/ARS/ars_react_nr01.191950338.191950550/ \
--checkpoint=lin_policy_plus_990.npz \
--num_rollouts=10
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, inspect
import time
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
os.sys.path.insert(0,currentdir)
from absl import app
from absl import flags
import pdb
import os
import numpy as np
import gym
import config_ars
import utility
import policies
FLAGS = flags.FLAGS
flags.DEFINE_string('logdir', None, 'The path of the checkpoint.')
flags.DEFINE_string('checkpoint', None, 'The file name of the checkpoint.')
flags.DEFINE_integer('num_rollouts', 1, 'The number of rollouts.')
def main(argv):
del argv # Unused.
print('loading and building expert policy')
checkpoint_file = os.path.join(FLAGS.logdir, FLAGS.checkpoint)
lin_policy = np.load(checkpoint_file, encoding='bytes')
lin_policy = lin_policy.items()[0][1]
M = lin_policy[0]
# mean and std of state vectors estimated online by ARS.
mean = lin_policy[1]
std = lin_policy[2]
config = utility.load_config(FLAGS.logdir)
print("config=",config)
env = config['env'](hard_reset=True, render=True)
ob_dim = env.observation_space.shape[0]
ac_dim = env.action_space.shape[0]
# set policy parameters. Possible filters: 'MeanStdFilter' for v2, 'NoFilter' for v1.
policy_params = {
'type': 'linear',
'ob_filter': config['filter'],
'ob_dim': ob_dim,
'ac_dim': ac_dim,
"weights": M,
"mean": mean,
"std": std,
}
policy = policies.LinearPolicy(policy_params, update_filter=False)
returns = []
observations = []
actions = []
for i in range(FLAGS.num_rollouts):
print('iter', i)
obs = env.reset()
done = False
totalr = 0.
steps = 0
while not done:
action = policy.act(obs)
observations.append(obs)
actions.append(action)
obs, r, done, _ = env.step(action)
time.sleep(1./100.)
totalr += r
steps += 1
if steps % 100 == 0:
print('%i/%i' % (steps, config['rollout_length']))
if steps >= config['rollout_length']:
break
returns.append(totalr)
print('returns', returns)
print('mean return', np.mean(returns))
print('std of return', np.std(returns))
if __name__ == '__main__':
flags.mark_flag_as_required('logdir')
flags.mark_flag_as_required('checkpoint')
app.run(main)