41
41
PegasusTokenizer ,
42
42
TrainingArguments ,
43
43
Trainer ,
44
- GPTQConfig
44
+ GPTQConfig ,
45
+ PegasusForConditionalGeneration ,
46
+ PegasusTokenizer ,
45
47
)
46
48
import threading
47
49
@@ -254,6 +256,8 @@ def __init__(self, model_name, **kwargs):
254
256
if "use_auth_token" in kwargs :
255
257
kwargs ["token" ] = kwargs .pop ("use_auth_token" )
256
258
259
+ self .model_name = model_name
260
+
257
261
if (
258
262
"task" in kwargs
259
263
and model_name is not None
@@ -278,29 +282,55 @@ def __init__(self, model_name, **kwargs):
278
282
model_name , ** kwargs
279
283
)
280
284
elif self .task == "summarization" or self .task == "translation" :
281
- self .model = AutoModelForSeq2SeqLM .from_pretrained (model_name , ** kwargs )
285
+ if model_name == "google/pegasus-xsum" :
286
+ # HF auto model doesn't detect GPUs
287
+ self .model = PegasusForConditionalGeneration .from_pretrained (
288
+ model_name
289
+ )
290
+ else :
291
+ self .model = AutoModelForSeq2SeqLM .from_pretrained (
292
+ model_name , ** kwargs
293
+ )
282
294
elif self .task == "text-generation" or self .task == "conversational" :
283
295
# See: https://huggingface.co/docs/transformers/main/quantization
284
296
if "quantization_config" in kwargs :
285
297
quantization_config = kwargs .pop ("quantization_config" )
286
298
quantization_config = GPTQConfig (** quantization_config )
287
- self .model = AutoModelForCausalLM .from_pretrained (model_name , quantization_config = quantization_config , ** kwargs )
299
+ self .model = AutoModelForCausalLM .from_pretrained (
300
+ model_name , quantization_config = quantization_config , ** kwargs
301
+ )
288
302
else :
289
- self .model = AutoModelForCausalLM .from_pretrained (model_name , ** kwargs )
303
+ self .model = AutoModelForCausalLM .from_pretrained (
304
+ model_name , ** kwargs
305
+ )
290
306
else :
291
307
raise PgMLException (f"Unhandled task: { self .task } " )
292
308
309
+ if model_name == "google/pegasus-xsum" :
310
+ kwargs .pop ("token" , None )
311
+
293
312
if "token" in kwargs :
294
313
self .tokenizer = AutoTokenizer .from_pretrained (
295
314
model_name , token = kwargs ["token" ]
296
315
)
297
316
else :
298
- self .tokenizer = AutoTokenizer .from_pretrained (model_name )
317
+ if model_name == "google/pegasus-xsum" :
318
+ self .tokenizer = PegasusTokenizer .from_pretrained (model_name )
319
+ else :
320
+ self .tokenizer = AutoTokenizer .from_pretrained (model_name )
321
+
322
+ pipe_kwargs = {
323
+ "model" : self .model ,
324
+ "tokenizer" : self .tokenizer ,
325
+ }
326
+
327
+ # https://huggingface.co/docs/transformers/en/model_doc/pegasus
328
+ if model_name == "google/pegasus-xsum" :
329
+ pipe_kwargs ["device" ] = kwargs .get ("device" , "cpu" )
299
330
300
331
self .pipe = transformers .pipeline (
301
332
self .task ,
302
- model = self .model ,
303
- tokenizer = self .tokenizer ,
333
+ ** pipe_kwargs ,
304
334
)
305
335
else :
306
336
self .pipe = transformers .pipeline (** kwargs )
@@ -320,7 +350,7 @@ def stream(self, input, timeout=None, **kwargs):
320
350
self .tokenizer ,
321
351
timeout = timeout ,
322
352
skip_prompt = True ,
323
- skip_special_tokens = True
353
+ skip_special_tokens = True ,
324
354
)
325
355
if "chat_template" in kwargs :
326
356
input = self .tokenizer .apply_chat_template (
@@ -343,9 +373,7 @@ def stream(self, input, timeout=None, **kwargs):
343
373
)
344
374
else :
345
375
streamer = TextIteratorStreamer (
346
- self .tokenizer ,
347
- timeout = timeout ,
348
- skip_special_tokens = True
376
+ self .tokenizer , timeout = timeout , skip_special_tokens = True
349
377
)
350
378
input = self .tokenizer (input , return_tensors = "pt" , padding = True ).to (
351
379
self .model .device
@@ -496,7 +524,6 @@ def embed(transformer, inputs, kwargs):
496
524
return embed_using (model , transformer , inputs , kwargs )
497
525
498
526
499
-
500
527
def clear_gpu_cache (memory_usage : None ):
501
528
if not torch .cuda .is_available ():
502
529
raise PgMLException (f"No GPU available" )
0 commit comments