Skip to content

Commit d31b6f4

Browse files
authored
SDK - Allow parallel batch uploads (#1465)
1 parent 6d061ed commit d31b6f4

File tree

2 files changed

+224
-102
lines changed

2 files changed

+224
-102
lines changed

pgml-sdks/pgml/src/collection.rs

Lines changed: 144 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods};
66
use sea_query::Alias;
77
use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query};
88
use sea_query_binder::SqlxBinder;
9-
use serde_json::json;
10-
use sqlx::Executor;
9+
use serde_json::{json, Value};
1110
use sqlx::PgConnection;
11+
use sqlx::{Executor, Pool, Postgres};
1212
use std::borrow::Cow;
1313
use std::collections::HashMap;
1414
use std::path::Path;
1515
use std::time::SystemTime;
1616
use std::time::UNIX_EPOCH;
17+
use tokio::task::JoinSet;
1718
use tracing::{instrument, warn};
1819
use walkdir::WalkDir;
1920

2021
use crate::debug_sqlx_query;
2122
use crate::filter_builder::FilterBuilder;
23+
use crate::pipeline::FieldAction;
2224
use crate::search_query_builder::build_search_query;
2325
use crate::vector_search_query_builder::build_vector_search_query;
2426
use crate::{
@@ -496,28 +498,80 @@ impl Collection {
496498
// -> Insert the document
497499
// -> Foreach pipeline check if we need to resync the document and if so sync the document
498500
// -> Commit the transaction
501+
let mut args = args.unwrap_or_default();
502+
let args = args.as_object_mut().context("args must be a JSON object")?;
503+
499504
self.verify_in_database(false).await?;
500505
let mut pipelines = self.get_pipelines().await?;
501506

502507
let pool = get_or_initialize_pool(&self.database_url).await?;
503508

504-
let mut parsed_schemas = vec![];
505509
let project_info = &self.database_data.as_ref().unwrap().project_info;
510+
let mut parsed_schemas = vec![];
506511
for pipeline in &mut pipelines {
507512
let parsed_schema = pipeline
508513
.get_parsed_schema(project_info, &pool)
509514
.await
510515
.expect("Error getting parsed schema for pipeline");
511516
parsed_schemas.push(parsed_schema);
512517
}
513-
let mut pipelines: Vec<(Pipeline, _)> = pipelines.into_iter().zip(parsed_schemas).collect();
518+
let pipelines: Vec<(Pipeline, HashMap<String, FieldAction>)> =
519+
pipelines.into_iter().zip(parsed_schemas).collect();
514520

515-
let args = args.unwrap_or_default();
516-
let args = args.as_object().context("args must be a JSON object")?;
521+
let batch_size = args
522+
.remove("batch_size")
523+
.map(|x| x.try_to_u64())
524+
.unwrap_or(Ok(100))?;
525+
526+
let parallel_batches = args
527+
.get("parallel_batches")
528+
.map(|x| x.try_to_u64())
529+
.unwrap_or(Ok(1))? as usize;
517530

518531
let progress_bar = utils::default_progress_bar(documents.len() as u64);
519532
progress_bar.println("Upserting Documents...");
520533

534+
let mut set = JoinSet::new();
535+
for batch in documents.chunks(batch_size as usize) {
536+
if set.len() < parallel_batches {
537+
let local_self = self.clone();
538+
let local_batch = batch.to_owned();
539+
let local_args = args.clone();
540+
let local_pipelines = pipelines.clone();
541+
let local_pool = pool.clone();
542+
set.spawn(async move {
543+
local_self
544+
._upsert_documents(local_batch, local_args, local_pipelines, local_pool)
545+
.await
546+
});
547+
} else {
548+
if let Some(res) = set.join_next().await {
549+
res??;
550+
progress_bar.inc(batch_size);
551+
}
552+
}
553+
}
554+
555+
while let Some(res) = set.join_next().await {
556+
res??;
557+
progress_bar.inc(batch_size);
558+
}
559+
560+
progress_bar.println("Done Upserting Documents\n");
561+
progress_bar.finish();
562+
563+
Ok(())
564+
}
565+
566+
async fn _upsert_documents(
567+
self,
568+
batch: Vec<Json>,
569+
args: serde_json::Map<String, Value>,
570+
mut pipelines: Vec<(Pipeline, HashMap<String, FieldAction>)>,
571+
pool: Pool<Postgres>,
572+
) -> anyhow::Result<()> {
573+
let project_info = &self.database_data.as_ref().unwrap().project_info;
574+
521575
let query = if args
522576
.get("merge")
523577
.map(|v| v.as_bool().unwrap_or(false))
@@ -539,111 +593,99 @@ impl Collection {
539593
)
540594
};
541595

542-
let batch_size = args
543-
.get("batch_size")
544-
.map(TryToNumeric::try_to_u64)
545-
.unwrap_or(Ok(100))?;
546-
547-
for batch in documents.chunks(batch_size as usize) {
548-
let mut transaction = pool.begin().await?;
549-
550-
let mut query_values = String::new();
551-
let mut binding_parameter_counter = 1;
552-
for _ in 0..batch.len() {
553-
query_values = format!(
554-
"{query_values}, (${}, ${}, ${})",
555-
binding_parameter_counter,
556-
binding_parameter_counter + 1,
557-
binding_parameter_counter + 2
558-
);
559-
binding_parameter_counter += 3;
560-
}
596+
let mut transaction = pool.begin().await?;
561597

562-
let query = query.replace(
563-
"{values_parameters}",
564-
&query_values.chars().skip(1).collect::<String>(),
565-
);
566-
let query = query.replace(
567-
"{binding_parameter}",
568-
&format!("${binding_parameter_counter}"),
598+
let mut query_values = String::new();
599+
let mut binding_parameter_counter = 1;
600+
for _ in 0..batch.len() {
601+
query_values = format!(
602+
"{query_values}, (${}, ${}, ${})",
603+
binding_parameter_counter,
604+
binding_parameter_counter + 1,
605+
binding_parameter_counter + 2
569606
);
607+
binding_parameter_counter += 3;
608+
}
570609

571-
let mut query = sqlx::query_as(&query);
572-
573-
let mut source_uuids = vec![];
574-
for document in batch {
575-
let id = document
576-
.get("id")
577-
.context("`id` must be a key in document")?
578-
.to_string();
579-
let md5_digest = md5::compute(id.as_bytes());
580-
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
581-
source_uuids.push(source_uuid);
582-
583-
let start = SystemTime::now();
584-
let timestamp = start
585-
.duration_since(UNIX_EPOCH)
586-
.expect("Time went backwards")
587-
.as_millis();
588-
589-
let versions: HashMap<String, serde_json::Value> = document
590-
.as_object()
591-
.context("document must be an object")?
592-
.iter()
593-
.try_fold(HashMap::new(), |mut acc, (key, value)| {
594-
let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes());
595-
let md5_digest = format!("{md5_digest:x}");
596-
acc.insert(
597-
key.to_owned(),
598-
serde_json::json!({
599-
"last_updated": timestamp,
600-
"md5": md5_digest
601-
}),
602-
);
603-
anyhow::Ok(acc)
604-
})?;
605-
let versions = serde_json::to_value(versions)?;
606-
607-
query = query.bind(source_uuid).bind(document).bind(versions);
608-
}
610+
let query = query.replace(
611+
"{values_parameters}",
612+
&query_values.chars().skip(1).collect::<String>(),
613+
);
614+
let query = query.replace(
615+
"{binding_parameter}",
616+
&format!("${binding_parameter_counter}"),
617+
);
609618

610-
let results: Vec<(i64, Option<Json>)> = query
611-
.bind(source_uuids)
612-
.fetch_all(&mut *transaction)
613-
.await?;
619+
let mut query = sqlx::query_as(&query);
620+
621+
let mut source_uuids = vec![];
622+
for document in &batch {
623+
let id = document
624+
.get("id")
625+
.context("`id` must be a key in document")?
626+
.to_string();
627+
let md5_digest = md5::compute(id.as_bytes());
628+
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
629+
source_uuids.push(source_uuid);
630+
631+
let start = SystemTime::now();
632+
let timestamp = start
633+
.duration_since(UNIX_EPOCH)
634+
.expect("Time went backwards")
635+
.as_millis();
636+
637+
let versions: HashMap<String, serde_json::Value> = document
638+
.as_object()
639+
.context("document must be an object")?
640+
.iter()
641+
.try_fold(HashMap::new(), |mut acc, (key, value)| {
642+
let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes());
643+
let md5_digest = format!("{md5_digest:x}");
644+
acc.insert(
645+
key.to_owned(),
646+
serde_json::json!({
647+
"last_updated": timestamp,
648+
"md5": md5_digest
649+
}),
650+
);
651+
anyhow::Ok(acc)
652+
})?;
653+
let versions = serde_json::to_value(versions)?;
614654

615-
let dp: Vec<(i64, Json, Option<Json>)> = results
616-
.into_iter()
617-
.zip(batch)
618-
.map(|((id, previous_document), document)| {
619-
(id, document.to_owned(), previous_document)
655+
query = query.bind(source_uuid).bind(document).bind(versions);
656+
}
657+
658+
let results: Vec<(i64, Option<Json>)> = query
659+
.bind(source_uuids)
660+
.fetch_all(&mut *transaction)
661+
.await?;
662+
663+
let dp: Vec<(i64, Json, Option<Json>)> = results
664+
.into_iter()
665+
.zip(batch)
666+
.map(|((id, previous_document), document)| (id, document.to_owned(), previous_document))
667+
.collect();
668+
669+
for (pipeline, parsed_schema) in &mut pipelines {
670+
let ids_to_run_on: Vec<i64> = dp
671+
.iter()
672+
.filter(|(_, document, previous_document)| match previous_document {
673+
Some(previous_document) => parsed_schema
674+
.iter()
675+
.any(|(key, _)| document[key] != previous_document[key]),
676+
None => true,
620677
})
678+
.map(|(document_id, _, _)| *document_id)
621679
.collect();
622-
623-
for (pipeline, parsed_schema) in &mut pipelines {
624-
let ids_to_run_on: Vec<i64> = dp
625-
.iter()
626-
.filter(|(_, document, previous_document)| match previous_document {
627-
Some(previous_document) => parsed_schema
628-
.iter()
629-
.any(|(key, _)| document[key] != previous_document[key]),
630-
None => true,
631-
})
632-
.map(|(document_id, _, _)| *document_id)
633-
.collect();
634-
if !ids_to_run_on.is_empty() {
635-
pipeline
636-
.sync_documents(ids_to_run_on, project_info, &mut transaction)
637-
.await
638-
.expect("Failed to execute pipeline");
639-
}
680+
if !ids_to_run_on.is_empty() {
681+
pipeline
682+
.sync_documents(ids_to_run_on, project_info, &mut transaction)
683+
.await
684+
.expect("Failed to execute pipeline");
640685
}
641-
642-
transaction.commit().await?;
643-
progress_bar.inc(batch_size);
644686
}
645-
progress_bar.println("Done Upserting Documents\n");
646-
progress_bar.finish();
687+
688+
transaction.commit().await?;
647689
Ok(())
648690
}
649691

pgml-sdks/pgml/src/lib.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,86 @@ mod tests {
431431
Ok(())
432432
}
433433

434+
#[tokio::test]
435+
async fn can_add_pipeline_and_upsert_documents_with_parallel_batches() -> anyhow::Result<()> {
436+
internal_init_logger(None, None).ok();
437+
let collection_name = "test_r_c_capaud_107";
438+
let pipeline_name = "test_r_p_capaud_6";
439+
let mut pipeline = Pipeline::new(
440+
pipeline_name,
441+
Some(
442+
json!({
443+
"title": {
444+
"semantic_search": {
445+
"model": "intfloat/e5-small"
446+
}
447+
},
448+
"body": {
449+
"splitter": {
450+
"model": "recursive_character",
451+
"parameters": {
452+
"chunk_size": 1000,
453+
"chunk_overlap": 40
454+
}
455+
},
456+
"semantic_search": {
457+
"model": "hkunlp/instructor-base",
458+
"parameters": {
459+
"instruction": "Represent the Wikipedia document for retrieval"
460+
}
461+
},
462+
"full_text_search": {
463+
"configuration": "english"
464+
}
465+
}
466+
})
467+
.into(),
468+
),
469+
)?;
470+
let mut collection = Collection::new(collection_name, None)?;
471+
collection.add_pipeline(&mut pipeline).await?;
472+
let documents = generate_dummy_documents(20);
473+
collection
474+
.upsert_documents(
475+
documents.clone(),
476+
Some(
477+
json!({
478+
"batch_size": 4,
479+
"parallel_batches": 5
480+
})
481+
.into(),
482+
),
483+
)
484+
.await?;
485+
let pool = get_or_initialize_pool(&None).await?;
486+
let documents_table = format!("{}.documents", collection_name);
487+
let queried_documents: Vec<models::Document> =
488+
sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table))
489+
.fetch_all(&pool)
490+
.await?;
491+
assert!(queried_documents.len() == 20);
492+
let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name);
493+
let title_chunks: Vec<models::Chunk> =
494+
sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table))
495+
.fetch_all(&pool)
496+
.await?;
497+
assert!(title_chunks.len() == 20);
498+
let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name);
499+
let body_chunks: Vec<models::Chunk> =
500+
sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table))
501+
.fetch_all(&pool)
502+
.await?;
503+
assert!(body_chunks.len() == 120);
504+
let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name);
505+
let tsvectors: Vec<models::TSVector> =
506+
sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table))
507+
.fetch_all(&pool)
508+
.await?;
509+
assert!(tsvectors.len() == 120);
510+
collection.archive().await?;
511+
Ok(())
512+
}
513+
434514
#[tokio::test]
435515
async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> {
436516
internal_init_logger(None, None).ok();

0 commit comments

Comments
 (0)