Skip to content

Commit 28eff54

Browse files
committed
Batch upsert documents
1 parent fa9639f commit 28eff54

File tree

5 files changed

+147
-18
lines changed

5 files changed

+147
-18
lines changed

pgml-sdks/pgml/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-sdks/pgml/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "1.1.0"
3+
version = "1.2.0"
44
edition = "2021"
55
authors = ["PosgresML <team@postgresml.org>"]
66
homepage = "https://postgresml.org/"

pgml-sdks/pgml/src/batch.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//! Upsert documents in batches.
2+
3+
#[cfg(feature = "rust_bridge")]
4+
use rust_bridge::{alias, alias_methods};
5+
6+
use tracing::instrument;
7+
8+
use crate::{types::Json, Collection};
9+
10+
#[cfg(feature = "python")]
11+
use crate::{collection::CollectionPython, types::JsonPython};
12+
13+
#[cfg(feature = "c")]
14+
use crate::{collection::CollectionC, languages::c::JsonC};
15+
16+
/// A batch of documents staged for upsert
17+
#[cfg_attr(feature = "rust_bridge", derive(alias))]
18+
#[derive(Debug, Clone)]
19+
pub struct Batch {
20+
collection: Collection,
21+
pub(crate) documents: Vec<Json>,
22+
pub(crate) size: i64,
23+
pub(crate) args: Option<Json>,
24+
}
25+
26+
#[cfg_attr(feature = "rust_bridge", alias_methods(new, upsert_documents, finish,))]
27+
impl Batch {
28+
/// Create a new upsert batch.
29+
///
30+
/// # Arguments
31+
///
32+
/// * `collection` - The collection to upsert documents to.
33+
/// * `size` - The size of the batch.
34+
/// * `args` - Optional arguments to pass to the upsert operation.
35+
///
36+
/// # Example
37+
///
38+
/// ```
39+
/// use pgml::{Collection, Batch};
40+
///
41+
/// let collection = Collection::new("my_collection");
42+
/// let batch = Batch::new(&collection, 100, None);
43+
/// ```
44+
pub fn new(collection: &Collection, size: i64, args: Option<Json>) -> Self {
45+
Self {
46+
collection: collection.clone(),
47+
args,
48+
documents: Vec::new(),
49+
size,
50+
}
51+
}
52+
53+
/// Upsert documents into the collection. If the batch is full, save the documents.
54+
///
55+
/// When using this method, remember to call [finish](Batch::finish) to save any remaining documents
56+
/// in the last batch.
57+
///
58+
/// # Arguments
59+
///
60+
/// * `documents` - The documents to upsert.
61+
///
62+
/// # Example
63+
///
64+
/// ```
65+
/// use pgml::{Collection, Batch};
66+
/// use serde_json::json;
67+
///
68+
/// let collection = Collection::new("my_collection");
69+
/// let mut batch = Batch::new(&collection, 100, None);
70+
///
71+
/// batch.upsert_documents(vec![json!({"id": 1}), json!({"id": 2})]).await?;
72+
/// batch.finish().await?;
73+
/// ```
74+
#[instrument(skip(self))]
75+
pub async fn upsert_documents(&mut self, documents: Vec<Json>) -> anyhow::Result<()> {
76+
for document in documents {
77+
if self.size as usize >= self.documents.len() {
78+
self.collection
79+
.upsert_documents(std::mem::take(&mut self.documents), self.args.clone())
80+
.await?;
81+
self.documents.clear();
82+
}
83+
84+
self.documents.push(document);
85+
}
86+
87+
Ok(())
88+
}
89+
90+
/// Save any remaining documents in the last batch.
91+
#[instrument(skip(self))]
92+
pub async fn finish(&mut self) -> anyhow::Result<()> {
93+
if !self.documents.is_empty() {
94+
self.collection
95+
.upsert_documents(std::mem::take(&mut self.documents), self.args.clone())
96+
.await?;
97+
}
98+
99+
Ok(())
100+
}
101+
}

pgml-sdks/pgml/src/collection.rs

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ impl Collection {
208208
.all(|c| c.is_alphanumeric() || c.is_whitespace() || c == '-' || c == '_')
209209
{
210210
anyhow::bail!(
211-
"Name must only consist of letters, numebers, white space, and '-' or '_'"
211+
"Collection name must only consist of letters, numbers, white space, and '-' or '_'"
212212
)
213213
}
214214
let (pipelines_table_name, documents_table_name) = Self::generate_table_names(name);
@@ -264,21 +264,43 @@ impl Collection {
264264
} else {
265265
let mut transaction = pool.begin().await?;
266266

267-
let project_id: i64 = sqlx::query_scalar("INSERT INTO pgml.projects (name, task) VALUES ($1, 'embedding'::pgml.task) ON CONFLICT (name) DO UPDATE SET task = EXCLUDED.task RETURNING id, task::TEXT")
268-
.bind(&self.name)
269-
.fetch_one(&mut *transaction)
270-
.await?;
267+
let project_id: i64 = sqlx::query_scalar(
268+
"
269+
INSERT INTO pgml.projects (
270+
name,
271+
task
272+
) VALUES (
273+
$1,
274+
'embedding'::pgml.task
275+
) ON CONFLICT (name)
276+
DO UPDATE SET
277+
task = EXCLUDED.task
278+
RETURNING id, task::TEXT",
279+
)
280+
.bind(&self.name)
281+
.fetch_one(&mut *transaction)
282+
.await?;
271283

272284
transaction
273285
.execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str())
274286
.await?;
275287

276-
let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, sdk_version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *")
277-
.bind(&self.name)
278-
.bind(project_id)
279-
.bind(crate::SDK_VERSION)
280-
.fetch_one(&mut *transaction)
281-
.await?;
288+
let c: models::Collection = sqlx::query_as(
289+
"
290+
INSERT INTO pgml.collections (
291+
name,
292+
project_id,
293+
sdk_version
294+
) VALUES (
295+
$1, $2, $3
296+
) ON CONFLICT (name) DO NOTHING
297+
RETURNING *",
298+
)
299+
.bind(&self.name)
300+
.bind(project_id)
301+
.bind(crate::SDK_VERSION)
302+
.fetch_one(&mut *transaction)
303+
.await?;
282304

283305
let collection_database_data = CollectionDatabaseData {
284306
id: c.id,
@@ -353,23 +375,25 @@ impl Collection {
353375
.await?;
354376

355377
if exists {
356-
warn!("Pipeline {} already exists not adding", pipeline.name);
378+
warn!("Pipeline {} already exists, not adding", pipeline.name);
357379
} else {
358-
// We want to intentially throw an error if they have already added this pipeline
380+
// We want to intentionally throw an error if they have already added this pipeline
359381
// as we don't want to casually resync
382+
let mp = MultiProgress::new();
383+
mp.println(format!("Adding pipeline {}...", pipeline.name))?;
384+
360385
pipeline
361386
.verify_in_database(project_info, true, &pool)
362387
.await?;
363388

364-
let mp = MultiProgress::new();
365-
mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?;
389+
mp.println(format!("Added pipeline {}, now syncing...", pipeline.name))?;
366390

367391
// TODO: Revisit this. If the pipeline is added but fails to sync, then it will be "out of sync" with the documents in the table
368392
// This is rare, but could happen
369393
pipeline
370394
.resync(project_info, pool.acquire().await?.as_mut())
371395
.await?;
372-
mp.println(format!("Done Syncing {}\n", pipeline.name))?;
396+
mp.println(format!("Done syncing {}\n", pipeline.name))?;
373397
}
374398
Ok(())
375399
}

pgml-sdks/pgml/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use tokio::runtime::{Builder, Runtime};
1414
use tracing::Level;
1515
use tracing_subscriber::FmtSubscriber;
1616

17+
mod batch;
1718
mod builtins;
1819
#[cfg(any(feature = "python", feature = "javascript"))]
1920
mod cli;
@@ -40,6 +41,7 @@ mod utils;
4041
mod vector_search_query_builder;
4142

4243
// Re-export
44+
pub use batch::Batch;
4345
pub use builtins::Builtins;
4446
pub use collection::Collection;
4547
pub use model::Model;
@@ -217,6 +219,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> {
217219
m.add_class::<builtins::BuiltinsPython>()?;
218220
m.add_class::<transformer_pipeline::TransformerPipelinePython>()?;
219221
m.add_class::<open_source_ai::OpenSourceAIPython>()?;
222+
m.add_class::<batch::BatchPython>()?;
220223
Ok(())
221224
}
222225

@@ -275,6 +278,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> {
275278
"newOpenSourceAI",
276279
open_source_ai::OpenSourceAIJavascript::new,
277280
)?;
281+
cx.export_function("newBatch", batch::BatchJavascript::new)?;
278282
Ok(())
279283
}
280284

0 commit comments

Comments
 (0)