Skip to content

Commit

Permalink
add type promotion static T+T logit. (#60638)
Browse files Browse the repository at this point in the history
* add type promotion static T+T logit.

* fix bug

* fix code comment

* add where op test for type promotion.

* fix

* fix bug

* fix

* fix path

* fix

* fix

* fix spelling problem.

* support paddle inference.

* add where grad
  • Loading branch information
zxcd authored Jan 19, 2024
1 parent 5e87a34 commit defd2d6
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 44 deletions.
1 change: 1 addition & 0 deletions python/paddle/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
is_compiled_with_rocm,
is_compiled_with_xpu,
name_scope,
process_type_promotion,
program_guard,
require_version,
set_flags,
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_flags,
in_pir_mode,
paddle_type_to_proto_type,
process_type_promotion,
set_flags,
)
from .incubate.checkpoint import auto_checkpoint as acp
Expand Down Expand Up @@ -1770,6 +1771,8 @@ def run(
return_numpy=return_numpy,
)
else:
# do type promotion if necessary
program = process_type_promotion(program)
res = self._run_impl(
program=program,
feed=feed,
Expand Down
107 changes: 107 additions & 0 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
_global_flags_ = core.globals()

SUPPORT_PROMOTION_OPS_AND_INPUTNAME = {
"elementwise_add": ['X', 'Y'],
"elementwise_add_grad": ['X', 'Y'],
"elementwise_sub": ['X', 'Y'],
"elementwise_sub_grad": ['X', 'Y'],
"elementwise_mul": ['X', 'Y'],
"elementwise_mul_grad": ['X', 'Y'],
"where": ['X', 'Y'],
"where_grad": ['X', 'Y'],
}


def _global_flags():
return _global_flags_
Expand Down Expand Up @@ -8141,3 +8152,99 @@ def _get_paddle_place_list(places):
ret.append(p)

return ret


def dtype_to_str(in_dtype):
if in_dtype == core.VarDesc.VarType.FP16:
return "fp16"
elif in_dtype == core.VarDesc.VarType.BF16:
return "bf16"
elif in_dtype == core.VarDesc.VarType.FP32:
return "fp32"
elif in_dtype == core.VarDesc.VarType.FP64:
return "fp64"
else:
return None


def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype):
op_device = op.attr('op_device')
cast_name = var_name.name + '.cast_' + dtype_to_str(out_dtype)
out_var = block.create_var(
name=cast_name,
dtype=out_dtype,
persistable=False,
stop_gradient=var_name.stop_gradient,
)
op_role = (
int(core.op_proto_and_checker_maker.OpRole.Forward)
if not op.has_attr('op_role')
else op.attr('op_role')
)
block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": var_name},
outputs={"Out": out_var},
attrs={
"in_dtype": var_name.dtype,
"out_dtype": out_var.dtype,
"op_device": op_device,
"op_role": op_role,
},
)
op.desc._rename_input(var_name.name, out_var.name)


def process_type_promotion(program):
org_program = program
if program is None:
program = default_main_program()
# not support pir for now
if not isinstance(program, Program):
return org_program
global_block = program.global_block()
all_params = global_block.all_parameters()
for block in program.blocks:
ops = block.ops
idx = 0
while idx < len(ops):
op = ops[idx]
var_name = None
all_dtypes = []
all_input_name_need_cast = []

need_transed_var_names = SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get(
op.type, None
)
# type promotion only support some dyadic api
if need_transed_var_names is None:
idx += 1
continue

# get all dtype and input_name
for input_idx in range(len(op.input_arg_names)):
if op.input_names[input_idx] in need_transed_var_names:
input_arg_name = op.input_arg_names[input_idx]
all_dtypes.append(
op.block._var_recursive(input_arg_name).dtype
)
all_input_name_need_cast.append(input_arg_name)

# only support promote between float
if core.need_type_promotion(*all_dtypes):
common_dtype = core.get_promote_dtype(op.type, *all_dtypes)
for input_name_need_cast in all_input_name_need_cast:
var_name = op.block._var_recursive(input_name_need_cast)
if var_name.dtype != common_dtype:
# add cast op for different dtype
add_cast_for_type_promotion(
op,
block,
idx,
var_name,
common_dtype,
)
idx += 1
idx += 1
return program
10 changes: 2 additions & 8 deletions python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,19 +534,13 @@ def __impl__(self, other_var):
if lhs_dtype != rhs_dtype:
if method_name in SUPPORT_PROMOTION_OPS:
if core.need_type_promotion(lhs_dtype, rhs_dtype):
common_dtype = core.get_promote_dtype(
op_type, lhs_dtype, rhs_dtype
)
# only report warning here, real promotion deal in Executor
warnings.warn(
f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted to {common_dtype}"
f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted"
)
warnings.filterwarnings(
"ignore", message="The input dtypes of OP"
)
if rhs_dtype != common_dtype:
other_var = astype(other_var, common_dtype)
if lhs_dtype != common_dtype:
self = astype(self, common_dtype)
else:
# NOTE(zoooo0820): Currently, we still keep the old illogical \
# logic for compatibility reasons
Expand Down
17 changes: 16 additions & 1 deletion python/paddle/static/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
unique_name,
)
from paddle.base.executor import Executor, global_scope
from paddle.base.framework import Parameter, dygraph_not_support, static_only
from paddle.base.framework import (
Parameter,
dygraph_not_support,
process_type_promotion,
static_only,
)
from paddle.base.log_helper import get_logger
from paddle.framework.io_utils import (
_clone_var_in_block_,
Expand Down Expand Up @@ -587,6 +592,10 @@ def save_inference_model(
_check_vars('fetch_vars', fetch_vars)

program = _get_valid_program(kwargs.get('program', None))

# do type promotion
program = process_type_promotion(program)

clip_extra = kwargs.get('clip_extra', True)
program = normalize_program(
program,
Expand Down Expand Up @@ -903,6 +912,9 @@ def load_inference_model(path_prefix, executor, **kwargs):
# deserialize bytes to program
program = deserialize_program(program_bytes)

# do type promotion
program = process_type_promotion(program)

vars = list(filter(is_persistable, program.list_vars()))
if len(vars) > 0:
load_vars(
Expand Down Expand Up @@ -958,6 +970,9 @@ def load_inference_model(path_prefix, executor, **kwargs):
# deserialize bytes to program
program = deserialize_program(program_bytes)

# do type promotion
program = process_type_promotion(program)

vars = list(filter(is_persistable, program.list_vars()))
if len(vars) > 0:
load_dirname = os.path.dirname(params_path)
Expand Down
85 changes: 50 additions & 35 deletions test/legacy_test/test_tensor_type_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,31 @@ def test_dtype_is_expected(self):
)


class TestAPIAddInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = paddle.add(self.l_value, self.r_value)
out_reverse = paddle.add(self.r_value, self.l_value)

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res


create_test_case(TestAPIAddInStatic, 'float16', 'float32', 'float32')
create_test_case(TestAPIAddInStatic, 'float16', 'float64', 'float64')

create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64')


if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(TestAPIAddInStatic, 'bfloat16', 'float16', 'float32')
create_test_case(TestAPIAddInStatic, 'bfloat16', 'float32', 'float32')
create_test_case(TestAPIAddInStatic, 'bfloat16', 'float64', 'float64')


class TestOperatorOverloadSubInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
Expand Down Expand Up @@ -156,74 +181,64 @@ def run_api(self):
)


class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic):
class TestAPISubInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = self.l_value * self.r_value
out_reverse = self.r_value * self.l_value
out = paddle.subtract(self.l_value, self.r_value)
out_reverse = paddle.subtract(self.r_value, self.l_value)

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res


create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64'
)
create_test_case(TestAPISubInStatic, 'float16', 'float32', 'float32')
create_test_case(TestAPISubInStatic, 'float16', 'float64', 'float64')

create_test_case(
TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64'
)
create_test_case(TestAPIAddInStatic, 'float32', 'float64', 'float64')

if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(
TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64'
)

if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(TestAPISubInStatic, 'bfloat16', 'float16', 'float32')
create_test_case(TestAPISubInStatic, 'bfloat16', 'float32', 'float32')
create_test_case(TestAPISubInStatic, 'bfloat16', 'float64', 'float64')

class TestOperatorOverloadGTInStatic(TestOperatorOverloadAddInStatic):
def set_dtype(self):
self.ldtype = 'float32'
self.rdtype = 'float64'
self.expected_out_dtype = 'bool'

class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = self.l_value > self.r_value
out_reverse = self.r_value > self.l_value
out = self.l_value * self.r_value
out_reverse = self.r_value * self.l_value

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res


create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float32', 'bool')
create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float64', 'bool')
create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64'
)

create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'float64', 'bool')
create_test_case(
TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64'
)

if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(
TestOperatorOverloadGTInStatic, 'bfloat16', 'float16', 'bool'
TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32'
)
create_test_case(
TestOperatorOverloadGTInStatic, 'bfloat16', 'float32', 'bool'
TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadGTInStatic, 'bfloat16', 'float64', 'bool'
TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64'
)


Expand Down
Loading

0 comments on commit defd2d6

Please sign in to comment.