diff --git a/examples/pybullet/gym/envs/bullet/cartpole_bullet.py b/examples/pybullet/gym/envs/bullet/cartpole_bullet.py index f117e8f71..3e7f29daf 100644 --- a/examples/pybullet/gym/envs/bullet/cartpole_bullet.py +++ b/examples/pybullet/gym/envs/bullet/cartpole_bullet.py @@ -22,10 +22,14 @@ class CartPoleBulletEnv(gym.Env): 'video.frames_per_second' : 50 } - def __init__(self): + def __init__(self, renders=True): # start the bullet physics server - p.connect(p.GUI) - #p.connect(p.DIRECT) + self._renders = renders + if (renders): + p.connect(p.GUI) + else: + p.connect(p.DIRECT) + observation_high = np.array([ np.finfo(np.float32).max, np.finfo(np.float32).max, @@ -33,7 +37,7 @@ class CartPoleBulletEnv(gym.Env): np.finfo(np.float32).max]) action_high = np.array([0.1]) - self.action_space = spaces.Discrete(5) + self.action_space = spaces.Discrete(9) self.observation_space = spaces.Box(-observation_high, observation_high) self.theta_threshold_radians = 1 @@ -56,8 +60,8 @@ class CartPoleBulletEnv(gym.Env): self.state = p.getJointState(self.cartpole, 1)[0:2] + p.getJointState(self.cartpole, 0)[0:2] theta, theta_dot, x, x_dot = self.state - dv = 0.4 - deltav = [-2.*dv, -dv, 0, dv, 2.*dv][action] + dv = 0.1 + deltav = [-10.*dv,-5.*dv, -2.*dv, -0.1*dv, 0, 0.1*dv, 2.*dv,5.*dv, 10.*dv][action] p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, targetVelocity=(deltav + self.state[3])) diff --git a/examples/pybullet/gym/train_pybullet_cartpole.py b/examples/pybullet/gym/train_pybullet_cartpole.py index 353fa4a86..4b7f4839e 100644 --- a/examples/pybullet/gym/train_pybullet_cartpole.py +++ b/examples/pybullet/gym/train_pybullet_cartpole.py @@ -12,7 +12,7 @@ def callback(lcl, glb): def main(): - env = gym.make('CartPoleBulletEnv-v0') + env = CartPoleBulletEnv(renders=False) model = deepq.models.mlp([64]) act = deepq.learn( env,