기계는 거짓말하지 않는다

Pytorch RuntimeError: Input type (~) and weight type (~) should be the same 본문

AI

Pytorch RuntimeError: Input type (~) and weight type (~) should be the same

KillinTime 2023. 5. 13. 00:10

Pytorch 모델 사용 시, 입력 텐서와 가중치 텐서의 데이터 유형이 서로 일치하지 않을 때 발생할 수 있는 오류이다.

아래는 오류 내용들이다.

 

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same ~

 

to() method를 이용해 같은 데이터 유형으로 만들어주면 된다.

# 예시 1
input = input.to(torch.float32)

# 예시 2
# 모델 가중치, 연산을 반 정밀도(half precision)로 변환할 때
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.half().to(device)
input = input.half().to(device)
Comments