Skip to content

Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading #3148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

bobbyliujb
Copy link

Summary:

Context

In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

This diff

  • added backend_return_whole_row flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
  • added read_only_ flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
  • added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
  • updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
  • by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision:
D77604158

Privacy Context Container: L1138451

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 1, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#3148)

Summary:
X-link: facebookresearch/FBGEMM#1495


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision:
D77604158

Privacy Context Container: L1138451
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading

Summary:
X-link: facebookresearch/FBGEMM#1495

X-link: pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision:
D77604158

Privacy Context Container: L1138451
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#3148)

Summary:
X-link: pytorch/FBGEMM#4429

X-link: facebookresearch/FBGEMM#1495


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#3148)

Summary:
X-link: pytorch/FBGEMM#4429

X-link: facebookresearch/FBGEMM#1495


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#3148)

Summary:
X-link: pytorch/FBGEMM#4429

X-link: facebookresearch/FBGEMM#1495


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

…kend for checkpoint saving/loading (pytorch#3148)

Summary:
X-link: pytorch/FBGEMM#4429

X-link: facebookresearch/FBGEMM#1495

Pull Request resolved: pytorch#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77604158

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 1, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:
Pull Request resolved: pytorch#4429

X-link: facebookresearch/FBGEMM#1495

X-link: pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 2, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Reviewed By: emlin

Differential Revision: D77604158
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 2, 2025
…kend for checkpoint saving/loading (pytorch#4429)

Summary:

X-link: facebookresearch/FBGEMM#1495

X-link: pytorch/torchrec#3148

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains backend change
* updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend)

Reviewed By: emlin

Differential Revision: D77604158
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants