-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
Hey @MoisesHer , Thanks for submitting the PR
CI supported jobs: [miscellaneous, edge, windows-cpu, windows-gpu, unix-cpu, website, unix-gpu, clang, sanity, centos-cpu, centos-gpu] Note: |
@@ -2319,3 +2319,21 @@ def test_fp16_spmm(): | |||
out = mxsps.dot(inp, weight) | |||
out_np = mx.nd.dot(inp, weight) | |||
assert_almost_equal(out.asnumpy(), out_np, rtol=1e-3, atol=1e-5) | |||
|
|||
@with_seed() | |||
@pytest.mark.serial |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to mark as serial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that the mark.serial
is triggered for all tests in this file. Thus, we may keep mark.serial
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We run tests based on the tag and it has nothing to do with file. serial is only needed when test invocation is long-running and consumes lots of memory. Since this is no longer the case through parametrizing the input, the serial tag is not needed.
please open a follow up PR to finish the change.
* Add GPU-optimization for split op * Complete operator * unit-test: use parametrize * fix lint * fix lint * fix lint
Description
Optimization of split operator on GPU
Checklist
Essentials
Changes
Performance for several scenarios can be find here:
https://docs.google.com/spreadsheets/d/1ksQcOetbs3MDAhT5pGaU3vKMoExQK-oVFqjXsol44eQ/edit?usp=sharing
When the last axis is smaller than 128, the new implementation performs in general worse than original version. Thus, in those cases we redirect those scenarios to run the original version.