From 3cdbc4cc29424c8d184239ca4887e26678af3f8e Mon Sep 17 00:00:00 2001 From: Erwin Coumans Date: Thu, 2 Jan 2020 19:33:57 -0800 Subject: [PATCH] fix CartPoleBulletEnv-v1 and add CartPoleContinuousBulletEnv-v0 (continuous version) --- .../pybullet/gym/pybullet_envs/__init__.py | 8 ++ .../gym/pybullet_envs/bullet/__init__.py | 1 + .../pybullet_envs/bullet/cartpole_bullet.py | 110 ++++++++++++++---- 3 files changed, 97 insertions(+), 22 deletions(-) diff --git a/examples/pybullet/gym/pybullet_envs/__init__.py b/examples/pybullet/gym/pybullet_envs/__init__.py index 42c5d7ccf..83b536fdb 100644 --- a/examples/pybullet/gym/pybullet_envs/__init__.py +++ b/examples/pybullet/gym/pybullet_envs/__init__.py @@ -25,6 +25,14 @@ register( reward_threshold=190.0, ) +register( + id='CartPoleContinuousBulletEnv-v0', + entry_point='pybullet_envs.bullet:CartPoleContinuousBulletEnv', + max_episode_steps=200, + reward_threshold=190.0, +) + + register( id='MinitaurBulletEnv-v0', entry_point='pybullet_envs.bullet:MinitaurBulletEnv', diff --git a/examples/pybullet/gym/pybullet_envs/bullet/__init__.py b/examples/pybullet/gym/pybullet_envs/bullet/__init__.py index 75c133276..42932a7ed 100644 --- a/examples/pybullet/gym/pybullet_envs/bullet/__init__.py +++ b/examples/pybullet/gym/pybullet_envs/bullet/__init__.py @@ -1,4 +1,5 @@ from pybullet_envs.bullet.cartpole_bullet import CartPoleBulletEnv +from pybullet_envs.bullet.cartpole_bullet import CartPoleContinuousBulletEnv from pybullet_envs.bullet.minitaur_gym_env import MinitaurBulletEnv from pybullet_envs.bullet.minitaur_duck_gym_env import MinitaurBulletDuckEnv from pybullet_envs.bullet.racecarGymEnv import RacecarGymEnv diff --git a/examples/pybullet/gym/pybullet_envs/bullet/cartpole_bullet.py b/examples/pybullet/gym/pybullet_envs/bullet/cartpole_bullet.py index f1b594cde..1fe2cfe65 100644 --- a/examples/pybullet/gym/pybullet_envs/bullet/cartpole_bullet.py +++ b/examples/pybullet/gym/pybullet_envs/bullet/cartpole_bullet.py @@ -15,8 +15,9 @@ from gym.utils import seeding import numpy as np import time import subprocess -import pybullet as p +import pybullet as p2 import pybullet_data +import pybullet_utils.bullet_client as bc from pkg_resources import parse_version logger = logging.getLogger(__name__) @@ -25,13 +26,13 @@ logger = logging.getLogger(__name__) class CartPoleBulletEnv(gym.Env): metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 50} - def __init__(self, renders=True): + def __init__(self, renders=False, discrete_actions=True): # start the bullet physics server self._renders = renders - if (renders): - p.connect(p.GUI) - else: - p.connect(p.DIRECT) + self._discrete_actions = discrete_actions + self._render_height = 200 + self._render_width = 320 + self._physics_client_id = -1 self.theta_threshold_radians = 12 * 2 * math.pi / 360 self.x_threshold = 0.4 #2.4 high = np.array([ @@ -42,7 +43,13 @@ class CartPoleBulletEnv(gym.Env): self.force_mag = 10 - self.action_space = spaces.Discrete(2) + if self._discrete_actions: + self.action_space = spaces.Discrete(2) + else: + action_dim = 1 + action_high = np.array([self.force_mag] * action_dim) + self.action_space = spaces.Box(-action_high, action_high) + self.observation_space = spaces.Box(-high, high, dtype=np.float32) self.seed() @@ -58,7 +65,11 @@ class CartPoleBulletEnv(gym.Env): return [seed] def step(self, action): - force = self.force_mag if action == 1 else -self.force_mag + p = self._p + if self._discrete_actions: + force = self.force_mag if action == 1 else -self.force_mag + else: + force = action[0] p.setJointMotorControl2(self.cartpole, 0, p.TORQUE_CONTROL, force=force) p.stepSimulation() @@ -77,19 +88,27 @@ class CartPoleBulletEnv(gym.Env): def reset(self): # print("-----------reset simulation---------------") - p.resetSimulation() - self.cartpole = p.loadURDF(os.path.join(pybullet_data.getDataPath(), "cartpole.urdf"), - [0, 0, 0]) - p.changeDynamics(self.cartpole, -1, linearDamping=0, angularDamping=0) - p.changeDynamics(self.cartpole, 0, linearDamping=0, angularDamping=0) - p.changeDynamics(self.cartpole, 1, linearDamping=0, angularDamping=0) - self.timeStep = 0.02 - p.setJointMotorControl2(self.cartpole, 1, p.VELOCITY_CONTROL, force=0) - p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, force=0) - p.setGravity(0, 0, -9.8) - p.setTimeStep(self.timeStep) - p.setRealTimeSimulation(0) - + if self._physics_client_id < 0: + if self._renders: + self._p = bc.BulletClient(connection_mode=p2.GUI) + else: + self._p = bc.BulletClient() + self._physics_client_id = self._p._client + + p = self._p + p.resetSimulation() + self.cartpole = p.loadURDF(os.path.join(pybullet_data.getDataPath(), "cartpole.urdf"), + [0, 0, 0]) + p.changeDynamics(self.cartpole, -1, linearDamping=0, angularDamping=0) + p.changeDynamics(self.cartpole, 0, linearDamping=0, angularDamping=0) + p.changeDynamics(self.cartpole, 1, linearDamping=0, angularDamping=0) + self.timeStep = 0.02 + p.setJointMotorControl2(self.cartpole, 1, p.VELOCITY_CONTROL, force=0) + p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, force=0) + p.setGravity(0, 0, -9.8) + p.setTimeStep(self.timeStep) + p.setRealTimeSimulation(0) + p = self._p randstate = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) p.resetJointState(self.cartpole, 1, randstate[0], randstate[1]) p.resetJointState(self.cartpole, 0, randstate[2], randstate[3]) @@ -99,4 +118,51 @@ class CartPoleBulletEnv(gym.Env): return np.array(self.state) def render(self, mode='human', close=False): - return + if mode == "human": + self._renders = True + if mode != "rgb_array": + return np.array([]) + base_pos=[0,0,0] + self._cam_dist = 2 + self._cam_pitch = 0.3 + self._cam_yaw = 0 + if (self._physics_client_id>=0): + view_matrix = self._p.computeViewMatrixFromYawPitchRoll( + cameraTargetPosition=base_pos, + distance=self._cam_dist, + yaw=self._cam_yaw, + pitch=self._cam_pitch, + roll=0, + upAxisIndex=2) + proj_matrix = self._p.computeProjectionMatrixFOV(fov=60, + aspect=float(self._render_width) / + self._render_height, + nearVal=0.1, + farVal=100.0) + (_, _, px, _, _) = self._p.getCameraImage( + width=self._render_width, + height=self._render_height, + renderer=self._p.ER_BULLET_HARDWARE_OPENGL, + viewMatrix=view_matrix, + projectionMatrix=proj_matrix) + else: + px = np.array([[[255,255,255,255]]*self._render_width]*self._render_height, dtype=np.uint8) + rgb_array = np.array(px, dtype=np.uint8) + rgb_array = np.reshape(np.array(px), (self._render_height, self._render_width, -1)) + rgb_array = rgb_array[:, :, :3] + return rgb_array + + def configure(self, args): + pass + + def close(self): + if self._physics_client_id >= 0: + self._p.disconnect() + self._physics_client_id = -1 + +class CartPoleContinuousBulletEnv(CartPoleBulletEnv): + metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 50} + + def __init__(self, renders=False): + # start the bullet physics server + CartPoleBulletEnv.__init__(self, renders, discrete_actions=False)