Source code for bigdl.models.utils.model_broadcast
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
import gc
from tempfile import NamedTemporaryFile
from pyspark.cloudpickle import print_exec
from pyspark.broadcast import Broadcast
from pyspark.broadcast import _from_id
from bigdl.nn.layer import Model
def _from_id_and_type(bid, bigdl_type):
    result = _from_id(bid)
    return ModelBroadcast(path=result._path, bigdl_type=bigdl_type)
[docs]def broadcast_model(sc, layer):
    return ModelBroadcast(sc, layer, sc._pickled_broadcast_vars) 
[docs]class ModelBroadcast(Broadcast):
    def __init__(self, sc=None, layer=None, pickle_registry=None, path=None, bigdl_type="float"):
        """
        Should not be called directly by users -- use L{SparkContext.broadcast()}
        instead.
        """
        if layer is not None:
            self.bigdl_type = layer.bigdl_type
        else:
            self.bigdl_type = bigdl_type
        super(ModelBroadcast, self).__init__(sc, layer, pickle_registry, path)
[docs]    def dump(self, value, f):
        try:
            value.saveModel(f.name, over_write=True)
        except Exception as e:
            msg = "Could not serialize broadcast: %s" % e.__class__.__name__
            print_exec(sys.stderr)
            raise ValueError(msg)
        f.close()
        return f.name 
    def _load(self, path):
        return Model.loadModel(path, bigdl_type=self.bigdl_type)
    @property
    def value(self):
        """ Return the broadcasted value
        """
        if not hasattr(self, "_value") and self._path is not None:
            self._value = self._load(self._path)
        return self._value
    def __reduce__(self):
        if self._jbroadcast is None:
            raise Exception("Broadcast can only be serialized in driver")
        self._pickle_registry.add(self)
        return _from_id_and_type, (self._jbroadcast.id(), self.bigdl_type)