add goal training
This commit is contained in:
76
examples/pybullet/gym/pybullet_envs/deep_mimic/env/goals.py
vendored
Normal file
76
examples/pybullet/gym/pybullet_envs/deep_mimic/env/goals.py
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
from enum import Enum
|
||||
import random
|
||||
import math
|
||||
|
||||
def randomVal(lowerBound, upperbound):
|
||||
difference = upperbound - lowerBound
|
||||
return lowerBound + random.random() * difference
|
||||
|
||||
class GoalType(Enum):
|
||||
NoGoal = 0
|
||||
Strike = 1
|
||||
TargetHeading = 2
|
||||
Throw = 3
|
||||
TerrainTraversal = 4
|
||||
|
||||
@staticmethod
|
||||
def from_str(goal_type_str):
|
||||
if goal_type_str == '':
|
||||
return GoalType.NoGoal
|
||||
else:
|
||||
try:
|
||||
return GoalType[goal_type_str]
|
||||
except:
|
||||
raise NotImplementedError
|
||||
|
||||
class Goal:
|
||||
def __init__(self, goal_type: GoalType):
|
||||
self.goal_type = goal_type
|
||||
self.generateGoalData()
|
||||
|
||||
def generateGoalData(self):
|
||||
if self.goal_type == GoalType.NoGoal:
|
||||
self.goal_data = []
|
||||
|
||||
elif self.goal_type == GoalType.Strike:
|
||||
# distance, height, rot
|
||||
distance = randomVal(0.6, 0.8)
|
||||
height = randomVal(0.8, 1.25)
|
||||
rot = randomVal(-1, 1) # radians
|
||||
|
||||
self.is_hit = False
|
||||
self.is_hit_prev = False
|
||||
# The max distance from the target counting as a hit
|
||||
self.hit_range = 0.2
|
||||
|
||||
# Transform to xyz coordinates for placement in environment
|
||||
x = distance * math.cos(rot)
|
||||
y = distance * math.sin(rot)
|
||||
z = height
|
||||
|
||||
# Y axis up, z axis in different direction
|
||||
self.goal_data = [-x, z, y]
|
||||
|
||||
elif self.goal_type == GoalType.TargetHeading:
|
||||
# Direction: 2D unit vector
|
||||
# speed: max speed
|
||||
random_rot = random.random() * 2 * math.pi
|
||||
x = math.cos(random_rot)
|
||||
y = math.sin(random_rot)
|
||||
velocity = randomVal(0, 0.5)
|
||||
self.goal_data = [x, y, velocity]
|
||||
|
||||
elif self.goal_type == GoalType.Throw:
|
||||
# TODO
|
||||
raise NotImplementedError
|
||||
elif self.goal_type == GoalType.TerrainTraversal:
|
||||
# TODO
|
||||
raise NotImplementedError
|
||||
|
||||
def getTFData(self):
|
||||
if self.goal_type == GoalType.Strike:
|
||||
x = 0.0
|
||||
if self.is_hit:
|
||||
x = 1.0
|
||||
return [x] + self.goal_data
|
||||
|
||||
17
examples/pybullet/gym/pybullet_envs/deep_mimic/env/humanoid_link_ids.py
vendored
Normal file
17
examples/pybullet/gym/pybullet_envs/deep_mimic/env/humanoid_link_ids.py
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
from enum import IntEnum
|
||||
|
||||
class HumanoidLinks(IntEnum):
|
||||
chest = 1
|
||||
neck = 2
|
||||
rightHip = 3
|
||||
rightKnee = 4
|
||||
rightAnkle = 5
|
||||
rightShoulder = 6
|
||||
rightElbow = 7
|
||||
rightWrist = 8
|
||||
leftHip = 9
|
||||
leftKnee = 10
|
||||
leftAnkle = 11
|
||||
leftShoulder = 12
|
||||
leftElbow = 13
|
||||
leftWrist = 14
|
||||
@@ -807,4 +807,7 @@ class HumanoidStablePD(object):
|
||||
|
||||
def getSimModelBasePosition(self):
|
||||
return self._pybullet_client\
|
||||
.getBasePositionAndOrientation(self._sim_model)
|
||||
.getBasePositionAndOrientation(self._sim_model)
|
||||
|
||||
def getLinkPosition(self, link_id):
|
||||
return self._pybullet_client.getLinkState(self._sim_model, link_id)[0]
|
||||
|
||||
@@ -6,6 +6,8 @@ from pybullet_utils import bullet_client
|
||||
import time
|
||||
from pybullet_envs.deep_mimic.env import motion_capture_data
|
||||
from pybullet_envs.deep_mimic.env import humanoid_stable_pd
|
||||
from pybullet_envs.deep_mimic.env.goals import GoalType, Goal
|
||||
from pybullet_envs.deep_mimic.env.humanoid_link_ids import HumanoidLinks
|
||||
import pybullet_data
|
||||
import pybullet as p1
|
||||
import random
|
||||
@@ -20,8 +22,14 @@ class PyBulletDeepMimicEnv(Env):
|
||||
self._isInitialized = False
|
||||
self._useStablePD = True
|
||||
self._arg_parser = arg_parser
|
||||
self.goal = self.getGoal()
|
||||
self.target_id = None
|
||||
|
||||
self.reset()
|
||||
|
||||
if self.goal.goal_type == GoalType.Strike:
|
||||
self.target_id = self.drawStrikeGoal()
|
||||
|
||||
def reset(self):
|
||||
|
||||
if not self._isInitialized:
|
||||
@@ -91,6 +99,8 @@ class PyBulletDeepMimicEnv(Env):
|
||||
#this clears the contact points. Todo: add API to explicitly clear all contact points?
|
||||
#self._pybullet_client.stepSimulation()
|
||||
self._humanoid.resetPose()
|
||||
# generate new goal
|
||||
self.goal.generateGoalData()
|
||||
self.needs_update_time = self.t - 1 #force update
|
||||
|
||||
def get_num_agents(self):
|
||||
@@ -145,7 +155,7 @@ class PyBulletDeepMimicEnv(Env):
|
||||
return np.array(out_scale)
|
||||
|
||||
def get_goal_size(self, agent_id):
|
||||
return 0
|
||||
return len(self.goal.getTFData())
|
||||
|
||||
def get_action_size(self, agent_id):
|
||||
ctrl_size = 43 #numDof
|
||||
@@ -153,13 +163,16 @@ class PyBulletDeepMimicEnv(Env):
|
||||
return ctrl_size - root_size
|
||||
|
||||
def build_goal_norm_groups(self, agent_id):
|
||||
return np.array([])
|
||||
# Perform no normalization on goal data
|
||||
return np.array([-1] * len(self.goal.getTFData()))
|
||||
|
||||
def build_goal_offset(self, agent_id):
|
||||
return np.array([])
|
||||
# no offset
|
||||
return np.array([0] * len(self.goal.getTFData()))
|
||||
|
||||
def build_goal_scale(self, agent_id):
|
||||
return np.array([])
|
||||
# no scale
|
||||
return np.array([1] * len(self.goal.getTFData()))
|
||||
|
||||
def build_action_offset(self, agent_id):
|
||||
out_offset = [0] * self.get_action_size(agent_id)
|
||||
@@ -232,11 +245,19 @@ class PyBulletDeepMimicEnv(Env):
|
||||
return np.array(state)
|
||||
|
||||
def record_goal(self, agent_id):
|
||||
return np.array([])
|
||||
return np.array(self.goal.getTFData())
|
||||
|
||||
def calc_reward(self, agent_id):
|
||||
kinPose = self._humanoid.computePose(self._humanoid._frameFraction)
|
||||
reward = self._humanoid.getReward(kinPose)
|
||||
|
||||
mimic_weight = 0.7
|
||||
goal_weight = 0.3
|
||||
|
||||
if self.goal.goal_type == GoalType.Strike:
|
||||
linkPos = self._humanoid.getLinkPosition(HumanoidLinks.rightAnkle)
|
||||
reward = mimic_weight * reward + goal_weight * self.calcStrikeGoalReward(linkPos)
|
||||
|
||||
return reward
|
||||
|
||||
def set_action(self, agent_id, action):
|
||||
@@ -255,7 +276,7 @@ class PyBulletDeepMimicEnv(Env):
|
||||
self.desiredPose[6] = 0
|
||||
target_pose = np.array(self.desiredPose)
|
||||
|
||||
np.savetxt("pb_target_pose.csv", target_pose, delimiter=",")
|
||||
# np.savetxt("pb_target_pose.csv", target_pose, delimiter=",")
|
||||
|
||||
#print("set_action: desiredPose=", self.desiredPose)
|
||||
|
||||
@@ -266,6 +287,9 @@ class PyBulletDeepMimicEnv(Env):
|
||||
#print("pybullet_deep_mimic_env:update timeStep=",timeStep," t=",self.t)
|
||||
self._pybullet_client.setTimeStep(timeStep)
|
||||
self._humanoid._timeStep = timeStep
|
||||
self.updateGoal(self._humanoid.getLinkPosition(HumanoidLinks.rightAnkle))
|
||||
if self.target_id is not None:
|
||||
self.updateDrawStrikeGoal()
|
||||
|
||||
for i in range(1):
|
||||
self.t += timeStep
|
||||
@@ -367,4 +391,62 @@ class PyBulletDeepMimicEnv(Env):
|
||||
|
||||
self._humanoid.thrown_body_ids = []
|
||||
|
||||
self._pybullet_client.configureDebugVisualizer(self._pybullet_client.COV_ENABLE_RENDERING,1)
|
||||
self._pybullet_client.configureDebugVisualizer(self._pybullet_client.COV_ENABLE_RENDERING,1)
|
||||
|
||||
def getGoal(self):
|
||||
goal_type_str = self._arg_parser.parse_string("goal_type")
|
||||
return Goal(GoalType.from_str(goal_type_str))
|
||||
|
||||
def calcStrikeGoalReward(self, linkPos):
|
||||
if self.goal.is_hit:
|
||||
return 1
|
||||
else:
|
||||
goalPos = self.goal.goal_data
|
||||
distanceSquared = sum([(x - y)**2 for (x, y) in zip(goalPos, linkPos)])
|
||||
return math.exp(-4*distanceSquared)
|
||||
|
||||
def updateGoal(self, linkPos):
|
||||
if self.goal.goal_type == GoalType.Strike:
|
||||
goalPos = self.goal.goal_data
|
||||
distance = sum([(x - y)**2 for (x, y) in zip(goalPos, linkPos)]) ** 0.5
|
||||
|
||||
if distance <= self.goal.hit_range:
|
||||
self.goal.is_hit = True
|
||||
|
||||
def drawStrikeGoal(self):
|
||||
vis_id = self._pybullet_client.createVisualShape(
|
||||
shapeType=self._pybullet_client.GEOM_SPHERE,
|
||||
radius=0.2,
|
||||
rgbaColor=[1,0,0,0.5])
|
||||
|
||||
obj_id = self._pybullet_client.createMultiBody(
|
||||
baseVisualShapeIndex=vis_id,
|
||||
basePosition=self.goal.goal_data)
|
||||
return obj_id
|
||||
|
||||
def updateDrawStrikeGoal(self):
|
||||
current_pos = self._pybullet_client.getBasePositionAndOrientation(self.target_id)[0]
|
||||
target_pos = self.goal.goal_data
|
||||
|
||||
if target_pos != current_pos:
|
||||
self._pybullet_client.resetBasePositionAndOrientation(
|
||||
self.target_id,
|
||||
target_pos,
|
||||
[0, 0, 0, 1]
|
||||
)
|
||||
if self.goal.is_hit != self.goal.is_hit_prev:
|
||||
self.goal.is_hit_prev = self.goal.is_hit
|
||||
if self.goal.is_hit:
|
||||
# Color green
|
||||
self._pybullet_client.changeVisualShape(
|
||||
self.target_id,
|
||||
-1,
|
||||
rgbaColor=[0, 1, 0, 0.5]
|
||||
)
|
||||
else:
|
||||
# Color red
|
||||
self._pybullet_client.changeVisualShape(
|
||||
self.target_id,
|
||||
-1,
|
||||
rgbaColor=[1, 0, 0, 0.5]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user