Merge pull request #1193 from erwincoumans/master

fix gym/envs/bullet/cartpole_bullet.py, 	improve getAABB.py
This commit is contained in:
erwincoumans
2017-06-16 19:34:42 -07:00
committed by GitHub
11 changed files with 279 additions and 64 deletions

View File

@@ -1308,12 +1308,25 @@ int b3GetStatusBodyIndex(b3SharedMemoryStatusHandle statusHandle)
return bodyId;
}
b3SharedMemoryCommandHandle b3RequestCollisionInfoCommandInit(b3PhysicsClientHandle physClient, int bodyUniqueId)
{
PhysicsClient* cl = (PhysicsClient* ) physClient;
b3Assert(cl);
b3Assert(cl->canSubmitCommand());
struct SharedMemoryCommand* command = cl->getAvailableSharedMemoryCommand();
b3Assert(command);
command->m_type =CMD_REQUEST_COLLISION_INFO;
command->m_updateFlags = 0;
command->m_requestCollisionInfoArgs.m_bodyUniqueId = bodyUniqueId;
return (b3SharedMemoryCommandHandle) command;
}
int b3GetStatusAABB(b3SharedMemoryStatusHandle statusHandle, int linkIndex, double aabbMin[3], double aabbMax[3])
{
const SharedMemoryStatus* status = (const SharedMemoryStatus* ) statusHandle;
const SendActualStateArgs &args = status->m_sendActualStateArgs;
btAssert(status->m_type == CMD_ACTUAL_STATE_UPDATE_COMPLETED);
if (status->m_type != CMD_ACTUAL_STATE_UPDATE_COMPLETED)
const b3SendCollisionInfoArgs &args = status->m_sendCollisionInfoArgs;
btAssert(status->m_type == CMD_REQUEST_COLLISION_INFO_COMPLETED);
if (status->m_type != CMD_REQUEST_COLLISION_INFO_COMPLETED)
return 0;
if (linkIndex==-1)
@@ -1330,13 +1343,13 @@ int b3GetStatusAABB(b3SharedMemoryStatusHandle statusHandle, int linkIndex, doub
if (linkIndex >= 0 && linkIndex < args.m_numLinks)
{
aabbMin[0] = args.m_linkWorldAABBsMin[0];
aabbMin[1] = args.m_linkWorldAABBsMin[1];
aabbMin[2] = args.m_linkWorldAABBsMin[2];
aabbMin[0] = args.m_linkWorldAABBsMin[linkIndex*3+0];
aabbMin[1] = args.m_linkWorldAABBsMin[linkIndex*3+1];
aabbMin[2] = args.m_linkWorldAABBsMin[linkIndex*3+2];
aabbMax[0] = args.m_linkWorldAABBsMax[0];
aabbMax[1] = args.m_linkWorldAABBsMax[1];
aabbMax[2] = args.m_linkWorldAABBsMax[2];
aabbMax[0] = args.m_linkWorldAABBsMax[linkIndex*3+0];
aabbMax[1] = args.m_linkWorldAABBsMax[linkIndex*3+1];
aabbMax[2] = args.m_linkWorldAABBsMax[linkIndex*3+2];
return 1;
}

View File

@@ -57,6 +57,8 @@ int b3GetStatusActualState(b3SharedMemoryStatusHandle statusHandle,
const double* actualStateQdot[],
const double* jointReactionForces[]);
b3SharedMemoryCommandHandle b3RequestCollisionInfoCommandInit(b3PhysicsClientHandle physClient, int bodyUniqueId);
int b3GetStatusAABB(b3SharedMemoryStatusHandle statusHandle, int linkIndex, double aabbMin[3], double aabbMax[3]);
///If you re-connected to an existing server, or server changed otherwise, sync the body info and user constraints etc.

View File

@@ -1079,6 +1079,15 @@ const SharedMemoryStatus* PhysicsClientSharedMemory::processServerStatus() {
b3Warning("Request createVisualShape failed");
break;
}
case CMD_REQUEST_COLLISION_INFO_COMPLETED:
{
break;
}
case CMD_REQUEST_COLLISION_INFO_FAILED:
{
b3Warning("Request getCollisionInfo failed");
break;
}
default: {
b3Error("Unknown server status %d\n", serverCmd.m_type);

View File

@@ -914,7 +914,16 @@ void PhysicsDirect::postProcessStatus(const struct SharedMemoryStatus& serverCmd
b3Warning("createMultiBody failed");
break;
}
case CMD_REQUEST_COLLISION_INFO_COMPLETED:
{
break;
}
case CMD_REQUEST_COLLISION_INFO_FAILED:
{
b3Warning("Request getCollisionInfo failed");
break;
}
default:
{
//b3Warning("Unknown server status type");

View File

@@ -4023,6 +4023,105 @@ bool PhysicsServerCommandProcessor::processCommand(const struct SharedMemoryComm
hasStatus = true;
break;
}
case CMD_REQUEST_COLLISION_INFO:
{
SharedMemoryStatus& serverCmd = serverStatusOut;
serverStatusOut.m_type = CMD_REQUEST_COLLISION_INFO_FAILED;
hasStatus=true;
int bodyUniqueId = clientCmd.m_requestCollisionInfoArgs.m_bodyUniqueId;
InteralBodyData* body = m_data->m_bodyHandles.getHandle(bodyUniqueId);
if (body && body->m_multiBody)
{
btMultiBody* mb = body->m_multiBody;
serverStatusOut.m_type = CMD_REQUEST_COLLISION_INFO_COMPLETED;
serverCmd.m_sendCollisionInfoArgs.m_numLinks = body->m_multiBody->getNumLinks();
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[0] = 0;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[1] = 0;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[2] = 0;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[0] = -1;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[1] = -1;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[2] = -1;
if (body->m_multiBody->getBaseCollider())
{
btTransform tr;
tr.setOrigin(mb->getBasePos());
tr.setRotation(mb->getWorldToBaseRot().inverse());
btVector3 aabbMin,aabbMax;
body->m_multiBody->getBaseCollider()->getCollisionShape()->getAabb(tr,aabbMin,aabbMax);
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[0] = aabbMin[0];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[1] = aabbMin[1];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[2] = aabbMin[2];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[0] = aabbMax[0];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[1] = aabbMax[1];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[2] = aabbMax[2];
}
for (int l=0;l<mb->getNumLinks();l++)
{
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMin[3*l+0] = 0;
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMin[3*l+1] = 0;
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMin[3*l+2] = 0;
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMax[3*l+0] = -1;
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMax[3*l+1] = -1;
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMax[3*l+2] = -1;
if (body->m_multiBody->getLink(l).m_collider)
{
btVector3 aabbMin,aabbMax;
body->m_multiBody->getLinkCollider(l)->getCollisionShape()->getAabb(mb->getLink(l).m_cachedWorldTransform,aabbMin,aabbMax);
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMin[3*l+0] = aabbMin[0];
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMin[3*l+1] = aabbMin[1];
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMin[3*l+2] = aabbMin[2];
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMax[3*l+0] = aabbMax[0];
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMax[3*l+1] = aabbMax[1];
serverCmd.m_sendCollisionInfoArgs.m_linkWorldAABBsMax[3*l+2] = aabbMax[2];
}
}
}
else
{
if (body && body->m_rigidBody)
{
btRigidBody* rb = body->m_rigidBody;
SharedMemoryStatus& serverCmd = serverStatusOut;
serverStatusOut.m_type = CMD_REQUEST_COLLISION_INFO_COMPLETED;
serverCmd.m_sendCollisionInfoArgs.m_numLinks = 0;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[0] = 0;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[1] = 0;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[2] = 0;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[0] = -1;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[1] = -1;
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[2] = -1;
if (rb->getCollisionShape())
{
btTransform tr = rb->getWorldTransform();
btVector3 aabbMin,aabbMax;
rb->getCollisionShape()->getAabb(tr,aabbMin,aabbMax);
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[0] = aabbMin[0];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[1] = aabbMin[1];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMin[2] = aabbMin[2];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[0] = aabbMax[0];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[1] = aabbMax[1];
serverCmd.m_sendCollisionInfoArgs.m_rootWorldAABBMax[2] = aabbMax[2];
}
}
}
break;
}
case CMD_REQUEST_ACTUAL_STATE:
{
BT_PROFILE("CMD_REQUEST_ACTUAL_STATE");
@@ -4077,28 +4176,7 @@ bool PhysicsServerCommandProcessor::processCommand(const struct SharedMemoryComm
serverCmd.m_sendActualStateArgs.m_rootLocalInertialFrame[6] =
body->m_rootLocalInertialFrame.getRotation()[3];
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMin[0] = 0;
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMin[1] = 0;
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMin[2] = 0;
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMax[0] = -1;
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMax[1] = -1;
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMax[2] = -1;
if (body->m_multiBody->getBaseCollider())
{
btVector3 aabbMin,aabbMax;
body->m_multiBody->getBaseCollider()->getCollisionShape()->getAabb(tr,aabbMin,aabbMax);
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMin[0] = aabbMin[0];
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMin[1] = aabbMin[1];
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMin[2] = aabbMin[2];
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMax[0] = aabbMax[0];
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMax[1] = aabbMax[1];
serverCmd.m_sendActualStateArgs.m_rootWorldAABBMax[2] = aabbMax[2];
}
//base position in world space, carthesian
serverCmd.m_sendActualStateArgs.m_actualStateQ[0] = tr.getOrigin()[0];
@@ -4200,29 +4278,7 @@ bool PhysicsServerCommandProcessor::processCommand(const struct SharedMemoryComm
serverCmd.m_sendActualStateArgs.m_linkState[l*7+5] = linkCOMRotation.z();
serverCmd.m_sendActualStateArgs.m_linkState[l*7+6] = linkCOMRotation.w();
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMin[3*l+0] = 0;
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMin[3*l+1] = 0;
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMin[3*l+2] = 0;
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMax[3*l+0] = -1;
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMax[3*l+1] = -1;
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMax[3*l+2] = -1;
if (body->m_multiBody->getLink(l).m_collider)
{
btVector3 aabbMin,aabbMax;
body->m_multiBody->getLinkCollider(l)->getCollisionShape()->getAabb(mb->getLink(l).m_cachedWorldTransform,aabbMin,aabbMax);
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMin[3*l+0] = aabbMin[0];
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMin[3*l+1] = aabbMin[1];
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMin[3*l+2] = aabbMin[2];
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMax[3*l+0] = aabbMax[0];
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMax[3*l+1] = aabbMax[1];
serverCmd.m_sendActualStateArgs.m_linkWorldAABBsMax[3*l+2] = aabbMax[2];
}
btVector3 worldLinVel(0,0,0);
btVector3 worldAngVel(0,0,0);

View File

@@ -432,9 +432,7 @@ struct SendActualStateArgs
int m_numDegreeOfFreedomU;
double m_rootLocalInertialFrame[7];
double m_rootWorldAABBMin[3];
double m_rootWorldAABBMax[3];
//actual state is only written by the server, read-only access by client is expected
double m_actualStateQ[MAX_DEGREE_OF_FREEDOM];
double m_actualStateQdot[MAX_DEGREE_OF_FREEDOM];
@@ -447,10 +445,24 @@ struct SendActualStateArgs
double m_linkState[7*MAX_NUM_LINKS];
double m_linkWorldVelocities[6*MAX_NUM_LINKS];//linear velocity and angular velocity in world space (x/y/z each).
double m_linkLocalInertialFrames[7*MAX_NUM_LINKS];
};
struct b3SendCollisionInfoArgs
{
int m_numLinks;
double m_rootWorldAABBMin[3];
double m_rootWorldAABBMax[3];
double m_linkWorldAABBsMin[3*MAX_NUM_LINKS];
double m_linkWorldAABBsMax[3*MAX_NUM_LINKS];
};
struct b3RequestCollisionInfoArgs
{
int m_bodyUniqueId;
};
enum EnumSensorTypes
{
SENSOR_FORCE_TORQUE=1,
@@ -918,6 +930,7 @@ struct SharedMemoryCommand
struct b3CreateCollisionShapeArgs m_createCollisionShapeArgs;
struct b3CreateVisualShapeArgs m_createVisualShapeArgs;
struct b3CreateMultiBodyArgs m_createMultiBodyArgs;
struct b3RequestCollisionInfoArgs m_requestCollisionInfoArgs;
};
};
@@ -987,6 +1000,7 @@ struct SharedMemoryStatus
struct b3CreateCollisionShapeResultArgs m_createCollisionShapeResultArgs;
struct b3CreateVisualShapeResultArgs m_createVisualShapeResultArgs;
struct b3CreateMultiBodyResultArgs m_createMultiBodyResultArgs;
struct b3SendCollisionInfoArgs m_sendCollisionInfoArgs;
};
};

View File

@@ -65,6 +65,7 @@ enum EnumSharedMemoryClientCommand
CMD_CREATE_COLLISION_SHAPE,
CMD_CREATE_VISUAL_SHAPE,
CMD_CREATE_MULTI_BODY,
CMD_REQUEST_COLLISION_INFO,
//don't go beyond this command!
CMD_MAX_CLIENT_COMMANDS,
@@ -156,6 +157,8 @@ enum EnumSharedMemoryServerStatus
CMD_CREATE_VISUAL_SHAPE_COMPLETED,
CMD_CREATE_MULTI_BODY_FAILED,
CMD_CREATE_MULTI_BODY_COMPLETED,
CMD_REQUEST_COLLISION_INFO_COMPLETED,
CMD_REQUEST_COLLISION_INFO_FAILED,
//don't go beyond 'CMD_MAX_SERVER_COMMANDS!
CMD_MAX_SERVER_COMMANDS
};

View File

@@ -0,0 +1,77 @@
import pybullet as p
draw=1
printtext = 0
if (draw):
p.connect(p.GUI)
else:
p.connect(p.DIRECT)
r2d2 = p.loadURDF("r2d2.urdf")
def drawAABB(aabb):
f = [aabbMin[0],aabbMin[1],aabbMin[2]]
t = [aabbMax[0],aabbMin[1],aabbMin[2]]
p.addUserDebugLine(f,t,[1,0,0])
f = [aabbMin[0],aabbMin[1],aabbMin[2]]
t = [aabbMin[0],aabbMax[1],aabbMin[2]]
p.addUserDebugLine(f,t,[0,1,0])
f = [aabbMin[0],aabbMin[1],aabbMin[2]]
t = [aabbMin[0],aabbMin[1],aabbMax[2]]
p.addUserDebugLine(f,t,[0,0,1])
f = [aabbMin[0],aabbMin[1],aabbMax[2]]
t = [aabbMin[0],aabbMax[1],aabbMax[2]]
p.addUserDebugLine(f,t,[1,1,1])
f = [aabbMin[0],aabbMin[1],aabbMax[2]]
t = [aabbMax[0],aabbMin[1],aabbMax[2]]
p.addUserDebugLine(f,t,[1,1,1])
f = [aabbMax[0],aabbMin[1],aabbMin[2]]
t = [aabbMax[0],aabbMin[1],aabbMax[2]]
p.addUserDebugLine(f,t,[1,1,1])
f = [aabbMax[0],aabbMin[1],aabbMin[2]]
t = [aabbMax[0],aabbMax[1],aabbMin[2]]
p.addUserDebugLine(f,t,[1,1,1])
f = [aabbMax[0],aabbMax[1],aabbMin[2]]
t = [aabbMin[0],aabbMax[1],aabbMin[2]]
p.addUserDebugLine(f,t,[1,1,1])
f = [aabbMin[0],aabbMax[1],aabbMin[2]]
t = [aabbMin[0],aabbMax[1],aabbMax[2]]
p.addUserDebugLine(f,t,[1,1,1])
f = [aabbMax[0],aabbMax[1],aabbMax[2]]
t = [aabbMin[0],aabbMax[1],aabbMax[2]]
p.addUserDebugLine(f,t,[1.0,0.5,0.5])
f = [aabbMax[0],aabbMax[1],aabbMax[2]]
t = [aabbMax[0],aabbMin[1],aabbMax[2]]
p.addUserDebugLine(f,t,[1,1,1])
f = [aabbMax[0],aabbMax[1],aabbMax[2]]
t = [aabbMax[0],aabbMax[1],aabbMin[2]]
p.addUserDebugLine(f,t,[1,1,1])
aabb = p.getAABB(r2d2)
aabbMin = aabb[0]
aabbMax = aabb[1]
if (printtext):
print(aabbMin)
print(aabbMax)
if (draw==1):
drawAABB(aabb)
for i in range (p.getNumJoints(r2d2)):
aabb = p.getAABB(r2d2,i)
aabbMin = aabb[0]
aabbMax = aabb[1]
if (printtext):
print(aabbMin)
print(aabbMax)
if (draw==1):
drawAABB(aabb)
while(1):
a=0

View File

@@ -0,0 +1,29 @@
import gym
from baselines import deepq
from envs.bullet.cartpole_bullet import CartPoleBulletEnv
def main():
env = gym.make('CartPoleBulletEnv-v0')
act = deepq.load("cartpole_model.pkl")
while True:
obs, done = env.reset(), False
print("obs")
print(obs)
print("type(obs)")
print(type(obs))
episode_rew = 0
while not done:
env.render()
o = obs[None]
aa = act(o)
a = aa[0]
obs, rew, done, _ = env.step(a)
episode_rew += rew
print("Episode reward", episode_rew)
if __name__ == '__main__':
main()

View File

@@ -25,7 +25,7 @@ class CartPoleBulletEnv(gym.Env):
def __init__(self):
# start the bullet physics server
p.connect(p.GUI)
# p.connect(p.DIRECT)
#p.connect(p.DIRECT)
observation_high = np.array([
np.finfo(np.float32).max,
np.finfo(np.float32).max,
@@ -33,7 +33,7 @@ class CartPoleBulletEnv(gym.Env):
np.finfo(np.float32).max])
action_high = np.array([0.1])
self.action_space = spaces.Box(-action_high, action_high)
self.action_space = spaces.Discrete(5)
self.observation_space = spaces.Box(-observation_high, observation_high)
self.theta_threshold_radians = 1
@@ -55,8 +55,11 @@ class CartPoleBulletEnv(gym.Env):
# time.sleep(self.timeStep)
self.state = p.getJointState(self.cartpole, 1)[0:2] + p.getJointState(self.cartpole, 0)[0:2]
theta, theta_dot, x, x_dot = self.state
force = action
p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, targetVelocity=(action + self.state[3]))
dv = 0.4
deltav = [-2.*dv, -dv, 0, dv, 2.*dv][action]
p.setJointMotorControl2(self.cartpole, 0, p.VELOCITY_CONTROL, targetVelocity=(deltav + self.state[3]))
done = x < -self.x_threshold \
or x > self.x_threshold \

View File

@@ -2010,12 +2010,12 @@ static PyObject* pybullet_getAABB(PyObject* self, PyObject* args, PyObject* keyw
}
cmd_handle =
b3RequestActualStateCommandInit(sm, bodyUniqueId);
b3RequestCollisionInfoCommandInit(sm, bodyUniqueId);
status_handle =
b3SubmitClientCommandAndWaitStatus(sm, cmd_handle);
status_type = b3GetStatusType(status_handle);
if (status_type != CMD_ACTUAL_STATE_UPDATE_COMPLETED)
if (status_type != CMD_REQUEST_COLLISION_INFO_COMPLETED)
{
PyErr_SetString(SpamError, "getAABB failed.");
return NULL;