diff --git a/examples/pybullet/gym/pybullet_envs/deep_mimic/env/goals.py b/examples/pybullet/gym/pybullet_envs/deep_mimic/env/goals.py index e506329f6..d696f52ae 100644 --- a/examples/pybullet/gym/pybullet_envs/deep_mimic/env/goals.py +++ b/examples/pybullet/gym/pybullet_envs/deep_mimic/env/goals.py @@ -1,4 +1,5 @@ from enum import Enum +from abc import ABC, abstractmethod import random import math import pybullet as pb @@ -25,75 +26,100 @@ class GoalType(Enum): except: raise NotImplementedError -class Goal: +class Goal(ABC): def __init__(self, goal_type: GoalType): self.goal_type = goal_type - self.follow_rot = False - self.is_hit_prev = False - self.generateGoalData() + @abstractmethod def generateGoalData(self, modelPos=[0,0,0], modelOrient=[0,0,0,1]): - if self.goal_type == GoalType.NoGoal: - self.goal_data = [] + pass - elif self.goal_type == GoalType.Strike: - # distance, height, rot - distance = randomVal(0.6, 0.8) - height = randomVal(0.8, 1.25) - rot = randomVal(-1, 1) # radians - - self.is_hit = False - - # The max distance from the target counting as a hit - self.hit_range = 0.2 + @abstractmethod + def getTFData(self): + pass - # Transform to xyz coordinates for placement in environment - x = distance * math.cos(rot) - y = distance * math.sin(rot) - z = height +class NoGoal(Goal): + def __init__(self): + super().__init__(GoalType.NoGoal) - # Y axis up, z axis in different direction - self.goal_data = [-x, z, y] - - if self.follow_rot: - # Take rotation of human model into account - eulerAngles = pb.getEulerFromQuaternion(modelOrient) - # Only Y angle matters - eulerAngles = [0, eulerAngles[1], 0] - yQuat = pb.getQuaternionFromEuler(eulerAngles) - rotMatList = pb.getMatrixFromQuaternion(yQuat) - rotMat = numpy.array([rotMatList[0:3], rotMatList[3:6], rotMatList[6:9]]) - vec = numpy.array(self.goal_data) - rotatedVec = numpy.dot(rotMat, vec) - self.world_pos = rotatedVec.tolist() - - self.world_pos = [ self.world_pos[0] + modelPos[0], - self.world_pos[1], - self.world_pos[2] + modelPos[2]] - else: - self.world_pos = [-x + modelPos[0], z, y + modelPos[2]] - - elif self.goal_type == GoalType.TargetHeading: - # Direction: 2D unit vector - # speed: max speed - random_rot = random.random() * 2 * math.pi - x = math.cos(random_rot) - y = math.sin(random_rot) - velocity = randomVal(0, 0.5) - self.goal_data = [x, y, velocity] - - elif self.goal_type == GoalType.Throw: - # TODO - raise NotImplementedError - elif self.goal_type == GoalType.TerrainTraversal: - # TODO - raise NotImplementedError + def generateGoalData(self, modelPos=[0,0,0], modelOrient=[0,0,0,1]): + self.goal_data = [] def getTFData(self): - if self.goal_type == GoalType.Strike: - x = 0.0 - if self.is_hit: - x = 1.0 - return [x] + self.goal_data - \ No newline at end of file + return self.goal_data + +class StrikeGoal(Goal): + def __init__(self): + self.follow_rot = False + self.is_hit_prev = False + super().__init__(GoalType.Strike) + + def generateGoalData(self, modelPos=[0,0,0], modelOrient=[0,0,0,1]): + # distance, height, rot + distance = randomVal(0.6, 0.8) + height = randomVal(0.8, 1.25) + rot = randomVal(-1, 1) # radians + + self.is_hit = False + + # The max distance from the target counting as a hit + self.hit_range = 0.2 + + # Transform to xyz coordinates for placement in environment + x = distance * math.cos(rot) + y = distance * math.sin(rot) + z = height + + # Y axis up, z axis in different direction + self.goal_data = [-x, z, y] + + if self.follow_rot: + # Take rotation of human model into account + eulerAngles = pb.getEulerFromQuaternion(modelOrient) + # Only Y angle matters + eulerAngles = [0, eulerAngles[1], 0] + yQuat = pb.getQuaternionFromEuler(eulerAngles) + rotMatList = pb.getMatrixFromQuaternion(yQuat) + rotMat = numpy.array([rotMatList[0:3], rotMatList[3:6], rotMatList[6:9]]) + vec = numpy.array(self.goal_data) + rotatedVec = numpy.dot(rotMat, vec) + self.world_pos = rotatedVec.tolist() + + self.world_pos = [ self.world_pos[0] + modelPos[0], + self.world_pos[1], + self.world_pos[2] + modelPos[2]] + else: + self.world_pos = [-x + modelPos[0], z, y + modelPos[2]] + + def getTFData(self): + x = 0.0 + if self.is_hit: + x = 1.0 + return [x] + self.goal_data + +class TargetHeadingGoal(Goal): + def __init__(self): + super().__init__(GoalType.TargetHeading) + + def generateGoalData(self, modelPos=[0,0,0], modelOrient=[0,0,0,1]): + # Direction: 2D unit vector + # speed: max speed + random_rot = random.random() * 2 * math.pi + x = math.cos(random_rot) + y = math.sin(random_rot) + velocity = randomVal(0, 0.5) + self.goal_data = [x, y, velocity] + + def getTFData(self): + return self.goal_data + +def createGoal(goal_type: GoalType) -> Goal: + if goal_type == GoalType.NoGoal: + return NoGoal() + elif goal_type == GoalType.Strike: + return StrikeGoal() + elif goal_type == GoalType.TargetHeading: + return TargetHeadingGoal() + else: + raise NotImplementedError diff --git a/examples/pybullet/gym/pybullet_envs/deep_mimic/env/pybullet_deep_mimic_env.py b/examples/pybullet/gym/pybullet_envs/deep_mimic/env/pybullet_deep_mimic_env.py index 1052288a1..26322f03c 100644 --- a/examples/pybullet/gym/pybullet_envs/deep_mimic/env/pybullet_deep_mimic_env.py +++ b/examples/pybullet/gym/pybullet_envs/deep_mimic/env/pybullet_deep_mimic_env.py @@ -6,7 +6,7 @@ from pybullet_utils import bullet_client import time from pybullet_envs.deep_mimic.env import motion_capture_data from pybullet_envs.deep_mimic.env import humanoid_stable_pd -from pybullet_envs.deep_mimic.env.goals import GoalType, Goal +from pybullet_envs.deep_mimic.env.goals import GoalType, Goal, createGoal from pybullet_envs.deep_mimic.env.humanoid_link_ids import HumanoidLinks import pybullet_data import pybullet as p1 @@ -397,7 +397,7 @@ class PyBulletDeepMimicEnv(Env): def getGoal(self): goal_type_str = self._arg_parser.parse_string("goal_type") - return Goal(GoalType.from_str(goal_type_str)) + return createGoal(GoalType.from_str(goal_type_str)) def calcStrikeGoalReward(self, linkPos): if self.goal.is_hit: @@ -424,7 +424,7 @@ class PyBulletDeepMimicEnv(Env): def drawStrikeGoal(self): vis_id = self._pybullet_client.createVisualShape( shapeType=self._pybullet_client.GEOM_SPHERE, - radius=0.2, + radius=self.goal.hit_range, rgbaColor=[1,0,0,0.5]) obj_id = self._pybullet_client.createMultiBody(