Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semantic Segmentation Inference Take Too Long #279

Open
Mio-Atse opened this issue Oct 18, 2024 · 1 comment
Open

Semantic Segmentation Inference Take Too Long #279

Mio-Atse opened this issue Oct 18, 2024 · 1 comment

Comments

@Mio-Atse
Copy link

I decided to use semantic segmentation for a SLAM project. I wrote a segmentation code that takes xyz rgb inputs, from which I could call the already provided trained models. Unfortunately I didn't understand how to use the provided test code and had to improvise a bit. The problem I'm having right now is that the segmentation is extremely slow. As you can see, I am working with CUDA GPU and even if I choose 0.5 as Voxel Size and downsample, the segmentation takes at least 50 seconds. Does anyone have any suggestions? If you have a code that can do inference with xyz rgb data, I would be very grateful if you can share it. I am calling this script from main script that can't share but it is not the problem.

import torch
import numpy as np
import importlib
import os
import sys
import open3d as o3d
import argparse
import logging
from datetime import datetime
import cv2

# Add the directory containing your model to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

NUM_CLASSES = 13
BATCH_SIZE = 32
NUM_POINT = 4096  # Adjusted to match the original code

# Define color map
COLOR_MAP = {
    'ceiling': [0, 255, 0],
    'floor': [0, 0, 255],
    'wall': [0, 255, 255],
    'beam': [255, 255, 0],
    'column': [255, 0, 255],
    'window': [100, 100, 255],
    'door': [200, 200, 100],
    'table': [170, 120, 200],
    'chair': [255, 0, 0],
    'sofa': [200, 100, 100],
    'bookcase': [10, 200, 100],
    'board': [200, 200, 200],
    'clutter': [50, 50, 50]
}

LABEL_TO_NAMES = {i: name for i, name in enumerate(COLOR_MAP.keys())}

# Set up CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def split_point_cloud_into_blocks(xyz, rgb, num_point, block_size=1.0, stride=0.5):
    """
    Split the point cloud into blocks.
    """
    # Determine the min and max coordinates
    coord_min = torch.min(xyz, dim=0)[0]
    coord_max = torch.max(xyz, dim=0)[0]
    
    # Calculate the number of blocks in each dimension
    grid_x = int(torch.ceil((coord_max[0] - coord_min[0] - block_size) / stride)) + 1
    grid_y = int(torch.ceil((coord_max[1] - coord_min[1] - block_size) / stride)) + 1
    
    blocks = []
    
    for idx_y in range(grid_y):
        for idx_x in range(grid_x):
            s_x = coord_min[0] + idx_x * stride
            s_y = coord_min[1] + idx_y * stride
            e_x = s_x + block_size
            e_y = s_y + block_size
            
            # Find points within this block
            block_mask = (xyz[:, 0] >= s_x) & (xyz[:, 0] <= e_x) & \
                         (xyz[:, 1] >= s_y) & (xyz[:, 1] <= e_y)
            block_point_indices = torch.where(block_mask)[0]
            
            if block_point_indices.size(0) == 0:
                continue
            
            block_xyz = xyz[block_point_indices]
            block_rgb = rgb[block_point_indices]
            
            # Adjust coordinates within the block
            block_xyz_centered = block_xyz - torch.tensor([s_x + block_size / 2.0, s_y + block_size / 2.0, 0], device=device)
            
            # Normalize XYZ coordinates
            block_xyz_normalized = block_xyz / coord_max
            
            # Normalize RGB
            block_rgb_normalized = block_rgb / 255.0
            
            # Combine features
            block_points = torch.cat((block_xyz_centered, block_rgb_normalized, block_xyz_normalized), dim=1)
            
            # Pad or sample to num_point
            if block_points.size(0) >= num_point:
                # Randomly sample num_point points
                idx = torch.randperm(block_points.size(0))[:num_point]
                block_points = block_points[idx]
                block_point_indices = block_point_indices[idx]
            else:
                # Pad with duplicated points
                idx = torch.randint(block_points.size(0), (num_point - block_points.size(0),))
                block_points = torch.cat((block_points, block_points[idx]), dim=0)
                block_point_indices = torch.cat((block_point_indices, block_point_indices[idx]), dim=0)
            
            blocks.append((block_points, block_point_indices))
    
    return blocks

def segment_point_cloud_with_voting(classifier, xyz, rgb, num_point, num_votes=3, block_size=1.0, stride=0.5):
    """
    Perform segmentation with voting.
    """
    num_classes = NUM_CLASSES
    num_points = xyz.size(0)
    vote_label_pool = torch.zeros((num_points, num_classes), dtype=torch.float32, device=device)
    
    for vote in range(num_votes):
        # Introduce randomness in point cloud for each vote
        xyz_vote = xyz + torch.rand_like(xyz) * 0.02 - 0.01
        # Split the point cloud into blocks
        blocks = split_point_cloud_into_blocks(xyz_vote, rgb, num_point, block_size, stride)
        
        for block_points, block_point_indices in blocks:
            # Prepare data for the model
            batch_points_tensor = block_points.unsqueeze(0).transpose(2, 1)
            
            # Perform segmentation
            with torch.no_grad():
                seg_pred, _ = classifier(batch_points_tensor)
                batch_pred_label = seg_pred.argmax(dim=2)
            
            # Update vote label pool
            one_hot = torch.nn.functional.one_hot(batch_pred_label[0], num_classes=num_classes).float()
            vote_label_pool.index_add_(0, block_point_indices, one_hot)
    
    # Get final predicted labels
    final_pred_labels = vote_label_pool.argmax(dim=1)
    return final_pred_labels

def segment_frame(classifier, xyz, rgb, num_point, num_votes=3, block_size=1.0, stride=0.5):
    """
    Perform segmentation on a single frame of point cloud data.
    
    Args:
        classifier: The trained classifier model
        xyz: Nx3 array of point coordinates
        rgb: Nx3 array of point colors (0-255)
        num_point: Number of points per sample
        num_votes: Number of votes for segmentation
        block_size: Block size for splitting point cloud
        stride: Stride for splitting point cloud
    
    Returns:
        segmented_points: Nx6 array of segmented points (XYZ + RGB)
        labels: Nx1 array of predicted labels
    """
    # Convert inputs to PyTorch tensors on the appropriate device
    xyz = torch.tensor(xyz, dtype=torch.float32, device=device)
    rgb = torch.tensor(rgb, dtype=torch.float32, device=device)
    
    # Perform segmentation
    labels = segment_point_cloud_with_voting(classifier, xyz, rgb, num_point, num_votes, block_size, stride)
    
    # Apply color map to segmented points
    segmented_colors = torch.tensor([COLOR_MAP[LABEL_TO_NAMES[label.item()]] for label in labels], device=device)
    
    # Combine XYZ coordinates with segmented colors
    segmented_points = torch.cat((xyz, segmented_colors), dim=1)
    
    return segmented_points.cpu().numpy(), labels.cpu().numpy()

# Load the classifier (this can be done once and the classifier can be passed to segment_frame)

def load_classifier(model_path, model_name):
    MODEL = importlib.import_module(model_name)
    classifier = MODEL.get_model(NUM_CLASSES)
    state_dict = torch.load(model_path, map_location=device)['model_state_dict']
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    classifier.load_state_dict(new_state_dict)
    classifier = classifier.to(device)
    classifier.eval()
    return classifier  # Remove JIT compilation
@surpoloyang
Copy link

Pointnet is indeed slow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants