Programming/Python

[Pytorch] 전이학습

콩순이컴퓨터 2023. 2. 14. 20:04
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

alexnet = models.alexnet().to(device)
resnet18 = models.resnet18().to(device)
vgg16 = models.vgg16().to(device)
densenet = models.densenet161().to(device)
inception = models.inception_v3().to(device)
googlenet = models.googlenet().to(device)
shufflenet = models.shufflenet_v2_x1_0().to(device)
mobilenet_v2 = models.mobilenet_v2().to(device)
resnext50_32x4d = models.resnext50_32x4d().to(device)
wide_resnet50_2 = models.wide_resnet50_2().to(device)
mnasnet = models.mnasnet1_0().to(device)

 

 

📚 사전 학습 모델

Resnet18 불러오기

model = torchvision.models.resnet18(pretrained=True)

# 필터 사이즈 7x7 -> 3x3
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)

# 마지막 출력 노드가 1000이기 때문에 데이터에 맞춰 10으로 변경
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model = model.to(device)

 

 


📚 모델 프리징

모델 불러오기

# 모델을 불러온 후, output layer의 노드를 클래스 수와 맞춘다. 
model = torchvision.models.alexnet(pretrained=True)
num_ftrs = model.classifier[6].in_features # 4096
model.classifier[6] = nn.Linear(num_ftrs,10)
model = model.to(device)

 

 

 

 

모델 파라미터명 확인하기

for i, (name, param) in enumerate(model.named_parameters()) :
    print(i, name)

 

 

 

 

 

변수 프리징하기

# requires_grad를 False로 하여 학습 시 업데이트가 되지 않게 한다.
for i, (name, parma) in enumerate(model.named_parameters()) :
    param.requires_grad = False
    # 합성곱 층에 대한 가중치와 편향(9번까지만)만 프리징 되면 for문을 멈춘다.
    if i == 9 :
        print('end')
        break

 

 

 

 

 

Requires_grad 확인하기

f_list = [0, 3, 6, 8, 10]
c_list = [1, 4, 6]
for i in f_list :
    print(model.features[i].weight.requires_grad)
    print(model.features[i].bias.requires_grad)
    
for j in c_list : 
    print(model.classifier[j].weight.requires_grad)
    print(model.classifier[j].bias.requires_grad)