-
Notifications
You must be signed in to change notification settings - Fork 22.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SpectralOps CPU implementation for ARM/PowerPC processors (where MKL is not available) #41592
Comments
Looks like this is PyTorch issue ( |
I am not sure if we support Jetson Nano and ARM64 processor. Will try to find someone who knows more... |
The Jetson Nano wheels are compiled by NVIDIA, so maybe they will now who are responsible for this. |
This sounds like an enhancement to me: at the moment, there are only MKL-accelerate cpu implementation of
MKL is not available on ARM, so in order to achieve feature parity, one should rely on other implementations, for example on http://www.fftw.org/ |
Thanks @malfet. I was able to compute fft on ARM by using CUDA device on waveform:
|
Test code from #42426:
I'd suggest to change the title to the other: "Support torch.fft without MKL support" |
Is there anyway Torch could incorporate https://developer.arm.com/tools-and-software/server-and-hpc/downloads/arm-performance-libraries Or is there a cross platform equivalent as seem only 64bit? Openblas,fftw... ? I was overjoyed someone had done some great work for Raspberry but my 1st test used FFT (also 32bit). |
It might be interesting to build a PyTorch-compatible library that can use the linked software, but I think the ARM community would be expected to drive any effort to support fft-like functionality in PyTorch on their hardware. |
another alternative is to support fftw3? |
The reality is we should all be supporting opensource such as Openblas & FFTW and not using hardware specifics such as Intel MKL. Arm perforamnce-libs, Intel-MKL & anything AMD might want to submit should be feature requests, to a normal opensource core not the other way round.
If Intel MKL uses industry-standard C and Fortran APIs why are we using 'Intel MKL' and not the 'industry-standard C and Fortran APIs' as a core? |
@StuartIanNaylor I encourage you to produce a PyTorch-compatible library using these alternatives or a PR that discusses their impact on build size and performance. |
The rationale of my reply is that would be illogical as surely the industry standards should be core and maybe I should provide hardware specifics such as IntelMKL. But the necessity to use a PyTorch-compatible library is what at least I am questioning as why is it needed when there are industry standard apis and libaries? |
Sorry I'm not sure what you're getting at, @StuartIanNaylor. As mentioned, we would consider adopting other math libraries or even entirely native implementations for these operations, but someone needs to do the work and demonstrate the correctness and performance of these alternatives. The best way to advocate for these changes is by doing that work. |
@mruberry I apologise if you are not sure what I am getting at but I just find it really strange that firstly you use a specific hardware vendors math libs when industry standard cross platform libs exist as they do. What is the point of even being able to raise or comment on a issue if the response is if you want it do it yourself?! I am asking why was Intel MKL chosen for the core as that really does confuse me with all honesty and no unfortunately I don't have the ability to implement standard math libs or arm specific. Also why do we need to demonstrate anything when its in the title of Intel MKL and that is worse than bad performance as its exclusive to an architecture. |
I see. I think we're straying far from the original issue here. These more general type of questions are best asked on our forum: https://discuss.pytorch.org/. |
My 2c: I think the misunderstanding comes from the reply:
That is reasonable for the extension to this issue: Support high(er)-performance special implementations for a specific hardware (ARM). However the basic suggestion was that PyTorch supports a cross-architecture FFT lib like FFTW by default and not default to a "special implementation for a specific hardware (Intel MKL)" which renders it unusable when doing CPU on non-"Intel x86" hardware (e.g. AMD [bad performance], Power, ARM, ...) See my comment #41592 (comment) |
Supporting fftw (or any particular library) is interesting. The questions we should answer when considering an alternative to our current approach are:
These can be tricky questions to answer. To the last question, however, is there not a PyTorch-compatible library calling fftw already available? PyTorch CPU tensors can be converted to NumPy arrays without copying memory. Is there no fftw package that operates on NumPy arrays? |
@mruberry I think that is a little too adversarial ;) There are several places in PyTorch where we have multiple libraries providing implementations of one function, with some decisions about when to select which one. If we are serious about ARM (which we should be!) then it's not a hard call to say that we should add another library to cover FFT support in this situation. Now, obviously there is work to figure out which library is appropriate and whether or not we should even ship it with our regular CPU binaries (probably not, if MKL fft is universally better), and as core developers we might not prioritize this work, but if my job were to make PyTorch work as well as possible on ARM, this would probably be part of the mandate. |
I did a few simple tests using numpy, pyfft and mkl-fft in Python and this seems to be true for x86. But again: MKL does not work at all on non x86. I'm actually surprised because I expected to see the slowdown on an AMD Rome processor due to the known "downgrading" of MKL performance on non-Intel processors. But I was not able to verify that for FFTs using Python as the interface. |
yeah, sorry, I should have specified, regular x86 cpu binaries :) |
Thanks for working on it @malfet |
The issue also happens in mobile: https://discuss.pytorch.org/t/fft-operations-on-mobile/119598. Please keep us posted for any updates! |
We are likely going to use pocketfft on non-x86 platforms |
Pocketfft would be a great inclusion but any choice other than vendor specific libs will do. Has anyone got an ETA as would really like to use on ARM64 and have tried and would seem so have many but still hit.
As just tried all the great whls supplied by https://mathinf.eu/pytorch/arm64/2021-01/ |
🐛 Bug
fft: ATen not compiled with MKL support
RuntimeError thrown when trying to compute Spectrogram on Jetson Nano that uses ARM64 processor.To Reproduce
Code sample:
Stack trace:
Expected behavior
Spectrogram from waveform created
Environment
Commands used to install PyTorch:
Commands used to install torchaudio:
sox:
torchaudio:
torchaudio.__version__
output:0.7.0a0+102174e
collect_env.py
output:Other relevant information:
MKL is not installed, because it is not supported on ARM processors; oneDNN installed
Additional context
I did not install MKL because it is not supported on ARM processors, so building PyTorch from source with MKL support is not possible. Is there any workaround to this problem?
cc @malfet @seemethere @walterddr @mruberry @peterbell10 @ezyang
The text was updated successfully, but these errors were encountered: