반갑습니다... 요즘 몸도 안좋고 이리저리 바빠서 블로그는 쳐다볼 시간이 없네요.
최근 연구실에서 CT image 데이터셋을 다룰 일이 생겨서 작업하다 매우 귀찮은 작업이 있음을 인지하였습니다. CT 이미지는 아래 이미지 처럼 어떤 축을 기준으로 촬영되느냐에 따라서 볼 수 있는 구조가 달라지고, 따라서 목적에 따라 사용되는 이미지가 달라집니다. 따라서 '일반적인' 어디 공개되어 있는 CT 데이터들은 axial, coronal, sagittal인지 명시하거나 아예 다른 폴더에 저장되어 있는 경우가 많았던 것 같습니다. Kaggle과 같은 곳에서 볼 수 있는 오픈 데이터셋에서 여러 방향에서 촬영된 CT 이미지의 경우 대부분 그렇게 정리되어 있을 것 같습니다.
그런데 제가 병원으로부터 받은 데이터는 환자의 CT 이미지들이 CT 촬영 방향과 무관하게 한 폴더에 때려박혀져 있었습니다. 뿐만아니라 환자마다 촬영된 CT 이미지의 수도 달랐을 뿐만 아니라 모종의 이유로 여러번 촬영을 진행한 환자도 있었으며 나아가 frame들이 연속적으로 저장되어 있지 않는 환자들도 다수 존재했습니다. 나아가 한 환자에 대해서도 axial으로 촬영된 샘플의 수와 coronal으로 촬영된 데이터의 수도 달랐습니다. (의료데이터의 특성상 예시 이미지를 가지고 올 수 없어 아쉽습니다...)
데이터를 사용하는 사람들이야 데이터를 어떻게 저장해야 미래에 분석 및 이용에 용이할지를 고려하지만 현장에서 데이터를 수집하는 의사 선생님들이 책상 앞에 앉아 키보드만 두드리는 저희를 고려하는 것은 절대 일반적인 일이 아닌 것 같습니다. 여러 의사 선생님들이 정리한 데이터셋의 경우 아예 자동화가 불가능한 경우도 더러 있었기에 이정도면 양반이라고 생각했습니다.
전처리도 아니고 데이터셋 정리에 많은 시간을 쏟고 싶지 않았던 저는 가장 기본적으로 이웃하는 frame들은 spatial domain상에서 유사도를 가진다는 sequential characteristics를 이용해서 이미지들을 분류하려 하였습니다.
Classification based on sequential characteristics
def calculate_frame_diff(frames):
"""
두 이미지 프레임 간의 픽셀 값 차이의 총합을 계산하는 함수
Parameters:
frames (numpy.ndarray): shape이 (N, H, W, 3)인 이미지 프레임들의 numpy array
Returns:
numpy.ndarray: shape이 (N-1,)인 픽셀 값 차이의 평균치를 담은 numpy array
"""
num_frames = frames.shape[0]
# 픽셀 값 차이를 저장할 배열을 생성합니다.
frame_diff_avg = np.zeros(num_frames - 1)
for i in range(num_frames - 1):
# 현재 프레임과 다음 프레임 간의 픽셀 값 차이를 계산하고 평균을 구합니다.
frame_diff_avg[i] = np.average(np.abs(frames[i] - frames[i+1]))
#frame_diff_avg[i] = np.average(np.multiply(frames[i], frames[i+1]))
return frame_diff_avg
def find_indices_exceeding_threshold(arr, threshold):
"""
특정 threshold를 넘는 값을 가지는 인덱스를 반환하는 함수
Parameters:
arr (numpy.ndarray): shape이 (N-1,)인 numpy array
threshold (float): 찾고자 하는 threshold 값
Returns:
list: 특정 threshold를 넘는 값을 가지는 인덱스 리스트
"""
exceeding_indices = np.where(np.abs(np.diff(arr)) > threshold)[0]
return exceeding_indices
위 두 개의 함수를 활용하여서 축을 분류해보려고 했습니다. 보통 CT 영상들은 미세하게 이동해가며 신체 구조를 촬영하기에 이웃하는 두 frame의 이미지는 spatial domain 상에서 큰 차이를 가지지 않기에 axial으로 촬영된 frame에서 coronal축으로 넘어가는 순간 크게 이미지가 변화할 것이고 따라서 해당 부분의 인덱스를 기준으로 축을 구분하면 될거라고 생각했습니다.
실제로 위 그림과 같이 frame들이 잘 정렬되어 있는 환자의 이미지들에서는 sagittal에서 axial, axial에서 coronal로 넘어가는 부분에서 차이가 크게 발생하는 것을 확인할 수 있었습니다. 하지만 위 방식만으로는 완벽하게 이미지들을 분류하지 못했습니다. 이유는 크게 두가지였습니다.
첫번째는 당연히 예외 때문이었습니다. 중간중간에 촬영을 새로한 경우나 coronal frame들 사이에 axial frame이 섞여있는 경우 frame간 정렬이 제대로 되어 있지 않은 경우, 중간에 ambient frame이 있는 경우 등 상상을 뛰어 넘는 다양한 예외가 있었고 '그 정도 예외는 직접 정리 하지 뭐' 했던 생각을 접게 됐습니다. 나아가 각 환자마다 방향의 전환이 일어났다고 판단하는 threshold를 잡는 것도 쉬운 일이 아니였습니다.
결국 이렇게 frame들의 정렬이나 환자의 재촬영과 같은 가정이나 조건이 필요하지 않은 방식이 필요하다고 생각했습니다. CT 이미지의 sequential한 성질을 버려야 했고, 하나의 frame을 독립적으로 분류해야겠다고 생각했고, 결국 가장 쉽게 굴릴 수 있으면서 성능이 보장되는 deep neural network라는 선택지로 회귀하게 되었습니다.
Classification via Deep Neural Network
우선 간단하게 10명의 환자의 데이터를 직접 axial과 coronal로 분류하여 학습 데이터셋을 만들고 다른 3명의 데이터를 axial과 coronal로 직접 분류하여 validation set을 만들었습니다. Sagittal의 경우 환자별로 sagittal frame을 가지는 경우도 있고 없는 경우도 있었는데, 가지는 경우는 모두 가장 첫번째 index에 위치하였으며, 세 장 이상의 sagittal frame을 가지는 환자는 없었습니다. 저의 경우 sagittal은 분석의 대상이 아니였고 sagittal이 아닌 다른 방향의 경우에도 일부 데이터가 소실되어도 주변에 유사한 frame이 많이 있으니 괜찮다고 생각하고 일단 과감하게 가장 앞 3~5장의 frame을 날려버렸고 에 굳이 이를 추가적으로 분류할 필요가 없었습니다.
이후 아래의 trainer를 통해 학습을 진행하였습니다. 학부생 시절에 짜놓은 코드에서 크게 변한 것이 없어 수정해야 하는 부분이 많은 코드이지만 이 정도 간단한 task를 위해서는 충분한 것 같다고 생각했습니다.
import os
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import sklearn.metrics as metrics
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights
import utils
class SupervisedTrainer(object):
def __init__(self):
self.save_path = f'./checkpoints/'
os.makedirs(self.save_path, exist_ok=True)
self.epoch = 0
self.epochs = 2
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) # Imagenet1k로 pretrained 된 ResNet50
self.model.fc = nn.Linear(2048, 2) # 분류하고자 하는 class가 Axial/Coronal 2개이므로 head의 형태를 바꿔줌
self.model.to(self.device)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.AdamW(self.model.parameters(), lr=0.0001)
self.trainset = utils.load_dataset(is_train=True)
self.testset = utils.load_dataset(is_train=False)
self.train_loader = DataLoader(self.trainset, batch_size=128, shuffle=True , drop_last=True )
self.test_loader = DataLoader(self.testset , batch_size=128, shuffle=False, drop_last=False)
self.train_loss = []
self.test_loss = []
self.accs = []
total_params = sum(p.numel() for p in self.model.parameters())
print(f'model name:resnet\ndataset:axial-coronal\ndevice:{self.device}\nTotal parameter:{total_params:,}')
def train(self):
loss_trace = []
self.model.train()
for _, batch in enumerate(self.train_loader):
self.optimizer.zero_grad()
X, Y = batch
X, Y = X.to(self.device), Y.to(self.device)
pred = self.model(X)
loss = self.criterion(pred, Y)
loss_trace.append(loss.cpu().detach().numpy())
loss.backward()
self.optimizer.step()
self.train_loss.append(np.average(loss_trace))
@torch.no_grad()
def test(self):
self.model.eval()
loss_trace = []
result_pred, result_anno = [], []
for idx, batch in enumerate(self.test_loader):
X, Y = batch
X, Y = X.to(self.device), Y.to(self.device)
pred = self.model(X)
loss = self.criterion(pred, Y)
loss_trace.append(loss.cpu().detach().numpy())
pred_np = pred.to('cpu').detach().numpy()
pred_np = np.argmax(pred_np, axis=1).squeeze()
Y_np = Y.to('cpu').detach().numpy().reshape(-1, 1).squeeze()
result_pred = np.hstack((result_pred, pred_np))
result_anno = np.hstack((result_anno, Y_np))
acc = metrics.accuracy_score(y_true=result_anno, y_pred=result_pred)
self.test_loss.append(np.average(loss_trace))
self.accs.append(acc)
def save_model(self):
torch.save(self.model.state_dict(), f'{self.save_path}/{self.epoch+1}.pth')
def print_train_info(self):
print(f'({self.epoch+1:03}/{self.epochs}) Train Loss:{self.train_loss[self.epoch]:>6.4f} Test Loss:{self.test_loss[self.epoch]:>6.4f} Test Accuracy:{self.accs[self.epoch]*100:>5.2f}%')
제가 다루는 CT이미지의 경우 512*512의 resolution을 가지지만 ImageNet으로 pretrained된 model을 사용하기 위해서 아래와 같이 dataset을 호출할 때 이미지의 resolution을 224*224로 resize해주었으며, 하나의 채널값을 RGB 채널로 broadcast하도록 하여 pretrained된 모델에 집어넣을 수 있도록 하였습니다.
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
def load_dataset(is_train):
transform = transforms.Compose([transforms.Resize(224),
#transforms.RandomCrop(32, padding=2,padding_mode='reflect'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),])
if is_train:
return ImageFolder('./train', transform=transform)
else:
return ImageFolder('./test', transform=transform)
많은 epoch을 요구하지도 않고 2번만에 test loss가 0.0000 아래로 떨어졌고 accuracy 또한 100%를 찍는 것을 확인할 수 있습니다. 애초에 Axial frame과 coronal frame이 워낙 두드러지는 차이를 가지고 있고, 이미 ImageNet이라는 큰 데이터셋으로 학습된 backbone을 가지고 있기에 저정도의 finetune으로도 충분히 좋은 성능을 낼 수 있었던 것 같습니다.
최종적으로 파일을 정리하는 코드입니다. 혹시나 원본을 덮어쓰거나 하는 실수를 하지 않기 위해서 (여러번 덮어쓰고 나서 그러지 않기로 다짐했습니다...) 파일 자체를 이동하지 않고 복사하는 방식을 선택했습니다.
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(2048, 2)
model.load_state_dict(torch.load('./checkpoints/2.pth'))
model.eval()
model.to('cuda')
import os
from PIL import Image
from tqdm import tqdm
from shutil import copyfile
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--s', type=int, default=0)
parser.add_argument('--e', type=int, default=2136)
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),])
class_dict = {0:'axial', 1:'coronal'}
softmax = nn.Softmax(dim=-1)
users = [user for user in os.listdir('./dataset') if os.path.isdir(f'./dataset/{user}')]
with torch.no_grad():
for subjectID in tqdm(users[args.s:args.e]):
path = f'./dataset/{subjectID}/image/{subjectID}'
dest = {0:f'./dataset/{subjectID}/image/axial',
1:f'./dataset/{subjectID}/image/coronal'}
for i in range(2):
os.makedirs(dest[i], exist_ok=True)
files = [file for file in os.listdir(path) if file.endswith('.tif')]
files.sort()
files = files[5:] #remove sagittal image this may remove non-sagittal frame but let's omit this
imgs = torch.zeros((len(files),3,224,224)) # some imgs have alpha channel
for idx, file in enumerate(files):
imgs[idx] = transform(Image.open(f'{path}/{file}'))
imgs = imgs.to('cuda')
logit = model(imgs)
pred = torch.argmax(logit, dim=-1)
prob = torch.max(softmax(logit), dim=-1)
prob = prob[0].cpu().numpy()
pred = pred.detach().cpu().numpy()
for idx, p in enumerate(pred):
# print(f'{files[idx]}: {class_dict[p]:>10} {prob[idx]*100:.2f}%')
if prob[idx] > 0.75:
copyfile(f'{path}/{files[idx]}', f'{dest[p]}/{files[idx]}')
모든 유저의 image를 불러와 한 batch에 담아 학습시킨 resnet50에 무식하게 때려박고 있는 모습입니다... 마지막에 모델이 뱉은 logit을 softmax에 통과시켜 확률값이 0.75를 넘기는 경우에 한정하여 파일을 저장하고 있는데, 이는 확률이 낮게 나오는 경우 앞서 이야기 한 많은 예외들을 해결해주기 위한 것입니다. 사실 여러 모델을 학습하여 ensemble 하거나 bayesian inference를 통해 uncertainty를 측정하여 이러한 예외들을 처리할 수 있겠지만, 그건 너무 귀찮고 저의 목적을 위해서는 이 정도만 해도 충분한 것 같습니다. 닭 잡는데 소 잡는 칼을 쓸 필요는 없으니까요... (ResNet50 자체가 소 잡는 칼 같긴 함)
번외로 왜 굳이 resnet50을 사용했냐? 특별한 이유가 있는 것은 아니고 torchvision.models document에서 제일 먼저 눈에 들어와서 사용했습니다... 어떤 pretrained model을 사용했더라도 충분히 잘 분류했을 것이라고 생각이 듭니다. 굳이 parameter 수도 많은 resnet50을 써서 데이터 정리가 느려졌지만 그 틈을 타 오늘 한 뻘짓을 정리했으니 됐다고 생각합니다. 3090 GPU 2대에 1000명 정도 나누어 돌리니 2~3시간 정도 걸리는 것 같습니다...
코드는 https://github.com/leekichang/CT-axis-classification에 push 해두었으니 혹시 비슷한 문제가 있으시면 가져다 쓰시면 됩니다... Issue나 PR은 받을 여유가 없지 않을까... 아마 특별하게 관리하지는 않을 레포가 될 것 같습니다.
GitHub - leekichang/CT-axis-classification
Contribute to leekichang/CT-axis-classification development by creating an account on GitHub.
github.com
끗!
'삽질 노트' 카테고리의 다른 글
텐서보드 로그 파일 .csv로 정리하기 (Tensorboard log to csv) (0) | 2023.08.22 |
---|---|
TCP, HTTP, etc 터널링 with ngrok (0) | 2023.06.23 |
박사학위과정 전문연구요원 준비 (1) - 한국사능력검정시험 준비 후기 (1) | 2023.04.17 |
윈도우 caps lock -> 한/영 키로 만들기 (1) | 2023.01.04 |
윈도우 Terminal 꾸미기 (Windows terminal theme 설정, Oh-My-Posh, 배경이미지 넣기) (1) | 2023.01.04 |