Skip to content

Commit 231d4c4

Browse files
authored
train test split (#299)
1 parent ba979c9 commit 231d4c4

File tree

1 file changed

+16
-7
lines changed
  • pgml-extension/pgml_rust/src

1 file changed

+16
-7
lines changed

pgml-extension/pgml_rust/src/lib.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ mod pgml_rust {
9494
),
9595
};
9696

97-
let (mut x, mut y, mut num_rows) = (vec![], vec![], 0);
97+
let (mut x, mut y, mut num_rows, mut num_features) = (vec![], vec![], 0, 0);
9898

9999
let hyperparams = hyperparams.0;
100100

@@ -131,7 +131,7 @@ mod pgml_rust {
131131
.into_iter()
132132
.map(|column| format!("CAST({} AS REAL)", column))
133133
.collect::<Vec<String>>();
134-
134+
135135
let query = format!(
136136
"SELECT {}, CAST({} AS REAL) FROM {} ORDER BY RANDOM()",
137137
features.clone().join(", "),
@@ -151,11 +151,22 @@ mod pgml_rust {
151151
num_rows += 1;
152152
});
153153

154+
num_features = features.len();
155+
154156
Ok(Some(()))
155157
});
156158

157-
let mut dtrain = DMatrix::from_dense(&x, num_rows).unwrap();
158-
dtrain.set_labels(&y).unwrap();
159+
// todo parameterize test split instead of 0.5
160+
let test_rows = (num_rows as f32 * 0.5).round() as usize;
161+
let train_rows = num_rows - test_rows;
162+
let mut dtrain = DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap();
163+
let mut dtest = DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap();
164+
dtrain.set_labels(&y[..train_rows]).unwrap();
165+
dtest.set_labels(&y[train_rows..]).unwrap();
166+
167+
168+
// specify datasets to evaluate against during training
169+
let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];
159170

160171
// configure objectives, metrics, etc.
161172
let learning_params = parameters::learning::LearningTaskParametersBuilder::default()
@@ -186,8 +197,6 @@ mod pgml_rust {
186197
.build()
187198
.unwrap();
188199

189-
// specify datasets to evaluate against during training
190-
// let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];
191200

192201
// overall configuration for training/evaluation
193202
let params = parameters::TrainingParametersBuilder::default()
@@ -197,7 +206,7 @@ mod pgml_rust {
197206
None => 2,
198207
}) // number of training iterations
199208
.booster_params(booster_params) // model parameters
200-
// .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration
209+
.evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration
201210
.build()
202211
.unwrap();
203212

0 commit comments

Comments
 (0)