In [6]:
! curl -L -o dataset.zip https://www.kaggle.com/api/v1/datasets/download/fuyadhasanbhoyan/knee-osteoarthritis-classification-224224
! unzip dataset.zip
A saída de streaming foi truncada nas últimas 5000 linhas.
inflating: Knee Osteoarthritis Classification/test/Osteoporosis/Osteoporosis 720.jpg
inflating: Knee Osteoarthritis Classification/test/Osteoporosis/Osteoporosis 721.jpg
inflating: Knee Osteoarthritis Classification/test/Osteoporosis/Osteoporosis 721_aug_0.jpeg
inflating: Knee Osteoarthritis Classification/test/Osteoporosis/Osteoporosis 722.jpg
inflating: Knee Osteoarthritis Classification/test/Osteoporosis/Osteoporosis 722_aug_0.jpeg
inflating: Knee Osteoarthritis Classification/test/Osteoporosis/Osteoporosis 723.jpg
inflating: Knee Osteoarthritis Classification/val/Osteoporosis/Osteoporosis 712_aug_0.jpeg
inflating: Knee Osteoarthritis Classification/val/Osteoporosis/Osteoporosis 713.jpg
inflating: Knee Osteoarthritis Classification/val/Osteoporosis/Osteoporosis 713_aug_0.jpeg
inflating: Knee Osteoarthritis Classification/val/Osteoporosis/Osteoporosis 714.jpg
In [7]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import imageio as io
from PIL import Image
import os
class_labels = ['normal', 'osteopenia', 'osteoporosis']
class KneeOsteoDataset(Dataset):
def __init__(self, path) -> None:
dirs = os.listdir(os.path.join(path))
data = []
for dir in dirs:
for file in os.listdir(os.path.join(path, dir)):
data.append({ 'cls': dir.lower(), 'path': os.path.join(path, dir, file) })
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
try:
image = Image.open(item['path']).convert('RGB')
image = self.transform(image)
label = class_labels.index(item['cls'])
return (label, image)
except:
print(f'Error loading image {item["path"]}')
return self.__getitem__((index + 1) % len(self.data))
In [8]:
train_dataset = KneeOsteoDataset('./Knee Osteoarthritis Classification/train')
val_dataset = KneeOsteoDataset('./Knee Osteoarthritis Classification/val')
test_dataset = KneeOsteoDataset('./Knee Osteoarthritis Classification/test')
In [9]:
train_loader = DataLoader(train_dataset, batch_size=20,
shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=20,
shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=20,
shuffle=True, num_workers=0)
In [10]:
class_labels = ['normal', 'osteopenia', 'osteoporosis']
In [1]:
import torchvision.models as models
from torch import nn
from torch.nn import CrossEntropyLoss, Linear
from torch.optim import Adam
resnet18 = models.resnet18(pretrained=True)
resnet18.fc = nn.Sequential(
nn.Linear(resnet18.fc.in_features, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 3)
)
/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth 100%|██████████| 44.7M/44.7M [00:00<00:00, 110MB/s]
In [ ]:
In [12]:
from torch import cuda, max, sum, no_grad
import tqdm
train_accs = []
train_losses = []
val_accs = []
val_losses = []
def train(model, train_loader, val_loader, epochs):
use_cuda = cuda.is_available()
device = "cuda" if cuda.is_available() else "cpu"
model.to(device)
criterion = CrossEntropyLoss()
optimizer = Adam(resnet18.fc.parameters(), lr=3e-4, weight_decay=0.001)
for epoch in range(epochs):
model.train()
running_loss = 0.0
running_corrects = 0
for labels, inputs in tqdm.tqdm(train_loader):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += sum(preds == labels.data)
# Compute average training loss and accuracy for this epoch
train_loss = running_loss / len(train_loader.dataset)
train_acc = running_corrects.float() / len(train_loader.dataset)
train_losses.append(train_loss)
train_accs.append(train_acc)
model.eval()
running_loss = 0.0
running_corrects = 0
with no_grad():
for labels, inputs in val_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
# Get the predicted class (with the highest score)
_, preds = max(outputs, 1)
# Compute the loss between the predictions and actual labels
loss = criterion(outputs, labels)
# Accumulate the running loss and the number of correct predictions
running_loss += loss.item() * inputs.size(0)
running_corrects += sum(preds == labels.data)
val_loss = running_loss / len(val_loader.dataset)
val_acc = running_corrects.float() / len(val_loader.dataset)
val_losses.append(val_loss)
val_accs.append(val_acc)
# Print the results for the current epoch
print(f'Epoch [{epoch+1}/{epochs}], train loss: {train_loss:.4f}, train acc: {train_acc:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.4f}')
In [ ]:
[x.item() for x in train_accs]
Out[ ]:
[]
In [ ]:
train(resnet18, train_loader, val_loader, 100)
11%|█ | 18/171 [00:01<00:13, 11.63it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.48it/s]
Epoch [1/100], train loss: 0.9397, train acc: 0.5451, val loss: 0.8926, val acc: 0.6352
37%|███▋ | 64/171 [00:05<00:09, 11.62it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.52it/s]
Epoch [2/100], train loss: 0.8440, train acc: 0.6133, val loss: 0.8681, val acc: 0.6380
21%|██ | 36/171 [00:03<00:11, 11.44it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.44it/s]
Epoch [3/100], train loss: 0.8287, train acc: 0.6291, val loss: 0.8411, val acc: 0.6435
13%|█▎ | 22/171 [00:01<00:12, 11.83it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.65it/s]
Epoch [4/100], train loss: 0.7813, train acc: 0.6511, val loss: 0.8835, val acc: 0.5657
89%|████████▉ | 152/171 [00:13<00:01, 11.71it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.58it/s]
Epoch [5/100], train loss: 0.7639, train acc: 0.6537, val loss: 0.8210, val acc: 0.6398
85%|████████▌ | 146/171 [00:12<00:02, 11.63it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.61it/s]
Epoch [6/100], train loss: 0.7291, train acc: 0.6771, val loss: 0.8811, val acc: 0.6407
6%|▌ | 10/171 [00:00<00:13, 11.71it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.56it/s]
Epoch [7/100], train loss: 0.7321, train acc: 0.6733, val loss: 0.8254, val acc: 0.6537
53%|█████▎ | 90/171 [00:07<00:06, 11.74it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.49it/s]
Epoch [8/100], train loss: 0.7186, train acc: 0.6812, val loss: 0.8142, val acc: 0.6722
23%|██▎ | 40/171 [00:03<00:11, 11.56it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.60it/s]
Epoch [9/100], train loss: 0.6928, train acc: 0.6906, val loss: 0.8319, val acc: 0.6537
23%|██▎ | 40/171 [00:03<00:11, 11.21it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.47it/s]
Epoch [10/100], train loss: 0.6924, train acc: 0.7026, val loss: 0.8100, val acc: 0.6519
100%|██████████| 171/171 [00:14<00:00, 11.65it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg Epoch [11/100], train loss: 0.6660, train acc: 0.7076, val loss: 0.8058, val acc: 0.6861
71%|███████▏ | 122/171 [00:10<00:04, 11.94it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.63it/s]
Epoch [12/100], train loss: 0.6690, train acc: 0.7090, val loss: 0.7938, val acc: 0.6546
84%|████████▍ | 144/171 [00:12<00:02, 11.36it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.63it/s]
Epoch [13/100], train loss: 0.6433, train acc: 0.7166, val loss: 0.8485, val acc: 0.6028
41%|████ | 70/171 [00:06<00:09, 10.79it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.61it/s]
Epoch [14/100], train loss: 0.6478, train acc: 0.7114, val loss: 0.7690, val acc: 0.6843
39%|███▉ | 67/171 [00:05<00:08, 11.76it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.47it/s]
Epoch [15/100], train loss: 0.6338, train acc: 0.7181, val loss: 0.7881, val acc: 0.6824
7%|▋ | 12/171 [00:01<00:13, 11.83it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.67it/s]
Epoch [16/100], train loss: 0.6337, train acc: 0.7201, val loss: 0.8350, val acc: 0.6426
2%|▏ | 4/171 [00:00<00:13, 11.93it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.56it/s]
Epoch [17/100], train loss: 0.6240, train acc: 0.7260, val loss: 0.7856, val acc: 0.6444
5%|▍ | 8/171 [00:00<00:13, 11.87it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.60it/s]
Epoch [18/100], train loss: 0.6230, train acc: 0.7245, val loss: 0.7756, val acc: 0.6731
16%|█▋ | 28/171 [00:02<00:12, 11.81it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.61it/s]
Epoch [19/100], train loss: 0.6023, train acc: 0.7368, val loss: 0.7509, val acc: 0.7009
83%|████████▎ | 142/171 [00:12<00:02, 10.72it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.56it/s]
Epoch [20/100], train loss: 0.6041, train acc: 0.7395, val loss: 0.7978, val acc: 0.6657
90%|█████████ | 154/171 [00:13<00:01, 11.86it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.66it/s]
Epoch [21/100], train loss: 0.5907, train acc: 0.7436, val loss: 0.7749, val acc: 0.6954
83%|████████▎ | 142/171 [00:12<00:02, 11.71it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.53it/s]
Epoch [22/100], train loss: 0.5877, train acc: 0.7471, val loss: 0.7930, val acc: 0.6667
40%|███▉ | 68/171 [00:05<00:08, 11.82it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.60it/s]
Epoch [23/100], train loss: 0.5794, train acc: 0.7538, val loss: 0.7794, val acc: 0.6954
94%|█████████▎| 160/171 [00:13<00:00, 11.73it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.59it/s]
Epoch [24/100], train loss: 0.5788, train acc: 0.7488, val loss: 0.8036, val acc: 0.6417
36%|███▋ | 62/171 [00:05<00:09, 11.72it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.64it/s]
Epoch [25/100], train loss: 0.5751, train acc: 0.7477, val loss: 0.8680, val acc: 0.5981
6%|▌ | 10/171 [00:00<00:13, 11.78it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.64it/s]
Epoch [26/100], train loss: 0.5670, train acc: 0.7702, val loss: 0.7877, val acc: 0.6565
88%|████████▊ | 150/171 [00:12<00:01, 10.93it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.52it/s]
Epoch [27/100], train loss: 0.5591, train acc: 0.7635, val loss: 0.7901, val acc: 0.6685
70%|███████ | 120/171 [00:10<00:04, 11.52it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.57it/s]
Epoch [28/100], train loss: 0.5497, train acc: 0.7605, val loss: 0.7523, val acc: 0.6852
32%|███▏ | 54/171 [00:04<00:10, 11.52it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.53it/s]
Epoch [29/100], train loss: 0.5445, train acc: 0.7749, val loss: 0.7720, val acc: 0.6898
6%|▌ | 10/171 [00:00<00:13, 11.86it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.60it/s]
Epoch [30/100], train loss: 0.5318, train acc: 0.7722, val loss: 0.7824, val acc: 0.6639
44%|████▍ | 76/171 [00:06<00:08, 11.49it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.57it/s]
Epoch [31/100], train loss: 0.5063, train acc: 0.7886, val loss: 0.7523, val acc: 0.7019
96%|█████████▌| 164/171 [00:14<00:00, 11.83it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.64it/s]
Epoch [32/100], train loss: 0.5142, train acc: 0.7837, val loss: 0.7407, val acc: 0.6880
92%|█████████▏| 158/171 [00:13<00:01, 11.81it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.64it/s]
Epoch [33/100], train loss: 0.5183, train acc: 0.7790, val loss: 0.7964, val acc: 0.7176
87%|████████▋ | 148/171 [00:12<00:01, 11.73it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.49it/s]
Epoch [34/100], train loss: 0.5014, train acc: 0.7848, val loss: 0.7887, val acc: 0.6963
62%|██████▏ | 106/171 [00:09<00:06, 10.72it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.62it/s]
Epoch [35/100], train loss: 0.5042, train acc: 0.7869, val loss: 0.7620, val acc: 0.7056
99%|█████████▉| 170/171 [00:14<00:00, 11.67it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.61it/s]
Epoch [36/100], train loss: 0.4892, train acc: 0.7913, val loss: 0.7482, val acc: 0.7028
55%|█████▍ | 94/171 [00:08<00:06, 11.74it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.61it/s]
Epoch [37/100], train loss: 0.4964, train acc: 0.7910, val loss: 0.7700, val acc: 0.6898
64%|██████▍ | 110/171 [00:09<00:05, 11.77it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.62it/s]
Epoch [38/100], train loss: 0.4787, train acc: 0.7974, val loss: 0.7347, val acc: 0.7222
92%|█████████▏| 158/171 [00:13<00:01, 11.14it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.57it/s]
Epoch [39/100], train loss: 0.4859, train acc: 0.7924, val loss: 0.7652, val acc: 0.7019
87%|████████▋ | 148/171 [00:12<00:01, 11.82it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.69it/s]
Epoch [40/100], train loss: 0.4772, train acc: 0.7983, val loss: 0.7862, val acc: 0.6907
49%|████▉ | 84/171 [00:07<00:07, 11.65it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.54it/s]
Epoch [41/100], train loss: 0.4737, train acc: 0.8030, val loss: 0.7158, val acc: 0.7259
58%|█████▊ | 100/171 [00:08<00:06, 11.26it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.57it/s]
Epoch [42/100], train loss: 0.4751, train acc: 0.7892, val loss: 0.7744, val acc: 0.7046
16%|█▋ | 28/171 [00:02<00:12, 11.72it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.61it/s]
Epoch [43/100], train loss: 0.4524, train acc: 0.8112, val loss: 0.7758, val acc: 0.7102
67%|██████▋ | 114/171 [00:09<00:04, 11.53it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.51it/s]
Epoch [44/100], train loss: 0.4511, train acc: 0.8097, val loss: 0.8025, val acc: 0.6917
9%|▉ | 16/171 [00:01<00:13, 11.68it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.57it/s]
Epoch [45/100], train loss: 0.4536, train acc: 0.8115, val loss: 0.7699, val acc: 0.7028
37%|███▋ | 64/171 [00:05<00:09, 11.69it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.47it/s]
Epoch [46/100], train loss: 0.4517, train acc: 0.8068, val loss: 0.7903, val acc: 0.7139
73%|███████▎ | 124/171 [00:10<00:03, 11.76it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.59it/s]
Epoch [47/100], train loss: 0.4633, train acc: 0.8039, val loss: 0.8258, val acc: 0.6815
6%|▌ | 10/171 [00:00<00:13, 11.68it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.47it/s]
Epoch [48/100], train loss: 0.4457, train acc: 0.8150, val loss: 0.7984, val acc: 0.6796
43%|████▎ | 74/171 [00:06<00:08, 11.57it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
100%|██████████| 171/171 [00:14<00:00, 11.50it/s]
Epoch [49/100], train loss: 0.4314, train acc: 0.8290, val loss: 0.7295, val acc: 0.7120
40%|███▉ | 68/171 [00:06<00:09, 11.21it/s]
Error loading image ./Knee Osteoarthritis Classification/train/Normal/Normal 280_aug_0.jpeg
44%|████▍ | 76/171 [00:06<00:08, 11.14it/s]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-62-c807287d08b0> in <cell line: 0>() ----> 1 train(resnet18, train_loader, val_loader, 100) <ipython-input-61-33725f377b7d> in train(model, train_loader, val_loader, epochs) 35 optimizer.step() 36 ---> 37 running_loss += loss.item() * inputs.size(0) 38 running_corrects += sum(preds == labels.data) 39 KeyboardInterrupt:
In [ ]:
import matplotlib.pyplot as plt
plt.plot([x.item() for x in train_accs], label='train_acc')
plt.plot(train_losses, label='train_losses')
plt.legend()
plt.show()
In [ ]:
val_losses
Out[ ]:
[3, 2]
In [ ]:
import matplotlib.pyplot as plt
plt.plot([x.item() for x in val_accs], label='val_accs')
plt.plot(val_losses, label='val_losses')
plt.legend()
plt.show()
In [38]:
import numpy as np
import matplotlib.pyplot as plt
def predict(model, loader):
use_cuda = cuda.is_available()
device = "cuda" if cuda.is_available() else "cpu"
model.to(device)
running_corrects = 0
with no_grad():
for labels, inputs in loader:
inputs = inputs.to(device)
labels = labels.to(device)
# plot input
print(inputs[0][0].cpu().min())
outputs = model(inputs)
# Get the predicted class (with the highest score)
_, preds = max(outputs, 1)
# Compute the loss between the predictions and actual labels
# loss = criterion(outputs, labels)
# Accumulate the running loss and the number of correct predictions
# running_loss += loss.item() * inputs.size(0)
running_corrects += sum(preds == labels.data)
# val_loss = running_loss / len(val_loader.dataset)
acc = running_corrects.float() / len(loader.dataset)
return acc
In [39]:
predict(resnet18, test_loader)
tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179) tensor(-2.1179)
Out[39]:
tensor(0.7222, device='cuda:0')
In [ ]:
# Save model!
import torch
torch.save(resnet18.state_dict(), 'resnet18_knee.pth')
In [3]:
# Load model
import torch
resnet18.load_state_dict(torch.load('resnet18_knee.pth'))
Out[3]:
<All keys matched successfully>
Gradcam¶
In [4]:
!pip install grad-cam
Collecting grad-cam Downloading grad-cam-1.5.4.tar.gz (7.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/7.8 MB ? eta -:--:-- ━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.5/7.8 MB 43.8 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━ 3.6/7.8 MB 45.4 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━ 6.8/7.8 MB 58.7 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 7.8/7.8 MB 61.9 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 46.5 MB/s eta 0:00:00 Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from grad-cam) (2.0.2) Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from grad-cam) (11.1.0) Requirement already satisfied: torch>=1.7.1 in /usr/local/lib/python3.11/dist-packages (from grad-cam) (2.6.0+cu124) Requirement already satisfied: torchvision>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from grad-cam) (0.21.0+cu124) Collecting ttach (from grad-cam) Downloading ttach-0.0.3-py3-none-any.whl.metadata (5.2 kB) Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from grad-cam) (4.67.1) Requirement already satisfied: opencv-python in /usr/local/lib/python3.11/dist-packages (from grad-cam) (4.11.0.86) Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (from grad-cam) (3.10.0) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (from grad-cam) (1.6.1) Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (3.18.0) Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (4.12.2) Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (3.4.2) Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (3.1.6) Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (2025.3.0) Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.7.1->grad-cam) Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.7.1->grad-cam) Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.7.1->grad-cam) Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.7.1->grad-cam) Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.7.1->grad-cam) Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.7.1->grad-cam) Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.7.1->grad-cam) Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.7.1->grad-cam) Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=1.7.1->grad-cam) Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (0.6.2) Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (2.21.5) Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (12.4.127) Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=1.7.1->grad-cam) Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (3.2.0) Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7.1->grad-cam) (1.13.1) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.7.1->grad-cam) (1.3.0) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (1.3.1) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (4.56.0) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (1.4.8) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (24.2) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (3.2.1) Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib->grad-cam) (2.8.2) Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->grad-cam) (1.14.1) Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->grad-cam) (1.4.2) Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->grad-cam) (3.6.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib->grad-cam) (1.17.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.7.1->grad-cam) (3.0.2) Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 4.4 MB/s eta 0:00:00 Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 51.7 MB/s eta 0:00:00 Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 22.2 MB/s eta 0:00:00 Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 44.9 MB/s eta 0:00:00 Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 2.8 MB/s eta 0:00:00 Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 5.4 MB/s eta 0:00:00 Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 13.0 MB/s eta 0:00:00 Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 7.0 MB/s eta 0:00:00 Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 6.4 MB/s eta 0:00:00 Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 87.3 MB/s eta 0:00:00 Downloading ttach-0.0.3-py3-none-any.whl (9.8 kB) Building wheels for collected packages: grad-cam Building wheel for grad-cam (pyproject.toml) ... done Created wheel for grad-cam: filename=grad_cam-1.5.4-py3-none-any.whl size=39682 sha256=351af1ea8326d231b2367f8d6f666b25e4647935de0f0efe3e24e1af54efdc06 Stored in directory: /root/.cache/pip/wheels/8b/0d/d2/b12bec1ccc028921fb98158042ade2d19dae73925dfc636954 Successfully built grad-cam Installing collected packages: ttach, nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, grad-cam Attempting uninstall: nvidia-nvjitlink-cu12 Found existing installation: nvidia-nvjitlink-cu12 12.5.82 Uninstalling nvidia-nvjitlink-cu12-12.5.82: Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82 Attempting uninstall: nvidia-curand-cu12 Found existing installation: nvidia-curand-cu12 10.3.6.82 Uninstalling nvidia-curand-cu12-10.3.6.82: Successfully uninstalled nvidia-curand-cu12-10.3.6.82 Attempting uninstall: nvidia-cufft-cu12 Found existing installation: nvidia-cufft-cu12 11.2.3.61 Uninstalling nvidia-cufft-cu12-11.2.3.61: Successfully uninstalled nvidia-cufft-cu12-11.2.3.61 Attempting uninstall: nvidia-cuda-runtime-cu12 Found existing installation: nvidia-cuda-runtime-cu12 12.5.82 Uninstalling nvidia-cuda-runtime-cu12-12.5.82: Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82 Attempting uninstall: nvidia-cuda-nvrtc-cu12 Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82 Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82: Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82 Attempting uninstall: nvidia-cuda-cupti-cu12 Found existing installation: nvidia-cuda-cupti-cu12 12.5.82 Uninstalling nvidia-cuda-cupti-cu12-12.5.82: Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82 Attempting uninstall: nvidia-cublas-cu12 Found existing installation: nvidia-cublas-cu12 12.5.3.2 Uninstalling nvidia-cublas-cu12-12.5.3.2: Successfully uninstalled nvidia-cublas-cu12-12.5.3.2 Attempting uninstall: nvidia-cusparse-cu12 Found existing installation: nvidia-cusparse-cu12 12.5.1.3 Uninstalling nvidia-cusparse-cu12-12.5.1.3: Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3 Attempting uninstall: nvidia-cudnn-cu12 Found existing installation: nvidia-cudnn-cu12 9.3.0.75 Uninstalling nvidia-cudnn-cu12-9.3.0.75: Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75 Attempting uninstall: nvidia-cusolver-cu12 Found existing installation: nvidia-cusolver-cu12 11.6.3.83 Uninstalling nvidia-cusolver-cu12-11.6.3.83: Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83 Successfully installed grad-cam-1.5.4 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 ttach-0.0.3
In [63]:
from pytorch_grad_cam import GradCAM
cam = GradCAM(model=resnet18, target_layers=[resnet18.layer4[-1]])
sample = next(iter(test_loader))
cam_out = cam(input_tensor=sample[1])
In [83]:
import matplotlib.pyplot as plt
idx = 10
plt.imshow(cam_out[idx], cmap='jet')
plt.imshow(sample[1][idx][0], cmap='gray', alpha=0.5)
plt.title(class_labels[sample[0][idx]])
Out[83]:
Text(0.5, 1.0, 'osteopenia')