Derivatives with a Computation Graph
Last updated
Was this helpful?
Last updated
Was this helpful?
上个视频中 我们看了一个使用计算图来计算函数J的例子 现在 让我们用一个简明的例子 说明如何用计算图来计算函数J 的导数 这是一个计算图 比方说你想计算J对于v的导数
它等于多少呢 也就是说 如果我们把v的值 改变一点点 J的值将会如何变化呢? J被定义为3乘以v 现在v等于11 因此如果我们把v提高一点点到11.001 那么J就从目前的33 提高到33.003 这里我们把v提高0.001 结果J增加了3个0.001 因此J对v的导数等于3 因为J的增量是v增量的3倍 事实上 这和我们之前视频的例子很相似 之前的视频中 f(a)=3a 我们推导得到一个简化的df/da 我们采用不那么严谨的写法 df/da=3 回到我们现在的例子,我们有J=3v 所以 dJ/dv=3 这里J的作用相当于之前例子里的f 而v相当于a 用反向传播这个术语来解释的话 如果你想要计算最终输出变量对于v的导数 而这也是你通常最关心的变量 这就是一步反向传播 我们把这个过程叫做图中的一步反向传播 现在我们来看另一个例子 dJ/da是什么 换句话说 如果我们增大一点a的值 J的值会如何变化呢
让我们仔细看一下这个例子,现在a=5 我们把它增大到5.001 那么对于v的影响,注意到v=a+u,原来是11 现在增加到11.001 正如我们在之前的例子中看到的那样 J现在从33增加到33.003 我们看到如果a增加0.001 J会增加0.003 刚才我说的增大a的意思是 在原来值为5的基础上 再加一个新的值 那么a的变化会在计算图中向右传播 结果J变成了33.003 因此J增量是a增量的3倍 这意味着J对a的导数为3 我们来分解一下这个过程 如果你改变了a v也会随之改变
v改变了 J也会改变 所以当你增大a时的时候J的改变量 也就是说当你将a的值改变一点点的时候
首先 因为a有变化所以v也随之变化 v的值增加多少呢 v的增加量取决于dv/da 然后v的改变使得J的值也改变了 在微积分中这叫做链式法则 a影响v v影响J 然后当你改变a的时候J的改变量等于 改变a时v的改变量乘以 改变v时J的改变量 再强调一下 在微积分中 这叫做链式法则 我们从这个例子中可以看到如果a增加0.001 v也变大了0.001 因此dv/da=1 事实上 如果把之前的式子代入 dv/dJ=3 dv/da=1 乘积是3乘1 我们得到dJ/da正确答案是3 这个小例子展示了如何通过计算dJ/dv 即J关于v的导数 来帮助你计算dJ/da 这是反向传播的另一步
接下来我要介绍另一种符号惯例 当你写反向传播代码的时候 那些你真正关心的 或者你想优化的最终输出变量 在这个例子里 最终输出变量是J 也就是计算图中的最后一个节点 因此你会做许多关于 最终输出变量的导数的计算 即这个FinalOutputVar(最终输出变量)对于其他变量 我们就叫它dvar 你会需要计算做许多关于最终输出变量导数的计算 在这个例子中是J 这会牵涉到许多中间变量 例如a b c u v 当你在程序中实现的时候 你给这些变量取什么名字呢 在Python中 你可以起一个很长的名字 比如dFinalOutputVar/dvar 但这是一个很长的变量名 你可以把这叫做dJdvar 但因为导数都是关于最终输出变量J的 我想引入一种新的记号 当你在代码中计算这个导数的时候 我们就用变量名dvar来代表这个值 所以dvar在你的代码里就代表 最终输出变量 比如J对它的导数 有时候 对于各种中间量的损失 在你代码的计算中 用dv来代表这个值 dv=3 在代码里 你用da代表这个值 da也等于3 通过这个计算图 我们介绍了一部分反向传播的知识 我们将在下一张幻灯片中继续这个例子 让我们换一张干净的图 让我们回顾一下 我们通过反向运算得到dv=3 dv只是一个变量名 它代表的其实是dJ/dv 我们已经计算出da=3 同样da也是dJ/da在代码中的变量名
我们推演了反向传播是如何在这两条边上实现的 现在让我们继续计算导数 现在看u的值 dJ/du是什么 我们来做一个跟之前类似的计算 从u=6开始 如果u从6变成6.001 那么v 原来是11 变成了11.001 J原来是33 变成了33.003 因此J增加了3倍u增加的量 关于u的分析与对a的分析非常相像 用dJ/dv乘以dv/du可以算出来 这一项dJ/dv我们已经算出来是3 这一项dv/du可以算出是1 所以我们又完成了一步反向传播 我们得出了du也等于3的结论 当然du指的是dJ/du 我们再来仔细的计算最后一个例子 dJ/db是什么 想一想如果你能改变b的值 你想通过改变b的值来最小化 或者最大化J的值 那么dJ/db这个导数 或者说当你稍微改变b的值 函数J的斜率是多少呢
使用链式法则来计算 dJ/db可以写成两项相乘 dJ/du乘上du/db 原因是如果你稍稍改变b的值 b从3变成3.001 b首先改变u 那么它会对u产生多大影响呢 u被定义为b乘c u一开始是6 当b=3时 u变成6.002 因为在我们的例子中c=2 这告诉我们du/db=2 当你把b增加0.001时 u增加两倍也就是0.002 因此du/db=2 现在我们知道u的变化是b的变化的2倍 那么dJ/du是什么 我们已经算出了这等于3 因此把这两项乘起来我们得到dJ/db=6 重复一下这个推导过程的第二部分 我们想知道当u增加了0.002的时候 J怎样变化 dJ/du=3告诉我们 当u增加0.002时 J应该增加这个数值的3倍 因此J应该增加0.006 这是因为dJ/du=3 如果举例计算一下 你会发现如果b变成3.001 那么u变成6.002 v变成11.002 这是a+u 也就是5+u 然后J等于3倍的v 结果就是33.006 以上就是你推导出dJ/db=6的过程 当我们反向传播的时候 这里是db=6 db是python中dJ/db的变量名 最后一个例子我就不详细说了 但如果你也算出dJ 这一项是dJ/du乘du 结果是9 也就是3乘3 这个例子我就不细讲了 通过最后一步 我们可以算出dc=9
这个视频 这个例子里最重要的东西是 当在计算导数的时候 最有效率的方式 是按红色箭头方向从右往左 特别的 我们先计算对v的导数 计算得到的结果 对于计算J对a的导数和J对u的导数很有用 然后J对u的导数 这一项和这一项 对于计算J对b的导数 和J对c的导数都有用 这就是计算图以及前向 或者说从左到右 计算代价函数 比如你想优化的J 以及如何反向或从右到左计算导数 如果你对微积分或者链式法则不熟悉 我知道有些细节可能过得很快 如果你没有跟上这些细节 别担心 这里这个 我们会在讲解逻辑回归时 再次复习这个视频中的内容 并且展示你要怎样实现一个计算图 在逻辑回归模型中计算导数 GTC字幕组 翻译