@@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods};
6
6
use sea_query:: Alias ;
7
7
use sea_query:: { Expr , NullOrdering , Order , PostgresQueryBuilder , Query } ;
8
8
use sea_query_binder:: SqlxBinder ;
9
- use serde_json:: json;
10
- use sqlx:: Executor ;
9
+ use serde_json:: { json, Value } ;
11
10
use sqlx:: PgConnection ;
11
+ use sqlx:: { Executor , Pool , Postgres } ;
12
12
use std:: borrow:: Cow ;
13
13
use std:: collections:: HashMap ;
14
14
use std:: path:: Path ;
15
15
use std:: time:: SystemTime ;
16
16
use std:: time:: UNIX_EPOCH ;
17
+ use tokio:: task:: JoinSet ;
17
18
use tracing:: { instrument, warn} ;
18
19
use walkdir:: WalkDir ;
19
20
20
21
use crate :: debug_sqlx_query;
21
22
use crate :: filter_builder:: FilterBuilder ;
23
+ use crate :: pipeline:: FieldAction ;
22
24
use crate :: search_query_builder:: build_search_query;
23
25
use crate :: vector_search_query_builder:: build_vector_search_query;
24
26
use crate :: {
@@ -496,28 +498,80 @@ impl Collection {
496
498
// -> Insert the document
497
499
// -> Foreach pipeline check if we need to resync the document and if so sync the document
498
500
// -> Commit the transaction
501
+ let mut args = args. unwrap_or_default ( ) ;
502
+ let args = args. as_object_mut ( ) . context ( "args must be a JSON object" ) ?;
503
+
499
504
self . verify_in_database ( false ) . await ?;
500
505
let mut pipelines = self . get_pipelines ( ) . await ?;
501
506
502
507
let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
503
508
504
- let mut parsed_schemas = vec ! [ ] ;
505
509
let project_info = & self . database_data . as_ref ( ) . unwrap ( ) . project_info ;
510
+ let mut parsed_schemas = vec ! [ ] ;
506
511
for pipeline in & mut pipelines {
507
512
let parsed_schema = pipeline
508
513
. get_parsed_schema ( project_info, & pool)
509
514
. await
510
515
. expect ( "Error getting parsed schema for pipeline" ) ;
511
516
parsed_schemas. push ( parsed_schema) ;
512
517
}
513
- let mut pipelines: Vec < ( Pipeline , _ ) > = pipelines. into_iter ( ) . zip ( parsed_schemas) . collect ( ) ;
518
+ let pipelines: Vec < ( Pipeline , HashMap < String , FieldAction > ) > =
519
+ pipelines. into_iter ( ) . zip ( parsed_schemas) . collect ( ) ;
514
520
515
- let args = args. unwrap_or_default ( ) ;
516
- let args = args. as_object ( ) . context ( "args must be a JSON object" ) ?;
521
+ let batch_size = args
522
+ . remove ( "batch_size" )
523
+ . map ( |x| x. try_to_u64 ( ) )
524
+ . unwrap_or ( Ok ( 100 ) ) ?;
525
+
526
+ let parallel_batches = args
527
+ . get ( "parallel_batches" )
528
+ . map ( |x| x. try_to_u64 ( ) )
529
+ . unwrap_or ( Ok ( 1 ) ) ? as usize ;
517
530
518
531
let progress_bar = utils:: default_progress_bar ( documents. len ( ) as u64 ) ;
519
532
progress_bar. println ( "Upserting Documents..." ) ;
520
533
534
+ let mut set = JoinSet :: new ( ) ;
535
+ for batch in documents. chunks ( batch_size as usize ) {
536
+ if set. len ( ) < parallel_batches {
537
+ let local_self = self . clone ( ) ;
538
+ let local_batch = batch. to_owned ( ) ;
539
+ let local_args = args. clone ( ) ;
540
+ let local_pipelines = pipelines. clone ( ) ;
541
+ let local_pool = pool. clone ( ) ;
542
+ set. spawn ( async move {
543
+ local_self
544
+ . _upsert_documents ( local_batch, local_args, local_pipelines, local_pool)
545
+ . await
546
+ } ) ;
547
+ } else {
548
+ if let Some ( res) = set. join_next ( ) . await {
549
+ res??;
550
+ progress_bar. inc ( batch_size) ;
551
+ }
552
+ }
553
+ }
554
+
555
+ while let Some ( res) = set. join_next ( ) . await {
556
+ res??;
557
+ progress_bar. inc ( batch_size) ;
558
+ }
559
+
560
+ progress_bar. println ( "Done Upserting Documents\n " ) ;
561
+ progress_bar. finish ( ) ;
562
+
563
+ Ok ( ( ) )
564
+ }
565
+
566
+ async fn _upsert_documents (
567
+ self ,
568
+ batch : Vec < Json > ,
569
+ args : serde_json:: Map < String , Value > ,
570
+ mut pipelines : Vec < ( Pipeline , HashMap < String , FieldAction > ) > ,
571
+ pool : Pool < Postgres > ,
572
+ ) -> anyhow:: Result < ( ) > {
573
+ let project_info = & self . database_data . as_ref ( ) . unwrap ( ) . project_info ;
574
+
521
575
let query = if args
522
576
. get ( "merge" )
523
577
. map ( |v| v. as_bool ( ) . unwrap_or ( false ) )
@@ -539,111 +593,99 @@ impl Collection {
539
593
)
540
594
} ;
541
595
542
- let batch_size = args
543
- . get ( "batch_size" )
544
- . map ( TryToNumeric :: try_to_u64)
545
- . unwrap_or ( Ok ( 100 ) ) ?;
546
-
547
- for batch in documents. chunks ( batch_size as usize ) {
548
- let mut transaction = pool. begin ( ) . await ?;
549
-
550
- let mut query_values = String :: new ( ) ;
551
- let mut binding_parameter_counter = 1 ;
552
- for _ in 0 ..batch. len ( ) {
553
- query_values = format ! (
554
- "{query_values}, (${}, ${}, ${})" ,
555
- binding_parameter_counter,
556
- binding_parameter_counter + 1 ,
557
- binding_parameter_counter + 2
558
- ) ;
559
- binding_parameter_counter += 3 ;
560
- }
596
+ let mut transaction = pool. begin ( ) . await ?;
561
597
562
- let query = query. replace (
563
- "{values_parameters}" ,
564
- & query_values. chars ( ) . skip ( 1 ) . collect :: < String > ( ) ,
565
- ) ;
566
- let query = query. replace (
567
- "{binding_parameter}" ,
568
- & format ! ( "${binding_parameter_counter}" ) ,
598
+ let mut query_values = String :: new ( ) ;
599
+ let mut binding_parameter_counter = 1 ;
600
+ for _ in 0 ..batch. len ( ) {
601
+ query_values = format ! (
602
+ "{query_values}, (${}, ${}, ${})" ,
603
+ binding_parameter_counter,
604
+ binding_parameter_counter + 1 ,
605
+ binding_parameter_counter + 2
569
606
) ;
607
+ binding_parameter_counter += 3 ;
608
+ }
570
609
571
- let mut query = sqlx:: query_as ( & query) ;
572
-
573
- let mut source_uuids = vec ! [ ] ;
574
- for document in batch {
575
- let id = document
576
- . get ( "id" )
577
- . context ( "`id` must be a key in document" ) ?
578
- . to_string ( ) ;
579
- let md5_digest = md5:: compute ( id. as_bytes ( ) ) ;
580
- let source_uuid = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
581
- source_uuids. push ( source_uuid) ;
582
-
583
- let start = SystemTime :: now ( ) ;
584
- let timestamp = start
585
- . duration_since ( UNIX_EPOCH )
586
- . expect ( "Time went backwards" )
587
- . as_millis ( ) ;
588
-
589
- let versions: HashMap < String , serde_json:: Value > = document
590
- . as_object ( )
591
- . context ( "document must be an object" ) ?
592
- . iter ( )
593
- . try_fold ( HashMap :: new ( ) , |mut acc, ( key, value) | {
594
- let md5_digest = md5:: compute ( serde_json:: to_string ( value) ?. as_bytes ( ) ) ;
595
- let md5_digest = format ! ( "{md5_digest:x}" ) ;
596
- acc. insert (
597
- key. to_owned ( ) ,
598
- serde_json:: json!( {
599
- "last_updated" : timestamp,
600
- "md5" : md5_digest
601
- } ) ,
602
- ) ;
603
- anyhow:: Ok ( acc)
604
- } ) ?;
605
- let versions = serde_json:: to_value ( versions) ?;
606
-
607
- query = query. bind ( source_uuid) . bind ( document) . bind ( versions) ;
608
- }
610
+ let query = query. replace (
611
+ "{values_parameters}" ,
612
+ & query_values. chars ( ) . skip ( 1 ) . collect :: < String > ( ) ,
613
+ ) ;
614
+ let query = query. replace (
615
+ "{binding_parameter}" ,
616
+ & format ! ( "${binding_parameter_counter}" ) ,
617
+ ) ;
609
618
610
- let results: Vec < ( i64 , Option < Json > ) > = query
611
- . bind ( source_uuids)
612
- . fetch_all ( & mut * transaction)
613
- . await ?;
619
+ let mut query = sqlx:: query_as ( & query) ;
620
+
621
+ let mut source_uuids = vec ! [ ] ;
622
+ for document in & batch {
623
+ let id = document
624
+ . get ( "id" )
625
+ . context ( "`id` must be a key in document" ) ?
626
+ . to_string ( ) ;
627
+ let md5_digest = md5:: compute ( id. as_bytes ( ) ) ;
628
+ let source_uuid = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
629
+ source_uuids. push ( source_uuid) ;
630
+
631
+ let start = SystemTime :: now ( ) ;
632
+ let timestamp = start
633
+ . duration_since ( UNIX_EPOCH )
634
+ . expect ( "Time went backwards" )
635
+ . as_millis ( ) ;
636
+
637
+ let versions: HashMap < String , serde_json:: Value > = document
638
+ . as_object ( )
639
+ . context ( "document must be an object" ) ?
640
+ . iter ( )
641
+ . try_fold ( HashMap :: new ( ) , |mut acc, ( key, value) | {
642
+ let md5_digest = md5:: compute ( serde_json:: to_string ( value) ?. as_bytes ( ) ) ;
643
+ let md5_digest = format ! ( "{md5_digest:x}" ) ;
644
+ acc. insert (
645
+ key. to_owned ( ) ,
646
+ serde_json:: json!( {
647
+ "last_updated" : timestamp,
648
+ "md5" : md5_digest
649
+ } ) ,
650
+ ) ;
651
+ anyhow:: Ok ( acc)
652
+ } ) ?;
653
+ let versions = serde_json:: to_value ( versions) ?;
614
654
615
- let dp: Vec < ( i64 , Json , Option < Json > ) > = results
616
- . into_iter ( )
617
- . zip ( batch)
618
- . map ( |( ( id, previous_document) , document) | {
619
- ( id, document. to_owned ( ) , previous_document)
655
+ query = query. bind ( source_uuid) . bind ( document) . bind ( versions) ;
656
+ }
657
+
658
+ let results: Vec < ( i64 , Option < Json > ) > = query
659
+ . bind ( source_uuids)
660
+ . fetch_all ( & mut * transaction)
661
+ . await ?;
662
+
663
+ let dp: Vec < ( i64 , Json , Option < Json > ) > = results
664
+ . into_iter ( )
665
+ . zip ( batch)
666
+ . map ( |( ( id, previous_document) , document) | ( id, document. to_owned ( ) , previous_document) )
667
+ . collect ( ) ;
668
+
669
+ for ( pipeline, parsed_schema) in & mut pipelines {
670
+ let ids_to_run_on: Vec < i64 > = dp
671
+ . iter ( )
672
+ . filter ( |( _, document, previous_document) | match previous_document {
673
+ Some ( previous_document) => parsed_schema
674
+ . iter ( )
675
+ . any ( |( key, _) | document[ key] != previous_document[ key] ) ,
676
+ None => true ,
620
677
} )
678
+ . map ( |( document_id, _, _) | * document_id)
621
679
. collect ( ) ;
622
-
623
- for ( pipeline, parsed_schema) in & mut pipelines {
624
- let ids_to_run_on: Vec < i64 > = dp
625
- . iter ( )
626
- . filter ( |( _, document, previous_document) | match previous_document {
627
- Some ( previous_document) => parsed_schema
628
- . iter ( )
629
- . any ( |( key, _) | document[ key] != previous_document[ key] ) ,
630
- None => true ,
631
- } )
632
- . map ( |( document_id, _, _) | * document_id)
633
- . collect ( ) ;
634
- if !ids_to_run_on. is_empty ( ) {
635
- pipeline
636
- . sync_documents ( ids_to_run_on, project_info, & mut transaction)
637
- . await
638
- . expect ( "Failed to execute pipeline" ) ;
639
- }
680
+ if !ids_to_run_on. is_empty ( ) {
681
+ pipeline
682
+ . sync_documents ( ids_to_run_on, project_info, & mut transaction)
683
+ . await
684
+ . expect ( "Failed to execute pipeline" ) ;
640
685
}
641
-
642
- transaction. commit ( ) . await ?;
643
- progress_bar. inc ( batch_size) ;
644
686
}
645
- progress_bar . println ( "Done Upserting Documents \n " ) ;
646
- progress_bar . finish ( ) ;
687
+
688
+ transaction . commit ( ) . await ? ;
647
689
Ok ( ( ) )
648
690
}
649
691
0 commit comments