From c5c0d0277d6bfb68bea0d146713a93c09e7c3233 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 13 Dec 2024 10:43:47 +0200 Subject: [PATCH 1/7] flatten buffer args, delete dtype [pr] (#8202) --- tinygrad/ops.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index abb8cefa6cd8..fd2df9644a2e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -281,7 +281,7 @@ def full_shape(self) -> Tuple[sint, ...]: @property def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape @property - def size(self) -> int: return self.arg[1][1] if self.op is Ops.BUFFER else unwrap(self.st).size + def size(self) -> int: return self.arg[2] if self.op is Ops.BUFFER else unwrap(self.st).size # *** uop evaluation *** @@ -416,7 +416,7 @@ def metaop(op:Ops, shape:Tuple[sint, ...], dtype:DType, device:str, arg=None, sr from tinygrad.shape.shapetracker import ShapeTracker # NOTE: we embed device on CONST with a fake BUFFER uop if op is Ops.CONST: - fake = UOp(Ops.BUFFER, dtype.ptr(), (), (-1, (device, 1, dtype))) + fake = UOp(Ops.BUFFER, dtype.ptr(), (), (-1, device, 1)) return UOp(Ops.VIEW, dtype, (fake, arg if isinstance(arg, UOp) else UOp.const(dtype, unwrap(arg))), ShapeTracker.from_shape(())).reshape((1,)*len(shape)).expand(shape) # otherwise it's a contiguous st @@ -477,12 +477,12 @@ def stride(self, arg:Tuple[int, ...]): return self.view(unwrap(self.st).stride(a buffer_num = itertools.count(0) @staticmethod - def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype))) + def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), device, size)) @property def device(self) -> str: return unwrap(self._device) @functools.cached_property def _device(self) -> Optional[str]: - return self.arg[1][0] if self.op is Ops.BUFFER else dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None + return self.arg[1] if self.op is Ops.BUFFER else dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: if self.op is Ops.BUFFER: return self @@ -497,7 +497,7 @@ def buffer(self) -> Buffer: return self.src[0].buffer assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" from tinygrad.device import Buffer - buffers[self] = ret = Buffer(*self.arg[1]) + buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base) return ret @property def realized(self) -> Optional[Buffer]: From 5864627abe6dffdccf3ebafb230eb65c1bd40e2c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 13 Dec 2024 11:43:43 +0200 Subject: [PATCH 2/7] process replay filter warnings [pr] (#8199) --- test/external/process_replay/process_replay.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 95e56de660ee..c0d80fee1f9f 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # compare kernels created by HEAD against master -import os, multiprocessing, logging, pickle, sqlite3, difflib, functools +import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings from typing import Callable, List, Set, Tuple, Union, cast from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.engine.schedule import ScheduleContext, full_ast_rewrite @@ -25,6 +25,7 @@ if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0 SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") if REF == "master": SKIP_PROCESS_REPLAY = True +class ProcessReplayWarning(Warning): pass # *** recreators @@ -56,9 +57,8 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]: with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2]) if good is None: continue except Exception as e: - logging.warning(f"FAILED TO RECREATE KERNEL {e}") + warnings.warn(f"FAILED TO RECREATE KERNEL {e}", ProcessReplayWarning) for x in args[:-1]: logging.info(x) - if ASSERT_DIFF: return True continue # diff kernels try: assert args[-1] == good @@ -85,7 +85,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None: cur = conn.cursor() try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0] except sqlite3.OperationalError: - logging.warning(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?") + warnings.warn(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning) return None conn.commit() cur.close() @@ -100,7 +100,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None: logging.info(f"{sum(changed)} kernels changed") if sum(insertion) != 0: logging.info(colored(f"{sum(insertion)} insertions(+)", "green")) if sum(deletions) != 0: logging.info(colored(f"{sum(deletions)} deletions(-)", "red")) - if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes") + if any(changed): warnings.warn("process replay detected changes", ProcessReplayWarning) # *** main loop @@ -109,6 +109,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None: logging.info("skipping process replay.") exit(0) + if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning) for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]: logging.info(f"***** {name} diff") try: _pmap(name, fxn) From 651f72442c515c89079f9639084915c02977b104 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Fri, 13 Dec 2024 10:55:37 +0100 Subject: [PATCH 3/7] encapsulate the exported webgpu model (#8203) --- examples/webgpu/efficientnet/index.html | 10 +++--- examples/webgpu/yolov8/compile.py | 2 +- examples/webgpu/yolov8/index.html | 7 +++-- extra/export_model.py | 41 +++++++++++++------------ 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/examples/webgpu/efficientnet/index.html b/examples/webgpu/efficientnet/index.html index fccd7e5744fc..32540141c42c 100644 --- a/examples/webgpu/efficientnet/index.html +++ b/examples/webgpu/efficientnet/index.html @@ -17,8 +17,11 @@ * { text-align: center; font-family: monospace; } tinygrad has WebGPU - +

WebGPU tinygrad EfficientNet!

@@ -61,8 +64,6 @@

WebGPU tinygrad EfficientNe const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json(); - const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("../../net.safetensors")).arrayBuffer()); - const reorderChannelsAndRemoveAlpha = (data) => { const out = []; let i = 0; @@ -97,9 +98,8 @@

WebGPU tinygrad EfficientNe try { resultText.innerHTML = "loading..." labels = await getLabels(); - const safetensor = await getSavetensorBuffer(); const device = await getDevice(); - net = await timer(() => setupNet(device, safetensor), "(compilation)"); + net = await timer(() => model.load(device, '../../net.safetensors'), "(compilation)"); resultText.innerHTML = "ready" } catch (e) { error(e) diff --git a/examples/webgpu/yolov8/compile.py b/examples/webgpu/yolov8/compile.py index a5f9d2bc5cbb..667b5ac75a56 100644 --- a/examples/webgpu/yolov8/compile.py +++ b/examples/webgpu/yolov8/compile.py @@ -12,7 +12,7 @@ yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80) state_dict = safe_load(get_weights_location(yolo_variant)) load_state_dict(yolo_infer, state_dict) - prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416)) + prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416), model_name="yolov8") dirname = Path(__file__).parent safe_save(state, (dirname / "net.safetensors").as_posix()) with open(dirname / f"net.js", "w") as text_file: diff --git a/examples/webgpu/yolov8/index.html b/examples/webgpu/yolov8/index.html index 39d8231d0f18..89d111aeee07 100644 --- a/examples/webgpu/yolov8/index.html +++ b/examples/webgpu/yolov8/index.html @@ -4,7 +4,10 @@ YOLOv8 tinygrad WebGPU - +