Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support auto-generated FP8 meta for CKPT converters. #844

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions rosetta/utils/te_pax_t5x_ckpt_converter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
the size to chucnk kernel (weighs) then store, only support with --fw=pax. Setting None means no chunking. (default: None)
--weight-only indicate if the source checkpoint only includes weights. (default: False)
--skip-ln indicate if skip the conversion for LayerNorm. (default: False)
--gen-fp8-meta indicate if generate corresponding FP8 meta. Only works when --direction=fw2te (default: False)
--amax-history-len AMAX_HISTORY_LEN
the length of amax history, which is only used when --gen-fp8-meta is specified. (default: 1024)
--pax-repeat indicate if the source Pax checkpoint enables Repeat. (default: False)
--t5x-fuse-qkv indicate if the source T5X checkpoint enables fused_qkv_params of TE. (default: False)
```
Expand Down Expand Up @@ -63,7 +66,7 @@ python converter/main.py \
--input-path=/your_path_to_src_ckpt \
--output-path=/your_path_to_output_ckpt \
--fw=pax \
--direction=fw2tw \
--direction=fw2te \
--pax-repeat \
--num-of-layer=8 \
--num-of-head=6 \
Expand Down Expand Up @@ -145,16 +148,49 @@ python converter/main.py \

### Notes
#### Running converted CKPTs with Transformer Engine (TE) + FP8
If you run the converted TE checkpoints ,from frameworks Pax or T5X, with FP8 enabled, you might enounter
an error said that there is not FP8 meta found in the given checkpoint at restoring phase. That is because the
original checkpoints to convert do not contains information about FP8 meta. To address this issue, please run
a training process with the same model config on the target framework, plus TE and FP8, then store a checkpoint
at step 0. Next, use the converted checkpoint to replace weights of the checkpoint from famework + TE + FP8, and
restoring it to keep training.
We now support auto-generating FP8 meta for converted TE checkpoints from framework checkpoints for further FP8 training.
To enable this feature, please add `--gen-fp8-meta` to your command when running the converter.
Additionally, you should specify the size of the amax history to be applied to subsequent FP8 training using `--amax-history-len`.

For examples:
- Pax -> TE (Repeat) with FP8 with 1024 amax history length:
```bash
python converter/main.py \
--input-path=/your_path_to_src_ckpt \
--output-path=/your_path_to_output_ckpt \
--fw=pax \
--direction=fw2te \
--pax-repeat \
--gen-fp8-meta \
--amax-history-len=1024 \
--num-of-layer=8 \
--num-of-head=6 \
--head-dim=64 \
--mlp-intermediate-dim=1024
```

- T5X -> TE/FusedQKV with FP8 with 1024 amax history length:
```bash
python converter/main.py \
--input-path=/your_path_to_src_ckpt \
--output-path=/your_path_to_output_ckpt \
--fw=t5x \
--direction=fw2te \
--t5x-fuse-qkv \
--embed-dim=512 \
--num-of-layer=8 \
--num-of-head=6 \
--head-dim=64 \
--mlp-intermediate-dim=1024
```

NOTE:
For the generated FP8 meta, only the amax of weights is accurate. Therefore, please be aware that a few steps for adjusting FP8 meta
of inputs and gradients are needed when resuming training with the converted FP8 checkpoints.

#### The folder structure of CKPT by Pax and T5X
If you would like to run the converted CKPTs with frameworks, you may expect the converted CKPTs have the same folder
structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the
structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the
CKPTs from frameworks, and no need to pre-generate folders, since it would be generated when needed.
For Pax, you could set `--output-path` be like ` /${your_path_to_output}/checkpoints/checkpoint_${step}`.
For T5X, you could set `--output-path` be like `/${your_path_to_output}/checkpoint_${step}`.
19 changes: 15 additions & 4 deletions rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse

from paxml_converters import Pax2TEConvertHelper, Pax2TERepeatConvertHelper
Expand Down Expand Up @@ -109,6 +108,17 @@ def parse_args():
default=False,
help="indicate if skip the conversion for LayerNorm.")

parser.add_argument('--gen-fp8-meta',
action="store_true",
default=False,
help="indicate if generate corresponding FP8 meta."
" Only works when --direction=fw2te")
parser.add_argument(
'--amax-history-len',
type=int,
default=1024,
help="the length of amax history, which is only used when --gen-fp8-meta is specified.")

parser.add_argument('--pax-repeat',
action="store_true",
default=False,
Expand All @@ -129,7 +139,8 @@ def parse_args():
def get_convert_helper(args):

model_config = ModelConfig(args.num_of_layer, args.embed_dim, args.num_of_head, args.head_dim,
args.mlp_intermediate_dim, args.kernel_chunk_size)
args.mlp_intermediate_dim, args.kernel_chunk_size,
args.amax_history_len)

convert_helper_cls = None

Expand All @@ -140,8 +151,8 @@ def get_convert_helper(args):
convert_helper_cls = T5X_CONVERT_HELPER_DICT[(args.direction, args.t5x_fuse_qkv)]

assert convert_helper_cls is not None, "Not Supported."
return convert_helper_cls(args.input_path, args.output_path, model_config,
args.weight_only, args.skip_ln)
return convert_helper_cls(args.input_path, args.output_path, model_config, args.weight_only,
args.skip_ln, args.gen_fp8_meta)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import jax.numpy as jnp

from utils import ConvertHelper
Expand All @@ -26,6 +25,10 @@ def catagories(self):
return ['mdl_vars.params']
return ['mdl_vars.params', "opt_states_0_2.m.params", "opt_states_0_2.v.params"]

@property
def fp8_meta_catagories(self):
return {'mdl_vars.params': 'mdl_vars.fp8_metas'}


class Pax2TEConvertHelper(PaxConvertHelperBase):

Expand All @@ -46,8 +49,11 @@ def _generate_ckpt_map(self):
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),
(hidden_dim, mlp_intermediate_dim),
0,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1])),
gen_fp8_meta=True,
fp8_meta_postfix='0'),
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.bias.b":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_bias",
Expand All @@ -57,7 +63,10 @@ def _generate_ckpt_map(self):
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.linear.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_kernel",
(mlp_intermediate_dim, hidden_dim), 1),
(mlp_intermediate_dim, hidden_dim),
1,
gen_fp8_meta=True,
fp8_meta_postfix='1'),
f"lm.transformer.x_layers_{i}.ff_layer.layer_norm.bias":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.ln_bias",
Expand Down Expand Up @@ -90,9 +99,12 @@ def _generate_ckpt_map(self):
f"lm.transformer.x_layers_{i}.self_attention.combined_qkv.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.kernel",
(3, hidden_dim, num_of_head, head_dim), 0,
(3, hidden_dim, num_of_head, head_dim),
0,
lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])),
lambda x: jnp.transpose(x, (1, 0, 2))),
lambda x: jnp.transpose(x, (1, 0, 2)),
gen_fp8_meta=True,
fp8_meta_postfix='0'),
f"lm.transformer.x_layers_{i}.self_attention.post.b":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.out.bias",
Expand All @@ -102,9 +114,12 @@ def _generate_ckpt_map(self):
f"lm.transformer.x_layers_{i}.self_attention.post.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.out.kernel",
(hidden_dim, num_of_head, head_dim), 1,
(hidden_dim, num_of_head, head_dim),
1,
lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])),
lambda x: jnp.transpose(x, (1, 0)))
lambda x: jnp.transpose(x, (1, 0)),
gen_fp8_meta=True,
fp8_meta_postfix='0')
})

return ckpt_map
Expand Down Expand Up @@ -199,6 +214,10 @@ def catagories(self):
f"opt_states_0.p#{num_of_layer}#i-1_2.v.params"
]

@property
def fp8_meta_catagories(self):
return {'mdl_vars.params': 'mdl_vars.fp8_metas'}


class Pax2TERepeatConvertHelper(PaxRepeatConvertHelperBase):

Expand All @@ -220,8 +239,12 @@ def _generate_ckpt_map(self):
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel',
(num_of_layer, hidden_dim, mlp_intermediate_dim), 1,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),
(num_of_layer, hidden_dim, mlp_intermediate_dim),
1,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1])),
gen_fp8_meta=True,
fp8_meta_postfix='0',
fp8_meta_shape_prefix=(num_of_layer,)),
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wo_bias',
Expand All @@ -231,7 +254,11 @@ def _generate_ckpt_map(self):
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wo_kernel',
(num_of_layer, mlp_intermediate_dim, hidden_dim), 2),
(num_of_layer, mlp_intermediate_dim, hidden_dim),
2,
gen_fp8_meta=True,
fp8_meta_postfix='1',
fp8_meta_shape_prefix=(num_of_layer,)),
'lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.ln_bias',
Expand Down Expand Up @@ -264,9 +291,13 @@ def _generate_ckpt_map(self):
'lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.attention.qkv.kernel',
(num_of_layer, 3, hidden_dim, num_of_head, head_dim), 1,
(num_of_layer, 3, hidden_dim, num_of_head, head_dim),
1,
lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])),
lambda x: jnp.transpose(x, (0, 2, 1, 3))),
lambda x: jnp.transpose(x, (0, 2, 1, 3)),
gen_fp8_meta=True,
fp8_meta_postfix='0',
fp8_meta_shape_prefix=(num_of_layer,)),
'lm.transformer.repeat.sub.x_layers_0.self_attention.post.b':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.attention.out.bias',
Expand All @@ -276,9 +307,13 @@ def _generate_ckpt_map(self):
'lm.transformer.repeat.sub.x_layers_0.self_attention.post.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.attention.out.kernel',
(num_of_layer, hidden_dim, num_of_head, head_dim), 2,
(num_of_layer, hidden_dim, num_of_head, head_dim),
2,
lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])),
lambda x: jnp.transpose(x, (0, 2, 1)))
lambda x: jnp.transpose(x, (0, 2, 1)),
gen_fp8_meta=True,
fp8_meta_postfix='0',
fp8_meta_shape_prefix=(num_of_layer,))
})

return ckpt_map
Expand Down
Loading
Loading