Source code for bigdl.dataset.base

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import sys
import shutil
import tempfile
import time
from distutils.dir_util import mkpath
from six.moves.urllib.request import urlopen
import numpy as np


# Adopt from keras
[docs]class Progbar(object): def __init__(self, target, width=30, verbose=1, interval=0.01): ''' :param target: total number of steps expected :param interval: minimum visual progress update interval (in seconds) ''' self.width = width self.target = target self.sum_values = {} self.unique_values = [] self.start = time.time() self.last_update = 0 self.interval = interval self.total_width = 0 self.seen_so_far = 0 self.verbose = verbose
[docs] def update(self, current, values=[], force=False): ''' :param current: index of current step :param values: list of tuples (name, value_for_last_step).The progress bar will display averages for these values. :param force: force visual progress update ''' for k, v in values: if k not in self.sum_values: self.sum_values[k] = [v * (current - self.seen_so_far), current - self.seen_so_far] self.unique_values.append(k) else: self.sum_values[k][0] += v * (current - self.seen_so_far) self.sum_values[k][1] += (current - self.seen_so_far) self.seen_so_far = current now = time.time() if self.verbose == 1: if not force and (now - self.last_update) < self.interval: return prev_total_width = self.total_width sys.stdout.write("\b" * prev_total_width) sys.stdout.write("\r") numdigits = int(np.floor(np.log10(self.target))) + 1 barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) bar = barstr % (current, self.target) prog = float(current) / self.target prog_width = int(self.width * prog) if prog_width > 0: bar += ('=' * (prog_width-1)) if current < self.target: bar += '>' else: bar += '=' bar += ('.' * (self.width - prog_width)) bar += ']' sys.stdout.write(bar) self.total_width = len(bar) if current: time_per_unit = (now - self.start) / current else: time_per_unit = 0 eta = time_per_unit * (self.target - current) info = '' if current < self.target: info += ' - ETA: %ds' % eta else: info += ' - %ds' % (now - self.start) for k in self.unique_values: info += ' - %s:' % k if type(self.sum_values[k]) is list: avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) if abs(avg) > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg else: info += ' %s' % self.sum_values[k] self.total_width += len(info) if prev_total_width > self.total_width: info += ((prev_total_width - self.total_width) * " ") sys.stdout.write(info) sys.stdout.flush() if current >= self.target: sys.stdout.write("\n") if self.verbose == 2: if current >= self.target: info = '%ds' % (now - self.start) for k in self.unique_values: info += ' - %s:' % k avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) if avg > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg sys.stdout.write(info + "\n") self.last_update = now
[docs] def add(self, n, values=[]): self.update(self.seen_so_far + n, values)
[docs]def display_table(rows, positions): def display_row(objects, positions): line = '' for i in range(len(objects)): line += str(objects[i]) line = line[:positions[i]] line += ' ' * (positions[i] - len(line)) print(line) for objects in rows: display_row(objects, positions)
# Adopt from keras # Under Python 2, 'urlretrieve' relies on FancyURLopener from legacy # urllib module, known to have issues with proxy management if sys.version_info[0] == 2:
[docs] def urlretrieve(url, filename, reporthook=None, data=None): def chunk_read(response, chunk_size=8192, reporthook=None): total_size = response.info().get('Content-Length').strip() total_size = int(total_size) count = 0 while 1: chunk = response.read(chunk_size) count += 1 if not chunk: reporthook(count, total_size, total_size) break if reporthook: reporthook(count, chunk_size, total_size) yield chunk response = urlopen(url, data) with open(filename, 'wb') as fd: for chunk in chunk_read(response, reporthook=reporthook): fd.write(chunk)
else: from six.moves.urllib.request import urlretrieve
[docs]def maybe_download(filename, work_directory, source_url): if not os.path.exists(work_directory): mkpath(work_directory) filepath = os.path.join(work_directory, filename) if not os.path.exists(filepath): print('Downloading data from', source_url) global progbar progbar = None def dl_progress(count, block_size, total_size): global progbar if progbar is None: progbar = Progbar(total_size) else: progbar.update(count * block_size) with tempfile.NamedTemporaryFile() as tmpfile: temp_file_name = tmpfile.name urlretrieve(source_url, temp_file_name, dl_progress) shutil.copy2(temp_file_name, filepath) size = os.path.getsize(filepath) print('Successfully downloaded', filename, size, 'bytes.') return filepath