Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Tensor sparsity metric #1293

Closed
wants to merge 10 commits into from
54 changes: 47 additions & 7 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -2686,6 +2686,9 @@ view.ValueTextView = class extends view.Control {
line.innerText = item;
line.style.whiteSpace = style;
break;
case 'percentage':
line.innerText = `${(item * 100).toFixed(1)}%`;
break;
default:
line.innerText = item;
break;
Expand Down Expand Up @@ -3162,7 +3165,16 @@ view.TensorSidebar = class extends view.ObjectSidebar {
if (value.type) {
const item = new view.ValueView(this._view, value, '');
this.add('type', item);
item.toggle();
}

if (value.initializer && value.initializer.category === 'Initializer') {
lutzroeder marked this conversation as resolved.
Show resolved Hide resolved
this.addHeader('Metrics');

const tensor = new view.Tensor(value.initializer);
const metrics = tensor.metrics;
for (const metric of metrics) {
this.addProperty(metric.name, [metric.value], metric.style);
}
}
}
};
Expand Down Expand Up @@ -3623,10 +3635,11 @@ view.FindSidebar = class extends view.Control {

view.Argument = class {

constructor(name, value, type) {
constructor(name, value, type, style) {
this.name = name;
this.value = value;
this.type = type;
this.style = style;
}
};

Expand All @@ -3638,6 +3651,7 @@ view.Tensor = class {
this._encoding = tensor.encoding;
this._layout = tensor.type.layout;
this._stride = tensor.stride;
this._metrics = null;
switch (this._encoding) {
case undefined:
case '':
Expand Down Expand Up @@ -4119,12 +4133,38 @@ view.Tensor = class {
}

get metrics() {
const metrics = Array.from(this._tensor.metrics || []);
const keys = new Set(metrics.map((metrics) => metrics.name));
if (!keys.has('sparisity')) {
// metrics.push(new view.Argument('sparisity', 0, 'float32'));
if (this._metrics === null) {
const value = this.value;
lutzroeder marked this conversation as resolved.
Show resolved Hide resolved

const metrics = Array.from(this._tensor.metrics || []);
const keys = new Set(metrics.map((metrics) => metrics.name));
if (!keys.has('sparsity')) {
let num_zeros = 0;
let num_parameters = 0;
const stack = [value];
while (stack.length > 0) {
const val = stack.pop();
if (Array.isArray(val)) {
for (const element of val) {
stack.push(element);
}
} else {
num_zeros += Number(val === 0);
num_parameters += 1;
}
}

if (num_parameters > 0) {
metrics.push(new view.Argument('sparsity', num_zeros / num_parameters, 'float32', 'percentage'));
} else {
metrics.push(new view.Argument('sparsity', 0, 'float32', 'percentage'));
lutzroeder marked this conversation as resolved.
Show resolved Hide resolved
}
}

this._metrics = metrics;
}
return metrics;

return this._metrics;
}
};

Expand Down