python __call__の役割は、オブジェクトをメソッドとして使用できるキー分析nn.Moduleソースコードです.

4169 ワード

コードの例
import torch.nn as nn

class LSTMClassifier(nn.Module):
    """
    This is the simple RNN model we will be using to perform Sentiment Analysis.
    """

    def __init__(self, embedding_dim, hidden_dim, vocab_size):
        """
        Initialize the model by settingg up the various layers.
        """
        super(LSTMClassifier, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.dense = nn.Linear(in_features=hidden_dim, out_features=1)
        self.sig = nn.Sigmoid()
        
        self.word_dict = None

    def forward(self, x):
        """
        Perform a forward pass of our model on some input.
        """
        x = x.t()
        lengths = x[0,:]
        reviews = x[1:,:]
        embeds = self.embedding(reviews)
        lstm_out, _ = self.lstm(embeds)
        out = self.dense(lstm_out)
        out = out[lengths - 1, range(len(lengths))]
        return self.sig(out.squeeze())

 
 
def train(model, train_loader, epochs, optimizer, loss_fn, device):     for epoch in range(1, epochs + 1):         model.train()         total_loss = 0         for batch in train_loader:                      batch_X, batch_y = batch             print("batch_X.shape=",batch_X.shape)             print("batch_y.shape=",batch_y.shape)             batch_X = batch_X.to(device)             batch_y = batch_y.to(device)                         # TODO: Complete this train method to train the model provided.             optimizer.zero_grad()
            # get predictions from model             y_pred = model(batch_X)
            # perform backprop             loss = criterion(y_pred, batch_y)             loss.backward()             optimizer.step()             total_loss += loss.data.item()         print("Epoch: {}, BCELoss: {}".format(epoch, total_loss/len(train_loader)))
 
このクラスのLSTMClassifierオブジェクトmodelを使用して実際に呼び出されたのは、nn.Module()の__のためforward()です.call__メソッドでforwardメソッドが呼び出されました
#サードパーティがmodelコードを呼び出すシーンは次のとおりです.
import torch.optim as optim
from train.model import LSTMClassifier
device = torch.device("cuda"if torch.cuda.is_available() else "cpu")model = LSTMClassifier(32, 100, 5000).to(device) optimizer = optim.Adam(model.parameters(),lr=0.001) loss_fn = torch.nn.BCELoss()
train(model, train_sample_dl, 5, optimizer, loss_fn, device)
 
 
 
 
#上のクラスの親nn.Moduleの_call__メソッドではforwardメソッドが呼び出され、次のようになります.
def __call__(self, *input, **kwargs):     for hook in self._forward_pre_hooks.values():         result = hook(self, input)         if result is not None:             if not isinstance(result, tuple):                 result = (result,)             input = result     if torch._C._get_tracing_state():         result = self._slow_forward(*input, **kwargs)     else:         result = self.forward(*input, **kwargs)     for hook in self._forward_hooks.values():         hook_result = hook(self, input, result)         if hook_result is not None:             result = hook_result     if len(self._backward_hooks) > 0:         var = result         while not isinstance(var, torch.Tensor):             if isinstance(var, dict):                 var = next((v for v in var.values() if isinstance(v, torch.Tensor)))             else:                 var = var[0]         grad_fn = var.grad_fn         if grad_fn is not None:             for hook in self._backward_hooks.values():                 wrapper = functools.partial(hook, self)                 functools.update_wrapper(wrapper, hook)                 grad_fn.register_hook(wrapper)     return result
 
 
この現象を理解する鍵
__call__
_forward_pre_hooks
リファレンス
https://www.cnblogs.com/SBJBA/p/11355412.html
nn.Moduleメソッド逐次分析
https://blog.csdn.net/lishuiwang/article/details/104505675/
https://www.jb51.net/article/184033.htm
nn.Moduleソース
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py