본문 바로가기

Data Science/Deep Learning

EfficientNet Pytorch 버전 사용

728x90

- feature extraction block의 weight가 변하지 않도록 freeze했을 때

def get_model(model_name='efficientnet-b0'):
    model = EfficientNet.from_pretrained(model_name)
    # In case you want to freeze the feature extraction blocks from EfficientNet, you need to add these two lines
    for param in model.parameters():
        param.requires_grad = False

    del model._fc
    # # # use the same head as the baseline notebook.
    model._fc = nn.Linear(1280, NUM_CLASSES)
    
    return model

마지막 Linear Layer를 제외한 모든 weight가 변하지 않도록 설정됨

반응형