@@ -94,7 +94,7 @@ mod pgml_rust {
94
94
) ,
95
95
} ;
96
96
97
- let ( mut x, mut y, mut num_rows) = ( vec ! [ ] , vec ! [ ] , 0 ) ;
97
+ let ( mut x, mut y, mut num_rows, mut num_features ) = ( vec ! [ ] , vec ! [ ] , 0 , 0 ) ;
98
98
99
99
let hyperparams = hyperparams. 0 ;
100
100
@@ -131,7 +131,7 @@ mod pgml_rust {
131
131
. into_iter ( )
132
132
. map ( |column| format ! ( "CAST({} AS REAL)" , column) )
133
133
. collect :: < Vec < String > > ( ) ;
134
-
134
+
135
135
let query = format ! (
136
136
"SELECT {}, CAST({} AS REAL) FROM {} ORDER BY RANDOM()" ,
137
137
features. clone( ) . join( ", " ) ,
@@ -151,11 +151,22 @@ mod pgml_rust {
151
151
num_rows += 1 ;
152
152
} ) ;
153
153
154
+ num_features = features. len ( ) ;
155
+
154
156
Ok ( Some ( ( ) ) )
155
157
} ) ;
156
158
157
- let mut dtrain = DMatrix :: from_dense ( & x, num_rows) . unwrap ( ) ;
158
- dtrain. set_labels ( & y) . unwrap ( ) ;
159
+ // todo parameterize test split instead of 0.5
160
+ let test_rows = ( num_rows as f32 * 0.5 ) . round ( ) as usize ;
161
+ let train_rows = num_rows - test_rows;
162
+ let mut dtrain = DMatrix :: from_dense ( & x[ ..train_rows * num_features] , train_rows) . unwrap ( ) ;
163
+ let mut dtest = DMatrix :: from_dense ( & x[ train_rows * num_features..] , test_rows) . unwrap ( ) ;
164
+ dtrain. set_labels ( & y[ ..train_rows] ) . unwrap ( ) ;
165
+ dtest. set_labels ( & y[ train_rows..] ) . unwrap ( ) ;
166
+
167
+
168
+ // specify datasets to evaluate against during training
169
+ let evaluation_sets = & [ ( & dtrain, "train" ) , ( & dtest, "test" ) ] ;
159
170
160
171
// configure objectives, metrics, etc.
161
172
let learning_params = parameters:: learning:: LearningTaskParametersBuilder :: default ( )
@@ -186,8 +197,6 @@ mod pgml_rust {
186
197
. build ( )
187
198
. unwrap ( ) ;
188
199
189
- // specify datasets to evaluate against during training
190
- // let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];
191
200
192
201
// overall configuration for training/evaluation
193
202
let params = parameters:: TrainingParametersBuilder :: default ( )
@@ -197,7 +206,7 @@ mod pgml_rust {
197
206
None => 2 ,
198
207
} ) // number of training iterations
199
208
. booster_params ( booster_params) // model parameters
200
- // .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration
209
+ . evaluation_sets ( Some ( evaluation_sets) ) // optional datasets to evaluate against in each iteration
201
210
. build ( )
202
211
. unwrap ( ) ;
203
212
0 commit comments