Update gym_locomotion_envs.py
As suggested in https://github.com/bulletphysics/bullet3/pull/1759. The default isDone lets done = alive<0, and a special case is made for halfcheetah, forcing done=False. I had to pass the 'alive' condition as an additive parameter of WalkerBaseBulletEnv.
This commit is contained in:
committed by
GitHub
parent
997211650e
commit
a57c480f28
@@ -14,6 +14,7 @@ class WalkerBaseBulletEnv(MJCFBaseBulletEnv):
|
|||||||
self.walk_target_x = 1e3 # kilometer away
|
self.walk_target_x = 1e3 # kilometer away
|
||||||
self.walk_target_y = 0
|
self.walk_target_y = 0
|
||||||
self.stateId=-1
|
self.stateId=-1
|
||||||
|
self.alive = None
|
||||||
|
|
||||||
|
|
||||||
def create_single_player_scene(self, bullet_client):
|
def create_single_player_scene(self, bullet_client):
|
||||||
@@ -40,6 +41,12 @@ class WalkerBaseBulletEnv(MJCFBaseBulletEnv):
|
|||||||
|
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
def _isDone(self):
|
||||||
|
if self._alive is not None:
|
||||||
|
return self._alive < 0
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def move_robot(self, init_x, init_y, init_z):
|
def move_robot(self, init_x, init_y, init_z):
|
||||||
"Used by multiplayer stadium to move sideways, to another running lane."
|
"Used by multiplayer stadium to move sideways, to another running lane."
|
||||||
self.cpp_robot.query_position()
|
self.cpp_robot.query_position()
|
||||||
@@ -60,8 +67,8 @@ class WalkerBaseBulletEnv(MJCFBaseBulletEnv):
|
|||||||
|
|
||||||
state = self.robot.calc_state() # also calculates self.joints_at_limit
|
state = self.robot.calc_state() # also calculates self.joints_at_limit
|
||||||
|
|
||||||
alive = float(self.robot.alive_bonus(state[0]+self.robot.initial_z, self.robot.body_rpy[1])) # state[0] is body height above ground, body_rpy[1] is pitch
|
self.alive = float(self.robot.alive_bonus(state[0]+self.robot.initial_z, self.robot.body_rpy[1])) # state[0] is body height above ground, body_rpy[1] is pitch
|
||||||
done = alive < 0
|
done = self._isDone()
|
||||||
if not np.isfinite(state).all():
|
if not np.isfinite(state).all():
|
||||||
print("~INF~", state)
|
print("~INF~", state)
|
||||||
done = True
|
done = True
|
||||||
@@ -136,6 +143,9 @@ class HalfCheetahBulletEnv(WalkerBaseBulletEnv):
|
|||||||
self.robot = HalfCheetah()
|
self.robot = HalfCheetah()
|
||||||
WalkerBaseBulletEnv.__init__(self, self.robot)
|
WalkerBaseBulletEnv.__init__(self, self.robot)
|
||||||
|
|
||||||
|
def _isDone(self):
|
||||||
|
return False
|
||||||
|
|
||||||
class AntBulletEnv(WalkerBaseBulletEnv):
|
class AntBulletEnv(WalkerBaseBulletEnv):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.robot = Ant()
|
self.robot = Ant()
|
||||||
@@ -172,4 +182,3 @@ class HumanoidFlagrunHarderBulletEnv(HumanoidBulletEnv):
|
|||||||
s = HumanoidBulletEnv.create_single_player_scene(self, bullet_client)
|
s = HumanoidBulletEnv.create_single_player_scene(self, bullet_client)
|
||||||
s.zero_at_running_strip_start_line = False
|
s.zero_at_running_strip_start_line = False
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user