From 3599596ba245813d658f697a98656975b69d78f5 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Thu, 8 Feb 2024 17:29:54 +0100 Subject: [PATCH] ack model changes --- .../mdl-en-2019-Q3-librispeech/expected | 171 +++++++++--------- 1 file changed, 82 insertions(+), 89 deletions(-) diff --git a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected index b8d110a84c..d396503a7f 100644 --- a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected +++ b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected @@ -26,20 +26,20 @@ fragment scan_body_0( ) -> (i"fastlstm1.c_new": tensor, i"fastlstm1.r_new": tensor, i"fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0": tensor) { i"fastlstm1.peephole0.mul" = mul(i"fastlstm1.peephole0.mul.fix-rank-0-1", i"fastlstm1.c"); - i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"], expr = "ka,kn->bn", acc = "f32", output = ""); + i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.0..256" = add(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256"); i"fastlstm1.four_parts.split-over-1.0..256" = add(i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.0..256", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"); i"fastlstm1.peephole0.output" = add(i"fastlstm1.peephole0.mul", i"fastlstm1.four_parts.split-over-1.0..256"); i"fastlstm1.peephole0.output.nolin" = sigmoid(i"fastlstm1.peephole0.output"); i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [0]); - i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "ka,kn->bn", acc = "f32", output = ""); + i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.512..768" = add(i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768"); i"fastlstm1.four_parts.split-over-1.512..768" = add(i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.512..768", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"); i"fastlstm1.four_parts.j.nolin" = tanh(i"fastlstm1.four_parts.split-over-1.512..768"); i"fastlstm1.c_update" = mul(i"fastlstm1.peephole0.output.nolin", i"fastlstm1.four_parts.j.nolin"); i"fastlstm1.peephole1.mul" = mul(i"fastlstm1.peephole1.mul.fix-rank-0-1", i"fastlstm1.c"); i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [0]); - i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "ka,kn->bn", acc = "f32", output = ""); + i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.256..512" = add(i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512"); i"fastlstm1.four_parts.split-over-1.256..512" = add(i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.256..512", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"); i"fastlstm1.peephole1.output" = add(i"fastlstm1.peephole1.mul", i"fastlstm1.four_parts.split-over-1.256..512"); @@ -49,13 +49,13 @@ fragment scan_body_0( i"fastlstm1.tanh_c" = tanh(i"fastlstm1.c_new"); i"fastlstm1.peephole2.mul" = mul(i"fastlstm1.peephole2.mul.fix-rank-0-1", i"fastlstm1.c_new"); i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [0]); - i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "ka,kn->bn", acc = "f32", output = ""); + i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024" = add(i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024"); i"fastlstm1.four_parts.split-over-1.768..1024" = add(i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"); i"fastlstm1.peephole2.output" = add(i"fastlstm1.peephole2.mul", i"fastlstm1.four_parts.split-over-1.768..1024"); i"fastlstm1.peephole2.output.nolin" = sigmoid(i"fastlstm1.peephole2.output"); i"fastlstm1.m" = mul(i"fastlstm1.tanh_c", i"fastlstm1.peephole2.output.nolin"); - i"fastlstm1.h_new.W.split-over-1.0..128" = tract_core_einsum([i"fastlstm1.m", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->na", acc = "f32", output = ""); + i"fastlstm1.h_new.W.split-over-1.0..128" = tract_core_einsum([i"fastlstm1.m", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->mna", acc = "f32", output = ""); i"fastlstm1.h_new.split-over-1.0..128" = add(i"fastlstm1.h_new.W.split-over-1.0..128", i"fastlstm1.h_new.split-1-over-1.0..128.slice"); i"fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0" = unsqueeze(i"fastlstm1.m", axes = [0]); i"fastlstm1.r_new" = i"fastlstm1.h_new.split-over-1.0..128"; @@ -84,20 +84,20 @@ fragment scan_body_1( ) -> (i"fastlstm2.c_new": tensor, i"fastlstm2.r_new": tensor, i"fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0": tensor) { i"fastlstm2.peephole0.mul" = mul(i"fastlstm2.peephole0.mul.fix-rank-0-1", i"fastlstm2.c"); - i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); + i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.0..256" = add(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256"); i"fastlstm2.four_parts.split-over-1.0..256" = add(i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.0..256", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"); i"fastlstm2.peephole0.output" = add(i"fastlstm2.peephole0.mul", i"fastlstm2.four_parts.split-over-1.0..256"); i"fastlstm2.peephole0.output.nolin" = sigmoid(i"fastlstm2.peephole0.output"); i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [0]); - i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); + i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.512..768" = add(i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768"); i"fastlstm2.four_parts.split-over-1.512..768" = add(i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.512..768", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"); i"fastlstm2.four_parts.j.nolin" = tanh(i"fastlstm2.four_parts.split-over-1.512..768"); i"fastlstm2.c_update" = mul(i"fastlstm2.peephole0.output.nolin", i"fastlstm2.four_parts.j.nolin"); i"fastlstm2.peephole1.mul" = mul(i"fastlstm2.peephole1.mul.fix-rank-0-1", i"fastlstm2.c"); i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [0]); - i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); + i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.256..512" = add(i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512"); i"fastlstm2.four_parts.split-over-1.256..512" = add(i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.256..512", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"); i"fastlstm2.peephole1.output" = add(i"fastlstm2.peephole1.mul", i"fastlstm2.four_parts.split-over-1.256..512"); @@ -107,13 +107,13 @@ fragment scan_body_1( i"fastlstm2.tanh_c" = tanh(i"fastlstm2.c_new"); i"fastlstm2.peephole2.mul" = mul(i"fastlstm2.peephole2.mul.fix-rank-0-1", i"fastlstm2.c_new"); i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [0]); - i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); + i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024" = add(i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024"); i"fastlstm2.four_parts.split-over-1.768..1024" = add(i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"); i"fastlstm2.peephole2.output" = add(i"fastlstm2.peephole2.mul", i"fastlstm2.four_parts.split-over-1.768..1024"); i"fastlstm2.peephole2.output.nolin" = sigmoid(i"fastlstm2.peephole2.output"); i"fastlstm2.m" = mul(i"fastlstm2.tanh_c", i"fastlstm2.peephole2.output.nolin"); - i"fastlstm2.h_new.W.split-over-1.0..128" = tract_core_einsum([i"fastlstm2.m", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->an", acc = "f32", output = ""); + i"fastlstm2.h_new.W.split-over-1.0..128" = tract_core_einsum([i"fastlstm2.m", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->mna", acc = "f32", output = ""); i"fastlstm2.h_new.split-over-1.0..128" = add(i"fastlstm2.h_new.W.split-over-1.0..128", i"fastlstm2.h_new.split-1-over-1.0..128.slice"); i"fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0" = unsqueeze(i"fastlstm2.m", axes = [0]); i"fastlstm2.r_new" = i"fastlstm2.h_new.split-over-1.0..128"; @@ -135,63 +135,60 @@ graph network(input) -> (output) { i"lda.output_conv" = conv(i"lda.output_input", i"lda.kernel.0", i"lda.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"lda.output" = transpose(i"lda.output_conv", axes = [0, 2, 1]); i"lda.output.rm_n" = squeeze(i"lda.output", axes = [0]); - i"tdnn1.affine.output.filters_as_co_ci" = variable(label = "tdnn1.affine.output.filters_as_co_ci", shape = [256, 200]); - i"tdnn1.affine.output.einsum" = matmul(i"tdnn1.affine.output.filters_as_co_ci", i"lda.output.rm_n", transposeA = false, transposeB = true); - i"tdnn1.affine.output.bias.reshape" = variable(label = "tdnn1.affine.output.bias.reshape", shape = [256, 1]); + i"tdnn1.affine.output.einsum.fix_a.0" = unsqueeze(i"lda.output.rm_n", axes = [0]); + i"tdnn1.affine.output.einsum.fix_b.0" = variable(label = "tdnn1.affine.output.einsum.fix_b.0", shape = [1, 256, 200]); + i"tdnn1.affine.output.einsum" = matmul(i"tdnn1.affine.output.einsum.fix_b.0", i"tdnn1.affine.output.einsum.fix_a.0", transposeA = false, transposeB = true); + i"tdnn1.affine.output.bias.reshape" = variable(label = "tdnn1.affine.output.bias.reshape", shape = [1, 256, 1]); i"tdnn1.affine.output" = add(i"tdnn1.affine.output.einsum", i"tdnn1.affine.output.bias.reshape"); - i"tdnn1.relu.output.low.cst" = [[0.0]]; + i"tdnn1.relu.output.low.cst" = [[[0.0]]]; i"tdnn1.relu.output.low" = max(i"tdnn1.affine.output", i"tdnn1.relu.output.low.cst"); i"tdnn1.renorm.reduced.sq" = square(i"tdnn1.relu.output.low"); - i"tdnn1.renorm.reduced.sum" = sum_reduce(i"tdnn1.renorm.reduced.sq", axes = [0]); - i"tdnn1.renorm.scaled-recip" = [[0.00390625]]; + i"tdnn1.renorm.reduced.sum" = sum_reduce(i"tdnn1.renorm.reduced.sq", axes = [1]); + i"tdnn1.renorm.scaled-recip" = [[[0.00390625]]]; i"tdnn1.renorm.scaled" = mul(i"tdnn1.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.scaled"); i"tdnn1.renorm.output" = mul(i"tdnn1.relu.output.low", i"tdnn1.renorm.output-recip"); - i"tdnn2.affine.output.delay" = tract_pulse_delay(i"tdnn1.renorm.output", axis = 1, delay = 0, overlap = 2); - i"tdnn2.affine.output.add_n" = unsqueeze(i"tdnn2.affine.output.delay", axes = [0]); + i"tdnn2.affine.output.delay" = tract_pulse_delay(i"tdnn1.renorm.output", axis = 2, delay = 0, overlap = 2); i"tdnn2.affine.kernel.0" = variable(label = "tdnn2.affine.kernel.0", shape = [256, 256, 3]); i"tdnn2.affine.bias.0" = variable(label = "tdnn2.affine.bias.0", shape = [256]); - i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output.add_n", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); + i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output.delay", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn2.affine.output" = i"tdnn2.affine.output_conv"; - i"tdnn2.affine.output.rm_n" = squeeze(i"tdnn2.affine.output", axes = [0]); - i"tdnn2.relu.output.low.cst" = [[0.0]]; - i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output.rm_n", i"tdnn2.relu.output.low.cst"); + i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output", i"tdnn1.relu.output.low.cst"); i"tdnn2.renorm.reduced.sq" = square(i"tdnn2.relu.output.low"); - i"tdnn2.renorm.reduced.sum" = sum_reduce(i"tdnn2.renorm.reduced.sq", axes = [0]); - i"tdnn2.renorm.scaled-recip" = [[0.00390625]]; - i"tdnn2.renorm.scaled" = mul(i"tdnn2.renorm.reduced.sum", i"tdnn2.renorm.scaled-recip"); + i"tdnn2.renorm.reduced.sum" = sum_reduce(i"tdnn2.renorm.reduced.sq", axes = [1]); + i"tdnn2.renorm.scaled" = mul(i"tdnn2.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.scaled"); i"tdnn2.renorm.output" = mul(i"tdnn2.relu.output.low", i"tdnn2.renorm.output-recip"); - i"tdnn3.affine.output.add_n" = unsqueeze(i"tdnn2.renorm.output", axes = [0]); i"tdnn3.affine.kernel.0" = variable(label = "tdnn3.affine.kernel.0", shape = [256, 256, 3]); i"tdnn3.affine.bias.0" = variable(label = "tdnn3.affine.bias.0", shape = [256]); - i"tdnn3.affine.output_conv" = conv(i"tdnn3.affine.output.add_n", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]); + i"tdnn3.affine.output_conv" = conv(i"tdnn2.renorm.output", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn3.affine.output" = i"tdnn3.affine.output_conv"; - i"tdnn3.affine.output.rm_n" = squeeze(i"tdnn3.affine.output", axes = [0]); - i"tdnn3.relu.output.low.cst" = [[0.0]]; - i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output.rm_n", i"tdnn3.relu.output.low.cst"); + i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output", i"tdnn1.relu.output.low.cst"); i"tdnn3.renorm.reduced.sq" = square(i"tdnn3.relu.output.low"); - i"tdnn3.renorm.reduced.sum" = sum_reduce(i"tdnn3.renorm.reduced.sq", axes = [0]); - i"tdnn3.renorm.scaled-recip" = [[0.00390625]]; - i"tdnn3.renorm.scaled" = mul(i"tdnn3.renorm.reduced.sum", i"tdnn3.renorm.scaled-recip"); + i"tdnn3.renorm.reduced.sum" = sum_reduce(i"tdnn3.renorm.reduced.sq", axes = [1]); + i"tdnn3.renorm.scaled" = mul(i"tdnn3.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.scaled"); i"tdnn3.renorm.output" = mul(i"tdnn3.relu.output.low", i"tdnn3.renorm.output-recip"); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", shape = [256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", transposeA = true, transposeB = false); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", axes = [0]); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a.0" = unsqueeze(i"tdnn3.renorm.output", axes = [0]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.0" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.0", shape = [1, 256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a.0", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.0", transposeA = true, transposeB = false); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [1, 0, 2]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1", shape = [1, 1, 256, 256]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a.0", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1", transposeA = true, transposeB = false); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [1]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0", axes = [1, 0, 2]); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a.0" = unsqueeze(i"tdnn3.renorm.output", axes = [0]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.0" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.0", shape = [1, 256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a.0", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.0", transposeA = true, transposeB = false); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [1, 0, 2]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1", shape = [1, 1, 256, 256]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a.0", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1", transposeA = true, transposeB = false); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [1]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0", axes = [1, 0, 2]); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a.0" = unsqueeze(i"tdnn3.renorm.output", axes = [0]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.0" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.0", shape = [1, 256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a.0", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.0", transposeA = true, transposeB = false); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [1, 0, 2]); - i"incoming-94/0" = variable(label = "incoming-94/0", shape = [1, 256]); - i"incoming-40/0" = variable(label = "incoming-40/0", shape = [128, 1]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1", shape = [1, 1, 256, 256]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a.0", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1", transposeA = true, transposeB = false); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [1]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0", axes = [1, 0, 2]); + i"tap.tap.fastlstm1.c_init.0-35/0-82/0" = variable(label = "tap.tap.fastlstm1.c_init.0-35/0-82/0", shape = [1, 256]); + i"tap.fastlstm1.r_init.0-36/0" = variable(label = "tap.fastlstm1.r_init.0-36/0", shape = [1, 128, 1]); i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice" = variable(label = "fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", shape = [128, 256]); i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice" = variable(label = "fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", shape = [128, 256]); i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice" = variable(label = "fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", shape = [128, 256]); @@ -201,64 +198,59 @@ graph network(input) -> (output) { i"fastlstm1.four_parts.split-1-over-1.512..768.slice" = variable(label = "fastlstm1.four_parts.split-1-over-1.512..768.slice", shape = [1, 256]); i"fastlstm1.four_parts.split-1-over-1.768..1024.slice" = variable(label = "fastlstm1.four_parts.split-1-over-1.768..1024.slice", shape = [1, 256]); i"fastlstm1.h_new.W.split-1-over-1.0..128.slice" = variable(label = "fastlstm1.h_new.W.split-1-over-1.0..128.slice", shape = [256, 128]); - i"fastlstm1.h_new.split-1-over-1.0..128.slice" = variable(label = "fastlstm1.h_new.split-1-over-1.0..128.slice", shape = [128, 1]); + i"fastlstm1.h_new.split-1-over-1.0..128.slice" = variable(label = "fastlstm1.h_new.split-1-over-1.0..128.slice", shape = [1, 128, 1]); i"fastlstm1.peephole0.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole0.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm1.peephole1.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole1.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm1.peephole2.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole2.mul.fix-rank-0-1", shape = [1, 256]); - ( i"fastlstm1.c_final", i"fastlstm1.c_final_1" ) = tract_core_scan(body = "scan_body_0", scan = [("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0", 0, 1)], full = [("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm1.four_parts.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm1.h_new.W.split-1-over-1.0..128.slice", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm1.h_new.split-1-over-1.0..128.slice", i"fastlstm1.h_new.split-1-over-1.0..128.slice"), ("fastlstm1.peephole0.mul.fix-rank-0-1", i"fastlstm1.peephole0.mul.fix-rank-0-1"), ("fastlstm1.peephole1.mul.fix-rank-0-1", i"fastlstm1.peephole1.mul.fix-rank-0-1"), ("fastlstm1.peephole2.mul.fix-rank-0-1", i"fastlstm1.peephole2.mul.fix-rank-0-1")], state = [("fastlstm1.c", i"incoming-94/0", "fastlstm1.c_new"), ("fastlstm1.r", i"incoming-40/0", "fastlstm1.r_new")], output = [("fastlstm1.r_new", "full", 1, 1), ("fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 2, reset_every_turn = false); + ( i"fastlstm1.c_final", i"fastlstm1.c_final_1" ) = tract_core_scan(body = "scan_body_0", scan = [("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1", 0, 1)], full = [("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm1.four_parts.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm1.h_new.W.split-1-over-1.0..128.slice", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm1.h_new.split-1-over-1.0..128.slice", i"fastlstm1.h_new.split-1-over-1.0..128.slice"), ("fastlstm1.peephole0.mul.fix-rank-0-1", i"fastlstm1.peephole0.mul.fix-rank-0-1"), ("fastlstm1.peephole1.mul.fix-rank-0-1", i"fastlstm1.peephole1.mul.fix-rank-0-1"), ("fastlstm1.peephole2.mul.fix-rank-0-1", i"fastlstm1.peephole2.mul.fix-rank-0-1")], state = [("fastlstm1.c", i"tap.tap.fastlstm1.c_init.0-35/0-82/0", "fastlstm1.c_new"), ("fastlstm1.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm1.r_new")], output = [("fastlstm1.r_new", "full", 2, 1), ("fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 2, reset_every_turn = false); i"fastlstm1.h_new.W.split-over-1.128..256.fix_a.0" = transpose(i"fastlstm1.c_final_1", axes = [1, 0, 2]); - i"fastlstm1.h_new.W.split-over-1.128..256.fix_b.0" = variable(label = "fastlstm1.h_new.W.split-over-1.128..256.fix_b.0", shape = [1, 256, 128]); - i"fastlstm1.h_new.W.split-over-1.128..256" = matmul(i"fastlstm1.h_new.W.split-over-1.128..256.fix_b.0", i"fastlstm1.h_new.W.split-over-1.128..256.fix_a.0", transposeA = true, transposeB = true); - i"fastlstm1.h_new.W.split-over-1.128..256.fix_c.0" = squeeze(i"fastlstm1.h_new.W.split-over-1.128..256", axes = [0]); - i"fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice" = variable(label = "fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice", shape = [128, 1]); + i"fastlstm1.h_new.W.split-over-1.128..256.fix_a.1" = unsqueeze(i"fastlstm1.h_new.W.split-over-1.128..256.fix_a.0", axes = [0]); + i"fastlstm1.h_new.W.split-over-1.128..256.fix_b.1" = variable(label = "fastlstm1.h_new.W.split-over-1.128..256.fix_b.1", shape = [1, 1, 256, 128]); + i"fastlstm1.h_new.W.split-over-1.128..256" = matmul(i"fastlstm1.h_new.W.split-over-1.128..256.fix_b.1", i"fastlstm1.h_new.W.split-over-1.128..256.fix_a.1", transposeA = true, transposeB = true); + i"fastlstm1.h_new.W.split-over-1.128..256.fix_c.0" = squeeze(i"fastlstm1.h_new.W.split-over-1.128..256", axes = [1]); + i"fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice" = variable(label = "fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice", shape = [1, 128, 1]); i"fastlstm1.h_new.split-over-1.128..256" = add(i"fastlstm1.h_new.W.split-over-1.128..256.fix_c.0", i"fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice"); - i"fastlstm1.h_new.concat-1" = concat([i"fastlstm1.c_final", i"fastlstm1.h_new.split-over-1.128..256"], axis = 0); - i"tdnn4.affine.output.delay" = tract_pulse_delay(i"fastlstm1.h_new.concat-1", axis = 1, delay = 0, overlap = 2); - i"tdnn4.affine.output.add_n" = unsqueeze(i"tdnn4.affine.output.delay", axes = [0]); + i"fastlstm1.h_new.concat-1" = concat([i"fastlstm1.c_final", i"fastlstm1.h_new.split-over-1.128..256"], axis = 1); + i"tdnn4.affine.output.delay" = tract_pulse_delay(i"fastlstm1.h_new.concat-1", axis = 2, delay = 0, overlap = 2); i"tdnn4.affine.kernel.0" = variable(label = "tdnn4.affine.kernel.0", shape = [256, 256, 3]); i"tdnn4.affine.bias.0" = variable(label = "tdnn4.affine.bias.0", shape = [256]); - i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output.add_n", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); + i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output.delay", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn4.affine.output" = i"tdnn4.affine.output_conv"; - i"tdnn4.affine.output.rm_n" = squeeze(i"tdnn4.affine.output", axes = [0]); - i"tdnn4.relu.output.low.cst" = [[0.0]]; - i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output.rm_n", i"tdnn4.relu.output.low.cst"); + i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output", i"tdnn1.relu.output.low.cst"); i"tdnn4.renorm.reduced.sq" = square(i"tdnn4.relu.output.low"); - i"tdnn4.renorm.reduced.sum" = sum_reduce(i"tdnn4.renorm.reduced.sq", axes = [0]); - i"tdnn4.renorm.scaled-recip" = [[0.00390625]]; - i"tdnn4.renorm.scaled" = mul(i"tdnn4.renorm.reduced.sum", i"tdnn4.renorm.scaled-recip"); + i"tdnn4.renorm.reduced.sum" = sum_reduce(i"tdnn4.renorm.reduced.sq", axes = [1]); + i"tdnn4.renorm.scaled" = mul(i"tdnn4.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.scaled"); i"tdnn4.renorm.output" = mul(i"tdnn4.relu.output.low", i"tdnn4.renorm.output-recip"); - i"tdnn5.affine.output.delay" = tract_pulse_delay(i"tdnn4.renorm.output", axis = 1, delay = 0, overlap = 2); - i"tdnn5.affine.output.add_n" = unsqueeze(i"tdnn5.affine.output.delay", axes = [0]); + i"tdnn5.affine.output.delay" = tract_pulse_delay(i"tdnn4.renorm.output", axis = 2, delay = 0, overlap = 2); i"tdnn5.affine.kernel.0" = variable(label = "tdnn5.affine.kernel.0", shape = [256, 256, 3]); i"tdnn5.affine.bias.0" = variable(label = "tdnn5.affine.bias.0", shape = [256]); - i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output.add_n", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); + i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output.delay", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn5.affine.output" = i"tdnn5.affine.output_conv"; - i"tdnn5.affine.output.rm_n" = squeeze(i"tdnn5.affine.output", axes = [0]); - i"tdnn5.relu.output.low.cst" = [[0.0]]; - i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output.rm_n", i"tdnn5.relu.output.low.cst"); + i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output", i"tdnn1.relu.output.low.cst"); i"tdnn5.renorm.reduced.sq" = square(i"tdnn5.relu.output.low"); - i"tdnn5.renorm.reduced.sum" = sum_reduce(i"tdnn5.renorm.reduced.sq", axes = [0]); - i"tdnn5.renorm.scaled-recip" = [[0.00390625]]; - i"tdnn5.renorm.scaled" = mul(i"tdnn5.renorm.reduced.sum", i"tdnn5.renorm.scaled-recip"); + i"tdnn5.renorm.reduced.sum" = sum_reduce(i"tdnn5.renorm.reduced.sq", axes = [1]); + i"tdnn5.renorm.scaled" = mul(i"tdnn5.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.scaled"); i"tdnn5.renorm.output" = mul(i"tdnn5.relu.output.low", i"tdnn5.renorm.output-recip"); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", shape = [256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", transposeA = true, transposeB = false); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", axes = [0]); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a.0" = unsqueeze(i"tdnn5.renorm.output", axes = [0]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.0" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.0", shape = [1, 256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a.0", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.0", transposeA = true, transposeB = false); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [1, 0, 2]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1", shape = [1, 1, 256, 256]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a.0", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1", transposeA = true, transposeB = false); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [1]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0", axes = [1, 0, 2]); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a.0" = unsqueeze(i"tdnn5.renorm.output", axes = [0]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.0" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.0", shape = [1, 256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a.0", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.0", transposeA = true, transposeB = false); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [1, 0, 2]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1", shape = [1, 1, 256, 256]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a.0", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1", transposeA = true, transposeB = false); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [1]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0", axes = [1, 0, 2]); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a.0" = unsqueeze(i"tdnn5.renorm.output", axes = [0]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.0" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.0", shape = [1, 256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a.0", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.0", transposeA = true, transposeB = false); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [1, 0, 2]); - i"incoming-154/0" = variable(label = "incoming-154/0", shape = [1, 256]); - i"incoming-65/0" = variable(label = "incoming-65/0", shape = [1, 128]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1", shape = [1, 1, 256, 256]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a.0", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1", transposeA = true, transposeB = false); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [1]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0", axes = [1, 0, 2]); i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice" = variable(label = "fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", shape = [128, 256]); i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice" = variable(label = "fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", shape = [128, 256]); i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice" = variable(label = "fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", shape = [128, 256]); @@ -268,13 +260,14 @@ graph network(input) -> (output) { i"fastlstm2.four_parts.split-1-over-1.512..768.slice" = variable(label = "fastlstm2.four_parts.split-1-over-1.512..768.slice", shape = [1, 256]); i"fastlstm2.four_parts.split-1-over-1.768..1024.slice" = variable(label = "fastlstm2.four_parts.split-1-over-1.768..1024.slice", shape = [1, 256]); i"fastlstm2.h_new.W.split-1-over-1.0..128.slice" = variable(label = "fastlstm2.h_new.W.split-1-over-1.0..128.slice", shape = [256, 128]); - i"fastlstm2.h_new.split-1-over-1.0..128.slice" = variable(label = "fastlstm2.h_new.split-1-over-1.0..128.slice", shape = [1, 128]); + i"fastlstm2.h_new.split-1-over-1.0..128.slice" = variable(label = "fastlstm2.h_new.split-1-over-1.0..128.slice", shape = [1, 128, 1]); i"fastlstm2.peephole0.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole0.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm2.peephole1.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole1.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm2.peephole2.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole2.mul.fix-rank-0-1", shape = [1, 256]); - ( i"fastlstm2.c_final", i"fastlstm2.c_final_1" ) = tract_core_scan(body = "scan_body_1", scan = [("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0", 0, 1)], full = [("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm2.four_parts.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm2.h_new.W.split-1-over-1.0..128.slice", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm2.h_new.split-1-over-1.0..128.slice", i"fastlstm2.h_new.split-1-over-1.0..128.slice"), ("fastlstm2.peephole0.mul.fix-rank-0-1", i"fastlstm2.peephole0.mul.fix-rank-0-1"), ("fastlstm2.peephole1.mul.fix-rank-0-1", i"fastlstm2.peephole1.mul.fix-rank-0-1"), ("fastlstm2.peephole2.mul.fix-rank-0-1", i"fastlstm2.peephole2.mul.fix-rank-0-1")], state = [("fastlstm2.c", i"incoming-154/0", "fastlstm2.c_new"), ("fastlstm2.r", i"incoming-65/0", "fastlstm2.r_new")], output = [("fastlstm2.r_new", "full", 0, 1), ("fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 6, reset_every_turn = false); - i"output.affine.output.W.concat-einsum-slice-k.0.0..128" = variable(label = "output.affine.output.W.concat-einsum-slice-k.0.0..128", shape = [1690, 128]); - i"output.affine.output.W.concat-einsum-k.0..128" = matmul(i"fastlstm2.c_final", i"output.affine.output.W.concat-einsum-slice-k.0.0..128", transposeA = false, transposeB = true); + ( i"fastlstm2.c_final", i"fastlstm2.c_final_1" ) = tract_core_scan(body = "scan_body_1", scan = [("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1", 0, 1)], full = [("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm2.four_parts.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm2.h_new.W.split-1-over-1.0..128.slice", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm2.h_new.split-1-over-1.0..128.slice", i"fastlstm2.h_new.split-1-over-1.0..128.slice"), ("fastlstm2.peephole0.mul.fix-rank-0-1", i"fastlstm2.peephole0.mul.fix-rank-0-1"), ("fastlstm2.peephole1.mul.fix-rank-0-1", i"fastlstm2.peephole1.mul.fix-rank-0-1"), ("fastlstm2.peephole2.mul.fix-rank-0-1", i"fastlstm2.peephole2.mul.fix-rank-0-1")], state = [("fastlstm2.c", i"tap.tap.fastlstm1.c_init.0-35/0-82/0", "fastlstm2.c_new"), ("fastlstm2.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm2.r_new")], output = [("fastlstm2.r_new", "full", 2, 1), ("fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 6, reset_every_turn = false); + i"output.affine.output.W.concat-einsum-k.0..128.fix_a.0" = variable(label = "output.affine.output.W.concat-einsum-k.0..128.fix_a.0", shape = [1, 1690, 128]); + i"output.affine.output.W.concat-einsum-k.0..128" = matmul(i"fastlstm2.c_final", i"output.affine.output.W.concat-einsum-k.0..128.fix_a.0", transposeA = true, transposeB = true); + i"output.affine.output.W.concat-einsum-k.0..128.fix_c.0" = squeeze(i"output.affine.output.W.concat-einsum-k.0..128", axes = [0]); i"fastlstm2.h_new.W.split-over-1.128..256.fix_a.0" = transpose(i"fastlstm2.c_final_1", axes = [1, 0, 2]); i"fastlstm2.h_new.W.split-over-1.128..256.fix_b.0" = variable(label = "fastlstm2.h_new.W.split-over-1.128..256.fix_b.0", shape = [1, 256, 128]); i"fastlstm2.h_new.W.split-over-1.128..256" = matmul(i"fastlstm2.h_new.W.split-over-1.128..256.fix_b.0", i"fastlstm2.h_new.W.split-over-1.128..256.fix_a.0", transposeA = true, transposeB = true); @@ -283,7 +276,7 @@ graph network(input) -> (output) { i"fastlstm2.h_new.split-over-1.128..256" = add(i"fastlstm2.h_new.W.split-over-1.128..256.fix_c.0", i"fastlstm2.c_final.fastlstm2.h_new.split-1-over-1.128..256.slice"); i"output.affine.output.W.concat-einsum-slice-k.0.128..256" = variable(label = "output.affine.output.W.concat-einsum-slice-k.0.128..256", shape = [1690, 128]); i"output.affine.output.W.concat-einsum-k.128..256" = matmul(i"fastlstm2.h_new.split-over-1.128..256", i"output.affine.output.W.concat-einsum-slice-k.0.128..256", transposeA = true, transposeB = true); - i"output.affine.output.W.concat-einsum-k.add-1" = add(i"output.affine.output.W.concat-einsum-k.0..128", i"output.affine.output.W.concat-einsum-k.128..256"); + i"output.affine.output.W.concat-einsum-k.add-1" = add(i"output.affine.output.W.concat-einsum-k.0..128.fix_c.0", i"output.affine.output.W.concat-einsum-k.128..256"); i"output.affine.bias.0" = variable(label = "output.affine.bias.0", shape = [1, 1690]); i"output.affine.output" = add(i"output.affine.output.W.concat-einsum-k.add-1", i"output.affine.bias.0"); output = i"output.affine.output";