Programming/Python
[Pytorch] 전이학습
콩순이컴퓨터
2023. 2. 14. 20:04
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
alexnet = models.alexnet().to(device)
resnet18 = models.resnet18().to(device)
vgg16 = models.vgg16().to(device)
densenet = models.densenet161().to(device)
inception = models.inception_v3().to(device)
googlenet = models.googlenet().to(device)
shufflenet = models.shufflenet_v2_x1_0().to(device)
mobilenet_v2 = models.mobilenet_v2().to(device)
resnext50_32x4d = models.resnext50_32x4d().to(device)
wide_resnet50_2 = models.wide_resnet50_2().to(device)
mnasnet = models.mnasnet1_0().to(device)
📚 사전 학습 모델
Resnet18 불러오기
model = torchvision.models.resnet18(pretrained=True)
# 필터 사이즈 7x7 -> 3x3
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
# 마지막 출력 노드가 1000이기 때문에 데이터에 맞춰 10으로 변경
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model = model.to(device)
📚 모델 프리징
모델 불러오기
# 모델을 불러온 후, output layer의 노드를 클래스 수와 맞춘다.
model = torchvision.models.alexnet(pretrained=True)
num_ftrs = model.classifier[6].in_features # 4096
model.classifier[6] = nn.Linear(num_ftrs,10)
model = model.to(device)
모델 파라미터명 확인하기
for i, (name, param) in enumerate(model.named_parameters()) :
print(i, name)
변수 프리징하기
# requires_grad를 False로 하여 학습 시 업데이트가 되지 않게 한다.
for i, (name, parma) in enumerate(model.named_parameters()) :
param.requires_grad = False
# 합성곱 층에 대한 가중치와 편향(9번까지만)만 프리징 되면 for문을 멈춘다.
if i == 9 :
print('end')
break
Requires_grad 확인하기
f_list = [0, 3, 6, 8, 10]
c_list = [1, 4, 6]
for i in f_list :
print(model.features[i].weight.requires_grad)
print(model.features[i].bias.requires_grad)
for j in c_list :
print(model.classifier[j].weight.requires_grad)
print(model.classifier[j].bias.requires_grad)