Source code for bigdl.dataset.news20
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tarfile
from bigdl.dataset import base
import os
import sys
NEWS20_URL = 'http://qwone.com/~jason/20Newsgroups/20news-19997.tar.gz' # noqa
GLOVE_URL = 'http://nlp.stanford.edu/data/glove.6B.zip' # noqa
CLASS_NUM = 20
[docs]def download_news20(dest_dir):
file_name = "20news-19997.tar.gz"
file_abs_path = base.maybe_download(file_name, dest_dir, NEWS20_URL)
tar = tarfile.open(file_abs_path, "r:gz")
extracted_to = os.path.join(dest_dir, "20_newsgroups")
if not os.path.exists(extracted_to):
print("Extracting %s to %s" % (file_abs_path, extracted_to))
tar.extractall(dest_dir)
tar.close()
return extracted_to
[docs]def download_glove_w2v(dest_dir):
file_name = "glove.6B.zip"
file_abs_path = base.maybe_download(file_name, dest_dir, GLOVE_URL)
import zipfile
zip_ref = zipfile.ZipFile(file_abs_path, 'r')
extracted_to = os.path.join(dest_dir, "glove.6B")
if not os.path.exists(extracted_to):
print("Extracting %s to %s" % (file_abs_path, extracted_to))
zip_ref.extractall(extracted_to)
zip_ref.close()
return extracted_to
[docs]def get_news20(source_dir="/tmp/news20/"):
"""
Parse or download news20 if source_dir is empty.
:param source_dir: The directory storing news data.
:return: A list of (tokens, label)
"""
news_dir = download_news20(source_dir)
texts = [] # list of text samples
label_id = 0
for name in sorted(os.listdir(news_dir)):
path = os.path.join(news_dir, name)
label_id += 1
if os.path.isdir(path):
for fname in sorted(os.listdir(path)):
if fname.isdigit():
fpath = os.path.join(path, fname)
if sys.version_info < (3,):
f = open(fpath)
else:
f = open(fpath, encoding='latin-1')
content = f.read()
texts.append((content, label_id))
f.close()
print('Found %s texts.' % len(texts))
return texts
[docs]def get_glove_w2v(source_dir="/tmp/news20/", dim=100):
"""
Parse or download the pre-trained glove word2vec if source_dir is empty.
:param source_dir: The directory storing the pre-trained word2vec
:param dim: The dimension of a vector
:return: A dict mapping from word to vector
"""
w2v_dir = download_glove_w2v(source_dir)
w2v_path = os.path.join(w2v_dir, "glove.6B.%sd.txt" % dim)
if sys.version_info < (3,):
w2v_f = open(w2v_path)
else:
w2v_f = open(w2v_path, encoding='latin-1')
pre_w2v = {}
for line in w2v_f.readlines():
items = line.split(" ")
pre_w2v[items[0]] = [float(i) for i in items[1:]]
w2v_f.close()
return pre_w2v
if __name__ == "__main__":
get_news20("/tmp/news20/")
get_glove_w2v("/tmp/news20/")