refactor goals

This commit is contained in:
Bart Moyaers
2019-07-05 15:49:15 +02:00
parent e8d6b24933
commit e6d4cfc7da
2 changed files with 91 additions and 65 deletions

View File

@@ -1,4 +1,5 @@
from enum import Enum
from abc import ABC, abstractmethod
import random
import math
import pybullet as pb
@@ -25,19 +26,36 @@ class GoalType(Enum):
except:
raise NotImplementedError
class Goal:
class Goal(ABC):
def __init__(self, goal_type: GoalType):
self.goal_type = goal_type
self.follow_rot = False
self.is_hit_prev = False
self.generateGoalData()
@abstractmethod
def generateGoalData(self, modelPos=[0,0,0], modelOrient=[0,0,0,1]):
pass
@abstractmethod
def getTFData(self):
pass
class NoGoal(Goal):
def __init__(self):
super().__init__(GoalType.NoGoal)
def generateGoalData(self, modelPos=[0,0,0], modelOrient=[0,0,0,1]):
if self.goal_type == GoalType.NoGoal:
self.goal_data = []
elif self.goal_type == GoalType.Strike:
def getTFData(self):
return self.goal_data
class StrikeGoal(Goal):
def __init__(self):
self.follow_rot = False
self.is_hit_prev = False
super().__init__(GoalType.Strike)
def generateGoalData(self, modelPos=[0,0,0], modelOrient=[0,0,0,1]):
# distance, height, rot
distance = randomVal(0.6, 0.8)
height = randomVal(0.8, 1.25)
@@ -74,7 +92,17 @@ class Goal:
else:
self.world_pos = [-x + modelPos[0], z, y + modelPos[2]]
elif self.goal_type == GoalType.TargetHeading:
def getTFData(self):
x = 0.0
if self.is_hit:
x = 1.0
return [x] + self.goal_data
class TargetHeadingGoal(Goal):
def __init__(self):
super().__init__(GoalType.TargetHeading)
def generateGoalData(self, modelPos=[0,0,0], modelOrient=[0,0,0,1]):
# Direction: 2D unit vector
# speed: max speed
random_rot = random.random() * 2 * math.pi
@@ -83,17 +111,15 @@ class Goal:
velocity = randomVal(0, 0.5)
self.goal_data = [x, y, velocity]
elif self.goal_type == GoalType.Throw:
# TODO
raise NotImplementedError
elif self.goal_type == GoalType.TerrainTraversal:
# TODO
raise NotImplementedError
def getTFData(self):
if self.goal_type == GoalType.Strike:
x = 0.0
if self.is_hit:
x = 1.0
return [x] + self.goal_data
return self.goal_data
def createGoal(goal_type: GoalType) -> Goal:
if goal_type == GoalType.NoGoal:
return NoGoal()
elif goal_type == GoalType.Strike:
return StrikeGoal()
elif goal_type == GoalType.TargetHeading:
return TargetHeadingGoal()
else:
raise NotImplementedError

View File

@@ -6,7 +6,7 @@ from pybullet_utils import bullet_client
import time
from pybullet_envs.deep_mimic.env import motion_capture_data
from pybullet_envs.deep_mimic.env import humanoid_stable_pd
from pybullet_envs.deep_mimic.env.goals import GoalType, Goal
from pybullet_envs.deep_mimic.env.goals import GoalType, Goal, createGoal
from pybullet_envs.deep_mimic.env.humanoid_link_ids import HumanoidLinks
import pybullet_data
import pybullet as p1
@@ -403,7 +403,7 @@ class PyBulletDeepMimicEnv(Env):
def getGoal(self):
goal_type_str = self._arg_parser.parse_string("goal_type")
return Goal(GoalType.from_str(goal_type_str))
return createGoal(GoalType.from_str(goal_type_str))
def calcStrikeGoalReward(self, linkPos):
if self.goal.is_hit:
@@ -430,7 +430,7 @@ class PyBulletDeepMimicEnv(Env):
def drawStrikeGoal(self):
vis_id = self._pybullet_client.createVisualShape(
shapeType=self._pybullet_client.GEOM_SPHERE,
radius=0.2,
radius=self.goal.hit_range,
rgbaColor=[1,0,0,0.5])
obj_id = self._pybullet_client.createMultiBody(