add ARS to train/eval Minitaur
This commit is contained in:
99
examples/pybullet/gym/pybullet_envs/ARS/eval_ars.py
Normal file
99
examples/pybullet/gym/pybullet_envs/ARS/eval_ars.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
blaze run -c opt //experimental/users/jietan/ARS:eval_ars -- \
|
||||
--logdir=/cns/ij-d/home/jietan/experiment/ARS/ars_react_nr01.191950338.191950550/ \
|
||||
--checkpoint=lin_policy_plus_990.npz \
|
||||
--num_rollouts=10
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os, inspect
|
||||
import time
|
||||
|
||||
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
||||
os.sys.path.insert(0,currentdir)
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import pdb
|
||||
import os
|
||||
import numpy as np
|
||||
import gym
|
||||
import config_ars
|
||||
import utility
|
||||
import policies
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('logdir', None, 'The path of the checkpoint.')
|
||||
flags.DEFINE_string('checkpoint', None, 'The file name of the checkpoint.')
|
||||
flags.DEFINE_integer('num_rollouts', 1, 'The number of rollouts.')
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
|
||||
print('loading and building expert policy')
|
||||
checkpoint_file = os.path.join(FLAGS.logdir, FLAGS.checkpoint)
|
||||
lin_policy = np.load(checkpoint_file, encoding='bytes')
|
||||
lin_policy = lin_policy.items()[0][1]
|
||||
|
||||
M = lin_policy[0]
|
||||
# mean and std of state vectors estimated online by ARS.
|
||||
mean = lin_policy[1]
|
||||
std = lin_policy[2]
|
||||
|
||||
config = utility.load_config(FLAGS.logdir)
|
||||
print("config=",config)
|
||||
env = config['env'](hard_reset=True, render=True)
|
||||
ob_dim = env.observation_space.shape[0]
|
||||
ac_dim = env.action_space.shape[0]
|
||||
|
||||
# set policy parameters. Possible filters: 'MeanStdFilter' for v2, 'NoFilter' for v1.
|
||||
policy_params = {
|
||||
'type': 'linear',
|
||||
'ob_filter': config['filter'],
|
||||
'ob_dim': ob_dim,
|
||||
'ac_dim': ac_dim,
|
||||
"weights": M,
|
||||
"mean": mean,
|
||||
"std": std,
|
||||
}
|
||||
policy = policies.LinearPolicy(policy_params, update_filter=False)
|
||||
returns = []
|
||||
observations = []
|
||||
actions = []
|
||||
for i in range(FLAGS.num_rollouts):
|
||||
print('iter', i)
|
||||
obs = env.reset()
|
||||
done = False
|
||||
totalr = 0.
|
||||
steps = 0
|
||||
while not done:
|
||||
action = policy.act(obs)
|
||||
observations.append(obs)
|
||||
actions.append(action)
|
||||
|
||||
obs, r, done, _ = env.step(action)
|
||||
time.sleep(1./100.)
|
||||
totalr += r
|
||||
steps += 1
|
||||
if steps % 100 == 0:
|
||||
print('%i/%i' % (steps, config['rollout_length']))
|
||||
if steps >= config['rollout_length']:
|
||||
break
|
||||
returns.append(totalr)
|
||||
|
||||
print('returns', returns)
|
||||
print('mean return', np.mean(returns))
|
||||
print('std of return', np.std(returns))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('logdir')
|
||||
flags.mark_flag_as_required('checkpoint')
|
||||
app.run(main)
|
||||
Reference in New Issue
Block a user