9
9
from pgml .exceptions import PgMLException
10
10
from pgml .sql import q
11
11
12
+
12
13
class Project (object ):
13
14
"""
14
15
Use projects to refine multiple models of a particular dataset on a specific objective.
15
-
16
+
16
17
Attributes:
17
18
id (int): a unique identifier
18
19
name (str): a human friendly unique identifier
19
20
objective (str): the purpose of this project
20
21
created_at (Timestamp): when this project was created
21
22
updated_at (Timestamp): when this project was last updated
22
23
"""
23
-
24
+
24
25
_cache = {}
25
26
26
27
def __init__ (self ):
@@ -36,11 +37,14 @@ def find(cls, id: int):
36
37
Returns:
37
38
Project or None: instantiated from the database if found
38
39
"""
39
- result = plpy .execute (f"""
40
+ result = plpy .execute (
41
+ f"""
40
42
SELECT *
41
43
FROM pgml.projects
42
44
WHERE id = { q (id )}
43
- """ , 1 )
45
+ """ ,
46
+ 1 ,
47
+ )
44
48
if len (result ) == 0 :
45
49
return None
46
50
@@ -53,25 +57,28 @@ def find(cls, id: int):
53
57
@classmethod
54
58
def find_by_name (cls , name : str ):
55
59
"""
56
- Get a Project from the database by name.
57
-
60
+ Get a Project from the database by name.
61
+
58
62
This is the prefered API to retrieve projects, and they are cached by
59
63
name to avoid needing to go to he database on every usage.
60
-
64
+
61
65
Args:
62
66
name (str): the project name
63
67
Returns:
64
68
Project or None: instantiated from the database if found
65
69
"""
66
70
if name in cls ._cache :
67
71
return cls ._cache [name ]
68
-
69
- result = plpy .execute (f"""
72
+
73
+ result = plpy .execute (
74
+ f"""
70
75
SELECT *
71
76
FROM pgml.projects
72
77
WHERE name = { q (name )}
73
- """ , 1 )
74
- if len (result )== 0 :
78
+ """ ,
79
+ 1 ,
80
+ )
81
+ if len (result ) == 0 :
75
82
return None
76
83
77
84
project = Project ()
@@ -84,7 +91,7 @@ def find_by_name(cls, name: str):
84
91
def create (cls , name : str , objective : str ):
85
92
"""
86
93
Create a Project and save it to the database.
87
-
94
+
88
95
Args:
89
96
name (str): a human friendly identifier
90
97
objective (str): valid values are ["regression", "classification"].
@@ -93,11 +100,16 @@ def create(cls, name: str, objective: str):
93
100
"""
94
101
95
102
project = Project ()
96
- project .__dict__ = dict (plpy .execute (f"""
103
+ project .__dict__ = dict (
104
+ plpy .execute (
105
+ f"""
97
106
INSERT INTO pgml.projects (name, objective)
98
107
VALUES ({ q (name )} , { q (objective )} )
99
108
RETURNING *
100
- """ , 1 )[0 ])
109
+ """ ,
110
+ 1 ,
111
+ )[0 ]
112
+ )
101
113
project .__init__ ()
102
114
cls ._cache [name ] = project
103
115
return project
@@ -112,10 +124,11 @@ def deployed_model(self):
112
124
self ._deployed_model = Model .find_deployed (self .id )
113
125
return self ._deployed_model
114
126
127
+
115
128
class Snapshot (object ):
116
129
"""
117
130
Snapshots capture a set of training & test data for repeatability.
118
-
131
+
119
132
Attributes:
120
133
id (int): a unique identifier
121
134
relation_name (str): the name of the table or view to snapshot
@@ -126,11 +139,18 @@ class Snapshot(object):
126
139
created_at (Timestamp): when this snapshot was created
127
140
updated_at (Timestamp): when this snapshot was last updated
128
141
"""
142
+
129
143
@classmethod
130
- def create (cls , relation_name : str , y_column_name : str , test_size : float or int , test_sampling : str ):
144
+ def create (
145
+ cls ,
146
+ relation_name : str ,
147
+ y_column_name : str ,
148
+ test_size : float or int ,
149
+ test_sampling : str ,
150
+ ):
131
151
"""
132
- Create a Snapshot and save it to the database.
133
-
152
+ Create a Snapshot and save it to the database.
153
+
134
154
This creates both a metadata record in the snapshots table, as well as creating a new table
135
155
that holds a snapshot of all the data currently present in the relation so that training
136
156
runs may be repeated, or further analysis may be conducted against the input.
@@ -145,32 +165,46 @@ def create(cls, relation_name: str, y_column_name: str, test_size: float or int,
145
165
"""
146
166
147
167
snapshot = Snapshot ()
148
- snapshot .__dict__ = dict (plpy .execute (f"""
168
+ snapshot .__dict__ = dict (
169
+ plpy .execute (
170
+ f"""
149
171
INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status)
150
172
VALUES ({ q (relation_name )} , { q (y_column_name )} , { q (test_size )} , { q (test_sampling )} , 'new')
151
173
RETURNING *
152
- """ , 1 )[0 ])
153
- plpy .execute (f"""
174
+ """ ,
175
+ 1 ,
176
+ )[0 ]
177
+ )
178
+ plpy .execute (
179
+ f"""
154
180
CREATE TABLE pgml."snapshot_{ snapshot .id } " AS
155
181
SELECT * FROM "{ snapshot .relation_name } ";
156
- """ )
157
- snapshot .__dict__ = dict (plpy .execute (f"""
182
+ """
183
+ )
184
+ snapshot .__dict__ = dict (
185
+ plpy .execute (
186
+ f"""
158
187
UPDATE pgml.snapshots
159
188
SET status = 'created'
160
189
WHERE id = { q (snapshot .id )}
161
190
RETURNING *
162
- """ , 1 )[0 ])
191
+ """ ,
192
+ 1 ,
193
+ )[0 ]
194
+ )
163
195
return snapshot
164
196
165
197
def data (self ):
166
198
"""
167
199
Returns:
168
200
list, list, list, list: All rows from the snapshot split into X_train, X_test, y_train, y_test sets.
169
201
"""
170
- data = plpy .execute (f"""
202
+ data = plpy .execute (
203
+ f"""
171
204
SELECT *
172
205
FROM pgml."snapshot_{ self .id } "
173
- """ )
206
+ """
207
+ )
174
208
175
209
print (data )
176
210
# Sanity check the data
@@ -203,10 +237,10 @@ def data(self):
203
237
y .append (y_ )
204
238
205
239
# Split into training and test sets
206
- if self .test_sampling == ' random' :
240
+ if self .test_sampling == " random" :
207
241
return train_test_split (X , y , test_size = self .test_size , random_state = 0 )
208
242
else :
209
- if self .test_sampling == ' first' :
243
+ if self .test_sampling == " first" :
210
244
X .reverse ()
211
245
y .reverse ()
212
246
if isinstance (split , float ):
@@ -216,9 +250,9 @@ def data(self):
216
250
split = int (self .test_size * X .len ())
217
251
return X [:split ], X [split :], y [:split ], y [split :]
218
252
219
-
220
253
# TODO normalize and clean data
221
254
255
+
222
256
class Model (object ):
223
257
"""Models use an algorithm on a snapshot of data to record the parameters learned.
224
258
@@ -234,23 +268,26 @@ class Model(object):
234
268
pickle (bytes): the serialized version of the model parameters
235
269
algorithm: the in memory version of the model parameters that can make predictions
236
270
"""
271
+
237
272
@classmethod
238
273
def create (cls , project : Project , snapshot : Snapshot , algorithm_name : str ):
239
274
"""
240
275
Create a Model and save it to the database.
241
-
276
+
242
277
Args:
243
- project (str):
244
- snapshot (str):
278
+ project (str):
279
+ snapshot (str):
245
280
algorithm_name (str):
246
281
Returns:
247
282
Model: instantiated from the database
248
283
"""
249
- result = plpy .execute (f"""
284
+ result = plpy .execute (
285
+ f"""
250
286
INSERT INTO pgml.models (project_id, snapshot_id, algorithm_name, status)
251
287
VALUES ({ q (project .id )} , { q (snapshot .id )} , { q (algorithm_name )} , 'new')
252
288
RETURNING *
253
- """ )
289
+ """
290
+ )
254
291
model = Model ()
255
292
model .__dict__ = dict (result [0 ])
256
293
model .__init__ ()
@@ -265,15 +302,17 @@ def find_deployed(cls, project_id: int):
265
302
Returns:
266
303
Model: that should currently be used for predictions of the project
267
304
"""
268
- result = plpy .execute (f"""
305
+ result = plpy .execute (
306
+ f"""
269
307
SELECT models.*
270
308
FROM pgml.models
271
309
JOIN pgml.deployments
272
310
ON deployments.model_id = models.id
273
311
AND deployments.project_id = { q (project_id )}
274
312
ORDER by deployments.created_at DESC
275
313
LIMIT 1
276
- """ )
314
+ """
315
+ )
277
316
if len (result ) == 0 :
278
317
return None
279
318
@@ -303,19 +342,19 @@ def algorithm(self):
303
342
self ._algorithm = pickle .loads (self .pickle )
304
343
else :
305
344
self ._algorithm = {
306
- ' linear_regression' : LinearRegression ,
307
- ' random_forest_regression' : RandomForestRegressor ,
308
- ' random_forest_classification' : RandomForestClassifier
309
- }[self .algorithm_name + '_' + self .project .objective ]()
310
-
345
+ " linear_regression" : LinearRegression ,
346
+ " random_forest_regression" : RandomForestRegressor ,
347
+ " random_forest_classification" : RandomForestClassifier ,
348
+ }[self .algorithm_name + "_" + self .project .objective ]()
349
+
311
350
return self ._algorithm
312
351
313
352
def fit (self , snapshot : Snapshot ):
314
353
"""
315
- Learns the parameters of this model and records them in the database.
354
+ Learns the parameters of this model and records them in the database.
316
355
317
- Args:
318
- snapshot (Snapshot): dataset used to train this model
356
+ Args:
357
+ snapshot (Snapshot): dataset used to train this model
319
358
"""
320
359
X_train , X_test , y_train , y_test = snapshot .data ()
321
360
@@ -328,22 +367,28 @@ def fit(self, snapshot: Snapshot):
328
367
r2 = r2_score (y_test , y_pred )
329
368
330
369
# Save the model
331
- self .__dict__ = dict (plpy .execute (f"""
370
+ self .__dict__ = dict (
371
+ plpy .execute (
372
+ f"""
332
373
UPDATE pgml.models
333
374
SET pickle = '\\ x{ pickle .dumps (self .algorithm ).hex ()} ',
334
375
status = 'successful',
335
376
mean_squared_error = { q (msq )} ,
336
377
r2_score = { q (r2 )}
337
378
WHERE id = { q (self .id )}
338
379
RETURNING *
339
- """ )[0 ])
380
+ """
381
+ )[0 ]
382
+ )
340
383
341
384
def deploy (self ):
342
385
"""Promote this model to the active version for the project that will be used for predictions"""
343
- plpy .execute (f"""
386
+ plpy .execute (
387
+ f"""
344
388
INSERT INTO pgml.deployments (project_id, model_id)
345
389
VALUES ({ q (self .project_id )} , { q (self .id )} )
346
- """ )
390
+ """
391
+ )
347
392
348
393
def predict (self , data : list ):
349
394
"""Use the model for a set of features.
@@ -358,12 +403,12 @@ def predict(self, data: list):
358
403
359
404
360
405
def train (
361
- project_name : str ,
406
+ project_name : str ,
362
407
objective : str ,
363
- relation_name : str ,
364
- y_column_name : str ,
408
+ relation_name : str ,
409
+ y_column_name : str ,
365
410
test_size : float or int = 0.1 ,
366
- test_sampling : str = "random"
411
+ test_sampling : str = "random" ,
367
412
):
368
413
"""Create a regression model from a table or view filled with training data.
369
414
@@ -390,5 +435,5 @@ def train(
390
435
model .fit (snapshot )
391
436
if best_error is None or model .mean_squared_error < best_error :
392
437
best_error = model .mean_squared_error
393
- best_model = model
438
+ best_model = model
394
439
best_model .deploy ()
0 commit comments