Skip to content

Working simple python thread metrics #1239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 63 additions & 5 deletions pgml-extension/src/bindings/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import shutil
import time
import queue
import sys
import json

import datasets
from InstructorEmbedding import INSTRUCTOR
Expand Down Expand Up @@ -40,7 +42,7 @@
TrainingArguments,
Trainer,
)
from threading import Thread
import threading

__cache_transformer_by_model_id = {}
__cache_sentence_transformer_by_name = {}
Expand All @@ -62,6 +64,26 @@
}


class WorkerThreads:
def __init__(self):
self.worker_threads = {}

def delete_thread(self, id):
del self.worker_threads[id]

def update_thread(self, id, value):
self.worker_threads[id] = value

def get_thread(self, id):
if id in self.worker_threads:
return self.worker_threads[id]
else:
return None


worker_threads = WorkerThreads()


class PgMLException(Exception):
pass

Expand Down Expand Up @@ -105,6 +127,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
self.token_cache = []
self.text_index_cache = []

def set_worker_thread_id(self, id):
self.worker_thread_id = id

def get_worker_thread_id(self):
return self.worker_thread_id

def put(self, values):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
Expand Down Expand Up @@ -149,6 +177,22 @@ def __next__(self):
return value


def streaming_worker(worker_threads, model, **kwargs):
thread_id = threading.get_native_id()
try:
worker_threads.update_thread(
thread_id, json.dumps({"model": model.name_or_path})
)
except:
worker_threads.update_thread(thread_id, "Error setting data")
try:
model.generate(**kwargs)
except BaseException as error:
print(f"Error in streaming_worker: {error}", file=sys.stderr)
finally:
worker_threads.delete_thread(thread_id)


class GGMLPipeline(object):
def __init__(self, model_name, **task):
import ctransformers
Expand Down Expand Up @@ -185,7 +229,7 @@ def do_work():
self.q.put(x)
self.done = True

thread = Thread(target=do_work)
thread = threading.Thread(target=do_work)
thread.start()

def __iter__(self):
Expand Down Expand Up @@ -283,7 +327,13 @@ def stream(self, input, timeout=None, **kwargs):
input, add_generation_prompt=True, tokenize=False
)
input = self.tokenizer(input, return_tensors="pt").to(self.model.device)
generation_kwargs = dict(input, streamer=streamer, **kwargs)
generation_kwargs = dict(
input,
worker_threads=worker_threads,
model=self.model,
streamer=streamer,
**kwargs,
)
else:
streamer = TextIteratorStreamer(
self.tokenizer,
Expand All @@ -292,9 +342,17 @@ def stream(self, input, timeout=None, **kwargs):
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
self.model.device
)
generation_kwargs = dict(input, streamer=streamer, **kwargs)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
generation_kwargs = dict(
input,
worker_threads=worker_threads,
model=self.model,
streamer=streamer,
**kwargs,
)
# thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread = threading.Thread(target=streaming_worker, kwargs=generation_kwargs)
thread.start()
streamer.set_worker_thread_id(thread.native_id)
return streamer

def __call__(self, inputs, **kwargs):
Expand Down