Week4 Day3

18885 ワード

📋 SequenceSequenceSequence tototo SequenceSequenceSequence withwithwith AttentionAttentionAttention


📌 Seq2SeqSeq2SeqSeq2Seq ModelModelModel

  • コーデックと復号デコーダからなり、単語のシーケンスを入力として受信し、単語のシーケンスを出力とする.
  • 欠点:シーケンスシーケンスの長さが長くなっても、固定サイズのベクトルベクトルに情報+(Long term term dependency)依存項を含める.

    📌Seq2SeqSeq2SeqSeq2Seq ModelModelModel withwithwith AttentionAttentionAttention


    各時間ラベル上のデコーダがいくつかの重要なソースソースソースシーケンスに集中できるように、注意注意注意注意注意注意注意注意注意注意メカニズムが必要である.
  • によってボトルネックの問題も解決できます.
  • は、それぞれ内積コーデック非表示状態とデコーダ非表示状態とにより、得られた確率分布を用いてコーデック非表示状態重み付け和を求める.
  • とデコーダ非表示状態のconcatconcatを出力値として使用します.
  • により、損失計算による勾配は、エンコーダ非表示状態に直接伝達される.
  • class DotAttention(nn.Module):
      def __init__(self):
        super().__init__()
    
      def forward(self, decoder_hidden, encoder_outputs):  # (1, B, d_h), (S_L, B, d_h)
        query = decoder_hidden.squeeze(0)  # (B, d_h)
        key = encoder_outputs.transpose(0, 1)  # (B, S_L, d_h)
    
        energy = torch.sum(torch.mul(key, query.unsqueeze(1)), dim=-1)  # (B, S_L)
    
        attn_scores = F.softmax(energy, dim=-1)  # (B, S_L)
        attn_values = torch.sum(torch.mul(encoder_outputs.transpose(0, 1), attn_scores.unsqueeze(2)), dim=1)  # (B, d_h)
    
        return attn_values, attn_scores
        
    class Decoder(nn.Module):
      def forward(self, batch, encoder_outputs, hidden):  
        outputs, hidden = self.rnn(batch_emb, hidden)  # (1, B, d_h), (1, B, d_h)
    
        attn_values, attn_scores = self.attention(hidden, encoder_outputs)  # (B, d_h), (B, S_L)
        concat_outputs = torch.cat((outputs, attn_values.unsqueeze(0)), dim=-1)  # (1, B, 2d_h)
    
        return self.output_linear(concat_outputs).squeeze(0), hidden  # (B, V), (1, B, d_h)

    📌 DifferenctDifferenctDifferenct AttentionAttentionAttention MechanismsMechanismsMechanisms

  • はまた、内在的ではなく、追加の学習可能なパラメータ演算を含む類似性の測定を可能にする注意注意注意注意注意注意注意注意注意注意注意注意注意機構も存在する.
  • class ConcatAttention(nn.Module):
      def __init__(self):
        super().__init__()
    
        self.w = nn.Linear(2*hidden_size, hidden_size, bias=False)
        self.v = nn.Linear(hidden_size, 1, bias=False)
    
      def forward(self, decoder_hidden, encoder_outputs):  # (1, B, d_h), (S_L, B, d_h)
        src_max_len = encoder_outputs.shape[0]
    
        decoder_hidden = decoder_hidden.transpose(0, 1).repeat(1, src_max_len, 1)  # (B, S_L, d_h)
        encoder_outputs = encoder_outputs.transpose(0, 1)  # (B, S_L, d_h)
    
        concat_hiddens = torch.cat((decoder_hidden, encoder_outputs), dim=2)  # (B, S_L, 2d_h)
        energy = torch.tanh(self.w(concat_hiddens))  # (B, S_L, d_h)
    
        attn_scores = F.softmax(self.v(energy), dim=1)  # (B, S_L, 1)
        attn_values = torch.sum(torch.mul(encoder_outputs, attn_scores), dim=1)  # (B, d_h)
    
        return attn_values, attn_scores
        
    class Decoder(nn.Module):
      def __init__(self, attention):
        super().__init__()
    
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.attention = attention
        self.rnn = nn.GRU(
            embedding_size + hidden_size,
            hidden_size
        )
        self.output_linear = nn.Linear(hidden_size, vocab_size)
    
      def forward(self, batch, encoder_outputs, hidden):  # batch: (B), encoder_outputs: (S_L, B, d_h), hidden: (1, B, d_h)  
        batch_emb = self.embedding(batch)  # (B, d_w)
        batch_emb = batch_emb.unsqueeze(0)  # (1, B, d_w)
    
        attn_values, attn_scores = self.attention(hidden, encoder_outputs)  # (B, d_h), (B, S_L)
    
        concat_emb = torch.cat((batch_emb, attn_values.unsqueeze(0)), dim=-1)  # (1, B, d_w+d_h)
    
        outputs, hidden = self.rnn(concat_emb, hidden)  # (1, B, d_h), (1, B, d_h)
    
        return self.output_linear(outputs).squeeze(0), hidden  # (B, V), (1, B, d_h)

    📋 BeamBeamBeam searchsearchsearch

  • 理想的には、翻訳において、入力文が付与されると、出力文単語の接続接続確率確率確率確率確率確率確率確率確率が最大となる選択である.
  • しかし、現在の状態が最適であるGreedyGreedyGreedyを選択する方法は、最終的な最適選択ではない、
  • .
  • すべての状況に対する「取り尽くせない、使い尽くせない、使い尽くせない検索検索」には、過剰な演算量が必要である.
  • 📌 BeamBeamBeam searchsearchsearch

  • デコーダの各時間ラベルは、kk個の可能な部分翻訳を追跡する.
  • 📋 BLEUBLEUBLEU scorescorescore


    📌 PrecisionPrecisionPrecision andandand RecallRecallRecall

  • p r e c i o n p r e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d e c i o n d
  • 回想:精度
  • は、実際に存在する正解に一致する情報がどれだけあるかを示す.
  • FFF-measurementmeasurement:precisionprecisionprecisionprecision精密度とリコールリコールリコール率の調和平均値において、小さい値に少し重み付けを加えた測定方法
  • しかし、これらの方法は語順
  • を測定することができない.

    📌 BiLingualBiLingualBiLingual EvaluationEvaluationEvaluation UnderstudyUnderstudyUnderstudy (BLEU)(BLEU)(BLEU)

  • -GramN-GramN図の重なりを測定し、
  • 四サイズnn-gramgram図の精密測定
  • 短い予測単語はprecisionprecisiondecificationに有利であるため、短いペナルティタイプが追加された.