Skip to content

Commit 0be25d0

Browse files
authored
Added OpenSourceAI and conversational support in the extension (#1206)
1 parent dd18739 commit 0be25d0

File tree

16 files changed

+1140
-161
lines changed

16 files changed

+1140
-161
lines changed

pgml-extension/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "2.8.0"
3+
version = "2.8.1"
44
edition = "2021"
55

66
[lib]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
-- pgml::api::transform_conversational_json
2+
CREATE FUNCTION pgml."transform"(
3+
"task" jsonb, /* pgrx::datum::json::JsonB */
4+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
5+
"inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec<pgrx::datum::json::JsonB> */
6+
"cache" bool DEFAULT false /* bool */
7+
) RETURNS jsonb /* alloc::string::String */
8+
IMMUTABLE STRICT PARALLEL SAFE
9+
LANGUAGE c /* Rust */
10+
AS 'MODULE_PATHNAME', 'transform_conversational_json_wrapper';
11+
12+
-- pgml::api::transform_conversational_string
13+
CREATE FUNCTION pgml."transform"(
14+
"task" TEXT, /* alloc::string::String */
15+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
16+
"inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec<pgrx::datum::json::JsonB> */
17+
"cache" bool DEFAULT false /* bool */
18+
) RETURNS jsonb /* alloc::string::String */
19+
IMMUTABLE STRICT PARALLEL SAFE
20+
LANGUAGE c /* Rust */
21+
AS 'MODULE_PATHNAME', 'transform_conversational_string_wrapper';
22+
23+
-- pgml::api::transform_stream_string
24+
DROP FUNCTION IF EXISTS pgml."transform_stream"(text,jsonb,text,boolean);
25+
CREATE FUNCTION pgml."transform_stream"(
26+
"task" TEXT, /* alloc::string::String */
27+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
28+
"input" TEXT DEFAULT '', /* &str */
29+
"cache" bool DEFAULT false /* bool */
30+
) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */
31+
IMMUTABLE STRICT PARALLEL SAFE
32+
LANGUAGE c /* Rust */
33+
AS 'MODULE_PATHNAME', 'transform_stream_string_wrapper';
34+
35+
-- pgml::api::transform_stream_json
36+
DROP FUNCTION IF EXISTS pgml."transform_stream"(jsonb,jsonb,text,boolean);
37+
CREATE FUNCTION pgml."transform_stream"(
38+
"task" jsonb, /* pgrx::datum::json::JsonB */
39+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
40+
"input" TEXT DEFAULT '', /* &str */
41+
"cache" bool DEFAULT false /* bool */
42+
) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */
43+
IMMUTABLE STRICT PARALLEL SAFE
44+
LANGUAGE c /* Rust */
45+
AS 'MODULE_PATHNAME', 'transform_stream_json_wrapper';
46+
47+
-- pgml::api::transform_stream_conversational_json
48+
CREATE FUNCTION pgml."transform_stream"(
49+
"task" TEXT, /* alloc::string::String */
50+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
51+
"inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec<pgrx::datum::json::JsonB> */
52+
"cache" bool DEFAULT false /* bool */
53+
) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */
54+
IMMUTABLE STRICT PARALLEL SAFE
55+
LANGUAGE c /* Rust */
56+
AS 'MODULE_PATHNAME', 'transform_stream_conversational_string_wrapper';
57+
58+
-- pgml::api::transform_stream_conversational_string
59+
CREATE FUNCTION pgml."transform_stream"(
60+
"task" jsonb, /* pgrx::datum::json::JsonB */
61+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
62+
"inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec<pgrx::datum::json::JsonB> */
63+
"cache" bool DEFAULT false /* bool */
64+
) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */
65+
IMMUTABLE STRICT PARALLEL SAFE
66+
LANGUAGE c /* Rust */
67+
AS 'MODULE_PATHNAME', 'transform_stream_conversational_json_wrapper';

pgml-extension/src/api.rs

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,50 @@ pub fn transform_string(
632632
}
633633
}
634634

635+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
636+
#[pg_extern(immutable, parallel_safe, name = "transform")]
637+
#[allow(unused_variables)] // cache is maintained for api compatibility
638+
pub fn transform_conversational_json(
639+
task: JsonB,
640+
args: default!(JsonB, "'{}'"),
641+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
642+
cache: default!(bool, false),
643+
) -> JsonB {
644+
if !task.0["task"]
645+
.as_str()
646+
.is_some_and(|v| v == "conversational")
647+
{
648+
error!(
649+
"ARRAY[]::JSONB inputs for transform should only be used with a conversational task"
650+
);
651+
}
652+
match crate::bindings::transformers::transform(&task.0, &args.0, inputs) {
653+
Ok(output) => JsonB(output),
654+
Err(e) => error!("{e}"),
655+
}
656+
}
657+
658+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
659+
#[pg_extern(immutable, parallel_safe, name = "transform")]
660+
#[allow(unused_variables)] // cache is maintained for api compatibility
661+
pub fn transform_conversational_string(
662+
task: String,
663+
args: default!(JsonB, "'{}'"),
664+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
665+
cache: default!(bool, false),
666+
) -> JsonB {
667+
if task != "conversational" {
668+
error!(
669+
"ARRAY[]::JSONB inputs for transform should only be used with a conversational task"
670+
);
671+
}
672+
let task_json = json!({ "task": task });
673+
match crate::bindings::transformers::transform(&task_json, &args.0, inputs) {
674+
Ok(output) => JsonB(output),
675+
Err(e) => error!("{e}"),
676+
}
677+
}
678+
635679
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
636680
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
637681
#[allow(unused_variables)] // cache is maintained for api compatibility
@@ -640,7 +684,7 @@ pub fn transform_stream_json(
640684
args: default!(JsonB, "'{}'"),
641685
input: default!(&str, "''"),
642686
cache: default!(bool, false),
643-
) -> SetOfIterator<'static, String> {
687+
) -> SetOfIterator<'static, JsonB> {
644688
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
645689
let python_iter =
646690
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input)
@@ -657,7 +701,7 @@ pub fn transform_stream_string(
657701
args: default!(JsonB, "'{}'"),
658702
input: default!(&str, "''"),
659703
cache: default!(bool, false),
660-
) -> SetOfIterator<'static, String> {
704+
) -> SetOfIterator<'static, JsonB> {
661705
let task_json = json!({ "task": task });
662706
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
663707
let python_iter =
@@ -667,6 +711,54 @@ pub fn transform_stream_string(
667711
SetOfIterator::new(python_iter)
668712
}
669713

714+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
715+
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
716+
#[allow(unused_variables)] // cache is maintained for api compatibility
717+
pub fn transform_stream_conversational_json(
718+
task: JsonB,
719+
args: default!(JsonB, "'{}'"),
720+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
721+
cache: default!(bool, false),
722+
) -> SetOfIterator<'static, JsonB> {
723+
if !task.0["task"]
724+
.as_str()
725+
.is_some_and(|v| v == "conversational")
726+
{
727+
error!(
728+
"ARRAY[]::JSONB inputs for transform_stream should only be used with a conversational task"
729+
);
730+
}
731+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
732+
let python_iter =
733+
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs)
734+
.map_err(|e| error!("{e}"))
735+
.unwrap();
736+
SetOfIterator::new(python_iter)
737+
}
738+
739+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
740+
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
741+
#[allow(unused_variables)] // cache is maintained for api compatibility
742+
pub fn transform_stream_conversational_string(
743+
task: String,
744+
args: default!(JsonB, "'{}'"),
745+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
746+
cache: default!(bool, false),
747+
) -> SetOfIterator<'static, JsonB> {
748+
if task != "conversational" {
749+
error!(
750+
"ARRAY::JSONB inputs for transform_stream should only be used with a conversational task"
751+
);
752+
}
753+
let task_json = json!({ "task": task });
754+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
755+
let python_iter =
756+
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs)
757+
.map_err(|e| error!("{e}"))
758+
.unwrap();
759+
SetOfIterator::new(python_iter)
760+
}
761+
670762
#[cfg(feature = "python")]
671763
#[pg_extern(immutable, parallel_safe, name = "generate")]
672764
fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String {

pgml-extension/src/bindings/transformers/transform.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,28 @@ impl TransformStreamIterator {
2323
}
2424

2525
impl Iterator for TransformStreamIterator {
26-
type Item = String;
26+
type Item = JsonB;
2727
fn next(&mut self) -> Option<Self::Item> {
2828
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
29-
Python::with_gil(|py| -> Result<Option<String>, PyErr> {
29+
Python::with_gil(|py| -> Result<Option<JsonB>, PyErr> {
3030
let code = "next(python_iter)";
3131
let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?;
3232
if res.is_none() {
3333
Ok(None)
3434
} else {
35-
let res: String = res.extract()?;
36-
Ok(Some(res))
35+
let res: Vec<String> = res.extract()?;
36+
Ok(Some(JsonB(serde_json::to_value(res).unwrap())))
3737
}
3838
})
3939
.map_err(|e| error!("{e}"))
4040
.unwrap()
4141
}
4242
}
4343

44-
pub fn transform(
44+
pub fn transform<T: serde::Serialize>(
4545
task: &serde_json::Value,
4646
args: &serde_json::Value,
47-
inputs: Vec<&str>,
47+
inputs: T,
4848
) -> Result<serde_json::Value> {
4949
crate::bindings::python::activate()?;
5050
whitelist::verify_task(task)?;
@@ -74,17 +74,17 @@ pub fn transform(
7474
Ok(serde_json::from_str(&results)?)
7575
}
7676

77-
pub fn transform_stream(
77+
pub fn transform_stream<T: serde::Serialize>(
7878
task: &serde_json::Value,
7979
args: &serde_json::Value,
80-
input: &str,
80+
input: T,
8181
) -> Result<Py<PyAny>> {
8282
crate::bindings::python::activate()?;
8383
whitelist::verify_task(task)?;
8484

8585
let task = serde_json::to_string(task)?;
8686
let args = serde_json::to_string(args)?;
87-
let inputs = serde_json::to_string(&vec![input])?;
87+
let input = serde_json::to_string(&input)?;
8888

8989
Python::with_gil(|py| -> Result<Py<PyAny>> {
9090
let transform: Py<PyAny> = get_module!(PY_MODULE)
@@ -99,7 +99,7 @@ pub fn transform_stream(
9999
&[
100100
task.into_py(py),
101101
args.into_py(py),
102-
inputs.into_py(py),
102+
input.into_py(py),
103103
true.into_py(py),
104104
],
105105
),
@@ -110,10 +110,10 @@ pub fn transform_stream(
110110
})
111111
}
112112

113-
pub fn transform_stream_iterator(
113+
pub fn transform_stream_iterator<T: serde::Serialize>(
114114
task: &serde_json::Value,
115115
args: &serde_json::Value,
116-
input: &str,
116+
input: T,
117117
) -> Result<TransformStreamIterator> {
118118
let python_iter = transform_stream(task, args, input)
119119
.map_err(|e| error!("{e}"))

0 commit comments

Comments
 (0)