move pybullet.connect into the Gym environment.

If you like to enable rendering, call the env.render(mode="human") before calling the first env.reset
This commit is contained in:
Erwin Coumans
2017-08-26 14:58:48 -07:00
parent 51b7e1040f
commit e267f5c3d2
10 changed files with 44 additions and 53 deletions

View File

@@ -15,11 +15,11 @@ class MJCFBaseBulletEnv(gym.Env):
'video.frames_per_second': 60 'video.frames_per_second': 60
} }
def __init__(self, robot): def __init__(self, robot, render=False):
self.scene = None self.scene = None
self.physicsClientId=-1
self.camera = Camera() self.camera = Camera()
self.isRender = render
self.robot = robot self.robot = robot
self._seed() self._seed()
@@ -33,6 +33,15 @@ class MJCFBaseBulletEnv(gym.Env):
return [seed] return [seed]
def _reset(self): def _reset(self):
print("self.isRender=")
print(self.isRender)
if (self.physicsClientId<0):
if (self.isRender):
self.physicsClientId = p.connect(p.GUI)
else:
self.physicsClientId = p.connect(p.DIRECT)
p.configureDebugVisualizer(p.COV_ENABLE_GUI,0)
if self.scene is None: if self.scene is None:
self.scene = self.create_single_player_scene() self.scene = self.create_single_player_scene()
if not self.scene.multiplayer: if not self.scene.multiplayer:
@@ -49,7 +58,13 @@ class MJCFBaseBulletEnv(gym.Env):
return s return s
def _render(self, mode, close): def _render(self, mode, close):
pass if (mode=="human"):
self.isRender = True
def _close(self):
if (self.physicsClientId>=0):
p.disconnect(self.physicsClientId)
self.physicsClientId = -1
def HUD(self, state, a, done): def HUD(self, state, a, done):
pass pass

View File

@@ -30,14 +30,10 @@ class SmallReactivePolicy:
def main(): def main():
env = gym.make("AntBulletEnv-v0") env = gym.make("AntBulletEnv-v0")
env.render(mode="human")
cid = p.connect(p.GUI)
p.configureDebugVisualizer(p.COV_ENABLE_GUI,0)
pi = SmallReactivePolicy(env.observation_space, env.action_space) pi = SmallReactivePolicy(env.observation_space, env.action_space)
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
env.reset() env.reset()
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
torsoId = -1 torsoId = -1
for i in range (p.getNumBodies()): for i in range (p.getNumBodies()):
print(p.getBodyInfo(i)) print(p.getBodyInfo(i))
@@ -49,10 +45,8 @@ def main():
frame = 0 frame = 0
score = 0 score = 0
restart_delay = 0 restart_delay = 0
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
obs = env.reset() obs = env.reset()
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
while 1: while 1:
time.sleep(0.001) time.sleep(0.001)
a = pi.act(obs) a = pi.act(obs)

View File

@@ -29,15 +29,11 @@ class SmallReactivePolicy:
def main(): def main():
env = gym.make("HalfCheetahBulletEnv-v0") env = gym.make("HalfCheetahBulletEnv-v0")
env.render(mode="human")
cid = p.connect(p.GUI)
pi = SmallReactivePolicy(env.observation_space, env.action_space) pi = SmallReactivePolicy(env.observation_space, env.action_space)
#disable rendering during reset, makes loading much faster #disable rendering during reset, makes loading much faster
p.configureDebugVisualizer(p.COV_ENABLE_GUI,0)
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
env.reset() env.reset()
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
torsoId = -1 torsoId = -1
for i in range (p.getNumBodies()): for i in range (p.getNumBodies()):
print(p.getBodyInfo(i)) print(p.getBodyInfo(i))
@@ -52,10 +48,8 @@ def main():
frame = 0 frame = 0
score = 0 score = 0
restart_delay = 0 restart_delay = 0
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
obs = env.reset() obs = env.reset()
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
while 1: while 1:
time.sleep(0.001) time.sleep(0.001)
a = pi.act(obs) a = pi.act(obs)

View File

@@ -32,13 +32,10 @@ class SmallReactivePolicy:
def main(): def main():
env = gym.make("HopperBulletEnv-v0") env = gym.make("HopperBulletEnv-v0")
env.render(mode="human")
cid = p.connect(p.GUI)
p.configureDebugVisualizer(p.COV_ENABLE_GUI,0)
pi = SmallReactivePolicy(env.observation_space, env.action_space) pi = SmallReactivePolicy(env.observation_space, env.action_space)
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
env.reset() env.reset()
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
for i in range (p.getNumBodies()): for i in range (p.getNumBodies()):
print(p.getBodyInfo(i)) print(p.getBodyInfo(i))
if (p.getBodyInfo(i)[1].decode() == "hopper"): if (p.getBodyInfo(i)[1].decode() == "hopper"):
@@ -52,10 +49,8 @@ def main():
score = 0 score = 0
restart_delay = 0 restart_delay = 0
#disable rendering during reset, makes loading much faster #disable rendering during reset, makes loading much faster
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
obs = env.reset() obs = env.reset()
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
while 1: while 1:
time.sleep(0.001) time.sleep(0.001)
a = pi.act(obs) a = pi.act(obs)

View File

@@ -30,13 +30,9 @@ class SmallReactivePolicy:
def main(): def main():
env = gym.make("HumanoidBulletEnv-v0") env = gym.make("HumanoidBulletEnv-v0")
env.render(mode="human")
cid = p.connect(p.GUI)
p.configureDebugVisualizer(p.COV_ENABLE_GUI,0)
pi = SmallReactivePolicy(env.observation_space, env.action_space) pi = SmallReactivePolicy(env.observation_space, env.action_space)
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
env.reset() env.reset()
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
torsoId = -1 torsoId = -1
for i in range (p.getNumBodies()): for i in range (p.getNumBodies()):
print(p.getBodyInfo(i)) print(p.getBodyInfo(i))
@@ -47,10 +43,8 @@ def main():
frame = 0 frame = 0
score = 0 score = 0
restart_delay = 0 restart_delay = 0
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
obs = env.reset() obs = env.reset()
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
while 1: while 1:
time.sleep(0.001) time.sleep(0.001)
a = pi.act(obs) a = pi.act(obs)

View File

@@ -29,9 +29,8 @@ class SmallReactivePolicy:
def main(): def main():
env = gym.make("InvertedDoublePendulumBulletEnv-v0") env = gym.make("InvertedDoublePendulumBulletEnv-v0")
env.render(mode="human")
cid = p.connect(p.GUI)
pi = SmallReactivePolicy(env.observation_space, env.action_space) pi = SmallReactivePolicy(env.observation_space, env.action_space)
while 1: while 1:

View File

@@ -30,8 +30,7 @@ class SmallReactivePolicy:
def main(): def main():
print("create env") print("create env")
env = gym.make("InvertedPendulumBulletEnv-v0") env = gym.make("InvertedPendulumBulletEnv-v0")
print("connecting") env.render(mode="human")
cid = p.connect(p.GUI)
pi = SmallReactivePolicy(env.observation_space, env.action_space) pi = SmallReactivePolicy(env.observation_space, env.action_space)
while 1: while 1:

View File

@@ -29,9 +29,8 @@ class SmallReactivePolicy:
def main(): def main():
env = gym.make("InvertedPendulumSwingupBulletEnv-v0") env = gym.make("InvertedPendulumSwingupBulletEnv-v0")
env.render(mode="human")
cid = p.connect(p.GUI)
pi = SmallReactivePolicy(env.observation_space, env.action_space) pi = SmallReactivePolicy(env.observation_space, env.action_space)
while 1: while 1:

View File

@@ -29,15 +29,10 @@ class SmallReactivePolicy:
def main(): def main():
env = gym.make("Walker2DBulletEnv-v0") env = gym.make("Walker2DBulletEnv-v0")
env.render(mode="human")
cid = p.connect(p.GUI)
p.configureDebugVisualizer(p.COV_ENABLE_GUI,0)
pi = SmallReactivePolicy(env.observation_space, env.action_space) pi = SmallReactivePolicy(env.observation_space, env.action_space)
p.configureDebugVisualizer(p.COV_ENABLE_GUI,0)
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
env.reset() env.reset()
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
torsoId = -1 torsoId = -1
for i in range (p.getNumBodies()): for i in range (p.getNumBodies()):
print(p.getBodyInfo(i)) print(p.getBodyInfo(i))

View File

@@ -1,12 +1,15 @@
from .scene_stadium import SinglePlayerStadiumScene from .scene_stadium import SinglePlayerStadiumScene
from .env_bases import MJCFBaseBulletEnv from .env_bases import MJCFBaseBulletEnv
import numpy as np import numpy as np
import pybullet as p
from robot_locomotors import Hopper, Walker2D, HalfCheetah, Ant, Humanoid from robot_locomotors import Hopper, Walker2D, HalfCheetah, Ant, Humanoid
class WalkerBaseBulletEnv(MJCFBaseBulletEnv): class WalkerBaseBulletEnv(MJCFBaseBulletEnv):
def __init__(self, robot): def __init__(self, robot, render=False):
MJCFBaseBulletEnv.__init__(self, robot) print("WalkerBase::__init__")
MJCFBaseBulletEnv.__init__(self, robot, render)
self.camera_x = 0 self.camera_x = 0
self.walk_target_x = 1e3 # kilometer away self.walk_target_x = 1e3 # kilometer away
self.walk_target_y = 0 self.walk_target_y = 0
@@ -16,11 +19,15 @@ class WalkerBaseBulletEnv(MJCFBaseBulletEnv):
return self.stadium_scene return self.stadium_scene
def _reset(self): def _reset(self):
r = MJCFBaseBulletEnv._reset(self) r = MJCFBaseBulletEnv._reset(self)
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,0)
self.parts, self.jdict, self.ordered_joints, self.robot_body = self.robot.addToScene( self.parts, self.jdict, self.ordered_joints, self.robot_body = self.robot.addToScene(
self.stadium_scene.ground_plane_mjcf) self.stadium_scene.ground_plane_mjcf)
self.ground_ids = set([(self.parts[f].bodies[self.parts[f].bodyIndex], self.parts[f].bodyPartIndex) for f in self.ground_ids = set([(self.parts[f].bodies[self.parts[f].bodyIndex], self.parts[f].bodyPartIndex) for f in
self.foot_ground_object_names]) self.foot_ground_object_names])
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING,1)
return r return r
def move_robot(self, init_x, init_y, init_z): def move_robot(self, init_x, init_y, init_z):