Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX]try avoid the error in operator/tensor/amp_cast.h #20188

Merged
merged 3 commits into from
Apr 30, 2021
Merged

[BUGFIX]try avoid the error in operator/tensor/amp_cast.h #20188

merged 3 commits into from
Apr 30, 2021

Conversation

Neutron3529
Copy link
Contributor

@Neutron3529 Neutron3529 commented Apr 19, 2021

I'm trying to avoid the error generated by amp using bfloat16

The error is due to:

/me/prog/prog-amp.py:77: UserWarning: All children of this Sequential layer 'compose1_' are HybridBlocks. Consider using HybridSequential for the best performance.
  transform_test.hybridize(static_alloc=True,static_shape=True)
Traceback (most recent call last):
  File "/me/prog/prog-amp.py", line 359, in <module>
    loss0   = loss_fn(output, label)
  File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 314, in __mul__
    return multiply(self, other)
  File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 3757, in multiply
    return _ufunc_helper(
  File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 3576, in _ufunc_helper
    return fn_array(lhs, rhs)
  File "/me/incubator-mxnet/python/mxnet/contrib/amp/amp.py", line 109, in _new_fun
    return f(*args, **kwargs)
  File "<string>", line 52, in broadcast_mul
  File "/me/incubator-mxnet/python/mxnet/_ctypes/ndarray.py", line 82, in _imperative_invoke
    check_call(_LIB.MXImperativeInvokeEx(
  File "/me/incubator-mxnet/python/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "/me/incubator-mxnet/src/io/../operator/elemwise_op_common.h", line 135
MXNetError: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected bfloat16, got float32
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/me/incubator-mxnet/python/mxnet/base.py", line 587, in _notify_shutdown
    check_call(_LIB.MXNotifyShutdown())
  File "/me/incubator-mxnet/python/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "/me/incubator-mxnet/src/operator/tensor/./amp_cast.h", line 136
MXNetError: Unknown type enum 12

which is tested under mxnet v1.x, but seems also affect v2.0

since 30-series RTX card support bfloat16, there is no need to disable it using #ifndef __NVCC__ explicitly,

I don't know whether it works, but things could not be worse.

Description

import mxnet as mx
try:
  from mxnet.contrib import amp
except:
  from mxnet import amp

#amp.init('float16') # Ok
amp.init('bfloat16') # Error in previous version
net=mx.gluon.nn.Sequential()
net.hidden=mx.gluon.nn.Dense(400)
net.out=mx.gluon.nn.Dense(10)
ctx=mx.gpu()
net.initialize(ctx=ctx)
trainer = mx.gluon.Trainer(net.collect_params(), 'sgd')
amp.init_trainer(trainer)
with mx.autograd.record():
  loss=net(mx.nd.array([32,28,28],ctx=ctx))

loss.backward()
trainer.step(1)

such code will fail in previous version of mxnet, and here I provide a workaround.
further modification of bf16 is needed.

Checklist

Essentials

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

Actually this PR does nothing, Further support of bf16 (including a very important operator convolution) is required but I know nothing about cudnn.

I'm trying to avoid the error generated by amp using bfloat16

The error is due to:
```
/me/prog/prog-amp.py:77: UserWarning: All children of this Sequential layer 'compose1_' are HybridBlocks. Consider using HybridSequential for the best performance.
  transform_test.hybridize(static_alloc=True,static_shape=True)
Traceback (most recent call last):
  File "/me/prog/prog-amp.py", line 359, in <module>
    loss0   = loss_fn(output, label)
  File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 314, in __mul__
    return multiply(self, other)
  File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 3757, in multiply
    return _ufunc_helper(
  File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 3576, in _ufunc_helper
    return fn_array(lhs, rhs)
  File "/me/incubator-mxnet/python/mxnet/contrib/amp/amp.py", line 109, in _new_fun
    return f(*args, **kwargs)
  File "<string>", line 52, in broadcast_mul
  File "/me/incubator-mxnet/python/mxnet/_ctypes/ndarray.py", line 82, in _imperative_invoke
    check_call(_LIB.MXImperativeInvokeEx(
  File "/me/incubator-mxnet/python/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "/me/incubator-mxnet/src/io/../operator/elemwise_op_common.h", line 135
MXNetError: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected bfloat16, got float32
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/me/incubator-mxnet/python/mxnet/base.py", line 587, in _notify_shutdown
    check_call(_LIB.MXNotifyShutdown())
  File "/me/incubator-mxnet/python/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "/me/incubator-mxnet/src/operator/tensor/./amp_cast.h", line 136
MXNetError: Unknown type enum 12
```
which is tested under mxnet v1.x, but seems also affect v2.0

since 30-series RTX card support bfloat16, there is no need to disable it using `#ifndef __NVCC__` explicitly, 

I don't know whether it works, but things could not be worse.
@mxnet-bot
Copy link

Hey @Neutron3529 , Thanks for submitting the PR
All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands:

  • To trigger all jobs: @mxnet-bot run ci [all]
  • To trigger specific jobs: @mxnet-bot run ci [job1, job2]

CI supported jobs: [centos-gpu, unix-cpu, unix-gpu, clang, website, windows-cpu, edge, sanity, centos-cpu, windows-gpu, miscellaneous]


Note:
Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin.
All CI tests must pass before the PR can be merged.

@Neutron3529 Neutron3529 changed the title try avoid the error in operator/tensor/amp_cast.h [BUGFIX]try avoid the error in operator/tensor/amp_cast.h Apr 19, 2021
@szha szha requested a review from ptrendx April 19, 2021 14:49
@Neutron3529
Copy link
Contributor Author

@mxnet-bot run ci [unix-gpu]

@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-work-in-progress PR is still work in progress pr-awaiting-review PR is waiting for code review and removed pr-awaiting-testing PR is reviewed and waiting CI build and test pr-work-in-progress PR is still work in progress labels Apr 21, 2021
@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-work-in-progress PR is still work in progress and removed pr-awaiting-review PR is waiting for code review pr-awaiting-testing PR is reviewed and waiting CI build and test labels Apr 24, 2021
@Neutron3529
Copy link
Contributor Author

@mxnet-bot run ci [unix-gpu, unix-cpu, centos-cpu]

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [unix-gpu, unix-cpu, centos-cpu]

@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-work-in-progress PR is still work in progress and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Apr 24, 2021
@Neutron3529
Copy link
Contributor Author

@mxnet-bot run ci [unix-gpu, centos-cpu]

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [centos-cpu, unix-gpu]

@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Apr 29, 2021
@mseth10 mseth10 added the pr-awaiting-review PR is waiting for code review label Apr 29, 2021
@szha szha merged commit 0946c8e into apache:master Apr 30, 2021
@szha
Copy link
Member

szha commented Apr 30, 2021

@Neutron3529 thank you!

@Neutron3529
Copy link
Contributor Author

@Neutron3529 thank you!

You're welcome.

Since I am in China, it is not very convenience to visit github.
You may open a new issue about the amp support with bfloat16.

This commit do not work at least with convolution layers.

@Neutron3529 Neutron3529 deleted the patch-3 branch May 2, 2021 14:43
chinakook pushed a commit to chinakook/mxnet that referenced this pull request Aug 4, 2021
* try avoid the error in operator/tensor/amp_cast.h

I'm trying to avoid the error generated by amp using bfloat16

The error is due to:
```
/me/prog/prog-amp.py:77: UserWarning: All children of this Sequential layer 'compose1_' are HybridBlocks. Consider using HybridSequential for the best performance.
  transform_test.hybridize(static_alloc=True,static_shape=True)
Traceback (most recent call last):
  File "/me/prog/prog-amp.py", line 359, in <module>
    loss0   = loss_fn(output, label)
  File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 314, in __mul__
    return multiply(self, other)
  File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 3757, in multiply
    return _ufunc_helper(
  File "/me/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 3576, in _ufunc_helper
    return fn_array(lhs, rhs)
  File "/me/incubator-mxnet/python/mxnet/contrib/amp/amp.py", line 109, in _new_fun
    return f(*args, **kwargs)
  File "<string>", line 52, in broadcast_mul
  File "/me/incubator-mxnet/python/mxnet/_ctypes/ndarray.py", line 82, in _imperative_invoke
    check_call(_LIB.MXImperativeInvokeEx(
  File "/me/incubator-mxnet/python/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "/me/incubator-mxnet/src/io/../operator/elemwise_op_common.h", line 135
MXNetError: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected bfloat16, got float32
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/me/incubator-mxnet/python/mxnet/base.py", line 587, in _notify_shutdown
    check_call(_LIB.MXNotifyShutdown())
  File "/me/incubator-mxnet/python/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "/me/incubator-mxnet/src/operator/tensor/./amp_cast.h", line 136
MXNetError: Unknown type enum 12
```
which is tested under mxnet v1.x, but seems also affect v2.0

since 30-series RTX card support bfloat16, there is no need to disable it using `#ifndef __NVCC__` explicitly, 

I don't know whether it works, but things could not be worse.

* forgive my garbage coding, I'm not a computer scientist

* revert all the modification of base.h

Co-authored-by: Neutron3529 <qweytr1@mail.ustc.edu.cn>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants