Tensorflowソース分析のExecutorState
64313 ワード
前に書く
この文書は、ソースコードの各文の意味をできるだけ詳細な注釈で示す.Tensorflowバージョン:1.10私は自分の慣れた順番や分かりやすい順番、あるいは私だけが知っているところで、間違いがあれば、コメントしてください.感謝しています~.このコードコメントはレコードとしてのみ使用され、他の使用として使用されません.
1. ScheduleReady
2. Process
3. NodeDone
この文書は、ソースコードの各文の意味をできるだけ詳細な注釈で示す.Tensorflowバージョン:1.10私は自分の慣れた順番や分かりやすい順番、あるいは私だけが知っているところで、間違いがあれば、コメントしてください.感謝しています~.このコードコメントはレコードとしてのみ使用され、他の使用として使用されません.
1. ScheduleReady
// Process ready node
void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
TaggedNodeReadyQueue* inline_ready) {
// Node Node ready 。
// node 3 node , 3 node 1 。 PropagateOutput ready 。
// ready , , ( , input ),whatever, , 。
if (ready.empty()) return;
int64 scheduled_nsec = 0;
if (stats_collector_) {
scheduled_nsec = nodestats::NowInNsec();
}
//inline_ready Process ready Node。
// :
if (inline_ready == nullptr) {
// Schedule to run all the ready ops in thread pool.
// , ready node Process 。
for (auto& tagged_node : ready) {
runner_([=]() { Process(tagged_node, scheduled_nsec); });
}
return;
}
const GraphView& gview = impl_->gview_;
// ( ) op
const TaggedNode* curr_expensive_node = nullptr;
// inlined_ready 。 node。
for (auto& tagged_node : ready) {
const NodeItem& item = *gview.node(tagged_node.node->id());
// ( , GPU )
if (tagged_node.is_dead || !item.kernel_is_expensive) {
// Inline this inexpensive node.
// inline_ready ( , )。
inline_ready->push_back(tagged_node);
} else {
// , 。 :
// Process op
if (curr_expensive_node) {
// Dispatch to another thread since there is plenty of work to
// do for this thread.
// op 。
runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
scheduled_nsec));
}
// op 。
curr_expensive_node = &tagged_node;
// , ready op cpu , 。 。。。
}
}
// , op
if (curr_expensive_node) {
// inline_ready , 。 :
if (inline_ready->empty()) {
// Tail recursion optimization
// op
inline_ready->push_back(*curr_expensive_node);
} else {
// There are inline nodes to run already. We dispatch this expensive
// node to other thread.
// , op 。 op 。
runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
scheduled_nsec));
}
}
}
2. Process
//Process op, op 。 ,tagged_node ready, 。
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
WithContext wc(context_);
const GraphView& gview = impl_->gview_;
//ready ready node(op)
TaggedNodeSeq ready;
// inline_ready tagged_node node(op)
TaggedNodeReadyQueue inline_ready;
// 。
// Parameters passed to OpKernel::Compute.
TensorValueVec inputs;
DeviceContextVec input_device_contexts;
AllocatorAttributeVec input_alloc_attrs;
OpKernelContext::Params params;
params.step_id = step_id_;
Device* device = impl_->params_.device;
params.device = device;
params.log_memory = log_memory_;
params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
params.rendezvous = rendezvous_;
params.collective_executor = collective_executor_;
params.session_state = session_state_;
params.tensor_store = tensor_store_;
params.cancellation_manager = cancellation_manager_;
params.call_frame = call_frame_;
params.function_library = impl_->params_.function_library;
params.resource_manager = device->resource_manager();
params.step_container = step_container_;
params.slice_reader_cache = slice_reader_cache_;
params.inputs = &inputs;
params.input_device_contexts = &input_device_contexts;
params.input_alloc_attrs = &input_alloc_attrs;
params.runner = &runner_;
params.stats_collector = stats_collector_;
Status s;
NodeExecStatsInterface* stats = nullptr;
EntryVector outputs;
bool completed = false;
// tagged_node inline_ready, Process 。
// ,inline_ready tagged_node 。
inline_ready.push_back(tagged_node);
while (!inline_ready.empty()) {
// 。
tagged_node = inline_ready.front();
inline_ready.pop_front();
// , node,id,item 。
const Node* node = tagged_node.node;
FrameState* input_frame = tagged_node.input_frame;
const int64 input_iter = tagged_node.input_iter;
const int id = node->id();
//gview , string_view。
const NodeItem& item = *gview.node(id);
// TODO(misard) Replace with a finer-grain enabling flag once we
// add better optional debugging support.
if (vlog_ && VLOG_IS_ON(1)) {
mutex_lock l(input_frame->mu);
input_frame->GetIteration(input_iter)->mark_started(item.pending_id);
}
// Set the device_context for this node id, if it exists.
// id 。
if (id < device_context_map_.size()) {
params.op_device_context = device_context_map_[id];
}
params.track_allocations = false;
stats = nullptr;
if (stats_collector_ && !tagged_node.is_dead) {
stats = stats_collector_->CreateNodeExecStats(node);
// Track allocations if and only if we are collecting statistics, and
// `stats` object is expecting allocations to be tracked.
params.track_allocations = stats ? stats->TrackAllocations() : false;
nodestats::SetScheduled(stats, scheduled_nsec);
nodestats::SetAllStart(stats);
}
if (vlog_) {
VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
<< SummarizeNode(*node) << (tagged_node.is_dead ? " is dead" : "")
<< " device: " << device->name();
}
Entry* input_tensors = GetInputTensors(input_frame, input_iter);
Entry* first_input = input_tensors + item.input_start;
outputs.clear();
TensorReferenceVector accessed_tensors;
DeviceContext* device_context = nullptr;
// Only execute this node if it is not dead or it is a send/recv
// transfer node. For transfer nodes, we need to propagate the "dead"
// bit even when the node is dead.
bool launched_asynchronously = false;
if (tagged_node.is_dead && !IsTransferNode(node)) {
outputs.resize(item.num_outputs);
} else {
// Prepares inputs.
bool is_input_dead = false;
s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
&input_alloc_attrs, &is_input_dead);
if (!s.ok()) {
// Clear inputs.
int num_inputs = item.num_inputs;
for (int i = 0; i < num_inputs; ++i) {
(first_input + i)->ClearVal();
}
MaybeMarkCompleted(input_frame, input_iter, id);
// Continue to process the nodes in 'inline_ready'.
completed = NodeDone(s, item.node, ready, stats, &inline_ready);
continue;
}
// Set up compute params.
OpKernel* op_kernel = item.kernel;
params.op_kernel = op_kernel;
params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
params.is_input_dead = is_input_dead;
params.output_attr_array = item.output_attrs();
params.forward_from_array = item.forward_from();
if (item.kernel_is_async) {
// Asynchronous computes.
AsyncOpKernel* async = item.kernel->AsAsync();
DCHECK(async != nullptr);
launched_asynchronously = true;
AsyncState* state =
new AsyncState(params, tagged_node, &item, first_input, stats);
auto done = [this, state]() {
Device* device = impl_->params_.device;
NodeExecStatsInterface* stats = state->stats; // Shorthand
Entry* first_input = state->first_input; // Shorthand
nodestats::SetOpEnd(stats);
EntryVector outputs;
Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
nodestats::SetMemory(stats, &state->ctx);
if (vlog_) {
VLOG(2) << "Async kernel done: " << state->item->node->id()
<< " step " << step_id_ << " "
<< SummarizeNode(*state->item->node)
<< (state->tagged_node.is_dead ? " is dead" : "")
<< " device: " << device->name();
}
// Clears inputs.
const int num_inputs = state->item->num_inputs;
for (int i = 0; i < num_inputs; ++i) {
(first_input + i)->ClearVal();
}
FrameState* input_frame = state->tagged_node.input_frame;
const int64 input_iter = state->tagged_node.input_iter;
const int id = state->tagged_node.node->id();
MaybeMarkCompleted(input_frame, input_iter, id);
TaggedNodeSeq ready;
if (s.ok()) {
PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
}
outputs.clear();
if (s.ok() && impl_->device_record_tensor_accesses_) {
// Get the list of all tensors accessed during the execution
TensorReferenceVector accessed;
state->ctx.retrieve_accessed_tensors(&accessed);
nodestats::SetReferencedTensors(stats, accessed);
// callee takes ownership of the vector
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
accessed);
}
const bool completed =
NodeDone(s, state->item->node, ready, stats, nullptr);
delete state;
if (completed) Finish();
};
nodestats::SetOpStart(stats);
device->ComputeAsync(async, &state->ctx, done);
} else {
// Synchronous computes.
OpKernelContext ctx(¶ms, item.num_outputs);
nodestats::SetOpStart(stats);
if (TF_PREDICT_FALSE(
MightTrace(item, event_collector_, trace_using_annotations_))) {
const string& op_name = op_kernel->name();
tracing::ScopedRegion region(tracing::EventCategory::kCompute,
op_name);
if (trace_using_annotations_) {
// The OpKernel may create child activities (such as GPU kernel
// launches), so use a `ScopedAnnotation` to relate these activities
// in the trace.
tracing::ScopedAnnotation activity(
op_name, strings::StrCat(op_kernel->type_string(),
"#id=", step_id_, "#"));
device->Compute(op_kernel, &ctx);
} else {
// Use the cheaper `ScopedActivity` to trace just the OpKernel
// execution.
tracing::ScopedActivity activity(
op_name,
strings::StrCat(op_kernel->type_string(), "#id=", step_id_,
"#"),
item.kernel_is_expensive);
device->Compute(op_kernel, &ctx);
}
} else {
// In the common case, avoid creating any tracing objects.
device->Compute(op_kernel, &ctx);
}
nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, &outputs, stats);
if (s.ok() && impl_->device_record_tensor_accesses_) {
// Get the list of all tensors accessed during the execution
ctx.retrieve_accessed_tensors(&accessed_tensors);
device_context = ctx.op_device_context();
}
nodestats::SetMemory(stats, &ctx);
}
}
if (!launched_asynchronously) {
if (vlog_) {
VLOG(2) << "Synchronous kernel done: " << id << " step "
<< params.step_id << " " << SummarizeNode(*node)
<< (tagged_node.is_dead ? " is dead: " : "")
<< " device: " << device->name();
}
// Clears inputs.
const int num_inputs = item.num_inputs;
for (int i = 0; i < num_inputs; ++i) {
(first_input + i)->ClearVal();
}
MaybeMarkCompleted(input_frame, input_iter, id);
// Propagates outputs.
if (s.ok()) {
PropagateOutputs(tagged_node, &item, &outputs, &ready);
}
outputs.clear();
if (!accessed_tensors.empty()) {
nodestats::SetReferencedTensors(stats, accessed_tensors);
// device_context is set above in synchronous computes
device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
}
if (stats) {
scheduled_nsec = nodestats::NowInNsec();
}
// Postprocess.
completed = NodeDone(s, item.node, ready, stats, &inline_ready);
}
} // while !inline_ready.empty()
// This thread of computation is done if completed = true.
if (completed) Finish();
}
3. NodeDone