Skip to content

train test split #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions pgml-extension/pgml_rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ mod pgml_rust {
),
};

let (mut x, mut y, mut num_rows) = (vec![], vec![], 0);
let (mut x, mut y, mut num_rows, mut num_features) = (vec![], vec![], 0, 0);

let hyperparams = hyperparams.0;

Expand Down Expand Up @@ -131,7 +131,7 @@ mod pgml_rust {
.into_iter()
.map(|column| format!("CAST({} AS REAL)", column))
.collect::<Vec<String>>();

let query = format!(
"SELECT {}, CAST({} AS REAL) FROM {} ORDER BY RANDOM()",
features.clone().join(", "),
Expand All @@ -151,11 +151,22 @@ mod pgml_rust {
num_rows += 1;
});

num_features = features.len();

Ok(Some(()))
});

let mut dtrain = DMatrix::from_dense(&x, num_rows).unwrap();
dtrain.set_labels(&y).unwrap();
// todo parameterize test split instead of 0.5
let test_rows = (num_rows as f32 * 0.5).round() as usize;
let train_rows = num_rows - test_rows;
let mut dtrain = DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap();
let mut dtest = DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap();
dtrain.set_labels(&y[..train_rows]).unwrap();
dtest.set_labels(&y[train_rows..]).unwrap();


// specify datasets to evaluate against during training
let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];

// configure objectives, metrics, etc.
let learning_params = parameters::learning::LearningTaskParametersBuilder::default()
Expand Down Expand Up @@ -186,8 +197,6 @@ mod pgml_rust {
.build()
.unwrap();

// specify datasets to evaluate against during training
// let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];

// overall configuration for training/evaluation
let params = parameters::TrainingParametersBuilder::default()
Expand All @@ -197,7 +206,7 @@ mod pgml_rust {
None => 2,
}) // number of training iterations
.booster_params(booster_params) // model parameters
// .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration
.evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration
.build()
.unwrap();

Expand Down