add ARS to train/eval Minitaur
This commit is contained in:
64
examples/pybullet/gym/pybullet_envs/ARS/train_ars.py
Normal file
64
examples/pybullet/gym/pybullet_envs/ARS/train_ars.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""TODO(jietan): DO NOT SUBMIT without one-line documentation for train_ars.
|
||||
|
||||
blaze build -c opt //experimental/users/jietan/ARS:train_ars
|
||||
blaze-bin/experimental/users/jietan/ARS/train_ars \
|
||||
--logdir=/cns/ij-d/home/jietan/experiment/ARS/test1 \
|
||||
--config_name=MINITAUR_GYM_CONFIG
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import os
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import ars
|
||||
import config_ars
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('logdir', None, 'The directory to write the log file.')
|
||||
flags.DEFINE_string('config_name', None, 'The name of the config dictionary')
|
||||
|
||||
|
||||
def run_ars(config, logdir):
|
||||
|
||||
env = config["env"]()
|
||||
ob_dim = env.observation_space.shape[0]
|
||||
ac_dim = env.action_space.shape[0]
|
||||
|
||||
# set policy parameters. Possible filters: 'MeanStdFilter' for v2, 'NoFilter' for v1.
|
||||
policy_params = {
|
||||
'type': 'linear',
|
||||
'ob_filter': config['filter'],
|
||||
'ob_dim': ob_dim,
|
||||
'ac_dim': ac_dim
|
||||
}
|
||||
|
||||
ARS = ars.ARSLearner(
|
||||
env_callback=config['env'],
|
||||
policy_params=policy_params,
|
||||
num_deltas=config['num_directions'],
|
||||
deltas_used=config['deltas_used'],
|
||||
step_size=config['step_size'],
|
||||
delta_std=config['delta_std'],
|
||||
logdir=logdir,
|
||||
rollout_length=config['rollout_length'],
|
||||
shift=config['shift'],
|
||||
params=config,
|
||||
seed=config['seed'])
|
||||
|
||||
return ARS.train(config['num_iterations'])
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused.
|
||||
config = getattr(config_ars, FLAGS.config_name)
|
||||
run_ars(config=config, logdir=FLAGS.logdir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('logdir')
|
||||
flags.mark_flag_as_required('config_name')
|
||||
app.run(main)
|
||||
Reference in New Issue
Block a user