diff --git a/python/paddle/base/__init__.py b/python/paddle/base/__init__.py index 9a6d8914feddb..83fe57b21ce4c 100644 --- a/python/paddle/base/__init__.py +++ b/python/paddle/base/__init__.py @@ -107,6 +107,7 @@ is_compiled_with_rocm, is_compiled_with_xpu, name_scope, + process_type_promotion, program_guard, require_version, set_flags, diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 4a7b24d6618c8..f73b2c999b227 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -41,6 +41,7 @@ get_flags, in_pir_mode, paddle_type_to_proto_type, + process_type_promotion, set_flags, ) from .incubate.checkpoint import auto_checkpoint as acp @@ -1770,6 +1771,8 @@ def run( return_numpy=return_numpy, ) else: + # do type promotion if necessary + program = process_type_promotion(program) res = self._run_impl( program=program, feed=feed, diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index e00a1827361aa..8fa7af0ef291c 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -56,6 +56,17 @@ CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() _global_flags_ = core.globals() +SUPPORT_PROMOTION_OPS_AND_INPUTNAME = { + "elementwise_add": ['X', 'Y'], + "elementwise_add_grad": ['X', 'Y'], + "elementwise_sub": ['X', 'Y'], + "elementwise_sub_grad": ['X', 'Y'], + "elementwise_mul": ['X', 'Y'], + "elementwise_mul_grad": ['X', 'Y'], + "where": ['X', 'Y'], + "where_grad": ['X', 'Y'], +} + def _global_flags(): return _global_flags_ @@ -8141,3 +8152,99 @@ def _get_paddle_place_list(places): ret.append(p) return ret + + +def dtype_to_str(in_dtype): + if in_dtype == core.VarDesc.VarType.FP16: + return "fp16" + elif in_dtype == core.VarDesc.VarType.BF16: + return "bf16" + elif in_dtype == core.VarDesc.VarType.FP32: + return "fp32" + elif in_dtype == core.VarDesc.VarType.FP64: + return "fp64" + else: + return None + + +def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype): + op_device = op.attr('op_device') + cast_name = var_name.name + '.cast_' + dtype_to_str(out_dtype) + out_var = block.create_var( + name=cast_name, + dtype=out_dtype, + persistable=False, + stop_gradient=var_name.stop_gradient, + ) + op_role = ( + int(core.op_proto_and_checker_maker.OpRole.Forward) + if not op.has_attr('op_role') + else op.attr('op_role') + ) + block._insert_op_without_sync( + idx, + type="cast", + inputs={"X": var_name}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": var_name.dtype, + "out_dtype": out_var.dtype, + "op_device": op_device, + "op_role": op_role, + }, + ) + op.desc._rename_input(var_name.name, out_var.name) + + +def process_type_promotion(program): + org_program = program + if program is None: + program = default_main_program() + # not support pir for now + if not isinstance(program, Program): + return org_program + global_block = program.global_block() + all_params = global_block.all_parameters() + for block in program.blocks: + ops = block.ops + idx = 0 + while idx < len(ops): + op = ops[idx] + var_name = None + all_dtypes = [] + all_input_name_need_cast = [] + + need_transed_var_names = SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get( + op.type, None + ) + # type promotion only support some dyadic api + if need_transed_var_names is None: + idx += 1 + continue + + # get all dtype and input_name + for input_idx in range(len(op.input_arg_names)): + if op.input_names[input_idx] in need_transed_var_names: + input_arg_name = op.input_arg_names[input_idx] + all_dtypes.append( + op.block._var_recursive(input_arg_name).dtype + ) + all_input_name_need_cast.append(input_arg_name) + + # only support promote between float + if core.need_type_promotion(*all_dtypes): + common_dtype = core.get_promote_dtype(op.type, *all_dtypes) + for input_name_need_cast in all_input_name_need_cast: + var_name = op.block._var_recursive(input_name_need_cast) + if var_name.dtype != common_dtype: + # add cast op for different dtype + add_cast_for_type_promotion( + op, + block, + idx, + var_name, + common_dtype, + ) + idx += 1 + idx += 1 + return program diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index d4634367b6be5..758f0410285a4 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -534,19 +534,13 @@ def __impl__(self, other_var): if lhs_dtype != rhs_dtype: if method_name in SUPPORT_PROMOTION_OPS: if core.need_type_promotion(lhs_dtype, rhs_dtype): - common_dtype = core.get_promote_dtype( - op_type, lhs_dtype, rhs_dtype - ) + # only report warning here, real promotion deal in Executor warnings.warn( - f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted to {common_dtype}" + f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted" ) warnings.filterwarnings( "ignore", message="The input dtypes of OP" ) - if rhs_dtype != common_dtype: - other_var = astype(other_var, common_dtype) - if lhs_dtype != common_dtype: - self = astype(self, common_dtype) else: # NOTE(zoooo0820): Currently, we still keep the old illogical \ # logic for compatibility reasons diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 64ee380b4b392..3d3d4f30fa2d4 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -33,7 +33,12 @@ unique_name, ) from paddle.base.executor import Executor, global_scope -from paddle.base.framework import Parameter, dygraph_not_support, static_only +from paddle.base.framework import ( + Parameter, + dygraph_not_support, + process_type_promotion, + static_only, +) from paddle.base.log_helper import get_logger from paddle.framework.io_utils import ( _clone_var_in_block_, @@ -587,6 +592,10 @@ def save_inference_model( _check_vars('fetch_vars', fetch_vars) program = _get_valid_program(kwargs.get('program', None)) + + # do type promotion + program = process_type_promotion(program) + clip_extra = kwargs.get('clip_extra', True) program = normalize_program( program, @@ -903,6 +912,9 @@ def load_inference_model(path_prefix, executor, **kwargs): # deserialize bytes to program program = deserialize_program(program_bytes) + # do type promotion + program = process_type_promotion(program) + vars = list(filter(is_persistable, program.list_vars())) if len(vars) > 0: load_vars( @@ -958,6 +970,9 @@ def load_inference_model(path_prefix, executor, **kwargs): # deserialize bytes to program program = deserialize_program(program_bytes) + # do type promotion + program = process_type_promotion(program) + vars = list(filter(is_persistable, program.list_vars())) if len(vars) > 0: load_dirname = os.path.dirname(params_path) diff --git a/test/legacy_test/test_tensor_type_promotion.py b/test/legacy_test/test_tensor_type_promotion.py index c47bfe8e5d1d5..19d26048f6997 100644 --- a/test/legacy_test/test_tensor_type_promotion.py +++ b/test/legacy_test/test_tensor_type_promotion.py @@ -119,6 +119,31 @@ def test_dtype_is_expected(self): ) +class TestAPIAddInStatic(TestOperatorOverloadAddInStatic): + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = paddle.add(self.l_value, self.r_value) + out_reverse = paddle.add(self.r_value, self.l_value) + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + +create_test_case(TestAPIAddInStatic, 'float16', 'float32', 'float32') +create_test_case(TestAPIAddInStatic, 'float16', 'float64', 'float64') + +create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64') + + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case(TestAPIAddInStatic, 'bfloat16', 'float16', 'float32') + create_test_case(TestAPIAddInStatic, 'bfloat16', 'float32', 'float32') + create_test_case(TestAPIAddInStatic, 'bfloat16', 'float64', 'float64') + + class TestOperatorOverloadSubInStatic(TestOperatorOverloadAddInStatic): def run_api(self): prog = paddle.static.Program() @@ -156,74 +181,64 @@ def run_api(self): ) -class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic): +class TestAPISubInStatic(TestOperatorOverloadAddInStatic): def run_api(self): prog = paddle.static.Program() with paddle.static.program_guard(prog): self.generate_test_value() - out = self.l_value * self.r_value - out_reverse = self.r_value * self.l_value + out = paddle.subtract(self.l_value, self.r_value) + out_reverse = paddle.subtract(self.r_value, self.l_value) res = self.exe.run(prog, fetch_list=[out, out_reverse]) return res -create_test_case( - TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32' -) -create_test_case( - TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64' -) +create_test_case(TestAPISubInStatic, 'float16', 'float32', 'float32') +create_test_case(TestAPISubInStatic, 'float16', 'float64', 'float64') -create_test_case( - TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64' -) +create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64') -if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): - create_test_case( - TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32' - ) - create_test_case( - TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32' - ) - create_test_case( - TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64' - ) +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case(TestAPISubInStatic, 'bfloat16', 'float16', 'float32') + create_test_case(TestAPISubInStatic, 'bfloat16', 'float32', 'float32') + create_test_case(TestAPISubInStatic, 'bfloat16', 'float64', 'float64') -class TestOperatorOverloadGTInStatic(TestOperatorOverloadAddInStatic): - def set_dtype(self): - self.ldtype = 'float32' - self.rdtype = 'float64' - self.expected_out_dtype = 'bool' +class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic): def run_api(self): prog = paddle.static.Program() with paddle.static.program_guard(prog): self.generate_test_value() - out = self.l_value > self.r_value - out_reverse = self.r_value > self.l_value + out = self.l_value * self.r_value + out_reverse = self.r_value * self.l_value res = self.exe.run(prog, fetch_list=[out, out_reverse]) return res -create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float32', 'bool') -create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float64', 'bool') +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32' +) +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64' +) -create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'float64', 'bool') +create_test_case( + TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64' +) if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): create_test_case( - TestOperatorOverloadGTInStatic, 'bfloat16', 'float16', 'bool' + TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32' ) create_test_case( - TestOperatorOverloadGTInStatic, 'bfloat16', 'float32', 'bool' + TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32' ) create_test_case( - TestOperatorOverloadGTInStatic, 'bfloat16', 'float64', 'bool' + TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64' ) diff --git a/test/legacy_test/test_where_op.py b/test/legacy_test/test_where_op.py index 53abe8a99f732..b338d6df0e378 100644 --- a/test/legacy_test/test_where_op.py +++ b/test/legacy_test/test_where_op.py @@ -318,6 +318,61 @@ def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): expect = np.where(cond_data, x_data, y_data) np.testing.assert_array_equal(out[0], expect) + def __test_where_with_type_promotion( + self, x_dtype, y_dtype, expeced_dtype=None + ): + paddle.enable_static() + main_program = paddle.static.Program() + shape = [3, 10] + with paddle.static.program_guard(main_program): + cond = paddle.static.data(name='cond', shape=[3, 10], dtype='bool') + x = paddle.static.data(name='x', shape=shape, dtype=x_dtype) + y = paddle.static.data(name='y', shape=shape, dtype=y_dtype) + cond_data_tmp = np.random.random(size=shape).astype('float32') + cond_data = cond_data_tmp < 0.3 + + if x_dtype != 'bfloat16': + x_data = np.random.random(size=shape).astype(x_dtype) + else: + x_data = convert_float_to_uint16( + np.random.random(size=shape).astype('float32') + ) + if y_dtype != 'bfloat16': + y_data = np.random.random(size=shape).astype(y_dtype) + else: + y_data = convert_float_to_uint16( + np.random.random(size=shape).astype('float32') + ) + result = paddle.where(condition=cond, x=x, y=y) + for use_cuda in [False, True]: + if use_cuda and (not base.core.is_compiled_with_cuda()): + return + place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() + exe = base.Executor(place) + out = exe.run( + paddle.static.default_main_program(), + feed={'cond': cond_data, 'x': x_data, 'y': y_data}, + fetch_list=[result], + ) + if x_dtype == 'bfloat16' or y_dtype == 'bfloat16': + x_data_convert = ( + convert_uint16_to_float(x_data) + if x_dtype == 'bfloat16' + else x_data + ) + y_data_convert = ( + convert_uint16_to_float(y_data) + if y_dtype == 'bfloat16' + else y_data + ) + expect = np.where(cond_data, x_data_convert, y_data_convert) + np.testing.assert_array_equal(out[0], expect) + self.assertEqual(out[0].dtype.__str__(), expeced_dtype) + else: + expect = np.where(cond_data, x_data, y_data) + np.testing.assert_array_equal(out[0], expect) + self.assertEqual(out[0].dtype, expect.dtype) + @test_with_pir_api def test_static_api_broadcast_1(self): cond_shape = [2, 4] @@ -374,6 +429,63 @@ def test_static_api_broadcast_8(self): b_shape = [2, 2, 1] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + def test_static_api_type_promotion_fp16_fp32(self): + x_dtype = 'float16' + y_dtype = 'float32' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + def test_static_api_type_promotion_fp16_fp64(self): + x_dtype = 'float16' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + def test_static_api_type_promotion_fp32_fp64(self): + x_dtype = 'float32' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp16(self): + x_dtype = 'bfloat16' + y_dtype = 'float16' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float32') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float32') + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp32(self): + x_dtype = 'bfloat16' + y_dtype = 'float32' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float32') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float32') + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp64(self): + x_dtype = 'bfloat16' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float64') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float64') + class TestWhereDygraphAPI(unittest.TestCase): def test_api(self):