Skip to content

Refactor and add more moretrics #440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 137 additions & 94 deletions pgml-extension/src/metrics.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
/// Module providing various metrics used to rank the algorithms.
use ndarray::{Array2, ArrayView1, ArrayView2};
use ndarray::{Array2, ArrayView1};
use std::collections::{BTreeSet, HashMap};

#[derive(PartialEq, Copy, Clone, Debug)]
pub enum Average {
Micro,
Macro,
Binary,
}

/// Confusion matrix metrics for a class.
#[derive(Debug)]
pub struct ConfusionMatrixMetrics {
tp: f32,
Expand All @@ -13,6 +21,7 @@ pub struct ConfusionMatrixMetrics {
}

impl ConfusionMatrixMetrics {
/// Args: TP, FP, FN, TN.
pub fn new(metrics: (f32, f32, f32, f32)) -> ConfusionMatrixMetrics {
ConfusionMatrixMetrics {
tp: metrics.0,
Expand All @@ -23,114 +32,146 @@ impl ConfusionMatrixMetrics {
}
}

pub fn confusion_matrix(
ground_truth: &ArrayView1<usize>,
y_hat: &ArrayView1<usize>,
num_classes: usize,
) -> Array2<f32> {
assert_eq!(ground_truth.len(), y_hat.len());
pub struct ConfusionMatrix {
/// The confusion matrix in its raw form.
matrix: Array2<f32>,

// Distinct classes.
let mut classes = ground_truth.iter().collect::<BTreeSet<_>>();
classes.extend(&mut y_hat.iter().collect::<BTreeSet<_>>().into_iter());
/// Predicates calculated using the confusion matrix, indexed by class number.
metrics: Vec<ConfusionMatrixMetrics>,
}

assert_eq!(num_classes, classes.len());
impl ConfusionMatrix {
/// Construct a new confusion matrix from the ground truth
/// and the predictions.
/// `num_classes` is passed it to ensure that all classes
/// were present in the test set.
pub fn new(
ground_truth: &ArrayView1<usize>,
y_hat: &ArrayView1<usize>,
num_classes: usize,
) -> ConfusionMatrix {
assert_eq!(ground_truth.len(), y_hat.len());

// Distinct classes.
let mut classes = ground_truth.iter().collect::<BTreeSet<_>>();
classes.extend(&mut y_hat.iter().collect::<BTreeSet<_>>().into_iter());

assert_eq!(num_classes, classes.len());

// Class value = index in the confusion matrix
// e.g. class value 5 will be index 4 if there are classes 1, 2, 3 and 4 present.
let indexes = classes
.iter()
.enumerate()
.map(|(a, b)| (**b, a))
.collect::<HashMap<usize, usize>>();

let mut matrix = Array2::zeros((num_classes, num_classes));

for (i, t) in ground_truth.iter().enumerate() {
let h = y_hat[i];

matrix[(indexes[t], indexes[&h])] += 1.0;
}

// Class value = index in the confusion matrix
// e.g. class value 5 will be index 4 if there are classes 1, 2, 3 and 4 present.
let indexes = classes
.iter()
.enumerate()
.map(|(a, b)| (**b, a))
.collect::<HashMap<usize, usize>>();
let mut metrics = Vec::new();

let mut matrix = Array2::zeros((num_classes, num_classes));
// Scikit confusion matrix starts from 1 and goes to 0,
// ours starts from 0 and goes to 1. No big deal,
// just flip everything lol.
if num_classes == 2 {
let tp = matrix[(1, 1)];
let fp = matrix[(0, 1)];
let fn_ = matrix[(1, 0)];
let tn = matrix[(0, 0)];

for (i, t) in ground_truth.iter().enumerate() {
let h = y_hat[i];
metrics.push(ConfusionMatrixMetrics::new((tp, fp, fn_, tn)));
} else {
for class in 0..num_classes {
let tp = matrix[(class, class)];
let fp = matrix.row(class).sum() - tp;
let fn_ = matrix.column(class).sum() - tp;
let tn = matrix.sum() - tp - fp - fn_;

metrics.push(ConfusionMatrixMetrics::new((tp, fp, fn_, tn)));
}
}

matrix[(indexes[t], indexes[&h])] += 1.0;
ConfusionMatrix { matrix, metrics }
}

matrix
}

/// Return macro-averaged recall for the confusion matrix.
pub fn metrics(matrix: &ArrayView2<f32>, num_classes: usize) -> Vec<ConfusionMatrixMetrics> {
let mut metrics = Vec::new();

// Scikit confusion matrix starts from 1 and goes to 0,
// ours starts from 0 and goes to 1. No big deal,
// just flip everything lol.
if num_classes == 2 {
let tp = matrix[(1, 1)];
let fp = matrix[(0, 1)];
let fn_ = matrix[(1, 0)];
let tn = matrix[(0, 0)];

metrics.push(ConfusionMatrixMetrics::new((tp, fp, fn_, tn)));
} else {
for class in 0..num_classes {
let tp = matrix[(class, class)];
let fp = matrix.row(class).sum() - tp;
let fn_ = matrix.column(class).sum() - tp;
let tn = matrix.sum() - tp - fp - fn_;
pub fn accuracy(&self) -> f32 {
let numerator = self.matrix.diag().sum();
let denominator = self.matrix.sum();

metrics.push(ConfusionMatrixMetrics::new((tp, fp, fn_, tn)));
}
numerator / denominator
}

metrics
}

/// Return macro-averaged recall.
pub fn recall(metrics: &Vec<ConfusionMatrixMetrics>) -> f32 {
let recalls = metrics
.iter()
.map(|m| m.tp / (m.tp + m.fn_))
.collect::<Vec<f32>>();
/// Average recall.
pub fn recall(&self) -> f32 {
let recalls = self
.metrics
.iter()
.map(|m| m.tp / (m.tp + m.fn_))
.collect::<Vec<f32>>();

recalls.iter().sum::<f32>() / recalls.len() as f32
}
recalls.iter().sum::<f32>() / recalls.len() as f32
}

pub fn precision(metrics: &Vec<ConfusionMatrixMetrics>) -> f32 {
let precisions = metrics
.iter()
.map(|m| m.tp / (m.tp + m.fp))
.collect::<Vec<f32>>();
/// Average precision.
pub fn precision(&self) -> f32 {
let precisions = self
.metrics
.iter()
.map(|m| m.tp / (m.tp + m.fp))
.collect::<Vec<f32>>();

precisions.iter().sum::<f32>() / precisions.len() as f32
}
precisions.iter().sum::<f32>() / precisions.len() as f32
}

pub fn f1(metrics: &Vec<ConfusionMatrixMetrics>) -> f32 {
let recalls = metrics
.iter()
.map(|m| m.tp / (m.tp + m.fn_))
.collect::<Vec<f32>>();
let precisions = metrics
.iter()
.map(|m| m.tp / (m.tp + m.fp))
.collect::<Vec<f32>>();

let mut f1s = Vec::new();

for (i, recall) in recalls.iter().enumerate() {
let precision = precisions[i];
f1s.push(2. * ((precision * recall) / (precision + recall)));
pub fn f1(&self, average: Average) -> f32 {
match average {
Average::Macro => self.f1_macro(),
Average::Micro | Average::Binary => self.f1_micro(), // micro = binary if num_classes = 2
}
}

f1s.iter().sum::<f32>() / f1s.len() as f32
}
/// Calculate the f1 using micro metrics, i.e. the sum of predicates.
/// This evaluates the classifier as a whole instead of evaluating it as a sum of individual parts.
fn f1_micro(&self) -> f32 {
let tp = self.metrics.iter().map(|m| m.tp).sum::<f32>();
let fn_ = self.metrics.iter().map(|m| m.fn_).sum::<f32>();
let fp = self.metrics.iter().map(|m| m.fp).sum::<f32>();

pub fn f1_micro(metrics: &Vec<ConfusionMatrixMetrics>) -> f32 {
let tp = metrics.iter().map(|m| m.tp).sum::<f32>();
let fn_ = metrics.iter().map(|m| m.fn_).sum::<f32>();
let fp = metrics.iter().map(|m| m.fp).sum::<f32>();
let recall = tp / (tp + fn_);
let precision = tp / (tp + fp);

let recall = tp / (tp + fn_);
let precision = tp / (tp + fp);
2. * ((precision * recall) / (precision + recall))
}

/// Calculate f1 using the average of class f1's.
/// This gives equal opportunity to each class to impact the overall score.
fn f1_macro(&self) -> f32 {
let recalls = self
.metrics
.iter()
.map(|m| m.tp / (m.tp + m.fn_))
.collect::<Vec<f32>>();
let precisions = self
.metrics
.iter()
.map(|m| m.tp / (m.tp + m.fp))
.collect::<Vec<f32>>();

let mut f1s = Vec::new();

for (i, recall) in recalls.iter().enumerate() {
let precision = precisions[i];
f1s.push(2. * ((precision * recall) / (precision + recall)));
}

2. * ((precision * recall) / (precision + recall))
f1s.iter().sum::<f32>() / f1s.len() as f32
}
}

#[cfg(test)]
Expand All @@ -143,15 +184,17 @@ mod test {
let ground_truth = array![1, 2, 3, 4, 4];
let y_hat = array![1, 2, 3, 4, 4];

let mat = confusion_matrix(
let mat = ConfusionMatrix::new(
&ArrayView1::from(&ground_truth),
&ArrayView1::from(&y_hat),
4,
);
let metrics = metrics(&ArrayView2::from(&mat), 4);

assert_eq!(mat[(3, 3)], 2.0);
assert_eq!(f1(&metrics), 1.0);
assert_eq!(f1_micro(&metrics), 1.0);
let f1 = mat.f1(Average::Macro);
let f1_micro = mat.f1(Average::Micro);

assert_eq!(mat.matrix[(3, 3)], 2.0);
assert_eq!(f1, 1.0);
assert_eq!(f1_micro, 1.0);
}
}
Loading