Skip to content

Commit 4612b4f

Browse files
committed
Refactor the initialization of GUC parameters.
Managing GUC parameters in different places is hard to maintain. This patch organizes GUC definitions in a single place. Also, we use define_xxx_guc() APIs to define these parameters and it will allow us to manage GucContext, GucFlags in future. P.S., the test case test_trusted_model doesn't seem correct. I fixed it in this patch.
1 parent 0842673 commit 4612b4f

File tree

6 files changed

+88
-44
lines changed

6 files changed

+88
-44
lines changed

pgml-extension/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ serde = { version = "1.0" }
4949
serde_json = { version = "1.0", features = ["preserve_order"] }
5050
typetag = "0.2"
5151
xgboost = { git = "https://github.com/postgresml/rust-xgboost", branch = "master" }
52+
lazy_static = "1.4.0"
5253

5354
[dev-dependencies]
5455
pgrx-tests = "=0.11.2"

pgml-extension/src/bindings/python/mod.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ use pgrx::*;
66
use pyo3::prelude::*;
77
use pyo3::types::PyTuple;
88

9-
use crate::config::get_config;
9+
use crate::config::PGML_VENV;
1010
use crate::create_pymodule;
1111

12-
static CONFIG_NAME: &str = "pgml.venv";
13-
1412
create_pymodule!("/src/bindings/python/python.py");
1513

1614
pub fn activate_venv(venv: &str) -> Result<bool> {
@@ -23,8 +21,8 @@ pub fn activate_venv(venv: &str) -> Result<bool> {
2321
}
2422

2523
pub fn activate() -> Result<bool> {
26-
match get_config(CONFIG_NAME) {
27-
Some(venv) => activate_venv(&venv),
24+
match PGML_VENV.1.get() {
25+
Some(venv) => activate_venv(&venv.to_string_lossy()),
2826
None => Ok(false),
2927
}
3028
}

pgml-extension/src/bindings/transformers/whitelist.rs

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,54 @@
11
use anyhow::{bail, Error};
2+
use pgrx::GucSetting;
23
#[cfg(any(test, feature = "pg_test"))]
34
use pgrx::{pg_schema, pg_test};
45
use serde_json::Value;
6+
use std::ffi::CStr;
57

6-
use crate::config::get_config;
7-
8-
static CONFIG_HF_WHITELIST: &str = "pgml.huggingface_whitelist";
9-
static CONFIG_HF_TRUST_REMOTE_CODE_BOOL: &str = "pgml.huggingface_trust_remote_code";
10-
static CONFIG_HF_TRUST_WHITELIST: &str = "pgml.huggingface_trust_remote_code_whitelist";
8+
use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_WHITELIST, PGML_HF_WHITELIST};
119

1210
/// Verify that the model in the task JSON is allowed based on the huggingface whitelists.
1311
pub fn verify_task(task: &Value) -> Result<(), Error> {
1412
let task_model = match get_model_name(task) {
1513
Some(model) => model.to_string(),
1614
None => return Ok(()),
1715
};
18-
let whitelisted_models = config_csv_list(CONFIG_HF_WHITELIST);
16+
let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST.1);
1917

2018
let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model);
2119
if !model_is_allowed {
22-
bail!("model {task_model} is not whitelisted. Consider adding to {CONFIG_HF_WHITELIST} in postgresql.conf");
20+
bail!(
21+
"model {} is not whitelisted. Consider adding to {} in postgresql.conf",
22+
task_model,
23+
PGML_HF_WHITELIST.0
24+
);
2325
}
2426

2527
let task_trust = get_trust_remote_code(task);
26-
let trust_remote_code = get_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL)
27-
.map(|v| v == "true")
28-
.unwrap_or(true);
28+
let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.1.get();
2929

30-
let trusted_models = config_csv_list(CONFIG_HF_TRUST_WHITELIST);
30+
let trusted_models = config_csv_list(&PGML_HF_TRUST_WHITELIST.1);
3131

3232
let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model);
3333

3434
let remote_code_allowed = trust_remote_code && model_is_trusted;
3535
if !remote_code_allowed && task_trust == Some(true) {
36-
bail!("model {task_model} is not trusted to run remote code. Consider setting {CONFIG_HF_TRUST_REMOTE_CODE_BOOL} = 'true' or adding {task_model} to {CONFIG_HF_TRUST_WHITELIST}");
36+
bail!(
37+
"model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}",
38+
task_model,
39+
PGML_HF_TRUST_REMOTE_CODE.0,
40+
task_model,
41+
PGML_HF_TRUST_WHITELIST.0
42+
);
3743
}
3844

3945
Ok(())
4046
}
4147

42-
fn config_csv_list(name: &str) -> Vec<String> {
43-
match get_config(name) {
48+
fn config_csv_list(csv_list: &GucSetting<Option<&'static CStr>>) -> Vec<String> {
49+
match csv_list.get() {
4450
Some(value) => value
51+
.to_string_lossy()
4552
.trim_matches('"')
4653
.split(',')
4754
.filter_map(|s| if s.is_empty() { None } else { Some(s.to_string()) })
@@ -122,7 +129,7 @@ mod tests {
122129
#[pg_test]
123130
fn test_empty_whitelist() {
124131
let model = "Salesforce/xgen-7b-8k-inst";
125-
set_config(CONFIG_HF_WHITELIST, "").unwrap();
132+
set_config(PGML_HF_WHITELIST.0, "").unwrap();
126133
let task_json = format!(json_template!(), model, false);
127134
let task: Value = serde_json::from_str(&task_json).unwrap();
128135
assert!(verify_task(&task).is_ok());
@@ -131,12 +138,12 @@ mod tests {
131138
#[pg_test]
132139
fn test_nonempty_whitelist() {
133140
let model = "Salesforce/xgen-7b-8k-inst";
134-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
141+
set_config(PGML_HF_WHITELIST.0, model).unwrap();
135142
let task_json = format!(json_template!(), model, false);
136143
let task: Value = serde_json::from_str(&task_json).unwrap();
137144
assert!(verify_task(&task).is_ok());
138145

139-
set_config(CONFIG_HF_WHITELIST, "other_model").unwrap();
146+
set_config(PGML_HF_WHITELIST.0, "other_model").unwrap();
140147
let task_json = format!(json_template!(), model, false);
141148
let task: Value = serde_json::from_str(&task_json).unwrap();
142149
assert!(verify_task(&task).is_err());
@@ -145,18 +152,18 @@ mod tests {
145152
#[pg_test]
146153
fn test_trusted_model() {
147154
let model = "Salesforce/xgen-7b-8k-inst";
148-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
149-
set_config(CONFIG_HF_TRUST_WHITELIST, model).unwrap();
155+
set_config(PGML_HF_WHITELIST.0, model).unwrap();
156+
set_config(PGML_HF_TRUST_WHITELIST.0, model).unwrap();
150157

151158
let task_json = format!(json_template!(), model, false);
152159
let task: Value = serde_json::from_str(&task_json).unwrap();
153160
assert!(verify_task(&task).is_ok());
154161

155162
let task_json = format!(json_template!(), model, true);
156163
let task: Value = serde_json::from_str(&task_json).unwrap();
157-
assert!(verify_task(&task).is_ok());
164+
assert!(verify_task(&task).is_err());
158165

159-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap();
166+
set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap();
160167
let task_json = format!(json_template!(), model, false);
161168
let task: Value = serde_json::from_str(&task_json).unwrap();
162169
assert!(verify_task(&task).is_ok());
@@ -169,8 +176,8 @@ mod tests {
169176
#[pg_test]
170177
fn test_untrusted_model() {
171178
let model = "Salesforce/xgen-7b-8k-inst";
172-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
173-
set_config(CONFIG_HF_TRUST_WHITELIST, "other_model").unwrap();
179+
set_config(PGML_HF_WHITELIST.0, model).unwrap();
180+
set_config(PGML_HF_TRUST_WHITELIST.0, "other_model").unwrap();
174181

175182
let task_json = format!(json_template!(), model, false);
176183
let task: Value = serde_json::from_str(&task_json).unwrap();
@@ -180,7 +187,7 @@ mod tests {
180187
let task: Value = serde_json::from_str(&task_json).unwrap();
181188
assert!(verify_task(&task).is_err());
182189

183-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap();
190+
set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap();
184191
let task_json = format!(json_template!(), model, false);
185192
let task: Value = serde_json::from_str(&task_json).unwrap();
186193
assert!(verify_task(&task).is_ok());

pgml-extension/src/config.rs

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,58 @@
1+
use lazy_static::lazy_static;
2+
use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting};
13
use std::ffi::CStr;
24

35
#[cfg(any(test, feature = "pg_test"))]
46
use pgrx::{pg_schema, pg_test};
5-
use pgrx_pg_sys::AsPgCStr;
67

7-
pub fn get_config(name: &str) -> Option<String> {
8-
// SAFETY: name is not null because it is a Rust reference.
9-
let ptr = unsafe { pgrx_pg_sys::GetConfigOption(name.as_pg_cstr(), true, false) };
10-
(!ptr.is_null()).then(move || {
11-
// SAFETY: assuming pgrx_pg_sys is providing a valid, null terminated pointer.
12-
unsafe { CStr::from_ptr(ptr) }.to_string_lossy().to_string()
13-
})
8+
lazy_static! {
9+
pub static ref PGML_VENV: (&'static str, GucSetting<Option<&'static CStr>>) =
10+
("pgml.venv", GucSetting::<Option<&'static CStr>>::new(None));
11+
pub static ref PGML_HF_WHITELIST: (&'static str, GucSetting<Option<&'static CStr>>) = (
12+
"pgml.huggingface_whitelist",
13+
GucSetting::<Option<&'static CStr>>::new(None),
14+
);
15+
pub static ref PGML_HF_TRUST_REMOTE_CODE: (&'static str, GucSetting<bool>) =
16+
("pgml.huggingface_trust_remote_code", GucSetting::<bool>::new(false));
17+
pub static ref PGML_HF_TRUST_WHITELIST: (&'static str, GucSetting<Option<&'static CStr>>) = (
18+
"pgml.huggingface_trust_remote_code_whitelist",
19+
GucSetting::<Option<&'static CStr>>::new(None),
20+
);
21+
}
22+
23+
pub fn initialize_server_params() {
24+
GucRegistry::define_string_guc(
25+
PGML_VENV.0,
26+
"Python's virtual environment path",
27+
"",
28+
&PGML_VENV.1,
29+
GucContext::Userset,
30+
GucFlags::default(),
31+
);
32+
GucRegistry::define_string_guc(
33+
PGML_HF_WHITELIST.0,
34+
"Models allowed to be downloaded from huggingface",
35+
"",
36+
&PGML_HF_WHITELIST.1,
37+
GucContext::Userset,
38+
GucFlags::default(),
39+
);
40+
GucRegistry::define_bool_guc(
41+
PGML_HF_TRUST_REMOTE_CODE.0,
42+
"Whether model can execute remote codes",
43+
"",
44+
&PGML_HF_TRUST_REMOTE_CODE.1,
45+
GucContext::Userset,
46+
GucFlags::default(),
47+
);
48+
GucRegistry::define_string_guc(
49+
PGML_HF_TRUST_WHITELIST.0,
50+
"Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'",
51+
"",
52+
&PGML_HF_TRUST_WHITELIST.1,
53+
GucContext::Userset,
54+
GucFlags::default(),
55+
);
1456
}
1557

1658
#[cfg(any(test, feature = "pg_test"))]
@@ -26,17 +68,11 @@ pub fn set_config(name: &str, value: &str) -> Result<(), pgrx::spi::Error> {
2668
mod tests {
2769
use super::*;
2870

29-
#[pg_test]
30-
fn read_config_max_connections() {
31-
let name = "max_connections";
32-
assert_eq!(get_config(name), Some("100".into()));
33-
}
34-
3571
#[pg_test]
3672
fn read_pgml_huggingface_whitelist() {
3773
let name = "pgml.huggingface_whitelist";
3874
let value = "meta-llama/Llama-2-7b";
3975
set_config(name, value).unwrap();
40-
assert_eq!(get_config(name), Some(value.into()));
76+
assert_eq!(PGML_HF_WHITELIST.1.get().unwrap().to_string_lossy(), value);
4177
}
4278
}

pgml-extension/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ extension_sql_file!("../sql/schema.sql", name = "schema");
2424
#[cfg(not(feature = "use_as_lib"))]
2525
#[pg_guard]
2626
pub extern "C" fn _PG_init() {
27+
config::initialize_server_params();
2728
bindings::python::activate().expect("Error setting python venv");
2829
orm::project::init();
2930
}

0 commit comments

Comments
 (0)