Skip to content

Commit a8d8218

Browse files
authored
fix and test preprocessing examples (#1520)
1 parent c3a8514 commit a8d8218

File tree

14 files changed

+70
-26
lines changed

14 files changed

+70
-26
lines changed

.github/workflows/ubuntu-packages-and-docker-image.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
workflow_dispatch:
55
inputs:
66
packageVersion:
7-
default: "2.8.2"
7+
default: "2.9.1"
88
jobs:
99
#
1010
# PostgresML extension.

pgml-cms/docs/resources/developer-docs/contributing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ SELECT pgml.version();
127127
postgres=# select pgml.version();
128128
version
129129
-------------------
130-
2.7.4
130+
2.9.1
131131
(1 row)
132132
```
133133
{% endtab %}

pgml-cms/docs/resources/developer-docs/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ CREATE EXTENSION
132132
pgml_test=# SELECT pgml.version();
133133
version
134134
---------
135-
2.7.4
135+
2.9.1
136136
(1 row)
137137
```
138138

pgml-cms/docs/resources/developer-docs/quick-start-with-docker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Time: 41.520 ms
8080
postgresml=# SELECT pgml.version();
8181
version
8282
---------
83-
2.7.13
83+
2.9.1
8484
(1 row)
8585
```
8686

pgml-cms/docs/resources/developer-docs/self-hosting/pooler.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,6 @@ Type "help" for help.
115115
postgresml=> SELECT pgml.version();
116116
version
117117
---------
118-
2.7.9
118+
2.9.1
119119
(1 row)
120120
```

pgml-extension/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "2.9.0"
3+
version = "2.9.1"
44
edition = "2021"
55

66
[lib]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
-- load the diamonds dataset, that contains text categorical variables
2+
SELECT pgml.load_dataset('jdxcosta/diamonds');
3+
4+
-- view the data
5+
SELECT * FROM pgml."jdxcosta/diamonds" LIMIT 10;
6+
7+
-- drop the Unamed column, since it's not useful for training (you could create a view instead)
8+
ALTER TABLE pgml."jdxcosta/diamonds" DROP COLUMN "Unnamed: 0";
9+
10+
-- train a model using preprocessors to scale the numeric variables, and target encode the categoricals
11+
SELECT pgml.train(
12+
project_name => 'Diamond prices',
13+
task => 'regression',
14+
relation_name => 'pgml.jdxcosta/diamonds',
15+
y_column_name => 'price',
16+
algorithm => 'lightgbm',
17+
preprocess => '{
18+
"carat": {"scale": "standard"},
19+
"depth": {"scale": "standard"},
20+
"table": {"scale": "standard"},
21+
"cut": {"encode": "target", "scale": "standard"},
22+
"color": {"encode": "target", "scale": "standard"},
23+
"clarity": {"encode": "target", "scale": "standard"}
24+
}'
25+
);
26+
27+
-- run some predictions, notice we're passing a heterogeneous row (tuple) as input, rather than a homogenous ARRAY[].
28+
SELECT price, pgml.predict('Diamond prices', (carat, cut, color, clarity, depth, "table", x, y, z)) AS prediction
29+
FROM pgml."jdxcosta/diamonds"
30+
LIMIT 10;
31+
32+
-- This is a difficult dataset for more algorithms, which makes it a good challenge for preprocessing, and additional
33+
-- feature engineering. What's next?

pgml-extension/sql/pgml--2.9.0--2.9.1.sql

Whitespace-only changes.

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ pub fn load_dataset(
380380
.ok_or(anyhow!("dataset `data` key is not an object"))?;
381381
let column_names = types
382382
.iter()
383-
.map(|(name, _type)| name.clone())
383+
.map(|(name, _type)| format!("\"{}\"", name))
384384
.collect::<Vec<String>>()
385385
.join(", ");
386386
let column_types = types
@@ -393,13 +393,14 @@ pub fn load_dataset(
393393
"int64" => "INT8",
394394
"int32" => "INT4",
395395
"int16" => "INT2",
396+
"int8" => "INT2",
396397
"float64" => "FLOAT8",
397398
"float32" => "FLOAT4",
398399
"float16" => "FLOAT4",
399400
"bool" => "BOOLEAN",
400401
_ => bail!("unhandled dataset feature while reading dataset: {type_}"),
401402
};
402-
Ok(format!("{name} {type_}"))
403+
Ok(format!("\"{name}\" {type_}"))
403404
})
404405
.collect::<Result<Vec<String>>>()?
405406
.join(", ");
@@ -455,7 +456,7 @@ pub fn load_dataset(
455456
.into_datum(),
456457
)),
457458
"dict" | "list" => row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())),
458-
"int64" | "int32" | "int16" => row.push((
459+
"int64" | "int32" | "int16" | "int8" => row.push((
459460
PgBuiltInOids::INT8OID.oid(),
460461
value
461462
.as_i64()

pgml-extension/src/orm/model.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -344,13 +344,12 @@ impl Model {
344344
).unwrap().first();
345345

346346
if !result.is_empty() {
347-
let project_id = result.get(2).unwrap().unwrap();
348-
let project = Project::find(project_id).unwrap();
349-
let snapshot_id = result.get(3).unwrap().unwrap();
350-
let snapshot = Snapshot::find(snapshot_id).unwrap();
351-
let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).unwrap();
352-
let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).unwrap();
353-
347+
let project_id = result.get(2).unwrap().expect("project_id is i64");
348+
let project = Project::find(project_id).expect("project doesn't exist");
349+
let snapshot_id = result.get(3).unwrap().expect("snapshot_id is i64");
350+
let snapshot = Snapshot::find(snapshot_id).expect("snapshot doesn't exist");
351+
let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).expect("algorithm is malformed");
352+
let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).expect("runtime is malformed");
354353
let data = Spi::get_one_with_args::<Vec<u8>>(
355354
"
356355
SELECT data

pgml-extension/src/orm/sampling.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl Sampling {
5555
Sampling::stratified => {
5656
format!(
5757
"
58-
SELECT *
58+
SELECT {col_string}
5959
FROM (
6060
SELECT
6161
*,
@@ -125,7 +125,7 @@ mod tests {
125125
let columns = get_column_fixtures();
126126
let sql = sampling.get_sql("my_table", columns);
127127
let expected_sql = "
128-
SELECT *
128+
SELECT \"col1\", \"col2\"
129129
FROM (
130130
SELECT
131131
*,

pgml-extension/src/orm/snapshot.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,24 @@ impl Column {
230230
if self.preprocessor.encode == Encode::target {
231231
let categories = self.statistics.categories.as_mut().unwrap();
232232
let mut sums = vec![0_f32; categories.len() + 1];
233+
let mut total = 0.;
233234
Zip::from(array).and(target).for_each(|&value, &target| {
235+
total += target;
234236
sums[value as usize] += target;
235237
});
238+
let avg_target = total / categories.len() as f32;
236239
for category in categories.values_mut() {
237-
let sum = sums[category.value as usize];
238-
category.value = sum / category.members as f32;
240+
if category.members > 0 {
241+
let sum = sums[category.value as usize];
242+
category.value = sum / category.members as f32;
243+
} else {
244+
// use avg target for categories w/ no members, e.g. __NULL__ category in a complete dataset
245+
category.value = avg_target;
246+
}
239247
}
240248
}
241249

242-
// Data is filtered for NaN because it is not well defined statistically, and they are counted as separate stat
250+
// Data is filtered for NaN because it is not well-defined statistically, and they are counted as separate stat
243251
let mut data = array
244252
.iter()
245253
.filter_map(|n| if n.is_nan() { None } else { Some(*n) })
@@ -404,7 +412,8 @@ impl Snapshot {
404412
.first();
405413
if !result.is_empty() {
406414
let jsonb: JsonB = result.get(7).unwrap().unwrap();
407-
let columns: Vec<Column> = serde_json::from_value(jsonb.0).unwrap();
415+
let columns: Vec<Column> =
416+
serde_json::from_value(jsonb.0).expect("invalid json description of columns");
408417
// let jsonb: JsonB = result.get(8).unwrap();
409418
// let analysis: Option<IndexMap<String, f32>> = Some(serde_json::from_value(jsonb.0).unwrap());
410419
let mut s = Snapshot {
@@ -500,9 +509,10 @@ impl Snapshot {
500509

501510
let preprocessors: HashMap<String, Preprocessor> = serde_json::from_value(preprocess.0).expect("is valid");
502511

512+
let mut position = 0; // Postgres column positions are not updated when other columns are dropped, but we expect consecutive positions when we read the table.
503513
Spi::connect(|mut client| {
504514
let mut columns: Vec<Column> = Vec::new();
505-
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN, ordinal_position::INTEGER FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
515+
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
506516
None,
507517
Some(vec![
508518
(PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()),
@@ -520,7 +530,7 @@ impl Snapshot {
520530
pg_type = pg_type[1..].to_string() + "[]";
521531
}
522532
let nullable = row[3].value::<bool>().unwrap().unwrap();
523-
let position = row[4].value::<i32>().unwrap().unwrap() as usize;
533+
position += 1;
524534
let label = match y_column_name {
525535
Some(ref y_column_name) => y_column_name.contains(&name),
526536
None => false,
@@ -1158,7 +1168,7 @@ impl Snapshot {
11581168
pub fn numeric_encoded_dataset(&mut self) -> Dataset {
11591169
let mut data = None;
11601170
Spi::connect(|client| {
1161-
// Postgres Arrays arrays are 1 indexed and so are SPI tuples...
1171+
// Postgres arrays are 1 indexed and so are SPI tuples...
11621172
let result = client.select(&self.select_sql(), None, None).unwrap();
11631173
let num_rows = result.len();
11641174
let (num_train_rows, num_test_rows) = self.train_test_split(num_rows);

pgml-extension/tests/test.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ SELECT pgml.load_dataset('wine');
3030
\i examples/regression.sql
3131
\i examples/vectors.sql
3232
\i examples/chunking.sql
33+
\i examples/preprocessing.sql
3334
-- transformers are generally too slow to run in the test suite
3435
--\i examples/transformers.sql

0 commit comments

Comments
 (0)