データ前処理(二):bertモデルの入力データフォーマットにデータを処理する


前節で処理したデータを用いてmatch_にデータを格納するdataファイルの下で、bertモデルの入力データフォーマットにデータを再処理します
  • 必要なパッケージ
  • をインポートする.
    import torch
    import os
    import pickle as pkl
    from tqdm import tqdm
    from torch.utils.data import dataset
    
    class TextMatchDataset(dataset.Dataset):
        def __init__(self, config, path):
            self.config = config
            self.path = path
            self.inference = False
            self.max_len = self.config.pad_size
            self.contents = self.load_dataset_match(config)
    
        def load_dataset_match(self, config):
            if "test" in self.path:
                self.inference = True
            if self.config.token_type:
                pad, cls, sep = '[PAD]', '[CLS]', '[SEP]'
            else:
                pad, cls, sep = '', '', ''
    
            contens = []
            lenth_count = []
            file_stream = open(self.path, 'r', encoding="utf-8")
            for line in tqdm(file_stream.readlines()):
                lin = line.strip()
                if not lin:
                    continue
                if len(lin.split("\t")) != 3:
                    print(line)
                    continue
                source, target, label = lin.split('\t')
                token_id_full = []
                mask_full = []
                #          
                seq_source = config.tokenizer.tokenize(source[:(self.max_len - 2)])
                seq_target = config.tokenizer.tokenize(target[:(self.max_len - 1)])
                #              
                seq_token = [cls] + seq_source + [sep] + seq_target + [sep]
                #    
                seq_segment = [0] * (len(seq_source) + 2) + [1] * (len(seq_target) + 1)
                # id   
                seq_idx = self.config.tokenizer.convert_tokens_to_ids(seq_token)
                #   max_len seq_idx         
                padding = [0] * ((self.max_len * 2) - len(seq_idx))
                # seg_mask
                seq_mask = [1] * len(seq_idx) + padding
                #  seq      
                seq_idx = seq_idx + padding
                # seq_segment
                seq_segment = seq_segment + padding
                
                # print(seq_idx)
                # print(seq_mask)
                # print(seq_segment)
                # print(len(seq_idx))
                # print(len(seq_mask))
                # print(len(seq_segment))
                
                assert len(seq_idx) == self.max_len * 2
                assert len(seq_mask) == self.max_len * 2
                assert len(seq_segment) == self.max_len * 2
    
                token_id_full.append(seq_idx)
                token_id_full.append(seq_mask)
                token_id_full.append(seq_segment)
    
                if self.inference:
                    token_id_full.append(label)
                else:
                    token_id_full.append(int(label))
    
                contens.append(token_id_full)
    
            return contens
    
        def __getitem__(self, index):
            elements = self.contents[index]
            seq_idx = torch.LongTensor(elements[0])
            seq_mask = torch.LongTensor(elements[1])
            seq_segment = torch.LongTensor(elements[2])
            if not self.inference:
                label = torch.LongTensor([elements[3]])
            else:
                label = [elements[3]]
            return (seq_idx, seq_mask, seq_segment), label
    
        def __len__(self):
            return len(self.contents)
    
    from param import Param
    if __name__ == '__main__':
        param = Param(base_path="./match_data", model_name="SimBERT_A")
        train_data = TextMatchDataset(param, param.dev_path)
        (token, mask, segment), label = train_data[0]
        print(train_data[4300])
        print(len(token))
        print(len(mask))
        print(len(segment))
    
  • Param.py
  • import os.path as osp
    # from util import mkdir_if_no_dir
    import os
    from transformers import BertTokenizer, ElectraTokenizer, AutoTokenizer
    
    def mkdir_if_no_dir(path):
        """         """
        if not os.path.exists(path):
            os.mkdir(path)
    
    class Param:
        def __init__(self, base_path, model_name):
            if "A" in model_name:
                self.train_path = osp.join(base_path, 'train_A.txt')           #    
                self.dev_path = osp.join(base_path, 'valid_A.txt')               #    
                self.test_path = osp.join(base_path, 'test_A.txt')             #    
                self.result_path = osp.join(base_path, "predict_A.csv")
            else:
                self.train_path = osp.join(base_path, 'train_B.txt')           #    
                self.dev_path = osp.join(base_path, 'valid_B.txt')               #    
                self.test_path = osp.join(base_path, 'test_B.txt')             #    
                self.result_path = osp.join(base_path, "predict_B.csv")
            print([self.train_path, self.dev_path, self.test_path, self.result_path])
            mkdir_if_no_dir(osp.join(base_path, "saved_dict"))
            mkdir_if_no_dir(osp.join(base_path, "log"))
            self.save_path = osp.join(osp.join(base_path, 'saved_dict'), model_name + '.pt')  #       
            self.log_path = osp.join(osp.join(base_path, "log"), model_name)                  #       
            self.vocab_path = osp.join(base_path, "vocab.pkl")
            self.class_path = osp.join(base_path, "class.txt")
            self.vocab = {
         }
            self.device = None
            self.token_type = True
            self.model_name = "BERT"
            self.warmup_steps = 1000
            self.t_total = 100000
            self.class_list = {
         }
            with open(self.class_path, "r", encoding="utf-8") as fr:
                idx = 0
                for line in fr:
                    line = line.strip("
    "
    ) self.class_list[line] = idx idx += 1 self.class_list_verse = { v: k for k, v in self.class_list.items()} self.num_epochs = 5 # epoch self.batch_size = 32 # mini-batch self.pad_size = 256 # ( ) self.learning_rate = 1e-5 # self.require_improvement = 10000000 # 1000batch , self.multi_gpu = True self.device_ids = [0, 1] self.full_fine_tune = True self.use_adamW = True self.input_language = "multi" # ["eng", "original", "multi"] self.MAX_VOCAB_SIZE = 20000 self.min_vocab_freq = 1 if "BERT" in model_name: print("Load BERT Tokenizer") self.bert_path = "bert-base-chinese" self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) else: print("Load BERT Tokenizer") self.bert_path = "bert-base-chinese" self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
  • match_dataの下のclass.txtファイル
  • 0
    1