fix CartPoleBulletEnv-v1 and add CartPoleContinuousBulletEnv-v0 (continuous version)

This commit is contained in:
Erwin Coumans
2020-01-02 19:33:57 -08:00
parent ea3857c2c4
commit 3cdbc4cc29
3 changed files with 97 additions and 22 deletions

View File

@@ -25,6 +25,14 @@ register(
reward_threshold=190.0, reward_threshold=190.0,
) )
register(
id='CartPoleContinuousBulletEnv-v0',
entry_point='pybullet_envs.bullet:CartPoleContinuousBulletEnv',
max_episode_steps=200,
reward_threshold=190.0,
)
register( register(
id='MinitaurBulletEnv-v0', id='MinitaurBulletEnv-v0',
entry_point='pybullet_envs.bullet:MinitaurBulletEnv', entry_point='pybullet_envs.bullet:MinitaurBulletEnv',

View File

@@ -1,4 +1,5 @@
from pybullet_envs.bullet.cartpole_bullet import CartPoleBulletEnv from pybullet_envs.bullet.cartpole_bullet import CartPoleBulletEnv
from pybullet_envs.bullet.cartpole_bullet import CartPoleContinuousBulletEnv
from pybullet_envs.bullet.minitaur_gym_env import MinitaurBulletEnv from pybullet_envs.bullet.minitaur_gym_env import MinitaurBulletEnv
from pybullet_envs.bullet.minitaur_duck_gym_env import MinitaurBulletDuckEnv from pybullet_envs.bullet.minitaur_duck_gym_env import MinitaurBulletDuckEnv
from pybullet_envs.bullet.racecarGymEnv import RacecarGymEnv from pybullet_envs.bullet.racecarGymEnv import RacecarGymEnv

View File

@@ -15,8 +15,9 @@ from gym.utils import seeding
import numpy as np import numpy as np
import time import time
import subprocess import subprocess
import pybullet as p import pybullet as p2
import pybullet_data import pybullet_data
import pybullet_utils.bullet_client as bc
from pkg_resources import parse_version from pkg_resources import parse_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,13 +26,13 @@ logger = logging.getLogger(__name__)
class CartPoleBulletEnv(gym.Env): class CartPoleBulletEnv(gym.Env):
metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 50} metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 50}
def __init__(self, renders=True): def __init__(self, renders=False, discrete_actions=True):
# start the bullet physics server # start the bullet physics server
self._renders = renders self._renders = renders
if (renders): self._discrete_actions = discrete_actions
p.connect(p.GUI) self._render_height = 200
else: self._render_width = 320
p.connect(p.DIRECT) self._physics_client_id = -1
self.theta_threshold_radians = 12 * 2 * math.pi / 360 self.theta_threshold_radians = 12 * 2 * math.pi / 360
self.x_threshold = 0.4 #2.4 self.x_threshold = 0.4 #2.4
high = np.array([ high = np.array([
@@ -42,7 +43,13 @@ class CartPoleBulletEnv(gym.Env):
self.force_mag = 10 self.force_mag = 10
self.action_space = spaces.Discrete(2) if self._discrete_actions:
self.action_space = spaces.Discrete(2)
else:
action_dim = 1
action_high = np.array([self.force_mag] * action_dim)
self.action_space = spaces.Box(-action_high, action_high)
self.observation_space = spaces.Box(-high, high, dtype=np.float32) self.observation_space = spaces.Box(-high, high, dtype=np.float32)
self.seed() self.seed()
@@ -58,7 +65,11 @@ class CartPoleBulletEnv(gym.Env):
return [seed] return [seed]
def step(self, action): def step(self, action):
force = self.force_mag if action == 1 else -self.force_mag p = self._p
if self._discrete_actions:
force = self.force_mag if action == 1 else -self.force_mag
else:
force = action[0]
p.setJointMotorControl2(self.cartpole, 0, p.TORQUE_CONTROL, force=force) p.setJointMotorControl2(self.cartpole, 0, p.TORQUE_CONTROL, force=force)
p.stepSimulation() p.stepSimulation()
@@ -77,19 +88,27 @@ class CartPoleBulletEnv(gym.Env):
def reset(self): def reset(self):
# print("-----------reset simulation---------------") # print("-----------reset simulation---------------")
p.resetSimulation() if self._physics_client_id < 0:
self.cartpole = p.loadURDF(os.path.join(pybullet_data.getDataPath(), "cartpole.urdf"), if self._renders:
[0, 0, 0]) self._p = bc.BulletClient(connection_mode=p2.GUI)
p.changeDynamics(self.cartpole, -1, linearDamping=0, angularDamping=0) else:
p.changeDynamics(self.cartpole, 0, linearDamping=0, angularDamping=0) self._p = bc.BulletClient()
p.changeDynamics(self.cartpole, 1, linearDamping=0, angularDamping=0) self._physics_client_id = self._p._client
self.timeStep = 0.02
p.setJointMotorControl2(self.cartpole, 1, p.VELOCITY_CONTROL, force=0) p = self._p
p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, force=0) p.resetSimulation()
p.setGravity(0, 0, -9.8) self.cartpole = p.loadURDF(os.path.join(pybullet_data.getDataPath(), "cartpole.urdf"),
p.setTimeStep(self.timeStep) [0, 0, 0])
p.setRealTimeSimulation(0) p.changeDynamics(self.cartpole, -1, linearDamping=0, angularDamping=0)
p.changeDynamics(self.cartpole, 0, linearDamping=0, angularDamping=0)
p.changeDynamics(self.cartpole, 1, linearDamping=0, angularDamping=0)
self.timeStep = 0.02
p.setJointMotorControl2(self.cartpole, 1, p.VELOCITY_CONTROL, force=0)
p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, force=0)
p.setGravity(0, 0, -9.8)
p.setTimeStep(self.timeStep)
p.setRealTimeSimulation(0)
p = self._p
randstate = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) randstate = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
p.resetJointState(self.cartpole, 1, randstate[0], randstate[1]) p.resetJointState(self.cartpole, 1, randstate[0], randstate[1])
p.resetJointState(self.cartpole, 0, randstate[2], randstate[3]) p.resetJointState(self.cartpole, 0, randstate[2], randstate[3])
@@ -99,4 +118,51 @@ class CartPoleBulletEnv(gym.Env):
return np.array(self.state) return np.array(self.state)
def render(self, mode='human', close=False): def render(self, mode='human', close=False):
return if mode == "human":
self._renders = True
if mode != "rgb_array":
return np.array([])
base_pos=[0,0,0]
self._cam_dist = 2
self._cam_pitch = 0.3
self._cam_yaw = 0
if (self._physics_client_id>=0):
view_matrix = self._p.computeViewMatrixFromYawPitchRoll(
cameraTargetPosition=base_pos,
distance=self._cam_dist,
yaw=self._cam_yaw,
pitch=self._cam_pitch,
roll=0,
upAxisIndex=2)
proj_matrix = self._p.computeProjectionMatrixFOV(fov=60,
aspect=float(self._render_width) /
self._render_height,
nearVal=0.1,
farVal=100.0)
(_, _, px, _, _) = self._p.getCameraImage(
width=self._render_width,
height=self._render_height,
renderer=self._p.ER_BULLET_HARDWARE_OPENGL,
viewMatrix=view_matrix,
projectionMatrix=proj_matrix)
else:
px = np.array([[[255,255,255,255]]*self._render_width]*self._render_height, dtype=np.uint8)
rgb_array = np.array(px, dtype=np.uint8)
rgb_array = np.reshape(np.array(px), (self._render_height, self._render_width, -1))
rgb_array = rgb_array[:, :, :3]
return rgb_array
def configure(self, args):
pass
def close(self):
if self._physics_client_id >= 0:
self._p.disconnect()
self._physics_client_id = -1
class CartPoleContinuousBulletEnv(CartPoleBulletEnv):
metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 50}
def __init__(self, renders=False):
# start the bullet physics server
CartPoleBulletEnv.__init__(self, renders, discrete_actions=False)