-
Notifications
You must be signed in to change notification settings - Fork 12
/
fitting.py
31 lines (24 loc) · 1.02 KB
/
fitting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import torch
import argparse
from config.config import config
from lib.LandmarkDataset import LandmarkDataset
from lib.Recorder import Recorder
from lib.Fitter import Fitter
from lib.face_models import get_face_model
from lib.Camera import Camera
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/sample_video.yaml')
arg = parser.parse_args()
cfg = config()
cfg.load(arg.config)
cfg = cfg.get_cfg()
device = torch.device('cuda:%d' % cfg.gpu_id)
torch.cuda.set_device(cfg.gpu_id)
dataset = LandmarkDataset(landmark_folder=cfg.landmark_folder, camera_folder=cfg.camera_folder)
face_model = get_face_model(cfg.face_model, batch_size=len(dataset), device=device)
camera = Camera(image_size=cfg.image_size)
recorder = Recorder(save_folder=cfg.param_folder, camera=camera, visualize=cfg.visualize, save_vertices=cfg.save_vertices)
fitter = Fitter(cfg, dataset, face_model, camera, recorder, device)
fitter.run()