【経験共有】mindsporeモデル開発&modelarts;マルチカードトレーニング経験の共有

2460 ワード

転載先:https://bbs.huaweicloud.com/f...
作者:陳覇
モデルを開発する訓練部分は,データ処理,ネットワーク,損失関数,訓練に大別される.主にpytorchからmindsporeへの再現を実現しているので,両者の同じ部分は詳細に説明しない.
1.データ処理:
                    mindspore.dateset    ,    。     mindspore         。               MindRecord(MindSpore       ,      、          。),         .mindrecord    。     mindspore.dateset.GeneratorDataset       。  MindRecord           ,               ,           ,          cv.imread  ,            ,          ,                  。  ,               GeneratorDataset。


2.ネットワーク:
        pytorch    ,      ,       ,mindspore       ,      kernel_meta,        ,         。


3.損失関数:
    pytorch  ,          ,         ,  nn.cell,  init  construct。         mindspore    ,     init      , construct   。       ,                    。          :

pytorchのInterpolateはops.ResizeBilinear実装;
pytorchの2つのテンソルa,bはa[b>0]を実現するこの操作mindsporeではselect演算子を用いることができ、例えばcond=Tensor([True,False])x=Tensor([4,9])#あなたのテンソルAy=Tensor([0,0])#臨時select=P.select()z=select(cond,x,y)であり、最後にshapeとpytorchは異なるが、最後に和を求めることができる.
損失には一部の値のみが計算され、例えば(10,1)はその中の5つの値の損失のみを計算する必要がある.この場合は他の値を0に、equal演算子を用いて計算するindexを得ることができ、select演算子を用いてよいでしょう不要な値を0にすることができます.演算子を使用するにはTensorのデータ型に注意する必要があり、floatのみをサポートするものもあります.
cast演算子を使えばいいでしょうTensorデータ型は互いに変換します;
4.トレーニング:
          :    ,      ,     ,          model.train   ,                   。                 img,label      WithLossCell            ,  TrainOneStepCell                     model.train 。

5.カスタムcallback関数の使用方法:
   Mindspore  Callback      ,    ,       ,         ,     ,          。          Callback         Mindspore               。

公式Callbackクラス:
一般的には、公式に定義されたCallback関数の使用方法を使用します.
1.png
2.png
3.png
実行結果は次のとおりです.
4.png
カスタムCallback:
一般的に私にとって、使用は後ろの6つの関数で、主にstep_です.begin,step_end.関数名からその関数がいつ呼び出されたかがわかりますが、ここではこれ以上説明しません.Callbackは主にrun_にcontextの内容はある程度理解して、これは自分でソースコードを見ることができます.ここでは主に,自分の運転中にepochステップごとに平均損失に及ぼす影響を理解したいので,このカスタマイズを用いた.
5.png
6.png
7.png
実行結果は次のとおりです.
8.png
6.modelartsマルチカードトレーニング:
                     mix          obs  mix    catch ,      ckpt     mix    obs  。