more fixes for pybullet
This commit is contained in:
@@ -1,31 +0,0 @@
|
|||||||
"""One-line documentation for gym_example module.
|
|
||||||
|
|
||||||
A detailed description of gym_example.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gym
|
|
||||||
from envs.bullet.cartpole_bullet import CartPoleBulletEnv
|
|
||||||
import setuptools
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
w = [0.3, 0.02, 0.02, 0.012]
|
|
||||||
|
|
||||||
def main():
|
|
||||||
env = gym.make('CartPoleBulletEnv-v0')
|
|
||||||
for i_episode in range(1):
|
|
||||||
observation = env.reset()
|
|
||||||
done = False
|
|
||||||
t = 0
|
|
||||||
while not done:
|
|
||||||
print(observation)
|
|
||||||
action = np.array([np.inner(observation, w)])
|
|
||||||
print(action)
|
|
||||||
observation, reward, done, info = env.step(action)
|
|
||||||
t = t + 1
|
|
||||||
if done:
|
|
||||||
print("Episode finished after {} timesteps".format(t+1))
|
|
||||||
break
|
|
||||||
|
|
||||||
main()
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pybullet as p
|
import pybullet as p
|
||||||
from .. import pybullet_envs
|
import pybullet_envs
|
||||||
import time
|
import time
|
||||||
|
|
||||||
def relu(x):
|
def relu(x):
|
||||||
|
|||||||
@@ -1,88 +0,0 @@
|
|||||||
'''
|
|
||||||
A test for minitaurGymEnv
|
|
||||||
'''
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
from pybullet_envs.bullet.minitaurGymEnv import MinitaurGymEnv
|
|
||||||
|
|
||||||
try:
|
|
||||||
import sonnet
|
|
||||||
from agents import simpleAgentWithSonnet as agent_lib
|
|
||||||
except ImportError:
|
|
||||||
from agents import simpleAgent as agent_lib
|
|
||||||
|
|
||||||
|
|
||||||
def testSinePolicy():
|
|
||||||
"""Tests sine policy
|
|
||||||
"""
|
|
||||||
np.random.seed(47)
|
|
||||||
|
|
||||||
environment = MinitaurGymEnv(render=True)
|
|
||||||
sum_reward = 0
|
|
||||||
steps = 1000
|
|
||||||
amplitude1Bound = 0.5
|
|
||||||
amplitude2Bound = 0.15
|
|
||||||
speed = 40
|
|
||||||
|
|
||||||
for stepCounter in range(steps):
|
|
||||||
t = float(stepCounter) * environment._timeStep
|
|
||||||
|
|
||||||
if (t < 1):
|
|
||||||
amplitude1 = 0
|
|
||||||
amplitude2 = 0
|
|
||||||
else:
|
|
||||||
amplitude1 = amplitude1Bound
|
|
||||||
amplitude2 = amplitude2Bound
|
|
||||||
a1 = math.sin(t*speed)*amplitude1
|
|
||||||
a2 = math.sin(t*speed+3.14)*amplitude1
|
|
||||||
a3 = math.sin(t*speed)*amplitude2
|
|
||||||
a4 = math.sin(t*speed+3.14)*amplitude2
|
|
||||||
|
|
||||||
action = [a1, a2, a2, a1, a3, a4, a4, a3]
|
|
||||||
|
|
||||||
state, reward, done, info = environment.step(action)
|
|
||||||
sum_reward += reward
|
|
||||||
if done:
|
|
||||||
environment.reset()
|
|
||||||
print("sum reward: ", sum_reward)
|
|
||||||
|
|
||||||
|
|
||||||
def testDDPGPolicy():
|
|
||||||
"""Tests sine policy
|
|
||||||
"""
|
|
||||||
environment = MinitaurGymEnv(render=True)
|
|
||||||
sum_reward = 0
|
|
||||||
steps = 1000
|
|
||||||
ckpt_path = 'data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0'
|
|
||||||
observation_shape = (28,)
|
|
||||||
action_size = 8
|
|
||||||
actor_layer_size = (297, 158)
|
|
||||||
n_steps = 0
|
|
||||||
tf.reset_default_graph()
|
|
||||||
with tf.Session() as session:
|
|
||||||
agent = agent_lib.SimpleAgent(session=session, ckpt_path=ckpt_path, actor_layer_size=actor_layer_size)
|
|
||||||
state = environment.reset()
|
|
||||||
action = agent(state)
|
|
||||||
for _ in range(steps):
|
|
||||||
n_steps += 1
|
|
||||||
state, reward, done, info = environment.step(action)
|
|
||||||
action = agent(state)
|
|
||||||
sum_reward += reward
|
|
||||||
if done:
|
|
||||||
environment.reset()
|
|
||||||
n_steps += 1
|
|
||||||
print("total reward: ", sum_reward)
|
|
||||||
print("total steps: ", n_steps)
|
|
||||||
sum_reward = 0
|
|
||||||
n_steps = 0
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
testDDPGPolicy()
|
|
||||||
#testSinePolicy()
|
|
||||||
Reference in New Issue
Block a user