본문 바로가기
의료 AI/[04] - 논문 리뷰

[01 X-ray Hand] - Maskformer

by AI-BT 2024. 11. 27.
728x90
반응형

 

MaskFormer는 이미지를 분할하기 위해 classification + segmentation의 통합 접근 방식을 제안한 모델이다. 전통적인 방식이 픽셀 단위의 분류에 초점을 맞췄다면, MaskFormer는 각 마스크를 하나의 객체로 취급하여 보다 효율적인 분할을 수행한다. 특히, 트랜스포머(Transformer)를 활용하여 객체 간의 상호작용을 학습하며, 마스크 예측(mask prediction)과 클래스 예측(class prediction)을 함께 수행한다.

 

MaskFormer의 주요 구조

MaskFormer는 크게 Pixel-Level Module, Transformer Module, 그리고 Segmentation Module로 구성된다.

1. Pixel-Level Module

  • 입력 이미지를 백본(ResNet, Swin Transformer 등)을 통해 이미지 특징(feature)을 추출한다.
  • 이 특징 맵을 Pixel Decoder를 통해 픽셀 단위의 임베딩으로 변환한다.

2. Transformer Module

  • 픽셀 임베딩과 N개의 Query를 Transformer Decoder에 입력하여 객체 간 상호작용을 학습한다.
  • Transformer의 출력은 MLP(멀티 레이어 퍼셉트론)를 통해 다음과 같은 두 가지 예측을 한다.
    • N개의 클래스(class predictions): 각 마스크가 어떤 클래스인지 예측
    • N개의 마스크(mask embeddings): 각 객체의 마스크 형태를 예측

3. Segmentation Module

  • Transformer에서 나온 결과를 기반으로 최종 semantic segmentation 결과를 도출한다.
  • 손실 함수는 두 가지로 구성된다.
    • Classification Loss: 클래스 예측 정확도
    • Binary Mask Loss: 마스크 형태의 정확도

MaskFormer의 장점

  1. 효율성: 객체 단위의 마스크를 학습함으로써 모든 픽셀에 대해 일일이 분류를 하지 않아도 된다.
  2. 유연성: 다양한 백본(ResNet, Swin Transformer 등)과 쉽게 결합 가능하다.
  3. 확장성: Semantic Segmentation뿐만 아니라 Instance Segmentation, Panoptic Segmentation에도 사용 가능하다.

PyTorch 기반 MaskFormer 학습 코드

MaskFormer의 각 구성 요소 (Pixel-Level Module, Transformer Module, Segmentation Module)를 분리하여 단계별로 설명한다. PyTorch로 작성되었으며, 학습에 필요한 기초적인 모델 구조만 포함하고 있다.

import torch
import torch.nn as nn
from torchvision.models import resnet50
from torch.nn.functional import interpolate

# 1. Pixel-Level Module
class PixelLevelModule(nn.Module):
    def __init__(self, backbone_channels=2048, pixel_decoder_channels=256):
        super(PixelLevelModule, self).__init__()
        self.backbone = resnet50(pretrained=True)  # Backbone: ResNet50
        self.pixel_decoder = nn.Sequential(
            nn.Conv2d(backbone_channels, pixel_decoder_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(pixel_decoder_channels, pixel_decoder_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        features = self.backbone(x)  # Backbone에서 Feature 추출
        pixel_embeddings = self.pixel_decoder(features)  # Pixel Decoder로 임베딩 생성
        return pixel_embeddings

# 2. Transformer Module
class TransformerModule(nn.Module):
    def __init__(self, num_queries=100, embed_dim=256):
        super(TransformerModule, self).__init__()
        self.query_embeddings = nn.Embedding(num_queries, embed_dim)  # Query 생성
        self.transformer_decoder = nn.Transformer(embed_dim, num_heads=8, num_decoder_layers=6)

    def forward(self, pixel_embeddings):
        batch_size, c, h, w = pixel_embeddings.shape
        pixel_embeddings_flat = pixel_embeddings.flatten(2).permute(2, 0, 1)  # Flatten (H*W, B, C)
        queries = self.query_embeddings.weight.unsqueeze(1).repeat(1, batch_size, 1)  # (N, B, C)
        transformer_output = self.transformer_decoder(queries, pixel_embeddings_flat)  # Transformer Decoder
        return transformer_output

# 3. Segmentation Module
class SegmentationModule(nn.Module):
    def __init__(self, num_classes, embed_dim=256):
        super(SegmentationModule, self).__init__()
        self.classifier = nn.Linear(embed_dim, num_classes)  # Class Prediction
        self.mask_predictor = nn.Conv2d(embed_dim, 1, kernel_size=1)  # Mask Prediction

    def forward(self, transformer_output, pixel_embeddings):
        # Class Predictions
        class_preds = self.classifier(transformer_output.mean(dim=0))  # (N, num_classes)
        
        # Mask Predictions
        batch_size, c, h, w = pixel_embeddings.shape
        mask_embeddings = transformer_output.permute(1, 2, 0).view(batch_size, c, h, w)
        masks = self.mask_predictor(mask_embeddings)  # (B, N, H, W)
        return class_preds, masks

# 4. 전체 MaskFormer 모델
class MaskFormer(nn.Module):
    def __init__(self, num_classes, num_queries=100):
        super(MaskFormer, self).__init__()
        self.pixel_level = PixelLevelModule()
        self.transformer = TransformerModule(num_queries=num_queries)
        self.segmentation = SegmentationModule(num_classes=num_classes)

    def forward(self, x):
        pixel_embeddings = self.pixel_level(x)  # Pixel-Level Embeddings
        transformer_output = self.transformer(pixel_embeddings)  # Transformer Output
        class_preds, masks = self.segmentation(transformer_output, pixel_embeddings)  # Final Predictions
        return class_preds, masks

# 모델 초기화
num_classes = 21  # 예: Pascal VOC
num_queries = 100  # MaskFormer의 Query 수
model = MaskFormer(num_classes=num_classes, num_queries=num_queries)

# 입력 테스트
dummy_input = torch.rand(2, 3, 256, 256)  # (B, C, H, W)
class_predictions, mask_predictions = model(dummy_input)

print("Class Predictions Shape:", class_predictions.shape)  # (B, num_classes)
print("Mask Predictions Shape:", mask_predictions.shape)  # (B, N, H, W)

 

 

  1. Pixel-Level Module:
    • ResNet50을 백본으로 사용하여 이미지 특징 맵을 추출한다.
    • Pixel Decoder는 특징 맵을 픽셀 단위의 임베딩으로 변환한다.
  2. Transformer Module:
    • Query Embeddings를 생성하며, 이들은 객체 정보를 학습한다.
    • Transformer Decoder는 Query와 Pixel Embeddings 간의 상호작용을 학습한다.
  3. Segmentation Module:
    • Transformer Output에서 Class PredictionsMask Predictions를 생성한다.
  4. 통합 모델:
    • 각 모듈을 연결하여 MaskFormer의 전체 구조를 구현한다.

 

끝. 이상입니다.

감사힙니다.

 

728x90
반응형

'의료 AI > [04] - 논문 리뷰' 카테고리의 다른 글

[01 X-ray Hand] - U-Net3++  (0) 2024.11.22
DeepLab v1 아키텍쳐 분석  (1) 2024.11.20
[01 X-ray Hand] - SegNet의 아키텍처  (2) 2024.11.18
U-Net 의 이해  (8) 2024.11.10

댓글