add mpi_run version (not working yet)
This commit is contained in:
@@ -0,0 +1,52 @@
|
|||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
||||||
|
parentdir = os.path.dirname(os.path.dirname(currentdir))
|
||||||
|
os.sys.path.insert(0,parentdir)
|
||||||
|
print("parentdir=",parentdir)
|
||||||
|
|
||||||
|
|
||||||
|
from pybullet_envs.deep_mimic.env.pybullet_deep_mimic_env import PyBulletDeepMimicEnv
|
||||||
|
from pybullet_envs.deep_mimic.learning.rl_world import RLWorld
|
||||||
|
from pybullet_utils.logger import Logger
|
||||||
|
from testrl import update_world, update_timestep, build_world
|
||||||
|
import pybullet_utils.mpi_util as MPIUtil
|
||||||
|
|
||||||
|
args = []
|
||||||
|
world = None
|
||||||
|
|
||||||
|
def run():
|
||||||
|
global update_timestep
|
||||||
|
global world
|
||||||
|
|
||||||
|
done = False
|
||||||
|
while not (done):
|
||||||
|
update_world(world, update_timestep)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def shutdown():
|
||||||
|
global world
|
||||||
|
|
||||||
|
Logger.print2('Shutting down...')
|
||||||
|
world.shutdown()
|
||||||
|
return
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global args
|
||||||
|
global world
|
||||||
|
|
||||||
|
# Command line arguments
|
||||||
|
args = sys.argv[1:]
|
||||||
|
|
||||||
|
world = build_world(args, enable_draw=False)
|
||||||
|
|
||||||
|
run()
|
||||||
|
shutdown()
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -248,7 +248,7 @@ class PyBulletDeepMimicEnv(Env):
|
|||||||
isEnded = self._humanoid.terminates()
|
isEnded = self._humanoid.terminates()
|
||||||
#also check maximum time, 20 seconds (todo get from file)
|
#also check maximum time, 20 seconds (todo get from file)
|
||||||
#print("self.t=",self.t)
|
#print("self.t=",self.t)
|
||||||
if (self.t>3):
|
if (self.t>20):
|
||||||
isEnded = True
|
isEnded = True
|
||||||
return isEnded
|
return isEnded
|
||||||
|
|
||||||
|
|||||||
23
examples/pybullet/gym/pybullet_envs/deep_mimic/mpi_run.py
Normal file
23
examples/pybullet/gym/pybullet_envs/deep_mimic/mpi_run.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
from pybullet_utils.arg_parser import ArgParser
|
||||||
|
from pybullet_utils.logger import Logger
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Command line arguments
|
||||||
|
args = sys.argv[1:]
|
||||||
|
arg_parser = ArgParser()
|
||||||
|
arg_parser.load_args(args)
|
||||||
|
|
||||||
|
num_workers = arg_parser.parse_int('num_workers', 1)
|
||||||
|
assert(num_workers > 0)
|
||||||
|
|
||||||
|
Logger.print2('Running with {:d} workers'.format(num_workers))
|
||||||
|
cmd = 'mpiexec -n {:d} python3 DeepMimic_Optimizer.py '.format(num_workers)
|
||||||
|
cmd += ' '.join(args)
|
||||||
|
Logger.print2('cmd: ' + cmd)
|
||||||
|
subprocess.call(cmd, shell=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -15,6 +15,18 @@ from pybullet_envs.deep_mimic.env.pybullet_deep_mimic_env import PyBulletDeepMim
|
|||||||
import sys
|
import sys
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def update_world(world, time_elapsed):
|
||||||
|
timeStep = 1./600.
|
||||||
|
world.update(timeStep)
|
||||||
|
reward = world.env.calc_reward(agent_id=0)
|
||||||
|
#print("reward=",reward)
|
||||||
|
end_episode = world.env.is_episode_end()
|
||||||
|
if (end_episode):
|
||||||
|
world.end_episode()
|
||||||
|
world.reset()
|
||||||
|
return
|
||||||
|
|
||||||
def build_arg_parser(args):
|
def build_arg_parser(args):
|
||||||
arg_parser = ArgParser()
|
arg_parser = ArgParser()
|
||||||
arg_parser.load_args(args)
|
arg_parser.load_args(args)
|
||||||
@@ -28,43 +40,40 @@ def build_arg_parser(args):
|
|||||||
return arg_parser
|
return arg_parser
|
||||||
|
|
||||||
args = sys.argv[1:]
|
args = sys.argv[1:]
|
||||||
arg_parser = build_arg_parser(args)
|
|
||||||
|
|
||||||
render=False#True
|
|
||||||
env = PyBulletDeepMimicEnv (args,render)
|
|
||||||
|
|
||||||
world = RLWorld(env, arg_parser)
|
|
||||||
|
|
||||||
motion_file = arg_parser.parse_string("motion_file")
|
def build_world(args, enable_draw, playback_speed=1):
|
||||||
print("motion_file=",motion_file)
|
arg_parser = build_arg_parser(args)
|
||||||
bodies = arg_parser.parse_ints("fall_contact_bodies")
|
env = PyBulletDeepMimicEnv(args, enable_draw)
|
||||||
print("bodies=",bodies)
|
world = RLWorld(env, arg_parser)
|
||||||
int_output_path = arg_parser.parse_string("int_output_path")
|
#world.env.set_playback_speed(playback_speed)
|
||||||
print("int_output_path=",int_output_path)
|
|
||||||
|
|
||||||
agent_files = pybullet_data.getDataPath()+"/"+arg_parser.parse_string("agent_files")
|
motion_file = arg_parser.parse_string("motion_file")
|
||||||
|
print("motion_file=",motion_file)
|
||||||
|
bodies = arg_parser.parse_ints("fall_contact_bodies")
|
||||||
|
print("bodies=",bodies)
|
||||||
|
int_output_path = arg_parser.parse_string("int_output_path")
|
||||||
|
print("int_output_path=",int_output_path)
|
||||||
|
agent_files = pybullet_data.getDataPath()+"/"+arg_parser.parse_string("agent_files")
|
||||||
|
|
||||||
AGENT_TYPE_KEY = "AgentType"
|
AGENT_TYPE_KEY = "AgentType"
|
||||||
|
|
||||||
print("agent_file=",agent_files)
|
print("agent_file=",agent_files)
|
||||||
with open(agent_files) as data_file:
|
with open(agent_files) as data_file:
|
||||||
json_data = json.load(data_file)
|
json_data = json.load(data_file)
|
||||||
print("json_data=",json_data)
|
print("json_data=",json_data)
|
||||||
assert AGENT_TYPE_KEY in json_data
|
assert AGENT_TYPE_KEY in json_data
|
||||||
agent_type = json_data[AGENT_TYPE_KEY]
|
agent_type = json_data[AGENT_TYPE_KEY]
|
||||||
print("agent_type=",agent_type)
|
print("agent_type=",agent_type)
|
||||||
agent = PPOAgent(world, id, json_data)
|
agent = PPOAgent(world, id, json_data)
|
||||||
|
|
||||||
agent.set_enable_training(True)
|
agent.set_enable_training(True)
|
||||||
world.reset()
|
|
||||||
while (world.env._pybullet_client.isConnected()):
|
|
||||||
|
|
||||||
timeStep = 1./600.
|
|
||||||
world.update(timeStep)
|
|
||||||
reward = world.env.calc_reward(agent_id=0)
|
|
||||||
#print("reward=",reward)
|
|
||||||
|
|
||||||
end_episode = world.env.is_episode_end()
|
|
||||||
if (end_episode):
|
|
||||||
world.end_episode()
|
|
||||||
world.reset()
|
world.reset()
|
||||||
|
return world
|
||||||
|
|
||||||
|
world = build_world(args, True)
|
||||||
|
while (world.env._pybullet_client.isConnected()):
|
||||||
|
timeStep = 1./600.
|
||||||
|
update_world(world, timeStep)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user