1
1
/// Module providing various metrics used to rank the algorithms.
2
- use ndarray:: { Array2 , ArrayView1 , ArrayView2 } ;
2
+ use ndarray:: { Array2 , ArrayView1 } ;
3
3
use std:: collections:: { BTreeSet , HashMap } ;
4
4
5
+ #[ derive( PartialEq , Copy , Clone , Debug ) ]
6
+ pub enum Average {
7
+ Micro ,
8
+ Macro ,
9
+ Binary ,
10
+ }
11
+
12
+ /// Confusion matrix metrics for a class.
5
13
#[ derive( Debug ) ]
6
14
pub struct ConfusionMatrixMetrics {
7
15
tp : f32 ,
@@ -13,6 +21,7 @@ pub struct ConfusionMatrixMetrics {
13
21
}
14
22
15
23
impl ConfusionMatrixMetrics {
24
+ /// Args: TP, FP, FN, TN.
16
25
pub fn new ( metrics : ( f32 , f32 , f32 , f32 ) ) -> ConfusionMatrixMetrics {
17
26
ConfusionMatrixMetrics {
18
27
tp : metrics. 0 ,
@@ -23,114 +32,146 @@ impl ConfusionMatrixMetrics {
23
32
}
24
33
}
25
34
26
- pub fn confusion_matrix (
27
- ground_truth : & ArrayView1 < usize > ,
28
- y_hat : & ArrayView1 < usize > ,
29
- num_classes : usize ,
30
- ) -> Array2 < f32 > {
31
- assert_eq ! ( ground_truth. len( ) , y_hat. len( ) ) ;
35
+ pub struct ConfusionMatrix {
36
+ /// The confusion matrix in its raw form.
37
+ matrix : Array2 < f32 > ,
32
38
33
- // Distinct classes .
34
- let mut classes = ground_truth . iter ( ) . collect :: < BTreeSet < _ > > ( ) ;
35
- classes . extend ( & mut y_hat . iter ( ) . collect :: < BTreeSet < _ > > ( ) . into_iter ( ) ) ;
39
+ /// Predicates calculated using the confusion matrix, indexed by class number .
40
+ metrics : Vec < ConfusionMatrixMetrics > ,
41
+ }
36
42
37
- assert_eq ! ( num_classes, classes. len( ) ) ;
43
+ impl ConfusionMatrix {
44
+ /// Construct a new confusion matrix from the ground truth
45
+ /// and the predictions.
46
+ /// `num_classes` is passed it to ensure that all classes
47
+ /// were present in the test set.
48
+ pub fn new (
49
+ ground_truth : & ArrayView1 < usize > ,
50
+ y_hat : & ArrayView1 < usize > ,
51
+ num_classes : usize ,
52
+ ) -> ConfusionMatrix {
53
+ assert_eq ! ( ground_truth. len( ) , y_hat. len( ) ) ;
54
+
55
+ // Distinct classes.
56
+ let mut classes = ground_truth. iter ( ) . collect :: < BTreeSet < _ > > ( ) ;
57
+ classes. extend ( & mut y_hat. iter ( ) . collect :: < BTreeSet < _ > > ( ) . into_iter ( ) ) ;
58
+
59
+ assert_eq ! ( num_classes, classes. len( ) ) ;
60
+
61
+ // Class value = index in the confusion matrix
62
+ // e.g. class value 5 will be index 4 if there are classes 1, 2, 3 and 4 present.
63
+ let indexes = classes
64
+ . iter ( )
65
+ . enumerate ( )
66
+ . map ( |( a, b) | ( * * b, a) )
67
+ . collect :: < HashMap < usize , usize > > ( ) ;
68
+
69
+ let mut matrix = Array2 :: zeros ( ( num_classes, num_classes) ) ;
70
+
71
+ for ( i, t) in ground_truth. iter ( ) . enumerate ( ) {
72
+ let h = y_hat[ i] ;
73
+
74
+ matrix[ ( indexes[ t] , indexes[ & h] ) ] += 1.0 ;
75
+ }
38
76
39
- // Class value = index in the confusion matrix
40
- // e.g. class value 5 will be index 4 if there are classes 1, 2, 3 and 4 present.
41
- let indexes = classes
42
- . iter ( )
43
- . enumerate ( )
44
- . map ( |( a, b) | ( * * b, a) )
45
- . collect :: < HashMap < usize , usize > > ( ) ;
77
+ let mut metrics = Vec :: new ( ) ;
46
78
47
- let mut matrix = Array2 :: zeros ( ( num_classes, num_classes) ) ;
79
+ // Scikit confusion matrix starts from 1 and goes to 0,
80
+ // ours starts from 0 and goes to 1. No big deal,
81
+ // just flip everything lol.
82
+ if num_classes == 2 {
83
+ let tp = matrix[ ( 1 , 1 ) ] ;
84
+ let fp = matrix[ ( 0 , 1 ) ] ;
85
+ let fn_ = matrix[ ( 1 , 0 ) ] ;
86
+ let tn = matrix[ ( 0 , 0 ) ] ;
48
87
49
- for ( i, t) in ground_truth. iter ( ) . enumerate ( ) {
50
- let h = y_hat[ i] ;
88
+ metrics. push ( ConfusionMatrixMetrics :: new ( ( tp, fp, fn_, tn) ) ) ;
89
+ } else {
90
+ for class in 0 ..num_classes {
91
+ let tp = matrix[ ( class, class) ] ;
92
+ let fp = matrix. row ( class) . sum ( ) - tp;
93
+ let fn_ = matrix. column ( class) . sum ( ) - tp;
94
+ let tn = matrix. sum ( ) - tp - fp - fn_;
95
+
96
+ metrics. push ( ConfusionMatrixMetrics :: new ( ( tp, fp, fn_, tn) ) ) ;
97
+ }
98
+ }
51
99
52
- matrix[ ( indexes [ t ] , indexes [ & h ] ) ] += 1.0 ;
100
+ ConfusionMatrix { matrix, metrics }
53
101
}
54
102
55
- matrix
56
- }
57
-
58
- /// Return macro-averaged recall for the confusion matrix.
59
- pub fn metrics ( matrix : & ArrayView2 < f32 > , num_classes : usize ) -> Vec < ConfusionMatrixMetrics > {
60
- let mut metrics = Vec :: new ( ) ;
61
-
62
- // Scikit confusion matrix starts from 1 and goes to 0,
63
- // ours starts from 0 and goes to 1. No big deal,
64
- // just flip everything lol.
65
- if num_classes == 2 {
66
- let tp = matrix[ ( 1 , 1 ) ] ;
67
- let fp = matrix[ ( 0 , 1 ) ] ;
68
- let fn_ = matrix[ ( 1 , 0 ) ] ;
69
- let tn = matrix[ ( 0 , 0 ) ] ;
70
-
71
- metrics. push ( ConfusionMatrixMetrics :: new ( ( tp, fp, fn_, tn) ) ) ;
72
- } else {
73
- for class in 0 ..num_classes {
74
- let tp = matrix[ ( class, class) ] ;
75
- let fp = matrix. row ( class) . sum ( ) - tp;
76
- let fn_ = matrix. column ( class) . sum ( ) - tp;
77
- let tn = matrix. sum ( ) - tp - fp - fn_;
103
+ pub fn accuracy ( & self ) -> f32 {
104
+ let numerator = self . matrix . diag ( ) . sum ( ) ;
105
+ let denominator = self . matrix . sum ( ) ;
78
106
79
- metrics. push ( ConfusionMatrixMetrics :: new ( ( tp, fp, fn_, tn) ) ) ;
80
- }
107
+ numerator / denominator
81
108
}
82
109
83
- metrics
84
- }
85
-
86
- /// Return macro-averaged recall.
87
- pub fn recall ( metrics : & Vec < ConfusionMatrixMetrics > ) -> f32 {
88
- let recalls = metrics
89
- . iter ( )
90
- . map ( |m| m. tp / ( m. tp + m. fn_ ) )
91
- . collect :: < Vec < f32 > > ( ) ;
110
+ /// Average recall.
111
+ pub fn recall ( & self ) -> f32 {
112
+ let recalls = self
113
+ . metrics
114
+ . iter ( )
115
+ . map ( |m| m. tp / ( m. tp + m. fn_ ) )
116
+ . collect :: < Vec < f32 > > ( ) ;
92
117
93
- recalls. iter ( ) . sum :: < f32 > ( ) / recalls. len ( ) as f32
94
- }
118
+ recalls. iter ( ) . sum :: < f32 > ( ) / recalls. len ( ) as f32
119
+ }
95
120
96
- pub fn precision ( metrics : & Vec < ConfusionMatrixMetrics > ) -> f32 {
97
- let precisions = metrics
98
- . iter ( )
99
- . map ( |m| m. tp / ( m. tp + m. fp ) )
100
- . collect :: < Vec < f32 > > ( ) ;
121
+ /// Average precision.
122
+ pub fn precision ( & self ) -> f32 {
123
+ let precisions = self
124
+ . metrics
125
+ . iter ( )
126
+ . map ( |m| m. tp / ( m. tp + m. fp ) )
127
+ . collect :: < Vec < f32 > > ( ) ;
101
128
102
- precisions. iter ( ) . sum :: < f32 > ( ) / precisions. len ( ) as f32
103
- }
129
+ precisions. iter ( ) . sum :: < f32 > ( ) / precisions. len ( ) as f32
130
+ }
104
131
105
- pub fn f1 ( metrics : & Vec < ConfusionMatrixMetrics > ) -> f32 {
106
- let recalls = metrics
107
- . iter ( )
108
- . map ( |m| m. tp / ( m. tp + m. fn_ ) )
109
- . collect :: < Vec < f32 > > ( ) ;
110
- let precisions = metrics
111
- . iter ( )
112
- . map ( |m| m. tp / ( m. tp + m. fp ) )
113
- . collect :: < Vec < f32 > > ( ) ;
114
-
115
- let mut f1s = Vec :: new ( ) ;
116
-
117
- for ( i, recall) in recalls. iter ( ) . enumerate ( ) {
118
- let precision = precisions[ i] ;
119
- f1s. push ( 2. * ( ( precision * recall) / ( precision + recall) ) ) ;
132
+ pub fn f1 ( & self , average : Average ) -> f32 {
133
+ match average {
134
+ Average :: Macro => self . f1_macro ( ) ,
135
+ Average :: Micro | Average :: Binary => self . f1_micro ( ) , // micro = binary if num_classes = 2
136
+ }
120
137
}
121
138
122
- f1s. iter ( ) . sum :: < f32 > ( ) / f1s. len ( ) as f32
123
- }
139
+ /// Calculate the f1 using micro metrics, i.e. the sum of predicates.
140
+ /// This evaluates the classifier as a whole instead of evaluating it as a sum of individual parts.
141
+ fn f1_micro ( & self ) -> f32 {
142
+ let tp = self . metrics . iter ( ) . map ( |m| m. tp ) . sum :: < f32 > ( ) ;
143
+ let fn_ = self . metrics . iter ( ) . map ( |m| m. fn_ ) . sum :: < f32 > ( ) ;
144
+ let fp = self . metrics . iter ( ) . map ( |m| m. fp ) . sum :: < f32 > ( ) ;
124
145
125
- pub fn f1_micro ( metrics : & Vec < ConfusionMatrixMetrics > ) -> f32 {
126
- let tp = metrics. iter ( ) . map ( |m| m. tp ) . sum :: < f32 > ( ) ;
127
- let fn_ = metrics. iter ( ) . map ( |m| m. fn_ ) . sum :: < f32 > ( ) ;
128
- let fp = metrics. iter ( ) . map ( |m| m. fp ) . sum :: < f32 > ( ) ;
146
+ let recall = tp / ( tp + fn_) ;
147
+ let precision = tp / ( tp + fp) ;
129
148
130
- let recall = tp / ( tp + fn_) ;
131
- let precision = tp / ( tp + fp) ;
149
+ 2. * ( ( precision * recall) / ( precision + recall) )
150
+ }
151
+
152
+ /// Calculate f1 using the average of class f1's.
153
+ /// This gives equal opportunity to each class to impact the overall score.
154
+ fn f1_macro ( & self ) -> f32 {
155
+ let recalls = self
156
+ . metrics
157
+ . iter ( )
158
+ . map ( |m| m. tp / ( m. tp + m. fn_ ) )
159
+ . collect :: < Vec < f32 > > ( ) ;
160
+ let precisions = self
161
+ . metrics
162
+ . iter ( )
163
+ . map ( |m| m. tp / ( m. tp + m. fp ) )
164
+ . collect :: < Vec < f32 > > ( ) ;
165
+
166
+ let mut f1s = Vec :: new ( ) ;
167
+
168
+ for ( i, recall) in recalls. iter ( ) . enumerate ( ) {
169
+ let precision = precisions[ i] ;
170
+ f1s. push ( 2. * ( ( precision * recall) / ( precision + recall) ) ) ;
171
+ }
132
172
133
- 2. * ( ( precision * recall) / ( precision + recall) )
173
+ f1s. iter ( ) . sum :: < f32 > ( ) / f1s. len ( ) as f32
174
+ }
134
175
}
135
176
136
177
#[ cfg( test) ]
@@ -143,15 +184,17 @@ mod test {
143
184
let ground_truth = array ! [ 1 , 2 , 3 , 4 , 4 ] ;
144
185
let y_hat = array ! [ 1 , 2 , 3 , 4 , 4 ] ;
145
186
146
- let mat = confusion_matrix (
187
+ let mat = ConfusionMatrix :: new (
147
188
& ArrayView1 :: from ( & ground_truth) ,
148
189
& ArrayView1 :: from ( & y_hat) ,
149
190
4 ,
150
191
) ;
151
- let metrics = metrics ( & ArrayView2 :: from ( & mat) , 4 ) ;
152
192
153
- assert_eq ! ( mat[ ( 3 , 3 ) ] , 2.0 ) ;
154
- assert_eq ! ( f1( & metrics) , 1.0 ) ;
155
- assert_eq ! ( f1_micro( & metrics) , 1.0 ) ;
193
+ let f1 = mat. f1 ( Average :: Macro ) ;
194
+ let f1_micro = mat. f1 ( Average :: Micro ) ;
195
+
196
+ assert_eq ! ( mat. matrix[ ( 3 , 3 ) ] , 2.0 ) ;
197
+ assert_eq ! ( f1, 1.0 ) ;
198
+ assert_eq ! ( f1_micro, 1.0 ) ;
156
199
}
157
200
}
0 commit comments