Skip to content

Fix bug that shape mismatch error in predict when changing objective to softmax and update rust-xgboost commit #1636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pgml-extension/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pgml-extension/src/bindings/lightgbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl Bindings for Estimator {
}

/// Deserialize self from bytes, with additional context
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
where
Self: Sized,
{
Expand Down
7 changes: 4 additions & 3 deletions pgml-extension/src/bindings/linfa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize};

use super::Bindings;
use crate::orm::*;
use pgrx::*;

#[derive(Debug, Serialize, Deserialize)]
pub struct LinearRegression {
Expand Down Expand Up @@ -58,7 +59,7 @@ impl Bindings for LinearRegression {
}

/// Deserialize self from bytes, with additional context
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
where
Self: Sized,
{
Expand Down Expand Up @@ -187,7 +188,7 @@ impl Bindings for LogisticRegression {
}

/// Deserialize self from bytes, with additional context
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
where
Self: Sized,
{
Expand Down Expand Up @@ -261,7 +262,7 @@ impl Bindings for Svm {
}

/// Deserialize self from bytes, with additional context
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
where
Self: Sized,
{
Expand Down
2 changes: 1 addition & 1 deletion pgml-extension/src/bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ pub trait Bindings: Send + Sync + Debug + AToAny {
fn to_bytes(&self) -> Result<Vec<u8>>;

/// Deserialize self from bytes, with additional context
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
where
Self: Sized;
}
Expand Down
2 changes: 1 addition & 1 deletion pgml-extension/src/bindings/sklearn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ impl Bindings for Estimator {
}

/// Deserialize self from bytes, with additional context
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
fn from_bytes(bytes: &[u8], _hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
where
Self: Sized,
{
Expand Down
23 changes: 20 additions & 3 deletions pgml-extension/src/bindings/xgboost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,18 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, objective: learning::Object
Err(e) => error!("Failed to train model:\n\n{}", e),
};

Ok(Box::new(Estimator { estimator: booster }))
let softmax_objective = match hyperparams.get("objective") {
Some(value) => match value.as_str().unwrap() {
"multi:softmax" => true,
_ => false,
},
None => false,
};
Ok(Box::new(Estimator { softmax_objective, estimator: booster }))
}

pub struct Estimator {
softmax_objective: bool,
estimator: xgboost::Booster,
}

Expand All @@ -308,6 +316,9 @@ impl Bindings for Estimator {
fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result<Vec<f32>> {
let x = DMatrix::from_dense(features, features.len() / num_features)?;
let y = self.estimator.predict(&x)?;
if self.softmax_objective {
return Ok(y);
}
Ok(match num_classes {
0 => y,
_ => y
Expand Down Expand Up @@ -340,7 +351,7 @@ impl Bindings for Estimator {
}

/// Deserialize self from bytes, with additional context
fn from_bytes(bytes: &[u8]) -> Result<Box<dyn Bindings>>
fn from_bytes(bytes: &[u8], hyperparams: &JsonB) -> Result<Box<dyn Bindings>>
where
Self: Sized,
{
Expand All @@ -366,6 +377,12 @@ impl Bindings for Estimator {
.set_param("nthread", &concurrency.to_string())
.map_err(|e| anyhow!("could not set nthread XGBoost parameter: {e}"))?;

Ok(Box::new(Estimator { estimator }))
let objective_opt = hyperparams.0.get("objective").and_then(|v| v.as_str());
let softmax_objective = match objective_opt {
Some("multi:softmax") => true,
_ => false,
};

Ok(Box::new(Estimator { softmax_objective, estimator }))
}
}
18 changes: 11 additions & 7 deletions pgml-extension/src/orm/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
let mut runtime: Option<String> = None;
let mut algorithm: Option<String> = None;
let mut task: Option<String> = None;
let mut hyperparams: Option<JsonB> = None;

Spi::connect(|client| {
let result = client
Expand All @@ -39,7 +40,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
data,
runtime::TEXT,
algorithm::TEXT,
task::TEXT
task::TEXT,
hyperparams
FROM pgml.models
INNER JOIN pgml.files
ON models.id = files.model_id
Expand All @@ -66,6 +68,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
runtime = result.get(2).expect("Runtime for model is corrupted.");
algorithm = result.get(3).expect("Algorithm for model is corrupted.");
task = result.get(4).expect("Task for project is corrupted.");
hyperparams = result.get(5).expect("Hyperparams for model is corrupted.");
}
});

Expand All @@ -83,6 +86,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
let runtime = Runtime::from_str(&runtime.unwrap()).unwrap();
let algorithm = Algorithm::from_str(&algorithm.unwrap()).unwrap();
let task = Task::from_str(&task.unwrap()).unwrap();
let hyperparams = hyperparams.unwrap();

debug1!(
"runtime = {:?}, algorithm = {:?}, task = {:?}",
Expand All @@ -94,22 +98,22 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
let bindings: Box<dyn Bindings> = match runtime {
Runtime::rust => {
match algorithm {
Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data)?,
Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data)?,
Algorithm::xgboost => crate::bindings::xgboost::Estimator::from_bytes(&data, &hyperparams)?,
Algorithm::lightgbm => crate::bindings::lightgbm::Estimator::from_bytes(&data, &hyperparams)?,
Algorithm::linear => match task {
Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data)?,
Task::regression => crate::bindings::linfa::LinearRegression::from_bytes(&data, &hyperparams)?,
Task::classification => {
crate::bindings::linfa::LogisticRegression::from_bytes(&data)?
crate::bindings::linfa::LogisticRegression::from_bytes(&data, &hyperparams)?
}
_ => error!("Rust runtime only supports `classification` and `regression` task types for linear algorithms."),
},
Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data)?,
Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data, &hyperparams)?,
_ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams),
}
}

#[cfg(feature = "python")]
Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data)?,
Runtime::python => crate::bindings::sklearn::Estimator::from_bytes(&data, &hyperparams)?,

#[cfg(not(feature = "python"))]
Runtime::python => {
Expand Down
15 changes: 8 additions & 7 deletions pgml-extension/src/orm/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ impl Model {
)
.unwrap()
.unwrap();
let hyperparams = result.get(11).unwrap().unwrap();

let bindings: Box<dyn Bindings> = match runtime {
Runtime::openai => {
Expand All @@ -369,27 +370,27 @@ impl Model {
Runtime::rust => {
match algorithm {
Algorithm::xgboost => {
xgboost::Estimator::from_bytes(&data)?
xgboost::Estimator::from_bytes(&data, &hyperparams)?
}
Algorithm::lightgbm => {
lightgbm::Estimator::from_bytes(&data)?
lightgbm::Estimator::from_bytes(&data, &hyperparams)?
}
Algorithm::linear => match project.task {
Task::regression => {
linfa::LinearRegression::from_bytes(&data)?
linfa::LinearRegression::from_bytes(&data, &hyperparams)?
}
Task::classification => {
linfa::LogisticRegression::from_bytes(&data)?
linfa::LogisticRegression::from_bytes(&data, &hyperparams)?
}
_ => bail!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."),
},
Algorithm::svm => linfa::Svm::from_bytes(&data)?,
Algorithm::svm => linfa::Svm::from_bytes(&data, &hyperparams)?,
_ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams),
}
}

#[cfg(feature = "python")]
Runtime::python => sklearn::Estimator::from_bytes(&data)?,
Runtime::python => sklearn::Estimator::from_bytes(&data, &hyperparams)?,

#[cfg(not(feature = "python"))]
Runtime::python => {
Expand All @@ -409,7 +410,7 @@ impl Model {
snapshot_id,
algorithm,
runtime,
hyperparams: result.get(6).unwrap().unwrap(),
hyperparams: hyperparams,
status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(),
metrics: result.get(8).unwrap(),
search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()),
Expand Down