Skip to content

Commit

Permalink
fix and hub for cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
yangsenius committed Mar 19, 2021
1 parent f9ae6e8 commit 904eb4b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Introduction

**[TransPose](https://arxiv.org/abs/2012.14214)** is a human pose estimation model based on a CNN feature extractor, a Transformer Encoder, and a prediction head. Given an image, the attention layers built in Transformer can efficiently capture long-range spatial relationships between keypoints and explain what dependencies the predicted keypoints locations highly rely on.
**[TransPose](https://arxiv.org/abs/2012.14214)** is a human pose estimation model based on a CNN feature extractor, a Transformer Encoder, and a prediction head. Given an image, the attention layers built in Transformer can capture long-range spatial relationships between keypoints and explain what dependencies the predicted keypoints locations highly rely on.

![Architecture](transpose_architecture.png)

Expand All @@ -18,9 +18,9 @@ We choose two types of CNNs as the backbone candidates: ResNet and HRNet. The de
| TransPose-H-A4 | HRNet-S-W48 | 4 | 96 | 192 | 1 | 17.3Mb | 77.5 | [model](https://github.com/yangsenius/TransPose/releases/download/Hub/tp_h_48_256x192_enc4_d96_h192_mh1.pth) |
| TransPose-H-A6 | HRNet-S-W48 | 6 | 96 | 192 | 1 | 17.5Mb | 78.1 | [model](https://github.com/yangsenius/TransPose/releases/download/Hub/tp_h_48_256x192_enc6_d96_h192_mh1.pth) |

## News
### News

- [2021-3-19] ***TransPose-H-A6*** achieves **93.9%** accuracy on MPII test-dev set, with *256x192* input resolution. Details will be published.
- [2021-3-19]: ***TransPose-H-A6*** achieves **93.9%** accuracy on MPII test set, with *256x256* input resolution. Details will be published.

### Quick use

Expand Down Expand Up @@ -68,7 +68,7 @@ Given an input image, a pretrained TransPose model, and the predicted locations,
`TransPose-H-A4` with `threshold=0.00`
![example](attention_map_image_dependency_transposeh_thres_0.0.jpg)

`TransPose-H-A4` with `threshold=0.00075
`TransPose-H-A4` with `threshold=0.00075`
![example](attention_map_image_dependency_transposeh_thres_0.00075.jpg)

## Getting started
Expand Down
12 changes: 6 additions & 6 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def tpr_a4_256x192(pretrained=False, **kwargs):
if pretrained:
if cfg.TEST.MODEL_FILE and osp.isfile(cfg.TEST.MODEL_FILE):
print(">>Load pretrained weights from {}".format(cfg.TEST.MODEL_FILE))
pretrained_state_dict = torch.load(cfg.TEST.MODEL_FILE)
pretrained_state_dict = torch.load(cfg.TEST.MODEL_FILE, map_location=torch.device('cpu'))
model.load_state_dict(pretrained_state_dict, strict=True)
else:
### for pytorch 1.7 ###
Expand All @@ -64,9 +64,9 @@ def tpr_a4_256x192(pretrained=False, **kwargs):
if not osp.isfile(local_path):
torch.hub.download_url_to_file(
web_url, local_path, hash_prefix=None, progress=True)
checkpoint = torch.load(local_path)
checkpoint = torch.load(local_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
print("Successfully loaded pretrained weights!")
print("Successfully loaded model (on cpu) with pretrained weights!")
return model


Expand All @@ -82,7 +82,7 @@ def tph_a4_256x192(pretrained=False, **kwargs):
if pretrained:
if cfg.TEST.MODEL_FILE and osp.isfile(cfg.TEST.MODEL_FILE):
print(">>Load pretrained weights from {}".format(cfg.TEST.MODEL_FILE))
pretrained_state_dict = torch.load(cfg.TEST.MODEL_FILE)
pretrained_state_dict = torch.load(cfg.TEST.MODEL_FILE, map_location=torch.device('cpu'))
model.load_state_dict(pretrained_state_dict, strict=True)
else:
### for pytorch 1.7 ###
Expand All @@ -96,7 +96,7 @@ def tph_a4_256x192(pretrained=False, **kwargs):
if not osp.isfile(local_path):
torch.hub.download_url_to_file(
web_url, local_path, hash_prefix=None, progress=True)
checkpoint = torch.load(local_path)
checkpoint = torch.load(local_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
print("Successfully loaded pretrained weights!")
print("Successfully loaded model (on cpu) with pretrained weights!")
return model

0 comments on commit 904eb4b

Please sign in to comment.