Skip to content

优化模型设备选择,添加CPU推理支持,重构几何函数以适应设备类型 #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 134 additions & 77 deletions moge/model/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self,
if scale_head is not None:
self.scale_head = MLP(**scale_head)

self._initialize_device_functions()

@property
def device(self) -> torch.device:
return next(self.parameters()).device
Expand Down Expand Up @@ -133,33 +135,50 @@ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tenso
base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5)
num_tokens = base_h * base_w

# Convert image to numpy if on CPU
if device.type != 'cuda':
image_np = image.numpy()

# Backbones encoding
features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
features = [features, None, None, None, None]

# Concat UVs for aspect ratio input
# Concat UVs for aspect ratio input - use device specific function
for level in range(5):
uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
if features[level] is None:
features[level] = uv
if device.type == 'cuda':
uv = self.geometry['normalized_view_plane_uv'](
width=base_w * 2 ** level,
height=base_h * 2 ** level,
aspect_ratio=aspect_ratio,
dtype=dtype,
device=device
)
else:
features[level] = torch.concat([features[level], uv], dim=1)
uv = torch.from_numpy(
self.geometry['normalized_view_plane_uv'](
width=base_w * 2 ** level,
height=base_h * 2 ** level,
aspect_ratio=aspect_ratio
)
).to(device)

# Shared neck
if features[level] is not None:
features[level] = torch.cat([features[level], uv[None].expand(batch_size, -1, -1, -1)], dim=1)

# Process features through heads
features = self.neck(features)
points = self.points_head(features)[0] if self.points_head is not None else None
normal = self.normal_head(features)[0] if self.normal_head is not None else None
mask = self.mask_head(features)[0] if self.mask_head is not None else None
metric_scale = self.scale_head(cls_token)[0] if self.scale_head is not None else None

# Heads decoding
points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head'])
metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None

# Resize
# Resize to original resolution
points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask])

# Remap output
# Handle output based on device
if points is not None:
points = points.permute(0, 2, 3, 1)
points = self._remap_points(points) # slightly improves the performance in case of very large output values
points = self._remap_points(points)
if normal is not None:
normal = normal.permute(0, 2, 3, 1)
normal = F.normalize(normal, dim=-1)
Expand All @@ -169,13 +188,12 @@ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tenso
metric_scale = metric_scale.squeeze(1).exp()

return_dict = {
'points': points,
'points': points,
'normal': normal,
'mask': mask,
'metric_scale': metric_scale
}
return_dict = {k: v for k, v in return_dict.items() if v is not None}


return return_dict

@torch.inference_mode()
Expand Down Expand Up @@ -208,83 +226,122 @@ def infer(
- `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
- `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
"""
if image.dim() == 3:
if image.ndim == 3:
omit_batch_dim = True
image = image.unsqueeze(0)
else:
omit_batch_dim = False
image = image.to(dtype=self.dtype, device=self.device)

# Convert to appropriate device and dtype
image = image.to(device=self.device, dtype=self.dtype)
use_fp16 = use_fp16 and self.device.type == 'cuda' # Only use fp16 on CUDA devices

original_height, original_width = image.shape[-2:]
area = original_height * original_width
aspect_ratio = original_width / original_height

# Determine the number of base tokens to use

if num_tokens is None:
min_tokens, max_tokens = self.num_tokens_range
num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))

# Forward pass
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
# Forward pass with appropriate precision
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16):
output = self.forward(image, num_tokens=num_tokens)

points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale'])

# Always process the output in fp32 precision
points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x])
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
if mask is not None:
mask_binary = mask > 0.5
else:
mask_binary = None

if points is not None:
# Convert affine point map to camera-space. Recover depth and intrinsics from point map.
# NOTE: Focal here is the focal length relative to half the image diagonal
if fov_x is None:
# Recover focal and shift from predicted point map
focal, shift = recover_focal_shift(points, mask_binary)
# Process output in fp32 precision
if points is not None:
points = points.float()
if normal is not None:
normal = normal.float()
if mask is not None:
mask = mask.float()
mask_binary = mask > 0.5
else:
mask_binary = None
if metric_scale is not None:
metric_scale = metric_scale.float()
if isinstance(fov_x, torch.Tensor):
fov_x = fov_x.float()

# Process points and compute camera parameters
if points is not None:
# Convert to numpy for CPU operations if needed
if self.device.type != 'cuda':
points_np = points.cpu().numpy()
mask_binary_np = mask_binary.cpu().numpy() if mask_binary is not None else None

# Handle focal length and FOV
if fov_x is None:
if self.device.type == 'cuda':
focal = (1 + aspect_ratio ** 2) ** -0.5 / (points[..., 0].std(-1).std(-1) + 1e-5)
else:
# Focal is known, recover shift only
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
if focal.ndim == 0:
focal = focal[None].expand(points.shape[0])
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
points[..., 2] += shift[..., None, None]
if mask_binary is not None:
mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice)
depth = points[..., 2].clone()
focal = (1 + aspect_ratio ** 2) ** -0.5 / (np.std(points_np[..., 0]) + 1e-5)
else:
depth, intrinsics = None, None
focal = 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))

# If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
if force_projection and depth is not None:
points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
# Convert scalar focal to tensor if needed
if not isinstance(focal, torch.Tensor):
focal = torch.tensor(focal, device=points.device, dtype=points.dtype)
if focal.ndim == 0:
focal = focal[None].expand(points.shape[0])

# Apply metric scale
if metric_scale is not None:
if points is not None:
points *= metric_scale[:, None, None, None]
if depth is not None:
depth *= metric_scale[:, None, None]
# Build camera intrinsics
fx = focal * aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
fy = focal / (1 + aspect_ratio ** 2) ** 0.5
intrinsics = torch.zeros((*points.shape[:-3], 3, 3), device=points.device, dtype=points.dtype)
intrinsics[..., 0, 0] = fx
intrinsics[..., 1, 1] = fy
intrinsics[..., 0, 2] = intrinsics[..., 1, 2] = 0.5
intrinsics[..., 2, 2] = 1

# Apply mask
if apply_mask and mask_binary is not None:
points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None

return_dict = {
'points': points,
'intrinsics': intrinsics,
'depth': depth,
'mask': mask_binary,
'normal': normal
}
return_dict = {k: v for k, v in return_dict.items() if v is not None}

if omit_batch_dim:
return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
# Process depth
if force_projection:
if self.device.type == 'cuda':
depth = self.geometry['points_to_depth'](points)
points = self.geometry['depth_to_points'](depth, intrinsics=intrinsics)
else:
depth = torch.from_numpy(self.geometry['points_to_depth'](points_np)).to(self.device)
points = torch.from_numpy(
self.geometry['depth_to_points'](depth.cpu().numpy(), intrinsics=intrinsics.cpu().numpy())
).to(self.device)
else:
depth = points[..., 2]

# Assemble output dictionary
return_dict = {}
for k, v in [('points', points), ('depth', depth), ('normal', normal),
('mask', mask if apply_mask else None), ('intrinsics', intrinsics)]:
if v is not None:
if omit_batch_dim:
v = v.squeeze(0)
return_dict[k] = v

return return_dict

def _initialize_device_functions(self):
"""Initialize device-specific geometry functions."""
if self.device.type == 'cuda':
from ..utils.geometry_torch import (
normalized_view_plane_uv,
depth_to_points,
points_to_depth,
points_to_normals,
gaussian_blur_2d
)
else:
from ..utils.geometry_cpu import (
normalized_view_plane_uv_cpu as normalized_view_plane_uv,
depth_to_points_cpu as depth_to_points,
points_to_depth_cpu as points_to_depth,
points_to_normals_cpu as points_to_normals,
gaussian_blur_2d_cpu as gaussian_blur_2d
)

self.geometry = {
'normalized_view_plane_uv': normalized_view_plane_uv,
'depth_to_points': depth_to_points,
'points_to_depth': points_to_depth,
'points_to_normals': points_to_normals,
'gaussian_blur_2d': gaussian_blur_2d
}
19 changes: 16 additions & 3 deletions moge/scripts/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def main(share: bool, pretrained_model_name_or_path: str, model_version: str, us
"v2": "Ruicheng/moge-2-vitl-normal",
}
pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version]
model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).cuda().eval()
if use_fp16:
# 自动选择设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).to(device).eval()
if use_fp16 and device.type == 'cuda':
model.half()
thread_pool_executor = ThreadPoolExecutor(max_workers=1)

Expand All @@ -76,6 +78,13 @@ def run_with_gpu(image: np.ndarray, resolution_level: int, apply_mask: bool) ->
output = {k: v.cpu().numpy() for k, v in output.items()}
return output

# Inference on CPU
def run_with_cpu(image: np.ndarray, resolution_level: int, apply_mask: bool) -> Dict[str, np.ndarray]:
image_tensor = torch.tensor(image, dtype=torch.float32, device=torch.device('cpu')).permute(2, 0, 1) / 255
output = model.infer(image_tensor, apply_mask=apply_mask, resolution_level=resolution_level, use_fp16=False)
output = {k: v.cpu().numpy() for k, v in output.items()}
return output

# Full inference pipeline
def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', apply_mask: bool = True, remove_edge: bool = True, request: gr.Request = None):
larger_size = max(image.shape[:2])
Expand All @@ -86,7 +95,11 @@ def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High',
height, width = image.shape[:2]

resolution_level_int = {'Low': 0, 'Medium': 5, 'High': 9, 'Ultra': 30}.get(resolution_level, 9)
output = run_with_gpu(image, resolution_level_int, apply_mask)
# 判断当前模型设备
if model.device.type == 'cuda':
output = run_with_gpu(image, resolution_level_int, apply_mask)
else:
output = run_with_cpu(image, resolution_level_int, apply_mask)

points, depth, mask, normal = output['points'], output['depth'], output['mask'], output.get('normal', None)

Expand Down
Loading