Source code for bigdl.models.utils.model_broadcast
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
import gc
from tempfile import NamedTemporaryFile
from pyspark.cloudpickle import print_exec
from pyspark.broadcast import Broadcast
from pyspark.broadcast import _from_id
from bigdl.nn.layer import Model
def _from_id_and_type(bid, bigdl_type):
result = _from_id(bid)
return ModelBroadcast(path=result._path, bigdl_type=bigdl_type)
[docs]def broadcast_model(sc, layer):
return ModelBroadcast(sc, layer, sc._pickled_broadcast_vars)
[docs]class ModelBroadcast(Broadcast):
def __init__(self, sc=None, layer=None, pickle_registry=None, path=None, bigdl_type="float"):
"""
Should not be called directly by users -- use L{SparkContext.broadcast()}
instead.
"""
if layer is not None:
self.bigdl_type = layer.bigdl_type
else:
self.bigdl_type = bigdl_type
super(ModelBroadcast, self).__init__(sc, layer, pickle_registry, path)
[docs] def dump(self, value, f):
try:
value.saveModel(f.name, over_write=True)
except Exception as e:
msg = "Could not serialize broadcast: %s" % e.__class__.__name__
print_exec(sys.stderr)
raise ValueError(msg)
f.close()
return f.name
def _load(self, path):
return Model.loadModel(path, bigdl_type=self.bigdl_type)
@property
def value(self):
""" Return the broadcasted value
"""
if not hasattr(self, "_value") and self._path is not None:
self._value = self._load(self._path)
return self._value
def __reduce__(self):
if self._jbroadcast is None:
raise Exception("Broadcast can only be serialized in driver")
self._pickle_registry.add(self)
return _from_id_and_type, (self._jbroadcast.id(), self.bigdl_type)