diff --git a/examples/pybullet/gym/pybullet_envs/gym_locomotion_envs.py b/examples/pybullet/gym/pybullet_envs/gym_locomotion_envs.py index 6220b28cf..93cd8af33 100644 --- a/examples/pybullet/gym/pybullet_envs/gym_locomotion_envs.py +++ b/examples/pybullet/gym/pybullet_envs/gym_locomotion_envs.py @@ -39,6 +39,9 @@ class WalkerBaseBulletEnv(MJCFBaseBulletEnv): return r + + def _isDone(self): + return self._alive < 0 def move_robot(self, init_x, init_y, init_z): "Used by multiplayer stadium to move sideways, to another running lane." @@ -60,8 +63,8 @@ class WalkerBaseBulletEnv(MJCFBaseBulletEnv): state = self.robot.calc_state() # also calculates self.joints_at_limit - alive = float(self.robot.alive_bonus(state[0]+self.robot.initial_z, self.robot.body_rpy[1])) # state[0] is body height above ground, body_rpy[1] is pitch - done = alive < 0 + self._alive = float(self.robot.alive_bonus(state[0]+self.robot.initial_z, self.robot.body_rpy[1])) # state[0] is body height above ground, body_rpy[1] is pitch + done = self._isDone() if not np.isfinite(state).all(): print("~INF~", state) done = True @@ -89,7 +92,7 @@ class WalkerBaseBulletEnv(MJCFBaseBulletEnv): debugmode=0 if(debugmode): print("alive=") - print(alive) + print(self._alive) print("progress") print(progress) print("electricity_cost") @@ -100,7 +103,7 @@ class WalkerBaseBulletEnv(MJCFBaseBulletEnv): print(feet_collision_cost) self.rewards = [ - alive, + self._alive, progress, electricity_cost, joints_at_limit_cost, @@ -135,6 +138,9 @@ class HalfCheetahBulletEnv(WalkerBaseBulletEnv): def __init__(self): self.robot = HalfCheetah() WalkerBaseBulletEnv.__init__(self, self.robot) + + def _isDone(self): + return False class AntBulletEnv(WalkerBaseBulletEnv): def __init__(self): @@ -172,4 +178,3 @@ class HumanoidFlagrunHarderBulletEnv(HumanoidBulletEnv): s = HumanoidBulletEnv.create_single_player_scene(self, bullet_client) s.zero_at_running_strip_start_line = False return s -