Halide::Funcをcout


はじめに

だいぶ時間が空きました.
前置きはここと同じなので省略.
HalideでデバッグするときにFuncの中身を覗きたいときに役に立つことに関して書きました.

Halide::Funcの本体と出力

Halideで関数を表すHalide::Funcですが,その本体はHalide::Internal::Functionにあり,更に関数の定義はargsvaluesによってなされています
例えばFunc ff(x,y)=2*x+yのとき,argsは{x, y}でvaluesは{2*x+y}です.
argsは純粋定義の場合vector<string>(純粋定義はHalide::Varを用いて行われるが,実際は各Varの固有な名前しか使われていない?),更新定義の場合vector<Halide::Expr>で,ベクトル要素数=関数の次元数です.
valuesは純粋定義,更新定義ともにvector<Halide::Expr>で,ベクトル要素数=Funcの出力をTupleと見たときの要素数です.
Halide::Exprはostreamへ出力可能なので,argsvaluesを引っ張ってこれれば,Funcの出力ができそうです.

using namespace Halide;

// Halide::Internal::Functionの出力
ostream& operator<<(ostream& os, const Halide::Internal::Function f)
{
    if (f.has_pure_definition())
    {
        // 純粋定義の出力
        os << f.name() << "(";
        for (int i = 0; i < f.args().size(); i++)
        {
            os << f.args()[i];
            if (i < f.args().size() - 1) os << ", ";
        }
        os << ")=";
        if (f.values().size() > 1) // 関数がTupleの場合
        {
            os << "Tuple(";
            for (int i = 0; i < f.values().size(); i++)
            {
                os << f.values()[i];
                if (i < f.values().size() - 1) os << ", ";
            }
            os << ")" << endl;
        }
        else
        {
            os << f.values()[0] << endl;
        }
    }
    // 更新定義の出力
    for (int u = 0; u < f.updates().size(); u++)
    {
        os << f.name() << "(";
        Halide::Internal::Definition u_def = f.update(u);
        for (int i = 0; i < u_def.args().size(); i++)
        {
            os << u_def.args()[i];
            if (i < u_def.args().size() - 1) os << ", ";
        }
        os << ")=";
        if (u_def.values().size() > 1)
        {
            os << "Tuple(";
            for (int i = 0; i < u_def.values().size(); i++)
            {
                os << u_def.values()[i];
                if (i < u_def.values().size() - 1) os << ", ";
            }
            os << ")" << endl;
        }
        else
        {
            os << u_def.values()[0] << endl;
        }
    }
    return os;
}

// Halide::Funcの出力
ostream& operator<<(ostream& os, const Halide::Func f)
{
    os << f.function();
    return os;
}

上のコードでは,Halide::Internal::Functionのostreamへの<<演算子をオーバロードして,出力の本体部分を定義します.
Funcからは.function()で内部のFunctionにアクセスできるので,そのままos << f.function()でok.

実験

Var x("x"), y("y");
Func clamped = BoundaryConditions::repeat_edge(src);

RDom r(-rad, 2 * rad + 1, "r");
Func blur_x("blur_x"), blur_y("blur_y"), total("total"), kernel("kernel");

Expr d = -1.f / (2.f * sigma * sigma);
total() = sum(fast_exp((r * r) * d),"total_sum");

kernel(x) = fast_exp((x * x) * d);
blur_x(x, y) = sum(kernel(r) * clamped(x + r, y), "blurx_sum") / total();
blur_y(x, y) = sum(kernel(r) * blur_x(x, y + r), "blury_sum") / total();

cout << "blur_yは..."
cout << blur_y; // 出力
出力結果
blur_yは...blur_y(x, y)=((float32)blury_sum(x, y)/(float32)total())

ちゃんと出力できました.

入れ子関数も見たい

先ほどの実験ですと,一番外側のblur_yだけ出力されて,入れ子になっているblury_sumtotalまでは見れません.
そこで,Halide::Internal::populate_environmentを使います.
Halide::Internal::populate_environmentFunctionに対して,そのFunctionで入れ子になっているFunctionたちのstd::map<std::string, Halide::Internal::Function>を生成してくれます.
下のコードはFunctionFuncについて,入れ子になっているFunctionを出力する関数です.

void print_all_ref_functions(Halide::Internal::Function f)
{
    map<string, Halide::Internal::Function> env;
    Halide::Internal::populate_environment(f, env);
    for (auto it = env.begin(); it != env.end(); it++)
    {
        cout << it->first << "...\n";
        cout << it->second;
        cout << endl;
    }
}

void print_all_ref_funcs(Halide::Func f)
{
    print_all_ref_functions(f.function());
}

実験

先ほどのblur_yについて,入れ子関数まで出力させてみます.

Var x("x"), y("y");
Func clamped = BoundaryConditions::repeat_edge(src);

RDom r(-rad, 2 * rad + 1, "r");
Func blur_x("blur_x"), blur_y("blur_y"), total("total"), kernel("kernel");

Expr d = -1.f / (2.f * sigma * sigma);
total() = sum(fast_exp((r * r) * d),"total_sum");

kernel(x) = fast_exp((x * x) * d);
blur_x(x, y) = sum(kernel(r) * clamped(x + r, y), "blurx_sum") / total();
blur_y(x, y) = sum(kernel(r) * blur_x(x, y + r), "blury_sum") / total();

print_all_ref_funcs(blur_y); // 出力
出力結果
blur_x...
blur_x(x, y)=((float32)blurx_sum(x, y)/(float32)total())

blur_y...
blur_y(x, y)=((float32)blury_sum(x, y)/(float32)total())

blurx_sum...
blurx_sum(x, y)=0.000000f
blurx_sum(x, y)=((float32)blurx_sum(x, y) + ((float32)kernel(r$x)*(float32)repeat_edge(x + r$x, y)))

blury_sum...
blury_sum(x, y)=0.000000f
blury_sum(x, y)=((float32)blury_sum(x, y) + ((float32)kernel(r$x)*(float32)blur_x(x, y + r$x)))

kernel...
kernel(x)=(let t16 = float32((x*x)) in (let t17 = (float32)floor_f32((t16*-0.055556f)/0.693147f) in (let t18 = ((t16*-0.055556f) - (t17*0.693147f)) in (let t19 = (t18*t18) in (((((((0.013144f*t19) + 0.168739f)*t19) + 1.000000f)*t18) + ((((0.036690f*t19) + 0.499705f)*t19) + 1.000000f))*(float32)reinterpret(shift_left(max(min(int32(t17) + 127, 255), 0), (uint32)23)))))))

lambda_0...
lambda_0(_0, _1)=(float32)b0(_0, _1)

repeat_edge...
repeat_edge(_0, _1)=(float32)lambda_0(max(min(likely(_0), (0 + 512) - 1), 0), max(min(likely(_1), (0 + 512) - 1), 0))

total...
total()=(float32)total_sum()

total_sum...
total_sum()=0.000000f
total_sum()=(let t8 = float32((r$x*r$x)) in (let t9 = (float32)floor_f32((t8*-0.055556f)/0.693147f) in (let t10 = ((t8*-0.055556f) - (t9*0.693147f)) in (let t11 = (t10*t10) in ((float32)total_sum() + (((((((0.013144f*t11) + 0.168739f)*t11) + 1.000000f)*t10) + ((((0.036690f*t11) + 0.499705f)*t11) + 1.000000f))*(float32)reinterpret(shift_left(max(min(int32(t9) + 127, 255), 0), (uint32)23))))))))

こんな感じで,blur_yを構成するFuncを全て吐き出してくれます.
普段sumを使うと,総和の計算をするFuncはHalide側で自動的に生成され,その参照のExprが返されるので,sumはほぼブラックボックス状態ですが,このように引っ張り出してくることも可能です(応用すればsumのスケジューリングも変更可能!).
他にもBoundaryConditions::repeat_edgefast_expの実装が覗けます(ソースファイル見てもわかるけど...).

最後に

HalideのFuncの中身を覗く方法を紹介しました.
この方法を用いればブラックボックス状態のものの中身が分かるため,デバッグが捗るかと思います.