diff --git a/lib/gpt/ml/layer/linear.py b/lib/gpt/ml/layer/linear.py index b21cc30b..49bdceea 100644 --- a/lib/gpt/ml/layer/linear.py +++ b/lib/gpt/ml/layer/linear.py @@ -80,10 +80,12 @@ def _contract(self, w, f): return ret def __call__(self, weights, layer_input): + layer_input = g.util.to_list(layer_input) w = self._get_weight_list(weights) return self._contract(w, layer_input) def projected_gradient_adj(self, weights, layer_input, left): + layer_input = g.util.to_list(layer_input) left = g.util.to_list(left) assert len(weights) == 1