update goal every cycle
make distinction between world pos and relative goal pos
This commit is contained in:
@@ -27,8 +27,9 @@ class Goal:
|
||||
def __init__(self, goal_type: GoalType):
|
||||
self.goal_type = goal_type
|
||||
self.generateGoalData()
|
||||
self.is_hit_prev = False
|
||||
|
||||
def generateGoalData(self):
|
||||
def generateGoalData(self, modelPos=[0,0]):
|
||||
if self.goal_type == GoalType.NoGoal:
|
||||
self.goal_data = []
|
||||
|
||||
@@ -39,7 +40,7 @@ class Goal:
|
||||
rot = randomVal(-1, 1) # radians
|
||||
|
||||
self.is_hit = False
|
||||
self.is_hit_prev = False
|
||||
|
||||
# The max distance from the target counting as a hit
|
||||
self.hit_range = 0.2
|
||||
|
||||
@@ -50,6 +51,7 @@ class Goal:
|
||||
|
||||
# Y axis up, z axis in different direction
|
||||
self.goal_data = [-x, z, y]
|
||||
self.world_pos = [-x + modelPos[0], z, y + modelPos[1]]
|
||||
|
||||
elif self.goal_type == GoalType.TargetHeading:
|
||||
# Direction: 2D unit vector
|
||||
|
||||
@@ -94,13 +94,15 @@ class PyBulletDeepMimicEnv(Env):
|
||||
self.removeThrownObjects()
|
||||
|
||||
self._humanoid.setSimTime(startTime)
|
||||
self.prevCycleCount = self._humanoid.cycleCount
|
||||
|
||||
self._humanoid.resetPose()
|
||||
#this clears the contact points. Todo: add API to explicitly clear all contact points?
|
||||
#self._pybullet_client.stepSimulation()
|
||||
self._humanoid.resetPose()
|
||||
# generate new goal
|
||||
self.goal.generateGoalData()
|
||||
humanPos = self._humanoid.getLinkPosition(0)
|
||||
self.goal.generateGoalData([humanPos[0], humanPos[2]])
|
||||
self.needs_update_time = self.t - 1 #force update
|
||||
|
||||
def get_num_agents(self):
|
||||
@@ -288,7 +290,7 @@ class PyBulletDeepMimicEnv(Env):
|
||||
self._pybullet_client.setTimeStep(timeStep)
|
||||
self._humanoid._timeStep = timeStep
|
||||
self.updateGoal(self._humanoid.getLinkPosition(HumanoidLinks.rightAnkle))
|
||||
if self.target_id is not None:
|
||||
if self.target_id is not None: # TODO: check goal type
|
||||
self.updateDrawStrikeGoal()
|
||||
|
||||
for i in range(1):
|
||||
@@ -413,7 +415,13 @@ class PyBulletDeepMimicEnv(Env):
|
||||
|
||||
def updateGoal(self, linkPos):
|
||||
if self.goal.goal_type == GoalType.Strike:
|
||||
goalPos = self.goal.goal_data
|
||||
if self.prevCycleCount != self._humanoid.cycleCount:
|
||||
# generate new goal
|
||||
humanPos = self._humanoid.getLinkPosition(0)
|
||||
self.goal.generateGoalData([humanPos[0], humanPos[2]])
|
||||
self.prevCycleCount = self._humanoid.cycleCount
|
||||
|
||||
goalPos = self.goal.world_pos
|
||||
distance = sum([(x - y)**2 for (x, y) in zip(goalPos, linkPos)]) ** 0.5
|
||||
|
||||
if distance <= self.goal.hit_range:
|
||||
@@ -427,12 +435,12 @@ class PyBulletDeepMimicEnv(Env):
|
||||
|
||||
obj_id = self._pybullet_client.createMultiBody(
|
||||
baseVisualShapeIndex=vis_id,
|
||||
basePosition=self.goal.goal_data)
|
||||
basePosition=self.goal.world_pos)
|
||||
return obj_id
|
||||
|
||||
def updateDrawStrikeGoal(self):
|
||||
current_pos = self._pybullet_client.getBasePositionAndOrientation(self.target_id)[0]
|
||||
target_pos = self.goal.goal_data
|
||||
target_pos = self.goal.world_pos
|
||||
|
||||
if target_pos != current_pos:
|
||||
self._pybullet_client.resetBasePositionAndOrientation(
|
||||
|
||||
Reference in New Issue
Block a user