Skip to content

Commit 0e9a873

Browse files
authored
Move PythonIterator into its own function (#1156)
1 parent d5c8629 commit 0e9a873

File tree

2 files changed

+58
-49
lines changed

2 files changed

+58
-49
lines changed

pgml-extension/src/api.rs

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ use std::str::FromStr;
44
use ndarray::Zip;
55
use pgrx::iter::{SetOfIterator, TableIterator};
66
use pgrx::*;
7-
use pyo3::prelude::*;
8-
use pyo3::types::{IntoPyDict, PyDict};
97

108
#[cfg(feature = "python")]
119
use serde_json::json;
@@ -634,40 +632,6 @@ pub fn transform_string(
634632
}
635633
}
636634

637-
struct TransformStreamIterator {
638-
locals: Py<PyDict>,
639-
}
640-
641-
impl TransformStreamIterator {
642-
fn new(python_iter: Py<PyAny>) -> Self {
643-
let locals = Python::with_gil(|py| -> Result<Py<PyDict>, PyErr> {
644-
Ok([("python_iter", python_iter)].into_py_dict(py).into())
645-
})
646-
.map_err(|e| error!("{e}"))
647-
.unwrap();
648-
Self { locals }
649-
}
650-
}
651-
652-
impl Iterator for TransformStreamIterator {
653-
type Item = String;
654-
fn next(&mut self) -> Option<Self::Item> {
655-
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
656-
Python::with_gil(|py| -> Result<Option<String>, PyErr> {
657-
let code = "next(python_iter)";
658-
let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?;
659-
if res.is_none() {
660-
Ok(None)
661-
} else {
662-
let res: String = res.extract()?;
663-
Ok(Some(res))
664-
}
665-
})
666-
.map_err(|e| error!("{e}"))
667-
.unwrap()
668-
}
669-
}
670-
671635
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
672636
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
673637
#[allow(unused_variables)] // cache is maintained for api compatibility
@@ -678,11 +642,11 @@ pub fn transform_stream_json(
678642
cache: default!(bool, false),
679643
) -> SetOfIterator<'static, String> {
680644
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
681-
let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, input)
682-
.map_err(|e| error!("{e}"))
683-
.unwrap();
684-
let res = TransformStreamIterator::new(python_iter);
685-
SetOfIterator::new(res)
645+
let python_iter =
646+
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input)
647+
.map_err(|e| error!("{e}"))
648+
.unwrap();
649+
SetOfIterator::new(python_iter)
686650
}
687651

688652
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
@@ -696,11 +660,11 @@ pub fn transform_stream_string(
696660
) -> SetOfIterator<'static, String> {
697661
let task_json = json!({ "task": task });
698662
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
699-
let python_iter = crate::bindings::transformers::transform_stream(&task_json, &args.0, input)
700-
.map_err(|e| error!("{e}"))
701-
.unwrap();
702-
let res = TransformStreamIterator::new(python_iter);
703-
SetOfIterator::new(res)
663+
let python_iter =
664+
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input)
665+
.map_err(|e| error!("{e}"))
666+
.unwrap();
667+
SetOfIterator::new(python_iter)
704668
}
705669

706670
#[cfg(feature = "python")]

pgml-extension/src/bindings/transformers/transformers.rs

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,52 @@
11
use super::whitelist;
22
use super::TracebackError;
33
use anyhow::Result;
4+
use pgrx::*;
45
use pyo3::prelude::*;
5-
use pyo3::types::PyTuple;
6+
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
7+
68
create_pymodule!("/src/bindings/transformers/transformers.py");
79

10+
pub struct TransformStreamIterator {
11+
locals: Py<PyDict>,
12+
}
13+
14+
impl TransformStreamIterator {
15+
fn new(python_iter: Py<PyAny>) -> Self {
16+
let locals = Python::with_gil(|py| -> Result<Py<PyDict>, PyErr> {
17+
Ok([("python_iter", python_iter)].into_py_dict(py).into())
18+
})
19+
.map_err(|e| error!("{e}"))
20+
.unwrap();
21+
Self { locals }
22+
}
23+
}
24+
25+
impl Iterator for TransformStreamIterator {
26+
type Item = String;
27+
fn next(&mut self) -> Option<Self::Item> {
28+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
29+
Python::with_gil(|py| -> Result<Option<String>, PyErr> {
30+
let code = "next(python_iter)";
31+
let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?;
32+
if res.is_none() {
33+
Ok(None)
34+
} else {
35+
let res: String = res.extract()?;
36+
Ok(Some(res))
37+
}
38+
})
39+
.map_err(|e| error!("{e}"))
40+
.unwrap()
41+
}
42+
}
43+
844
pub fn transform(
945
task: &serde_json::Value,
1046
args: &serde_json::Value,
1147
inputs: Vec<&str>,
1248
) -> Result<serde_json::Value> {
1349
crate::bindings::python::activate()?;
14-
1550
whitelist::verify_task(task)?;
1651

1752
let task = serde_json::to_string(task)?;
@@ -45,7 +80,6 @@ pub fn transform_stream(
4580
input: &str,
4681
) -> Result<Py<PyAny>> {
4782
crate::bindings::python::activate()?;
48-
4983
whitelist::verify_task(task)?;
5084

5185
let task = serde_json::to_string(task)?;
@@ -75,3 +109,14 @@ pub fn transform_stream(
75109
Ok(output)
76110
})
77111
}
112+
113+
pub fn transform_stream_iterator(
114+
task: &serde_json::Value,
115+
args: &serde_json::Value,
116+
input: &str,
117+
) -> Result<TransformStreamIterator> {
118+
let python_iter = transform_stream(task, args, input)
119+
.map_err(|e| error!("{e}"))
120+
.unwrap();
121+
Ok(TransformStreamIterator::new(python_iter))
122+
}

0 commit comments

Comments
 (0)