1. 程式人生 > >pytorch筆記:08)使用LSTM寫古詩詞

pytorch筆記:08)使用LSTM寫古詩詞

測試環境:
centos7 + python3.6 + pytorch0.4 +cuda9

下面是用模型生成的藏頭詩(深度學習)

深宮昔時見,古貌多自有。
度日不相容,年年生一目。
學者若為霖,百姓貽憂厄。
習坎與天聰,優遊寧敢屢。

訓練資料
57580首詩歌,每首詩歌,書(pytorch入門與實踐)的作者對其進行了預處理,每首詩歌長度125字元(不足補空格,超過則丟棄)
下面data.py檔案用於提取資料

import numpy as np
import os

def get_data(conf):
    '''
    生成資料
    :param conf: 配置選項,Config物件
    :return: word2ix: 每個字元對應的索引id,如u'月'->100
    :return: ix2word: 每個字元對應的索引id,如100->u'月'
    :return: data: 每一行是一首詩對應的字的索引id
    '''
if os.path.exists(conf.data_path): datas = np.load(conf.data_path) #np資料檔案 data = datas['data'] ix2word = datas['ix2word'].item() word2ix = datas['word2ix'].item() return data, word2ix, ix2word

配置檔案

class Config(object):
    """Base configuration class.For custom configurations, create a
    sub-class that inherits from this one and override  properties that
    need to changed
    """
#模型儲存路徑字首(幾個epoch後儲存) model_prefix='checkpoints/tang' #模型儲存路徑 model_path='checkpoints/tang.pth' #start words start_words='春江花月夜' #生成詩歌的型別,預設為藏頭詩 p_type='acrostic' # 訓練次數 max_epech = 200 #資料存放的路徑 data_path='tang.npz' #mini批大小 batch_size=128 #dataloader載入資料使用多少程序
num_workers=4 #LSTM的層數 num_layers=2 #詞向量維數 embedding_dim=128 #LSTM隱藏層維度 hidden_dim=256 #多少個epoch儲存一次模型權重和詩 save_every=10 #訓練是生成詩的儲存路徑 out_path='out' #測試生成詩的儲存路徑 out_poetry_path='out/poetry.txt' #生成詩的最大長度 max_gen_len=200

模型定義

class PoetryModel(nn.Module):
    def __init__(self, vocab_size, conf, device):
        super(PoetryModel, self).__init__()
        self.num_layers = conf.num_layers
        self.hidden_dim = conf.hidden_dim
        self.device = device
        # 定義詞向量層
        self.embeddings = nn.Embedding(vocab_size, conf.embedding_dim)
        # 定義2層的LSTM,並且batch位於函式引數的第一位
        self.lstm = nn.LSTM(conf.embedding_dim, conf.hidden_dim, num_layers=self.num_layers)
        # 定義全連線層,後接一個softmax進行分類
        self.linear_out = nn.Linear(conf.hidden_dim, vocab_size)

    def forward(self, input, hidden=None):
        '''
        :param input: (seq,batch)
        :return: 模型的結果
        '''
        seq_len, batch_size = input.size()
        # embeds_size:(seq_len,batch_size,embedding_dim)
        embeds = self.embeddings(input)
        if hidden is None:
            h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device)
            c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device)
        else:
            h0, c0 = hidden
        output, hidden = self.lstm(embeds, (h0, c0))
        # output_size:(seq_len*batch_size,vocab_size)
        output = self.linear_out(output.view(seq_len * batch_size, -1))
        return output, hidden

模型訓練

def train(**kwargs):
    conf = Config()
    for k, v in kwargs.items():
        setattr(conf, k, v)
    # 獲取資料
    data, word2ix, ix2word = get_data(conf)
    # 生成dataload
    dataloader = DataLoader(dataset=data, batch_size=conf.batch_size,
                            shuffle=True,
                            drop_last=True,
                            num_workers=conf.num_workers)
    # 定義模型
    model = PoetryModel(len(word2ix), conf, device).to(device)
    # 定義優化器
    optimizer = Adam(model.parameters())
    # 定義損失函式
    criterion = nn.CrossEntropyLoss()
    # 開始訓練模型
    for epoch in range(conf.max_epech):
        epoch_loss = 0
        for i, data in enumerate(dataloader):
            data = data.long().transpose(1, 0).contiguous()
            input, target = data[:-1, :], data[1:, :]
            input, target = input.to(device), target.to(device)
            optimizer.zero_grad()
            output, _ = model(input)
            loss = criterion(output, target.view(-1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print("epoch_%d_loss:%0.4f" % (epoch, epoch_loss / conf.batch_size))
        if epoch % conf.save_every == 0:
            fout = open('%s/p%d' % (conf.out_path, epoch), 'w')
            for word in list('春江花月夜'):
                gen_poetry = generate(model, word, ix2word, word2ix, conf)
                fout.write("".join(gen_poetry) + "\n\n")
            fout.close()
            torch.save(model.state_dict(), "%s_%d.pth" % (conf.model_prefix, epoch))

本內容參考陳雲《pytorch入門與實踐》