Skip to content

SDK - Added re-ranking into document search #1527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgml-sdks/pgml/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 15 additions & 5 deletions pgml-sdks/pgml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ mod tests {
#[tokio::test]
async fn can_search_with_local_embeddings() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test_r_c_cswle_123";
let collection_name = "test_r_c_cswle_126";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
Expand Down Expand Up @@ -1038,7 +1038,12 @@ mod tests {
"full_text_search": {
"title": {
"query": "test 9",
"boost": 4.0
"boost": 4.0,
"rerank": {
"query": "Test document 2",
"model": "mixedbread-ai/mxbai-rerank-base-v1",
"num_documents_to_rerank": 100
}
},
"body": {
"query": "Test",
Expand All @@ -1051,7 +1056,12 @@ mod tests {
"parameters": {
"prompt": "query: ",
},
"boost": 2.0
"boost": 2.0,
"rerank": {
"query": "Test document 2",
"model": "mixedbread-ai/mxbai-rerank-base-v1",
"num_documents_to_rerank": 100
}
},
"body": {
"query": "This is the body test",
Expand Down Expand Up @@ -1086,7 +1096,7 @@ mod tests {
.iter()
.map(|r| r["document"]["id"].as_u64().unwrap())
.collect();
assert_eq!(ids, vec![9, 3, 4, 7, 5]);
assert_eq!(ids, vec![2, 9, 3, 8, 4]);

let pool = get_or_initialize_pool(&None).await?;

Expand All @@ -1111,7 +1121,7 @@ mod tests {
// Document ids are 1 based in the db not 0 based like they are here
assert_eq!(
search_results.iter().map(|sr| sr.2).collect::<Vec<i64>>(),
vec![10, 4, 5, 8, 6]
vec![3, 10, 4, 9, 5]
);

let event = json!({"clicked": true});
Expand Down
177 changes: 161 additions & 16 deletions pgml-sdks/pgml/src/search_query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ struct ValidSemanticSearchAction {
query: String,
parameters: Option<Json>,
boost: Option<f32>,
rerank: Option<ValidRerank>,
}

#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct ValidFullTextSearchAction {
query: String,
boost: Option<f32>,
rerank: Option<ValidRerank>,
}

#[derive(Debug, Deserialize)]
Expand All @@ -42,6 +44,20 @@ struct ValidQueryActions {
filter: Option<Json>,
}

const fn default_num_documents_to_rerank() -> u64 {
10
}

#[derive(Debug, Deserialize, Clone)]
#[serde(deny_unknown_fields)]
struct ValidRerank {
query: String,
model: String,
#[serde(default = "default_num_documents_to_rerank")]
num_documents_to_rerank: u64,
parameters: Option<Json>,
}

const fn default_limit() -> u64 {
10
}
Expand Down Expand Up @@ -106,7 +122,11 @@ pub async fn build_search_query(
// Build the CTE we actually use later
let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key);
let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key);
let cte_name = format!("{key}_embedding_score");
let cte_name = if vsa.rerank.is_some() {
format!("pre_rerank_{key}_embedding_score")
} else {
format!("{key}_embedding_score")
};
let boost = vsa.boost.unwrap_or(1.);
let mut score_cte_non_recursive = Query::select();
let mut score_cte_recurisive = Query::select();
Expand All @@ -131,6 +151,7 @@ pub async fn build_search_query(
score_cte_non_recursive
.from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings"))
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.join_as(
JoinType::InnerJoin,
chunks_table.to_table_tuple(),
Expand All @@ -157,6 +178,7 @@ pub async fn build_search_query(
score_cte_recurisive
.from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings"))
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || documents.id"#)))
.expr(Expr::cust(format!(
r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"#
Expand Down Expand Up @@ -213,6 +235,7 @@ pub async fn build_search_query(
score_cte_non_recursive
.from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings"))
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr(Expr::cust("ARRAY[documents.id] as previous_document_ids"))
.expr(Expr::cust_with_values(
format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"),
Expand Down Expand Up @@ -249,6 +272,7 @@ pub async fn build_search_query(
Expr::cust("1 = 1"),
)
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr(Expr::cust(format!(
r#""{cte_name}".previous_document_ids || documents.id"#
)))
Expand Down Expand Up @@ -295,18 +319,75 @@ pub async fn build_search_query(
.from_subquery(score_cte_non_recursive, Alias::new("non_recursive"))
.union(sea_query::UnionType::All, score_cte_recurisive)
.to_owned();

let mut score_cte = CommonTableExpression::from_select(score_cte);
score_cte.table_name(Alias::new(&cte_name));
with_clause.cte(score_cte);

if let Some(rerank) = vsa.rerank {
// Add our row_number_pre_rerank CTE
let mut row_number_pre_rerank = Query::select();
row_number_pre_rerank
.column(SIden::Str("id"))
.column(SIden::Str("chunk"))
.from(SIden::String(cte_name.clone()))
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number"))
.limit(rerank.num_documents_to_rerank);
let mut row_number_pre_rerank_cte =
CommonTableExpression::from_select(row_number_pre_rerank);
row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}")));
with_clause.cte(row_number_pre_rerank_cte);

// Our actual CTE
let mut query = Query::select();
query.column(SIden::Str("id"));
query.expr_as(
Expr::cust(format!("(rank).score * {boost}")),
Alias::new("score"),
);

// Build the actual CTE
let mut sub_query_rank_call = Query::select();
let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]);
let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]);
let parameters_expr =
Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]);
sub_query_rank_call.expr_as(Expr::cust_with_exprs(
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
[model_expr, query_expr, parameters_expr],
), Alias::new("rank"))
.from(SIden::String(format!("row_number_{cte_name}")));

let mut sub_query = Query::select();
sub_query
.columns([SIden::Str("id"), SIden::Str("rank")])
.from_as(
SIden::String(format!("row_number_{cte_name}")),
Alias::new("rnsv1"),
)
.join_subquery(
JoinType::InnerJoin,
sub_query_rank_call,
Alias::new("rnsv2"),
Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"),
);

query.from_subquery(sub_query, Alias::new("sub_query"));
let mut query_cte = CommonTableExpression::from_select(query);
query_cte.table_name(Alias::new(format!("{key}_embedding_score")));
with_clause.cte(query_cte);
}

// Add to the sum expression
sum_expression = if let Some(expr) = sum_expression {
Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))))
Some(expr.add(Expr::cust(format!(
r#"COALESCE("{key}_embedding_score".score, 0.0)"#
))))
} else {
Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))
Some(Expr::cust(format!(
r#"COALESCE("{key}_embedding_score".score, 0.0)"#
)))
};
score_table_names.push(cte_name);
score_table_names.push(format!("{key}_embedding_score"));
}

for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() {
Expand All @@ -315,10 +396,15 @@ pub async fn build_search_query(
let boost = vma.boost.unwrap_or(1.0);

// Build the score CTE
let cte_name = format!("{key}_tsvectors_score");
let cte_name = if vma.rerank.is_some() {
format!("pre_rerank_{key}_tsvectors_score")
} else {
format!("{key}_tsvectors_score")
};

let mut score_cte_non_recursive = Query::select()
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr_as(
Expr::cust_with_values(
format!(
Expand Down Expand Up @@ -361,6 +447,7 @@ pub async fn build_search_query(

let mut score_cte_recursive = Query::select()
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr_as(
Expr::cust_with_values(
format!(
Expand Down Expand Up @@ -425,13 +512,71 @@ pub async fn build_search_query(
score_cte.table_name(Alias::new(&cte_name));
with_clause.cte(score_cte);

if let Some(rerank) = vma.rerank {
// Add our row_number_pre_rerank CTE
let mut row_number_pre_rerank = Query::select();
row_number_pre_rerank
.column(SIden::Str("id"))
.column(SIden::Str("chunk"))
.from(SIden::String(cte_name.clone()))
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number"))
.limit(rerank.num_documents_to_rerank);
let mut row_number_pre_rerank_cte =
CommonTableExpression::from_select(row_number_pre_rerank);
row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}")));
with_clause.cte(row_number_pre_rerank_cte);

// Our actual CTE
let mut query = Query::select();
query.column(SIden::Str("id"));
query.expr_as(
Expr::cust(format!("(rank).score * {boost}")),
Alias::new("score"),
);

// Build the actual CTE
let mut sub_query_rank_call = Query::select();
let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]);
let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]);
let parameters_expr =
Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]);
sub_query_rank_call.expr_as(Expr::cust_with_exprs(
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
[model_expr, query_expr, parameters_expr],
), Alias::new("rank"))
.from(SIden::String(format!("row_number_{cte_name}")));

let mut sub_query = Query::select();
sub_query
.columns([SIden::Str("id"), SIden::Str("rank")])
.from_as(
SIden::String(format!("row_number_{cte_name}")),
Alias::new("rnsv1"),
)
.join_subquery(
JoinType::InnerJoin,
sub_query_rank_call,
Alias::new("rnsv2"),
Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"),
);

query.from_subquery(sub_query, Alias::new("sub_query"));
let mut query_cte = CommonTableExpression::from_select(query);
query_cte.table_name(Alias::new(format!("{key}_tsvectors_score")));
with_clause.cte(query_cte);
}

// Add to the sum expression
sum_expression = if let Some(expr) = sum_expression {
Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))))
Some(expr.add(Expr::cust(format!(
r#"COALESCE("{key}_tsvectors_score".score, 0.0)"#
))))
} else {
Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))
Some(Expr::cust(format!(
r#"COALESCE("{key}_tsvectors_score".score, 0.0)"#
)))
};
score_table_names.push(cte_name);
score_table_names.push(format!("{key}_tsvectors_score"));
}

let query = if let Some(select_from) = score_table_names.first() {
Expand All @@ -440,9 +585,9 @@ pub async fn build_search_query(
.into_iter()
.map(|t| Expr::col((SIden::String(t), SIden::Str("id"))).into())
.collect();
let mut main_query = Query::select();
let mut joined_query = Query::select();
for i in 1..score_table_names_e.len() {
main_query.full_outer_join(
joined_query.full_outer_join(
SIden::String(score_table_names[i].to_string()),
Expr::col((
SIden::String(score_table_names[i].to_string()),
Expand All @@ -455,7 +600,8 @@ pub async fn build_search_query(

let sum_expression = sum_expression
.context("query requires some scoring through full_text_search or semantic_search")?;
main_query

joined_query
.expr_as(Expr::expr(id_select_expression.clone()), Alias::new("id"))
.expr_as(sum_expression, Alias::new("score"))
.column(SIden::Str("document"))
Expand All @@ -468,10 +614,9 @@ pub async fn build_search_query(
)
.order_by(SIden::Str("score"), Order::Desc)
.limit(limit);

let mut main_query = CommonTableExpression::from_select(main_query);
main_query.table_name(Alias::new("main"));
with_clause.cte(main_query);
let mut joined_query = CommonTableExpression::from_select(joined_query);
joined_query.table_name(Alias::new("main"));
with_clause.cte(joined_query);

// Insert into searches table
let searches_table = format!("{}_{}.searches", collection.name, pipeline.name);
Expand Down