LSTM介绍

关于LSTM的具体原理,可以参考:

https://www.jb51.net/article/178582.htm

https://www.jb51.net/article/178423.htm

系列文章:

PyTorch搭建双向LSTM实现时间序列负荷预测

PyTorch搭建LSTM实现多变量多步长时序负荷预测

PyTorch搭建LSTM实现多变量时序负荷预测

PyTorch搭建LSTM实现时间序列负荷预测

LSTM参数

关于nn.LSTM的参数,官方文档给出的解释为:

总共有七个参数,其中只有前三个是必须的。由于大家普遍使用PyTorch的DataLoader来形成批量数据,因此batch_first也比较重要。LSTM的两个常见的应用场景为文本处理和时序预测,因此下面对每个参数我都会从这两个方面来进行具体解释。

  • input_size:在文本处理中,由于一个单词没法参与运算,因此我们得通过Word2Vec来对单词进行嵌入表示,将每一个单词表示成一个向量,此时input_size=embedding_size。
  • 比如每个句子中有五个单词,每个单词用一个100维向量来表示,那么这里input_size=100;
  • 在时间序列预测中,比如需要预测负荷,每一个负荷都是一个单独的值,都可以直接参与运算,因此并不需要将每一个负荷表示成一个向量,此时input_size=1。
  • 但如果我们使用多变量进行预测,比如我们利用前24小时每一时刻的[负荷、风速、温度、压强、湿度、天气、节假日信息]来预测下一时刻的负荷,那么此时input_size=7。
  • hidden_size:隐藏层节点个数。可以随意设置。
  • num_layers:层数。nn.LSTMCell与nn.LSTM相比,num_layers默认为1。
  • batch_first:默认为False,意义见后文。

Inputs

关于LSTM的输入,官方文档给出的定义为:

可以看到,输入由两部分组成:input、(初始的隐状态h_0,初始的单元状态c_0)

其中input:

input(seq_len, batch_size, input_size)
  • seq_len:在文本处理中,如果一个句子有7个单词,则seq_len=7;在时间序列预测中,假设我们用前24个小时的负荷来预测下一时刻负荷,则seq_len=24。
  • batch_size:一次性输入LSTM中的样本个数。在文本处理中,可以一次性输入很多个句子;在时间序列预测中,也可以一次性输入很多条数据。
  • input_size:见前文。

(h_0, c_0):

h_0(num_directions * num_layers, batch_size, hidden_size)
c_0(num_directions * num_layers, batch_size, hidden_size)

h_0和c_0的shape一致。

  • num_directions:如果是双向LSTM,则num_directions=2;否则num_directions=1。
  • num_layers:见前文。
  • batch_size:见前文。
  • hidden_size:见前文。

Outputs

关于LSTM的输出,官方文档给出的定义为:

可以看到,输出也由两部分组成:otput、(隐状态h_n,单元状态c_n)

其中output的shape为:

output(seq_len, batch_size, num_directions * hidden_size)

h_n和c_n的shape保持不变,参数解释见前文。

batch_first

如果在初始化LSTM时令batch_first=True,那么input和output的shape将由:

input(seq_len, batch_size, input_size)
output(seq_len, batch_size, num_directions * hidden_size)

变为:

input(batch_size, seq_len, input_size)
output(batch_size, seq_len, num_directions * hidden_size)

即batch_size提前。

案例

简单搭建一个LSTM如下所示:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.num_directions = 1 # 单向LSTM
        self.batch_size = batch_size
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, self.output_size)
    def forward(self, input_seq):
        h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        seq_len = input_seq.shape[1] # (5, 30)
        # input(batch_size, seq_len, input_size)
        input_seq = input_seq.view(self.batch_size, seq_len, 1)  # (5, 30, 1)
        # output(batch_size, seq_len, num_directions * hidden_size)
        output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
        output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)
        pred = self.linear(output) # pred(150, 1)
        pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)
        pred = pred[:, -1, :]  # (5, 1)
        return pred

其中定义模型的代码为:

self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)

我们加上具体的数字:

self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True)
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

再看前向传播:

def forward(self, input_seq):
    h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
    c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
    seq_len = input_seq.shape[1]  # (5, 30)
    # input(batch_size, seq_len, input_size)
    input_seq = input_seq.view(self.batch_size, seq_len, 1)  # (5, 30, 1)
    # output(batch_size, seq_len, num_directions * hidden_size)
    output, _ = self.lstm(input_seq, (h_0, c_0))  # output(5, 30, 64)
    output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size)  # (5 * 30, 64)
    pred = self.linear(output) # (150, 1)
    pred = pred.view(self.batch_size, seq_len, -1)  # (5, 30, 1)
    pred = pred[:, -1, :]  # (5, 1)
    return pred

假设用前30个预测下一个,则seq_len=30,batch_size=5,由于设置了batch_first=True,因此,输入到LSTM中的input的shape应该为:

input(batch_size, seq_len, input_size) = input(5, 30, 1)

但实际上,经过DataLoader处理后的input_seq为:

input_seq(batch_size, seq_len) = input_seq(5, 30)

(5, 30)表示一共5条数据,每条数据的维度都为30。为了匹配LSTM的输入,我们需要对input_seq的shape进行变换:

input_seq = input_seq.view(self.batch_size, seq_len, 1)  # (5, 30, 1)

然后将input_seq送入LSTM:

output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)

根据前文,output的shape为:

output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)

全连接层的定义为:

self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

因此,我们需要将output的第二维度变换为64(150, 64):

output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)

然后将output送入全连接层:

pred = self.linear(output) # pred(150, 1)

得到的预测值shape为(150, 1)。我们需要将其进行还原,变成(5, 30, 1):

pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)

在用DataLoader处理了数据后,得到的input_seq和label的shape分别为:

input_seq(batch_size, seq_len) = input_seq(5, 30)label(batch_size, output_size) = label(5, 1)

由于输出是输入右移,我们只需要取pred第二维度(time)中的最后一个数据:

pred = pred[:, -1, :] # (5, 1)

这样,我们就得到了预测值,然后与label求loss,然后再反向更新参数即可。

时间序列预测的一个真实案例请见:PyTorch搭建LSTM实现时间序列预测(负荷预测)

以上就是PyTorch深度学习LSTM从input输入到Linear输出的详细内容,更多关于LSTM input输入Linear输出的资料请关注Devmax其它相关文章!

PyTorch深度学习LSTM从input输入到Linear输出的更多相关文章

  1. HTML5 input新增type属性color颜色拾取器的实例代码

    type 属性规定 input 元素的类型。本文较详细的给大家介绍了HTML5 input新增type属性color颜色拾取器的实例代码,感兴趣的朋友跟随脚本之家小编一起看看吧

  2. 移动HTML5前端框架—MUI的使用

    这篇文章主要介绍了移动HTML5前端框架—MUI的使用的相关资料,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  3. 使用placeholder属性设置input文本框的提示信息

    这篇文章主要介绍了使用placeholder属性设置input文本框的提示信息,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下

  4. Bootstrap File Input文件上传组件

    这篇文章主要介绍了Bootstrap File Input文件上传组件,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  5. HTML5中input输入框默认提示文字向左向右移动的示例代码

    这篇文章主要介绍了HTML5中input输入框默认提示文字向左向右移动,本文通过实例代码给大家介绍的非常详细对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  6. swift 正则表达式运用实例选自《swifter 100个swift开发必备tip 》

  7. Swift 2.0关键字guard

    viewmode=list前言:当一项新的技术出来的时候,第一参考自然是文档。文档链接guard语句guard语句的作用是:当某些条件不满足的情况下,跳出作用域举个例子:写个函数,保证输入小于10在playground输入如下可以看到输出上述方法和使用if一样但是使用guard有一个好处如果不使用return,break,continue,throw跳出当前作用域,编译器会报错所以,对那些对条件要求十分严格的地方,guard是不二之选。另外,guard也可以使用可选绑定也就是guardlet的格式例如如何

  8. Swift 柯里化(currying)和反柯里化(uncurrying)

    //DemoofcurryingfuncaddTwoNums(a:Int)(num:Int)->Int{returna+num}letaddToFour=addTwoNums(4)letresult=addToFour(num:6)print("result:\(result)")funcgreaterThan(comparor:Int)(input:Int)->Bool{returninput>

  9. swift – 上下文类型“AnyObject”不能与字典文字一起使用?

    我正在尝试将Objective-C示例转换为Swift2,但我遇到一个小问题。原来的Objective-C片段:我认为Swift代码应该是:结果错误是:在这种情况下,如何将Objective-C转换成Swift?因此,声明数组更具体在Swift3中用于JSON集合类型或字典/数组仅包含值类型使用

  10. swift – 我可以指定generic是值类型吗?

    我知道我们可以通过使用AnyObject来基本上指定我们的泛型是任何引用类型:但是有没有办法指定我们的泛型应该只是值类型,不允许引用类型?

随机推荐

  1. 10 个Python中Pip的使用技巧分享

    众所周知,pip 可以安装、更新、卸载 Python 的第三方库,非常方便。本文小编为大家总结了Python中Pip的使用技巧,需要的可以参考一下

  2. python数学建模之三大模型与十大常用算法详情

    这篇文章主要介绍了python数学建模之三大模型与十大常用算法详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感想取得小伙伴可以参考一下

  3. Python爬取奶茶店数据分析哪家最好喝以及性价比

    这篇文章主要介绍了用Python告诉你奶茶哪家最好喝性价比最高,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧

  4. 使用pyinstaller打包.exe文件的详细教程

    PyInstaller是一个跨平台的Python应用打包工具,能够把 Python 脚本及其所在的 Python 解释器打包成可执行文件,下面这篇文章主要给大家介绍了关于使用pyinstaller打包.exe文件的相关资料,需要的朋友可以参考下

  5. 基于Python实现射击小游戏的制作

    这篇文章主要介绍了如何利用Python制作一个自己专属的第一人称射击小游戏,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手试一试

  6. Python list append方法之给列表追加元素

    这篇文章主要介绍了Python list append方法如何给列表追加元素,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

  7. Pytest+Request+Allure+Jenkins实现接口自动化

    这篇文章介绍了Pytest+Request+Allure+Jenkins实现接口自动化的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  8. 利用python实现简单的情感分析实例教程

    商品评论挖掘、电影推荐、股市预测……情感分析大有用武之地,下面这篇文章主要给大家介绍了关于利用python实现简单的情感分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  9. 利用Python上传日志并监控告警的方法详解

    这篇文章将详细为大家介绍如何通过阿里云日志服务搭建一套通过Python上传日志、配置日志告警的监控服务,感兴趣的小伙伴可以了解一下

  10. Pycharm中运行程序在Python console中执行,不是直接Run问题

    这篇文章主要介绍了Pycharm中运行程序在Python console中执行,不是直接Run问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

返回
顶部