pythonを使用してデータセットをtrain_に分割dataとtest_data

2194 ワード

画像の分類時に収集した分類されたデータをランダムに乱して比例して訓練セットとテストセットを生成します.以下はコードです.
import os
import random
import shutil

def data_random_split(current_dir,ratio_train):
    '''
                    train test    ,       
    '''
    data_listdir = os.listdir(current_dir)
    random.shuffle(data_listdir)
    train_len = int(len(data_listdir)*ratio_train)
    train_listdir = data_listdir[:train_len]
    #train_listdir = [os.path.join(current_dir, name) for name in data_listdir[:train_len]]
    #test_listdir  = [os.path.join(current_dir, name) for name in data_listdir[train_len:]]
    train_listdir = data_listdir[:train_len]
    test_listdir  = data_listdir[train_len:]
    return train_listdir,test_listdir

def data_generator(root,data_dir,ratio_train=0.7):
    '''
    root:    data  
    datadir:          ,    train_data、 test_data
    ratio_train:train_data       
    '''
    listdir = os.listdir(root)
    train_data_dir = os.path.join(data_dir, "train_data")
    test_data_dir = os.path.join(data_dir, "test_data")
    os.makedirs(train_data_dir)
    os.makedirs(test_data_dir)
    for name in listdir:
        print(name)
        current_dir = os.path.join(root, name)
        print(current_dir)
        train_dir_a,test_dir_a = data_random_split(current_dir,ratio_train)
        train_listdir_c = os.path.join(train_data_dir,name)
        test_listdir_c  = os.path.join(test_data_dir,name)
        a=1 if os.path.exists(train_listdir_c) else os.makedirs(train_listdir_c)
        a=1 if os.path.exists(test_listdir_c) else os.makedirs(test_listdir_c)
        for img in train_dir_a:
            train_listdir_b = os.path.join(current_dir, img) 
            train_listdir_d = os.path.join(train_listdir_c, img)
            train_dir_b = shutil.copy(train_listdir_b,train_listdir_d)
        for img in test_dir_a:
            test_listdir_b  = os.path.join(current_dir, img)
            test_listdir_d  = os.path.join(test_listdir_c, img)
            test_dir_b = shutil.copy(test_listdir_b,test_listdir_d)
                
    print('ok')
data_generator(r'F:\maize_disease',r'F:\1')