enable deepmimic training on mac
This commit is contained in:
@@ -141,4 +141,4 @@ class PDControllerStable(object):
|
|||||||
maxF = np.array(maxForces)
|
maxF = np.array(maxForces)
|
||||||
forces = np.clip(tau, -maxF , maxF )
|
forces = np.clip(tau, -maxF , maxF )
|
||||||
#print("c=",c)
|
#print("c=",c)
|
||||||
return tau
|
return tau
|
||||||
|
|||||||
@@ -22,8 +22,10 @@ class PyBulletDeepMimicEnv(Env):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
self.t = 0
|
self.t = 0
|
||||||
if not self._isInitialized:
|
if not self._isInitialized:
|
||||||
self._pybullet_client = bullet_client.BulletClient(connection_mode=p1.GUI)
|
if self.enable_draw:
|
||||||
#self._pybullet_client = bullet_client.BulletClient()
|
self._pybullet_client = bullet_client.BulletClient(connection_mode=p1.GUI)
|
||||||
|
else:
|
||||||
|
self._pybullet_client = bullet_client.BulletClient()
|
||||||
|
|
||||||
self._pybullet_client.setAdditionalSearchPath(pybullet_data.getDataPath())
|
self._pybullet_client.setAdditionalSearchPath(pybullet_data.getDataPath())
|
||||||
z2y = self._pybullet_client.getQuaternionFromEuler([-math.pi*0.5,0,0])
|
z2y = self._pybullet_client.getQuaternionFromEuler([-math.pi*0.5,0,0])
|
||||||
@@ -198,7 +200,7 @@ class PyBulletDeepMimicEnv(Env):
|
|||||||
state = self._humanoid.getState()
|
state = self._humanoid.getState()
|
||||||
state[1]=state[1]+0.008
|
state[1]=state[1]+0.008
|
||||||
#print("record_state=",state)
|
#print("record_state=",state)
|
||||||
return state
|
return np.array(state)
|
||||||
|
|
||||||
|
|
||||||
def record_goal(self, agent_id):
|
def record_goal(self, agent_id):
|
||||||
@@ -241,11 +243,11 @@ class PyBulletDeepMimicEnv(Env):
|
|||||||
def is_episode_end(self):
|
def is_episode_end(self):
|
||||||
isEnded = self._humanoid.terminates()
|
isEnded = self._humanoid.terminates()
|
||||||
#also check maximum time, 20 seconds (todo get from file)
|
#also check maximum time, 20 seconds (todo get from file)
|
||||||
print("self.t=",self.t)
|
#print("self.t=",self.t)
|
||||||
if (self.t>3):
|
if (self.t>3):
|
||||||
isEnded = True
|
isEnded = True
|
||||||
return isEnded
|
return isEnded
|
||||||
|
|
||||||
def check_valid_episode(self):
|
def check_valid_episode(self):
|
||||||
#could check if limbs exceed velocity threshold
|
#could check if limbs exceed velocity threshold
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from learning.ppo_agent import PPOAgent
|
from learning.ppo_agent import PPOAgent
|
||||||
|
import pybullet_data
|
||||||
|
|
||||||
AGENT_TYPE_KEY = "AgentType"
|
AGENT_TYPE_KEY = "AgentType"
|
||||||
|
|
||||||
def build_agent(world, id, file):
|
def build_agent(world, id, file):
|
||||||
agent = None
|
agent = None
|
||||||
with open(file) as data_file:
|
with open(pybullet_data.getDataPath()+"/"+file) as data_file:
|
||||||
json_data = json.load(data_file)
|
json_data = json.load(data_file)
|
||||||
|
|
||||||
assert AGENT_TYPE_KEY in json_data
|
assert AGENT_TYPE_KEY in json_data
|
||||||
@@ -17,4 +18,4 @@ def build_agent(world, id, file):
|
|||||||
else:
|
else:
|
||||||
assert False, 'Unsupported agent type: ' + agent_type
|
assert False, 'Unsupported agent type: ' + agent_type
|
||||||
|
|
||||||
return agent
|
return agent
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import numpy as np
|
|||||||
import copy
|
import copy
|
||||||
from pybullet_utils.logger import Logger
|
from pybullet_utils.logger import Logger
|
||||||
import inspect as inspect
|
import inspect as inspect
|
||||||
from env.env import Env
|
from pybullet_envs.deep_mimic.env.env import Env
|
||||||
import pybullet_utils.math_util as MathUtil
|
import pybullet_utils.math_util as MathUtil
|
||||||
|
|
||||||
class ReplayBuffer(object):
|
class ReplayBuffer(object):
|
||||||
@@ -348,4 +348,4 @@ class SampleBuffer(object):
|
|||||||
count0 = np.sum(self.idx_to_slot == MathUtil.INVALID_IDX)
|
count0 = np.sum(self.idx_to_slot == MathUtil.INVALID_IDX)
|
||||||
count1 = np.sum(self.slot_to_idx == MathUtil.INVALID_IDX)
|
count1 = np.sum(self.slot_to_idx == MathUtil.INVALID_IDX)
|
||||||
valid &= count0 == count1
|
valid &= count0 == count1
|
||||||
return valid
|
return valid
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
||||||
|
parentdir = os.path.dirname(os.path.dirname(currentdir))
|
||||||
|
os.sys.path.insert(0,parentdir)
|
||||||
|
print("parentdir=",parentdir)
|
||||||
import json
|
import json
|
||||||
from learning.rl_world import RLWorld
|
from pybullet_envs.deep_mimic.learning.rl_world import RLWorld
|
||||||
from learning.ppo_agent import PPOAgent
|
from learning.ppo_agent import PPOAgent
|
||||||
|
|
||||||
import pybullet_data
|
import pybullet_data
|
||||||
@@ -36,7 +42,7 @@ print("bodies=",bodies)
|
|||||||
int_output_path = arg_parser.parse_string("int_output_path")
|
int_output_path = arg_parser.parse_string("int_output_path")
|
||||||
print("int_output_path=",int_output_path)
|
print("int_output_path=",int_output_path)
|
||||||
|
|
||||||
agent_files = arg_parser.parse_string("agent_files")
|
agent_files = pybullet_data.getDataPath()+"/"+arg_parser.parse_string("agent_files")
|
||||||
|
|
||||||
AGENT_TYPE_KEY = "AgentType"
|
AGENT_TYPE_KEY = "AgentType"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user