Skip to content

Commit 9303cb4

Browse files
authored
Fix bug that shape mismatch error in predict when changing objective to softmax and update rust-xgboost commit (#1636)
1 parent b6cd734 commit 9303cb4

File tree

8 files changed

+48
-25
lines changed

8 files changed

+48
-25
lines changed

pgml-extension/Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/src/bindings/lightgbm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl Bindings for Estimator {
100100
}
101101

102102
/// Deserialize self from bytes, with additional context
103-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
103+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
104104
where
105105
Self: Sized,
106106
{

pgml-extension/src/bindings/linfa.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize};
88

99
use super::Bindings;
1010
use crate::orm::*;
11+
use pgrx::*;
1112

1213
#[derive(Debug, Serialize, Deserialize)]
1314
pub struct LinearRegression {
@@ -58,7 +59,7 @@ impl Bindings for LinearRegression {
5859
}
5960

6061
/// Deserialize self from bytes, with additional context
61-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
62+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
6263
where
6364
Self: Sized,
6465
{
@@ -187,7 +188,7 @@ impl Bindings for LogisticRegression {
187188
}
188189

189190
/// Deserialize self from bytes, with additional context
190-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
191+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
191192
where
192193
Self: Sized,
193194
{
@@ -261,7 +262,7 @@ impl Bindings for Svm {
261262
}
262263

263264
/// Deserialize self from bytes, with additional context
264-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
265+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
265266
where
266267
Self: Sized,
267268
{

pgml-extension/src/bindings/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ pub trait Bindings: Send + Sync + Debug + AToAny {
106106
fn to_bytes(&self) -> Result<Vec<u8>>;
107107

108108
/// Deserialize self from bytes, with additional context
109-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
109+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
110110
where
111111
Self: Sized;
112112
}

pgml-extension/src/bindings/sklearn/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ impl Bindings for Estimator {
197197
}
198198

199199
/// Deserialize self from bytes, with additional context
200-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
200+
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
201201
where
202202
Self: Sized,
203203
{

pgml-extension/src/bindings/xgboost.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,18 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, objective: learning::Object
288288
Err(e) => error!("Failed to train model:\n\n{}", e),
289289
};
290290

291-
Ok(Box::new(Estimator { estimator: booster }))
291+
let softmax_objective = match hyperparams.get("objective") {
292+
Some(value) => match value.as_str().unwrap() {
293+
"multi:softmax" => true,
294+
_ => false,
295+
},
296+
None => false,
297+
};
298+
Ok(Box::new(Estimator { softmax_objective, estimator: booster }))
292299
}
293300

294301
pub struct Estimator {
302+
softmax_objective: bool,
295303
estimator: xgboost::Booster,
296304
}
297305

@@ -308,6 +316,9 @@ impl Bindings for Estimator {
308316
fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result<Vec<f32>> {
309317
let x = DMatrix::from_dense(features, features.len() / num_features)?;
310318
let y = self.estimator.predict(&x)?;
319+
if self.softmax_objective {
320+
return Ok(y);
321+
}
311322
Ok(match num_classes {
312323
0 => y,
313324
_ => y
@@ -340,7 +351,7 @@ impl Bindings for Estimator {
340351
}
341352

342353
/// Deserialize self from bytes, with additional context
343-
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
354+
fn from_bytes(bytes: &[u8], hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
344355
where
345356
Self: Sized,
346357
{
@@ -366,6 +377,12 @@ impl Bindings for Estimator {
366377
.set_param("nthread", &concurrency.to_string())
367378
.map_err(|e| anyhow!("could not set nthread XGBoost parameter: {e}"))?;
368379

369-
Ok(Box::new(Estimator { estimator }))
380+
let objective_opt = hyperparams.0.get("objective").and_then(|v| v.as_str());
381+
let softmax_objective = match objective_opt {
382+
Some("multi:softmax") => true,
383+
_ => false,
384+
};
385+
386+
Ok(Box::new(Estimator { softmax_objective, estimator }))
370387
}
371388
}

pgml-extension/src/orm/file.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
3131
let mut runtime: Option<String> = None;
3232
let mut algorithm: Option<String> = None;
3333
let mut task: Option<String> = None;
34+
let mut hyperparams: Option<JsonB> = None;
3435

3536
Spi::connect(|client| {
3637
let result = client
@@ -39,7 +40,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
3940
data,
4041
runtime::TEXT,
4142
algorithm::TEXT,
42-
task::TEXT
43+
task::TEXT,
44+
hyperparams
4345
FROM pgml.models
4446
INNER JOIN pgml.files
4547
ON models.id = files.model_id
@@ -66,6 +68,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
6668
runtime = result.get(2).expect("Runtime for model is corrupted.");
6769
algorithm = result.get(3).expect("Algorithm for model is corrupted.");
6870
task = result.get(4).expect("Task for project is corrupted.");
71+
hyperparams = result.get(5).expect("Hyperparams for model is corrupted.");
6972
}
7073
});
7174

@@ -83,6 +86,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
8386
let runtime = Runtime::from_str(&runtime.unwrap()).unwrap();
8487
let algorithm = Algorithm::from_str(&algorithm.unwrap()).unwrap();
8588
let task = Task::from_str(&task.unwrap()).unwrap();
89+
let hyperparams = hyperparams.unwrap();
8690

8791
debug1!(
8892
"runtime = {:?}, algorithm = {:?}, task = {:?}",
@@ -94,22 +98,22 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
9498
let bindings: Box<dyn Bindings> = match runtime {
9599
Runtime::rust => {
96100
match algorithm {
97-
Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data)?,
98-
Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data)?,
101+
Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data, &hyperparams)?,
102+
Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data, &hyperparams)?,
99103
Algorithm::linear => match task {
100-
Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data)?,
104+
Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data, &hyperparams)?,
101105
Task::classification => {
102-
crate::bindings::linfa::LogisticRegression::from_bytes(&data)?
106+
crate::bindings::linfa::LogisticRegression::from_bytes(&data, &hyperparams)?
103107
}
104108
_ => error!("Rust runtime only supports `classification` and `regression` task types for linear algorithms."),
105109
},
106-
Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data)?,
110+
Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data, &hyperparams)?,
107111
_ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams),
108112
}
109113
}
110114

111115
#[cfg(feature = "python")]
112-
Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data)?,
116+
Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data, &hyperparams)?,
113117

114118
#[cfg(not(feature = "python"))]
115119
Runtime::python => {

pgml-extension/src/orm/model.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ impl Model {
360360
)
361361
.unwrap()
362362
.unwrap();
363+
let hyperparams = result.get(11).unwrap().unwrap();
363364

364365
let bindings: Box<dyn Bindings> = match runtime {
365366
Runtime::openai => {
@@ -369,27 +370,27 @@ impl Model {
369370
Runtime::rust => {
370371
match algorithm {
371372
Algorithm::xgboost => {
372-
xgboost::Estimator::from_bytes(&data)?
373+
xgboost::Estimator::from_bytes(&data, &hyperparams)?
373374
}
374375
Algorithm::lightgbm => {
375-
lightgbm::Estimator::from_bytes(&data)?
376+
lightgbm::Estimator::from_bytes(&data, &hyperparams)?
376377
}
377378
Algorithm::linear => match project.task {
378379
Task::regression => {
379-
linfa::LinearRegression::from_bytes(&data)?
380+
linfa::LinearRegression::from_bytes(&data, &hyperparams)?
380381
}
381382
Task::classification => {
382-
linfa::LogisticRegression::from_bytes(&data)?
383+
linfa::LogisticRegression::from_bytes(&data, &hyperparams)?
383384
}
384385
_ => bail!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."),
385386
},
386-
Algorithm::svm => linfa::Svm::from_bytes(&data)?,
387+
Algorithm::svm => linfa::Svm::from_bytes(&data, &hyperparams)?,
387388
_ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams),
388389
}
389390
}
390391

391392
#[cfg(feature = "python")]
392-
Runtime::python => sklearn::Estimator::from_bytes(&data)?,
393+
Runtime::python => sklearn::Estimator::from_bytes(&data, &hyperparams)?,
393394

394395
#[cfg(not(feature = "python"))]
395396
Runtime::python => {
@@ -409,7 +410,7 @@ impl Model {
409410
snapshot_id,
410411
algorithm,
411412
runtime,
412-
hyperparams: result.get(6).unwrap().unwrap(),
413+
hyperparams: hyperparams,
413414
status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(),
414415
metrics: result.get(8).unwrap(),
415416
search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()),

0 commit comments

Comments
 (0)