PyTorchで高階偏微分係数

1pt   2018-11-09 14:41
IT技術情報局

1f606-1.png はじめに

この記事ではPyTorchを使って高階の偏微分を求める方法を説明しています。

高階ではない微分係数

例えば、関数$f(x)=x^3$について、$left. frac{df(x)}{dx} right|_{x=4}$を求めたいとします。

>>> import torch >>> x = torch.tensor([4.], requires_grad=True) >>> f = x ** 3 >>> f.backward() >>> print(x.grad) tensor([48.])

backward()を呼べばいいだけで、これは皆さん普通に実行していることです。
(以下、import torchは省略します。)

1つの変数だけを扱う場合

2階の微分係数

先ほどの関数$f(x)=x^3$について、$x=4$における2階の微分係数$left. frac{d^2f(x)}{dx^2} right|_{x=4}$を求めたいとします。

>>> x = torch.tensor([4.], requires_grad=True) >>> f = x ** 3 >>> g = torch.autograd.grad(f, x, create_graph=True) >>> g (tensor([48.], grad_fn=<ThMulBackward>),) >>> g.backward() Traceback (most recent call last): File "<stdin>", line 1, in <module> AttributeError: 'tuple' object has no attribute 'backward' >>> g[0].backward() >>> x.grad tensor([24.])

torch.autograd.grad()を呼ぶときにcreate_graph=Trueとしているのがポイントです。
こうすると、微分係数(上の場合は48)だけでなく、
$f$の$x$に関する微分について計算グラフを作って、それも返してくれます。
すると、その計算グラフを使うことで、2階の微分係数が計算できるようになります。

上の例で、gと入力したとき

(tensor([48.], grad_fn=<ThMulBackward>),)

と表示されています。
48は1階の微分係数です。create_graph=Trueと設定しなかったら、これしか返してきません。実際、

>>> x = torch.tensor([4.], requires_grad=True) >>> f = x ** 3 >>> g = torch.autograd.grad(f, x) >>> g (tensor([48.]),)

となります。先ほどの例にあるgrad_fn=<ThMulBackward>というあやしい(?)表示が、
計算グラフも一緒に作ってくれている証拠です。

また、わざとエラーを出してみました1f606.png?resize=20%2C20&ssl=1
今は、$x$というひとつの変数についてしか微分していないのですが、
torch.autograd.grad()では、複数の変数の各々で偏微分する状況を扱うのが基本です。
そのため、返って来るのがいつでもtupleなんですね。
ひとつの変数しか考慮していなくても、要素数がひとつだけのtupleが帰ってきます。
そのため、

>>> g[0].backward()

と、tupleの最初の要素を使う旨、添え字0で指定しないと、エラーになります。

3階の微分係数

先ほどの関数$f(x)=x^3$について、$x=4$における3階の微分係数$left. frac{d^3f(x)}{dx^3} right|_{x=4}$を求めたいとします。

>>> x = torch.tensor([4.], requires_grad=True) >>> f = x ** 3 >>> g = torch.autograd.grad(f, x, create_graph=True) >>> h = torch.autograd.grad(g, x, create_graph=True) >>> h (tensor([24.], grad_fn=<ThMulBackward>),) >>> h[0].backward() >>> x.grad tensor([6.])

最初にtorch.autograd.grad()を呼んで戻ってきた計算グラフをそのまま、
次のtorch.autograd.grad()の呼び出しで使っています。
こうすると、この2回目の呼び出しは、2階の微分の計算グラフを返してきます。
それについてbackward()すれば、3階の微分係数を計算できます。
答えは6.と表示されています。これは$x^3$を3回$x$で微分すると定数$6$になるためです。

2個以上の変数を扱う場合

2階の偏微分係数

関数$f(x,y)=(x+2w)^3$について、まず、$left.frac{partial^2f(x,y)}{partial x^2}right|_{x=4,y=3}$と

$left. frac{partial^2 f(x,y)}{partial x partial y} right|_{x=4, y=3}$を求めてみます。

>>> x = torch.tensor([4.], requires_grad=True) >>> y = torch.tensor([3.], requires_grad=True) >>> f = (x + 2 * y) ** 3 >>> g = torch.autograd.grad(f, x, create_graph=True) >>> g (tensor([300.], grad_fn=<ThMulBackward>),) >>> g[0].backward() >>> x.grad tensor([60.]) >>> y.grad tensor([120.])

$left.frac{partial^2f(x,y)}{partial x^2}right|_{x=4,y=3} = 60$、そして

$left. frac{partial^2 f(x,y)}{partial x partial y} right|_{x=4, y=3} = 120$です。

では、次に、$left.frac{partial^2f(x,y)}{partial y partial x}right|_{x=4,y=3}$と

$left. frac{partial^2 f(x,y)}{partial y^2} right|_{x=4, y=3}$を求めてみます。

>>> x = torch.tensor([4.], requires_grad=True) >>> y = torch.tensor([3.], requires_grad=True) >>> f = (x + 2 * y) ** 3 >>> g = torch.autograd.grad(f, y, create_graph=True) >>> g (tensor([600.], grad_fn=<MulBackward>),) >>> g[0].backward() >>> x.grad tensor([120.]) >>> y.grad tensor([240.])

当然ですが、$left. frac{partial^2 f(x,y)}{partial x partial y} right|_{x=4, y=3}$と

$left.frac{partial^2f(x,y)}{partial y partial x}right|_{x=4,y=3}$は、同じ値120になります。

上のふたつの作業をまとめて実行しようとすると・・・

>>> x = torch.tensor([4.], requires_grad=True) >>> y = torch.tensor([3.], requires_grad=True) >>> f = (x + 2 * y) ** 3 >>> g = torch.autograd.grad(f, (x, y), create_graph=True) >>> g (tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>)) >>> g[0].backward() >>> x.grad tensor([60.]) >>> y.grad tensor([120.]) >>> g[1].backward() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward allow_unreachable=True) # allow_unreachable flag RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. >>> g = torch.autograd.grad(f, (x, y), create_graph=True) >>> g[1].backward() >>> x.grad tensor([180.]) >>> y.grad tensor([360.])

と、エラーが出てしまいました。
それに関して偏微分を求めたい2つの変数を、(x,y)とtupleにしてtorch.autograd.grad()に渡すと、

>>> g (tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>))

という箇所にあるように、それぞれの変数に関する1階の微分係数をtupleにして返してくれます。
しかし、g[0].backward()のあとに続けてg[1].backward()を実行したため、エラーになりました。
上の例では

>>> g = torch.autograd.grad(f, (x, y), create_graph=True)

と、もう一度torch.autograd.grad()を呼んでいます。

しかし、こういう使い方はしないでしょう。

>>> x = torch.tensor([4.], requires_grad=True) >>> y = torch.tensor([3.], requires_grad=True) >>> f = (x + 2 * y) ** 3 >>> g = torch.autograd.grad(f, (x, y), create_graph=True) >>> h = g[0] + g[1] >>> h.backward() >>> x.grad tensor([180.]) >>> y.grad tensor([360.])

のように、g[0]とg[1]を含む計算グラフを作って、それについてbackward()を呼べば、
特に問題はありません。普通はこういう使い方をします。

torch.autograd.grad()は複数の変数に関して偏微分をとった結果、つまり勾配をtupleとして返してきますが、
このtupleの要素を組み合わせて作った計算グラフについてまた微分をとる、という使い方が普通です。

例えば、WGANで使う勾配のL2ノルムの場合も、
torch.autograd.grad()が返してきたtupleの要素を組み合わせて計算グラフを作っていることになります。

おわりに

PyTorchの自動微分、便利です。

Source: python tag

   ITアンテナトップページへ
情報処理/ITの話題が沢山。