Skip to content

Commit fbffa48

Browse files
authored
Refactor and add more moretrics (#440)
1 parent fa96dad commit fbffa48

File tree

2 files changed

+194
-152
lines changed

2 files changed

+194
-152
lines changed

pgml-extension/src/metrics.rs

Lines changed: 137 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
/// Module providing various metrics used to rank the algorithms.
2-
use ndarray::{Array2, ArrayView1, ArrayView2};
2+
use ndarray::{Array2, ArrayView1};
33
use std::collections::{BTreeSet, HashMap};
44

5+
#[derive(PartialEq, Copy, Clone, Debug)]
6+
pub enum Average {
7+
Micro,
8+
Macro,
9+
Binary,
10+
}
11+
12+
/// Confusion matrix metrics for a class.
513
#[derive(Debug)]
614
pub struct ConfusionMatrixMetrics {
715
tp: f32,
@@ -13,6 +21,7 @@ pub struct ConfusionMatrixMetrics {
1321
}
1422

1523
impl ConfusionMatrixMetrics {
24+
/// Args: TP, FP, FN, TN.
1625
pub fn new(metrics: (f32, f32, f32, f32)) -> ConfusionMatrixMetrics {
1726
ConfusionMatrixMetrics {
1827
tp: metrics.0,
@@ -23,114 +32,146 @@ impl ConfusionMatrixMetrics {
2332
}
2433
}
2534

26-
pub fn confusion_matrix(
27-
ground_truth: &ArrayView1<usize>,
28-
y_hat: &ArrayView1<usize>,
29-
num_classes: usize,
30-
) -> Array2<f32> {
31-
assert_eq!(ground_truth.len(), y_hat.len());
35+
pub struct ConfusionMatrix {
36+
/// The confusion matrix in its raw form.
37+
matrix: Array2<f32>,
3238

33-
// Distinct classes.
34-
let mut classes = ground_truth.iter().collect::<BTreeSet<_>>();
35-
classes.extend(&mut y_hat.iter().collect::<BTreeSet<_>>().into_iter());
39+
/// Predicates calculated using the confusion matrix, indexed by class number.
40+
metrics: Vec<ConfusionMatrixMetrics>,
41+
}
3642

37-
assert_eq!(num_classes, classes.len());
43+
impl ConfusionMatrix {
44+
/// Construct a new confusion matrix from the ground truth
45+
/// and the predictions.
46+
/// `num_classes` is passed it to ensure that all classes
47+
/// were present in the test set.
48+
pub fn new(
49+
ground_truth: &ArrayView1<usize>,
50+
y_hat: &ArrayView1<usize>,
51+
num_classes: usize,
52+
) -> ConfusionMatrix {
53+
assert_eq!(ground_truth.len(), y_hat.len());
54+
55+
// Distinct classes.
56+
let mut classes = ground_truth.iter().collect::<BTreeSet<_>>();
57+
classes.extend(&mut y_hat.iter().collect::<BTreeSet<_>>().into_iter());
58+
59+
assert_eq!(num_classes, classes.len());
60+
61+
// Class value = index in the confusion matrix
62+
// e.g. class value 5 will be index 4 if there are classes 1, 2, 3 and 4 present.
63+
let indexes = classes
64+
.iter()
65+
.enumerate()
66+
.map(|(a, b)| (**b, a))
67+
.collect::<HashMap<usize, usize>>();
68+
69+
let mut matrix = Array2::zeros((num_classes, num_classes));
70+
71+
for (i, t) in ground_truth.iter().enumerate() {
72+
let h = y_hat[i];
73+
74+
matrix[(indexes[t], indexes[&h])] += 1.0;
75+
}
3876

39-
// Class value = index in the confusion matrix
40-
// e.g. class value 5 will be index 4 if there are classes 1, 2, 3 and 4 present.
41-
let indexes = classes
42-
.iter()
43-
.enumerate()
44-
.map(|(a, b)| (**b, a))
45-
.collect::<HashMap<usize, usize>>();
77+
let mut metrics = Vec::new();
4678

47-
let mut matrix = Array2::zeros((num_classes, num_classes));
79+
// Scikit confusion matrix starts from 1 and goes to 0,
80+
// ours starts from 0 and goes to 1. No big deal,
81+
// just flip everything lol.
82+
if num_classes == 2 {
83+
let tp = matrix[(1, 1)];
84+
let fp = matrix[(0, 1)];
85+
let fn_ = matrix[(1, 0)];
86+
let tn = matrix[(0, 0)];
4887

49-
for (i, t) in ground_truth.iter().enumerate() {
50-
let h = y_hat[i];
88+
metrics.push(ConfusionMatrixMetrics::new((tp, fp, fn_, tn)));
89+
} else {
90+
for class in 0..num_classes {
91+
let tp = matrix[(class, class)];
92+
let fp = matrix.row(class).sum() - tp;
93+
let fn_ = matrix.column(class).sum() - tp;
94+
let tn = matrix.sum() - tp - fp - fn_;
95+
96+
metrics.push(ConfusionMatrixMetrics::new((tp, fp, fn_, tn)));
97+
}
98+
}
5199

52-
matrix[(indexes[t], indexes[&h])] += 1.0;
100+
ConfusionMatrix { matrix, metrics }
53101
}
54102

55-
matrix
56-
}
57-
58-
/// Return macro-averaged recall for the confusion matrix.
59-
pub fn metrics(matrix: &ArrayView2<f32>, num_classes: usize) -> Vec<ConfusionMatrixMetrics> {
60-
let mut metrics = Vec::new();
61-
62-
// Scikit confusion matrix starts from 1 and goes to 0,
63-
// ours starts from 0 and goes to 1. No big deal,
64-
// just flip everything lol.
65-
if num_classes == 2 {
66-
let tp = matrix[(1, 1)];
67-
let fp = matrix[(0, 1)];
68-
let fn_ = matrix[(1, 0)];
69-
let tn = matrix[(0, 0)];
70-
71-
metrics.push(ConfusionMatrixMetrics::new((tp, fp, fn_, tn)));
72-
} else {
73-
for class in 0..num_classes {
74-
let tp = matrix[(class, class)];
75-
let fp = matrix.row(class).sum() - tp;
76-
let fn_ = matrix.column(class).sum() - tp;
77-
let tn = matrix.sum() - tp - fp - fn_;
103+
pub fn accuracy(&self) -> f32 {
104+
let numerator = self.matrix.diag().sum();
105+
let denominator = self.matrix.sum();
78106

79-
metrics.push(ConfusionMatrixMetrics::new((tp, fp, fn_, tn)));
80-
}
107+
numerator / denominator
81108
}
82109

83-
metrics
84-
}
85-
86-
/// Return macro-averaged recall.
87-
pub fn recall(metrics: &Vec<ConfusionMatrixMetrics>) -> f32 {
88-
let recalls = metrics
89-
.iter()
90-
.map(|m| m.tp / (m.tp + m.fn_))
91-
.collect::<Vec<f32>>();
110+
/// Average recall.
111+
pub fn recall(&self) -> f32 {
112+
let recalls = self
113+
.metrics
114+
.iter()
115+
.map(|m| m.tp / (m.tp + m.fn_))
116+
.collect::<Vec<f32>>();
92117

93-
recalls.iter().sum::<f32>() / recalls.len() as f32
94-
}
118+
recalls.iter().sum::<f32>() / recalls.len() as f32
119+
}
95120

96-
pub fn precision(metrics: &Vec<ConfusionMatrixMetrics>) -> f32 {
97-
let precisions = metrics
98-
.iter()
99-
.map(|m| m.tp / (m.tp + m.fp))
100-
.collect::<Vec<f32>>();
121+
/// Average precision.
122+
pub fn precision(&self) -> f32 {
123+
let precisions = self
124+
.metrics
125+
.iter()
126+
.map(|m| m.tp / (m.tp + m.fp))
127+
.collect::<Vec<f32>>();
101128

102-
precisions.iter().sum::<f32>() / precisions.len() as f32
103-
}
129+
precisions.iter().sum::<f32>() / precisions.len() as f32
130+
}
104131

105-
pub fn f1(metrics: &Vec<ConfusionMatrixMetrics>) -> f32 {
106-
let recalls = metrics
107-
.iter()
108-
.map(|m| m.tp / (m.tp + m.fn_))
109-
.collect::<Vec<f32>>();
110-
let precisions = metrics
111-
.iter()
112-
.map(|m| m.tp / (m.tp + m.fp))
113-
.collect::<Vec<f32>>();
114-
115-
let mut f1s = Vec::new();
116-
117-
for (i, recall) in recalls.iter().enumerate() {
118-
let precision = precisions[i];
119-
f1s.push(2. * ((precision * recall) / (precision + recall)));
132+
pub fn f1(&self, average: Average) -> f32 {
133+
match average {
134+
Average::Macro => self.f1_macro(),
135+
Average::Micro | Average::Binary => self.f1_micro(), // micro = binary if num_classes = 2
136+
}
120137
}
121138

122-
f1s.iter().sum::<f32>() / f1s.len() as f32
123-
}
139+
/// Calculate the f1 using micro metrics, i.e. the sum of predicates.
140+
/// This evaluates the classifier as a whole instead of evaluating it as a sum of individual parts.
141+
fn f1_micro(&self) -> f32 {
142+
let tp = self.metrics.iter().map(|m| m.tp).sum::<f32>();
143+
let fn_ = self.metrics.iter().map(|m| m.fn_).sum::<f32>();
144+
let fp = self.metrics.iter().map(|m| m.fp).sum::<f32>();
124145

125-
pub fn f1_micro(metrics: &Vec<ConfusionMatrixMetrics>) -> f32 {
126-
let tp = metrics.iter().map(|m| m.tp).sum::<f32>();
127-
let fn_ = metrics.iter().map(|m| m.fn_).sum::<f32>();
128-
let fp = metrics.iter().map(|m| m.fp).sum::<f32>();
146+
let recall = tp / (tp + fn_);
147+
let precision = tp / (tp + fp);
129148

130-
let recall = tp / (tp + fn_);
131-
let precision = tp / (tp + fp);
149+
2. * ((precision * recall) / (precision + recall))
150+
}
151+
152+
/// Calculate f1 using the average of class f1's.
153+
/// This gives equal opportunity to each class to impact the overall score.
154+
fn f1_macro(&self) -> f32 {
155+
let recalls = self
156+
.metrics
157+
.iter()
158+
.map(|m| m.tp / (m.tp + m.fn_))
159+
.collect::<Vec<f32>>();
160+
let precisions = self
161+
.metrics
162+
.iter()
163+
.map(|m| m.tp / (m.tp + m.fp))
164+
.collect::<Vec<f32>>();
165+
166+
let mut f1s = Vec::new();
167+
168+
for (i, recall) in recalls.iter().enumerate() {
169+
let precision = precisions[i];
170+
f1s.push(2. * ((precision * recall) / (precision + recall)));
171+
}
132172

133-
2. * ((precision * recall) / (precision + recall))
173+
f1s.iter().sum::<f32>() / f1s.len() as f32
174+
}
134175
}
135176

136177
#[cfg(test)]
@@ -143,15 +184,17 @@ mod test {
143184
let ground_truth = array![1, 2, 3, 4, 4];
144185
let y_hat = array![1, 2, 3, 4, 4];
145186

146-
let mat = confusion_matrix(
187+
let mat = ConfusionMatrix::new(
147188
&ArrayView1::from(&ground_truth),
148189
&ArrayView1::from(&y_hat),
149190
4,
150191
);
151-
let metrics = metrics(&ArrayView2::from(&mat), 4);
152192

153-
assert_eq!(mat[(3, 3)], 2.0);
154-
assert_eq!(f1(&metrics), 1.0);
155-
assert_eq!(f1_micro(&metrics), 1.0);
193+
let f1 = mat.f1(Average::Macro);
194+
let f1_micro = mat.f1(Average::Micro);
195+
196+
assert_eq!(mat.matrix[(3, 3)], 2.0);
197+
assert_eq!(f1, 1.0);
198+
assert_eq!(f1_micro, 1.0);
156199
}
157200
}

0 commit comments

Comments
 (0)