Skip to content

Commit 84f400b

Browse files
authored
Merge pull request #8 from postgresml/levkk-retrain
Allow to retrain the same project
2 parents e291d94 + c24419e commit 84f400b

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

pgml/pgml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
def version():
2-
return "0.3"
2+
return "0.4"

pgml/pgml/model.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -420,23 +420,43 @@ def train(
420420
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
421421
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
422422
"""
423-
project = Project.create(project_name, objective)
424-
snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling)
425-
best_model = None
426-
best_error = None
427423
if objective == "regression":
428424
algorithms = ["linear", "random_forest"]
429425
elif objective == "classification":
430426
algorithms = ["random_forest"]
431427
else:
432428
raise PgMLException(
433-
f"Unknown objective '{objective}', available options are: regression, classification"
429+
f"Unknown objective `{objective}`, available options are: regression, classification."
430+
)
431+
432+
try:
433+
project = Project.find_by_name(project_name)
434+
except PgMLException:
435+
project = Project.create(project_name, objective)
436+
437+
if project.objective != objective:
438+
raise PgMLException(
439+
f"Project `{project_name}` already exists with a different objective: `{project.objective}`. Create a new project instead."
434440
)
435441

442+
snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling)
443+
deployed = Model.find_deployed(project.id)
444+
445+
# Let's assume that the deployed model is better for now.
446+
best_model = deployed
447+
best_error = best_model.mean_squared_error if best_model else None
448+
436449
for algorithm_name in algorithms:
437450
model = Model.create(project, snapshot, algorithm_name)
438451
model.fit(snapshot)
452+
453+
# Find the better model and deploy that.
439454
if best_error is None or model.mean_squared_error < best_error:
440455
best_error = model.mean_squared_error
441456
best_model = model
442-
best_model.deploy()
457+
458+
if deployed and deployed.id == best_model.id:
459+
return "rolled back"
460+
else:
461+
best_model.deploy()
462+
return "deployed"

sql/install.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ RETURNS TABLE(project_name TEXT, objective TEXT, status TEXT)
109109
AS $$
110110
from pgml.model import train
111111

112-
train(project_name, objective, relation_name, y_column_name)
112+
status = train(project_name, objective, relation_name, y_column_name)
113113

114-
return [(project_name, objective, "deployed")]
114+
return [(project_name, objective, status)]
115115
$$ LANGUAGE plpython3u;
116116

117117
---

0 commit comments

Comments
 (0)