-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
Copy pathcuda-decoder-kernels.cu
2099 lines (1972 loc) · 98.2 KB
/
cuda-decoder-kernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// cudadecoder/cuda-decoder-kernels.cu
//
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Hugo Braun, Justin Luitjens, Ryan Leary
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cub/cub.cuh>
#include "cuda-decoder-kernels.h"
#include "cuda-decoder-kernels-utils.h"
namespace kaldi {
namespace cuda_decoder {
// Initialize the hashmap with NO_VAL
// Called in InitDeviceData, when building the CudaDecoder object
__global__ void init_hashmap_kernel(DeviceParams cst_dev_params) {
const int max_nlanes = cst_dev_params.max_nlanes;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, max_nlanes) {
const int capacity = cst_dev_params.hashmap_capacity;
KALDI_CUDA_DECODER_1D_KERNEL_LOOP(idx, capacity) {
cst_dev_params.d_hashmap_values.lane(ilane)[idx] =
KALDI_CUDA_DECODER_HASHMAP_NO_VAL;
}
}
}
// Initialize initial channel on device
// Called by ComputeInitialChannel
// It is NOT called in InitDecoding
// In InitDecoding we will clone the initial channel into the channel we called
// InitDecoding on
// Here we are actually creating this initial channel
// we do that once in the CudaDecoder constructor.
//
// The initial channel is the state of a channel when
// it will start decoding a new utterance
// thread (1, 1, 1)
// blocks(1, 1, 1);
__global__ void initialize_initial_lane_kernel(DeviceParams cst_dev_params) {
const int init_ichannel = cst_dev_params.init_channel_id;
const int init_ilane = 0;
ChannelCounters *init_channel_counters =
cst_dev_params.d_channels_counters.channel(init_ichannel);
LaneCounters *lane_counters =
cst_dev_params.d_lanes_counters.lane(init_ilane);
// Making the data look like an ExpandArcsEmitting just executed,
// and put the StartState in the aux_q. We will then pick up a normal
// execution from there
// (calling PruneAndPreprocess, then ExpandArcsNonEmitting..)
lane_counters->aux_q_end = 0;
lane_counters->aux_q_requested = 0;
lane_counters->post_expand_aux_q_end = 1;
lane_counters->main_q_global_offset = 0;
lane_counters->main_q_local_offset = 0;
lane_counters->main_q_n_extra_prev_tokens = 0;
lane_counters->int_cutoff = INT_MAX;
lane_counters->main_q_n_emitting_tokens = 0; // all non emitting
lane_counters->int_beam = floatToOrderedInt(cst_dev_params.default_beam);
lane_counters->main_q_narcs_and_end = {0, 0};
lane_counters->main_q_requested = 0;
lane_counters->prev_arg_min_int_cost = 0;
const StateId init_state = cst_dev_params.init_state;
const CostType init_cost = cst_dev_params.init_cost;
IntegerCostType int_init_cost = floatToOrderedInt(init_cost);
cst_dev_params.d_aux_q_state_and_cost.lane(init_ilane)[0] = {init_state,
int_init_cost};
lane_counters->min_int_cost = int_init_cost;
CostType cutoff = orderedIntToFloat(int_init_cost);
lane_counters->int_cutoff =
floatToOrderedInt(cutoff + cst_dev_params.default_beam);
cst_dev_params.d_aux_q_info.lane(init_ilane)[0] = {INT_MIN, -1};
}
// Called by InitDecoding
// Called when some channels will start decoding a new utterance
// do everything that's needed to do on the device to start decoding a new
// utterance with those channels
// It clones the initial channel (created in initialize_initial_lane_kernel)
// into the channels we want to InitDecoding on
__global__ void init_decoding_on_device_kernel(DeviceParams cst_dev_params,
KernelParams params) {
const int init_ichannel = cst_dev_params.init_channel_id;
const ChannelCounters *init_channel_counters =
cst_dev_params.d_channels_counters.channel(init_ichannel);
const int32 init_main_q_end =
init_channel_counters->prev_main_q_narcs_and_end.y;
const int32 nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
KALDI_CUDA_DECODER_1D_KERNEL_LOOP(idx, init_main_q_end) {
const LaneCounters *lane_counters =
cst_dev_params.d_lanes_counters.lane(ilane);
const int32 ichannel = lane_counters->channel_to_compute;
cst_dev_params.d_main_q_state_and_cost.channel(ichannel)[idx] =
cst_dev_params.d_main_q_state_and_cost.channel(init_ichannel)[idx];
cst_dev_params.d_main_q_degrees_prefix_sum.channel(ichannel)[idx] =
cst_dev_params.d_main_q_degrees_prefix_sum.channel(
init_ichannel)[idx];
cst_dev_params.d_main_q_arc_offsets.channel(ichannel)[idx] =
cst_dev_params.d_main_q_arc_offsets.channel(init_ichannel)[idx];
if (idx == 0) {
ChannelCounters *channel_counters =
cst_dev_params.d_channels_counters.channel(ichannel);
channel_counters->prev_main_q_narcs_and_end =
init_channel_counters->prev_main_q_narcs_and_end;
channel_counters->prev_main_q_n_extra_prev_tokens =
init_channel_counters->prev_main_q_n_extra_prev_tokens;
channel_counters->prev_main_q_global_offset = 0;
channel_counters->prev_main_q_extra_prev_tokens_global_offset = 0;
channel_counters->prev_beam = cst_dev_params.default_beam;
}
}
}
}
// Context switch : load
// Called by LoadChannelsStateToLanes
// THREADS : (1, 1, 1)
// BLOCKS : (1, nlanes_used, 1)
__global__ void load_channels_state_in_lanes_kernel(DeviceParams cst_dev_params,
KernelParams params) {
const int nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
LaneCounters *lane_counters = cst_dev_params.d_lanes_counters.lane(ilane);
const int32 ichannel = lane_counters->channel_to_compute;
const ChannelCounters *channel_counters =
cst_dev_params.d_channels_counters.channel(ichannel);
int2 main_q_narcs_and_end = channel_counters->prev_main_q_narcs_and_end;
lane_counters->main_q_narcs_and_end = main_q_narcs_and_end;
lane_counters->main_q_n_extra_prev_tokens =
channel_counters->prev_main_q_n_extra_prev_tokens;
CostType beam = channel_counters->prev_beam;
IntegerCostType int_beam = floatToOrderedInt(beam);
lane_counters->int_beam = int_beam;
lane_counters->adaptive_int_beam_with_validity_index.x = int_beam;
lane_counters->adaptive_int_beam_with_validity_index.y =
cst_dev_params.adaptive_beam_static_segment;
lane_counters->main_q_global_offset =
channel_counters
->prev_main_q_global_offset; // we'll update it after emitting
lane_counters->main_q_extra_prev_tokens_global_offset =
channel_counters->prev_main_q_extra_prev_tokens_global_offset;
}
}
// Context switch : store
// Called by SaveChannelsStateFromLanes
// THREADS : (1, 1, 1)
// BLOCKS : (1, nchannel_to_compute, 1)
__global__ void save_channels_state_from_lanes_kernel(
DeviceParams cst_dev_params, KernelParams params) {
const int nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
const LaneCounters *lane_counters =
cst_dev_params.d_lanes_counters.lane(ilane);
const int32 ichannel = lane_counters->channel_to_compute;
ChannelCounters *channel_counters =
cst_dev_params.d_channels_counters.channel(ichannel);
channel_counters->prev_main_q_global_offset =
lane_counters->main_q_global_offset;
channel_counters->prev_main_q_extra_prev_tokens_global_offset =
lane_counters->main_q_extra_prev_tokens_global_offset;
channel_counters->prev_main_q_narcs_and_end =
lane_counters->main_q_narcs_and_end;
channel_counters->prev_main_q_n_extra_prev_tokens =
lane_counters->main_q_n_extra_prev_tokens;
channel_counters->prev_beam = orderedIntToFloat(lane_counters->int_beam);
}
}
// compute_lane_offsets_kernel
// the kernel concatenate_lanes_data concatenates multiple array into a single
// continuous array
// compute_lane_offsets_kernel computes the offset of each array into this
// continous array
// This kernel is 1D : the lanes are on the X dimension, because we want to
// compute the offset of those lanes
__global__ void compute_lane_offsets_kernel(DeviceParams cst_dev_params,
KernelParams params) {
typedef cub::BlockScan<int4, KALDI_CUDA_DECODER_1D_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
const int nlanes = params.nlanes_used;
int4 sum_so_far = {0, 0, 0, 0};
KALDI_CUDA_DECODER_1D_BLOCK_OFFSET_KERNEL_LOOP(
block_offset, thread_idx,
nlanes + 1) { // +1 because we are doing an exclusive sum, and we want
// all the values
int32 ilane = block_offset + thread_idx;
int4 zero4 = {0, 0, 0, 0};
int4 lane_offsets = zero4;
if (ilane < nlanes) { // nlanes, not nlanes+1, because we cannot read +1
// values (undefined)
LaneCounters *d_lane_counters =
cst_dev_params.d_lanes_counters.lane(ilane);
int32 main_q_end = d_lane_counters->main_q_narcs_and_end.y;
int32 n_emitting_tokens = d_lane_counters->main_q_n_emitting_tokens;
int32 main_q_n_extra_prev_tokens =
d_lane_counters->main_q_n_extra_prev_tokens;
lane_offsets = {main_q_end, n_emitting_tokens, main_q_n_extra_prev_tokens,
0};
}
int4 block_aggregate;
BlockScan(temp_storage)
.ExclusiveScan(lane_offsets, lane_offsets, zero4, PlusPlusPlusPlus(),
block_aggregate);
PlusPlusPlusPlus pppp;
lane_offsets = pppp(lane_offsets, sum_so_far);
sum_so_far = pppp(sum_so_far, block_aggregate);
if (ilane < (nlanes + 1)) { // nlanes+1, to write the output
LaneCounters *d_lane_counters =
cst_dev_params.d_lanes_counters.lane(ilane);
LaneCounters *h_lane_counters =
cst_dev_params.h_lanes_counters.lane(ilane);
h_lane_counters->main_q_end_lane_offset =
d_lane_counters->main_q_end_lane_offset = lane_offsets.x;
h_lane_counters->main_q_n_emitting_tokens_lane_offset =
d_lane_counters->main_q_n_emitting_tokens_lane_offset =
lane_offsets.y;
h_lane_counters->main_q_n_extra_prev_tokens_lane_offset =
d_lane_counters->main_q_n_extra_prev_tokens_lane_offset =
lane_offsets.z;
}
__syncthreads(); // reusing temp_storage
}
}
// concatenate_lanes_data
// Called by PerformConcatenatedCopy
// Creates a concatenate array into concat,
// by concatenating all the arrays src.lane(ilane)
// for ilane=0..params.nlanes_used
// Used to prepare data for copy to Host. We want to avoid small Device2Host
// copies.
template <typename T>
__global__ void concatenate_lanes_data_kernel(DeviceParams cst_dev_params,
KernelParams params,
LaneMatrixView<T> src, T *concat,
int32 *lane_offsets) {
const int nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
const int32 stride =
sizeof(LaneCounters) / sizeof(int32); // offsets are in LaneCounters
int32 beg = *(lane_offsets + ilane * stride);
int32 end = *(lane_offsets + (ilane + 1) * stride);
int32 vec_size = end - beg;
KALDI_CUDA_DECODER_1D_KERNEL_LOOP(idx, vec_size) {
T d = src.lane(ilane)[idx];
concat[beg + idx] = d;
}
}
}
// nonemitting_preprocess_and_contract_kernel
// Called from PruneAndPreprocess
// This kernels prune the aux_q, move the survival tokens to the main_q,
// and add the preprocessing information necessary for the next ExpandArcs
// (the expand that follows PruneAndPreprocess is always non-emitting)
// It prunes the tokens using the cutoff, and prepare the data necessary for
// ExpandArcs:
// d_main_q_degrees_prefix_sum, d_main_q_arc_offsets_
// The prefix sum is done in one-pass here, using a trick (we compute the prefix
// sum
// as we fill the main_q)
__global__ void nonemitting_preprocess_and_contract_kernel(
DeviceParams cst_dev_params, KernelParams params) {
typedef cub::BlockScan<int2, KALDI_CUDA_DECODER_1D_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage sh_temp_storage;
// We need to move the survival tokens to the main_q
//
// sh_main_q_global_block_offset has two purposes :
// (1) to know where to store the survival tokens in the main_q
// (2) to perform the prefix sum degrees (of the survival tokens)
__shared__ int2 sh_main_q_global_block_offset;
const int nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
LaneCounters *lane_counters = cst_dev_params.d_lanes_counters.lane(ilane);
const int32 aux_q_end = lane_counters->post_expand_aux_q_end;
const IntegerCostType int_cutoff = lane_counters->int_cutoff;
// Keeping whole CTA alive. We'll use __syncthreads()
KALDI_CUDA_DECODER_1D_BLOCK_OFFSET_KERNEL_LOOP(block_offset, thread_idx,
aux_q_end) {
const int32 aux_q_idx = block_offset + thread_idx;
const int32 ichannel = lane_counters->channel_to_compute;
int32 degree = 0;
int32 arc_start = -1;
StateId token_state;
IntegerCostType token_int_cost;
// We've kept the whole CTA alive. Now we keep only those will a valid
// token
if (aux_q_idx < aux_q_end) {
const int2 both =
cst_dev_params.d_aux_q_state_and_cost.lane(ilane)[aux_q_idx];
token_state = both.x;
token_int_cost = both.y;
if (token_int_cost < int_cutoff) {
// We'll keep that token. Loading its arc degree/csr offset now.
arc_start = cst_dev_params.d_arc_ne_offsets[token_state];
const int32 arc_end =
cst_dev_params.d_arc_ne_offsets[token_state + 1];
degree = arc_end - arc_start;
}
}
// If we've set a different arc_start,
// this thread has a valid unpruned token
int32 is_pruned = (arc_start == -1);
// We now know which tokens will be moved to the main_q, the remaining
// will be pruned
// we now compute a prefix sum inside the CUDA block to determine the
// local indexes of the unpruned tokens
// the first unpruned token will have a index of 0, the second 1, ...
// We also need to compute the prefix sum of the arc degrees
// we start by doing a local prefix sum inside the CUDA block
int2 block_prefix_sum_narcs_and_end = {degree, (is_pruned ? 0 : 1)};
const int2 zero2 = {0, 0};
// Computing the prefix sum (exclusive)
BlockScan(sh_temp_storage)
.ExclusiveScan(block_prefix_sum_narcs_and_end,
block_prefix_sum_narcs_and_end, zero2, PlusPlus());
if (KALDI_CUDA_DECODER_IS_LAST_1D_THREAD()) {
// This conditional branch is entered by the last thread
// Because it is the last, the prefix_sum of that thread contains the
// sum of all elements
// We also add the value from this thread - the prefix sum is exclusive
// For the sum, we want it inclusive
int2 block_sum = block_prefix_sum_narcs_and_end;
block_sum.x += degree;
block_sum.y += is_pruned ? 0 : 1;
// Doing two things at the same time :
// requesting a spot in the main_q to store the survival tokens from
// this CTA
// We also increment the narcs value. atomic64.x will contain the number
// of
// arcs in the main_q up until the atomic64.y index
// That's all we need to finish our prefix sum. We add this global
// offset.
// First atomic to check if we are not overflowing main_q.
int block_offset =
atomicAdd(&lane_counters->main_q_requested, block_sum.y);
// Verify that we do not overflow
if (block_offset + block_sum.y < cst_dev_params.main_q_capacity) {
// we don't overflow we can safely grab a spot in the main_q
sh_main_q_global_block_offset =
atomicAddI2(&lane_counters->main_q_narcs_and_end, block_sum);
} else {
// our update would overflow
lane_counters->q_overflow |= OVERFLOW_MAIN_Q; // for the host
sh_main_q_global_block_offset.y =
cst_dev_params.main_q_capacity; // used as flag to broadcast the
// information in the CTA
}
}
// Syncing because :
// - Broadcasting sh_main_q_global_block_offset
// - We may reuse sh_temp_storage (cf CUB doc)
__syncthreads();
// Checking if we are overflowing the main_q
// All threads are executing the next line
if (sh_main_q_global_block_offset.y == cst_dev_params.main_q_capacity)
goto end_lane; // done for this lane
// If we are executing the following lines it means that we are not
// overflowing the queue
// We then continue what we were doing
if (!is_pruned) {
bool moving_emitting_tokens = (lane_counters->main_q_local_offset == 0);
// we will move our unpruned token to the main_q, at index main_q_idx
InfoToken tok_info = cst_dev_params.d_aux_q_info.lane(ilane)[aux_q_idx];
const int32 main_q_idx =
sh_main_q_global_block_offset.y + block_prefix_sum_narcs_and_end.y;
CostType acoustic_cost = 0.0f;
if (moving_emitting_tokens && tok_info.arc_idx != -1) {
const int32 arc_ilabel =
cst_dev_params.d_arc_pdf_ilabels[tok_info.arc_idx];
acoustic_cost = -lane_counters->loglikelihoods[arc_ilabel];
}
cst_dev_params.d_main_q_info.lane(ilane)[main_q_idx] = tok_info;
// Moving the token to the main q
cst_dev_params.d_main_q_state_and_cost.channel(ichannel)[main_q_idx] = {
token_state, token_int_cost};
cst_dev_params.d_main_q_acoustic_cost.lane(ilane)[main_q_idx] =
acoustic_cost;
// Saving the global prefix sum
const int32 prefix_sum_narcs =
sh_main_q_global_block_offset.x + block_prefix_sum_narcs_and_end.x;
cst_dev_params.d_main_q_degrees_prefix_sum.channel(
ichannel)[main_q_idx] = prefix_sum_narcs;
// Saving the CSR arc offset for that token's state
// it will be used by the expand kernel, and avoid doing a new random
// memory access in the expand kernel
cst_dev_params.d_main_q_arc_offsets.channel(ichannel)[main_q_idx] =
arc_start;
}
}
end_lane:; // empty statement
}
}
// GetAdaptiveBeam is used in ExpandArcs
// When we generate new tokens by traversing arcs,
// we can end up creating a lot of tokens, if the current frame
// generated loglikelihoods too uniform for instance (we don't have
// any good tokens that will reduce the cutoff, so we end up generating
// a lot of tokens)
// To avoid overflowing the aux_q, we apply a decreasing beam.
// With aux_q_end being the current aux_q size, we have a decrease function f, with
// adaptive_beam = f(aux_q_end)
// f is a decreasing piecewise constant function
// Please note that when processing tokens, we usually have dozens of thousands of threads
// generating tokens. Those are already in flight, and will not reload the beam immediatly.
// It means that we need to start reducing the beam as soon as we detect that we are generating more tokens than
// expected.
// We can configure the function f using KALDI_CUDA_DECODER_ADAPTIVE_BEAM_STATIC_SEGMENT
// and KALDI_CUDA_DECODER_ADAPTIVE_BEAM_NSTEPS.
// We will use default_beam for the first max_tokens_per_frame/KALDI_CUDA_DECODER_ADAPTIVE_BEAM_STATIC_SEGMENT
// tokens in the aux_q.
// Once we reach that number, we will decrease the adaptive beam linearly from default_beam to 0,
// using KALDI_CUDA_DECODER_ADAPTIVE_BEAM_NSTEPS steps
//
// x-axis : aux_q_end. How much tokens are already in the aux_q
// y-axis : adaptive_beam = f(aux_q_end)
// default_beam _| ________________
// | /\ _________
// | | _________
// 0 _| static_segment _________
// |________________________________________________
// | |
// aux_q_end= 0 max_tokens_per_frame
// We have :
// static_segment = max_tokens_per_frame/KALDI_CUDA_DECODER_ADAPTIVE_BEAM_STATIC_SEGMENT
// and KALDI_CUDA_DECODER_ADAPTIVE_BEAM_NSTEPS = 3
__device__ void UpdateAdaptiveBeam(const DeviceParams &cst_dev_params,
const int aux_q_index_block_offset,
IntegerCostType min_int_cost,
int2 *adaptive_int_beam_with_validity_index,
LaneCounters *lane_counters) {
int32 beam_valid_until_idx = adaptive_int_beam_with_validity_index->y;
if (aux_q_index_block_offset < beam_valid_until_idx) return; // nothing to do
CostType beam = orderedIntToFloat(adaptive_int_beam_with_validity_index->x);
while (aux_q_index_block_offset >= beam_valid_until_idx) {
beam /= 2;
beam_valid_until_idx += cst_dev_params.adaptive_beam_bin_width;
}
IntegerCostType new_int_cutoff = (min_int_cost < INT_MAX)
? floatToOrderedInt(orderedIntToFloat(min_int_cost) + beam)
: INT_MAX;
IntegerCostType int_beam = floatToOrderedInt(beam);
adaptive_int_beam_with_validity_index->x = int_beam;
adaptive_int_beam_with_validity_index->y = beam_valid_until_idx;
// We can have races between the two atomics
// However the worst than can happen is a CTA might delay updating the beam
// This is not a critical bug. However, once we have a floatToOrderedInt
// that is generating unsigned ints, we could merge the two atomics into a
// single atomic64
atomicMin(&lane_counters->adaptive_int_beam_with_validity_index.x, int_beam);
atomicMax(&lane_counters->adaptive_int_beam_with_validity_index.y,
beam_valid_until_idx);
atomicMin(&lane_counters->int_cutoff, new_int_cutoff);
}
// One CTA / lane
__global__ void reset_for_frame_and_estimate_cutoff_kernel(
DeviceParams cst_dev_params, KernelParams params) {
typedef cub::BlockReduce<CostType, KALDI_CUDA_DECODER_1D_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
const int nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
LaneCounters *lane_counters = cst_dev_params.d_lanes_counters.lane(ilane);
const int32 ichannel = lane_counters->channel_to_compute;
ChannelCounters *channel_counters =
cst_dev_params.d_channels_counters.channel(ichannel);
if (threadIdx.x == 0) {
const CostType current_beam = orderedIntToFloat(lane_counters->int_beam);
// Do some initialization
lane_counters->q_overflow = OVERFLOW_NONE;
lane_counters->main_q_n_emitting_tokens = INT_MAX;
lane_counters->int_cutoff = INT_MAX;
lane_counters->min_int_cost = INT_MAX;
lane_counters->q_overflow = OVERFLOW_NONE;
lane_counters->int_relative_cost = INT_MAX;
lane_counters->aux_q_requested = 0;
lane_counters->main_q_requested = 0;
lane_counters->main_q_local_offset = 0;
lane_counters->compute_max_active =
false; // will be set to true if necessary
channel_counters->min_int_cost_and_arg_with_final.x =
INT_MAX; // it will be set with atomicMins
const CostType new_beam =
fmin(cst_dev_params.default_beam,
current_beam * KALDI_CUDA_DECODER_ADAPTIVE_BEAM_RECOVER_RATE);
lane_counters->int_beam = floatToOrderedInt(new_beam);
}
const int32 prev_arg_min = lane_counters->prev_arg_min_int_cost;
int2 both =
cst_dev_params.d_main_q_state_and_cost.channel(ichannel)[prev_arg_min];
int32 int_cost = both.y;
CostType previous_cost = orderedIntToFloat(int_cost);
const int32 prev_arg_min_state = both.x;
int32 arc_start = cst_dev_params.d_arc_e_offsets[prev_arg_min_state];
int32 arc_end = cst_dev_params.d_arc_e_offsets[prev_arg_min_state + 1];
int32 narcs = arc_end - arc_start;
// no loop - we only process the first KALDI_CUDA_DECODER_1D_BLOCK arcs
// we just want an estimate
CostType total_cost = FLT_MAX;
if (threadIdx.x < narcs) {
int32 iarc = arc_start + threadIdx.x;
CostType arc_fixed_cost = cst_dev_params.d_arc_weights[iarc];
const int32 arc_ilabel = cst_dev_params.d_arc_pdf_ilabels[iarc];
CostType acoustic_cost = -lane_counters->loglikelihoods[arc_ilabel];
total_cost = previous_cost + arc_fixed_cost +
acoustic_cost; // +0.0f, best prev cost is normalized to 0
}
KALDI_CUDA_DECODER_1D_KERNEL_LOOP(bin_id, KALDI_CUDA_DECODER_HISTO_NBINS) {
cst_dev_params.d_histograms.lane(ilane)[bin_id] = 0; // reset for this frame
}
CostType min = BlockReduce(temp_storage).Reduce(total_cost, cub::Min());
if (narcs > 0 && threadIdx.x == 0) {
// narcs > 0 to have at least one valid element in the reduce
CostType new_cutoff = min + orderedIntToFloat(lane_counters->int_beam);
IntegerCostType new_int_cutoff = floatToOrderedInt(new_cutoff);
lane_counters->int_cutoff = new_int_cutoff;
lane_counters->min_int_cost = floatToOrderedInt(min);
}
}
}
// ExpandArc kernel
// This kernel does the actual work of traversing arcs
//
// Pseudo code :
// for all token tok in main_q[main_q_offset...end]:
// u = tok.next_state
// for all arc a(u->v) in the FST:
// v_cost = tok.cost + a.cost + accoustic_cost
//
// if v_cost < cutoff and v_cost < best_state_cost[v]
// generate token associated to v, add to aux_q
// if necessary update cutoff
// if aux_q is getting full, reduce beam
//
// For more information please refer to http://kaldi-asr.org/doc/decoders.html
//
// ExpandArc rely on some preprocessed data to be able to function
// for instance, it needs the prefix sum of the arc degree of all token.state in
// the main_q
// We need to call a Preprocess kernel before ExpandArc
//
// ExpandArc is used for both emitting and nonemitting phases
// Differences between emitting and nonemitting :
// 1) params.d_q_arc_offset contains offsets to either emitting or
// nonemitting arcs.
// It is transparent for this kernel. The differentiation was done in
// the Preprocess kernel,
// which is responsible for filling the params.d_q_arc_offset array
// 2) Computation of the acoustic cost. If nonemitting, it is equal to 0.
// If emitting, we need
// to use values from the acoustic model (through the d_loglikelihoods
// array)
//
// Note : ExpandArc is not the only kernel able to traverse arcs.
// FinalizeProcessNonemitting contains a simplified version of expand for only
// one CUDA block
template <bool IS_EMITTING>
__global__ void expand_arcs_kernel(DeviceParams cst_dev_params,
KernelParams params) {
// BlockScan that we will use to compute token indexes in the output queue,
// and to find the min cost in the block
typedef cub::BlockScan<int2, KALDI_CUDA_DECODER_1D_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage sh_temp_storage_scan;
// This kernel writes the new token to the output queue aux_q
// We will request a spot to store all the new tokens created by threads in
// this CUDA block
// sh_aux_q_index_block_offset indicates where to store them in the aux_q
// tokens created in this CUDA block will be store in :
// aux_q[sh_aux_q_index_block_offset], aux_q[sh_aux_q_index_block_offset + 1],
__shared__ int32 sh_aux_q_index_block_offset;
const int nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
LaneCounters *lane_counters = cst_dev_params.d_lanes_counters.lane(ilane);
const int32 main_q_offset = lane_counters->main_q_local_offset;
const int32 main_q_end = lane_counters->main_q_narcs_and_end.y;
const int32 total_narcs = lane_counters->main_q_narcs_and_end.x;
KALDI_CUDA_DECODER_1D_BLOCK_OFFSET_KERNEL_LOOP(block_offset, thread_idx,
total_narcs) {
int2 adaptive_int_beam_with_validity_index =
lane_counters->adaptive_int_beam_with_validity_index;
const int32 ichannel = lane_counters->channel_to_compute;
// Important : this thread is not responsible for a token in the input
// queue main_q
// but for an arc, going out of a token in the main_q
// The main_q contains in total total_narcs
// and this thread will compute the main_q_arc_index-th arc of the main_q
// For instance, first thread in the grid with threadIdx.x == 0 and
// blockIdx.x == 0
// will process the first arc of the token in main_q[main_q_offset + 0]
// (if that token has at least one arc)
//
// This insure a perfect one thread = one arc load balancing
// but we have work to do to know exactly which arc is the
// main_q_arc_index-th arc
// (what's its source ? its destination ? its arc_idx the FST CSR ?)
int32 main_q_arc_index = block_offset + thread_idx;
// We'll need those variables later in the kernel
// we declare them outside of the "valid_input" scope
// to be able to access them later
int32 main_q_idx;
int32 arc_idx;
StateId arc_next_state;
IntegerCostType int_total_cost = INT_MAX;
if (main_q_arc_index < total_narcs) {
// Current thread must take care of main_q_arc_index-th arc
// we need to now what's the source of that arc
// ie which token.state in main_q does it start from ?
// We use a binary search in the prefix sum of the token's degree to get
// that information
//
// Example : main_q contains 3 tokens
// - First token is associated to a state which has 3 outgoing arc
// - Second token is associated to a state which has 0 outgoing arc
// - Third token is associated to a state which has 2 outgoing arc
//
// We store the degrees in an array :
// [3, 0, 2]
//
// We then compute the exclusive prefix sum of that array :
// [0, 3, 3, 5]
//
// In total, we have 5 arcs in the main_q. ExpandArc will use 5 threads.
//
// Let's say we are the fifth thread in ExpandArc.
// we have threadIdx.x == 4, and blockIdx.x == 0
// it gives us main_q_arc_index == 4
// From there we have no idea what we're supposed to do next, we need to
// have information about the
// arc that we're supposed to traverse
//
// To do that, we look for the maximum index maxle_i in the prefix sum
// array such prefix_sum[i] <= 4
//
// [0, 3, 3, 5]
// |
// here
// maxle_i = 2
// it means that our source token is at index 2 in the main_q
// and we are computing the arc at index (main_q_arc_index -
// prefix_sum[maxle_i]) of that token
// ie the arc at index (4-3) = 1, the second arc of the second token in
// main_q
// Searching for the source of the arc that we will process
// (main_q_arc_index)
// we could preprocess the search in the preprocess kernels - for now
// this kernel is fast enough
const int32 *degrees_prefix_sum =
cst_dev_params.d_main_q_degrees_prefix_sum.channel(ichannel);
main_q_idx = binsearch_maxle(degrees_prefix_sum, main_q_arc_index,
main_q_offset, main_q_end - 1);
// state_first_arc_idx_in_main_q
// d_main_q_degrees_prefix_sum contains the prefix sum of the
// degrees of all tokens in the main_q
// d_main_q_degrees_prefix_sum[main_q_idx] contains the number of arc
// in the main_q until that token
const int32 state_first_arc_idx_in_main_q =
degrees_prefix_sum[main_q_idx];
// arc_offset_start is the offset in the CSR, to find the arcs
// related to the state main_q_state_[main_q_idx]
// it was set by the preprocess kernel
const int32 arc_offset_start =
cst_dev_params.d_main_q_arc_offsets.channel(ichannel)[main_q_idx];
// local_arc_index is the arc index for that state
// if local_arc_index == 2, we will process the second arc
// of state main_q_state_[main_q_idx]
const int32 local_arc_index =
main_q_arc_index - state_first_arc_idx_in_main_q;
// corresponding arc_idx in the FST
arc_idx = arc_offset_start + local_arc_index;
// Destination of that arc
arc_next_state = cst_dev_params.d_arc_nextstates[arc_idx];
// Building the total cost incrementally
// we'll add the acoustic cost and the old token's cost
const CostType arc_fixed_cost = cst_dev_params.d_arc_weights[arc_idx];
const CostType prev_token_cost = orderedIntToFloat(
cst_dev_params.d_main_q_state_and_cost.channel(ichannel)[main_q_idx]
.y);
CostType total_cost = prev_token_cost + arc_fixed_cost;
const int32 prev_state =
cst_dev_params.d_main_q_state_and_cost.channel(ichannel)[main_q_idx]
.x;
if (IS_EMITTING) {
const int32 arc_ilabel = cst_dev_params.d_arc_pdf_ilabels[arc_idx];
CostType acoustic_cost = -lane_counters->loglikelihoods[arc_ilabel];
total_cost += acoustic_cost;
}
int_total_cost = floatToOrderedInt(total_cost);
// If the total_cost is too large compared to our cutoff (beam search)
// then let's drop it
const IntegerCostType int_cutoff = lane_counters->int_cutoff;
if (int_total_cost >= int_cutoff) int_total_cost = INT_MAX;
}
// If int_total_cost < INT_MAX, it means that :
// - this thread had a valid input (main_q_arc_index < total_narcs)
// - the total_cost of the generated token is < cutoff
// We will then add that new token in the output queue, aux_q
// We need to know where to put that token in the aux_q
// we'll first compute its index inside the CUDA block
// the first valid output token in the CUDA block will have index 0,
// the second index 1... We compute that using a prefix sum
//
// We also need to find the overall min cost in the CUDA block
// a prefix sum is a scan operation, and a min a reduce operation
// we can perform a reduce operation using a scan (using the last value)
// we compute the prefix sum and the min in one scan, using the data
// struct CostTypeAndInt
const int32 has_successor = (int_total_cost < INT_MAX) ? 1 : 0;
int2 int_cost_and_index = {int_total_cost, has_successor};
BlockScan(sh_temp_storage_scan)
.InclusiveScan(int_cost_and_index, int_cost_and_index, MinPlus());
if (KALDI_CUDA_DECODER_IS_LAST_1D_THREAD()) {
// We are in a divergent branch
// This is the last thread. The last value of the inclusive scan is the
// total
const int32 total_successors_in_block = int_cost_and_index.y;
// Requesting a spot of size total_successors_in_block in the aux_q
// note: using 2 atomics here to avoid adding another kernel
// first request more space
const int aux_q_index_block_offset = atomicAdd(
&lane_counters->aux_q_requested, total_successors_in_block);
// check for overflow in aux_q
// We try to prevent an overflow from happening using an adaptive beam
// (cf GetAdaptiveBeam)
if (aux_q_index_block_offset + total_successors_in_block <
cst_dev_params.aux_q_capacity) {
// no overflow
// grab the aux_q offset
sh_aux_q_index_block_offset =
atomicAdd(&lane_counters->aux_q_end, total_successors_in_block);
// We are not overflowing the queue, updating the global values
IntegerCostType global_min_int_cost = lane_counters->min_int_cost;
IntegerCostType local_min_int_cost = int_cost_and_index.x;
// if we found a lower min_cost, update the global value
if (local_min_int_cost < global_min_int_cost) {
global_min_int_cost = local_min_int_cost;
atomicMin(&lane_counters->min_int_cost, global_min_int_cost);
CostType beam =
orderedIntToFloat(adaptive_int_beam_with_validity_index.x);
IntegerCostType new_int_cutoff = floatToOrderedInt(
orderedIntToFloat(local_min_int_cost) + beam);
atomicMin(&lane_counters->int_cutoff, new_int_cutoff);
}
int32 beam_valid_until_idx =
adaptive_int_beam_with_validity_index.y;
if (aux_q_index_block_offset >= beam_valid_until_idx) {
// This beam is no longer valid. Updating it
UpdateAdaptiveBeam(
cst_dev_params, aux_q_index_block_offset, global_min_int_cost,
&adaptive_int_beam_with_validity_index, lane_counters);
}
} else {
// sh_aux_q_index_block_offset is in shared memory
// its value is currently invalid (overflow)
// we set it to a special value and use it as a flag to broadcast
// the fact that we have an overflow and that all threads should exit
sh_aux_q_index_block_offset = cst_dev_params.aux_q_capacity;
// Setting the flag for the host. It will be used to print a warning
// to stderr
lane_counters->q_overflow |= OVERFLOW_AUX_Q;
// We do not jump to end_lane now, because only
// the first thread (threadIdx.x == 0) is executing this
// We wait until the end of the divergent branch
}
}
// Sync'ing for two reasons :
// - Broadcasting sh_aux_q_index_block_offset
// - reusing sh_temp_storage (cf CUB's doc)
__syncthreads();
// The only case where we can have that condition met,
// is if we detected an overflow if the previous lines
if (sh_aux_q_index_block_offset == cst_dev_params.aux_q_capacity)
goto end_lane; // done for this lane
//
// If we're executing the following lines it means everything
// is valid and we are not overflowing the aux_q
//
int_cost_and_index.y -= has_successor; // we want the exclusive sum now
const int32 aux_q_block_index = int_cost_and_index.y;
const int32 aux_q_index = sh_aux_q_index_block_offset + aux_q_block_index;
if (has_successor) {
// We save the new token to the aux_q
cst_dev_params.d_aux_q_state_and_cost.lane(ilane)[aux_q_index] = {
arc_next_state, int_total_cost};
// Index of the parent token
// the parent is the token used as input (source of arc)
// that parent is at index main_q_idx in the GPU memory
// However, the main_q is emptied before processing a new frame
// we need to add the offset related to the previous frames index
// we add cst_dev_params.main_q_global_offset
const int32 prev_token =
lane_counters->main_q_global_offset + main_q_idx;
assert(main_q_idx >= 0 && main_q_idx < cst_dev_params.main_q_capacity);
cst_dev_params.d_aux_q_info.lane(ilane)[aux_q_index] = {prev_token,
arc_idx};
}
}
end_lane:; // ";" is an empty statement
}
}
// post_expand_kernel
// Called after expand_arcs_kernel
// Takes care of what needs to be done after an expand_arcs_kernel
// execution. Mostly resetting the beam (if adaptive beam was triggered,
// the max_active_ kernels will take care of selecting a good beam),
// resetting the number of arcs in the main_q (we've processed them all),
// etc.
// Threads (1,1,1)
// Blocks (1, nlanes_used, 1)
template <bool IS_EMITTING>
__global__ void post_expand_kernel(DeviceParams cst_dev_params,
KernelParams params) {
const int nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
LaneCounters *lane_counters = cst_dev_params.d_lanes_counters.lane(ilane);
LaneCounters *h_lane_counters = cst_dev_params.h_lanes_counters.lane(ilane);
const int prev_main_q_end = lane_counters->main_q_narcs_and_end.y;
const int prev_n_extra_prev_tokens =
lane_counters->main_q_n_extra_prev_tokens;
const int aux_q_end = lane_counters->aux_q_end;
CostType min_cost = orderedIntToFloat(lane_counters->min_int_cost);
// The next step is the contracting step from aux_q to main_q
// It will need the aux_q_end value. But it will also empty the aux_q
// We're resetting aux_q_end to 0 now, but we're saving its old value
// in another place
lane_counters->post_expand_aux_q_end = aux_q_end;
h_lane_counters->post_expand_aux_q_end = aux_q_end; // pinned memory
h_lane_counters->q_overflow = lane_counters->q_overflow; // pinned memory
lane_counters->aux_q_end = 0;
lane_counters->aux_q_requested = 0;
// We are done processing those arcs
lane_counters->main_q_narcs_and_end.x = 0;
// Resetting the adaptive beam
lane_counters->adaptive_int_beam_with_validity_index.x =
lane_counters->int_beam;
lane_counters->adaptive_int_beam_with_validity_index.y =
cst_dev_params.adaptive_beam_static_segment;
CostType beam = orderedIntToFloat(lane_counters->int_beam);
lane_counters->int_cutoff = floatToOrderedInt(min_cost + beam);
// If the adaptive beam kicked in, we want to reset the beam
// the max-active process will take care of selecting the right beam
if (IS_EMITTING) {
// the main_q contains the tokens from the previous frame
// after emitting, we won't use them anymore to create new tokens
// we reset the main_q
lane_counters->main_q_narcs_and_end = {0, 0};
lane_counters->main_q_requested = 0;
// The main_q was flushed - we need to update the global_offset
lane_counters->main_q_global_offset += prev_main_q_end;
if (threadIdx.x == 0 && blockIdx.x == 0)
lane_counters->main_q_extra_prev_tokens_global_offset +=
prev_n_extra_prev_tokens;
// Moving local offset. Tokens created by last expand
// will be pruned, and survivals will be moved at the end
// of the main q. Those tokens will be placed after local_offset
lane_counters->main_q_requested = 0;
CostType min_cost = orderedIntToFloat(lane_counters->min_int_cost);
lane_counters->min_histo_cost = min_cost;
lane_counters->max_histo_cost = min_cost + beam;
lane_counters->histo_bin_width = beam / (KALDI_CUDA_DECODER_HISTO_NBINS-1);
} else {
lane_counters->main_q_local_offset = prev_main_q_end;
// reset requested to end of queue
lane_counters->main_q_requested = prev_main_q_end;
}
}
}
__global__ void post_contract_and_preprocess_kernel(DeviceParams cst_dev_params,
KernelParams params) {
const int nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
LaneCounters *lane_counters = cst_dev_params.d_lanes_counters.lane(ilane);
LaneCounters *h_lane_counters = cst_dev_params.h_lanes_counters.lane(ilane);
int2 main_q_narcs_and_end = lane_counters->main_q_narcs_and_end;
h_lane_counters->main_q_narcs_and_end =
main_q_narcs_and_end; // pinned memory
h_lane_counters->q_overflow = lane_counters->q_overflow; // pinned memory
atomicMin(&lane_counters->main_q_n_emitting_tokens, main_q_narcs_and_end.y);
}
}
// Meta-kernel (merging preprocess and expand) but only works with 1 CUDA block
// Used to avoid calling multiple main kernels (such as expand_arcs_kernel)
// for the tail of non emitting (lots of iterations with small number of arcs)
//
// Code is greatly simplified because we use only one CTA / lane
//
// Repeat until new queue empty:
// 1) Preprocess
// 2) Expand arcs
//
// The preprocess stage is not done on the first iteration, because it was
// already done by the ProcessAndContract kernel. We always call
// PruneAndPreprocess before calling FinalizeProcessNonemitting
//
// At the end, this kernel finalize the computation for current frame,
// so that it's ready for next ProcessEmitting
//
// This kernel works, but can be greatly simplified now.
__launch_bounds__(KALDI_CUDA_DECODER_LARGEST_1D_BLOCK, 1) __global__
void finalize_process_non_emitting_kernel(DeviceParams cst_dev_params,
KernelParams params) {
typedef cub::BlockScan<int2, KALDI_CUDA_DECODER_LARGEST_1D_BLOCK>
Int2BlockScan;
typedef cub::BlockScan<int, KALDI_CUDA_DECODER_LARGEST_1D_BLOCK> IntBlockScan;
__shared__ typename IntBlockScan::TempStorage sh_temp_storage_int_scan;
__shared__ typename Int2BlockScan::TempStorage sh_temp_storage_int2_scan;
const int nlanes = params.nlanes_used;
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
LaneCounters *lane_counters = cst_dev_params.d_lanes_counters.lane(ilane);
const int32 ichannel = lane_counters->channel_to_compute;
ChannelCounters *channel_counters =
cst_dev_params.d_channels_counters.channel(ichannel);
int2 both = lane_counters->main_q_narcs_and_end;
int32 main_q_narcs = both.x;
int32 main_q_end = both.y;
int32 main_q_local_offset = lane_counters->main_q_local_offset;
const int32 main_q_global_offset = lane_counters->main_q_global_offset;
// aux_q is empty when this kernel is called
int32 aux_q_end = 0;
IntegerCostType int_cutoff = lane_counters->int_cutoff;
while (main_q_narcs > 0) {
// Step 1 : ExpandArcs
KALDI_CUDA_DECODER_1D_BLOCK_OFFSET_KERNEL_LOOP(offset, thread_idx,
main_q_narcs) {
const int32 main_q_arc_idx = offset + thread_idx;
// For details on how this code works, please refer to comments in
// expand_arcs
IntegerCostType total_int_cost = INT_MAX;
int32 arc_idx;
StateId arc_next_state;
int32 main_q_idx;
if (main_q_arc_idx < main_q_narcs) {
main_q_idx = binsearch_maxle(
cst_dev_params.d_main_q_degrees_prefix_sum.channel(ichannel),
main_q_arc_idx, main_q_local_offset, main_q_end - 1);
const int32 state_first_arc_idx_in_main_q =
cst_dev_params.d_main_q_degrees_prefix_sum.channel(
ichannel)[main_q_idx];
const int32 arc_offset_start =
cst_dev_params.d_main_q_arc_offsets.channel(ichannel)[main_q_idx];
arc_idx = arc_offset_start +
(main_q_arc_idx - state_first_arc_idx_in_main_q);