Source code for bigdl.util.engine

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys
import os
import glob
import warnings


[docs]def exist_pyspark(): # check whether pyspark package exists try: import pyspark return True except ImportError: return False
[docs]def check_spark_source_conflict(spark_home, pyspark_path): # check if both spark_home env var and pyspark package exist # trigger a warning if two spark sources don't match if spark_home and not pyspark_path.startswith(spark_home): warning_msg = "Find both SPARK_HOME and pyspark. You may need to check whether they " + \ "match with each other. SPARK_HOME environment variable is set to: " + spark_home + \ ", and pyspark is found in: " + pyspark_path + ". If they are unmatched, " + \ "please use one source only to avoid conflict. " + \ "For example, you can unset SPARK_HOME and use pyspark only." warnings.warn(warning_msg)
def __sys_path_insert(file_path): if file_path not in sys.path: print("Prepending %s to sys.path" % file_path) sys.path.insert(0, file_path) def __prepare_spark_env(): spark_home = os.environ.get('SPARK_HOME', None) if exist_pyspark(): # use pyspark as the spark source import pyspark check_spark_source_conflict(spark_home, pyspark.__file__) else: # use SPARK_HOME as the spark source if not spark_home: raise ValueError( """Could not find Spark. Please make sure SPARK_HOME env is set: export SPARK_HOME=path to your spark home directory.""") print("Using %s" % spark_home) py4j = glob.glob(os.path.join(spark_home, 'python/lib', 'py4j-*.zip'))[0] pyspark = glob.glob(os.path.join(spark_home, 'python/lib', 'pyspark*.zip'))[0] __sys_path_insert(py4j) __sys_path_insert(pyspark) def __prepare_bigdl_env(): jar_dir = os.path.abspath(__file__ + "/../../") conf_paths = glob.glob(os.path.join(jar_dir, "share/conf/*.conf")) bigdl_classpath = get_bigdl_classpath() def append_path(env_var_name, jar_path): try: if jar_path not in os.environ[env_var_name].split(":"): print("Adding %s to %s" % (jar_path, env_var_name)) os.environ[env_var_name] = jar_path + ":" + os.environ[env_var_name] # noqa except KeyError: os.environ[env_var_name] = jar_path if bigdl_classpath: append_path("BIGDL_JARS", bigdl_classpath) if conf_paths: assert len(conf_paths) == 1, "Expecting one conf: %s" % len(conf_paths) __sys_path_insert(conf_paths[0]) if os.environ.get("BIGDL_JARS", None) and is_spark_below_2_2(): for jar in os.environ["BIGDL_JARS"].split(":"): append_path("SPARK_CLASSPATH", jar) if os.environ.get("BIGDL_PACKAGES", None): for package in os.environ["BIGDL_PACKAGES"].split(":"): __sys_path_insert(package)
[docs]def get_bigdl_classpath(): """ Get and return the jar path for bigdl if exists. """ if os.getenv("BIGDL_CLASSPATH"): return os.environ["BIGDL_CLASSPATH"] jar_dir = os.path.abspath(__file__ + "/../../") jar_paths = glob.glob(os.path.join(jar_dir, "share/lib/*.jar")) if jar_paths: assert len(jar_paths) == 1, "Expecting one jar: %s" % len(jar_paths) return jar_paths[0] return ""
[docs]def is_spark_below_2_2(): """ Check if spark version is below 2.2 """ import pyspark if(hasattr(pyspark,"version")): full_version = pyspark.version.__version__ # We only need the general spark version (eg, 1.6, 2.2). parts = full_version.split(".") spark_version = parts[0] + "." + parts[1] if(compare_version(spark_version, "2.2")>=0): return False return True
[docs]def compare_version(version1, version2): """ Compare version strings. :param version1; :param version2; :return: 1 if version1 is after version2; -1 if version1 is before version2; 0 if two versions are the same. """ v1Arr = version1.split(".") v2Arr = version2.split(".") len1 = len(v1Arr) len2 = len(v2Arr) lenMax = max(len1, len2) for x in range(lenMax): v1Token = 0 if x < len1: v1Token = int(v1Arr[x]) v2Token = 0 if x < len2: v2Token = int(v2Arr[x]) if v1Token < v2Token: return -1 if v1Token > v2Token: return 1 return 0
[docs]def prepare_env(): __prepare_spark_env() __prepare_bigdl_env()