Skip to content

Commit 4908be4

Browse files
authored
Add support for google/pegasus-xsum (#1325)
1 parent c7494db commit 4908be4

File tree

1 file changed

+39
-12
lines changed

1 file changed

+39
-12
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
PegasusTokenizer,
4242
TrainingArguments,
4343
Trainer,
44-
GPTQConfig
44+
GPTQConfig,
45+
PegasusForConditionalGeneration,
46+
PegasusTokenizer,
4547
)
4648
import threading
4749

@@ -254,6 +256,8 @@ def __init__(self, model_name, **kwargs):
254256
if "use_auth_token" in kwargs:
255257
kwargs["token"] = kwargs.pop("use_auth_token")
256258

259+
self.model_name = model_name
260+
257261
if (
258262
"task" in kwargs
259263
and model_name is not None
@@ -278,29 +282,55 @@ def __init__(self, model_name, **kwargs):
278282
model_name, **kwargs
279283
)
280284
elif self.task == "summarization" or self.task == "translation":
281-
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
285+
if model_name == "google/pegasus-xsum":
286+
# HF auto model doesn't detect GPUs
287+
self.model = PegasusForConditionalGeneration.from_pretrained(
288+
model_name
289+
)
290+
else:
291+
self.model = AutoModelForSeq2SeqLM.from_pretrained(
292+
model_name, **kwargs
293+
)
282294
elif self.task == "text-generation" or self.task == "conversational":
283295
# See: https://huggingface.co/docs/transformers/main/quantization
284296
if "quantization_config" in kwargs:
285297
quantization_config = kwargs.pop("quantization_config")
286298
quantization_config = GPTQConfig(**quantization_config)
287-
self.model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, **kwargs)
299+
self.model = AutoModelForCausalLM.from_pretrained(
300+
model_name, quantization_config=quantization_config, **kwargs
301+
)
288302
else:
289-
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
303+
self.model = AutoModelForCausalLM.from_pretrained(
304+
model_name, **kwargs
305+
)
290306
else:
291307
raise PgMLException(f"Unhandled task: {self.task}")
292308

309+
if model_name == "google/pegasus-xsum":
310+
kwargs.pop("token", None)
311+
293312
if "token" in kwargs:
294313
self.tokenizer = AutoTokenizer.from_pretrained(
295314
model_name, token=kwargs["token"]
296315
)
297316
else:
298-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
317+
if model_name == "google/pegasus-xsum":
318+
self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
319+
else:
320+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
321+
322+
pipe_kwargs = {
323+
"model": self.model,
324+
"tokenizer": self.tokenizer,
325+
}
326+
327+
# https://huggingface.co/docs/transformers/en/model_doc/pegasus
328+
if model_name == "google/pegasus-xsum":
329+
pipe_kwargs["device"] = kwargs.get("device", "cpu")
299330

300331
self.pipe = transformers.pipeline(
301332
self.task,
302-
model=self.model,
303-
tokenizer=self.tokenizer,
333+
**pipe_kwargs,
304334
)
305335
else:
306336
self.pipe = transformers.pipeline(**kwargs)
@@ -320,7 +350,7 @@ def stream(self, input, timeout=None, **kwargs):
320350
self.tokenizer,
321351
timeout=timeout,
322352
skip_prompt=True,
323-
skip_special_tokens=True
353+
skip_special_tokens=True,
324354
)
325355
if "chat_template" in kwargs:
326356
input = self.tokenizer.apply_chat_template(
@@ -343,9 +373,7 @@ def stream(self, input, timeout=None, **kwargs):
343373
)
344374
else:
345375
streamer = TextIteratorStreamer(
346-
self.tokenizer,
347-
timeout=timeout,
348-
skip_special_tokens=True
376+
self.tokenizer, timeout=timeout, skip_special_tokens=True
349377
)
350378
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
351379
self.model.device
@@ -496,7 +524,6 @@ def embed(transformer, inputs, kwargs):
496524
return embed_using(model, transformer, inputs, kwargs)
497525

498526

499-
500527
def clear_gpu_cache(memory_usage: None):
501528
if not torch.cuda.is_available():
502529
raise PgMLException(f"No GPU available")

0 commit comments

Comments
 (0)