Tensorflowソース分析のExecutorState

64313 ワード

前に書く
この文書は、ソースコードの各文の意味をできるだけ詳細な注釈で示す.Tensorflowバージョン:1.10私は自分の慣れた順番や分かりやすい順番、あるいは私だけが知っているところで、間違いがあれば、コメントしてください.感謝しています~.このコードコメントはレコードとしてのみ使用され、他の使用として使用されません.
1. ScheduleReady
  //         Process           ready    node         
  void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
                                    TaggedNodeReadyQueue* inline_ready) {
    // Node   Node    ready   。
    //    node      3 node   ,  3 node   1   。        PropagateOutput    ready  。
    
	//  ready    ,                 ,       (              , input      ),whatever,        ,     。
    if (ready.empty()) return;

    int64 scheduled_nsec = 0;
    if (stats_collector_) {
      scheduled_nsec = nodestats::NowInNsec();
    }
    //inline_ready              Process ready      Node。
    //       :
    if (inline_ready == nullptr) {
      // Schedule to run all the ready ops in thread pool.
      //      , ready      node      Process    。
      for (auto& tagged_node : ready) {
        runner_([=]() { Process(tagged_node, scheduled_nsec); });
      }
      return;
    }
    const GraphView& gview = impl_->gview_;
    //                  (   )  op
    const TaggedNode* curr_expensive_node = nullptr;
    //      inlined_ready     。     node。
    for (auto& tagged_node : ready) {
      const NodeItem& item = *gview.node(tagged_node.node->id());
      //           (      ,  GPU  )
      if (tagged_node.is_dead || !item.kernel_is_expensive) {
        // Inline this inexpensive node.
        //         inline_ready (      ,        )。
        inline_ready->push_back(tagged_node);
      } else {
      	//  ,             。  :
      	//       Process                op
        if (curr_expensive_node) {
          // Dispatch to another thread since there is plenty of work to
          // do for this thread.
          //             op       。
          runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
                            scheduled_nsec));
        }
        //      op     。
        curr_expensive_node = &tagged_node;
        //    ,     ready        op      cpu  ,       。      。。。
      }
    }
    //     ,                op
    if (curr_expensive_node) {
    	//  inline_ready  ,         。  :
      if (inline_ready->empty()) {
        // Tail recursion optimization
        //            op
        inline_ready->push_back(*curr_expensive_node);
      } else {
        // There are inline nodes to run already. We dispatch this expensive
        // node to other thread.
        //          ,   op    。         op        。
        runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
                          scheduled_nsec));
      }
    }
  }

2. Process
//Process              op,   op     。        ,tagged_node      ready,       。
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
    WithContext wc(context_);
    const GraphView& gview = impl_->gview_;
    //ready       ready  node(op)
    TaggedNodeSeq ready;
    // inline_ready   tagged_node     node(op)
    TaggedNodeReadyQueue inline_ready;
	
	//          。
    // Parameters passed to OpKernel::Compute.
    TensorValueVec inputs;
    DeviceContextVec input_device_contexts;
    AllocatorAttributeVec input_alloc_attrs;

    OpKernelContext::Params params;
    params.step_id = step_id_;
    Device* device = impl_->params_.device;
    params.device = device;
    params.log_memory = log_memory_;
    params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
    params.rendezvous = rendezvous_;
    params.collective_executor = collective_executor_;
    params.session_state = session_state_;
    params.tensor_store = tensor_store_;
    params.cancellation_manager = cancellation_manager_;
    params.call_frame = call_frame_;
    params.function_library = impl_->params_.function_library;
    params.resource_manager = device->resource_manager();
    params.step_container = step_container_;
    params.slice_reader_cache = slice_reader_cache_;
    params.inputs = &inputs;
    params.input_device_contexts = &input_device_contexts;
    params.input_alloc_attrs = &input_alloc_attrs;
    params.runner = &runner_;
    params.stats_collector = stats_collector_;

    Status s;
    NodeExecStatsInterface* stats = nullptr;

    EntryVector outputs;
    bool completed = false;
    //       tagged_node  inline_ready,    Process      。
    //     ,inline_ready    tagged_node       。
    inline_ready.push_back(tagged_node);
    while (!inline_ready.empty()) {
      //       。
      tagged_node = inline_ready.front();
      inline_ready.pop_front();
      //              ,  node,id,item 。
      const Node* node = tagged_node.node;
      FrameState* input_frame = tagged_node.input_frame;
      const int64 input_iter = tagged_node.input_iter;
      const int id = node->id();
      //gview              ,   string_view。
      const NodeItem& item = *gview.node(id);

      // TODO(misard) Replace with a finer-grain enabling flag once we
      // add better optional debugging support.
      if (vlog_ && VLOG_IS_ON(1)) {
             mutex_lock l(input_frame->mu);
        input_frame->GetIteration(input_iter)->mark_started(item.pending_id);
      }

      // Set the device_context for this node id, if it exists.
      //    id         。
      if (id < device_context_map_.size()) {
        params.op_device_context = device_context_map_[id];
      }

      params.track_allocations = false;
      stats = nullptr;
      if (stats_collector_ && !tagged_node.is_dead) {
        stats = stats_collector_->CreateNodeExecStats(node);
        // Track allocations if and only if we are collecting statistics, and
        // `stats` object is expecting allocations to be tracked.
        params.track_allocations = stats ? stats->TrackAllocations() : false;
        nodestats::SetScheduled(stats, scheduled_nsec);
        nodestats::SetAllStart(stats);
      }

      if (vlog_) {
        VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
                << SummarizeNode(*node) << (tagged_node.is_dead ? " is dead" : "")
                << " device: " << device->name();
      }

      Entry* input_tensors = GetInputTensors(input_frame, input_iter);
      Entry* first_input = input_tensors + item.input_start;
      outputs.clear();

      TensorReferenceVector accessed_tensors;
      DeviceContext* device_context = nullptr;
      // Only execute this node if it is not dead or it is a send/recv
      // transfer node. For transfer nodes, we need to propagate the "dead"
      // bit even when the node is dead.
      bool launched_asynchronously = false;
      if (tagged_node.is_dead && !IsTransferNode(node)) {
        outputs.resize(item.num_outputs);
      } else {
        // Prepares inputs.
        bool is_input_dead = false;
        s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
                          &input_alloc_attrs, &is_input_dead);
        if (!s.ok()) {
          // Clear inputs.
          int num_inputs = item.num_inputs;
          for (int i = 0; i < num_inputs; ++i) {
            (first_input + i)->ClearVal();
          }
          MaybeMarkCompleted(input_frame, input_iter, id);
          // Continue to process the nodes in 'inline_ready'.
                completed = NodeDone(s, item.node, ready, stats, &inline_ready);
          continue;
        }

        // Set up compute params.
        OpKernel* op_kernel = item.kernel;
        params.op_kernel = op_kernel;
        params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
        params.is_input_dead = is_input_dead;
        params.output_attr_array = item.output_attrs();
        params.forward_from_array = item.forward_from();

        if (item.kernel_is_async) {
          // Asynchronous computes.
          AsyncOpKernel* async = item.kernel->AsAsync();
          DCHECK(async != nullptr);
          launched_asynchronously = true;
          AsyncState* state =
              new AsyncState(params, tagged_node, &item, first_input, stats);

          auto done = [this, state]() {
            Device* device = impl_->params_.device;
            NodeExecStatsInterface* stats = state->stats;  // Shorthand
            Entry* first_input = state->first_input;       // Shorthand

            nodestats::SetOpEnd(stats);
            EntryVector outputs;
            Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
            nodestats::SetMemory(stats, &state->ctx);
            if (vlog_) {
              VLOG(2) << "Async kernel done: " << state->item->node->id()
                      << " step " << step_id_ << " "
                      << SummarizeNode(*state->item->node)
                      << (state->tagged_node.is_dead ? " is dead" : "")
                      << " device: " << device->name();
            }

            // Clears inputs.
            const int num_inputs = state->item->num_inputs;
            for (int i = 0; i < num_inputs; ++i) {
              (first_input + i)->ClearVal();
            }
            FrameState* input_frame = state->tagged_node.input_frame;
            const int64 input_iter = state->tagged_node.input_iter;
            const int id = state->tagged_node.node->id();
            MaybeMarkCompleted(input_frame, input_iter, id);
            TaggedNodeSeq ready;
            if (s.ok()) {
              PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
            }
            outputs.clear();
                     if (s.ok() && impl_->device_record_tensor_accesses_) {
              // Get the list of all tensors accessed during the execution
              TensorReferenceVector accessed;
              state->ctx.retrieve_accessed_tensors(&accessed);
              nodestats::SetReferencedTensors(stats, accessed);
              // callee takes ownership of the vector
              device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
                                                   accessed);
            }
            const bool completed =
                NodeDone(s, state->item->node, ready, stats, nullptr);
            delete state;
            if (completed) Finish();
          };
          nodestats::SetOpStart(stats);
          device->ComputeAsync(async, &state->ctx, done);
        } else {
          // Synchronous computes.
          OpKernelContext ctx(&params, item.num_outputs);
          nodestats::SetOpStart(stats);

          if (TF_PREDICT_FALSE(
                  MightTrace(item, event_collector_, trace_using_annotations_))) {
            const string& op_name = op_kernel->name();
            tracing::ScopedRegion region(tracing::EventCategory::kCompute,
                                         op_name);
            if (trace_using_annotations_) {
              // The OpKernel may create child activities (such as GPU kernel
              // launches), so use a `ScopedAnnotation` to relate these activities
              // in the trace.
              tracing::ScopedAnnotation activity(
                  op_name, strings::StrCat(op_kernel->type_string(),
                                           "#id=", step_id_, "#"));
              device->Compute(op_kernel, &ctx);
            } else {
              // Use the cheaper `ScopedActivity` to trace just the OpKernel
              // execution.
              tracing::ScopedActivity activity(
                  op_name,
                  strings::StrCat(op_kernel->type_string(), "#id=", step_id_,
                                  "#"),
                  item.kernel_is_expensive);
              device->Compute(op_kernel, &ctx);
            }
          } else {
            // In the common case, avoid creating any tracing objects.
            device->Compute(op_kernel, &ctx);
          }

          nodestats::SetOpEnd(stats);
                   s = ProcessOutputs(item, &ctx, &outputs, stats);
          if (s.ok() && impl_->device_record_tensor_accesses_) {
            // Get the list of all tensors accessed during the execution
            ctx.retrieve_accessed_tensors(&accessed_tensors);
            device_context = ctx.op_device_context();
          }
          nodestats::SetMemory(stats, &ctx);
        }
      }

      if (!launched_asynchronously) {
        if (vlog_) {
          VLOG(2) << "Synchronous kernel done: " << id << " step "
                  << params.step_id << " " << SummarizeNode(*node)
                  << (tagged_node.is_dead ? " is dead: " : "")
                  << " device: " << device->name();
        }

        // Clears inputs.
        const int num_inputs = item.num_inputs;
        for (int i = 0; i < num_inputs; ++i) {
          (first_input + i)->ClearVal();
        }
        MaybeMarkCompleted(input_frame, input_iter, id);
        // Propagates outputs.
        if (s.ok()) {
          PropagateOutputs(tagged_node, &item, &outputs, &ready);
        }
        outputs.clear();
        if (!accessed_tensors.empty()) {
          nodestats::SetReferencedTensors(stats, accessed_tensors);
          // device_context is set above in synchronous computes
          device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
        }
        if (stats) {
          scheduled_nsec = nodestats::NowInNsec();
        }
        // Postprocess.
        completed = NodeDone(s, item.node, ready, stats, &inline_ready);
      }
    }  // while !inline_ready.empty()

    // This thread of computation is done if completed = true.
    if (completed) Finish();
  }

3. NodeDone