Skip to content

Commit

Permalink
feat: use a generatic for the label's type
Browse files Browse the repository at this point in the history
  • Loading branch information
stropitek committed Jan 24, 2023
1 parent 6bd4669 commit 8e66373
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/__tests__/test.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ describe('Confusion Matrix', () => {

it('should throw if trying to get the count for unexisting label', () => {
const CM = new ConfusionMatrix(full.matrix, full.labels);
expect(() => CM.getCount('A', 'B')).toThrow(/label does not exist/);
expect(() => CM.getCount(4, 5)).toThrow(/label does not exist/);
});
});
62 changes: 31 additions & 31 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
* @param matrix - The confusion matrix, a 2D Array. Rows represent the actual label and columns the predicted label.
* @param labels - Labels of the confusion matrix, a 1D Array
*/
export class ConfusionMatrix {
private labels: Label[];
export class ConfusionMatrix<T extends Label> {
private labels: T[];
private matrix: number[][];
constructor(matrix: number[][], labels: Label[]) {
constructor(matrix: number[][], labels: T[]) {
if (matrix.length !== matrix[0].length) {
throw new Error('Confusion matrix must be square');
}
Expand All @@ -34,15 +34,15 @@ export class ConfusionMatrix {
* @param [options.sort]
* @return Confusion matrix
*/
static fromLabels(
actual: Label[],
predicted: Label[],
options: FromLabelsOptions = {},
static fromLabels<T extends Label>(
actual: T[],
predicted: T[],
options: FromLabelsOptions<T> = {},
) {
if (predicted.length !== actual.length) {
throw new Error('predicted and actual must have the same length');
}
let distinctLabels: Set<Label>;
let distinctLabels: Set<T>;
if (options.labels) {
distinctLabels = new Set(options.labels);
} else {
Expand Down Expand Up @@ -117,7 +117,7 @@ export class ConfusionMatrix {
* Get the number of true positive predictions.
* @param label - The label that should be considered "positive"
*/
getTruePositiveCount(label: Label): number {
getTruePositiveCount(label: T): number {
const index = this.getIndex(label);
return this.matrix[index][index];
}
Expand All @@ -126,7 +126,7 @@ export class ConfusionMatrix {
* Get the number of true negative predictions.
* @param label - The label that should be considered "positive"
*/
getTrueNegativeCount(label: Label) {
getTrueNegativeCount(label: T) {
const index = this.getIndex(label);
let count = 0;
for (let i = 0; i < this.matrix.length; i++) {
Expand All @@ -143,7 +143,7 @@ export class ConfusionMatrix {
* Get the number of false positive predictions.
* @param label - The label that should be considered "positive"
*/
getFalsePositiveCount(label: Label) {
getFalsePositiveCount(label: T) {
const index = this.getIndex(label);
let count = 0;
for (let i = 0; i < this.matrix.length; i++) {
Expand All @@ -158,7 +158,7 @@ export class ConfusionMatrix {
* Get the number of false negative predictions.
* @param label - The label that should be considered "positive"
*/
getFalseNegativeCount(label: Label): number {
getFalseNegativeCount(label: T): number {
const index = this.getIndex(label);
let count = 0;
for (let i = 0; i < this.matrix.length; i++) {
Expand All @@ -173,15 +173,15 @@ export class ConfusionMatrix {
* Get the number of real positive samples.
* @param label - The label that should be considered "positive"
*/
getPositiveCount(label: Label) {
getPositiveCount(label: T) {
return this.getTruePositiveCount(label) + this.getFalseNegativeCount(label);
}

/**
* Get the number of real negative samples.
* @param label - The label that should be considered "positive"
*/
getNegativeCount(label: Label) {
getNegativeCount(label: T) {
return this.getTrueNegativeCount(label) + this.getFalsePositiveCount(label);
}

Expand All @@ -190,7 +190,7 @@ export class ConfusionMatrix {
* @param label - The label to search for
* @throws if the label is not found
*/
getIndex(label: Label): number {
getIndex(label: T): number {
const index = this.labels.indexOf(label);
if (index === -1) throw new Error('The label does not exist');
return index;
Expand All @@ -202,7 +202,7 @@ export class ConfusionMatrix {
* @param label - The label that should be considered "positive"
* @return The true positive rate [0-1]
*/
getTruePositiveRate(label: Label) {
getTruePositiveRate(label: T) {
return this.getTruePositiveCount(label) / this.getPositiveCount(label);
}

Expand All @@ -212,7 +212,7 @@ export class ConfusionMatrix {
* @param label - The label that should be considered "positive"
* @return The true negative rate a.k.a. specificity.
*/
getTrueNegativeRate(label: Label) {
getTrueNegativeRate(label: T) {
return this.getTrueNegativeCount(label) / this.getNegativeCount(label);
}

Expand All @@ -222,7 +222,7 @@ export class ConfusionMatrix {
* @param label - The label that should be considered "positive"
* @return the positive predictive value a.k.a. precision.
*/
getPositivePredictiveValue(label: Label) {
getPositivePredictiveValue(label: T) {
const TP = this.getTruePositiveCount(label);
return TP / (TP + this.getFalsePositiveCount(label));
}
Expand All @@ -232,7 +232,7 @@ export class ConfusionMatrix {
* {@link https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values}
* @param label - The label that should be considered "positive"
*/
getNegativePredictiveValue(label: Label) {
getNegativePredictiveValue(label: T) {
const TN = this.getTrueNegativeCount(label);
return TN / (TN + this.getFalseNegativeCount(label));
}
Expand All @@ -242,7 +242,7 @@ export class ConfusionMatrix {
* {@link https://en.wikipedia.org/wiki/Type_I_and_type_II_errors#False_positive_and_false_negative_rates}
* @param label - The label that should be considered "positive"
*/
getFalseNegativeRate(label: Label) {
getFalseNegativeRate(label: T) {
return 1 - this.getTruePositiveRate(label);
}

Expand All @@ -251,7 +251,7 @@ export class ConfusionMatrix {
* {@link https://en.wikipedia.org/wiki/Type_I_and_type_II_errors#False_positive_and_false_negative_rates}
* @param label - The label that should be considered "positive"
*/
getFalsePositiveRate(label: Label) {
getFalsePositiveRate(label: T) {
return 1 - this.getTrueNegativeRate(label);
}

Expand All @@ -260,7 +260,7 @@ export class ConfusionMatrix {
* {@link https://en.wikipedia.org/wiki/False_discovery_rate}
* @param label - The label that should be considered "positive"
*/
getFalseDiscoveryRate(label: Label) {
getFalseDiscoveryRate(label: T) {
const FP = this.getFalsePositiveCount(label);
return FP / (FP + this.getTruePositiveCount(label));
}
Expand All @@ -269,7 +269,7 @@ export class ConfusionMatrix {
* False omission rate (FOR)
* @param label - The label that should be considered "positive"
*/
getFalseOmissionRate(label: Label) {
getFalseOmissionRate(label: T) {
const FN = this.getFalseNegativeCount(label);
return FN / (FN + this.getTruePositiveCount(label));
}
Expand All @@ -279,7 +279,7 @@ export class ConfusionMatrix {
* {@link https://en.wikipedia.org/wiki/F1_score}
* @param label - The label that should be considered "positive"
*/
getF1Score(label: Label) {
getF1Score(label: T) {
const TP = this.getTruePositiveCount(label);
return (
(2 * TP) /
Expand All @@ -294,7 +294,7 @@ export class ConfusionMatrix {
* {@link https://en.wikipedia.org/wiki/Matthews_correlation_coefficient}
* @param label - The label that should be considered "positive"
*/
getMatthewsCorrelationCoefficient(label: Label) {
getMatthewsCorrelationCoefficient(label: T) {
const TP = this.getTruePositiveCount(label);
const TN = this.getTrueNegativeCount(label);
const FP = this.getFalsePositiveCount(label);
Expand All @@ -310,7 +310,7 @@ export class ConfusionMatrix {
* {@link https://en.wikipedia.org/wiki/Youden%27s_J_statistic}
* @param label - The label that should be considered "positive"
*/
getInformedness(label: Label) {
getInformedness(label: T) {
return (
this.getTruePositiveRate(label) + this.getTrueNegativeRate(label) - 1
);
Expand All @@ -320,7 +320,7 @@ export class ConfusionMatrix {
* Markedness
* @param label - The label that should be considered "positive"
*/
getMarkedness(label: Label) {
getMarkedness(label: T) {
return (
this.getPositivePredictiveValue(label) +
this.getNegativePredictiveValue(label) -
Expand All @@ -333,7 +333,7 @@ export class ConfusionMatrix {
* @param label - The label that should be considered "positive"
* @return The 2x2 confusion table. [[TP, FN], [FP, TN]]
*/
getConfusionTable(label: Label) {
getConfusionTable(label: T) {
return [
[this.getTruePositiveCount(label), this.getFalseNegativeCount(label)],
[this.getFalsePositiveCount(label), this.getTrueNegativeCount(label)],
Expand Down Expand Up @@ -362,7 +362,7 @@ export class ConfusionMatrix {
* @param predicted - The predicted label
* @return The element in the confusion matrix
*/
getCount(actual: Label, predicted: Label) {
getCount(actual: T, predicted: T) {
const actualIndex = this.getIndex(actual);
const predictedIndex = this.getIndex(predicted);
return this.matrix[actualIndex][predictedIndex];
Expand All @@ -388,7 +388,7 @@ export class ConfusionMatrix {

type Label = boolean | number | string;

interface FromLabelsOptions {
labels?: Label[];
interface FromLabelsOptions<T extends Label> {
labels?: T[];
sort?: (...args: Label[]) => number;
}

0 comments on commit 8e66373

Please # to comment.