Skip to content

Commit 3807382

Browse files
authored
handle errors cleanly rather than crashing (#1010)
1 parent a60abcf commit 3807382

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

pgml-extension/src/metrics.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/// Module providing various metrics used to rank the algorithms.
2+
use pgrx::*;
23
use std::collections::{BTreeSet, HashMap};
34

45
use ndarray::{Array2, ArrayView1};
@@ -51,13 +52,17 @@ impl ConfusionMatrix {
5152
y_hat: &ArrayView1<usize>,
5253
num_classes: usize,
5354
) -> ConfusionMatrix {
54-
assert_eq!(ground_truth.len(), y_hat.len());
55+
if ground_truth.len() != y_hat.len() {
56+
error!("Can't compute metrics when the ground truth labels are a different size than the predicted labels. {} != {}", ground_truth.len(), y_hat.len())
57+
};
5558

5659
// Distinct classes.
5760
let mut classes = ground_truth.iter().collect::<BTreeSet<_>>();
5861
classes.extend(&mut y_hat.iter().collect::<BTreeSet<_>>().into_iter());
5962

60-
assert_eq!(num_classes, classes.len());
63+
if num_classes != classes.len() {
64+
error!("Can't compute metrics when the number of classes in the test set is different than the number of classes in the training set. {} != {}", num_classes, classes.len())
65+
};
6166

6267
// Class value = index in the confusion matrix
6368
// e.g. class value 5 will be index 4 if there are classes 1, 2, 3 and 4 present.

0 commit comments

Comments
 (0)