본문 바로가기
의료 AI/[06] - 프로젝트

Pseudo 라벨링(Pseudo Labeling)이란 무엇인가?

by AI-BT 2024. 11. 24.
728x90
반응형

Pseudo 라벨링은 반지도 학습(Semi-supervised Learning)의 한 기법이다. 이 방법은 모델이 예측한 결과를 기반으로*라벨이 없는 데이터(unlabeled data)에 임시 라벨을 할당하여 학습에 활용하는 방식이다. Pseudo 라벨링은 특히 라벨링 비용이 높은 상황에서 데이터 효율성을 극대화하는 데 유용하다.


왜 Pseudo 라벨링을 사용하는가?

  1. 라벨링 비용 절감
    데이터에 라벨을 직접 붙이는 작업은 많은 비용과 시간이 소요된다. Pseudo 라벨링은 모델의 예측 결과를 사용하여 라벨링 작업을 간소화한다.
  2. 데이터 활용 극대화
    라벨이 없는 데이터도 학습에 포함할 수 있어, 전체 데이터셋의 크기를 증가시킨다. 이를 통해 모델의 일반화 성능을 높일 수 있다.
  3. 모델 성능 개선
    모델이 학습 데이터를 더 많이 활용하면서 기존 라벨 데이터와의 조화를 이루어 성능을 개선할 수 있다.

Pseudo 라벨링의 동작 원리

  1. 초기 모델 학습
    라벨이 있는 데이터를 사용하여 초기 모델을 학습시킨다.
  2. 라벨 없는 데이터 예측
    학습된 모델로 라벨이 없는 데이터의 라벨을 예측한다. 이 예측값을 pseudo label이라 부른다.
  3. 라벨 데이터로 확장
    예측된 pseudo label을 라벨 데이터로 추가하여 모델을 재학습한다. 이 과정을 반복하면서 모델의 성능을 점진적으로 향상시킨다.

 

아래 코드를 보면서 이해해보자.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# 1. 간단한 Dataset 정의
class CustomDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.labels is not None:  # 라벨이 있는 경우
            return self.data[idx], self.labels[idx]
        return self.data[idx]  # 라벨이 없는 경우

# 2. 초기 모델 학습
def train_model(model, dataloader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

# 3. Pseudo 라벨링 적용
def generate_pseudo_labels(model, unlabeled_dataloader, threshold=0.8):
    model.eval()
    pseudo_data = []
    pseudo_labels = []
    with torch.no_grad():
        for inputs in unlabeled_dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            max_probs, preds = torch.max(probs, dim=1)
            for idx in range(len(max_probs)):
                if max_probs[idx] > threshold:  # 임계값 이상인 경우만 라벨링
                    pseudo_data.append(inputs[idx].cpu())
                    pseudo_labels.append(preds[idx].cpu())
    return torch.stack(pseudo_data), torch.tensor(pseudo_labels)

# 4. Pseudo 라벨링 예제 실행
device = "cuda" if torch.cuda.is_available() else "cpu"

# 데이터 및 모델 준비
labeled_data = torch.randn(100, 3, 32, 32)
labeled_labels = torch.randint(0, 10, (100,))
unlabeled_data = torch.randn(50, 3, 32, 32)

train_dataset = CustomDataset(labeled_data, labeled_labels)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

unlabeled_dataset = CustomDataset(unlabeled_data)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=16, shuffle=False)

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(3 * 32 * 32, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 초기 학습
print("Step 1: Training with labeled data")
train_model(model, train_loader, criterion, optimizer)

# Pseudo 라벨 생성
print("Step 2: Generating pseudo labels")
pseudo_data, pseudo_labels = generate_pseudo_labels(model, unlabeled_loader)

# Pseudo 라벨 추가 학습
print("Step 3: Training with pseudo labeled data")
combined_data = torch.cat([labeled_data, pseudo_data])
combined_labels = torch.cat([labeled_labels, pseudo_labels])

combined_dataset = CustomDataset(combined_data, combined_labels)
combined_loader = DataLoader(combined_dataset, batch_size=16, shuffle=True)

train_model(model, combined_loader, criterion, optimizer)

Pseudo 라벨링의 장단점

장점

  • 라벨링 비용을 줄이면서도 성능을 향상시킬 수 있다.
  • 모델이 더 많은 데이터를 학습하게 되어 일반화 성능이 좋아진다.

단점

  • 잘못된 pseudo 라벨이 학습에 부정적인 영향을 줄 수 있다.
  • 모델의 초기 성능에 따라 pseudo 라벨의 품질이 좌우된다.

 

 

Pseudo 라벨링은 데이터 활용도를 극대화할 수 있는 강력한 방법이다. 라벨링이 어려운 상황에서 모델 성능을 개선하고 데이터 효율성을 높이는 데 효과적이다. 하지만 적절한 임계값 설정과 초기 모델 성능 관리가 중요하다. Pseudo 라벨링을 효과적으로 적용하면 비용 절감과 성능 향상의 두 마리 토끼를 잡을 수 있다.

728x90
반응형

댓글