기계는 거짓말하지 않는다

torch load_state_dict model key 변경 본문

AI

torch load_state_dict model key 변경

KillinTime 2023. 2. 3. 23:39

torch의 load를 이용하여 model weight를 불러올 때,

module. model 등이 key 값에 더 붙어있거나, key 이름이 다른 경우 변경하여 가지고 올 수 있다.

단, model 구조는 같아야 한다.

state_dict = checkpoint[state_key]

new_state_dict = {}

# load된 model의 key에 model. 이 붙어있을 경우 제거
for k, v in state_dict.items():
    if "model." in k:
        name = k[6:]
        new_state_dict[name] = v

print(new_state_dict.keys())

if len(new_state_dict.keys()) == 0:
    model.load_state_dict(state_dict)
else:
    model.load_state_dict(new_state_dict)

model 구조가 다르고 존재하는 key 값만 불러오고 싶다면

load_state_dict 매개변수의 strict=False를 이용한다.

model.load_state_dict(new_state_dict, strict=False)
Comments