diff --git a/moge/model/v2.py b/moge/model/v2.py index eee351c..eeaf3c8 100644 --- a/moge/model/v2.py +++ b/moge/model/v2.py @@ -55,6 +55,8 @@ def __init__(self, if scale_head is not None: self.scale_head = MLP(**scale_head) + self._initialize_device_functions() + @property def device(self) -> torch.device: return next(self.parameters()).device @@ -133,33 +135,50 @@ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tenso base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5) num_tokens = base_h * base_w + # Convert image to numpy if on CPU + if device.type != 'cuda': + image_np = image.numpy() + # Backbones encoding features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True) features = [features, None, None, None, None] - # Concat UVs for aspect ratio input + # Concat UVs for aspect ratio input - use device specific function for level in range(5): - uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device) - uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1) - if features[level] is None: - features[level] = uv + if device.type == 'cuda': + uv = self.geometry['normalized_view_plane_uv']( + width=base_w * 2 ** level, + height=base_h * 2 ** level, + aspect_ratio=aspect_ratio, + dtype=dtype, + device=device + ) else: - features[level] = torch.concat([features[level], uv], dim=1) + uv = torch.from_numpy( + self.geometry['normalized_view_plane_uv']( + width=base_w * 2 ** level, + height=base_h * 2 ** level, + aspect_ratio=aspect_ratio + ) + ).to(device) - # Shared neck + if features[level] is not None: + features[level] = torch.cat([features[level], uv[None].expand(batch_size, -1, -1, -1)], dim=1) + + # Process features through heads features = self.neck(features) + points = self.points_head(features)[0] if self.points_head is not None else None + normal = self.normal_head(features)[0] if self.normal_head is not None else None + mask = self.mask_head(features)[0] if self.mask_head is not None else None + metric_scale = self.scale_head(cls_token)[0] if self.scale_head is not None else None - # Heads decoding - points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head']) - metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None - - # Resize + # Resize to original resolution points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask]) - # Remap output + # Handle output based on device if points is not None: points = points.permute(0, 2, 3, 1) - points = self._remap_points(points) # slightly improves the performance in case of very large output values + points = self._remap_points(points) if normal is not None: normal = normal.permute(0, 2, 3, 1) normal = F.normalize(normal, dim=-1) @@ -169,13 +188,12 @@ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tenso metric_scale = metric_scale.squeeze(1).exp() return_dict = { - 'points': points, + 'points': points, 'normal': normal, 'mask': mask, 'metric_scale': metric_scale } - return_dict = {k: v for k, v in return_dict.items() if v is not None} - + return return_dict @torch.inference_mode() @@ -208,83 +226,122 @@ def infer( - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. """ - if image.dim() == 3: + if image.ndim == 3: omit_batch_dim = True image = image.unsqueeze(0) else: omit_batch_dim = False - image = image.to(dtype=self.dtype, device=self.device) + + # Convert to appropriate device and dtype + image = image.to(device=self.device, dtype=self.dtype) + use_fp16 = use_fp16 and self.device.type == 'cuda' # Only use fp16 on CUDA devices original_height, original_width = image.shape[-2:] - area = original_height * original_width aspect_ratio = original_width / original_height - - # Determine the number of base tokens to use + if num_tokens is None: min_tokens, max_tokens = self.num_tokens_range num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)) - - # Forward pass - with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16): + + # Forward pass with appropriate precision + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16): output = self.forward(image, num_tokens=num_tokens) + points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale']) - # Always process the output in fp32 precision - points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x]) - with torch.autocast(device_type=self.device.type, dtype=torch.float32): - if mask is not None: - mask_binary = mask > 0.5 - else: - mask_binary = None - - if points is not None: - # Convert affine point map to camera-space. Recover depth and intrinsics from point map. - # NOTE: Focal here is the focal length relative to half the image diagonal - if fov_x is None: - # Recover focal and shift from predicted point map - focal, shift = recover_focal_shift(points, mask_binary) + # Process output in fp32 precision + if points is not None: + points = points.float() + if normal is not None: + normal = normal.float() + if mask is not None: + mask = mask.float() + mask_binary = mask > 0.5 + else: + mask_binary = None + if metric_scale is not None: + metric_scale = metric_scale.float() + if isinstance(fov_x, torch.Tensor): + fov_x = fov_x.float() + + # Process points and compute camera parameters + if points is not None: + # Convert to numpy for CPU operations if needed + if self.device.type != 'cuda': + points_np = points.cpu().numpy() + mask_binary_np = mask_binary.cpu().numpy() if mask_binary is not None else None + + # Handle focal length and FOV + if fov_x is None: + if self.device.type == 'cuda': + focal = (1 + aspect_ratio ** 2) ** -0.5 / (points[..., 0].std(-1).std(-1) + 1e-5) else: - # Focal is known, recover shift only - focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) - if focal.ndim == 0: - focal = focal[None].expand(points.shape[0]) - _, shift = recover_focal_shift(points, mask_binary, focal=focal) - fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 - intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5) - points[..., 2] += shift[..., None, None] - if mask_binary is not None: - mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice) - depth = points[..., 2].clone() + focal = (1 + aspect_ratio ** 2) ** -0.5 / (np.std(points_np[..., 0]) + 1e-5) else: - depth, intrinsics = None, None + focal = 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) - # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics - if force_projection and depth is not None: - points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics) + # Convert scalar focal to tensor if needed + if not isinstance(focal, torch.Tensor): + focal = torch.tensor(focal, device=points.device, dtype=points.dtype) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) - # Apply metric scale - if metric_scale is not None: - if points is not None: - points *= metric_scale[:, None, None, None] - if depth is not None: - depth *= metric_scale[:, None, None] + # Build camera intrinsics + fx = focal * aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + fy = focal / (1 + aspect_ratio ** 2) ** 0.5 + intrinsics = torch.zeros((*points.shape[:-3], 3, 3), device=points.device, dtype=points.dtype) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = intrinsics[..., 1, 2] = 0.5 + intrinsics[..., 2, 2] = 1 - # Apply mask - if apply_mask and mask_binary is not None: - points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None - depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None - normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None - - return_dict = { - 'points': points, - 'intrinsics': intrinsics, - 'depth': depth, - 'mask': mask_binary, - 'normal': normal - } - return_dict = {k: v for k, v in return_dict.items() if v is not None} - - if omit_batch_dim: - return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + # Process depth + if force_projection: + if self.device.type == 'cuda': + depth = self.geometry['points_to_depth'](points) + points = self.geometry['depth_to_points'](depth, intrinsics=intrinsics) + else: + depth = torch.from_numpy(self.geometry['points_to_depth'](points_np)).to(self.device) + points = torch.from_numpy( + self.geometry['depth_to_points'](depth.cpu().numpy(), intrinsics=intrinsics.cpu().numpy()) + ).to(self.device) + else: + depth = points[..., 2] + # Assemble output dictionary + return_dict = {} + for k, v in [('points', points), ('depth', depth), ('normal', normal), + ('mask', mask if apply_mask else None), ('intrinsics', intrinsics)]: + if v is not None: + if omit_batch_dim: + v = v.squeeze(0) + return_dict[k] = v + return return_dict + + def _initialize_device_functions(self): + """Initialize device-specific geometry functions.""" + if self.device.type == 'cuda': + from ..utils.geometry_torch import ( + normalized_view_plane_uv, + depth_to_points, + points_to_depth, + points_to_normals, + gaussian_blur_2d + ) + else: + from ..utils.geometry_cpu import ( + normalized_view_plane_uv_cpu as normalized_view_plane_uv, + depth_to_points_cpu as depth_to_points, + points_to_depth_cpu as points_to_depth, + points_to_normals_cpu as points_to_normals, + gaussian_blur_2d_cpu as gaussian_blur_2d + ) + + self.geometry = { + 'normalized_view_plane_uv': normalized_view_plane_uv, + 'depth_to_points': depth_to_points, + 'points_to_depth': points_to_depth, + 'points_to_normals': points_to_normals, + 'gaussian_blur_2d': gaussian_blur_2d + } diff --git a/moge/scripts/app.py b/moge/scripts/app.py index ba66024..208d502 100644 --- a/moge/scripts/app.py +++ b/moge/scripts/app.py @@ -51,8 +51,10 @@ def main(share: bool, pretrained_model_name_or_path: str, model_version: str, us "v2": "Ruicheng/moge-2-vitl-normal", } pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version] - model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).cuda().eval() - if use_fp16: + # 自动选择设备 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).to(device).eval() + if use_fp16 and device.type == 'cuda': model.half() thread_pool_executor = ThreadPoolExecutor(max_workers=1) @@ -76,6 +78,13 @@ def run_with_gpu(image: np.ndarray, resolution_level: int, apply_mask: bool) -> output = {k: v.cpu().numpy() for k, v in output.items()} return output + # Inference on CPU + def run_with_cpu(image: np.ndarray, resolution_level: int, apply_mask: bool) -> Dict[str, np.ndarray]: + image_tensor = torch.tensor(image, dtype=torch.float32, device=torch.device('cpu')).permute(2, 0, 1) / 255 + output = model.infer(image_tensor, apply_mask=apply_mask, resolution_level=resolution_level, use_fp16=False) + output = {k: v.cpu().numpy() for k, v in output.items()} + return output + # Full inference pipeline def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', apply_mask: bool = True, remove_edge: bool = True, request: gr.Request = None): larger_size = max(image.shape[:2]) @@ -86,7 +95,11 @@ def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', height, width = image.shape[:2] resolution_level_int = {'Low': 0, 'Medium': 5, 'High': 9, 'Ultra': 30}.get(resolution_level, 9) - output = run_with_gpu(image, resolution_level_int, apply_mask) + # 判断当前模型设备 + if model.device.type == 'cuda': + output = run_with_gpu(image, resolution_level_int, apply_mask) + else: + output = run_with_cpu(image, resolution_level_int, apply_mask) points, depth, mask, normal = output['points'], output['depth'], output['mask'], output.get('normal', None) diff --git a/moge/utils/geometry_cpu.py b/moge/utils/geometry_cpu.py new file mode 100644 index 0000000..80b6f0b --- /dev/null +++ b/moge/utils/geometry_cpu.py @@ -0,0 +1,103 @@ +from typing import * +import math +import numpy as np + +def normalized_view_plane_uv_cpu(width: int, height: int, aspect_ratio: float = None) -> np.ndarray: + """CPU version of normalized view plane UV coordinates.""" + if aspect_ratio is None: + aspect_ratio = width / height + + x = np.linspace(-0.5, 0.5, width, dtype=np.float32) + y = np.linspace(-0.5, 0.5, height, dtype=np.float32) + u, v = np.meshgrid(x * aspect_ratio, y) + return np.stack([u, v], axis=-1) + +def depth_to_points_cpu(depth: np.ndarray, intrinsics: np.ndarray = None) -> np.ndarray: + """Convert depth map to 3D points in camera space (CPU version).""" + height, width = depth.shape[-2:] + if intrinsics is None: + # Default normalized intrinsics + focal_x = focal_y = 1.0 + center_x = center_y = 0.5 + else: + focal_x, focal_y = intrinsics[..., 0, 0], intrinsics[..., 1, 1] + center_x, center_y = intrinsics[..., 0, 2], intrinsics[..., 1, 2] + + y, x = np.meshgrid( + np.linspace(0, height-1, height), + np.linspace(0, width-1, width), + indexing='ij' + ) + + x = (x - center_x * width) / focal_x + y = (y - center_y * height) / focal_y + + points = np.stack([ + x * depth, + y * depth, + depth + ], axis=-1) + + return points + +def points_to_depth_cpu(points: np.ndarray) -> np.ndarray: + """Extract depth from points map (CPU version).""" + return points[..., 2] + +def points_to_normals_cpu(points: np.ndarray, mask: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]: + """Compute surface normals from points map (CPU version).""" + # Compute tangent vectors + height, width = points.shape[-3:-1] + dx = points[..., 1:, :, :] - points[..., :-1, :, :] + dy = points[..., :, 1:, :] - points[..., :, :-1, :] + + # Compute normals from cross product + normal = np.cross(dx[..., :-1, :, :], dy[..., 1:, :, :], axis=-1) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-10) + + # Pad normals to match input size + normal_pad = np.pad( + normal, + tuple((0, 0) for _ in range(normal.ndim - 3)) + ((0, 1), (0, 1), (0, 0)), + mode='edge' + ) + + if mask is not None: + mask_valid = mask[..., 1:, 1:] & mask[..., :-1, 1:] & mask[..., 1:, :-1] & mask[..., :-1, :-1] + mask_pad = np.pad( + mask_valid, + tuple((0, 0) for _ in range(mask.ndim - 2)) + ((0, 1), (0, 1)), + mode='constant' + ) + else: + mask_pad = np.ones_like(mask, dtype=bool) + + return normal_pad, mask_pad + +def gaussian_blur_2d_cpu(x: np.ndarray, kernel_size: int, sigma: float) -> np.ndarray: + """Apply 2D Gaussian blur (CPU version).""" + kernel = np.exp(-(np.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1) ** 2) / (2 * sigma ** 2)) + kernel = kernel / kernel.sum() + kernel_2d = kernel[:, None] * kernel[None, :] + kernel_2d = kernel_2d[None, None, :, :] + + padding = kernel_size // 2 + x_pad = np.pad(x, ((0, 0), (0, 0), (padding, padding), (padding, padding)), mode='reflect') + + output = np.zeros_like(x) + for i in range(x.shape[1]): + output[:, i] = conv2d_cpu(x_pad[:, i:i+1], kernel_2d) + return output + +def conv2d_cpu(x: np.ndarray, kernel: np.ndarray) -> np.ndarray: + """Simple 2D convolution implementation for CPU.""" + from scipy.ndimage import convolve + out_shape = x.shape[:-2] + ( + x.shape[-2] - kernel.shape[-2] + 1, + x.shape[-1] - kernel.shape[-1] + 1 + ) + out = np.zeros(out_shape, dtype=x.dtype) + + for b in range(x.shape[0]): + out[b] = convolve(x[b, 0], kernel[0, 0], mode='valid') + return out