From 115ecb2b19a94c0f570b7ce0321d91b41e8a10c5 Mon Sep 17 00:00:00 2001 From: aryapratinavseth Date: Thu, 9 Jan 2025 12:13:51 +0000 Subject: [PATCH] updated conv1d --- .../pytorch_backtrace/backtrace/backtrace.py | 18 +- .../backtrace/utils/contrast.py | 116 ++++++++++-- .../pytorch_backtrace/backtrace/utils/prop.py | 168 ++++++++++++++++-- .../tf_backtrace/backtrace/backtrace.py | 46 +++-- .../backtrace/utils/utils_contrast.py | 104 +++++++++-- .../backtrace/utils/utils_prop.py | 150 ++++++++++++---- 6 files changed, 518 insertions(+), 84 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/backtrace/backtrace.py b/dl_backtrace/pytorch_backtrace/backtrace/backtrace.py index c247873..a2e7376 100644 --- a/dl_backtrace/pytorch_backtrace/backtrace/backtrace.py +++ b/dl_backtrace/pytorch_backtrace/backtrace/backtrace.py @@ -427,6 +427,13 @@ def proportional_eval( b1 = l1.state_dict()['bias'] pad1 = l1.padding[0] strides1 = l1.stride[0] + dilation1 = l1.dilation + groups1 = l1.groups + if not isinstance(b1, np.ndarray): + b1 = b1.numpy() + if not isinstance(w1, np.ndarray): + w1 = w1.numpy() # Convert PyTorch tensor to NumPy array + temp_wt = UP.calculate_wt_conv_1d( all_wt[start_layer], all_out[child_nodes[0]][0], @@ -434,6 +441,8 @@ def proportional_eval( b1, pad1, strides1, + dilation1, + groups1, activation_dict[model_resource[1][start_layer]["name"]], ) all_wt[child_nodes[0]] += temp_wt.T @@ -713,10 +722,17 @@ def contrast_eval(self, all_out, multiplier=100.0, b1 = l1.state_dict()['bias'] pad1 = l1.padding[0] strides1 = l1.stride[0] + dilation1 = l1.dilation + groups1 = l1.groups + if not isinstance(b1, np.ndarray): + b1 = b1.numpy() + if not isinstance(w1, np.ndarray): + w1 = w1.numpy() # Convert PyTorch tensor to NumPy array + temp_wt_pos,temp_wt_neg = UC.calculate_wt_conv_1d(all_wt_pos[start_layer], all_wt_neg[start_layer], all_out[child_nodes[0]][0], - w1,b1, pad1, strides1, + w1,b1, pad1, strides1,dilation1,groups1, activation_dict[model_resource[1][start_layer]['name']]) all_wt_pos[child_nodes[0]] += temp_wt_pos.T all_wt_neg[child_nodes[0]] += temp_wt_neg.T diff --git a/dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py b/dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py index 4d1c159..66c7a6d 100644 --- a/dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py +++ b/dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py @@ -271,7 +271,6 @@ def calculate_lstm_wt(self, input_data): output.append(h) return output - class LSTM_backtrace(object): def __init__( self, num_cells, units, weights, return_sequence=False, go_backwards=False @@ -798,7 +797,6 @@ def calculate_wt_maxpool(wts, inp, pool_size): test_wt = test_wt[0 : inp.shape[0], 0 : inp.shape[1], :] return test_wt - def calculate_wt_avgpool(wts_pos, wts_neg, inp, pool_size): pad1 = pool_size[0] pad2 = pool_size[1] @@ -852,7 +850,6 @@ def calculate_wt_avgpool(wts_pos, wts_neg, inp, pool_size): test_wt_neg = test_wt_neg[0 : inp.shape[0], 0 : inp.shape[1], :] return test_wt_pos, test_wt_neg - def calculate_wt_gavgpool(wts_pos, wts_neg, inp): channels = wts_pos.shape[0] wt_mat_pos = np.zeros_like(inp) @@ -947,26 +944,115 @@ def calculate_wt_conv_unit_1d(patch, wts_pos, wts_neg, w, b, act): return wt_mat_pos, wt_mat_neg -def calculate_wt_conv_1d(wts_pos, wts_neg, inp, w, b, padding, stride, act): +def calculate_padding_1d_v2(kernel_size, input_length, padding, strides, dilation=1, const_val=0.0): + """ + Calculate and apply padding to match TensorFlow Keras behavior for 'same', 'valid', and custom padding. + + Parameters: + kernel_size (int): Size of the convolutional kernel. + input_length (int): Length of the input along the spatial dimension. + padding (str/int/tuple): Padding type. Can be: + - 'valid': No padding. + - 'same': Pads to maintain output length equal to input length (stride=1). + - int: Symmetric padding on both sides. + - tuple/list: Explicit padding [left, right]. + strides (int): Stride size of the convolution. + dilation (int): Dilation rate for the kernel. + const_val (float): Value used for padding. Defaults to 0.0. + + Returns: + padded_length (int): Length of the input after padding. + paddings (list): Padding applied on left and right sides. + """ + effective_kernel_size = (kernel_size - 1) * dilation + 1 # Effective size considering dilation + + if padding == 'valid': + return input_length, [0, 0] + elif padding == 'same': + # Total padding required to keep output size same as input + pad_total = max(0, (input_length - 1) * strides + effective_kernel_size - input_length) + pad_left = pad_total // 2 + pad_right = pad_total - pad_left + elif isinstance(padding, int): + pad_left = padding + pad_right = padding + elif isinstance(padding, (list, tuple)) and len(padding) == 2: + pad_left, pad_right = padding + else: + raise ValueError("Invalid padding. Use 'valid', 'same', an integer, or a tuple/list of two integers.") + + padded_length = input_length + pad_left + pad_right + return padded_length, [pad_left, pad_right] + +def calculate_wt_conv_unit_1d_v2(patch, wts_pos, wts_neg, w, b, act): + """ + Calculate the weights for a single patch of the input with positive and negative contributions. + """ + k = w + bias = b + conv_out = np.einsum("ijk,ij->ijk", k, patch) + p_ind = conv_out>0 + p_ind = conv_out*p_ind + p_sum = np.einsum("ijk->k",p_ind) + n_ind = conv_out<0 + n_ind = conv_out*n_ind + n_sum = np.einsum("ijk->k",n_ind)*-1.0 + p_agg_wt_pos, p_agg_wt_neg, n_agg_wt_pos, n_agg_wt_neg, p_sum, n_sum = calculate_base_wt_array(p_sum, n_sum, bias, wts_pos, wts_neg) + wt_mat_pos = np.zeros_like(k) + wt_mat_neg = np.zeros_like(k) + + wt_mat_pos += (p_ind / p_sum) * p_agg_wt_pos + wt_mat_pos += (n_ind / n_sum) * n_agg_wt_pos * -1.0 + wt_mat_neg += (p_ind / p_sum) * p_agg_wt_neg + wt_mat_neg += (n_ind / n_sum) * n_agg_wt_neg * -1.0 + + wt_mat_pos = np.sum(wt_mat_pos, axis=-1) + wt_mat_neg = np.sum(wt_mat_neg, axis=-1) + + return wt_mat_pos, wt_mat_neg + +def calculate_wt_conv_1d(wts_pos, wts_neg, inp, w, b, padding, stride, dilation, groups, act): + """ + Perform relevance propagation for 1D convolution with dilation and groups. + """ wts_pos=wts_pos.T wts_neg=wts_neg.T inp=inp.T w = w.T - input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride) - out_ds_pos = np.zeros_like(input_padded) - out_ds_neg = np.zeros_like(input_padded) - for ind in range(wts_pos.shape[0]): - indexes = np.arange(ind * stride, ind * stride + w.shape[0]) - tmp_patch = input_padded[indexes] - updates_pos,updates_neg = calculate_wt_conv_unit_1d(tmp_patch, wts_pos[ind, :], wts_neg[ind, :], w, b, act) - out_ds_pos[indexes] += updates_pos - out_ds_neg[indexes] += updates_neg + kernel_size = w.shape[0] + input_length = inp.shape[0] - out_ds_pos = out_ds_pos[paddings[0][0]:(paddings[0][0] + inp.shape[0])] - out_ds_neg = out_ds_neg[paddings[0][0]:(paddings[0][0] + inp.shape[0])] + # Compute and apply padding + padded_length, paddings = calculate_padding_1d_v2(kernel_size, input_length, padding, stride, dilation) + inp_padded = np.pad(inp, ((paddings[0], paddings[1]), (0, 0)), 'constant', constant_values=0) + + out_ds_pos = np.zeros_like(inp_padded) + out_ds_neg = np.zeros_like(inp_padded) + + input_channels_per_group = inp.shape[1] // groups + output_channels_per_group = wts_pos.shape[1] // groups + + # Handle grouped convolutions + for g in range(groups): + input_start = g * input_channels_per_group + input_end = (g + 1) * input_channels_per_group + output_start = g * output_channels_per_group + output_end = (g + 1) * output_channels_per_group + + for ind in range(wts_pos.shape[0]): + start_idx = ind * stride + tmp_patch = inp_padded[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] + updates_pos, updates_neg = calculate_wt_conv_unit_1d_v2(tmp_patch, wts_pos[ind, output_start:output_end], wts_neg[ind, output_start:output_end], w[:, :, output_start:output_end], b[output_start:output_end], act) + out_ds_pos[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] += updates_pos + out_ds_neg[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] += updates_neg + + out_ds_pos = out_ds_pos[paddings[0]:(paddings[0] + inp.shape[0]), :] + out_ds_neg = out_ds_neg[paddings[0]:(paddings[0] + inp.shape[0]), :] return out_ds_pos, out_ds_neg + + def calculate_wt_max_unit_1d(patch, wts, pool_size): pmax = np.max(patch, axis=0) indexes = (patch-pmax)==0 diff --git a/dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py b/dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py index 94e321f..916b8a0 100644 --- a/dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py +++ b/dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py @@ -824,21 +824,167 @@ def calculate_wt_conv_unit_1d(patch, wts, w, b, act): wt_mat = np.sum(wt_mat, axis=-1) return wt_mat -def calculate_wt_conv_1d(wts, inp, w, b, padding, stride, act): + +def calculate_padding_1d_v2(kernel_size, input_length, padding, strides, dilation=1, const_val=0.0): + """ + Calculate and apply padding to match TensorFlow Keras behavior for 'same', 'valid', and custom padding. + + Parameters: + kernel_size (int): Size of the convolutional kernel. + input_length (int): Length of the input along the spatial dimension. + padding (str/int/tuple): Padding type. Can be: + - 'valid': No padding. + - 'same': Pads to maintain output length equal to input length (stride=1). + - int: Symmetric padding on both sides. + - tuple/list: Explicit padding [left, right]. + strides (int): Stride size of the convolution. + dilation (int): Dilation rate for the kernel. + const_val (float): Value used for padding. Defaults to 0.0. + + Returns: + padded_length (int): Length of the input after padding. + paddings (list): Padding applied on left and right sides. + """ + effective_kernel_size = (kernel_size - 1) * dilation + 1 # Effective size considering dilation + + if padding == 'valid': + return input_length, [0, 0] + elif padding == 'same': + # Total padding required to keep output size same as input + pad_total = max(0, (input_length - 1) * strides + effective_kernel_size - input_length) + pad_left = pad_total // 2 + pad_right = pad_total - pad_left + elif isinstance(padding, int): + pad_left = padding + pad_right = padding + elif isinstance(padding, (list, tuple)) and len(padding) == 2: + pad_left, pad_right = padding + else: + raise ValueError("Invalid padding. Use 'valid', 'same', an integer, or a tuple/list of two integers.") + + padded_length = input_length + pad_left + pad_right + return padded_length, [pad_left, pad_right] + + +def calculate_wt_conv_unit_1d_v2(patch, wts, w, b, act): + """ + Compute relevance for a single patch of the input tensor. + + Parameters: + patch (ndarray): Patch of input corresponding to the receptive field of the kernel. + wts (ndarray): Relevance values from the next layer for this patch. + w (ndarray): Weights of the convolutional kernel. + b (ndarray): Bias values for the convolution. + act (dict): Activation function details. Should contain: + - "type": Type of activation ('mono' or 'non_mono'). + - "range": Range dictionary with "l" (lower bound) and "u" (upper bound). + - "func": Function to apply for activation. + + Returns: + wt_mat (ndarray): Weighted relevance matrix for the patch. + """ + kernel = w + bias = b + wt_mat = np.zeros_like(kernel) + # Compute convolution output + conv_out = np.einsum("ijk,ij->ijk", kernel, patch) + # Separate positive and negative contributions + p_ind = conv_out > 0 + p_ind = conv_out * p_ind + p_sum = np.einsum("ijk->k",p_ind) + n_ind = conv_out < 0 + n_ind = conv_out * n_ind + n_sum = np.einsum("ijk->k",n_ind) * -1.0 + t_sum = p_sum + n_sum + # Handle positive and negative bias + bias_pos = bias * (bias > 0) + bias_neg = bias * (bias < 0) * -1.0 + # Activation handling (saturate weights if specified) + p_saturate = p_sum > 0 + n_saturate = n_sum > 0 + if act["type"] == 'mono': + if act["range"]["l"]: + temp_ind = t_sum > act["range"]["l"] + p_saturate = temp_ind + if act["range"]["u"]: + temp_ind = t_sum < act["range"]["u"] + n_saturate = temp_ind + elif act["type"] == 'non_mono': + t_act = act["func"](t_sum) + p_act = act["func"](p_sum + bias_pos) + n_act = act["func"](-1 * (n_sum + bias_neg)) + if act["range"]["l"]: + temp_ind = t_sum > act["range"]["l"] + p_saturate = p_saturate * temp_ind + if act["range"]["u"]: + temp_ind = t_sum < act["range"]["u"] + n_saturate = n_saturate * temp_ind + temp_ind = np.abs(t_act - p_act) > 1e-5 + n_saturate = n_saturate * temp_ind + temp_ind = np.abs(t_act - n_act) > 1e-5 + p_saturate = p_saturate * temp_ind + + # Aggregate weights + p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate + n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate + + wt_mat = wt_mat + (p_ind * p_agg_wt) + wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0) + wt_mat = np.sum(wt_mat, axis=-1) + + return wt_mat + +def calculate_wt_conv_1d(wts, inp, w, b, padding, stride, dilation, groups, act): + """ + Perform relevance propagation for a 1D convolution layer with support for groups and dilation. + + Parameters: + wts (ndarray): Relevance values from the next layer (shape: [output_length, output_channels]). + inp (ndarray): Input tensor for the current layer (shape: [input_length, input_channels]). + w (ndarray): Weights of the convolutional kernel (shape: [kernel_size, input_channels/groups, output_channels/groups]). + b (ndarray): Bias values for the convolution (shape: [output_channels]). + padding (str/int/tuple): Padding mode. Supports 'same', 'valid', integer, or tuple of (left, right). + stride (int): Stride of the convolution. + dilation (int): Dilation rate for the kernel. + groups (int): Number of groups for grouped convolution. + act (dict): Activation function details. + + Returns: + out_ds (ndarray): Propagated relevance for the input tensor. + """ wts = wts.T inp = inp.T - w = w.T - stride=stride - input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride) - out_ds = np.zeros_like(input_padded) - for ind in range(wts.shape[0]): - indexes = np.arange(ind * stride, ind * stride + w.shape[0]) - tmp_patch = input_padded[indexes] - updates = calculate_wt_conv_unit_1d(tmp_patch, wts[ind, :], w, b, act) - out_ds[indexes] += updates - out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])] + w = w.T + kernel_size = w.shape[0] + input_length = inp.shape[0] + + # Compute and apply padding + padded_length, paddings = calculate_padding_1d_v2(kernel_size, input_length, padding, stride, dilation) + inp_padded = np.pad(inp, ((paddings[0], paddings[1]), (0, 0)), 'constant', constant_values=0) + # Initialize output relevance map + out_ds = np.zeros_like(inp_padded) + + # Handle grouped convolution + input_channels_per_group = inp.shape[1] // groups + output_channels_per_group = wts.shape[1] // groups + + for g in range(groups): + input_start = g * input_channels_per_group + input_end = (g + 1) * input_channels_per_group + output_start = g * output_channels_per_group + output_end = (g + 1) * output_channels_per_group + + for ind in range(wts.shape[0]): + start_idx = ind * stride + tmp_patch = inp_padded[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] + updates = calculate_wt_conv_unit_1d_v2(tmp_patch, wts[ind, output_start:output_end], w[:, :, output_start:output_end], b[output_start:output_end], act) + out_ds[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] += updates + + # Remove padding + out_ds = out_ds[paddings[0]:(paddings[0] + input_length), :] return out_ds + def calculate_wt_max_unit_1d(patch, wts): pmax = np.max(patch, axis=0) indexes = (patch - pmax) == 0 diff --git a/dl_backtrace/tf_backtrace/backtrace/backtrace.py b/dl_backtrace/tf_backtrace/backtrace/backtrace.py index 98e4708..0e7a169 100644 --- a/dl_backtrace/tf_backtrace/backtrace/backtrace.py +++ b/dl_backtrace/tf_backtrace/backtrace/backtrace.py @@ -301,14 +301,23 @@ def proportional_eval(self, all_out, start_wt=[] , b1 = l1.weights[1] pad1 = l1.padding strides1 = l1.strides[0] + dilation1 = l1.dilation_rate[0] + groups1 = l1.groups + if not isinstance(b1, np.ndarray): + b1 = b1.numpy() + if not isinstance(w1, np.ndarray): + w1 = w1.numpy() # Convert PyTorch tensor to NumPy array + temp_wt = UP.calculate_wt_conv_1d( - all_wt[start_layer], - all_out[child_nodes[0]][0], - w1, - b1, - pad1, - strides1, - activation_dict[model_resource["graph"][start_layer]["name"]], + wts=all_wt[start_layer], + inp=all_out[child_nodes[0]][0], + w=w1, + b=b1, + padding=pad1, + stride=strides1, + dilation=dilation1, + groups=groups1, + act=activation_dict[model_resource["graph"][start_layer]["name"]], ) all_wt[child_nodes[0]] += temp_wt elif model_resource["graph"][start_layer]["class"] == "Conv1DTranspose": @@ -591,11 +600,24 @@ def contrast_eval(self, all_out , b1 = l1.weights[1] pad1 = l1.padding strides1 = l1.strides[0] - temp_wt_pos,temp_wt_neg = UC.calculate_wt_conv_1d(all_wt_pos[start_layer], - all_wt_neg[start_layer], - all_out[child_nodes[0]][0], - w1,b1, pad1, strides1, - activation_dict[model_resource["graph"][start_layer]['name']]) + dilation1 = l1.dilation_rate[0] + groups1 = l1.groups + if not isinstance(b1, np.ndarray): + b1 = b1.numpy() + if not isinstance(w1, np.ndarray): + w1 = w1.numpy() # Convert PyTorch tensor to NumPy array + + temp_wt_pos,temp_wt_neg = UC.calculate_wt_conv_1d( + wts_pos=all_wt_pos[start_layer], + wts_neg=all_wt_neg[start_layer], + inp=all_out[child_nodes[0]][0], + w=w1, + b=b1, + padding=pad1, + stride=strides1, + dilation=dilation1, + groups=groups1, + act=activation_dict[model_resource["graph"][start_layer]['name']]) all_wt_pos[child_nodes[0]] += temp_wt_pos all_wt_neg[child_nodes[0]] += temp_wt_neg elif model_resource["graph"][start_layer]["class"] == "Conv1DTranspose": diff --git a/dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py b/dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py index 533a775..37a6fb8 100644 --- a/dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py +++ b/dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py @@ -840,20 +840,100 @@ def calculate_wt_conv_unit_1d(patch, wts_pos, wts_neg, w, b, act): return wt_mat_pos, wt_mat_neg -def calculate_wt_conv_1d(wts_pos, wts_neg, inp, w, b, padding, stride, act): - input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride) - out_ds_pos = np.zeros_like(input_padded) - out_ds_neg = np.zeros_like(input_padded) - for ind in range(wts_pos.shape[0]): - indexes = np.arange(ind * stride, ind * stride + w.shape[0]) - tmp_patch = input_padded[indexes] - updates_pos,updates_neg = calculate_wt_conv_unit_1d(tmp_patch, wts_pos[ind, :], wts_neg[ind, :], w, b, act) +def calculate_padding_1d_v2(kernel_size, input_length, padding, strides, dilation=1, const_val=0.0): + """ + Calculate and apply padding to match TensorFlow Keras behavior for 'same', 'valid', and custom padding. + + Parameters: + kernel_size (int): Size of the convolutional kernel. + input_length (int): Length of the input along the spatial dimension. + padding (str/int/tuple): Padding type. Can be: + - 'valid': No padding. + - 'same': Pads to maintain output length equal to input length (stride=1). + - int: Symmetric padding on both sides. + - tuple/list: Explicit padding [left, right]. + strides (int): Stride size of the convolution. + dilation (int): Dilation rate for the kernel. + const_val (float): Value used for padding. Defaults to 0.0. + + Returns: + padded_length (int): Length of the input after padding. + paddings (list): Padding applied on left and right sides. + """ + effective_kernel_size = (kernel_size - 1) * dilation + 1 # Effective size considering dilation - out_ds_pos[indexes] += updates_pos - out_ds_neg[indexes] += updates_neg + if padding == 'valid': + return input_length, [0, 0] + elif padding == 'same': + # Total padding required to keep output size same as input + pad_total = max(0, (input_length - 1) * strides + effective_kernel_size - input_length) + pad_left = pad_total // 2 + pad_right = pad_total - pad_left + elif isinstance(padding, int): + pad_left = padding + pad_right = padding + elif isinstance(padding, (list, tuple)) and len(padding) == 2: + pad_left, pad_right = padding + else: + raise ValueError("Invalid padding. Use 'valid', 'same', an integer, or a tuple/list of two integers.") - out_ds_pos = out_ds_pos[paddings[0]:(paddings[0] + inp.shape[0])] - out_ds_neg = out_ds_neg[paddings[0]:(paddings[0] + inp.shape[0])] + padded_length = input_length + pad_left + pad_right + return padded_length, [pad_left, pad_right] + + +def calculate_wt_conv_unit_1d_v2(patch, wts_pos, wts_neg, w, b, act): + """ + Calculate the weights for a single patch of the input with positive and negative contributions. + """ + k = w + bias = b + conv_out = np.einsum("ijk,ij->ijk", k, patch) + p_ind = conv_out>0 + p_ind = conv_out*p_ind + p_sum = np.einsum("ijk->k",p_ind) + n_ind = conv_out<0 + n_ind = conv_out*n_ind + n_sum = np.einsum("ijk->k",n_ind)*-1.0 + p_agg_wt_pos, p_agg_wt_neg, n_agg_wt_pos, n_agg_wt_neg, p_sum, n_sum = calculate_base_wt_array(p_sum, n_sum, bias, wts_pos, wts_neg) + wt_mat_pos = np.zeros_like(k) + wt_mat_neg = np.zeros_like(k) + wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos) + wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0 + wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg) + wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0 + wt_mat_pos = np.sum(wt_mat_pos,axis=-1) + wt_mat_neg = np.sum(wt_mat_neg,axis=-1) + return wt_mat_pos, wt_mat_neg + +def calculate_wt_conv_1d(wts_pos, wts_neg, inp, w, b, padding, stride, dilation, groups, act): + """ + Perform relevance propagation for 1D convolution with dilation and groups. + """ + kernel_size = w.shape[0] + input_length = inp.shape[0] + padded_length, paddings = calculate_padding_1d_v2(kernel_size, input_length, padding, stride, dilation) + inp_padded = np.pad(inp, ((paddings[0], paddings[1]), (0, 0)), 'constant', constant_values=0) + out_ds_pos = np.zeros_like(inp_padded) + out_ds_neg = np.zeros_like(inp_padded) + + input_channels_per_group = inp.shape[1] // groups + output_channels_per_group = wts_pos.shape[1] // groups + + for g in range(groups): + input_start = g * input_channels_per_group + input_end = (g + 1) * input_channels_per_group + output_start = g * output_channels_per_group + output_end = (g + 1) * output_channels_per_group + + for ind in range(wts_pos.shape[0]): + start_idx = ind * stride + tmp_patch = inp_padded[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] + updates_pos, updates_neg = calculate_wt_conv_unit_1d_v2(tmp_patch, wts_pos[ind, output_start:output_end], wts_neg[ind, output_start:output_end], w[:, :, output_start:output_end], b[output_start:output_end], act) + out_ds_pos[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] += updates_pos + out_ds_neg[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] += updates_neg + + out_ds_pos = out_ds_pos[paddings[0]:(paddings[0] + inp.shape[0]), :] + out_ds_neg = out_ds_neg[paddings[0]:(paddings[0] + inp.shape[0]), :] return out_ds_pos, out_ds_neg def calculate_wt_max_unit_1d(patch, wts, pool_size): diff --git a/dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py b/dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py index c33727f..a1ab009 100644 --- a/dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +++ b/dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py @@ -754,33 +754,72 @@ def calculate_padding_1d(kernel_size, inp, padding, strides, const_val=0.0): pad_left = int(np.floor(pad_total / 2.0)) pad_right = int(np.ceil(pad_total / 2.0)) - inp_pad = np.pad(inp, (pad_left, pad_right), 'constant', constant_values=const_val) - return inp_pad, [pad_left, pad_right] -def calculate_padding_1d(kernel_size, inp, padding, strides, const_val=0.0): - if padding == 'valid': - return inp, [0, 0] - else: - remainder = inp.shape[0] % strides - if remainder == 0: - pad_total = max(0, kernel_size - strides) - else: - pad_total = max(0, kernel_size - remainder) - - pad_left = int(np.floor(pad_total / 2.0)) - pad_right = int(np.ceil(pad_total / 2.0)) - inp_pad = np.pad(inp, ((pad_left, pad_right),(0,0)), 'constant', constant_values=const_val) return inp_pad, [pad_left, pad_right] +def calculate_padding_1d_v2(kernel_size, input_length, padding, strides, dilation=1, const_val=0.0): + """ + Calculate and apply padding to match TensorFlow Keras behavior for 'same', 'valid', and custom padding. + + Parameters: + kernel_size (int): Size of the convolutional kernel. + input_length (int): Length of the input along the spatial dimension. + padding (str/int/tuple): Padding type. Can be: + - 'valid': No padding. + - 'same': Pads to maintain output length equal to input length (stride=1). + - int: Symmetric padding on both sides. + - tuple/list: Explicit padding [left, right]. + strides (int): Stride size of the convolution. + dilation (int): Dilation rate for the kernel. + const_val (float): Value used for padding. Defaults to 0.0. + + Returns: + padded_length (int): Length of the input after padding. + paddings (list): Padding applied on left and right sides. + """ + effective_kernel_size = (kernel_size - 1) * dilation + 1 # Effective size considering dilation -def calculate_wt_conv_unit_1d(patch, wts, w, b, act): - k = w.numpy() - bias = b.numpy() - b_ind = bias > 0 - bias_pos = bias * b_ind - b_ind = bias < 0 - bias_neg = bias * b_ind * -1.0 - conv_out = np.einsum("ijk,ij->ijk", k, patch) + if padding == 'valid': + return input_length, [0, 0] + elif padding == 'same': + # Total padding required to keep output size same as input + pad_total = max(0, (input_length - 1) * strides + effective_kernel_size - input_length) + pad_left = pad_total // 2 + pad_right = pad_total - pad_left + elif isinstance(padding, int): + pad_left = padding + pad_right = padding + elif isinstance(padding, (list, tuple)) and len(padding) == 2: + pad_left, pad_right = padding + else: + raise ValueError("Invalid padding. Use 'valid', 'same', an integer, or a tuple/list of two integers.") + + padded_length = input_length + pad_left + pad_right + return padded_length, [pad_left, pad_right] + +def calculate_wt_conv_unit_1d_v2(patch, wts, w, b, act): + """ + Compute relevance for a single patch of the input tensor. + + Parameters: + patch (ndarray): Patch of input corresponding to the receptive field of the kernel. + wts (ndarray): Relevance values from the next layer for this patch. + w (ndarray): Weights of the convolutional kernel. + b (ndarray): Bias values for the convolution. + act (dict): Activation function details. Should contain: + - "type": Type of activation ('mono' or 'non_mono'). + - "range": Range dictionary with "l" (lower bound) and "u" (upper bound). + - "func": Function to apply for activation. + + Returns: + wt_mat (ndarray): Weighted relevance matrix for the patch. + """ + kernel = w + bias = b + wt_mat = np.zeros_like(kernel) + # Compute convolution output + conv_out = np.einsum("ijk,ij->ijk", kernel, patch) + # Separate positive and negative contributions p_ind = conv_out > 0 p_ind = conv_out * p_ind p_sum = np.einsum("ijk->k",p_ind) @@ -788,7 +827,10 @@ def calculate_wt_conv_unit_1d(patch, wts, w, b, act): n_ind = conv_out * n_ind n_sum = np.einsum("ijk->k",n_ind) * -1.0 t_sum = p_sum + n_sum - wt_mat = np.zeros_like(k) + # Handle positive and negative bias + bias_pos = bias * (bias > 0) + bias_neg = bias * (bias < 0) * -1.0 + # Activation handling (saturate weights if specified) p_saturate = p_sum > 0 n_saturate = n_sum > 0 if act["type"] == 'mono': @@ -812,23 +854,65 @@ def calculate_wt_conv_unit_1d(patch, wts, w, b, act): n_saturate = n_saturate * temp_ind temp_ind = np.abs(t_act - n_act) > 1e-5 p_saturate = p_saturate * temp_ind + + # Aggregate weights p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate wt_mat = wt_mat + (p_ind * p_agg_wt) wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0) wt_mat = np.sum(wt_mat, axis=-1) + return wt_mat -def calculate_wt_conv_1d(wts, inp, w, b, padding, stride, act): - input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride) - out_ds = np.zeros_like(input_padded) - for ind in range(wts.shape[0]): - indexes = np.arange(ind * stride, ind * stride + w.shape[0]) - tmp_patch = input_padded[indexes] - updates = calculate_wt_conv_unit_1d(tmp_patch, wts[ind, :], w, b, act) - out_ds[indexes] += updates - out_ds = out_ds[paddings[0]:(paddings[0] + inp.shape[0])] +def calculate_wt_conv_1d(wts, inp, w, b, padding, stride, dilation, groups, act): + """ + Perform relevance propagation for a 1D convolution layer with support for groups and dilation. + + Parameters: + wts (ndarray): Relevance values from the next layer (shape: [output_length, output_channels]). + inp (ndarray): Input tensor for the current layer (shape: [input_length, input_channels]). + w (ndarray): Weights of the convolutional kernel (shape: [kernel_size, input_channels/groups, output_channels/groups]). + b (ndarray): Bias values for the convolution (shape: [output_channels]). + padding (str/int/tuple): Padding mode. Supports 'same', 'valid', integer, or tuple of (left, right). + stride (int): Stride of the convolution. + dilation (int): Dilation rate for the kernel. + groups (int): Number of groups for grouped convolution. + act (dict): Activation function details. + + Returns: + out_ds (ndarray): Propagated relevance for the input tensor. + """ + kernel_size = w.shape[0] + input_length = inp.shape[0] + + # Compute and apply padding + padded_length, paddings = calculate_padding_1d_v2(kernel_size, input_length, padding, stride, dilation) + inp_padded = np.pad(inp, ((paddings[0], paddings[1]), (0, 0)), 'constant', constant_values=0) + # Initialize output relevance map + out_ds = np.zeros_like(inp_padded) + + # Handle grouped convolution + input_channels_per_group = inp.shape[1] // groups + output_channels_per_group = wts.shape[1] // groups + + for g in range(groups): + input_start = g * input_channels_per_group + input_end = (g + 1) * input_channels_per_group + output_start = g * output_channels_per_group + output_end = (g + 1) * output_channels_per_group + + for ind in range(wts.shape[0]): + start_idx = ind * stride + tmp_patch = inp_padded[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] + updates = calculate_wt_conv_unit_1d_v2(tmp_patch, + wts[ind, output_start:output_end], + w[:, :, output_start:output_end], + b[output_start:output_end], + act) + out_ds[start_idx:start_idx + kernel_size * dilation:dilation, input_start:input_end] += updates + + out_ds = out_ds[paddings[0]:(paddings[0] + input_length), :] return out_ds def calculate_wt_max_unit_1d(patch, wts):