update goal every cycle

make distinction between world pos and relative goal pos
This commit is contained in:
Bart Moyaers
2019-06-27 11:27:34 +02:00
parent 293a76e879
commit f73eab5803
2 changed files with 17 additions and 7 deletions

View File

@@ -27,8 +27,9 @@ class Goal:
def __init__(self, goal_type: GoalType): def __init__(self, goal_type: GoalType):
self.goal_type = goal_type self.goal_type = goal_type
self.generateGoalData() self.generateGoalData()
self.is_hit_prev = False
def generateGoalData(self): def generateGoalData(self, modelPos=[0,0]):
if self.goal_type == GoalType.NoGoal: if self.goal_type == GoalType.NoGoal:
self.goal_data = [] self.goal_data = []
@@ -39,7 +40,7 @@ class Goal:
rot = randomVal(-1, 1) # radians rot = randomVal(-1, 1) # radians
self.is_hit = False self.is_hit = False
self.is_hit_prev = False
# The max distance from the target counting as a hit # The max distance from the target counting as a hit
self.hit_range = 0.2 self.hit_range = 0.2
@@ -50,6 +51,7 @@ class Goal:
# Y axis up, z axis in different direction # Y axis up, z axis in different direction
self.goal_data = [-x, z, y] self.goal_data = [-x, z, y]
self.world_pos = [-x + modelPos[0], z, y + modelPos[1]]
elif self.goal_type == GoalType.TargetHeading: elif self.goal_type == GoalType.TargetHeading:
# Direction: 2D unit vector # Direction: 2D unit vector

View File

@@ -94,13 +94,15 @@ class PyBulletDeepMimicEnv(Env):
self.removeThrownObjects() self.removeThrownObjects()
self._humanoid.setSimTime(startTime) self._humanoid.setSimTime(startTime)
self.prevCycleCount = self._humanoid.cycleCount
self._humanoid.resetPose() self._humanoid.resetPose()
#this clears the contact points. Todo: add API to explicitly clear all contact points? #this clears the contact points. Todo: add API to explicitly clear all contact points?
#self._pybullet_client.stepSimulation() #self._pybullet_client.stepSimulation()
self._humanoid.resetPose() self._humanoid.resetPose()
# generate new goal # generate new goal
self.goal.generateGoalData() humanPos = self._humanoid.getLinkPosition(0)
self.goal.generateGoalData([humanPos[0], humanPos[2]])
self.needs_update_time = self.t - 1 #force update self.needs_update_time = self.t - 1 #force update
def get_num_agents(self): def get_num_agents(self):
@@ -288,7 +290,7 @@ class PyBulletDeepMimicEnv(Env):
self._pybullet_client.setTimeStep(timeStep) self._pybullet_client.setTimeStep(timeStep)
self._humanoid._timeStep = timeStep self._humanoid._timeStep = timeStep
self.updateGoal(self._humanoid.getLinkPosition(HumanoidLinks.rightAnkle)) self.updateGoal(self._humanoid.getLinkPosition(HumanoidLinks.rightAnkle))
if self.target_id is not None: if self.target_id is not None: # TODO: check goal type
self.updateDrawStrikeGoal() self.updateDrawStrikeGoal()
for i in range(1): for i in range(1):
@@ -407,7 +409,13 @@ class PyBulletDeepMimicEnv(Env):
def updateGoal(self, linkPos): def updateGoal(self, linkPos):
if self.goal.goal_type == GoalType.Strike: if self.goal.goal_type == GoalType.Strike:
goalPos = self.goal.goal_data if self.prevCycleCount != self._humanoid.cycleCount:
# generate new goal
humanPos = self._humanoid.getLinkPosition(0)
self.goal.generateGoalData([humanPos[0], humanPos[2]])
self.prevCycleCount = self._humanoid.cycleCount
goalPos = self.goal.world_pos
distance = sum([(x - y)**2 for (x, y) in zip(goalPos, linkPos)]) ** 0.5 distance = sum([(x - y)**2 for (x, y) in zip(goalPos, linkPos)]) ** 0.5
if distance <= self.goal.hit_range: if distance <= self.goal.hit_range:
@@ -421,12 +429,12 @@ class PyBulletDeepMimicEnv(Env):
obj_id = self._pybullet_client.createMultiBody( obj_id = self._pybullet_client.createMultiBody(
baseVisualShapeIndex=vis_id, baseVisualShapeIndex=vis_id,
basePosition=self.goal.goal_data) basePosition=self.goal.world_pos)
return obj_id return obj_id
def updateDrawStrikeGoal(self): def updateDrawStrikeGoal(self):
current_pos = self._pybullet_client.getBasePositionAndOrientation(self.target_id)[0] current_pos = self._pybullet_client.getBasePositionAndOrientation(self.target_id)[0]
target_pos = self.goal.goal_data target_pos = self.goal.world_pos
if target_pos != current_pos: if target_pos != current_pos:
self._pybullet_client.resetBasePositionAndOrientation( self._pybullet_client.resetBasePositionAndOrientation(