Source code for bigdl.examples.imageframe.inception_validation

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from bigdl.util.common import *

from bigdl.transform.vision.image import *

from bigdl.optim.optimizer import *

from pyspark import SparkContext

from bigdl.nn.layer import *

from optparse import OptionParser

import sys

parser = OptionParser()
parser.add_option("-f", "--folder", type=str, dest="folder", default="",
                  help="url of hdfs folder store the hadoop sequence files")
parser.add_option("--model", type=str, dest="model", default="", help="model path")
parser.add_option("-b", "--batchSize", type=int, dest="batchSize", default=0, help="total batch size")
[docs]def get_data(url, sc=None, data_type="val"): path = os.path.join(url, data_type) return SeqFileFolder.files_to_image_frame(url=path, sc=sc, class_num=1000)
[docs]def run(image_path, model_path, batch_size): sparkConf = create_spark_conf().setAppName("test_validation") sc = get_spark_context(sparkConf) init_engine() transformer = Pipeline([PixelBytesToMat(), Resize(256, 256), CenterCrop(224, 224), ChannelNormalize(123.0, 117.0, 104.0), MatToTensor(), ImageFrameToSample(input_keys=["imageTensor"], target_keys=["label"])]) raw_image_frame = get_data(image_path, sc) transformed = transformer(raw_image_frame) model = Model.loadModel(model_path) result = model.evaluate(transformed, int(batch_size), [Top1Accuracy()]) print "top1 accuray", result[0]
if __name__ == "__main__": if len(sys.argv) != 3: print "parameters needed : <imagePath> <modelPath> <batchSize>" image_path = sys.argv[1] model_path = sys.argv[2] batch_size = sys.argv[3] run(image_path, model_path, batch_size)