From 40cb8006eeb01c9210baa09095c52ef7d486fb2f Mon Sep 17 00:00:00 2001 From: Erwin Coumans Date: Fri, 16 Jun 2017 17:06:11 -0700 Subject: [PATCH] fix gym/envs/bullet/cartpole_bullet.py (velocity hyperparameter still needs to be tuned) add enjoy_pybullet_cartpole.py --- .../pybullet/gym/enjoy_pybullet_cartpole.py | 29 +++++++++++++++++++ .../gym/envs/bullet/cartpole_bullet.py | 11 ++++--- 2 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 examples/pybullet/gym/enjoy_pybullet_cartpole.py diff --git a/examples/pybullet/gym/enjoy_pybullet_cartpole.py b/examples/pybullet/gym/enjoy_pybullet_cartpole.py new file mode 100644 index 000000000..77fc29c53 --- /dev/null +++ b/examples/pybullet/gym/enjoy_pybullet_cartpole.py @@ -0,0 +1,29 @@ +import gym + +from baselines import deepq +from envs.bullet.cartpole_bullet import CartPoleBulletEnv + +def main(): + env = gym.make('CartPoleBulletEnv-v0') + act = deepq.load("cartpole_model.pkl") + + while True: + obs, done = env.reset(), False + print("obs") + print(obs) + print("type(obs)") + print(type(obs)) + episode_rew = 0 + while not done: + env.render() + + o = obs[None] + aa = act(o) + a = aa[0] + obs, rew, done, _ = env.step(a) + episode_rew += rew + print("Episode reward", episode_rew) + + +if __name__ == '__main__': + main() diff --git a/examples/pybullet/gym/envs/bullet/cartpole_bullet.py b/examples/pybullet/gym/envs/bullet/cartpole_bullet.py index b1f1a1e35..f117e8f71 100644 --- a/examples/pybullet/gym/envs/bullet/cartpole_bullet.py +++ b/examples/pybullet/gym/envs/bullet/cartpole_bullet.py @@ -25,7 +25,7 @@ class CartPoleBulletEnv(gym.Env): def __init__(self): # start the bullet physics server p.connect(p.GUI) -# p.connect(p.DIRECT) + #p.connect(p.DIRECT) observation_high = np.array([ np.finfo(np.float32).max, np.finfo(np.float32).max, @@ -33,7 +33,7 @@ class CartPoleBulletEnv(gym.Env): np.finfo(np.float32).max]) action_high = np.array([0.1]) - self.action_space = spaces.Box(-action_high, action_high) + self.action_space = spaces.Discrete(5) self.observation_space = spaces.Box(-observation_high, observation_high) self.theta_threshold_radians = 1 @@ -55,8 +55,11 @@ class CartPoleBulletEnv(gym.Env): # time.sleep(self.timeStep) self.state = p.getJointState(self.cartpole, 1)[0:2] + p.getJointState(self.cartpole, 0)[0:2] theta, theta_dot, x, x_dot = self.state - force = action - p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, targetVelocity=(action + self.state[3])) + + dv = 0.4 + deltav = [-2.*dv, -dv, 0, dv, 2.*dv][action] + + p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, targetVelocity=(deltav + self.state[3])) done = x < -self.x_threshold \ or x > self.x_threshold \