diff --git a/examples/pybullet/gym/pybullet_envs/ARS/ars.py b/examples/pybullet/gym/pybullet_envs/ARS/ars.py index 760ebd732..6ebe28149 100644 --- a/examples/pybullet/gym/pybullet_envs/ARS/ars.py +++ b/examples/pybullet/gym/pybullet_envs/ARS/ars.py @@ -26,7 +26,7 @@ class Hp(): self.episode_length = 1000 self.learning_rate = 0.02 self.nb_directions = 16 - self.nb_best_directions = 16 + self.nb_best_directions = 8 assert self.nb_best_directions <= self.nb_directions self.noise = 0.03 self.seed = 1 @@ -194,7 +194,7 @@ def train(env, policy, normalizer, hp, parentPipes, args): # Sorting the rollouts by the max(r_pos, r_neg) and selecting the best directions scores = {k:max(r_pos, r_neg) for k,(r_pos,r_neg) in enumerate(zip(positive_rewards, negative_rewards))} - order = sorted(scores.keys(), key = lambda x:scores[x])[:hp.nb_best_directions] + order = sorted(scores.keys(), key = lambda x:-scores[x])[:hp.nb_best_directions] rollouts = [(positive_rewards[k], negative_rewards[k], deltas[k]) for k in order] # Updating our policy