update to CartPoleBulletEnv-v1 and check with latest baselines v0.1.5, works fine.
(make it more similar to classical control cartpole)
This commit is contained in:
@@ -59,9 +59,15 @@
|
|||||||
</visual>
|
</visual>
|
||||||
<inertial>
|
<inertial>
|
||||||
<origin xyz="0 0 0.5"/>
|
<origin xyz="0 0 0.5"/>
|
||||||
<mass value="10"/>
|
<mass value="0.1"/>
|
||||||
<inertia ixx="1.0" ixy="0.0" ixz="0.0" iyy="1.0" iyz="0.0" izz="1.0"/>
|
<inertia ixx="1.0" ixy="0.0" ixz="0.0" iyy="1.0" iyz="0.0" izz="1.0"/>
|
||||||
</inertial>
|
</inertial>
|
||||||
|
<collision>
|
||||||
|
<geometry>
|
||||||
|
<box size="0.05 0.05 1.0"/>
|
||||||
|
</geometry>
|
||||||
|
<origin rpy="0 0 0" xyz="0 0 0.5"/>
|
||||||
|
</collision>
|
||||||
</link>
|
</link>
|
||||||
|
|
||||||
<joint name="cart_to_pole" type="continuous">
|
<joint name="cart_to_pole" type="continuous">
|
||||||
|
|||||||
@@ -9,17 +9,17 @@ def register(id,*args,**kvargs):
|
|||||||
# ------------bullet-------------
|
# ------------bullet-------------
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='CartPoleBulletEnv-v0',
|
id='CartPoleBulletEnv-v1',
|
||||||
entry_point='pybullet_envs.bullet:CartPoleBulletEnv',
|
entry_point='pybullet_envs.bullet:CartPoleBulletEnv',
|
||||||
timestep_limit=1000,
|
max_episode_steps=200,
|
||||||
reward_threshold=950.0,
|
reward_threshold=190.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='MinitaurBulletEnv-v0',
|
id='MinitaurBulletEnv-v0',
|
||||||
entry_point='pybullet_envs.bullet:MinitaurBulletEnv',
|
entry_point='pybullet_envs.bullet:MinitaurBulletEnv',
|
||||||
timestep_limit=1000,
|
timestep_limit=1000,
|
||||||
reward_threshold=5.0,
|
reward_threshold=15.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
|
|||||||
@@ -5,12 +5,13 @@ parentdir = os.path.dirname(os.path.dirname(currentdir))
|
|||||||
os.sys.path.insert(0,parentdir)
|
os.sys.path.insert(0,parentdir)
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
import time
|
||||||
|
|
||||||
from baselines import deepq
|
from baselines import deepq
|
||||||
from pybullet_envs.bullet.cartpole_bullet import CartPoleBulletEnv
|
from pybullet_envs.bullet.cartpole_bullet import CartPoleBulletEnv
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
env = gym.make('CartPoleBulletEnv-v0')
|
env = gym.make('CartPoleBulletEnv-v1')
|
||||||
act = deepq.load("cartpole_model.pkl")
|
act = deepq.load("cartpole_model.pkl")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -28,6 +29,7 @@ def main():
|
|||||||
a = aa[0]
|
a = aa[0]
|
||||||
obs, rew, done, _ = env.step(a)
|
obs, rew, done, _ = env.step(a)
|
||||||
episode_rew += rew
|
episode_rew += rew
|
||||||
|
time.sleep(1./240.)
|
||||||
print("Episode reward", episode_rew)
|
print("Episode reward", episode_rew)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -34,24 +34,26 @@ class CartPoleBulletEnv(gym.Env):
|
|||||||
p.connect(p.GUI)
|
p.connect(p.GUI)
|
||||||
else:
|
else:
|
||||||
p.connect(p.DIRECT)
|
p.connect(p.DIRECT)
|
||||||
|
self.theta_threshold_radians = 12 * 2 * math.pi / 360
|
||||||
|
self.x_threshold = 0.4 #2.4
|
||||||
|
high = np.array([
|
||||||
|
self.x_threshold * 2,
|
||||||
|
np.finfo(np.float32).max,
|
||||||
|
self.theta_threshold_radians * 2,
|
||||||
|
np.finfo(np.float32).max])
|
||||||
|
|
||||||
observation_high = np.array([
|
self.force_mag = 10
|
||||||
np.finfo(np.float32).max,
|
|
||||||
np.finfo(np.float32).max,
|
|
||||||
np.finfo(np.float32).max,
|
|
||||||
np.finfo(np.float32).max])
|
|
||||||
action_high = np.array([0.1])
|
|
||||||
|
|
||||||
self.action_space = spaces.Discrete(9)
|
self.action_space = spaces.Discrete(2)
|
||||||
self.observation_space = spaces.Box(-observation_high, observation_high)
|
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
||||||
|
|
||||||
self.theta_threshold_radians = 1
|
|
||||||
self.x_threshold = 2.4
|
|
||||||
self._seed()
|
self._seed()
|
||||||
# self.reset()
|
# self.reset()
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
self._configure()
|
self._configure()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _configure(self, display=None):
|
def _configure(self, display=None):
|
||||||
self.display = display
|
self.display = display
|
||||||
|
|
||||||
@@ -60,41 +62,43 @@ class CartPoleBulletEnv(gym.Env):
|
|||||||
return [seed]
|
return [seed]
|
||||||
|
|
||||||
def _step(self, action):
|
def _step(self, action):
|
||||||
|
force = self.force_mag if action==1 else -self.force_mag
|
||||||
|
|
||||||
|
p.setJointMotorControl2(self.cartpole, 0, p.TORQUE_CONTROL, force=force)
|
||||||
p.stepSimulation()
|
p.stepSimulation()
|
||||||
# time.sleep(self.timeStep)
|
|
||||||
self.state = p.getJointState(self.cartpole, 1)[0:2] + p.getJointState(self.cartpole, 0)[0:2]
|
self.state = p.getJointState(self.cartpole, 1)[0:2] + p.getJointState(self.cartpole, 0)[0:2]
|
||||||
theta, theta_dot, x, x_dot = self.state
|
theta, theta_dot, x, x_dot = self.state
|
||||||
|
|
||||||
dv = 0.1
|
|
||||||
deltav = [-10.*dv,-5.*dv, -2.*dv, -0.1*dv, 0, 0.1*dv, 2.*dv,5.*dv, 10.*dv][action]
|
|
||||||
|
|
||||||
p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, targetVelocity=(deltav + self.state[3]))
|
|
||||||
|
|
||||||
done = x < -self.x_threshold \
|
done = x < -self.x_threshold \
|
||||||
or x > self.x_threshold \
|
or x > self.x_threshold \
|
||||||
or theta < -self.theta_threshold_radians \
|
or theta < -self.theta_threshold_radians \
|
||||||
or theta > self.theta_threshold_radians
|
or theta > self.theta_threshold_radians
|
||||||
|
done = bool(done)
|
||||||
reward = 1.0
|
reward = 1.0
|
||||||
|
#print("state=",self.state)
|
||||||
return np.array(self.state), reward, done, {}
|
return np.array(self.state), reward, done, {}
|
||||||
|
|
||||||
def _reset(self):
|
def _reset(self):
|
||||||
# print("-----------reset simulation---------------")
|
# print("-----------reset simulation---------------")
|
||||||
p.resetSimulation()
|
p.resetSimulation()
|
||||||
self.cartpole = p.loadURDF(os.path.join(pybullet_data.getDataPath(),"cartpole.urdf"),[0,0,0])
|
self.cartpole = p.loadURDF(os.path.join(pybullet_data.getDataPath(),"cartpole.urdf"),[0,0,0])
|
||||||
self.timeStep = 0.01
|
p.changeDynamics(self.cartpole, -1, linearDamping=0, angularDamping=0)
|
||||||
|
p.changeDynamics(self.cartpole, 0, linearDamping=0, angularDamping=0)
|
||||||
|
p.changeDynamics(self.cartpole, 1, linearDamping=0, angularDamping=0)
|
||||||
|
self.timeStep = 0.02
|
||||||
p.setJointMotorControl2(self.cartpole, 1, p.VELOCITY_CONTROL, force=0)
|
p.setJointMotorControl2(self.cartpole, 1, p.VELOCITY_CONTROL, force=0)
|
||||||
p.setGravity(0,0, -10)
|
p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, force=0)
|
||||||
|
p.setGravity(0,0, -9.8)
|
||||||
p.setTimeStep(self.timeStep)
|
p.setTimeStep(self.timeStep)
|
||||||
p.setRealTimeSimulation(0)
|
p.setRealTimeSimulation(0)
|
||||||
|
|
||||||
initialCartPos = self.np_random.uniform(low=-0.5, high=0.5, size=(1,))
|
randstate = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
|
||||||
initialAngle = self.np_random.uniform(low=-0.5, high=0.5, size=(1,))
|
p.resetJointState(self.cartpole, 1, randstate[0], randstate[1])
|
||||||
p.resetJointState(self.cartpole, 1, initialAngle)
|
p.resetJointState(self.cartpole, 0, randstate[2], randstate[3])
|
||||||
p.resetJointState(self.cartpole, 0, initialCartPos)
|
#print("randstate=",randstate)
|
||||||
|
|
||||||
self.state = p.getJointState(self.cartpole, 1)[0:2] + p.getJointState(self.cartpole, 0)[0:2]
|
self.state = p.getJointState(self.cartpole, 1)[0:2] + p.getJointState(self.cartpole, 0)[0:2]
|
||||||
|
#print("self.state=", self.state)
|
||||||
return np.array(self.state)
|
return np.array(self.state)
|
||||||
|
|
||||||
def _render(self, mode='human', close=False):
|
def _render(self, mode='human', close=False):
|
||||||
|
|||||||
Reference in New Issue
Block a user