forked from justkittenaround/tuts-ml-agent
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_script.py
More file actions
55 lines (41 loc) · 1.76 KB
/
base_script.py
File metadata and controls
55 lines (41 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import mlagents
from mlagents_envs.environment import UnityEnvironment as UE
env = UE(file_name='RollerballBuild', seed=1, side_channels=[])
env.reset()
behavior_name = list(env.behavior_specs)[0]
print(f"Name of the behavior : {behavior_name}")
spec = env.behavior_specs[behavior_name]
print("Number of observations : ", len(spec.observation_shapes))
if spec.is_action_continuous():
print("The action is continuous")
if spec.is_action_discrete():
print("The action is discrete")
decision_steps, terminal_steps = env.get_steps(behavior_name)
print(decision_steps.obs)
for episode in range(3):
env.reset()
decision_steps, terminal_steps = env.get_steps(behavior_name)
tracked_agent = -1 # -1 indicates not yet tracking
done = False # For the tracked_agent
episode_rewards = 0 # For the tracked_agent
while not done:
# Track the first agent we see if not tracking
# Note : len(decision_steps) = [number of agents that requested a decision]
if tracked_agent == -1 and len(decision_steps) >= 1:
tracked_agent = decision_steps.agent_id[0]
# Generate an action for all agents
action = spec.create_random_action(len(decision_steps))
# Set the actions
env.set_actions(behavior_name, action)
# Move the simulation forward
env.step()
# Get the new simulation results
decision_steps, terminal_steps = env.get_steps(behavior_name)
if tracked_agent in decision_steps: # The agent requested a decision
episode_rewards += decision_steps[tracked_agent].reward
if tracked_agent in terminal_steps: # The agent terminated its episode
episode_rewards += terminal_steps[tracked_agent].reward
done = True
print(f"Total rewards for episode {episode} is {episode_rewards}")
env.close()
print("Closed environment")