add ARS to train/eval Minitaur
This commit is contained in:
29
examples/pybullet/gym/pybullet_envs/ARS/train_ars_test.py
Normal file
29
examples/pybullet/gym/pybullet_envs/ARS/train_ars_test.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Tests for google3.experimental.users.jietan.ARS.train_ars.
|
||||
blaze build -c opt //experimental/users/jietan/ARS:train_ars_test
|
||||
blaze-bin/experimental/users/jietan/ARS/train_ars_test
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import flags
|
||||
from google3.testing.pybase import googletest
|
||||
from google3.experimental.users.jietan.ARS import train_ars
|
||||
from google3.experimental.users.jietan.ARS import config_ars
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
MAX_RETURN_AFTER_TWO_ITEATIONS = 0.0890905394617
|
||||
|
||||
class TrainArsTest(googletest.TestCase):
|
||||
|
||||
def testArsTwoStepResult(self):
|
||||
config = getattr(config_ars, "MINITAUR_REACTIVE_CONFIG")
|
||||
config['num_iterations'] = 2
|
||||
info = train_ars.run_ars(config=config, logdir=FLAGS.test_tmpdir)
|
||||
print (info)
|
||||
self.assertAlmostEqual(info["max_reward"], MAX_RETURN_AFTER_TWO_ITEATIONS)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
Reference in New Issue
Block a user