728x90
반응형
이번 포스팅에서는 Segmentation Models PyTorch(SMP) 라이브러리를 사용하여 UNet 기반의 모델을 학습하는 방법을 다룬다. SMP는 다양한 세그멘테이션 모델과 사전 학습된 백본(Backbone)을 제공하여 모델 개발을 간편하게 만든다. 이 글에서는 PyTorch와 SMP를 활용해 X-Ray 데이터를 학습하는 코드를 구성하고, 각 단계에서 사용한 기능과 로직을 설명하겠습니다.
SMP란 무엇인가?
Segmentation Models PyTorch는 PyTorch를 기반으로 구축된 세그멘테이션 모델 라이브러리이다. UNet, UNet++, DeepLabV3+와 같은 다양한 모델을 지원하며, EfficientNet, ResNet, RegNet 등 여러 백본을 선택할 수 있다. 특히 사전 학습된 백본을 통해 초기 학습 성능을 향상시킬 수 있는 것이 특징이다.
Train code 설명
1. Argument Parsing
argparse를 사용해 학습에 필요한 설정값을 동적으로 변경할 수 있다.
- 모델 이름, 이미지 크기, 학습률, 배치 크기 등 학습에 필요한 설정값을 정의했다.
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="unet")
parser.add_argument("--img_size", type=int, default=512)
parser.add_argument("--tr_batch_size", type=int, default=2)
parser.add_argument("--val_batch_size", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--fold", type=int, default=0)
parser.add_argument("--seed", type=int, default=21)
parser.add_argument("--pt_name", type=str, default="unetPP_all.pt")
parser.add_argument("--log_name", type=str, default="unetPP_all")
return parser.parse_args()
2. Augmentation
Albumentations 라이브러리를 사용해 데이터 증강(Augmentation)을 적용했다.
- train_tf에는 밝기/대비 조정, CLAHE 필터 등을 추가하여 학습 데이터를 다양화했다.
- val_tf는 검증 데이터에 기본적인 리사이즈만 적용했다.
train_tf = A.Compose([
A.Resize(IMAGE_SIZE, IMAGE_SIZE),
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
A.CLAHE(clip_limit=(1, 4), tile_grid_size=(8, 8), p=0.5)
])
val_tf = A.Compose([
A.Resize(IMAGE_SIZE, IMAGE_SIZE),
])
3. SMP 모델 정의
SMP의 UNet++ 모델을 사용했다.
- resnet152 백본을 사용하며, imagenet으로 사전 학습된 가중치를 적용했다.
- 입력 채널은 3(RGB), 출력 클래스는 29로 설정했다.
model = smp.UnetPlusPlus(
encoder_name="resnet152",
encoder_weights="imagenet",
in_channels=3,
classes=29
)
4. 손실 함수 및 옵티마이저
- 손실 함수: BCE(Binary Cross Entropy)와 Dice Loss를 혼합한 BCE_Dice_loss를 사용했다.
두 손실 함수를 결합해 픽셀 단위 손실과 전반적인 세그멘테이션 성능을 균형 있게 고려했다.
def BCE_Dice_loss(pred, target, bce_weight=0.5):
bce = F.binary_cross_entropy_with_logits(pred, target)
pred = F.sigmoid(pred)
dice = dice_loss(pred, target)
return bce * bce_weight + dice * (1 - bce_weight)
- 옵티마이저: Adam를 사용했다.
학습률은 1e-4로 설정하고, weight decay를 1e-5로 추가해 과적합을 방지했다.
optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-5)
5. 학습 루프
학습 루프는 다음과 같은 흐름으로 구성되었다.
- 모델을 train 모드로 설정하고 데이터를 순차적으로 학습한다.
- 스케줄러를 사용해 학습률을 점진적으로 조정한다.
- 손실 값이 개선될 경우 모델을 저장한다.
def train(model, data_loader, criterion, optimizer):
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-8)
best_loss = float('inf') # Loss를 기준으로 모델 저장
for epoch in range(NUM_EPOCHS):
model.train()
total_loss = 0
for step, (images, masks) in enumerate(data_loader):
images, masks = images.cuda(), masks.cuda()
outputs = model(images)
loss = criterion(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(data_loader)
if avg_loss < best_loss:
best_loss = avg_loss
save_model(model)
이번 포스팅에서는 SMP를 활용해 UNet++ 모델을 학습하는 방법을 살펴보았다. SMP의 간편한 인터페이스와 PyTorch의 유연성을 결합하면 고성능 세그멘테이션 모델을 효율적으로 개발할 수 있다. 추후 프로젝트가 끝나면 어떤 백본과 하이퍼파라미터를 사용했는지 공유 하겠습니다.
이상입니다.
끝.
728x90
반응형
'의료 AI > [06] - 프로젝트' 카테고리의 다른 글
Pseudo 라벨링(Pseudo Labeling)이란 무엇인가? (0) | 2024.11.24 |
---|---|
Semantic Segmentation 대회에서 사용하는 방법들 (1) | 2024.11.23 |
[꼬꼬무] - PyTorch 에서 메모리 부족할때 해결하는 방법! (Autocast, GradScaler) (0) | 2024.11.21 |
댓글