-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
webnn: implement
gatherElements
for DirectML backend
This CL also adds some WPT conformance tests to verify the implementation. webmachinelearning/webnn#375 (comment) Bug: 40206287 Change-Id: I88e6bbdf1fd6156421d8b190ed6be6d3b216962b Cq-Include-Trybots: luci.chromium.try:win11-blink-rel, mac14.arm64-blink-rel, mac14-blink-rel, mac15.arm64-blink-rel, mac15-blink-rel, linux-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5811264 Auto-Submit: Shiyi Zou <[email protected]> Commit-Queue: Weizhong Xia <[email protected]> Reviewed-by: ningxin hu <[email protected]> Reviewed-by: Weizhong Xia <[email protected]> Cr-Commit-Position: refs/heads/main@{#1348676}
- Loading branch information
1 parent
f91870c
commit eb162cb
Showing
1 changed file
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// META: title=test WebNN API gatherElements operation | ||
// META: global=window,dedicatedworker | ||
// META: variant=?cpu | ||
// META: variant=?gpu | ||
// META: variant=?npu | ||
// META: script=../resources/utils.js | ||
// META: timeout=long | ||
|
||
'use strict'; | ||
|
||
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-gatherElements | ||
// Gather values of the input tensor along an axis according to the indices. | ||
// | ||
// dictionary MLGatherOptions { | ||
// [EnforceRange] unsigned long axis = 0; | ||
// }; | ||
// | ||
// MLOperand gatherElements( | ||
// MLOperand input, MLOperand indices, | ||
// optional MLGatherOptions options = {}); | ||
|
||
|
||
const getGatherElementsPrecisionTolerance = () => { | ||
return {metricType: 'ULP', value: 0}; | ||
}; | ||
|
||
const gatherElementsTests = [ | ||
{ | ||
'name': 'gatherElements float32 2D input and uint32 indices options.axis=1', | ||
'graph': { | ||
'inputs': { | ||
'gatherElementsInput': { | ||
'data': [ | ||
-66.05901336669922, -68.9197006225586, -77.02045440673828, | ||
-26.158037185668945, 89.0337142944336, -45.89653396606445, | ||
43.84803771972656, 48.81806945800781, 51.79948425292969 | ||
], | ||
'descriptor': {'dimensions': [3, 3], 'dataType': 'float32'} | ||
}, | ||
'gatherElementsIndices': { | ||
'data': [1, 0, 2, 2, 1, 0], | ||
'descriptor': {'dimensions': [3, 2], 'dataType': 'uint32'}, | ||
'constant': true | ||
} | ||
}, | ||
'operators': [{ | ||
'name': 'gatherElements', | ||
'arguments': [ | ||
{'input': 'gatherElementsInput'}, | ||
{'indices': 'gatherElementsIndices'}, {'options': {'axis': 1}} | ||
], | ||
'outputs': 'gatherElementsOutput' | ||
}], | ||
'expectedOutputs': { | ||
'gatherElementsOutput': { | ||
'data': [ | ||
-68.9197006225586, -66.05901336669922, -45.89653396606445, | ||
-45.89653396606445, 48.81806945800781, 43.84803771972656 | ||
], | ||
'descriptor': {'dimensions': [3, 2], 'dataType': 'float32'} | ||
} | ||
} | ||
} | ||
}, | ||
{ | ||
'name': 'gatherElements float32 3D input and int32 negative indices', | ||
'graph': { | ||
'inputs': { | ||
'gatherElementsInput': { | ||
'data': [ | ||
-66.05901336669922, -68.9197006225586, -77.02045440673828, | ||
-26.158037185668945, 89.0337142944336, -45.89653396606445, | ||
43.84803771972656, 48.81806945800781 | ||
], | ||
'descriptor': {'dimensions': [2, 2, 2], 'dataType': 'float32'} | ||
}, | ||
'gatherElementsIndices': { | ||
'data': [-1, 0, 0, -1], | ||
'descriptor': {'dimensions': [1, 2, 2], 'dataType': 'int32'}, | ||
'constant': true | ||
} | ||
}, | ||
'operators': [{ | ||
'name': 'gatherElements', | ||
'arguments': [ | ||
{'input': 'gatherElementsInput'}, {'indices': 'gatherElementsIndices'} | ||
], | ||
'outputs': 'gatherElementsOutput' | ||
}], | ||
'expectedOutputs': { | ||
'gatherElementsOutput': { | ||
'data': [ | ||
89.0337142944336, -68.9197006225586, -77.02045440673828, | ||
48.81806945800781 | ||
], | ||
'descriptor': {'dimensions': [1, 2, 2], 'dataType': 'float32'} | ||
} | ||
} | ||
} | ||
}, | ||
{ | ||
'name': 'gatherElements float32 1D input and uint32 out-of-bounds indices', | ||
'graph': { | ||
'inputs': { | ||
'gatherElementsInput': { | ||
'data': [ | ||
-26.158037185668945, 89.0337142944336, -45.89653396606445, | ||
43.84803771972656, 48.81806945800781, 51.79948425292969 | ||
], | ||
'descriptor': {'dimensions': [6], 'dataType': 'float32'} | ||
}, | ||
'gatherElementsIndices': { | ||
'data': [7], | ||
'descriptor': {'dimensions': [1], 'dataType': 'uint32'}, | ||
'constant': true | ||
} | ||
}, | ||
'operators': [{ | ||
'name': 'gatherElements', | ||
'arguments': [ | ||
{'input': 'gatherElementsInput'}, {'indices': 'gatherElementsIndices'} | ||
], | ||
'outputs': 'gatherElementsOutput' | ||
}], | ||
'expectedOutputs': { | ||
'gatherElementsOutput': { | ||
'data': [51.79948425292969], | ||
'descriptor': {'dimensions': [1], 'dataType': 'float32'} | ||
} | ||
} | ||
} | ||
} | ||
]; | ||
|
||
if (navigator.ml) { | ||
gatherElementsTests.forEach((test) => { | ||
webnn_conformance_test( | ||
buildGraphAndCompute, getGatherElementsPrecisionTolerance, test); | ||
}); | ||
} else { | ||
test(() => assert_implements(navigator.ml, 'missing navigator.ml')); | ||
} |