From 65c69c6cf1e50a5b72f92a8666ae4406ba688306 Mon Sep 17 00:00:00 2001 From: Charles H Martin Date: Wed, 5 Jun 2019 16:16:49 -0700 Subject: [PATCH] Adding support for GPT --- weightwatcher/weightwatcher.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 295b3fa..be113cb 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -129,7 +129,7 @@ def model_is_valid(self, model=None): return True - + # test with https://github.com/osmr/imgclsmob/blob/master/README.md def analyze(self, model=None, layers=[], min_size=50, max_size=0, compute_alphas=False, compute_lognorms=True, normalize=False, @@ -150,10 +150,25 @@ def analyze(self, model=None, layers=[], min_size=50, max_size=0, compute_lognorms: Compute the log norms of the weight matrices. """ + model = model or self.model res = {} + # Treats Custom Conv1D / Attention Layers (ex: GPT, BERT) + # since they have custom subclass from nn.Module (OpenAIGPTModel) + def isPyTorchLinearOrConv1D(l): + tf = False + import torch.nn as nn + if isinstance(l, nn.Conv1d): + tf = True + if isinstance(l, nn.Module): + if hasattr(l, 'weight'): + w = l.weight.detach().numpy() + if len(w.shape)==2: # Linear + tf = True + return tf + if not isinstance(layers, list): layers = [layers] layer_ids = [x for x in layers if str(x).isdigit()] @@ -220,6 +235,18 @@ def analyze(self, model=None, layers=[], min_size=50, max_size=0, # weights = weigths[0]+weights[1] # CONV1D layer + elif (isPyTorchLinearOrConv1D(l)): + res[i] = {"layer_type": LAYER_TYPE.CONV1D} + + if (len(layer_types) > 0 and + not any(layer_type & LAYER_TYPE.CONV1D for layer_type in layer_types)): + msg = "Skipping (Layer type not requested to analyze)" + self.debug("Layer {}: {}".format(i+1, msg)) + res[i]["message"] = msg + continue + + weights = [np.array(l.weight.data.clone().cpu())] + elif (isinstance(l, keras.layers.convolutional.Conv1D)): res[i] = {"layer_type": LAYER_TYPE.CONV1D} @@ -289,7 +316,8 @@ def get_details(self, results=None): Return a pandas dataframe """ df = self.compute_details(results=results) - return df[:-1].dropna(axis=1, how='all').set_index("layer_id") # prune the last line summary + details = df[:-1].dropna(axis=1, how='all').set_index("layer_id") # prune the last line summary + return details[details.layer_type.notna()] def compute_details(self, results=None): """