[04 Segmentation] - Segmentation Models Pytorch 사용법

2024. 11. 25. 12:13·딥러닝 (Deep Learning)/[08] - 프로젝트
728x90
반응형

이번 포스팅에서는 Segmentation Models PyTorch(SMP) 라이브러리를 사용하여 UNet 기반의 모델을 학습하는 방법을 다룬다. SMP는 다양한 세그멘테이션 모델과 사전 학습된 백본(Backbone)을 제공하여 모델 개발을 간편하게 만든다. 이 글에서는 PyTorch와 SMP를 활용해 X-Ray 데이터를 학습하는 코드를 구성하고, 각 단계에서 사용한 기능과 로직을 설명하겠습니다.

 

SMP란 무엇인가?

Segmentation Models PyTorch는 PyTorch를 기반으로 구축된 세그멘테이션 모델 라이브러리이다. UNet, UNet++, DeepLabV3+와 같은 다양한 모델을 지원하며, EfficientNet, ResNet, RegNet 등 여러 백본을 선택할 수 있다. 특히 사전 학습된 백본을 통해 초기 학습 성능을 향상시킬 수 있는 것이 특징이다.


Train code 설명

 

1. Argument Parsing

argparse를 사용해 학습에 필요한 설정값을 동적으로 변경할 수 있다.

  • 모델 이름, 이미지 크기, 학습률, 배치 크기 등 학습에 필요한 설정값을 정의했다.
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="unet")
    parser.add_argument("--img_size", type=int, default=512)
    parser.add_argument("--tr_batch_size", type=int, default=2)
    parser.add_argument("--val_batch_size", type=int, default=1)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--fold", type=int, default=0)
    parser.add_argument("--seed", type=int, default=21)
    parser.add_argument("--pt_name", type=str, default="unetPP_all.pt")
    parser.add_argument("--log_name", type=str, default="unetPP_all")
    return parser.parse_args()

 

2. Augmentation

Albumentations 라이브러리를 사용해 데이터 증강(Augmentation)을 적용했다.

  • train_tf에는 밝기/대비 조정, CLAHE 필터 등을 추가하여 학습 데이터를 다양화했다.
  • val_tf는 검증 데이터에 기본적인 리사이즈만 적용했다.
train_tf = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
    A.CLAHE(clip_limit=(1, 4), tile_grid_size=(8, 8), p=0.5)
])

val_tf = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
])

 

3. SMP 모델 정의

SMP의 UNet++ 모델을 사용했다.

  • resnet152 백본을 사용하며, imagenet으로 사전 학습된 가중치를 적용했다.
  • 입력 채널은 3(RGB), 출력 클래스는 29로 설정했다.
model = smp.UnetPlusPlus(
    encoder_name="resnet152",
    encoder_weights="imagenet", 
    in_channels=3,
    classes=29
)

 

4. 손실 함수 및 옵티마이저

  • 손실 함수: BCE(Binary Cross Entropy)와 Dice Loss를 혼합한 BCE_Dice_loss를 사용했다.
    두 손실 함수를 결합해 픽셀 단위 손실과 전반적인 세그멘테이션 성능을 균형 있게 고려했다.
def BCE_Dice_loss(pred, target, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)
    return bce * bce_weight + dice * (1 - bce_weight)
  • 옵티마이저: Adam를 사용했다.
    학습률은 1e-4로 설정하고, weight decay를 1e-5로 추가해 과적합을 방지했다.
optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-5)

 

5. 학습 루프

학습 루프는 다음과 같은 흐름으로 구성되었다.

  1. 모델을 train 모드로 설정하고 데이터를 순차적으로 학습한다.
  2. 스케줄러를 사용해 학습률을 점진적으로 조정한다.
  3. 손실 값이 개선될 경우 모델을 저장한다.
def train(model, data_loader, criterion, optimizer):
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-8)
    best_loss = float('inf')  # Loss를 기준으로 모델 저장

    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0

        for step, (images, masks) in enumerate(data_loader):
            images, masks = images.cuda(), masks.cuda()
            outputs = model(images)
            loss = criterion(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(data_loader)

        if avg_loss < best_loss:
            best_loss = avg_loss
            save_model(model)

 

이번 포스팅에서는 SMP를 활용해 UNet++ 모델을 학습하는 방법을 살펴보았다. SMP의 간편한 인터페이스와 PyTorch의 유연성을 결합하면 고성능 세그멘테이션 모델을 효율적으로 개발할 수 있다. 추후 프로젝트가 끝나면 어떤 백본과 하이퍼파라미터를 사용했는지 공유 하겠습니다.

 

이상입니다.

끝.

 

728x90
반응형
저작자표시 비영리 변경금지 (새창열림)

'딥러닝 (Deep Learning) > [08] - 프로젝트' 카테고리의 다른 글

[05 RE-ID] - 기술 도입의 한계와 고려사항 그리고 해결책  (0) 2024.12.19
[05 RE-ID] - Re-Identification 기술 이해하기!  (1) 2024.12.19
[04 Segmentation] - Semantic Segmentation 대회에서 사용하는 방법들  (1) 2024.11.23
[04 Segmentation] - PyTorch 에서 메모리 부족할때 해결하는 방법! (Autocast, GradScaler)  (1) 2024.11.21
[04 Segmentation] - MMsegmentation 사용법  (11) 2024.11.09
'딥러닝 (Deep Learning)/[08] - 프로젝트' 카테고리의 다른 글
  • [05 RE-ID] - 기술 도입의 한계와 고려사항 그리고 해결책
  • [05 RE-ID] - Re-Identification 기술 이해하기!
  • [04 Segmentation] - Semantic Segmentation 대회에서 사용하는 방법들
  • [04 Segmentation] - PyTorch 에서 메모리 부족할때 해결하는 방법! (Autocast, GradScaler)
AI-BT
AI-BT
인공지능 (AI)과 블록체인에 관심있는 블로그
  • AI-BT
    AI-BLACK-TIGER
    AI-BT
  • 전체
    오늘
    어제
    • 분류 전체보기 (133)
      • 딥러닝 (Deep Learning) (81)
        • [01] - 딥러닝 이란? (5)
        • [02] - 데이터 (4)
        • [03] - 모델 (17)
        • [04] - 학습 및 최적화 (14)
        • [05] - 논문 리뷰 (17)
        • [06] - 평가 및 결과 분석 (4)
        • [07] - Serving (6)
        • [08] - 프로젝트 (14)
      • 머신러닝 & 딥러닝 개념 (0)
        • 머신러닝 (0)
        • 딥러닝 (0)
      • Quant 투자 (12)
        • 경제 (9)
        • 퀀트 알고리즘 & 전략 개요 (3)
      • 딥러닝 Math (4)
      • AI Naver boost camp (22)
        • 회고 (19)
        • CV 프로젝트 가이드 (3)
      • Python (1)
      • 개발 및 IT 용어 (6)
        • IT 용어 (2)
        • VS Code (1)
      • 코인 정보 (7)
  • 인기 글

  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
AI-BT
[04 Segmentation] - Segmentation Models Pytorch 사용법
상단으로

티스토리툴바