add ARS to train/eval Minitaur
This commit is contained in:
397
examples/pybullet/gym/pybullet_envs/ARS/ars.py
Normal file
397
examples/pybullet/gym/pybullet_envs/ARS/ars.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Internal implementation of the Augmented Random Search method."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os, inspect
|
||||
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
||||
os.sys.path.insert(0,currentdir)
|
||||
|
||||
from concurrent import futures
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
import gym
|
||||
import numpy as np
|
||||
import logz
|
||||
import utils
|
||||
import optimizers
|
||||
#from google3.pyglib import gfile
|
||||
import policies
|
||||
import shared_noise
|
||||
import utility
|
||||
|
||||
class Worker(object):
|
||||
"""Object class for parallel rollout generation."""
|
||||
|
||||
def __init__(self,
|
||||
env_seed,
|
||||
env_callback,
|
||||
policy_params=None,
|
||||
deltas=None,
|
||||
rollout_length=1000,
|
||||
delta_std=0.02):
|
||||
|
||||
# initialize OpenAI environment for each worker
|
||||
self.env = env_callback()
|
||||
self.env.seed(env_seed)
|
||||
|
||||
# each worker gets access to the shared noise table
|
||||
# with independent random streams for sampling
|
||||
# from the shared noise table.
|
||||
self.deltas = shared_noise.SharedNoiseTable(deltas, env_seed + 7)
|
||||
self.policy_params = policy_params
|
||||
if policy_params['type'] == 'linear':
|
||||
self.policy = policies.LinearPolicy(policy_params)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.delta_std = delta_std
|
||||
self.rollout_length = rollout_length
|
||||
|
||||
def get_weights_plus_stats(self):
|
||||
"""
|
||||
Get current policy weights and current statistics of past states.
|
||||
"""
|
||||
assert self.policy_params['type'] == 'linear'
|
||||
return self.policy.get_weights_plus_stats()
|
||||
|
||||
def rollout(self, shift=0., rollout_length=None):
|
||||
"""Performs one rollout of maximum length rollout_length.
|
||||
|
||||
At each time-step it substracts shift from the reward.
|
||||
"""
|
||||
|
||||
if rollout_length is None:
|
||||
rollout_length = self.rollout_length
|
||||
|
||||
total_reward = 0.
|
||||
steps = 0
|
||||
|
||||
ob = self.env.reset()
|
||||
for i in range(rollout_length):
|
||||
action = self.policy.act(ob)
|
||||
ob, reward, done, _ = self.env.step(action)
|
||||
steps += 1
|
||||
total_reward += (reward - shift)
|
||||
if done:
|
||||
break
|
||||
|
||||
return total_reward, steps
|
||||
|
||||
def do_rollouts(self, w_policy, num_rollouts=1, shift=1, evaluate=False):
|
||||
"""
|
||||
Generate multiple rollouts with a policy parametrized by w_policy.
|
||||
"""
|
||||
print('Doing {} rollouts'.format(num_rollouts))
|
||||
rollout_rewards, deltas_idx = [], []
|
||||
steps = 0
|
||||
|
||||
for i in range(num_rollouts):
|
||||
|
||||
if evaluate:
|
||||
self.policy.update_weights(w_policy)
|
||||
deltas_idx.append(-1)
|
||||
|
||||
# set to false so that evaluation rollouts are not used for updating state statistics
|
||||
self.policy.update_filter = False
|
||||
|
||||
# for evaluation we do not shift the rewards (shift = 0) and we use the
|
||||
# default rollout length (1000 for the MuJoCo locomotion tasks)
|
||||
reward, r_steps = self.rollout(
|
||||
shift=0., rollout_length=self.rollout_length)
|
||||
rollout_rewards.append(reward)
|
||||
|
||||
else:
|
||||
idx, delta = self.deltas.get_delta(w_policy.size)
|
||||
|
||||
delta = (self.delta_std * delta).reshape(w_policy.shape)
|
||||
deltas_idx.append(idx)
|
||||
|
||||
# set to true so that state statistics are updated
|
||||
self.policy.update_filter = True
|
||||
|
||||
# compute reward and number of timesteps used for positive perturbation rollout
|
||||
self.policy.update_weights(w_policy + delta)
|
||||
pos_reward, pos_steps = self.rollout(shift=shift)
|
||||
|
||||
# compute reward and number of timesteps used for negative pertubation rollout
|
||||
self.policy.update_weights(w_policy - delta)
|
||||
neg_reward, neg_steps = self.rollout(shift=shift)
|
||||
steps += pos_steps + neg_steps
|
||||
|
||||
rollout_rewards.append([pos_reward, neg_reward])
|
||||
|
||||
return {
|
||||
'deltas_idx': deltas_idx,
|
||||
'rollout_rewards': rollout_rewards,
|
||||
'steps': steps
|
||||
}
|
||||
|
||||
def stats_increment(self):
|
||||
self.policy.observation_filter.stats_increment()
|
||||
return
|
||||
|
||||
def get_weights(self):
|
||||
return self.policy.get_weights()
|
||||
|
||||
def get_filter(self):
|
||||
return self.policy.observation_filter
|
||||
|
||||
def sync_filter(self, other):
|
||||
self.policy.observation_filter.sync(other)
|
||||
return
|
||||
|
||||
|
||||
class ARSLearner(object):
|
||||
"""
|
||||
Object class implementing the ARS algorithm.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
env_callback,
|
||||
policy_params=None,
|
||||
num_workers=32,
|
||||
num_deltas=320,
|
||||
deltas_used=320,
|
||||
delta_std=0.02,
|
||||
logdir=None,
|
||||
rollout_length=1000,
|
||||
step_size=0.01,
|
||||
shift='constant zero',
|
||||
params=None,
|
||||
seed=123):
|
||||
|
||||
logz.configure_output_dir(logdir)
|
||||
# params_to_save = copy.deepcopy(params)
|
||||
# params_to_save['env'] = None
|
||||
# logz.save_params(params_to_save)
|
||||
utility.save_config(params, logdir)
|
||||
env = env_callback()
|
||||
|
||||
self.timesteps = 0
|
||||
self.action_size = env.action_space.shape[0]
|
||||
self.ob_size = env.observation_space.shape[0]
|
||||
self.num_deltas = num_deltas
|
||||
self.deltas_used = deltas_used
|
||||
self.rollout_length = rollout_length
|
||||
self.step_size = step_size
|
||||
self.delta_std = delta_std
|
||||
self.logdir = logdir
|
||||
self.shift = shift
|
||||
self.params = params
|
||||
self.max_past_avg_reward = float('-inf')
|
||||
self.num_episodes_used = float('inf')
|
||||
|
||||
# create shared table for storing noise
|
||||
print('Creating deltas table.')
|
||||
deltas = shared_noise.create_shared_noise()
|
||||
self.deltas = shared_noise.SharedNoiseTable(deltas, seed=seed + 3)
|
||||
print('Created deltas table.')
|
||||
|
||||
# initialize workers with different random seeds
|
||||
print('Initializing workers.')
|
||||
self.num_workers = num_workers
|
||||
self.workers = [
|
||||
Worker(
|
||||
seed + 7 * i,
|
||||
env_callback=env_callback,
|
||||
policy_params=policy_params,
|
||||
deltas=deltas,
|
||||
rollout_length=rollout_length,
|
||||
delta_std=delta_std) for i in range(num_workers)
|
||||
]
|
||||
|
||||
# initialize policy
|
||||
if policy_params['type'] == 'linear':
|
||||
self.policy = policies.LinearPolicy(policy_params)
|
||||
self.w_policy = self.policy.get_weights()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# initialize optimization algorithm
|
||||
self.optimizer = optimizers.SGD(self.w_policy, self.step_size)
|
||||
print('Initialization of ARS complete.')
|
||||
|
||||
def aggregate_rollouts(self, num_rollouts=None, evaluate=False):
|
||||
"""
|
||||
Aggregate update step from rollouts generated in parallel.
|
||||
"""
|
||||
|
||||
if num_rollouts is None:
|
||||
num_deltas = self.num_deltas
|
||||
else:
|
||||
num_deltas = num_rollouts
|
||||
|
||||
results_one = [] #rollout_ids_one
|
||||
results_two = [] #rollout_ids_two
|
||||
|
||||
t1 = time.time()
|
||||
num_rollouts = int(num_deltas / self.num_workers)
|
||||
# if num_rollouts > 0:
|
||||
# with futures.ThreadPoolExecutor(
|
||||
# max_workers=self.num_workers) as executor:
|
||||
# workers = [
|
||||
# executor.submit(
|
||||
# worker.do_rollouts,
|
||||
# self.w_policy,
|
||||
# num_rollouts=num_rollouts,
|
||||
# shift=self.shift,
|
||||
# evaluate=evaluate) for worker in self.workers
|
||||
# ]
|
||||
# for worker in futures.as_completed(workers):
|
||||
# results_one.append(worker.result())
|
||||
#
|
||||
# workers = [
|
||||
# executor.submit(
|
||||
# worker.do_rollouts,
|
||||
# self.w_policy,
|
||||
# num_rollouts=1,
|
||||
# shift=self.shift,
|
||||
# evaluate=evaluate)
|
||||
# for worker in self.workers[:(num_deltas % self.num_workers)]
|
||||
# ]
|
||||
# for worker in futures.as_completed(workers):
|
||||
# results_two.append(worker.result())
|
||||
|
||||
# parallel generation of rollouts
|
||||
rollout_ids_one = [
|
||||
worker.do_rollouts(
|
||||
self.w_policy,
|
||||
num_rollouts=num_rollouts,
|
||||
shift=self.shift,
|
||||
evaluate=evaluate) for worker in self.workers
|
||||
]
|
||||
|
||||
rollout_ids_two = [
|
||||
worker.do_rollouts(
|
||||
self.w_policy, num_rollouts=1, shift=self.shift, evaluate=evaluate)
|
||||
for worker in self.workers[:(num_deltas % self.num_workers)]
|
||||
]
|
||||
results_one = rollout_ids_one
|
||||
results_two = rollout_ids_two
|
||||
# gather results
|
||||
|
||||
rollout_rewards, deltas_idx = [], []
|
||||
|
||||
for result in results_one:
|
||||
if not evaluate:
|
||||
self.timesteps += result['steps']
|
||||
deltas_idx += result['deltas_idx']
|
||||
rollout_rewards += result['rollout_rewards']
|
||||
|
||||
for result in results_two:
|
||||
if not evaluate:
|
||||
self.timesteps += result['steps']
|
||||
deltas_idx += result['deltas_idx']
|
||||
rollout_rewards += result['rollout_rewards']
|
||||
|
||||
deltas_idx = np.array(deltas_idx)
|
||||
rollout_rewards = np.array(rollout_rewards, dtype=np.float64)
|
||||
|
||||
print('Maximum reward of collected rollouts:', rollout_rewards.max())
|
||||
info_dict = {
|
||||
"max_reward": rollout_rewards.max()
|
||||
}
|
||||
t2 = time.time()
|
||||
|
||||
print('Time to generate rollouts:', t2 - t1)
|
||||
|
||||
if evaluate:
|
||||
return rollout_rewards
|
||||
|
||||
# select top performing directions if deltas_used < num_deltas
|
||||
max_rewards = np.max(rollout_rewards, axis=1)
|
||||
if self.deltas_used > self.num_deltas:
|
||||
self.deltas_used = self.num_deltas
|
||||
|
||||
idx = np.arange(max_rewards.size)[max_rewards >= np.percentile(
|
||||
max_rewards, 100 * (1 - (self.deltas_used / self.num_deltas)))]
|
||||
deltas_idx = deltas_idx[idx]
|
||||
rollout_rewards = rollout_rewards[idx, :]
|
||||
|
||||
# normalize rewards by their standard deviation
|
||||
rollout_rewards /= np.std(rollout_rewards)
|
||||
|
||||
t1 = time.time()
|
||||
# aggregate rollouts to form g_hat, the gradient used to compute SGD step
|
||||
g_hat, count = utils.batched_weighted_sum(
|
||||
rollout_rewards[:, 0] - rollout_rewards[:, 1],
|
||||
(self.deltas.get(idx, self.w_policy.size) for idx in deltas_idx),
|
||||
batch_size=500)
|
||||
g_hat /= deltas_idx.size
|
||||
t2 = time.time()
|
||||
print('time to aggregate rollouts', t2 - t1)
|
||||
return g_hat, info_dict
|
||||
|
||||
def train_step(self):
|
||||
"""
|
||||
Perform one update step of the policy weights.
|
||||
"""
|
||||
|
||||
g_hat, info_dict = self.aggregate_rollouts()
|
||||
print('Euclidean norm of update step:', np.linalg.norm(g_hat))
|
||||
self.w_policy -= self.optimizer._compute_step(g_hat).reshape(
|
||||
self.w_policy.shape)
|
||||
return info_dict
|
||||
|
||||
def train(self, num_iter):
|
||||
|
||||
start = time.time()
|
||||
for i in range(num_iter):
|
||||
|
||||
t1 = time.time()
|
||||
info_dict = self.train_step()
|
||||
t2 = time.time()
|
||||
print('total time of one step', t2 - t1)
|
||||
print('iter ', i, ' done')
|
||||
|
||||
# record statistics every 10 iterations
|
||||
if ((i) % 10 == 0):
|
||||
|
||||
rewards = self.aggregate_rollouts(num_rollouts=8, evaluate=True)
|
||||
w = self.workers[0].get_weights_plus_stats()
|
||||
|
||||
checkpoint_filename = os.path.join(
|
||||
self.logdir, 'lin_policy_plus_{:03d}.npz'.format(i))
|
||||
print('Save checkpoints to {}...', checkpoint_filename)
|
||||
checkpoint_file = open(checkpoint_filename, 'w')
|
||||
np.savez(checkpoint_file, w)
|
||||
print('End save checkpoints.')
|
||||
print(sorted(self.params.items()))
|
||||
logz.log_tabular('Time', time.time() - start)
|
||||
logz.log_tabular('Iteration', i + 1)
|
||||
logz.log_tabular('AverageReward', np.mean(rewards))
|
||||
logz.log_tabular('StdRewards', np.std(rewards))
|
||||
logz.log_tabular('MaxRewardRollout', np.max(rewards))
|
||||
logz.log_tabular('MinRewardRollout', np.min(rewards))
|
||||
logz.log_tabular('timesteps', self.timesteps)
|
||||
logz.dump_tabular()
|
||||
|
||||
t1 = time.time()
|
||||
# get statistics from all workers
|
||||
for j in range(self.num_workers):
|
||||
self.policy.observation_filter.update(self.workers[j].get_filter())
|
||||
self.policy.observation_filter.stats_increment()
|
||||
|
||||
# make sure master filter buffer is clear
|
||||
self.policy.observation_filter.clear_buffer()
|
||||
# sync all workers
|
||||
#filter_id = ray.put(self.policy.observation_filter)
|
||||
setting_filters_ids = [
|
||||
worker.sync_filter(self.policy.observation_filter)
|
||||
for worker in self.workers
|
||||
]
|
||||
# waiting for sync of all workers
|
||||
#ray.get(setting_filters_ids)
|
||||
|
||||
increment_filters_ids = [
|
||||
worker.stats_increment() for worker in self.workers
|
||||
]
|
||||
# waiting for increment of all workers
|
||||
#ray.get(increment_filters_ids)
|
||||
t2 = time.time()
|
||||
print('Time to sync statistics:', t2 - t1)
|
||||
|
||||
return info_dict
|
||||
Reference in New Issue
Block a user