Skip to content

Commit 79f9833

Browse files
authored
organize python related modules (#962)
1 parent 6bdcf00 commit 79f9833

File tree

14 files changed

+127
-120
lines changed

14 files changed

+127
-120
lines changed

pgml-extension/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ sacremoses==0.0.53
1717
scikit-learn==1.3.0
1818
sentencepiece==0.1.99
1919
sentence-transformers==2.2.2
20+
tokenizers==0.13.3
2021
torch==2.0.1
2122
torchaudio==2.0.2
2223
torchvision==0.15.2

pgml-extension/src/api.rs

Lines changed: 14 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ use pgrx::iter::{SetOfIterator, TableIterator};
66
use pgrx::*;
77

88
#[cfg(feature = "python")]
9-
use pyo3::prelude::*;
109
use serde_json::json;
1110

1211
#[cfg(feature = "python")]
13-
use crate::bindings::sklearn::package_version;
1412
use crate::orm::*;
1513

1614
macro_rules! unwrap_or_error {
@@ -25,38 +23,13 @@ macro_rules! unwrap_or_error {
2523
#[cfg(feature = "python")]
2624
#[pg_extern]
2725
pub fn activate_venv(venv: &str) -> bool {
28-
unwrap_or_error!(crate::bindings::venv::activate_venv(venv))
26+
unwrap_or_error!(crate::bindings::python::activate_venv(venv))
2927
}
3028

3129
#[cfg(feature = "python")]
3230
#[pg_extern(immutable, parallel_safe)]
3331
pub fn validate_python_dependencies() -> bool {
34-
unwrap_or_error!(crate::bindings::venv::activate());
35-
36-
Python::with_gil(|py| {
37-
let sys = PyModule::import(py, "sys").unwrap();
38-
let version: String = sys.getattr("version").unwrap().extract().unwrap();
39-
info!("Python version: {version}");
40-
for module in ["xgboost", "lightgbm", "numpy", "sklearn"] {
41-
match py.import(module) {
42-
Ok(_) => (),
43-
Err(e) => {
44-
panic!(
45-
"The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"
46-
);
47-
}
48-
}
49-
}
50-
});
51-
52-
let sklearn = unwrap_or_error!(package_version("sklearn"));
53-
let xgboost = unwrap_or_error!(package_version("xgboost"));
54-
let lightgbm = unwrap_or_error!(package_version("lightgbm"));
55-
let numpy = unwrap_or_error!(package_version("numpy"));
56-
57-
info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);
58-
59-
true
32+
unwrap_or_error!(crate::bindings::python::validate_dependencies())
6033
}
6134

6235
#[cfg(not(feature = "python"))]
@@ -66,8 +39,7 @@ pub fn validate_python_dependencies() {}
6639
#[cfg(feature = "python")]
6740
#[pg_extern]
6841
pub fn python_package_version(name: &str) -> String {
69-
unwrap_or_error!(crate::bindings::venv::activate());
70-
unwrap_or_error!(package_version(name))
42+
unwrap_or_error!(crate::bindings::python::package_version(name))
7143
}
7244

7345
#[cfg(not(feature = "python"))]
@@ -79,13 +51,19 @@ pub fn python_package_version(name: &str) {
7951
#[cfg(feature = "python")]
8052
#[pg_extern]
8153
pub fn python_pip_freeze() -> TableIterator<'static, (name!(package, String),)> {
82-
unwrap_or_error!(crate::bindings::venv::activate());
54+
unwrap_or_error!(crate::bindings::python::pip_freeze())
55+
}
8356

84-
let packages = unwrap_or_error!(crate::bindings::venv::freeze())
85-
.into_iter()
86-
.map(|package| (package,));
57+
#[cfg(feature = "python")]
58+
#[pg_extern]
59+
fn python_version() -> String {
60+
unwrap_or_error!(crate::bindings::python::version())
61+
}
8762

88-
TableIterator::new(packages)
63+
#[cfg(not(feature = "python"))]
64+
#[pg_extern]
65+
pub fn python_version() -> String {
66+
String::from("Python is not installed, recompile with `--features python`")
8967
}
9068

9169
#[pg_extern]
@@ -104,26 +82,6 @@ pub fn validate_shared_library() {
10482
}
10583
}
10684

107-
#[cfg(feature = "python")]
108-
#[pg_extern]
109-
fn python_version() -> String {
110-
unwrap_or_error!(crate::bindings::venv::activate());
111-
let mut version = String::new();
112-
113-
Python::with_gil(|py| {
114-
let sys = PyModule::import(py, "sys").unwrap();
115-
version = sys.getattr("version").unwrap().extract().unwrap();
116-
});
117-
118-
version
119-
}
120-
121-
#[cfg(not(feature = "python"))]
122-
#[pg_extern]
123-
pub fn python_version() -> String {
124-
String::from("Python is not installed, recompile with `--features python`")
125-
}
126-
12785
#[pg_extern(immutable, parallel_safe)]
12886
fn version() -> String {
12987
crate::VERSION.to_string()

pgml-extension/src/bindings/langchain.rs renamed to pgml-extension/src/bindings/langchain/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ use pyo3::types::PyTuple;
66

77
use crate::{bindings::TracebackError, create_pymodule};
88

9-
create_pymodule!("/src/bindings/langchain.py");
9+
create_pymodule!("/src/bindings/langchain/langchain.py");
1010

1111
pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Result<Vec<String>> {
12-
crate::bindings::venv::activate()?;
12+
crate::bindings::python::activate()?;
1313

1414
let kwargs = serde_json::to_string(kwargs).unwrap();
1515

pgml-extension/src/bindings/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ pub mod langchain;
3838
pub mod lightgbm;
3939
pub mod linfa;
4040
#[cfg(feature = "python")]
41+
pub mod python;
42+
#[cfg(feature = "python")]
4143
pub mod sklearn;
4244
#[cfg(feature = "python")]
4345
pub mod transformers;
44-
#[cfg(feature = "python")]
45-
pub mod venv;
4646
pub mod xgboost;
4747

4848
pub type Fit = fn(dataset: &Dataset, hyperparams: &Hyperparams) -> Result<Box<dyn Bindings>>;
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//! Use virtualenv.
2+
3+
use anyhow::Result;
4+
use once_cell::sync::Lazy;
5+
use pgrx::iter::TableIterator;
6+
use pgrx::*;
7+
use pyo3::prelude::*;
8+
use pyo3::types::PyTuple;
9+
10+
use crate::config::get_config;
11+
use crate::{bindings::TracebackError, create_pymodule};
12+
13+
static CONFIG_NAME: &str = "pgml.venv";
14+
15+
create_pymodule!("/src/bindings/python/python.py");
16+
17+
pub fn activate_venv(venv: &str) -> Result<bool> {
18+
Python::with_gil(|py| {
19+
let activate_venv: Py<PyAny> = get_module!(PY_MODULE).getattr(py, "activate_venv")?;
20+
let result: Py<PyAny> =
21+
activate_venv.call1(py, PyTuple::new(py, &[venv.to_string().into_py(py)]))?;
22+
23+
Ok(result.extract(py)?)
24+
})
25+
}
26+
27+
pub fn activate() -> Result<bool> {
28+
match get_config(CONFIG_NAME) {
29+
Some(venv) => activate_venv(&venv),
30+
None => Ok(false),
31+
}
32+
}
33+
34+
pub fn pip_freeze() -> Result<TableIterator<'static, (name!(package, String),)>> {
35+
activate()?;
36+
let packages = Python::with_gil(|py| -> Result<Vec<String>> {
37+
let freeze = get_module!(PY_MODULE).getattr(py, "freeze")?;
38+
let result = freeze.call0(py)?;
39+
40+
Ok(result.extract(py)?)
41+
})?;
42+
43+
Ok(TableIterator::new(
44+
packages.into_iter().map(|package| (package,)),
45+
))
46+
}
47+
48+
pub fn validate_dependencies() -> Result<bool> {
49+
activate()?;
50+
Python::with_gil(|py| {
51+
let sys = PyModule::import(py, "sys").unwrap();
52+
let version: String = sys.getattr("version").unwrap().extract().unwrap();
53+
info!("Python version: {version}");
54+
for module in ["xgboost", "lightgbm", "numpy", "sklearn"] {
55+
match py.import(module) {
56+
Ok(_) => (),
57+
Err(e) => {
58+
panic!(
59+
"The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"
60+
);
61+
}
62+
}
63+
}
64+
});
65+
66+
let sklearn = package_version("sklearn")?;
67+
let xgboost = package_version("xgboost")?;
68+
let lightgbm = package_version("lightgbm")?;
69+
let numpy = package_version("numpy")?;
70+
71+
info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);
72+
73+
Ok(true)
74+
}
75+
76+
pub fn version() -> Result<String> {
77+
activate()?;
78+
Python::with_gil(|py| {
79+
let sys = PyModule::import(py, "sys").unwrap();
80+
let version: String = sys.getattr("version").unwrap().extract().unwrap();
81+
Ok(version)
82+
})
83+
}
84+
85+
pub fn package_version(name: &str) -> Result<String> {
86+
activate()?;
87+
Python::with_gil(|py| {
88+
let package = py.import(name)?;
89+
Ok(package.getattr("__version__")?.extract()?)
90+
})
91+
}

pgml-extension/src/bindings/sklearn.rs renamed to pgml-extension/src/bindings/sklearn/mod.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ use once_cell::sync::Lazy;
1515
use pyo3::prelude::*;
1616
use pyo3::types::PyTuple;
1717

18-
use crate::bindings::Bindings;
18+
use crate::{
19+
bindings::{Bindings, TracebackError},
20+
create_pymodule,
21+
orm::*,
22+
};
1923

20-
use crate::{bindings::TracebackError, create_pymodule, orm::*};
21-
22-
create_pymodule!("/src/bindings/sklearn.py");
24+
create_pymodule!("/src/bindings/sklearn/sklearn.py");
2325

2426
macro_rules! wrap_fit {
2527
($fn_name:tt, $task:literal) => {
@@ -355,10 +357,3 @@ pub fn cluster_metrics(
355357
Ok(scores)
356358
})
357359
}
358-
359-
pub fn package_version(name: &str) -> Result<String> {
360-
Python::with_gil(|py| {
361-
let package = py.import(name)?;
362-
Ok(package.getattr("__version__")?.extract()?)
363-
})
364-
}

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub fn transform(
2424
args: &serde_json::Value,
2525
inputs: Vec<&str>,
2626
) -> Result<serde_json::Value> {
27-
crate::bindings::venv::activate()?;
27+
crate::bindings::python::activate()?;
2828

2929
whitelist::verify_task(task)?;
3030

@@ -70,7 +70,7 @@ pub fn embed(
7070
inputs: Vec<&str>,
7171
kwargs: &serde_json::Value,
7272
) -> Result<Vec<Vec<f32>>> {
73-
crate::bindings::venv::activate()?;
73+
crate::bindings::python::activate()?;
7474

7575
let kwargs = serde_json::to_string(kwargs)?;
7676
Python::with_gil(|py| -> Result<Vec<Vec<f32>>> {
@@ -101,7 +101,7 @@ pub fn tune(
101101
hyperparams: &JsonB,
102102
path: &Path,
103103
) -> Result<HashMap<String, f64>> {
104-
crate::bindings::venv::activate()?;
104+
crate::bindings::python::activate()?;
105105

106106
let task = task.to_string();
107107
let hyperparams = serde_json::to_string(&hyperparams.0)?;
@@ -131,7 +131,7 @@ pub fn tune(
131131
}
132132

133133
pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<String>> {
134-
crate::bindings::venv::activate()?;
134+
crate::bindings::python::activate()?;
135135

136136
Python::with_gil(|py| -> Result<Vec<String>> {
137137
let generate = get_module!(PY_MODULE)
@@ -219,7 +219,7 @@ pub fn load_dataset(
219219
limit: Option<usize>,
220220
kwargs: &serde_json::Value,
221221
) -> Result<usize> {
222-
crate::bindings::venv::activate()?;
222+
crate::bindings::python::activate()?;
223223

224224
let kwargs = serde_json::to_string(kwargs)?;
225225

@@ -376,7 +376,7 @@ pub fn load_dataset(
376376
}
377377

378378
pub fn clear_gpu_cache(memory_usage: Option<f32>) -> Result<bool> {
379-
crate::bindings::venv::activate().unwrap();
379+
crate::bindings::python::activate().unwrap();
380380

381381
Python::with_gil(|py| -> Result<bool> {
382382
let clear_gpu_cache: Py<PyAny> = get_module!(PY_MODULE)

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
DataCollatorWithPadding,
3535
DefaultDataCollator,
3636
GenerationConfig,
37+
PegasusForConditionalGeneration,
38+
PegasusTokenizer,
3739
TrainingArguments,
3840
Trainer,
3941
)

pgml-extension/src/bindings/venv.rs

Lines changed: 0 additions & 40 deletions
This file was deleted.

pgml-extension/src/orm/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ impl Model {
8989
};
9090

9191
if runtime == Runtime::python {
92-
crate::bindings::venv::activate().unwrap();
92+
crate::bindings::python::activate().unwrap();
9393
}
9494

9595
let dataset = snapshot.tabular_dataset();

0 commit comments

Comments
 (0)