Torch - nn 模块学习

Torch 模块学习

1. Serialization

Torch 提供了 4 种序列化/反序列化 Lua/Torch objects 的方法.

  • torch.save(filename, object [, format, referenced])

    将 object 写入 filename 文件.

    format 可以是 ascii 或 binary(默认).

    用例:

    -- arbitrary object:
    obj = {
     mat = torch.randn(10,10),
     name = '10',
     test = {
        entry = 1
     }
    }
    
    -- save to disk:
    torch.save('test.dat', obj)
  • [object] torch.load(filename [, format, referenced])

    从 filename 文件读入 objects.

    用例:

    -- given serialized object from section above, reload:
    obj = torch.load('test.dat')
    
    print(obj)
    -- will print:
    -- {[mat]  = DoubleTensor - size: 10x10
    --  [name] = string : "10"
    --  [test] = table - size: 0}
  • [str] torch.serialize(object)

    将 object 序列化为 string.

    用例:

    -- arbitrary object:
    obj = {
     mat = torch.randn(10,10),
     name = '10',
     test = {
        entry = 1
     }
    }
    
    -- serialize:
    str = torch.serialize(obj)
  • [object] torch.deserialize(str)

    从 string 反序列化 object

    用例:

    -- given serialized object from section above, deserialize:
    obj = torch.deserialize(str)
    
    print(obj)
    -- will print:
    -- {[mat]  = DoubleTensor - size: 10x10
    --  [name] = string : "10"
    --  [test] = table - size: 0}

2. Module

Module 是抽象类,定义了训练神经网络所必要的基础方法,通过成员函数来设计网络结构.

主要包含变量的两种状态:outputgradInput.

  • [output] forward(input)

    根据输入 input,计算对应的输出 output. 一般 input 和 output 是 Tensors.

    不推荐修改 forward 函数. 需要实现 updateOutput(input) 函数.

  • [gradInput] backward(input, gradOutput)

    根据输入 input,计算关于 input 的梯度操作;需要是在 forward 之后调用;网络优化;

    一般 input, gradOutput, gradInput 是 Tensors.

    BP 包含种类梯度的计算及对应的函数:

    关于输入变量的梯度计算 - updateGradInput(input, gradOutput)

    关于网络参数的梯度计算 - accGradParameters(input,gradOutput,scale), gradParamaters * scale,累加

  • zeroGradParameters()

    如果网络 module 已经有参数,该函数将关于其参数的累积梯度gradParameters清零.

  • updateParameters(learningRate)

    更新网络 module 参数:

    parameters=parameterslearningRategradients_wrt_parameters

  • accUpdateGradParameters(input, gradOutput, learningRate)

    累积参数梯度,并更新参数.

  • share(mlp,s1,s2,…,sn)

    修改 module

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值