30 lines
944 B
Python
30 lines
944 B
Python
"""Tests for google3.experimental.users.jietan.ARS.train_ars.
|
|
blaze build -c opt //experimental/users/jietan/ARS:train_ars_test
|
|
blaze-bin/experimental/users/jietan/ARS/train_ars_test
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from absl import flags
|
|
from google3.testing.pybase import googletest
|
|
from google3.experimental.users.jietan.ARS import train_ars
|
|
from google3.experimental.users.jietan.ARS import config_ars
|
|
|
|
FLAGS = flags.FLAGS
|
|
MAX_RETURN_AFTER_TWO_ITEATIONS = 0.0890905394617
|
|
|
|
class TrainArsTest(googletest.TestCase):
|
|
|
|
def testArsTwoStepResult(self):
|
|
config = getattr(config_ars, "MINITAUR_REACTIVE_CONFIG")
|
|
config['num_iterations'] = 2
|
|
info = train_ars.run_ars(config=config, logdir=FLAGS.test_tmpdir)
|
|
print (info)
|
|
self.assertAlmostEqual(info["max_reward"], MAX_RETURN_AFTER_TWO_ITEATIONS)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
googletest.main()
|