Spark中常用工具类Utils的简明介绍

《深入理解Spark:核心思想与源码分析》一书前言的内容请看链接《深入理解SPARK:核心思想与源码分析》一书正式出版上市

《深入理解Spark:核心思想与源码分析》一书第一章的内容请看链接《第1章 环境准备》

《深入理解Spark:核心思想与源码分析》一书第二章的内容请看链接《第2章 SPARK设计理念与基本架构》

《深入理解Spark:核心思想与源码分析》一书第三章第一部分的内容请看链接《深入理解Spark:核心思想与源码分析》——SparkContext的初始化(伯篇)》

《深入理解Spark:核心思想与源码分析》一书第三章第二部分的内容请看链接《深入理解Spark:核心思想与源码分析》——SparkContext的初始化(仲篇)》

《深入理解Spark:核心思想与源码分析》一书第三章第三部分的内容请看链接《深入理解Spark:核心思想与源码分析》——SparkContext的初始化(叔篇)》

《深入理解Spark:核心思想与源码分析》一书第三章第四部分的内容请看链接《深入理解Spark:核心思想与源码分析》——SparkContext的初始化(季篇)》

 

Utils是Spark中最常用的工具类之一,如果不关心其实现,也不会对理解Spark有太多影响。但是对于Scala或者Spark的初学者来说,通过了解Utils工具类的实现,也是个不错的入门途径。下面将逐个介绍Utils工具类提供的常用方法。

1.localHostName

功能描述:获取本地机器名。

def localHostName(): String = { 
customHostname.getOrElse(localIpAddressHostname) 
} 

 

2.getDefaultPropertiesFile

功能描述:获取默认的Spark属性文件。

def getDefaultPropertiesFile(env: Map[String, String] = sys.env): String = { 
  env.get("SPARK_CONF_DIR") 
  .orElse(env.get("SPARK_HOME").map{ t => s"$t${File.separator}conf"}) 
  .map { t => new File(s"$t${File.separator}spark-defaults.conf")} 
  .filter(_.isFile) 
  .map(_.getAbsolutePath) 
  .orNull 
} 

 

3.loadDefaultSparkProperties

功能描述:加载指定文件中的Spark属性,如果没有指定文件,则加载默认Spark属性文件的属性。

def loadDefaultSparkProperties(conf:SparkConf, filePath: String = null):String = { 
  val path =Option(filePath).getOrElse(getDefaultPropertiesFile()) 
  Option(path).foreach { confFile => 
    getPropertiesFromFile(confFile).filter{ case (k,v) => 
    k.startsWith("spark.") 
    }.foreach { case (k, v) => 
      conf.setIfMissing(k, v) 
      sys.props.getOrElseUpdate(k, v) 
    } 
  } 
  path 
} 

4.getCallSite

功能描述:获取当前SparkContext的当前调用堆栈,将栈里最靠近栈底的属于spark或者Scala核心的类压入callStack的栈顶,并将此类的方法存入lastSparkMethod;将栈里最靠近栈顶的用户类放入callStack,将此类的行号存入firstUserLine,类名存入firstUserFile,最终返回的样例类CallSite存储了最短栈和长度默认为20的最长栈的样例类。在JavaWordCount例子中,获得的数据如下:
最短栈:JavaSparkContext at JavaWordCount.java:44;
最长栈:org.apache.spark.api.java.JavaSparkContext.<init>(JavaSparkContext.scala:61)org.apache.spark.examples.JavaWordCount.main(JavaWordCount.java:44)。

def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = {
    val trace = Thread.currentThread.getStackTrace().filterNot { ste: StackTraceElement =>
      ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace")
    }
    var lastSparkMethod = "<unknown>"
    var firstUserFile = "<unknown>"
    var firstUserLine = 0
    var insideSpark = true
    var callStack = new ArrayBuffer[String]() :+ "<unknown>"

    for (el <- trace) {
      if (insideSpark) {
        if (skipClass(el.getClassName)) {
          lastSparkMethod = if (el.getMethodName == "<init>") {
            el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
          } else {
            el.getMethodName
          }
          callStack(0) = el.toString // Put last Spark method on top of the stack trace.
        } else {
          firstUserLine = el.getLineNumber
          firstUserFile = el.getFileName
          callStack += el.toString
          insideSpark = false
        }
      } else {
        callStack += el.toString
      }
    }
    val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt
    CallSite(
      shortForm = s"$lastSparkMethod at $firstUserFile:$firstUserLine",
      longForm = callStack.take(callStackDepth).mkString("\n"))
  }

 

5.startServiceOnPort

功能描述:Scala跟其它脚本语言一样,函数也可以传递,此方法正是通过回调startService这个函数来启动服务,并最终返回startService返回的service地址及端口。如果启动过程有异常,还会多次重试,直到达到maxRetries表示的最大次数。

def startServiceOnPort[T](
      startPort: Int,
      startService: Int => (T, Int),
      conf: SparkConf,
      serviceName: String = ""): (T, Int) = {
    require(startPort == 0 || (1024 <= startPort && startPort < 65536),
      "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.")
    val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
    val maxRetries = portMaxRetries(conf)
    for (offset <- 0 to maxRetries) {
      val tryPort = if (startPort == 0) {
        startPort
      } else {
        ((startPort + offset - 1024) % (65536 - 1024)) + 1024
      }
      try {
        val (service, port) = startService(tryPort)
        logInfo(s"Successfully started service$serviceString on port $port.")
        return (service, port)
      } catch {
        case e: Exception if isBindCollision(e) =>
          if (offset >= maxRetries) {
            val exceptionMessage =
              s"${e.getMessage}: Service$serviceString failed after $maxRetries retries!"
            val exception = new BindException(exceptionMessage)
            exception.setStackTrace(e.getStackTrace)
            throw exception
          }
          logWarning(s"Service$serviceString could not bind on port $tryPort. " +
            s"Attempting port ${tryPort + 1}.")
      }
    }
    throw new SparkException(s"Failed to start service$serviceString on port $startPort")
  }

6.createDirectory

功能描述:用spark+UUID的方式创建临时文件目录,如果创建失败会多次重试,最多重试10次。

def createDirectory(root: String, namePrefix: String = "spark"): File = {
    var attempts = 0
    val maxAttempts = MAX_DIR_CREATION_ATTEMPTS
    var dir: File = null
    while (dir == null) {
      attempts += 1
      if (attempts > maxAttempts) {
        throw new IOException("Failed to create a temp directory (under " + root + ") after " +
          maxAttempts + " attempts!")
      }
      try {
        dir = new File(root, "spark-" + UUID.randomUUID.toString)
        if (dir.exists() || !dir.mkdirs()) {
          dir = null
        }
      } catch { case e: SecurityException => dir = null; }
    }

    dir
  }

 

7.getOrCreateLocalRootDirs

功能描述:根据spark.local.dir的配置,作为本地文件的根目录,在创建一、二级目录之前要确保根目录是存在的。然后调用createDirectory创建一级目录。

private[spark] def getOrCreateLocalRootDirs(conf: SparkConf): Array[String] = {
    if (isRunningInYarnContainer(conf)) {
      getYarnLocalDirs(conf).split(",")
    } else {
      Option(conf.getenv("SPARK_LOCAL_DIRS"))
        .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")))
        .split(",")
        .flatMap { root =>
          try {
            val rootDir = new File(root)
            if (rootDir.exists || rootDir.mkdirs()) {
              val dir = createDirectory(root)
              chmod700(dir)
              Some(dir.getAbsolutePath)
            } else {
              logError(s"Failed to create dir in $root. Ignoring this directory.")
              None
            }
          } catch {
            case e: IOException =>
            logError(s"Failed to create local root dir in $root. Ignoring this directory.")
            None
          }
        }
        .toArray
    }
  }

8.getLocalDir

功能描述:查询Spark本地文件的一级目录。

def getLocalDir(conf: SparkConf): String = {
    getOrCreateLocalRootDirs(conf)(0)
  }

9.createTempDir

功能描述:在Spark一级目录下创建临时目录,并将目录注册到shutdownDeletePaths:scala.collection.mutable.HashSet[String]中。

def createTempDir(
      root: String = System.getProperty("java.io.tmpdir"),
      namePrefix: String = "spark"): File = {
    val dir = createDirectory(root, namePrefix)
    registerShutdownDeleteDir(dir)
    dir
  }

 

10.RegisterShutdownDeleteDir

功能描述:将目录注册到shutdownDeletePaths:scala.collection.mutable.HashSet[String]中,以便在进程退出时删除。

  def registerShutdownDeleteDir(file: File) {
    val absolutePath = file.getAbsolutePath()
    shutdownDeletePaths.synchronized {
      shutdownDeletePaths += absolutePath
    }
  }

11.hasRootAsShutdownDeleteDir

功能描述:判断文件是否匹配关闭时要删除的文件及目录,shutdownDeletePaths:scala.collection.mutable.HashSet[String]存储在进程关闭时要删除的文件及目录。

def hasRootAsShutdownDeleteDir(file: File): Boolean = {
    val absolutePath = file.getAbsolutePath()
    val retval = shutdownDeletePaths.synchronized {
      shutdownDeletePaths.exists { path =>
        !absolutePath.equals(path) && absolutePath.startsWith(path)
      }
    }
    if (retval) {
      logInfo("path = " + file + ", already present as root for deletion.")
    }
    retval
  }

 

12.deleteRecursively

功能描述:用于删除文件或者删除目录及其子目录、子文件,并且从shutdownDeletePaths:scala.collection.mutable.HashSet[String]中移除此文件或目录。

def deleteRecursively(file: File) {
    if (file != null) {
      try {
        if (file.isDirectory && !isSymlink(file)) {
          var savedIOException: IOException = null
          for (child <- listFilesSafely(file)) {
            try {
              deleteRecursively(child)
            } catch {
              case ioe: IOException => savedIOException = ioe
            }
          }
          if (savedIOException != null) {
            throw savedIOException
          }
          shutdownDeletePaths.synchronized {
            shutdownDeletePaths.remove(file.getAbsolutePath)
          }
        }
      } finally {
        if (!file.delete()) {
          if (file.exists()) {
            throw new IOException("Failed to delete: " + file.getAbsolutePath)
          }
        }
      }
    }
  }

 

13.getSparkClassLoader

功能描述:获取加载当前class的ClassLoader。

  def getSparkClassLoader = getClass.getClassLoader

 

14.getContextOrSparkClassLoader

功能描述:用于获取线程上下文的ClassLoader,没有设置时获取加载Spark的ClassLoader。

def getContextOrSparkClassLoader =
    Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)

 

15.newDaemonCachedThreadPool

功能描述:使用Executors.newCachedThreadPool创建的缓存线程池。

  def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = {
    val threadFactory = namedThreadFactory(prefix)
    Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
  }

16.doFetchFile

功能描述:使用URLConnection通过http协议下载文件。

private def doFetchFile(url: String, targetDir: File, filename: String, conf: SparkConf,
      securityMgr: SecurityManager, hadoopConf: Configuration) {
    val tempFile = File.createTempFile("fetchFileTemp", null, new File(targetDir.getAbsolutePath))
    val targetFile = new File(targetDir, filename)
    val uri = new URI(url)
    val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
    Option(uri.getScheme).getOrElse("file") match {
      case "http" | "https" | "ftp" =>
        logInfo("Fetching " + url + " to " + tempFile)
        var uc: URLConnection = null
        if (securityMgr.isAuthenticationEnabled()) {
          logDebug("fetchFile with security enabled")
          val newuri = constructURIForAuthentication(uri, securityMgr)
          uc = newuri.toURL().openConnection()
          uc.setAllowUserInteraction(false)
        } else {
          logDebug("fetchFile not using security")
          uc = new URL(url).openConnection()
        }
        val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000
        uc.setConnectTimeout(timeout)
        uc.setReadTimeout(timeout)
        uc.connect()
        val in = uc.getInputStream()
        downloadFile(url, in, tempFile, targetFile, fileOverwrite)
      case "file" =>
        val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
        copyFile(url, sourceFile, targetFile, fileOverwrite)
      case _ =>
        val fs = getHadoopFileSystem(uri, hadoopConf)
        val in = fs.open(new Path(uri))
        downloadFile(url, in, tempFile, targetFile, fileOverwrite)
    }
  }

 

17.fetchFile

功能描述:如果文件在本地有缓存,则从本地获取,否则通过HTTP远程下载。最后对.tar、.tar.gz等格式的文件解压缩后,调用shell命令行的chmod命令给文件增加a+x的权限。

def fetchFile(
      url: String,
      targetDir: File,
      conf: SparkConf,
      securityMgr: SecurityManager,
      hadoopConf: Configuration,
      timestamp: Long,
      useCache: Boolean) {
    val fileName = url.split("/").last
    val targetFile = new File(targetDir, fileName)
    val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true)
    if (useCache && fetchCacheEnabled) {
      val cachedFileName = s"${url.hashCode}${timestamp}_cache"
      val lockFileName = s"${url.hashCode}${timestamp}_lock"
      val localDir = new File(getLocalDir(conf))
      val lockFile = new File(localDir, lockFileName)
      val raf = new RandomAccessFile(lockFile, "rw")
      val lock = raf.getChannel().lock()
      val cachedFile = new File(localDir, cachedFileName)
      try {
        if (!cachedFile.exists()) {
          doFetchFile(url, localDir, cachedFileName, conf, securityMgr, hadoopConf)
        }
      } finally {
        lock.release()
      }
      copyFile(
        url,
        cachedFile,
        targetFile,
        conf.getBoolean("spark.files.overwrite", false)
      )
    } else {
      doFetchFile(url, targetDir, fileName, conf, securityMgr, hadoopConf)
    }
    if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) {
      logInfo("Untarring " + fileName)
      Utils.execute(Seq("tar", "-xzf", fileName), targetDir)
    } else if (fileName.endsWith(".tar")) {
      logInfo("Untarring " + fileName)
      Utils.execute(Seq("tar", "-xf", fileName), targetDir)
    }
    FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
  }

 

18.executeAndGetOutput

功能描述:执行一条command命令,并且获取它的输出。调用stdoutThread的join方法,让当前线程等待stdoutThread执行完成。

def executeAndGetOutput(
      command: Seq[String],
      workingDir: File = new File("."),
      extraEnvironment: Map[String, String] = Map.empty): String = {
    val builder = new ProcessBuilder(command: _*)
        .directory(workingDir)
    val environment = builder.environment()
    for ((key, value) <- extraEnvironment) {
      environment.put(key, value)
    }
    val process = builder.start()
    new Thread("read stderr for " + command(0)) {
      override def run() {
        for (line <- Source.fromInputStream(process.getErrorStream).getLines()) {
          System.err.println(line)
        }
      }
    }.start()
    val output = new StringBuffer
    val stdoutThread = new Thread("read stdout for " + command(0)) {
      override def run() {
        for (line <- Source.fromInputStream(process.getInputStream).getLines()) {
          output.append(line)
        }
      }
    }
    stdoutThread.start()
    val exitCode = process.waitFor()
    stdoutThread.join()   // Wait for it to finish reading output
    if (exitCode != 0) {
      logError(s"Process $command exited with code $exitCode: $output")
      throw new SparkException(s"Process $command exited with code $exitCode")
    }
    output.toString
  }

 

19.memoryStringToMb

功能描述:将内存大小字符串转换为以MB为单位的整型值。

  def memoryStringToMb(str: String): Int = {
    val lower = str.toLowerCase
    if (lower.endsWith("k")) {
      (lower.substring(0, lower.length-1).toLong / 1024).toInt
    } else if (lower.endsWith("m")) {
      lower.substring(0, lower.length-1).toInt
    } else if (lower.endsWith("g")) {
      lower.substring(0, lower.length-1).toInt * 1024
    } else if (lower.endsWith("t")) {
      lower.substring(0, lower.length-1).toInt * 1024 * 1024
    } else {// no suffix, so it's just a number in bytes
      (lower.toLong / 1024 / 1024).toInt
    }
  }

 

posted @ 2016-03-23 14:44  泰山不老生  阅读(4589)  评论(0编辑  收藏  举报