Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] master from tinygrad:master #62

Merged
merged 7 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/webgpu/efficientnet/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
* { text-align: center; font-family: monospace; }
</style>
<title>tinygrad has WebGPU</title>
<script src="../../net.js"></script>
<link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png">
<script type="module">
import model from "../../net.js";
window.model = model;
</script>
</head>
<body>
<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
Expand Down Expand Up @@ -61,8 +64,6 @@ <h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNe

const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();

const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("../../net.safetensors")).arrayBuffer());

const reorderChannelsAndRemoveAlpha = (data) => {
const out = [];
let i = 0;
Expand Down Expand Up @@ -97,9 +98,8 @@ <h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNe
try {
resultText.innerHTML = "loading..."
labels = await getLabels();
const safetensor = await getSavetensorBuffer();
const device = await getDevice();
net = await timer(() => setupNet(device, safetensor), "(compilation)");
net = await timer(() => model.load(device, '../../net.safetensors'), "(compilation)");
resultText.innerHTML = "ready"
} catch (e) {
error(e)
Expand Down
2 changes: 1 addition & 1 deletion examples/webgpu/yolov8/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
state_dict = safe_load(get_weights_location(yolo_variant))
load_state_dict(yolo_infer, state_dict)
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416))
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416), model_name="yolov8")
dirname = Path(__file__).parent
safe_save(state, (dirname / "net.safetensors").as_posix())
with open(dirname / f"net.js", "w") as text_file:
Expand Down
7 changes: 5 additions & 2 deletions examples/webgpu/yolov8/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>YOLOv8 tinygrad WebGPU</title>
<script src="./net.js"></script>
<script type="module">
import yolov8 from "./net.js"
window.yolov8 = yolov8;
</script>
<style>
body {
text-align: center;
Expand Down Expand Up @@ -213,7 +216,7 @@ <h2 id="wgpu-error" style="display: none; color: red;">Error: WebGPU is not supp
wgpuError.style.display = "block";
loadingContainer.style.display = "none";
}
net = await loadNet(device);
net = await yolov8.load(device, "./net.safetensors");
loadingContainer.style.display = "none";
}
let start = performance.now();
Expand Down
41 changes: 21 additions & 20 deletions extra/export_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Dict, List
from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Device, Tensor
Expand All @@ -9,17 +9,6 @@
import json

EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
web_utils = {
"getTensorBuffer":
"""const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}""",
"getTensorMetadata": """const getTensorMetadata = (safetensorBuffer) => {
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}]));
};"""
}

def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
Expand Down Expand Up @@ -82,14 +71,15 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
def dtype_to_js_type(dtype: DType) -> str:
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"

def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name) -> Tuple[str,int,int]:
exported_name = "model" if model_name == None else model_name
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
create_bind_group_layouts = ",".join([
"device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format(
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'storage' }} }}" for argIdx, _ in enumerate(args)])
)
for i, (_name, args, global_size, _local_size) in enumerate(statements)
for _, (_, args, _, _) in enumerate(statements)
])
layouts = f"const layouts=[{create_bind_group_layouts}]"
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
Expand All @@ -103,9 +93,16 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names,
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
return f"""
{web_utils["getTensorBuffer"]}
const {exported_name} = (() => {{
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}};

{web_utils["getTensorMetadata"]}
const getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};

const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
Expand Down Expand Up @@ -187,9 +184,13 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names,
return {output_return};
}}
}}
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}
return {{ load }};
}})();
export default {exported_name};
"""

def export_model(model, target:str, *inputs):
def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
Expand All @@ -201,7 +202,7 @@ def export_model(model, target:str, *inputs):
if target == "clang":
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
elif target == "webgpu":
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name)
else:
prg = json.dumps({
"backend": Device.DEFAULT,
Expand Down
11 changes: 6 additions & 5 deletions test/external/process_replay/process_replay.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, List, Set, Tuple, Union, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import ScheduleContext, full_ast_rewrite
Expand All @@ -25,6 +25,7 @@
if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
class ProcessReplayWarning(Warning): pass

# *** recreators

Expand Down Expand Up @@ -56,9 +57,8 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
if good is None: continue
except Exception as e:
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
warnings.warn(f"FAILED TO RECREATE KERNEL {e}", ProcessReplayWarning)
for x in args[:-1]: logging.info(x)
if ASSERT_DIFF: return True
continue
# diff kernels
try: assert args[-1] == good
Expand All @@ -85,7 +85,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
warnings.warn(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning)
return None
conn.commit()
cur.close()
Expand All @@ -100,7 +100,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
logging.info(f"{sum(changed)} kernels changed")
if sum(insertion) != 0: logging.info(colored(f"{sum(insertion)} insertions(+)", "green"))
if sum(deletions) != 0: logging.info(colored(f"{sum(deletions)} deletions(-)", "red"))
if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes")
if any(changed): warnings.warn("process replay detected changes", ProcessReplayWarning)

# *** main loop

Expand All @@ -109,6 +109,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
logging.info("skipping process replay.")
exit(0)

if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning)
for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]:
logging.info(f"***** {name} diff")
try: _pmap(name, fxn)
Expand Down
Loading
Loading