Skip to content

Commit 016975f

Browse files
authored
Merge pull request #3 from postgresml/levkk-black-mvp
lint
2 parents 5f7f626 + 1ddd592 commit 016975f

File tree

4 files changed

+226
-69
lines changed

4 files changed

+226
-69
lines changed

pgml/pgml/model.py

Lines changed: 98 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@
99
from pgml.exceptions import PgMLException
1010
from pgml.sql import q
1111

12+
1213
class Project(object):
1314
"""
1415
Use projects to refine multiple models of a particular dataset on a specific objective.
15-
16+
1617
Attributes:
1718
id (int): a unique identifier
1819
name (str): a human friendly unique identifier
1920
objective (str): the purpose of this project
2021
created_at (Timestamp): when this project was created
2122
updated_at (Timestamp): when this project was last updated
2223
"""
23-
24+
2425
_cache = {}
2526

2627
def __init__(self):
@@ -36,11 +37,14 @@ def find(cls, id: int):
3637
Returns:
3738
Project or None: instantiated from the database if found
3839
"""
39-
result = plpy.execute(f"""
40+
result = plpy.execute(
41+
f"""
4042
SELECT *
4143
FROM pgml.projects
4244
WHERE id = {q(id)}
43-
""", 1)
45+
""",
46+
1,
47+
)
4448
if len(result) == 0:
4549
return None
4650

@@ -53,25 +57,28 @@ def find(cls, id: int):
5357
@classmethod
5458
def find_by_name(cls, name: str):
5559
"""
56-
Get a Project from the database by name.
57-
60+
Get a Project from the database by name.
61+
5862
This is the prefered API to retrieve projects, and they are cached by
5963
name to avoid needing to go to he database on every usage.
60-
64+
6165
Args:
6266
name (str): the project name
6367
Returns:
6468
Project or None: instantiated from the database if found
6569
"""
6670
if name in cls._cache:
6771
return cls._cache[name]
68-
69-
result = plpy.execute(f"""
72+
73+
result = plpy.execute(
74+
f"""
7075
SELECT *
7176
FROM pgml.projects
7277
WHERE name = {q(name)}
73-
""", 1)
74-
if len(result)== 0:
78+
""",
79+
1,
80+
)
81+
if len(result) == 0:
7582
return None
7683

7784
project = Project()
@@ -84,7 +91,7 @@ def find_by_name(cls, name: str):
8491
def create(cls, name: str, objective: str):
8592
"""
8693
Create a Project and save it to the database.
87-
94+
8895
Args:
8996
name (str): a human friendly identifier
9097
objective (str): valid values are ["regression", "classification"].
@@ -93,11 +100,16 @@ def create(cls, name: str, objective: str):
93100
"""
94101

95102
project = Project()
96-
project.__dict__ = dict(plpy.execute(f"""
103+
project.__dict__ = dict(
104+
plpy.execute(
105+
f"""
97106
INSERT INTO pgml.projects (name, objective)
98107
VALUES ({q(name)}, {q(objective)})
99108
RETURNING *
100-
""", 1)[0])
109+
""",
110+
1,
111+
)[0]
112+
)
101113
project.__init__()
102114
cls._cache[name] = project
103115
return project
@@ -112,10 +124,11 @@ def deployed_model(self):
112124
self._deployed_model = Model.find_deployed(self.id)
113125
return self._deployed_model
114126

127+
115128
class Snapshot(object):
116129
"""
117130
Snapshots capture a set of training & test data for repeatability.
118-
131+
119132
Attributes:
120133
id (int): a unique identifier
121134
relation_name (str): the name of the table or view to snapshot
@@ -126,11 +139,18 @@ class Snapshot(object):
126139
created_at (Timestamp): when this snapshot was created
127140
updated_at (Timestamp): when this snapshot was last updated
128141
"""
142+
129143
@classmethod
130-
def create(cls, relation_name: str, y_column_name: str, test_size: float or int, test_sampling: str):
144+
def create(
145+
cls,
146+
relation_name: str,
147+
y_column_name: str,
148+
test_size: float or int,
149+
test_sampling: str,
150+
):
131151
"""
132-
Create a Snapshot and save it to the database.
133-
152+
Create a Snapshot and save it to the database.
153+
134154
This creates both a metadata record in the snapshots table, as well as creating a new table
135155
that holds a snapshot of all the data currently present in the relation so that training
136156
runs may be repeated, or further analysis may be conducted against the input.
@@ -145,32 +165,46 @@ def create(cls, relation_name: str, y_column_name: str, test_size: float or int,
145165
"""
146166

147167
snapshot = Snapshot()
148-
snapshot.__dict__ = dict(plpy.execute(f"""
168+
snapshot.__dict__ = dict(
169+
plpy.execute(
170+
f"""
149171
INSERT INTO pgml.snapshots (relation_name, y_column_name, test_size, test_sampling, status)
150172
VALUES ({q(relation_name)}, {q(y_column_name)}, {q(test_size)}, {q(test_sampling)}, 'new')
151173
RETURNING *
152-
""", 1)[0])
153-
plpy.execute(f"""
174+
""",
175+
1,
176+
)[0]
177+
)
178+
plpy.execute(
179+
f"""
154180
CREATE TABLE pgml."snapshot_{snapshot.id}" AS
155181
SELECT * FROM "{snapshot.relation_name}";
156-
""")
157-
snapshot.__dict__ = dict(plpy.execute(f"""
182+
"""
183+
)
184+
snapshot.__dict__ = dict(
185+
plpy.execute(
186+
f"""
158187
UPDATE pgml.snapshots
159188
SET status = 'created'
160189
WHERE id = {q(snapshot.id)}
161190
RETURNING *
162-
""", 1)[0])
191+
""",
192+
1,
193+
)[0]
194+
)
163195
return snapshot
164196

165197
def data(self):
166198
"""
167199
Returns:
168200
list, list, list, list: All rows from the snapshot split into X_train, X_test, y_train, y_test sets.
169201
"""
170-
data = plpy.execute(f"""
202+
data = plpy.execute(
203+
f"""
171204
SELECT *
172205
FROM pgml."snapshot_{self.id}"
173-
""")
206+
"""
207+
)
174208

175209
print(data)
176210
# Sanity check the data
@@ -203,10 +237,10 @@ def data(self):
203237
y.append(y_)
204238

205239
# Split into training and test sets
206-
if self.test_sampling == 'random':
240+
if self.test_sampling == "random":
207241
return train_test_split(X, y, test_size=self.test_size, random_state=0)
208242
else:
209-
if self.test_sampling == 'first':
243+
if self.test_sampling == "first":
210244
X.reverse()
211245
y.reverse()
212246
if isinstance(split, float):
@@ -216,9 +250,9 @@ def data(self):
216250
split = int(self.test_size * X.len())
217251
return X[:split], X[split:], y[:split], y[split:]
218252

219-
220253
# TODO normalize and clean data
221254

255+
222256
class Model(object):
223257
"""Models use an algorithm on a snapshot of data to record the parameters learned.
224258
@@ -234,23 +268,26 @@ class Model(object):
234268
pickle (bytes): the serialized version of the model parameters
235269
algorithm: the in memory version of the model parameters that can make predictions
236270
"""
271+
237272
@classmethod
238273
def create(cls, project: Project, snapshot: Snapshot, algorithm_name: str):
239274
"""
240275
Create a Model and save it to the database.
241-
276+
242277
Args:
243-
project (str):
244-
snapshot (str):
278+
project (str):
279+
snapshot (str):
245280
algorithm_name (str):
246281
Returns:
247282
Model: instantiated from the database
248283
"""
249-
result = plpy.execute(f"""
284+
result = plpy.execute(
285+
f"""
250286
INSERT INTO pgml.models (project_id, snapshot_id, algorithm_name, status)
251287
VALUES ({q(project.id)}, {q(snapshot.id)}, {q(algorithm_name)}, 'new')
252288
RETURNING *
253-
""")
289+
"""
290+
)
254291
model = Model()
255292
model.__dict__ = dict(result[0])
256293
model.__init__()
@@ -265,15 +302,17 @@ def find_deployed(cls, project_id: int):
265302
Returns:
266303
Model: that should currently be used for predictions of the project
267304
"""
268-
result = plpy.execute(f"""
305+
result = plpy.execute(
306+
f"""
269307
SELECT models.*
270308
FROM pgml.models
271309
JOIN pgml.deployments
272310
ON deployments.model_id = models.id
273311
AND deployments.project_id = {q(project_id)}
274312
ORDER by deployments.created_at DESC
275313
LIMIT 1
276-
""")
314+
"""
315+
)
277316
if len(result) == 0:
278317
return None
279318

@@ -303,19 +342,19 @@ def algorithm(self):
303342
self._algorithm = pickle.loads(self.pickle)
304343
else:
305344
self._algorithm = {
306-
'linear_regression': LinearRegression,
307-
'random_forest_regression': RandomForestRegressor,
308-
'random_forest_classification': RandomForestClassifier
309-
}[self.algorithm_name + '_' + self.project.objective]()
310-
345+
"linear_regression": LinearRegression,
346+
"random_forest_regression": RandomForestRegressor,
347+
"random_forest_classification": RandomForestClassifier,
348+
}[self.algorithm_name + "_" + self.project.objective]()
349+
311350
return self._algorithm
312351

313352
def fit(self, snapshot: Snapshot):
314353
"""
315-
Learns the parameters of this model and records them in the database.
354+
Learns the parameters of this model and records them in the database.
316355
317-
Args:
318-
snapshot (Snapshot): dataset used to train this model
356+
Args:
357+
snapshot (Snapshot): dataset used to train this model
319358
"""
320359
X_train, X_test, y_train, y_test = snapshot.data()
321360

@@ -328,22 +367,28 @@ def fit(self, snapshot: Snapshot):
328367
r2 = r2_score(y_test, y_pred)
329368

330369
# Save the model
331-
self.__dict__ = dict(plpy.execute(f"""
370+
self.__dict__ = dict(
371+
plpy.execute(
372+
f"""
332373
UPDATE pgml.models
333374
SET pickle = '\\x{pickle.dumps(self.algorithm).hex()}',
334375
status = 'successful',
335376
mean_squared_error = {q(msq)},
336377
r2_score = {q(r2)}
337378
WHERE id = {q(self.id)}
338379
RETURNING *
339-
""")[0])
380+
"""
381+
)[0]
382+
)
340383

341384
def deploy(self):
342385
"""Promote this model to the active version for the project that will be used for predictions"""
343-
plpy.execute(f"""
386+
plpy.execute(
387+
f"""
344388
INSERT INTO pgml.deployments (project_id, model_id)
345389
VALUES ({q(self.project_id)}, {q(self.id)})
346-
""")
390+
"""
391+
)
347392

348393
def predict(self, data: list):
349394
"""Use the model for a set of features.
@@ -358,12 +403,12 @@ def predict(self, data: list):
358403

359404

360405
def train(
361-
project_name: str,
406+
project_name: str,
362407
objective: str,
363-
relation_name: str,
364-
y_column_name: str,
408+
relation_name: str,
409+
y_column_name: str,
365410
test_size: float or int = 0.1,
366-
test_sampling: str = "random"
411+
test_sampling: str = "random",
367412
):
368413
"""Create a regression model from a table or view filled with training data.
369414
@@ -390,5 +435,5 @@ def train(
390435
model.fit(snapshot)
391436
if best_error is None or model.mean_squared_error < best_error:
392437
best_error = model.mean_squared_error
393-
best_model = model
438+
best_model = model
394439
best_model.deploy()

pgml/pgml/sql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from plpy import quote_literal
22

3+
34
def q(obj):
45
if type(obj) == str:
56
return quote_literal(obj)

pgml/tests/plpy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22

33
execute_results = deque()
44

5+
56
def quote_literal(literal):
67
return "'" + literal + "'"
78

8-
def execute(sql, lines = 0):
9+
10+
def execute(sql, lines=0):
911
if len(execute_results) > 0:
1012
result = execute_results.popleft()
1113
return result
12-
else:
14+
else:
1315
return []
1416

17+
1518
def add_mock_result(result):
1619
execute_results.append(result)

0 commit comments

Comments
 (0)