-
Notifications
You must be signed in to change notification settings - Fork 398
/
Copy pathtd.py
1651 lines (1471 loc) · 72.6 KB
/
td.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import copy
import numpy as np
from collections import namedtuple
from typing import Union, Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.hpc_rl import hpc_wrapper
from ding.rl_utils.value_rescale import value_transform, value_inv_transform
from ding.torch_utils import to_tensor
q_1step_td_data = namedtuple('q_1step_td_data', ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'])
def discount_cumsum(x, gamma: float = 1.0) -> np.ndarray:
assert abs(gamma - 1.) < 1e-5, "gamma equals to 1.0 in original decision transformer paper"
disc_cumsum = np.zeros_like(x)
disc_cumsum[-1] = x[-1]
for t in reversed(range(x.shape[0] - 1)):
disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1]
return disc_cumsum
def q_1step_td_error(
data: namedtuple,
gamma: float,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
"""
Overview:
1 step td_error, support single agent case and multi agent case.
Arguments:
- data (:obj:`q_1step_td_data`): The input data, q_1step_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- criterion (:obj:`torch.nn.modules`): Loss function criterion
Returns:
- loss (:obj:`torch.Tensor`): 1step td error
Shapes:
- data (:obj:`q_1step_td_data`): the q_1step_td_data containing\
['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- next_act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`( , B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> action_dim = 4
>>> data = q_1step_td_data(
>>> q=torch.randn(3, action_dim),
>>> next_q=torch.randn(3, action_dim),
>>> act=torch.randint(0, action_dim, (3,)),
>>> next_act=torch.randint(0, action_dim, (3,)),
>>> reward=torch.randn(3),
>>> done=torch.randint(0, 2, (3,)).bool(),
>>> weight=torch.ones(3),
>>> )
>>> loss = q_1step_td_error(data, 0.99)
"""
q, next_q, act, next_act, reward, done, weight = data
assert len(act.shape) == 1, act.shape
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
if weight is None:
weight = torch.ones_like(reward)
q_s_a = q[batch_range, act]
target_q_s_a = next_q[batch_range, next_act]
target_q_s_a = gamma * (1 - done) * target_q_s_a + reward
return (criterion(q_s_a, target_q_s_a.detach()) * weight).mean()
m_q_1step_td_data = namedtuple('m_q_1step_td_data', ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight'])
def m_q_1step_td_error(
data: namedtuple,
gamma: float,
tau: float,
alpha: float,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
"""
Overview:
Munchausen td_error for DQN algorithm, support 1 step td error.
Arguments:
- data (:obj:`m_q_1step_td_data`): The input data, m_q_1step_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- tau (:obj:`float`): Entropy factor for Munchausen DQN
- alpha (:obj:`float`): Discount factor for Munchausen term
- criterion (:obj:`torch.nn.modules`): Loss function criterion
Returns:
- loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor
Shapes:
- data (:obj:`m_q_1step_td_data`): the m_q_1step_td_data containing\
['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- target_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`( , B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> action_dim = 4
>>> data = m_q_1step_td_data(
>>> q=torch.randn(3, action_dim),
>>> target_q=torch.randn(3, action_dim),
>>> next_q=torch.randn(3, action_dim),
>>> act=torch.randint(0, action_dim, (3,)),
>>> reward=torch.randn(3),
>>> done=torch.randint(0, 2, (3,)),
>>> weight=torch.ones(3),
>>> )
>>> loss = m_q_1step_td_error(data, 0.99, 0.01, 0.01)
"""
q, target_q, next_q, act, reward, done, weight = data
lower_bound = -1
assert len(act.shape) == 1, act.shape
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
if weight is None:
weight = torch.ones_like(reward)
q_s_a = q[batch_range, act]
# calculate muchausen addon
# replay_log_policy
target_v_s = target_q[batch_range].max(1)[0].unsqueeze(-1)
logsum = torch.logsumexp((target_q - target_v_s) / tau, 1).unsqueeze(-1)
log_pi = target_q - target_v_s - tau * logsum
act_get = act.unsqueeze(-1)
# same to the last second tau_log_pi_a
munchausen_addon = log_pi.gather(1, act_get)
muchausen_term = alpha * torch.clamp(munchausen_addon, min=lower_bound, max=1)
# replay_next_log_policy
target_v_s_next = next_q[batch_range].max(1)[0].unsqueeze(-1)
logsum_next = torch.logsumexp((next_q - target_v_s_next) / tau, 1).unsqueeze(-1)
tau_log_pi_next = next_q - target_v_s_next - tau * logsum_next
# do stable softmax == replay_next_policy
pi_target = F.softmax((next_q - target_v_s_next) / tau)
target_q_s_a = (gamma * (pi_target * (next_q - tau_log_pi_next) * (1 - done.unsqueeze(-1))).sum(1)).unsqueeze(-1)
target_q_s_a = reward.unsqueeze(-1) + muchausen_term + target_q_s_a
td_error_per_sample = criterion(q_s_a.unsqueeze(-1), target_q_s_a.detach()).squeeze(-1)
# calculate action_gap and clipfrac
with torch.no_grad():
top2_q_s = target_q[batch_range].topk(2, dim=1, largest=True, sorted=True)[0]
action_gap = (top2_q_s[:, 0] - top2_q_s[:, 1]).mean()
clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lower_bound)
clipfrac = torch.as_tensor(clipped).float()
return (td_error_per_sample * weight).mean(), td_error_per_sample, action_gap, clipfrac
q_v_1step_td_data = namedtuple('q_v_1step_td_data', ['q', 'v', 'act', 'reward', 'done', 'weight'])
def q_v_1step_td_error(
data: namedtuple, gamma: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none')
) -> torch.Tensor:
# we will use this function in discrete sac algorithm to calculate td error between q and v value.
"""
Overview:
td_error between q and v value for SAC algorithm, support 1 step td error.
Arguments:
- data (:obj:`q_v_1step_td_data`): The input data, q_v_1step_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- criterion (:obj:`torch.nn.modules`): Loss function criterion
Returns:
- loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor
Shapes:
- data (:obj:`q_v_1step_td_data`): the q_v_1step_td_data containing\
['q', 'v', 'act', 'reward', 'done', 'weight']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- v (:obj:`torch.FloatTensor`): :math:`(B, )`
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`( , B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> action_dim = 4
>>> data = q_v_1step_td_data(
>>> q=torch.randn(3, action_dim),
>>> v=torch.randn(3),
>>> act=torch.randint(0, action_dim, (3,)),
>>> reward=torch.randn(3),
>>> done=torch.randint(0, 2, (3,)),
>>> weight=torch.ones(3),
>>> )
>>> loss = q_v_1step_td_error(data, 0.99)
"""
q, v, act, reward, done, weight = data
if len(act.shape) == 1:
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
if weight is None:
weight = torch.ones_like(reward)
q_s_a = q[batch_range, act]
target_q_s_a = gamma * (1 - done) * v + reward
else:
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
actor_range = torch.arange(act.shape[1])
batch_actor_range = torch.arange(act.shape[0] * act.shape[1])
if weight is None:
weight = torch.ones_like(act)
temp_q = q.reshape(act.shape[0] * act.shape[1], -1)
temp_act = act.reshape(act.shape[0] * act.shape[1])
q_s_a = temp_q[batch_actor_range, temp_act]
q_s_a = q_s_a.reshape(act.shape[0], act.shape[1])
target_q_s_a = gamma * (1 - done).unsqueeze(1) * v + reward.unsqueeze(1)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
def view_similar(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
size = list(x.shape) + [1 for _ in range(len(target.shape) - len(x.shape))]
return x.view(*size)
nstep_return_data = namedtuple('nstep_return_data', ['reward', 'next_value', 'done'])
def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_gamma: Optional[torch.Tensor] = None):
'''
Overview:
Calculate nstep return for DQN algorithm, support single agent case and multi agent case.
Arguments:
- data (:obj:`nstep_return_data`): The input data, nstep_return_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num
- value_gamma (:obj:`torch.Tensor`): Discount factor for value
Returns:
- return (:obj:`torch.Tensor`): nstep return
Shapes:
- data (:obj:`nstep_return_data`): the nstep_return_data containing\
['reward', 'next_value', 'done']
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- next_value (:obj:`torch.FloatTensor`): :math:`(, B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
Examples:
>>> data = nstep_return_data(
>>> reward=torch.randn(3, 3),
>>> next_value=torch.randn(3),
>>> done=torch.randint(0, 2, (3,)),
>>> )
>>> loss = nstep_return(data, 0.99, 3)
'''
reward, next_value, done = data
assert reward.shape[0] == nstep
device = reward.device
if isinstance(gamma, float):
reward_factor = torch.ones(nstep).to(device)
for i in range(1, nstep):
reward_factor[i] = gamma * reward_factor[i - 1]
reward_factor = view_similar(reward_factor, reward)
return_tmp = reward.mul(reward_factor).sum(0)
if value_gamma is None:
return_ = return_tmp + (gamma ** nstep) * next_value * (1 - done)
else:
if np.isscalar(value_gamma):
value_gamma = torch.full_like(next_value, value_gamma)
value_gamma = view_similar(value_gamma, next_value)
done = view_similar(done, next_value)
return_ = return_tmp + value_gamma * next_value * (1 - done)
elif isinstance(gamma, list):
# if gamma is list, for NGU policy case
reward_factor = torch.ones([nstep + 1, done.shape[0]]).to(device)
for i in range(1, nstep + 1):
reward_factor[i] = torch.stack(gamma, dim=0).to(device) * reward_factor[i - 1]
reward_factor = view_similar(reward_factor, reward)
return_tmp = reward.mul(reward_factor[:nstep]).sum(0)
return_ = return_tmp + reward_factor[nstep] * next_value * (1 - done)
else:
raise TypeError("The type of gamma should be float or list")
return return_
dist_1step_td_data = namedtuple(
'dist_1step_td_data', ['dist', 'next_dist', 'act', 'next_act', 'reward', 'done', 'weight']
)
def dist_1step_td_error(
data: namedtuple,
gamma: float,
v_min: float,
v_max: float,
n_atom: int,
) -> torch.Tensor:
"""
Overview:
1 step td_error for distributed q-learning based algorithm
Arguments:
- data (:obj:`dist_1step_td_data`): The input data, dist_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- v_min (:obj:`float`): The min value of support
- v_max (:obj:`float`): The max value of support
- n_atom (:obj:`int`): The num of atom
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`dist_1step_td_data`): the dist_1step_td_data containing\
['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight']
- dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom]
- next_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)`
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- next_act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(, B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True)
>>> next_dist = torch.randn(4, 3, 51).abs()
>>> act = torch.randint(0, 3, (4,))
>>> next_act = torch.randint(0, 3, (4,))
>>> reward = torch.randn(4)
>>> done = torch.randint(0, 2, (4,))
>>> data = dist_1step_td_data(dist, next_dist, act, next_act, reward, done, None)
>>> loss = dist_1step_td_error(data, 0.99, -10.0, 10.0, 51)
"""
dist, next_dist, act, next_act, reward, done, weight = data
device = reward.device
assert len(reward.shape) == 1, reward.shape
support = torch.linspace(v_min, v_max, n_atom).to(device)
delta_z = (v_max - v_min) / (n_atom - 1)
if len(act.shape) == 1:
reward = reward.unsqueeze(-1)
done = done.unsqueeze(-1)
batch_size = act.shape[0]
batch_range = torch.arange(batch_size)
if weight is None:
weight = torch.ones_like(reward)
next_dist = next_dist[batch_range, next_act].detach()
else:
reward = reward.unsqueeze(-1).repeat(1, act.shape[1])
done = done.unsqueeze(-1).repeat(1, act.shape[1])
batch_size = act.shape[0] * act.shape[1]
batch_range = torch.arange(act.shape[0] * act.shape[1])
action_dim = dist.shape[2]
dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
reward = reward.reshape(act.shape[0] * act.shape[1], -1)
done = done.reshape(act.shape[0] * act.shape[1], -1)
next_dist = next_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
next_act = next_act.reshape(act.shape[0] * act.shape[1])
next_dist = next_dist[batch_range, next_act].detach()
next_dist = next_dist.reshape(act.shape[0] * act.shape[1], -1)
act = act.reshape(act.shape[0] * act.shape[1])
if weight is None:
weight = torch.ones_like(reward)
target_z = reward + (1 - done) * gamma * support
target_z = target_z.clamp(min=v_min, max=v_max)
b = (target_z - v_min) / delta_z
l = b.floor().long()
u = b.ceil().long()
# Fix disappearing probability mass when l = b = u (b is int)
l[(u > 0) * (l == u)] -= 1
u[(l < (n_atom - 1)) * (l == u)] += 1
proj_dist = torch.zeros_like(next_dist)
offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size,
n_atom).long().to(device)
proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1))
proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1))
log_p = torch.log(dist[batch_range, act])
loss = -(log_p * proj_dist * weight).sum(-1).mean()
return loss
dist_nstep_td_data = namedtuple(
'dist_1step_td_data', ['dist', 'next_n_dist', 'act', 'next_n_act', 'reward', 'done', 'weight']
)
def shape_fn_dntd(args, kwargs):
r"""
Overview:
Return dntd shape for hpc
Returns:
shape: [T, B, N, n_atom]
"""
if len(args) <= 0:
tmp = [kwargs['data'].reward.shape[0]]
tmp.extend(list(kwargs['data'].dist.shape))
else:
tmp = [args[0].reward.shape[0]]
tmp.extend(list(args[0].dist.shape))
return tmp
@hpc_wrapper(
shape_fn=shape_fn_dntd,
namedtuple_data=True,
include_args=[0, 1, 2, 3],
include_kwargs=['data', 'gamma', 'v_min', 'v_max']
)
def dist_nstep_td_error(
data: namedtuple,
gamma: float,
v_min: float,
v_max: float,
n_atom: int,
nstep: int = 1,
value_gamma: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error for distributed q-learning based algorithm, support single\
agent case and multi agent case.
Arguments:
- data (:obj:`dist_nstep_td_data`): The input data, dist_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`dist_nstep_td_data`): the dist_nstep_td_data containing\
['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight']
- dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom]
- next_n_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)`
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
Examples:
>>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True)
>>> next_n_dist = torch.randn(4, 3, 51).abs()
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> reward = torch.randn(5, 4)
>>> data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None)
>>> loss, _ = dist_nstep_td_error(data, 0.95, -10.0, 10.0, 51, 5)
"""
dist, next_n_dist, act, next_n_act, reward, done, weight = data
device = reward.device
reward_factor = torch.ones(nstep).to(device)
for i in range(1, nstep):
reward_factor[i] = gamma * reward_factor[i - 1]
reward = torch.matmul(reward_factor, reward)
support = torch.linspace(v_min, v_max, n_atom).to(device)
delta_z = (v_max - v_min) / (n_atom - 1)
if len(act.shape) == 1:
reward = reward.unsqueeze(-1)
done = done.unsqueeze(-1)
batch_size = act.shape[0]
batch_range = torch.arange(batch_size)
if weight is None:
weight = torch.ones_like(reward)
elif isinstance(weight, float):
weight = torch.tensor(weight)
next_n_dist = next_n_dist[batch_range, next_n_act].detach()
else:
reward = reward.unsqueeze(-1).repeat(1, act.shape[1])
done = done.unsqueeze(-1).repeat(1, act.shape[1])
batch_size = act.shape[0] * act.shape[1]
batch_range = torch.arange(act.shape[0] * act.shape[1])
action_dim = dist.shape[2]
dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
reward = reward.reshape(act.shape[0] * act.shape[1], -1)
done = done.reshape(act.shape[0] * act.shape[1], -1)
next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
next_n_act = next_n_act.reshape(act.shape[0] * act.shape[1])
next_n_dist = next_n_dist[batch_range, next_n_act].detach()
next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], -1)
act = act.reshape(act.shape[0] * act.shape[1])
if weight is None:
weight = torch.ones_like(reward)
elif isinstance(weight, float):
weight = torch.tensor(weight)
if value_gamma is None:
target_z = reward + (1 - done) * (gamma ** nstep) * support
elif isinstance(value_gamma, float):
value_gamma = torch.tensor(value_gamma).unsqueeze(-1)
target_z = reward + (1 - done) * value_gamma * support
else:
value_gamma = value_gamma.unsqueeze(-1)
target_z = reward + (1 - done) * value_gamma * support
target_z = target_z.clamp(min=v_min, max=v_max)
b = (target_z - v_min) / delta_z
l = b.floor().long()
u = b.ceil().long()
# Fix disappearing probability mass when l = b = u (b is int)
l[(u > 0) * (l == u)] -= 1
u[(l < (n_atom - 1)) * (l == u)] += 1
proj_dist = torch.zeros_like(next_n_dist)
offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size,
n_atom).long().to(device)
proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_n_dist * (u.float() - b)).view(-1))
proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_n_dist * (b - l.float())).view(-1))
assert (dist[batch_range, act] > 0.0).all(), ("dist act", dist[batch_range, act], "dist:", dist)
log_p = torch.log(dist[batch_range, act])
if len(weight.shape) == 1:
weight = weight.unsqueeze(-1)
td_error_per_sample = -(log_p * proj_dist).sum(-1)
loss = -(log_p * proj_dist * weight).sum(-1).mean()
return loss, td_error_per_sample
v_1step_td_data = namedtuple('v_1step_td_data', ['v', 'next_v', 'reward', 'done', 'weight'])
def v_1step_td_error(
data: namedtuple,
gamma: float,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
'''
Overview:
1 step td_error for distributed value based algorithm
Arguments:
- data (:obj:`v_1step_td_data`): The input data, v_1step_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- criterion (:obj:`torch.nn.modules`): Loss function criterion
Returns:
- loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor
Shapes:
- data (:obj:`v_1step_td_data`): the v_1step_td_data containing\
['v', 'next_v', 'reward', 'done', 'weight']
- v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ]
- next_v (:obj:`torch.FloatTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(, B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> v = torch.randn(5).requires_grad_(True)
>>> next_v = torch.randn(5)
>>> reward = torch.rand(5)
>>> done = torch.zeros(5)
>>> data = v_1step_td_data(v, next_v, reward, done, None)
>>> loss, td_error_per_sample = v_1step_td_error(data, 0.99)
'''
v, next_v, reward, done, weight = data
if weight is None:
weight = torch.ones_like(v)
if len(v.shape) == len(reward.shape):
if done is not None:
target_v = gamma * (1 - done) * next_v + reward
else:
target_v = gamma * next_v + reward
else:
if done is not None:
target_v = gamma * (1 - done).unsqueeze(1) * next_v + reward.unsqueeze(1)
else:
target_v = gamma * next_v + reward.unsqueeze(1)
td_error_per_sample = criterion(v, target_v.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
v_nstep_td_data = namedtuple('v_nstep_td_data', ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma'])
def v_nstep_td_error(
data: namedtuple,
gamma: float,
nstep: int = 1,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
"""
Overview:
Multistep (n step) td_error for distributed value based algorithm
Arguments:
- data (:obj:`dist_nstep_td_data`): The input data, v_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing \
['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma']
- v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ]
- next_v (:obj:`torch.FloatTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
- value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step \
we use value_gamma as the gamma discount value for next_v rather than gamma**n_step
Examples:
>>> v = torch.randn(5).requires_grad_(True)
>>> next_v = torch.randn(5)
>>> reward = torch.rand(5, 5)
>>> done = torch.zeros(5)
>>> data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99)
>>> loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5)
"""
v, next_n_v, reward, done, weight, value_gamma = data
if weight is None:
weight = torch.ones_like(v)
target_v = nstep_return(nstep_return_data(reward, next_n_v, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(v, target_v.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
q_nstep_td_data = namedtuple(
'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight']
)
dqfd_nstep_td_data = namedtuple(
'dqfd_nstep_td_data', [
'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_one_step', 'weight', 'new_n_q_one_step',
'next_n_action_one_step', 'is_expert'
]
)
def shape_fn_qntd(args, kwargs):
r"""
Overview:
Return qntd shape for hpc
Returns:
shape: [T, B, N]
"""
if len(args) <= 0:
tmp = [kwargs['data'].reward.shape[0]]
tmp.extend(list(kwargs['data'].q.shape))
else:
tmp = [args[0].reward.shape[0]]
tmp.extend(list(args[0].q.shape))
return tmp
@hpc_wrapper(shape_fn=shape_fn_qntd, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma'])
def q_nstep_td_error(
data: namedtuple,
gamma: Union[float, list],
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error for q-learning based algorithm
Arguments:
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
Examples:
>>> next_q = torch.randn(4, 3)
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> nstep =3
>>> q = torch.randn(4, 3).requires_grad_(True)
>>> reward = torch.rand(nstep, 4)
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
>>> loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, weight = data
if weight is None:
weight = torch.ones_like(reward)
if len(action.shape) == 1 or len(action.shape) < len(q.shape):
# we need to unsqueeze action and q to make them have the same shape
# e.g. single agent case: action is [B, ] and q is [B, ]
# e.g. multi agent case: action is [B, agent_num] and q is [B, agent_num, action_shape]
action = action.unsqueeze(-1)
elif len(action.shape) > 1: # MARL case
reward = reward.unsqueeze(-1)
weight = weight.unsqueeze(-1)
done = done.unsqueeze(-1)
if value_gamma is not None:
value_gamma = value_gamma.unsqueeze(-1)
q_s_a = q.gather(-1, action).squeeze(-1)
target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1)
if cum_reward:
if value_gamma is None:
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
else:
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
else:
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
def bdq_nstep_td_error(
data: namedtuple,
gamma: Union[float, list],
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures for \
Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946.
In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the \
calculation method of n-step, i.e., TD-error:
Arguments:
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing \
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` i.e. [batch_size, branch_num, action_bins_per_branch]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, D, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, D)`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, D)`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
Examples:
>>> action_per_branch = 3
>>> next_q = torch.randn(8, 6, action_per_branch)
>>> done = torch.randn(8)
>>> action = torch.randint(0, action_per_branch, size=(8, 6))
>>> next_action = torch.randint(0, action_per_branch, size=(8, 6))
>>> nstep =3
>>> q = torch.randn(8, 6, action_per_branch).requires_grad_(True)
>>> reward = torch.rand(nstep, 8)
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
>>> loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, weight = data
if weight is None:
weight = torch.ones_like(reward)
reward = reward.unsqueeze(-1)
done = done.unsqueeze(-1)
if value_gamma is not None:
value_gamma = value_gamma.unsqueeze(-1)
q_s_a = q.gather(-1, action.unsqueeze(-1)).squeeze(-1)
target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1)
if cum_reward:
if value_gamma is None:
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
else:
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
else:
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
td_error_per_sample = td_error_per_sample.mean(-1)
return (td_error_per_sample * weight).mean(), td_error_per_sample
def shape_fn_qntd_rescale(args, kwargs):
r"""
Overview:
Return qntd_rescale shape for hpc
Returns:
shape: [T, B, N]
"""
if len(args) <= 0:
tmp = [kwargs['data'].reward.shape[0]]
tmp.extend(list(kwargs['data'].q.shape))
else:
tmp = [args[0].reward.shape[0]]
tmp.extend(list(args[0].q.shape))
return tmp
@hpc_wrapper(
shape_fn=shape_fn_qntd_rescale, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma']
)
def q_nstep_td_error_with_rescale(
data: namedtuple,
gamma: Union[float, list],
nstep: int = 1,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
trans_fn: Callable = value_transform,
inv_trans_fn: Callable = value_inv_transform,
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error with value rescaling
Arguments:
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- trans_fn (:obj:`Callable`): Value transfrom function, default to value_transform\
(refer to rl_utils/value_rescale.py)
- inv_trans_fn (:obj:`Callable`): Value inverse transfrom function, default to value_inv_transform\
(refer to rl_utils/value_rescale.py)
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
Examples:
>>> next_q = torch.randn(4, 3)
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> nstep =3
>>> q = torch.randn(4, 3).requires_grad_(True)
>>> reward = torch.rand(nstep, 4)
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
>>> loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, weight = data
assert len(action.shape) == 1, action.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
q_s_a = q[batch_range, action]
target_q_s_a = next_n_q[batch_range, next_n_action]
target_q_s_a = inv_trans_fn(target_q_s_a)
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
target_q_s_a = trans_fn(target_q_s_a)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
def dqfd_nstep_td_error(
data: namedtuple,
gamma: float,
lambda_n_step_td: float,
lambda_supervised_loss: float,
margin_function: float,
lambda_one_step_td: float = 1.,
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd
Arguments:
- data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss
- gamma (:obj:`float`): discount factor
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- nstep (:obj:`int`): nstep num, default set to 10
Returns:
- loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\
+ supervised margin loss, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\
, 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
- new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)`
- next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )`
- is_expert (:obj:`int`) : 0 or 1
Examples:
>>> next_q = torch.randn(4, 3)
>>> done = torch.randn(4)
>>> done_1 = torch.randn(4)
>>> next_q_one_step = torch.randn(4, 3)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> next_action_one_step = torch.randint(0, 3, size=(4, ))
>>> is_expert = torch.ones((4))
>>> nstep = 3
>>> q = torch.randn(4, 3).requires_grad_(True)
>>> reward = torch.rand(nstep, 4)
>>> data = dqfd_nstep_td_data(
>>> q, next_q, action, next_action, reward, done, done_1, None,
>>> next_q_one_step, next_action_one_step, is_expert
>>> )
>>> loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error(
>>> data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1,
>>> margin_function=0.8, nstep=nstep
>>> )
"""
q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \
is_expert = data # set is_expert flag(expert 1, agent 0)
assert len(action.shape) == 1, action.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
q_s_a = q[batch_range, action]
target_q_s_a = next_n_q[batch_range, next_n_action]
target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step]
# calculate n-step TD-loss
if cum_reward:
if value_gamma is None:
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
else:
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
else:
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
# calculate 1-step TD-loss
nstep = 1
reward = reward[0].unsqueeze(0) # get the one-step reward
value_gamma = None
if cum_reward:
if value_gamma is None:
target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step)
else:
target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step)
else:
target_q_s_a_one_step = nstep_return(
nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma
)
td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach())
device = q_s_a.device
device_cpu = torch.device('cpu')
# calculate the supervised loss
l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, )
l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu))
# along the first dimension. for the index of the action, fill the corresponding position in l with 0
JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a)
return (
(
(
lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE
) * weight
).mean(), lambda_n_step_td * td_error_per_sample.abs() +
lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),
(td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
)
def dqfd_nstep_td_error_with_rescale(
data: namedtuple,
gamma: float,
lambda_n_step_td: float,
lambda_supervised_loss: float,
lambda_one_step_td: float,
margin_function: float,
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
trans_fn: Callable = value_transform,
inv_trans_fn: Callable = value_inv_transform,
) -> torch.Tensor:
"""