본문 바로가기
의료 AI/[06] - 프로젝트

Segmentation Models Pytorch 사용법

by AI-BT 2024. 11. 25.
728x90
반응형

이번 포스팅에서는 Segmentation Models PyTorch(SMP) 라이브러리를 사용하여 UNet 기반의 모델을 학습하는 방법을 다룬다. SMP는 다양한 세그멘테이션 모델과 사전 학습된 백본(Backbone)을 제공하여 모델 개발을 간편하게 만든다. 이 글에서는 PyTorch와 SMP를 활용해 X-Ray 데이터를 학습하는 코드를 구성하고, 각 단계에서 사용한 기능과 로직을 설명하겠습니다.

 

SMP란 무엇인가?

Segmentation Models PyTorch는 PyTorch를 기반으로 구축된 세그멘테이션 모델 라이브러리이다. UNet, UNet++, DeepLabV3+와 같은 다양한 모델을 지원하며, EfficientNet, ResNet, RegNet 등 여러 백본을 선택할 수 있다. 특히 사전 학습된 백본을 통해 초기 학습 성능을 향상시킬 수 있는 것이 특징이다.


Train code 설명

 

1. Argument Parsing

argparse를 사용해 학습에 필요한 설정값을 동적으로 변경할 수 있다.

  • 모델 이름, 이미지 크기, 학습률, 배치 크기 등 학습에 필요한 설정값을 정의했다.
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="unet")
    parser.add_argument("--img_size", type=int, default=512)
    parser.add_argument("--tr_batch_size", type=int, default=2)
    parser.add_argument("--val_batch_size", type=int, default=1)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--fold", type=int, default=0)
    parser.add_argument("--seed", type=int, default=21)
    parser.add_argument("--pt_name", type=str, default="unetPP_all.pt")
    parser.add_argument("--log_name", type=str, default="unetPP_all")
    return parser.parse_args()

 

2. Augmentation

Albumentations 라이브러리를 사용해 데이터 증강(Augmentation)을 적용했다.

  • train_tf에는 밝기/대비 조정, CLAHE 필터 등을 추가하여 학습 데이터를 다양화했다.
  • val_tf는 검증 데이터에 기본적인 리사이즈만 적용했다.
train_tf = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
    A.CLAHE(clip_limit=(1, 4), tile_grid_size=(8, 8), p=0.5)
])

val_tf = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
])

 

3. SMP 모델 정의

SMP의 UNet++ 모델을 사용했다.

  • resnet152 백본을 사용하며, imagenet으로 사전 학습된 가중치를 적용했다.
  • 입력 채널은 3(RGB), 출력 클래스는 29로 설정했다.
model = smp.UnetPlusPlus(
    encoder_name="resnet152",
    encoder_weights="imagenet", 
    in_channels=3,
    classes=29
)

 

4. 손실 함수 및 옵티마이저

  • 손실 함수: BCE(Binary Cross Entropy)와 Dice Loss를 혼합한 BCE_Dice_loss를 사용했다.
    두 손실 함수를 결합해 픽셀 단위 손실과 전반적인 세그멘테이션 성능을 균형 있게 고려했다.
def BCE_Dice_loss(pred, target, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)
    return bce * bce_weight + dice * (1 - bce_weight)
  • 옵티마이저: Adam를 사용했다.
    학습률은 1e-4로 설정하고, weight decay를 1e-5로 추가해 과적합을 방지했다.
optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-5)

 

5. 학습 루프

학습 루프는 다음과 같은 흐름으로 구성되었다.

  1. 모델을 train 모드로 설정하고 데이터를 순차적으로 학습한다.
  2. 스케줄러를 사용해 학습률을 점진적으로 조정한다.
  3. 손실 값이 개선될 경우 모델을 저장한다.
def train(model, data_loader, criterion, optimizer):
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-8)
    best_loss = float('inf')  # Loss를 기준으로 모델 저장

    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0

        for step, (images, masks) in enumerate(data_loader):
            images, masks = images.cuda(), masks.cuda()
            outputs = model(images)
            loss = criterion(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(data_loader)

        if avg_loss < best_loss:
            best_loss = avg_loss
            save_model(model)

 

이번 포스팅에서는 SMP를 활용해 UNet++ 모델을 학습하는 방법을 살펴보았다. SMP의 간편한 인터페이스와 PyTorch의 유연성을 결합하면 고성능 세그멘테이션 모델을 효율적으로 개발할 수 있다. 추후 프로젝트가 끝나면 어떤 백본과 하이퍼파라미터를 사용했는지 공유 하겠습니다.

 

이상입니다.

끝.

 

728x90
반응형

댓글