MNN classficationTopkEval


MNN classficationTopkEval.cppモデルとプロファイルを入力し、ImageNetデータセット上のモデルの分類精度をテストします.プログラムは3つの部分に分かれています.
  • ImageProcessは、画像を適切なフォーマットのTensorに変換する.
  • InterpreterモデルファイルからNetとSessionを作成し、セッションを実行します.
  • computeTopkAccは分類精度を計算する.

  • Interpreter::resizeTensorは入力テンソルを調整し、さらにInterpreter::resizeSessionは後続のテンソルを調整する.

    main


    main
    runEvaluation
    MNN_PRINTはAndroid JNIロゴに対応しています.モデルとプロファイルを入力し、runEvaluation評価モデルを入力します.
        if (argc < 3) {
         
            MNN_PRINT("Usage: ./classficationTopkEval.out model.mnn preTreatConfig.json
    "
    ); } const auto modelPath = argv[1]; const auto preTreatConfigFile = argv[2]; runEvaluation(modelPath, preTreatConfigFile); return 0;

    runEvaluation


    Created with Raphaël 2.2.0 runEvaluation modelPath, preTreatConfig Document Document::Parse Document::GetObject ImageProcess::Config ImageProcess::create ScheduleConfig Interpreter::createSession Interpreter::getSessionInput Interpreter::resizeTensor Interpreter::resizeSession Interpreter::getSessionOutput stbi_load Matrix Matrix::setTranslate ImageProcess::setMatrix ImageProcess::convert stbi_image_free Interpreter::runSession computeTopkAcc End
    ネットワークの入力および前処理パラメータはjsonファイルに格納される.RapidJSONはテンセントオープンソースのC++JSON解析器とジェネレータです.std::ifstream::rdbufはストリームバッファを取得します.内部filebufオブジェクトへのポインタを返します.GenericDocument::Parseは、読み取り専用文字列からjsonテキストを解析します(符号化変換付き).
        int height, width;
        std::string imagePath;
        std::string groundTruthIdFile;
        rapidjson::Document document;
        {
         
            std::ifstream fileNames(preTreatConfig);
            std::ostringstream output;
            output << fileNames.rdbuf();
            auto outputStr = output.str();
            document.Parse(outputStr.c_str());
            if (document.HasParseError()) {
         
                MNN_ERROR("Invalid json
    "
    ); return 0; } }

    GenericDocument::GetObjectはObjectオブジェクトを返します.ImageProcess::Config構造体記録画像フォーマットおよび変換時のパラメータ.GenericObject::HasMemberクエリーフィールド.GenericObjectオブジェクトリロードオペレータoperator[]ImageFormatは列挙タイプです.GenericValue::GetArray GenericValue::GetFloat
        auto picObj = document.GetObject();
        ImageProcess::Config config;
        config.filterType = BILINEAR;
        // defalut input image format
        config.destFormat = BGR;
        {
         
            if (picObj.HasMember("format")) {
         
                auto format = picObj["format"].GetString();
                static std::map<std::string, ImageFormat> formatMap{
         {
         "BGR", BGR}, {
         "RGB", RGB}, {
         "GRAY", GRAY}};
                if (formatMap.find(format) != formatMap.end()) {
         
                    config.destFormat = formatMap.find(format)->second;
                }
            }
        }
        config.sourceFormat = RGBA;
        {
         
            if (picObj.HasMember("mean")) {
         
                auto mean = picObj["mean"].GetArray();
                int cur   = 0;
                for (auto iter = mean.begin(); iter != mean.end(); iter++) {
         
                    config.mean[cur++] = iter->GetFloat();
                }
            }
            if (picObj.HasMember("normal")) {
         
                auto normal = picObj["normal"].GetArray();
                int cur     = 0;
                for (auto iter = normal.begin(); iter != normal.end(); iter++) {
         
                    config.normal[cur++] = iter->GetFloat();
                }
            }
            if (picObj.HasMember("width")) {
         
                width = picObj["width"].GetInt();
            }
            if (picObj.HasMember("height")) {
         
                height = picObj["height"].GetInt();
            }
            if (picObj.HasMember("imagePath")) {
         
                imagePath = picObj["imagePath"].GetString();
            }
            if (picObj.HasMember("groundTruthId")) {
         
                groundTruthIdFile = picObj["groundTruthId"].GetString();
            }
        }
    

    ImageProcess::create ImageProcess::Configに基づいてImageProcessオブジェクトを作成します.Interpreter::createFromFileモデルファイルからInterpreterオブジェクトを作成します.Interpreterはネットワークデータを持ち,複数のセッションが同じネットワークを共有できる.ScheduleConfig構造体は、セッションスケジュール表の構成を行います.Interpreter::createSessionはScheduleConfigの構成を使用してセッションを作成します.作成されたセッションは、ネットワークで管理されます.Interpreter::getSessionInput指定した名前の入力テンソルを取得します.Interpreter::resizeTensor入力テンソルを調整します.Interpreter::resizeSessionはこの関数を呼び出してテンソルを準備します.入力テンソルのサイズを調整したら、出力テンソルバッファ(hostまたはdeviceId)を復元します.Interpreter::getSessionOutput所定の名前の出力テンソルを取得します.lstatは、指定されたファイルのステータス情報を取得し、bufパラメータが指すメモリ領域に配置する.
        std::shared_ptr<ImageProcess> pretreat(ImageProcess::create(config));
    
        std::shared_ptr<Interpreter> classficationInterpreter(Interpreter::createFromFile(modelPath));
        ScheduleConfig classficationEvalConfig;
        classficationEvalConfig.type      = MNN_FORWARD_CPU;
        classficationEvalConfig.numThread = 4;
        auto classficationSession         = classficationInterpreter->createSession(classficationEvalConfig);
        auto inputTensor                  = classficationInterpreter->getSessionInput(classficationSession, nullptr);
        auto shape                        = inputTensor->shape();
        // the model has not input dimension
        if(shape.size() == 0){
         
            shape.resize(4);
            shape[0] = 1;
            shape[1] = 3;
            shape[2] = height;
            shape[3] = width;
        }
        // set batch to be 1
        shape[0] = 1;
        classficationInterpreter->resizeTensor(inputTensor, shape);
        classficationInterpreter->resizeSession(classficationSession);
    
        auto outputTensor = classficationInterpreter->getSessionOutput(classficationSession, nullptr);
    

    Opendir関数は、ディレクトリ名に対応するディレクトリストリームを開き、そのディレクトリストリームへのポインタを返します.ストリームはディレクトリの最初のアイテムにあります.readdirは、direntが指すディレクトリストリームの次のディレクトリエントリを表すdirp構造を指すポインタを返す.ディレクトリ・ストリームの最後に到達したり、エラーが発生したりした場合、NULLを返します.
        // read ground truth label id
        std::vector<int> groundTruthId;
        {
         
            std::ifstream inputOs(groundTruthIdFile);
            std::string line;
            while (std::getline(inputOs, line)) {
         
                groundTruthId.emplace_back(std::atoi(line.c_str()));
            }
        }
    
        // read images file path
        int count = 0;
        std::vector<std::string> files;
        {
         
            struct stat s;
            lstat(imagePath.c_str(), &s);
            struct dirent* filename;
            DIR* dir;
            dir = opendir(imagePath.c_str());
            while ((filename = readdir(dir)) != nullptr) {
         
                if (strcmp(filename->d_name, ".") == 0 || strcmp(filename->d_name, "..") == 0) {
         
                    continue;
                }
                files.push_back(filename->d_name);
                count++;
            }
            std::cout << "total: " << count << std::endl;
            std::sort(files.begin(), files.end());
        }
    
        if (count != groundTruthId.size()) {
         
            MNN_ERROR("The number of input images is not same with ground truth id
    "
    ); return 0; }

    stbi_loadはnothings/stbに由来する.Matrixは座標を変換するための3 x 3行列を含む.これにより、点とベクトルをマッピングするには、平行移動、スケール、傾斜、回転、パースを使用します.Matrix::setTranslate行の移動量を設定します.ImageProcess::setMatrixシミュレーション変換マトリクスを設定します.ImageProcess::convertはソースデータを所定のテンソルに変換します.stbi_image_freeロードされた画像を解放します.Interpreter::runSessionはネットワークを実行します.
    出力結果を並べ替え,computeTopkAccは前kビット精度を計算する.
        int test = 0;
        int top1 = 0;
        int topk = 0;
    
        const int outputTensorSize = outputTensor->elementSize();
        if (outputTensorSize != TOTAL_CLASS_NUM) {
         
            MNN_ERROR("Change the total class number, such as the result number of tensorflow mobilenetv1/v2 is 1001
    "
    ); return 0; } std::vector<std::pair<int, float>> sortedResult(outputTensorSize); for (const auto& file : files) { const auto img = imagePath + file; int h, w, channel; auto inputImage = stbi_load(img.c_str(), &w, &h, &channel, 4); if (!inputImage) { MNN_ERROR("Can't open %s
    "
    , img.c_str()); return 0; } // input image transform Matrix trans; // choose resize or crop // resize method // trans.setScale((float)(w-1) / (width-1), (float)(h-1) / (height-1)); // crop method trans.setTranslate(16.0f, 16.0f); pretreat->setMatrix(trans); pretreat->convert((uint8_t*)inputImage, h, w, 0, inputTensor); stbi_image_free(inputImage); classficationInterpreter->runSession(classficationSession); { // default float value auto outputDataPtr = outputTensor->host<float>(); for (int i = 0; i < outputTensorSize; ++i) { sortedResult[i] = std::make_pair(i, outputDataPtr[i]); } std::sort(sortedResult.begin(), sortedResult.end(), [](std::pair<int, float> a, std::pair<int, float> b) { return a.second > b.second; }); } computeTopkAcc(groundTruthId, sortedResult, test, &top1, &topk); test++; MNN_PRINT("==> tested: %f, Top1: %f, Topk: %f
    "
    , (float)test / (float)count * 100.0, (float)top1 / (float)test * 100.0, (float)topk / (float)test * 100.0); } return 0;

    computeTopkAcc

        const int label = groundTruthId[index];
        if (sortedResult[0].first == label) {
         
            (*top1)++;
        }
        for (int i = 0; i < TOPK; ++i) {
         
            if (label == sortedResult[i].first) {
         
                (*topk)++;
                break;
            }
        }
    

    参考資料:

  • MNN推理過程ソースコード分析ノート(一)主流程
  • MNNのtflite-MobilenetSSD-c++導入プロセス
  • の詳細
  • Android Stuido Ndk-Jni開発(二):Jniにロゴ情報を印刷する
  • Constraints and concepts (since C++20)
  • lstat(), lstat64() — Get status of file or symbolic link
  • FlatBuffersのSchema
  • に深く入る
  • FlatBuffersのEncode
  • に深く入る