Skip to content

Commit 8354cd1

Browse files
authored
Fully qualify table name (#445)
1 parent d4b44e1 commit 8354cd1

File tree

6 files changed

+73
-18
lines changed

6 files changed

+73
-18
lines changed

pgml-dashboard/pgml_dashboard/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
DATABASES = {
109109
"default": {
110110
"ENGINE": "django.db.backends.postgresql",
111-
"OPTIONS": {"options": "-c search_path=pgml,public"},
111+
"OPTIONS": {"options": "-c search_path=public,pgml"},
112112
"NAME": database.path[1:],
113113
"USER": database.username,
114114
"PASSWORD": database.password,

pgml-extension/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "2.0.1"
3+
version = "2.0.2"
44
edition = "2021"
55

66
[lib]

pgml-extension/docker/entrypoint.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#!/bin/bash
22

33
# Exit on error, real CI
4-
set -e
5-
64
echo "Starting Postgres..."
75
service postgresql start
86

pgml-extension/src/api.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ mod tests {
538538
#[pg_test]
539539
fn test_snapshot_lifecycle() {
540540
load_diabetes(Some(25));
541+
541542
let snapshot = Snapshot::create(
542543
"pgml.diabetes",
543544
vec!["target".to_string()],
@@ -547,6 +548,19 @@ mod tests {
547548
assert!(snapshot.id > 0);
548549
}
549550

551+
#[pg_test]
552+
#[should_panic]
553+
fn test_not_fully_qualified_table() {
554+
load_diabetes(Some(25));
555+
556+
let result = std::panic::catch_unwind(|| {
557+
let _snapshot =
558+
Snapshot::create("diabetes", vec!["target".to_string()], 0.5, Sampling::last);
559+
});
560+
561+
assert!(result.is_err());
562+
}
563+
550564
#[pg_test]
551565
fn test_train_regression() {
552566
load_diabetes(None);

pgml-extension/src/orm/snapshot.rs

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ impl Snapshot {
193193
) -> Snapshot {
194194
let mut snapshot: Option<Snapshot> = None;
195195
let status = Status::in_progress;
196+
197+
// Validate table exists.
198+
let (_schema_name, _table_name) = Self::fully_qualified_table(relation_name);
199+
196200
Spi::connect(|client| {
197201
let result = client.select("INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status) VALUES ($1, $2, $3, $4::pgml.sampling, $5::pgml.status) RETURNING id, relation_name, y_column_name, test_size, test_sampling::TEXT, status::TEXT, columns, analysis, created_at, updated_at;",
198202
Some(1),
@@ -256,22 +260,61 @@ impl Snapshot {
256260
.unwrap() as usize
257261
}
258262

263+
fn fully_qualified_table(relation_name: &str) -> (String, String) {
264+
info!("Validating relation: {}", relation_name);
265+
266+
let parts = relation_name
267+
.split('.')
268+
.map(|name| name.to_string())
269+
.collect::<Vec<String>>();
270+
271+
let (schema_name, table_name) = match parts.len() {
272+
1 => (None, parts[0].clone()),
273+
2 => (Some(parts[0].clone()), parts[1].clone()),
274+
_ => error!(
275+
"Relation name \"{}\" is not parsable into schema name and table name",
276+
relation_name
277+
),
278+
};
279+
280+
match schema_name {
281+
None => {
282+
let table_count = Spi::get_one_with_args::<i64>("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'public'", vec![
283+
(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())
284+
]).unwrap();
285+
286+
let error = format!("Relation \"{}\" could not be found in the public schema. Please specify the table schema, e.g. pgml.{}", table_name, table_name);
287+
288+
match table_count {
289+
0 => error!("{}", error),
290+
1 => (String::from("public"), table_name),
291+
_ => error!("{}", error),
292+
}
293+
}
294+
295+
Some(schema_name) => {
296+
let exists = Spi::get_one_with_args::<i64>("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2", vec![
297+
(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()),
298+
(PgBuiltInOids::TEXTOID.oid(), schema_name.clone().into_datum()),
299+
]).unwrap();
300+
301+
if exists == 1 {
302+
(schema_name, table_name)
303+
} else {
304+
error!(
305+
"Relation \"{}\".\"{}\" doesn't exist",
306+
schema_name, table_name
307+
);
308+
}
309+
}
310+
}
311+
}
312+
259313
#[allow(clippy::format_push_string)]
260314
fn analyze(&mut self) {
315+
let (schema_name, table_name) = Self::fully_qualified_table(&self.relation_name);
316+
261317
Spi::connect(|client| {
262-
let parts = self
263-
.relation_name
264-
.split('.')
265-
.map(|name| name.to_string())
266-
.collect::<Vec<String>>();
267-
let (schema_name, table_name) = match parts.len() {
268-
1 => (String::from("public"), parts[0].clone()),
269-
2 => (parts[0].clone(), parts[1].clone()),
270-
_ => error!(
271-
"Relation name {} is not parsable into schema name and table name",
272-
self.relation_name
273-
),
274-
};
275318
let mut columns: Vec<Column> = Vec::new();
276319
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN, ordinal_position::INTEGER FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
277320
None,

0 commit comments

Comments
 (0)