본문 바로가기
AI/딥러닝 프로젝트

[01 - ImageNet Sketch 데이터] - Grad Cam [3]

by AI-BT 2024. 9. 26.
728x90
반응형

Grad-CAM을 활용한 이미지 분류 모델 시각화

이번 포스팅에서는 Grad-CAM을 활용하여 이미지 분류 모델의 시각화를 구현한 사례를 소개합니다. Grad-CAM(Gradient-weighted Class Activation Mapping)은 신경망이 이미지를 분류할 때 주목한 부분을 시각화할 수 있는 방법입니다. 특히 복잡한 CNN(Convolutional Neural Network) 모델에서, 모델이 어떤 부분을 중점적으로 보았는지 확인할 수 있어 모델 해석에 큰 도움이 됩니다.

 

Grad-CAM을 적용한 전체 코드 개요

아래 코드에서는 ResNet101 모델을 사용하여 이미지 분류를 진행한 후, Grad-CAM을 통해 모델이 주목한 이미지 영역을 시각화하는 과정을 설명합니다. Augmentation이 적용된 데이터를 사용해 Grad-CAM을 출력하고, 각 이미지에 대한 원본 이미지, Grad-CAM 결과 이미지, 오버레이된 이미지를 시각화합니다.

 


1. 필요한 라이브러리 및 클래스 정의

Grad-CAM을 사용하려면 PyTorch, Albumentations, TorchCAM, 그리고 OpenCV와 같은 라이브러리들이 필요합니다. 먼저 기본적인 데이터 로딩, Augmentation 적용, 모델 선택 등을 위한 클래스를 정의합니다.

Custom Dataset 및 Albumentations Transform

CustomDataset 클래스는 이미지 경로와 타겟 정보를 포함하는 DataFrame을 통해 데이터를 로딩하는 역할을 합니다. 여기서는 Albumentations를 사용하여 다양한 이미지 변환을 적용합니다.

class CustomDataset(Dataset):
    def __init__(self, root_dir: str, info_df: pd.DataFrame, transform: Callable, is_inference: bool = False):
        self.root_dir = root_dir
        self.transform = transform
        self.is_inference = is_inference
        self.image_paths = info_df['image_path'].tolist()
        
        if not self.is_inference:
            self.targets = info_df['target'].tolist()

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, index: int) -> Union[Tuple[torch.Tensor, int], torch.Tensor]:
        img_path = os.path.join(self.root_dir, self.image_paths[index])
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transform(image)

        if self.is_inference:
            return image
        else:
            target = self.targets[index]
            return image, target

Albumentations Transform

AlbumentationsTransform 클래스는 학습 또는 테스트에 사용할 변환(Transform)을 정의합니다. 학습 시에는 데이터 증강(Augmentation)을 적용하고, 테스트 시에는 리사이즈와 정규화만 수행합니다.

class AlbumentationsTransform:
    def __init__(self, is_train: bool = True):
        common_transforms = [
            A.Resize(224, 224),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ]
        
        if is_train:
            self.transform = A.Compose(
                [
                    A.OneOf([A.RandomCrop(64, 64), A.Resize(224, 224)], p=0.5),
                    A.HorizontalFlip(p=0.5),
                    A.Rotate(limit=30),
                    A.RandomBrightnessContrast(brightness_limit=(-0.2, -0.2), contrast_limit=0, p=1),
                    A.Sharpen(alpha=(0.2, 0.5), lightness=(0.7, 1.3), p=1),
                    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.5),
                ] + common_transforms
            )
        else:
            self.transform = A.Compose(common_transforms)

    def __call__(self, image) -> torch.Tensor:
        transformed = self.transform(image=image)
        return transformed['image']

 


2. Grad-CAM을 통한 모델 시각화 함수

Grad-CAM을 적용하기 위해 visualize_gradcam 함수를 정의합니다. 이 함수는 Grad-CAM을 사용하여 모델이 주목한 영역을 시각화하고, 결과 이미지를 출력합니다.

def visualize_gradcam(model: torch.nn.Module, device: torch.device, dataloader: DataLoader, target_layer: str, num_images: int, random_indices: list = None, num_classes: int = 500):
    cam_extractor = GradCAM(model, target_layer)
    model.eval()

    total_images = len(dataloader.dataset)
    if random_indices is None:
        random_indices = random.sample(range(total_images), num_images)
    else:
        random_indices = random_indices[:num_images]

    fig, axes = plt.subplots(num_images, 3, figsize=(18, 6 * num_images))

    current_index = 0
    selected_indices = set(random_indices)

    for inputs in dataloader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        batch_size = inputs.size(0)
        for j in range(batch_size):
            if current_index in selected_indices:
                class_idx = preds[j].item()
                if class_idx >= num_classes:
                    class_idx = random.randint(0, num_classes - 1)

                cam = cam_extractor(class_idx=class_idx, scores=outputs)[0]
                cam = cam.mean(dim=0).cpu().numpy()
                cam = cv2.resize(cam, (inputs[j].shape[2], inputs[j].shape[1]))
                cam = (cam - cam.min()) / (cam.max() - cam.min())
                cam = np.uint8(255 * cam)
                cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
                cam = cv2.cvtColor(cam, cv2.COLOR_BGR2RGB)

                input_image = inputs[j].cpu().numpy().transpose((1, 2, 0))
                input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min())
                input_image = (input_image * 255).astype(np.uint8)

                img_idx = list(selected_indices).index(current_index)
                ax_row = axes[img_idx] if num_images > 1 else axes

                ax_row[0].imshow(input_image)
                ax_row[0].set_title(f"Original Image {current_index}")
                ax_row[0].axis('off')

                ax_row[1].imshow(cam)
                ax_row[1].set_title("Grad-CAM Image")
                ax_row[1].axis('off')

                overlay = cv2.addWeighted(input_image, 0.5, cam, 0.5, 0)
                ax_row[2].imshow(overlay)
                ax_row[2].set_title("Overlay Image")
                ax_row[2].axis('off')

            current_index += 1
            if current_index >= total_images:
                break

    plt.tight_layout()
    plt.show()

 


 

3. Grad-CAM 시각화 실행

위에서 정의한 visualize_gradcam 함수를 통해 모델이 예측에 사용한 이미지의 주요 영역을 시각화합니다. 레이어를 선택하여 각 단계에서 모델이 주목한 부분을 확인할 수 있습니다.

layers_to_visualize = [
    'layer1.0.conv1',  # 초기 레이어
    'layer2.1.conv2',  # 중간 레이어
    'layer3.2.conv3',  # 더 깊은 레이어
    'layer4.2.act2'    # 마지막 레이어
]

for layer in layers_to_visualize:
    print(f"Visualizing Grad-CAM for layer: {layer}")
    visualize_gradcam(model.model, device, test_loader, target_layer=layer, num_images=3)

 

 

 


Insight

Grad-CAM을 활용한 시각화는 CNN 모델이 이미지를 분류할 때 주목하는 영역을 직관적으로 파악할 수 있어 모델 해석에 큰 도움을 준다. 특히, Augmentation이 적용된 데이터를 활용한 Grad-CAM 시각화는 모델이 다양한 데이터 변형에도 어떻게 반응하는지 확인할 수 있어 학습 성능을 개선하는 데 중요한 인사이트를 제공한다.

Grad-CAM을 통해 모델이 예측할 때 집중하는 부분을 명확하게 파악할 수 있었고, 이를 바탕으로 모델의 약점과 개선할 방향을 찾아낼 수 있었다. 예를 들어, 모델이 주목해야 할 핵심 영역이 아닌 다른 부분에 집중한다면, 데이터 증강 기법이나 아키텍처 수정이 필요할 수 있음을 시사한다. 이러한 피드백을 통해 모델을 보다 정확하고 효율적으로 발전시킬 수 있으며, 궁극적으로 더 높은 성능을 가진 딥러닝 모델로 개선할 수 있다.

Grad-CAM은 단순한 시각화 도구를 넘어, 모델 개선과 최적화를 위한 전략적 방향을 제시하는 중요한 도구로 활용 될 수 있다는 점이 인상깊었다.

 

감사합니다.

 

 
 

 

728x90
반응형

댓글