Skip to content

Commit aca505c

Browse files
committed
add vllm inference docs; fix logic
1 parent 74ce6ae commit aca505c

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

pgml-extension/src/bindings/vllm/inference.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,28 @@ pub fn vllm_inference(task: &Value, inputs: &[&str]) -> PyResult<Value> {
4040
Ok(json!(outputs))
4141
}
4242

43+
/// Determine if the "model" specified in the task is the same model as the one cached.
44+
///
45+
/// # Panic
46+
/// This function panics if:
47+
/// - `task` is not an object
48+
/// - "model" key is missing from `task` object
49+
/// - "model" value is not a str
4350
fn get_model_name<M>(model: &M, task: &Value) -> ModelName
4451
where
4552
M: std::ops::Deref<Target = Option<LLM>>,
4653
{
47-
match task
48-
.as_object()
49-
.and_then(|obj| obj.get("model").and_then(|m| m.as_str()))
50-
{
51-
Some(name) => match model.as_ref() {
52-
Some(llm) if llm.model() == name => ModelName::Same,
53-
_ => ModelName::Different(name.to_string()),
54-
},
55-
None => ModelName::Same,
54+
let name = task.as_object()
55+
.expect("`task` is an object")
56+
.get("model")
57+
.expect("model key is present")
58+
.as_str()
59+
.expect("model value is a str");
60+
61+
if matches!(model.as_ref(), Some(llm) if llm.model() == name) {
62+
ModelName::Same
63+
} else {
64+
ModelName::Different(name.to_string())
5665
}
5766
}
5867

0 commit comments

Comments
 (0)