-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_trpo_sigma_test.m
146 lines (124 loc) · 9.79 KB
/
plot_trpo_sigma_test.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
clear;
close all;
addpath('plotter/');
load_avg_standard_trpo
% define column descriptors
step = 1;
rollout_ep_len_mean = 2;
rollout_ep_rew_mean = 3;
time_fps = 4;
train_fitness_score = 5;
train_mean_evaluation_length = 6;
train_mean_evaluation_reward = 7;
train_real_mean_ep_len = 8;
train_explained_variance = 9;
train_is_line_search_success = 10;
train_kl_divergence_loss = 11;
train_learning_rate = 12;
train_policy_objective = 13;
train_std = 14;
train_value_loss = 15;
path = "data_cpy/csv/TRPO_sigma_test_theta_1/";
TRPO_sigma_001_number_0 = readmatrix(strcat(path, "TRPO_sigma_0.010_number_0.csv"));
TRPO_sigma_001_number_1 = readmatrix(strcat(path, "TRPO_sigma_0.010_number_1.csv"));
TRPO_sigma_001_number_2 = readmatrix(strcat(path, "TRPO_sigma_0.010_number_2.csv"));
TRPO_sigma_001_number_3 = readmatrix(strcat(path, "TRPO_sigma_0.010_number_3.csv"));
TRPO_sigma_001_number_4 = readmatrix(strcat(path, "TRPO_sigma_0.010_number_4.csv"));
TRPO_sigma_002_number_0 = readmatrix(strcat(path, "TRPO_sigma_0.020_number_0.csv"));
TRPO_sigma_002_number_1 = readmatrix(strcat(path, "TRPO_sigma_0.020_number_1.csv"));
TRPO_sigma_002_number_2 = readmatrix(strcat(path, "TRPO_sigma_0.020_number_2.csv"));
TRPO_sigma_002_number_3 = readmatrix(strcat(path, "TRPO_sigma_0.020_number_3.csv"));
TRPO_sigma_002_number_4 = readmatrix(strcat(path, "TRPO_sigma_0.020_number_4.csv"));
TRPO_sigma_005_number_0 = readmatrix(strcat(path, "TRPO_sigma_0.050_number_0.csv"));
TRPO_sigma_005_number_1 = readmatrix(strcat(path, "TRPO_sigma_0.050_number_1.csv"));
TRPO_sigma_005_number_2 = readmatrix(strcat(path, "TRPO_sigma_0.050_number_2.csv"));
TRPO_sigma_005_number_3 = readmatrix(strcat(path, "TRPO_sigma_0.050_number_3.csv"));
TRPO_sigma_005_number_4 = readmatrix(strcat(path, "TRPO_sigma_0.050_number_4.csv"));
TRPO_sigma_010_number_0 = readmatrix(strcat(path, "TRPO_sigma_0.100_number_0.csv"));
TRPO_sigma_010_number_1 = readmatrix(strcat(path, "TRPO_sigma_0.100_number_1.csv"));
TRPO_sigma_010_number_2 = readmatrix(strcat(path, "TRPO_sigma_0.100_number_2.csv"));
TRPO_sigma_010_number_3 = readmatrix(strcat(path, "TRPO_sigma_0.100_number_3.csv"));
TRPO_sigma_010_number_4 = readmatrix(strcat(path, "TRPO_sigma_0.100_number_4.csv"));
TRPO_sigma_020_number_0 = readmatrix(strcat(path, "TRPO_sigma_0.200_number_0.csv"));
TRPO_sigma_020_number_1 = readmatrix(strcat(path, "TRPO_sigma_0.200_number_1.csv"));
TRPO_sigma_020_number_2 = readmatrix(strcat(path, "TRPO_sigma_0.200_number_2.csv"));
TRPO_sigma_020_number_3 = readmatrix(strcat(path, "TRPO_sigma_0.200_number_3.csv"));
TRPO_sigma_020_number_4 = readmatrix(strcat(path, "TRPO_sigma_0.200_number_4.csv"));
%take the mean of all
TRPO_sigma_001_mean = (TRPO_sigma_001_number_0 + TRPO_sigma_001_number_1 + TRPO_sigma_001_number_2 + TRPO_sigma_001_number_3 + TRPO_sigma_001_number_4)/5;
TRPO_sigma_002_mean = (TRPO_sigma_002_number_0 + TRPO_sigma_002_number_1 + TRPO_sigma_002_number_2 + TRPO_sigma_002_number_3 + TRPO_sigma_002_number_4)/5;
TRPO_sigma_005_mean = (TRPO_sigma_005_number_0 + TRPO_sigma_005_number_1 + TRPO_sigma_005_number_2 + TRPO_sigma_005_number_3 + TRPO_sigma_005_number_4)/5;
TRPO_sigma_010_mean = (TRPO_sigma_010_number_0 + TRPO_sigma_010_number_1 + TRPO_sigma_010_number_2 + TRPO_sigma_010_number_3 + TRPO_sigma_010_number_4)/5;
TRPO_sigma_020_mean = (TRPO_sigma_020_number_0 + TRPO_sigma_020_number_1 + TRPO_sigma_020_number_2 + TRPO_sigma_020_number_3 + TRPO_sigma_020_number_4)/5;
rows = size(TRPO_sigma_001_mean, 1);
iterations_vector = [1:1:rows];
%plot
figure(1);
hold on;
plot(iterations_vector, trpo_standad_avg(:,train_mean_evaluation_reward), 'LineWidth', 2);
plot(iterations_vector, TRPO_sigma_001_mean(:,train_mean_evaluation_reward));
plot(iterations_vector, TRPO_sigma_002_mean(:,train_mean_evaluation_reward));
plot(iterations_vector, TRPO_sigma_005_mean(:,train_mean_evaluation_reward));
plot(iterations_vector, TRPO_sigma_010_mean(:,train_mean_evaluation_reward));
plot(iterations_vector, TRPO_sigma_020_mean(:,train_mean_evaluation_reward));
legend('baseline', '0.01', '0.02', '0.05', '0.10', '0.20');
legend('Location', 'northwest');
xticks([0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150]);
title('TRPO: Sigma impact on mean evaluation reward');
xlabel('Iterations');
ylabel('Mean Evaluation Reward');
hold off;
TRPO_sigma_001_number_0_first_300 = return_it_where_val_is_reached(TRPO_sigma_001_number_0, 300, train_mean_evaluation_reward);
TRPO_sigma_001_number_1_first_300 = return_it_where_val_is_reached(TRPO_sigma_001_number_1, 300, train_mean_evaluation_reward);
TRPO_sigma_001_number_2_first_300 = return_it_where_val_is_reached(TRPO_sigma_001_number_2, 300, train_mean_evaluation_reward);
TRPO_sigma_001_number_3_first_300 = return_it_where_val_is_reached(TRPO_sigma_001_number_3, 300, train_mean_evaluation_reward);
TRPO_sigma_001_number_4_first_300 = return_it_where_val_is_reached(TRPO_sigma_001_number_4, 300, train_mean_evaluation_reward);
TRPO_sigma_002_number_0_first_300 = return_it_where_val_is_reached(TRPO_sigma_002_number_0, 300, train_mean_evaluation_reward);
TRPO_sigma_002_number_1_first_300 = return_it_where_val_is_reached(TRPO_sigma_002_number_1, 300, train_mean_evaluation_reward);
TRPO_sigma_002_number_2_first_300 = return_it_where_val_is_reached(TRPO_sigma_002_number_2, 300, train_mean_evaluation_reward);
TRPO_sigma_002_number_3_first_300 = return_it_where_val_is_reached(TRPO_sigma_002_number_3, 300, train_mean_evaluation_reward);
TRPO_sigma_002_number_4_first_300 = return_it_where_val_is_reached(TRPO_sigma_002_number_4, 300, train_mean_evaluation_reward);
TRPO_sigma_005_number_0_first_300 = return_it_where_val_is_reached(TRPO_sigma_005_number_0, 300, train_mean_evaluation_reward);
TRPO_sigma_005_number_1_first_300 = return_it_where_val_is_reached(TRPO_sigma_005_number_1, 300, train_mean_evaluation_reward);
TRPO_sigma_005_number_2_first_300 = return_it_where_val_is_reached(TRPO_sigma_005_number_2, 300, train_mean_evaluation_reward);
TRPO_sigma_005_number_3_first_300 = return_it_where_val_is_reached(TRPO_sigma_005_number_3, 300, train_mean_evaluation_reward);
TRPO_sigma_005_number_4_first_300 = return_it_where_val_is_reached(TRPO_sigma_005_number_4, 300, train_mean_evaluation_reward);
TRPO_sigma_010_number_0_first_300 = return_it_where_val_is_reached(TRPO_sigma_010_number_0, 300, train_mean_evaluation_reward);
TRPO_sigma_010_number_1_first_300 = return_it_where_val_is_reached(TRPO_sigma_010_number_1, 300, train_mean_evaluation_reward);
TRPO_sigma_010_number_2_first_300 = return_it_where_val_is_reached(TRPO_sigma_010_number_2, 300, train_mean_evaluation_reward);
TRPO_sigma_010_number_3_first_300 = return_it_where_val_is_reached(TRPO_sigma_010_number_3, 300, train_mean_evaluation_reward);
TRPO_sigma_010_number_4_first_300 = return_it_where_val_is_reached(TRPO_sigma_010_number_4, 300, train_mean_evaluation_reward);
TRPO_sigma_020_number_0_first_300 = return_it_where_val_is_reached(TRPO_sigma_020_number_0, 300, train_mean_evaluation_reward);
TRPO_sigma_020_number_1_first_300 = return_it_where_val_is_reached(TRPO_sigma_020_number_1, 300, train_mean_evaluation_reward);
TRPO_sigma_020_number_2_first_300 = return_it_where_val_is_reached(TRPO_sigma_020_number_2, 300, train_mean_evaluation_reward);
TRPO_sigma_020_number_3_first_300 = return_it_where_val_is_reached(TRPO_sigma_020_number_3, 300, train_mean_evaluation_reward);
TRPO_sigma_020_number_4_first_300 = return_it_where_val_is_reached(TRPO_sigma_020_number_4, 300, train_mean_evaluation_reward);
TRPO_sigma_001_first_300 = [TRPO_sigma_001_number_0_first_300, TRPO_sigma_001_number_1_first_300, TRPO_sigma_001_number_2_first_300, TRPO_sigma_001_number_3_first_300, TRPO_sigma_001_number_4_first_300];
TRPO_sigma_002_first_300 = [TRPO_sigma_002_number_0_first_300, TRPO_sigma_002_number_1_first_300, TRPO_sigma_002_number_2_first_300, TRPO_sigma_002_number_3_first_300, TRPO_sigma_002_number_4_first_300];
TRPO_sigma_005_first_300 = [TRPO_sigma_005_number_0_first_300, TRPO_sigma_005_number_1_first_300, TRPO_sigma_005_number_2_first_300, TRPO_sigma_005_number_3_first_300, TRPO_sigma_005_number_4_first_300];
TRPO_sigma_010_first_300 = [TRPO_sigma_010_number_0_first_300, TRPO_sigma_010_number_1_first_300, TRPO_sigma_010_number_2_first_300, TRPO_sigma_010_number_3_first_300, TRPO_sigma_010_number_4_first_300];
TRPO_sigma_020_first_300 = [TRPO_sigma_020_number_0_first_300, TRPO_sigma_020_number_1_first_300, TRPO_sigma_020_number_2_first_300, TRPO_sigma_020_number_3_first_300, TRPO_sigma_020_number_4_first_300];
TRPO_sigma_001_first_300_mean = mean(TRPO_sigma_001_first_300)
TRPO_sigma_002_first_300_mean = mean(TRPO_sigma_002_first_300)
TRPO_sigma_005_first_300_mean = mean(TRPO_sigma_005_first_300)
TRPO_sigma_010_first_300_mean = mean(TRPO_sigma_010_first_300)
TRPO_sigma_020_first_300_mean = mean(TRPO_sigma_020_first_300)
TRPO_sigma_001_first_300_lowest = min(TRPO_sigma_001_first_300)
TRPO_sigma_002_first_300_lowest = min(TRPO_sigma_002_first_300)
TRPO_sigma_005_first_300_lowest = min(TRPO_sigma_005_first_300)
TRPO_sigma_010_first_300_lowest = min(TRPO_sigma_010_first_300)
TRPO_sigma_020_first_300_lowest = min(TRPO_sigma_020_first_300)
hold on;
figure(3);
%make bar plot of all first_300_mean and first_300_lowest
xticks_vec = ["baseline", "0.01" "0.02" "0.05" "0.1" "0.2"];
bar(xticks_vec, [trpo_first_300_mean, TRPO_sigma_001_first_300_mean, TRPO_sigma_002_first_300_mean, TRPO_sigma_005_first_300_mean, TRPO_sigma_010_first_300_mean, TRPO_sigma_020_first_300_mean; trpo_first_300_lowest, TRPO_sigma_001_first_300_lowest, TRPO_sigma_002_first_300_lowest, TRPO_sigma_005_first_300_lowest, TRPO_sigma_010_first_300_lowest, TRPO_sigma_020_first_300_lowest]);
title('TRPO: Iteration to get reward of 300 vs Sigma');
subtitle('Mean and lowest of 5 runs theta=1')
xlabel('Sigma');
ylabel('Iteration');
yline(trpo_first_300_mean, '--');
yline(trpo_standard_first_300_20_lowest, '--');
legend('Mean', 'Lowest', 'baseline mean', 'baseline 20% lowest');
legend('Location', 'northwest');
hold off;