-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_trpo_k_test.m
163 lines (139 loc) · 10.8 KB
/
plot_trpo_k_test.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
%clear;
%close all;
addpath('plotter/');
load_avg_standard_trpo
number_of_evals = 5;
% define column descriptors TRPO
TRPO_step = 1;
TRPO_rollout_ep_len_mean = 2;
TRPO_rollout_ep_rew_mean = 3;
TRPO_time_fps = 4;
TRPO_train_beta_loss = 5;
TRPO_train_fitness_score = 6;
TRPO_train_mean_episodic_Re3_reward = 7;
TRPO_train_mean_evaluation_length = 8;
TRPO_train_mean_evaluation_reward = 9;
TRPO_train_real_mean_ep_len = 10;
TRPO_train_explained_variance = 11;
TRPO_train_is_line_search_success = 12;
TRPO_train_kl_divergence_loss = 13;
TRPO_train_learning_rate = 14;
TRPO_train_policy_objective = 15;
TRPO_train_std = 16;
TRPO_train_value_los = 17;
path = "data_cpy/csv/TRPO_k_test_3-1000/";
TRPO_k_3_number_0 = readmatrix(strcat(path, "TRPO_k_3_number_0.csv"));
TRPO_k_3_number_1 = readmatrix(strcat(path, "TRPO_k_3_number_1.csv"));
TRPO_k_3_number_2 = readmatrix(strcat(path, "TRPO_k_3_number_2.csv"));
TRPO_k_3_number_3 = readmatrix(strcat(path, "TRPO_k_3_number_3.csv"));
TRPO_k_3_number_4 = readmatrix(strcat(path, "TRPO_k_3_number_4.csv"));
TRPO_k_50_number_0 = readmatrix(strcat(path, "TRPO_k_50_number_0.csv"));
TRPO_k_50_number_1 = readmatrix(strcat(path, "TRPO_k_50_number_1.csv"));
TRPO_k_50_number_2 = readmatrix(strcat(path, "TRPO_k_50_number_2.csv"));
TRPO_k_50_number_3 = readmatrix(strcat(path, "TRPO_k_50_number_3.csv"));
TRPO_k_50_number_4 = readmatrix(strcat(path, "TRPO_k_50_number_4.csv"));
TRPO_k_100_number_0 = readmatrix(strcat(path, "TRPO_k_100_number_0.csv"));
TRPO_k_100_number_1 = readmatrix(strcat(path, "TRPO_k_100_number_1.csv"));
TRPO_k_100_number_2 = readmatrix(strcat(path, "TRPO_k_100_number_2.csv"));
TRPO_k_100_number_3 = readmatrix(strcat(path, "TRPO_k_100_number_3.csv"));
TRPO_k_100_number_4 = readmatrix(strcat(path, "TRPO_k_100_number_4.csv"));
TRPO_k_200_number_0 = readmatrix(strcat(path, "TRPO_k_200_number_0.csv"));
TRPO_k_200_number_1 = readmatrix(strcat(path, "TRPO_k_200_number_1.csv"));
TRPO_k_200_number_2 = readmatrix(strcat(path, "TRPO_k_200_number_2.csv"));
TRPO_k_200_number_3 = readmatrix(strcat(path, "TRPO_k_200_number_3.csv"));
TRPO_k_200_number_4 = readmatrix(strcat(path, "TRPO_k_200_number_4.csv"));
TRPO_k_500_number_0 = readmatrix(strcat(path, "TRPO_k_500_number_0.csv"));
TRPO_k_500_number_1 = readmatrix(strcat(path, "TRPO_k_500_number_1.csv"));
TRPO_k_500_number_2 = readmatrix(strcat(path, "TRPO_k_500_number_2.csv"));
TRPO_k_500_number_3 = readmatrix(strcat(path, "TRPO_k_500_number_3.csv"));
TRPO_k_500_number_4 = readmatrix(strcat(path, "TRPO_k_500_number_4.csv"));
TRPO_k_1000_number_0 = readmatrix(strcat(path, "TRPO_k_1000_number_0.csv"));
TRPO_k_1000_number_1 = readmatrix(strcat(path, "TRPO_k_1000_number_1.csv"));
TRPO_k_1000_number_2 = readmatrix(strcat(path, "TRPO_k_1000_number_2.csv"));
TRPO_k_1000_number_3 = readmatrix(strcat(path, "TRPO_k_1000_number_3.csv"));
TRPO_k_1000_number_4 = readmatrix(strcat(path, "TRPO_k_1000_number_4.csv"));
TRPO_k_3_mean = (TRPO_k_3_number_0 + TRPO_k_3_number_1 + TRPO_k_3_number_2 + TRPO_k_3_number_3 + TRPO_k_3_number_4) / number_of_evals;
TRPO_k_50_mean = (TRPO_k_50_number_0 + TRPO_k_50_number_1 + TRPO_k_50_number_2 + TRPO_k_50_number_3 + TRPO_k_50_number_4) / number_of_evals;
TRPO_k_100_mean = (TRPO_k_100_number_0 + TRPO_k_100_number_1 + TRPO_k_100_number_2 + TRPO_k_100_number_3 + TRPO_k_100_number_4) / number_of_evals;
TRPO_k_200_mean = (TRPO_k_200_number_0 + TRPO_k_200_number_1 + TRPO_k_200_number_2 + TRPO_k_200_number_3 + TRPO_k_200_number_4) / number_of_evals;
TRPO_k_500_mean = (TRPO_k_500_number_0 + TRPO_k_500_number_1 + TRPO_k_500_number_2 + TRPO_k_500_number_3 + TRPO_k_500_number_4) / number_of_evals;
TRPO_k_1000_mean = (TRPO_k_1000_number_0 + TRPO_k_1000_number_1 + TRPO_k_1000_number_2 + TRPO_k_1000_number_3 + TRPO_k_1000_number_4) / number_of_evals;
rows = size(TRPO_k_3_mean, 1);
iterations_vec = [1:1:rows];
figure(1);
hold on;
plot(iterations_vec, trpo_standad_avg(:, 7), 'LineWidth', 2);
plot(iterations_vec, TRPO_k_3_mean(:, TRPO_train_mean_evaluation_reward), 'LineWidth', 1);
plot(iterations_vec, TRPO_k_50_mean(:, TRPO_train_mean_evaluation_reward), 'LineWidth', 1);
plot(iterations_vec, TRPO_k_100_mean(:, TRPO_train_mean_evaluation_reward), 'LineWidth', 1);
plot(iterations_vec, TRPO_k_200_mean(:, TRPO_train_mean_evaluation_reward), 'LineWidth', 1);
plot(iterations_vec, TRPO_k_500_mean(:, TRPO_train_mean_evaluation_reward), 'LineWidth', 1);
plot(iterations_vec, TRPO_k_1000_mean(:, TRPO_train_mean_evaluation_reward), 'LineWidth', 1);
legend('baseline', 'k = 3', 'k = 50', 'k = 100', 'k = 200', 'k = 500', 'k = 1000');
legend('Location', 'northwest');
title('TRPO: Re3 k test');
xticks([0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150]);
xlabel('Iterations');
ylabel('Mean evaluation reward');
hold off;
TRPO_k_3_number_0_first_300 = return_it_where_val_is_reached(TRPO_k_3_number_0, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_3_number_1_first_300 = return_it_where_val_is_reached(TRPO_k_3_number_1, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_3_number_2_first_300 = return_it_where_val_is_reached(TRPO_k_3_number_2, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_3_number_3_first_300 = return_it_where_val_is_reached(TRPO_k_3_number_3, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_3_number_4_first_300 = return_it_where_val_is_reached(TRPO_k_3_number_4, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_50_number_0_first_300 = return_it_where_val_is_reached(TRPO_k_50_number_0, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_50_number_1_first_300 = return_it_where_val_is_reached(TRPO_k_50_number_1, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_50_number_2_first_300 = return_it_where_val_is_reached(TRPO_k_50_number_2, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_50_number_3_first_300 = return_it_where_val_is_reached(TRPO_k_50_number_3, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_50_number_4_first_300 = return_it_where_val_is_reached(TRPO_k_50_number_4, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_100_number_0_first_300 = return_it_where_val_is_reached(TRPO_k_100_number_0, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_100_number_1_first_300 = return_it_where_val_is_reached(TRPO_k_100_number_1, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_100_number_2_first_300 = return_it_where_val_is_reached(TRPO_k_100_number_2, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_100_number_3_first_300 = return_it_where_val_is_reached(TRPO_k_100_number_3, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_100_number_4_first_300 = return_it_where_val_is_reached(TRPO_k_100_number_4, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_200_number_0_first_300 = return_it_where_val_is_reached(TRPO_k_200_number_0, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_200_number_1_first_300 = return_it_where_val_is_reached(TRPO_k_200_number_1, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_200_number_2_first_300 = return_it_where_val_is_reached(TRPO_k_200_number_2, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_200_number_3_first_300 = return_it_where_val_is_reached(TRPO_k_200_number_3, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_200_number_4_first_300 = return_it_where_val_is_reached(TRPO_k_200_number_4, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_500_number_0_first_300 = return_it_where_val_is_reached(TRPO_k_500_number_0, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_500_number_1_first_300 = return_it_where_val_is_reached(TRPO_k_500_number_1, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_500_number_2_first_300 = return_it_where_val_is_reached(TRPO_k_500_number_2, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_500_number_3_first_300 = return_it_where_val_is_reached(TRPO_k_500_number_3, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_500_number_4_first_300 = return_it_where_val_is_reached(TRPO_k_500_number_4, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_1000_number_0_first_300 = return_it_where_val_is_reached(TRPO_k_1000_number_0, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_1000_number_1_first_300 = return_it_where_val_is_reached(TRPO_k_1000_number_1, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_1000_number_2_first_300 = return_it_where_val_is_reached(TRPO_k_1000_number_2, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_1000_number_3_first_300 = return_it_where_val_is_reached(TRPO_k_1000_number_3, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_1000_number_4_first_300 = return_it_where_val_is_reached(TRPO_k_1000_number_4, 300, TRPO_train_mean_evaluation_reward);
TRPO_k_3_first_300 = [TRPO_k_3_number_0_first_300, TRPO_k_3_number_1_first_300, TRPO_k_3_number_2_first_300, TRPO_k_3_number_3_first_300, TRPO_k_3_number_4_first_300];
TRPO_k_50_first_300 = [TRPO_k_50_number_0_first_300, TRPO_k_50_number_1_first_300, TRPO_k_50_number_2_first_300, TRPO_k_50_number_3_first_300, TRPO_k_50_number_4_first_300];
TRPO_k_100_first_300 = [TRPO_k_100_number_0_first_300, TRPO_k_100_number_1_first_300, TRPO_k_100_number_2_first_300, TRPO_k_100_number_3_first_300, TRPO_k_100_number_4_first_300];
TRPO_k_200_first_300 = [TRPO_k_200_number_0_first_300, TRPO_k_200_number_1_first_300, TRPO_k_200_number_2_first_300, TRPO_k_200_number_3_first_300, TRPO_k_200_number_4_first_300];
TRPO_k_500_first_300 = [TRPO_k_500_number_0_first_300, TRPO_k_500_number_1_first_300, TRPO_k_500_number_2_first_300, TRPO_k_500_number_3_first_300, TRPO_k_500_number_4_first_300];
TRPO_k_1000_first_300 = [TRPO_k_1000_number_0_first_300, TRPO_k_1000_number_1_first_300, TRPO_k_1000_number_2_first_300, TRPO_k_1000_number_3_first_300, TRPO_k_1000_number_4_first_300];
TRPO_k_3_first_300_mean = mean(TRPO_k_3_first_300);
TRPO_k_50_first_300_mean = mean(TRPO_k_50_first_300);
TRPO_k_100_first_300_mean = mean(TRPO_k_100_first_300);
TRPO_k_200_first_300_mean = mean(TRPO_k_200_first_300);
TRPO_k_500_first_300_mean = mean(TRPO_k_500_first_300);
TRPO_k_1000_first_300_mean = mean(TRPO_k_1000_first_300);
TRPO_k_3_first_300_lowest = min(TRPO_k_3_first_300);
TRPO_k_50_first_300_lowest = min(TRPO_k_50_first_300);
TRPO_k_100_first_300_lowest = min(TRPO_k_100_first_300);
TRPO_k_200_first_300_lowest = min(TRPO_k_200_first_300);
TRPO_k_500_first_300_lowest = min(TRPO_k_500_first_300);
TRPO_k_1000_first_300_lowest = min(TRPO_k_1000_first_300);
hold on;
figure(2);
hold on;
xticks_vec = ["baseline", "3", "50", "100", "200", "500", "1000"];
bar(xticks_vec, [trpo_first_300_mean, TRPO_k_3_first_300_mean, TRPO_k_50_first_300_mean, TRPO_k_100_first_300_mean, TRPO_k_200_first_300_mean, TRPO_k_500_first_300_mean, TRPO_k_1000_first_300_mean; trpo_first_300_lowest, TRPO_k_3_first_300_lowest, TRPO_k_50_first_300_lowest, TRPO_k_100_first_300_lowest, TRPO_k_200_first_300_lowest, TRPO_k_500_first_300_lowest, TRPO_k_1000_first_300_lowest]);
title('TRPO: Iteration to get reward of 300 vs k');
xlabel('k');
ylabel('Iterations');
yline(trpo_first_300_mean, '--');
yline(trpo_standard_first_300_20_lowest, '--');
legend('Mean', 'Lowest', 'baseline mean', 'baseline 20% lowest');
legend('Location', 'southeast');
hold off;