博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
PyTorch学习系列(九)——参数_初始化
阅读量:7223 次
发布时间:2019-06-29

本文共 987 字,大约阅读时间需要 3 分钟。

from:http://blog.csdn.net/VictoriaW/article/details/72872036

之前我学习了

那么如何在pytorch里实现呢。

PyTorch提供了多种参数初始化函数:

  • torch.nn.init.constant(tensor, val)
  • torch.nn.init.normal(tensor, mean=0, std=1)
  • torch.nn.init.xavier_uniform(tensor, gain=1)
  • 等等。详细请参考:

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。

示例:

self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)init.xavier_uniform(self.conv1.weight)init.constant(self.conv1.bias, 0.1)

 

上面的语句是对网络的某一层参数进行初始化。如何对整个网络的参数进行初始化定制呢?

 
def weights_init(m):    classname=m.__class__.__name__    if classname.find('Conv') != -1:        xavier(m.weight.data)        xavier(m.bias.data)net = Net()net.apply(weights_init) #apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。

不建议访问以下划线为前缀的成员,他们是内部的,如果有改变不会通知用户。更推荐的一种方法是检查某个module是否是某种类型:

def weights_init(m):    if isinstance(m, nn.Conv2d):        xavier(m.weight.data)        xavier(m.bias.data)

 

转载于:https://www.cnblogs.com/lindaxin/p/8037561.html

你可能感兴趣的文章
samba服务
查看>>
Oracle数据库迁移方案
查看>>
linux磁盘及文件系统管理的部分知识一
查看>>
shell daemon init.d/functions 守护进程 lock文件
查看>>
我的友情链接
查看>>
删除exchange2010公用文件夹数据库
查看>>
python3实现excel里面读数据进行排序
查看>>
我的友情链接
查看>>
C# XML 文档注释
查看>>
Asp.Net文件的上传和下载
查看>>
Linux进程管理
查看>>
Spring Boot cache backed guava/caffeine
查看>>
windows下nginx+tomcat集群,实现session复制共享
查看>>
EHCache工具类
查看>>
Spring简单整合ibatis
查看>>
python之web模块学习-- httplib
查看>>
系统服务Windows Management Instrumentation如何修复
查看>>
聊聊前端面试的那些事
查看>>
Linux抓包工具tcpdump详解
查看>>
Go编程笔记(9)
查看>>