update goal every cycle
make distinction between world pos and relative goal pos
This commit is contained in:
@@ -27,8 +27,9 @@ class Goal:
|
|||||||
def __init__(self, goal_type: GoalType):
|
def __init__(self, goal_type: GoalType):
|
||||||
self.goal_type = goal_type
|
self.goal_type = goal_type
|
||||||
self.generateGoalData()
|
self.generateGoalData()
|
||||||
|
self.is_hit_prev = False
|
||||||
|
|
||||||
def generateGoalData(self):
|
def generateGoalData(self, modelPos=[0,0]):
|
||||||
if self.goal_type == GoalType.NoGoal:
|
if self.goal_type == GoalType.NoGoal:
|
||||||
self.goal_data = []
|
self.goal_data = []
|
||||||
|
|
||||||
@@ -39,7 +40,7 @@ class Goal:
|
|||||||
rot = randomVal(-1, 1) # radians
|
rot = randomVal(-1, 1) # radians
|
||||||
|
|
||||||
self.is_hit = False
|
self.is_hit = False
|
||||||
self.is_hit_prev = False
|
|
||||||
# The max distance from the target counting as a hit
|
# The max distance from the target counting as a hit
|
||||||
self.hit_range = 0.2
|
self.hit_range = 0.2
|
||||||
|
|
||||||
@@ -50,6 +51,7 @@ class Goal:
|
|||||||
|
|
||||||
# Y axis up, z axis in different direction
|
# Y axis up, z axis in different direction
|
||||||
self.goal_data = [-x, z, y]
|
self.goal_data = [-x, z, y]
|
||||||
|
self.world_pos = [-x + modelPos[0], z, y + modelPos[1]]
|
||||||
|
|
||||||
elif self.goal_type == GoalType.TargetHeading:
|
elif self.goal_type == GoalType.TargetHeading:
|
||||||
# Direction: 2D unit vector
|
# Direction: 2D unit vector
|
||||||
|
|||||||
@@ -94,13 +94,15 @@ class PyBulletDeepMimicEnv(Env):
|
|||||||
self.removeThrownObjects()
|
self.removeThrownObjects()
|
||||||
|
|
||||||
self._humanoid.setSimTime(startTime)
|
self._humanoid.setSimTime(startTime)
|
||||||
|
self.prevCycleCount = self._humanoid.cycleCount
|
||||||
|
|
||||||
self._humanoid.resetPose()
|
self._humanoid.resetPose()
|
||||||
#this clears the contact points. Todo: add API to explicitly clear all contact points?
|
#this clears the contact points. Todo: add API to explicitly clear all contact points?
|
||||||
#self._pybullet_client.stepSimulation()
|
#self._pybullet_client.stepSimulation()
|
||||||
self._humanoid.resetPose()
|
self._humanoid.resetPose()
|
||||||
# generate new goal
|
# generate new goal
|
||||||
self.goal.generateGoalData()
|
humanPos = self._humanoid.getLinkPosition(0)
|
||||||
|
self.goal.generateGoalData([humanPos[0], humanPos[2]])
|
||||||
self.needs_update_time = self.t - 1 #force update
|
self.needs_update_time = self.t - 1 #force update
|
||||||
|
|
||||||
def get_num_agents(self):
|
def get_num_agents(self):
|
||||||
@@ -288,7 +290,7 @@ class PyBulletDeepMimicEnv(Env):
|
|||||||
self._pybullet_client.setTimeStep(timeStep)
|
self._pybullet_client.setTimeStep(timeStep)
|
||||||
self._humanoid._timeStep = timeStep
|
self._humanoid._timeStep = timeStep
|
||||||
self.updateGoal(self._humanoid.getLinkPosition(HumanoidLinks.rightAnkle))
|
self.updateGoal(self._humanoid.getLinkPosition(HumanoidLinks.rightAnkle))
|
||||||
if self.target_id is not None:
|
if self.target_id is not None: # TODO: check goal type
|
||||||
self.updateDrawStrikeGoal()
|
self.updateDrawStrikeGoal()
|
||||||
|
|
||||||
for i in range(1):
|
for i in range(1):
|
||||||
@@ -407,7 +409,13 @@ class PyBulletDeepMimicEnv(Env):
|
|||||||
|
|
||||||
def updateGoal(self, linkPos):
|
def updateGoal(self, linkPos):
|
||||||
if self.goal.goal_type == GoalType.Strike:
|
if self.goal.goal_type == GoalType.Strike:
|
||||||
goalPos = self.goal.goal_data
|
if self.prevCycleCount != self._humanoid.cycleCount:
|
||||||
|
# generate new goal
|
||||||
|
humanPos = self._humanoid.getLinkPosition(0)
|
||||||
|
self.goal.generateGoalData([humanPos[0], humanPos[2]])
|
||||||
|
self.prevCycleCount = self._humanoid.cycleCount
|
||||||
|
|
||||||
|
goalPos = self.goal.world_pos
|
||||||
distance = sum([(x - y)**2 for (x, y) in zip(goalPos, linkPos)]) ** 0.5
|
distance = sum([(x - y)**2 for (x, y) in zip(goalPos, linkPos)]) ** 0.5
|
||||||
|
|
||||||
if distance <= self.goal.hit_range:
|
if distance <= self.goal.hit_range:
|
||||||
@@ -421,12 +429,12 @@ class PyBulletDeepMimicEnv(Env):
|
|||||||
|
|
||||||
obj_id = self._pybullet_client.createMultiBody(
|
obj_id = self._pybullet_client.createMultiBody(
|
||||||
baseVisualShapeIndex=vis_id,
|
baseVisualShapeIndex=vis_id,
|
||||||
basePosition=self.goal.goal_data)
|
basePosition=self.goal.world_pos)
|
||||||
return obj_id
|
return obj_id
|
||||||
|
|
||||||
def updateDrawStrikeGoal(self):
|
def updateDrawStrikeGoal(self):
|
||||||
current_pos = self._pybullet_client.getBasePositionAndOrientation(self.target_id)[0]
|
current_pos = self._pybullet_client.getBasePositionAndOrientation(self.target_id)[0]
|
||||||
target_pos = self.goal.goal_data
|
target_pos = self.goal.world_pos
|
||||||
|
|
||||||
if target_pos != current_pos:
|
if target_pos != current_pos:
|
||||||
self._pybullet_client.resetBasePositionAndOrientation(
|
self._pybullet_client.resetBasePositionAndOrientation(
|
||||||
|
|||||||
Reference in New Issue
Block a user