Pytorch v0.4のコードをv0.3で動かす際には.dataに注意[Pytorch]
Pytorchのコードを見ているとミニバッチごとのlossやaccuracyを計算する際、.dataを用いて値を取り出されることが頻繁にある。
よくある例:
for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):inp, target = dis_inp, dis_targetdis_opt.zero_grad()out = discriminator.batchClassify(inp)loss_fn = nn.BCELoss()loss = loss_fn(out, target)loss.backward()dis_opt.step()total_loss += loss.data # .dataを使ってlossから値を取り出しtotal_lossに蓄積total_acc += torch.sum((out>0.5)==(target>0.5)).data # .dataを使ってaccuracyを取り出しtotal_accに蓄積.dataを用いて値を取り出す上のコードは、pytorch v0.4で動かすならエラーがでない。
しかし、上のコードをpytorch v0.3で動かすと下のエラーが発生する。
# pytorch v0.3で動かした際のエラーTypeError: div_ received an invalid combination of arguments - got (float), but expected one of:* (int value)didn't match because some of the arguments have invalid types: (!float!)* (torch.cuda.ByteTensor other)didn't match because some of the arguments have invalid types: (!float!)なぜv0.4で動いていたコードがv0.3で動かなくなってしまうのか?
その原因は、pytorch v0.3でVariableの型の行列から値を取り出す際に.dataを使用するとデータの型がVariableからTensorに変わってしまうからである。
簡単な例で問題をみていこう。
# Pytorch v0.4で動かすと問題ないのだが、v0.3で動かすとエラーが発生するコード# 本コードではv0.3で動かしたことを想定import torchfrom torch.autograd import Variablex = Variable(torch.Tensor([1,2,3])) # ListからTensorに変換し、更にTensorをVariableに変換y = Variable(torch.Tensor([4,5,6])).data # .dataを使ってVariableから値を取り出す。# 返り値を表示。xにはVariableが格納されていることがわかる。x>>Variable containing:123
# 返り値を表示。yはVariableではなくTensorが格納されてしまった。y>>456
# TensorとVariableを加算するとエラーが発生x+y>>Traceback (most recent call last):File "", line 1, inx+yRuntimeError: add() received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:* (float other, float alpha)* (Variable other, float alpha)上の例のようにpytorch v0.3では、.dataで値を取り出すと型違いによるエラーが発生してしまう為、v0.4で動いていたコードが動かなくなる。
v0.4ではVariableとTensorが統合された為、型違いによるエラーが発生しないようだ。
v0.4公式がまとめた変更点(英語)
v0.4の変更点が日本語でまとめられた記事
ちなみに、v0.4の公式ドキュメントでは.dataを使うことはunsafeだと言っており、代わりに.detachを使うことが推奨されている。
What about .data ? のセクションに .dataを使うことの危険性が述べられている。
.dataで値を取り出した場合、xに対する変更がautogradで追跡できない為、危険視されているようだ。
最後に.detachを使用して書いた、v0.3とv0.4ともにエラーが発生しないコードを載せる。
# v0.3とv0.4ともにエラーが発生しないコードimport torchfrom torch.autograd import Variablex = Variable(torch.Tensor([1,2,3])) # ListからTensorに変換し、更にTensorをVariableに変換z = Variable(torch.Tensor([7,8,9])).detach()z>>Variable containing:789
# v0.3とv0.4ともにエラー無しで加算が行えるx+z>>Variable containing:81012追記
あまり良い方法ではないのだが、lossやaccuracyなど1x1の値を取り出す場合には、.data[0]を使う手もある。
.data[0]は、下のIrfan_Buluさんの回答を見て知った。