fix a few pybullet Gym environments for rendering in stable_baselines

if PYBULLET_EGL environment is set, try to enable EGL for faster rendering
bump up pybullet to 2.6.2
This commit is contained in:
Erwin Coumans
2020-01-01 18:47:46 -08:00
parent 528bd28e34
commit b6dea7ba64
6 changed files with 44 additions and 21 deletions

View File

@@ -1,10 +1,17 @@
import gym, gym.spaces, gym.utils, gym.utils.seeding
import numpy as np
import pybullet
import os
from pybullet_utils import bullet_client
from pkg_resources import parse_version
try:
if os.environ["PYBULLET_EGL"]:
import pkgutil
except:
pass
class MJCFBaseBulletEnv(gym.Env):
"""
@@ -31,6 +38,7 @@ class MJCFBaseBulletEnv(gym.Env):
self.action_space = robot.action_space
self.observation_space = robot.observation_space
#self.reset()
def configure(self, args):
self.robot.args = args
@@ -48,7 +56,19 @@ class MJCFBaseBulletEnv(gym.Env):
self._p = bullet_client.BulletClient(connection_mode=pybullet.GUI)
else:
self._p = bullet_client.BulletClient()
self._p.resetSimulation()
#optionally enable EGL for faster headless rendering
try:
if os.environ["PYBULLET_EGL"]:
con_mode = self._p.getConnectionInfo()['connectionMethod']
if con_mode==self._p.DIRECT:
egl = pkgutil.get_loader('eglRenderer')
if (egl):
self._p.loadPlugin(egl.get_filename(), "_eglRendererPlugin")
else:
self._p.loadPlugin("eglRendererPlugin")
except:
pass
self.physicsClientId = self._p._client
self._p.configureDebugVisualizer(pybullet.COV_ENABLE_GUI, 0)
@@ -77,24 +97,35 @@ class MJCFBaseBulletEnv(gym.Env):
if (hasattr(self, 'robot')):
if (hasattr(self.robot, 'body_xyz')):
base_pos = self.robot.body_xyz
view_matrix = self._p.computeViewMatrixFromYawPitchRoll(cameraTargetPosition=base_pos,
if (self.physicsClientId>=0):
view_matrix = self._p.computeViewMatrixFromYawPitchRoll(cameraTargetPosition=base_pos,
distance=self._cam_dist,
yaw=self._cam_yaw,
pitch=self._cam_pitch,
roll=0,
upAxisIndex=2)
proj_matrix = self._p.computeProjectionMatrixFOV(fov=60,
proj_matrix = self._p.computeProjectionMatrixFOV(fov=60,
aspect=float(self._render_width) /
self._render_height,
nearVal=0.1,
farVal=100.0)
(_, _, px, _, _) = self._p.getCameraImage(width=self._render_width,
(_, _, px, _, _) = self._p.getCameraImage(width=self._render_width,
height=self._render_height,
viewMatrix=view_matrix,
projectionMatrix=proj_matrix,
renderer=pybullet.ER_BULLET_HARDWARE_OPENGL)
rgb_array = np.array(px)
try:
# Keep the previous orientation of the camera set by the user.
con_mode = self._p.getConnectionInfo()['connectionMethod']
if con_mode==self._p.SHARED_MEMORY or con_mode == self._p.GUI:
[yaw, pitch, dist] = self._p.getDebugVisualizerCamera()[8:11]
self._p.resetDebugVisualizerCamera(dist, yaw, pitch, base_pos)
except:
pass
else:
px = np.array([[[255,255,255,255]]*self._render_width]*self._render_height, dtype=np.uint8)
rgb_array = np.array(px, dtype=np.uint8)
rgb_array = np.reshape(np.array(px), (self._render_height, self._render_width, -1))
rgb_array = rgb_array[:, :, :3]
return rgb_array

View File

@@ -9,12 +9,12 @@ class WalkerBaseBulletEnv(MJCFBaseBulletEnv):
def __init__(self, robot, render=False):
# print("WalkerBase::__init__ start")
MJCFBaseBulletEnv.__init__(self, robot, render)
self.camera_x = 0
self.walk_target_x = 1e3 # kilometer away
self.walk_target_y = 0
self.stateId = -1
MJCFBaseBulletEnv.__init__(self, robot, render)
def create_single_player_scene(self, bullet_client):
self.stadium_scene = SinglePlayerStadiumScene(bullet_client,

View File

@@ -32,7 +32,7 @@ class StadiumScene(Scene):
for i in self.ground_plane_mjcf:
self._p.changeDynamics(i, -1, lateralFriction=0.8, restitution=0.5)
self._p.changeVisualShape(i, -1, rgbaColor=[1, 1, 1, 0.8])
self._p.configureDebugVisualizer(pybullet.COV_ENABLE_PLANAR_REFLECTION, 1)
self._p.configureDebugVisualizer(pybullet.COV_ENABLE_PLANAR_REFLECTION,i)
# for j in range(p.getNumJoints(i)):
# self._p.changeDynamics(i,j,lateralFriction=0)

View File

@@ -42,17 +42,6 @@ class BulletClient(object):
"""Inject the client id into Bullet functions."""
attribute = getattr(pybullet, name)
if inspect.isbuiltin(attribute):
if name not in [
"invertTransform",
"multiplyTransforms",
"getMatrixFromQuaternion",
"getEulerFromQuaternion",
"computeViewMatrixFromYawPitchRoll",
"computeProjectionMatrixFOV",
"getQuaternionFromEuler",
]: # A temporary hack for now.
attribute = functools.partial(attribute, physicsClientId=self._client)
if name=="disconnect":
self._client = -1
return attribute