Sparkソース分析--Task

8981 ワード

TaskはDAGSchedulerとTaskSchedulerの間のインターフェースであるDAGSchedulerであり、DAGの各ステージの各partitionsをtaskにカプセル化して最終的にtasksetをTaskSchedulerに提出する必要がある
 
/**
 * A task to execute on a worker node.
 */
private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
  def run(attemptId: Long): T  //Task     
  def preferredLocations: Seq[TaskLocation] = Nil //Spark  locality,     task   location
  var epoch: Long = -1   // Map output tracker epoch. Will be set by TaskScheduler.
  var metrics: Option[TaskMetrics] = None
}

 
TaskContext
TaskMetricsとTaskで使用されたcallbackの記録に使用
例えばHadoopRDDではtask完了時にclose input streamが必要
package org.apache.spark
class TaskContext(
  val stageId: Int,
  val splitId: Int,
  val attemptId: Long,
  val runningLocally: Boolean = false,
  val taskMetrics: TaskMetrics = TaskMetrics.empty() //TaskMetrics   task           
) extends Serializable {

  @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]

  // Add a callback function to be executed on task completion. An example use
  // is for HadoopRDD to register a callback to close the input stream.
  def addOnCompleteCallback(f: () => Unit) {
    onCompleteCallbacks += f
  }

  def executeOnCompleteCallbacks() {
    onCompleteCallbacks.foreach{_()}
  }
}

 
ResultTask
Resultステージに対応した直接結果
package org.apache.spark.scheduler
private[spark] class ResultTask[T, U](
    stageId: Int,
    var rdd: RDD[T],
    var func: (TaskContext, Iterator[T]) => U,
    var partition: Int,
    @transient locs: Seq[TaskLocation],
    var outputId: Int)
  extends Task[U](stageId) with Externalizable {

  override def run(attemptId: Long): U = {  //   resultTask, run         ,   count 
    val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
    metrics = Some(context.taskMetrics)
    try {
      func(context, rdd.iterator(split, context)) //      RDD iterator  func,   count  
    } finally {
      context.executeOnCompleteCallbacks()
    }
  }
}

 
ShuffleMapTask
ShuffleMapステージに対応し、結果を他のステージへの入力とする
package org.apache.spark.scheduler
private[spark] class ShuffleMapTask(
    stageId: Int,
    var rdd: RDD[_],
    var dep: ShuffleDependency[_,_],
    var partition: Int,
    @transient private var locs: Seq[TaskLocation])
  extends Task[MapStatus](stageId)
  with Externalizable
  with Logging {

  override def run(attemptId: Long): MapStatus = {
    val numOutputSplits = dep.partitioner.numPartitions //  ShuffleDependency partitioner    shuffle  partition   

    val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
    metrics = Some(taskContext.taskMetrics)

    val blockManager = SparkEnv.get.blockManager // shuffle    blockManager   
    var shuffle: ShuffleBlocks = null
    var buckets: ShuffleWriterGroup = null

    try {
      // Obtain all the block writers for shuffle blocks.
      val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
      shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) //   shuffleBlockManager,    shuffleId   partitions  
      buckets = shuffle.acquireWriters(partition) //   shuffle  buckets(   partition)

      // Write the map output to its associated buckets.
      for (elem <- rdd.iterator(split, taskContext)) { //  RDD     elem  
        val pair = elem.asInstanceOf[Product2[Any, Any]]
        val bucketId = dep.partitioner.getPartition(pair._1) //   pair key  shuffle,     bucketid
        buckets.writers(bucketId).write(pair) //  pair    bucket
      }
      
      // Commit  buckets block,    RDD    shuffleid    block,      
      // Commit the writes. Get the size of each bucket block (total block size).
      var totalBytes = 0L
      val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => //     buckets    data size  (   )
        writer.commit()
        writer.close()
        val size = writer.size()
        totalBytes += size
        MapOutputTracker.compressSize(size)
      }

      // Update shuffle metrics.
      val shuffleMetrics = new ShuffleWriteMetrics
      shuffleMetrics.shuffleBytesWritten = totalBytes
      metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)

      return new MapStatus(blockManager.blockManagerId, compressedSizes) //     MapStatus,   blockManagerId    data size,      MapOutputTracker
    } catch { case e: Exception =>
      // If there is an exception from running the task, revert the partial writes
      // and throw the exception upstream to Spark.
      if (buckets != null) {
        buckets.writers.foreach(_.revertPartialWrites())
      }
      throw e
    } finally {
      // Release the writers back to the shuffle block manager.
      if (shuffle != null && buckets != null) {
        shuffle.releaseWriters(buckets)
      }
      // Execute the callbacks on task completion.
      taskContext.executeOnCompleteCallbacks()
    }
  }

 
TaskSet
ステージをカプセル化するために使用されるすべてのtasksは、TaskSchedulerにコミットされます.
package org.apache.spark.scheduler
/**
 * A set of tasks submitted together to the low-level TaskScheduler, usually representing
 * missing partitions of a particular stage.
 */
private[spark] class TaskSet(
    val tasks: Array[Task[_]],
    val stageId: Int,
    val attempt: Int,
    val priority: Int,
    val properties: Properties) {
    val id: String = stageId + "." + attempt

  override def toString: String = "TaskSet " + id
}