98 lines
2.6 KiB
Python
98 lines
2.6 KiB
Python
|
import logging
|
||
|
|
||
|
import config
|
||
|
|
||
|
|
||
|
def tokenize(sent):
|
||
|
return sent.split(" ")
|
||
|
|
||
|
|
||
|
class Lang:
|
||
|
"""Represents the vocabulary
|
||
|
"""
|
||
|
def __init__(self, name):
|
||
|
self.name = name
|
||
|
self.word2index = {
|
||
|
config.PAD: 0,
|
||
|
config.UNK: 1,
|
||
|
}
|
||
|
self.word2count = {}
|
||
|
self.index2word = {
|
||
|
0: config.PAD,
|
||
|
1: config.UNK,
|
||
|
}
|
||
|
self.n_words = 2
|
||
|
|
||
|
def add_sentence(self, sentence):
|
||
|
assert isinstance(
|
||
|
sentence, (list, tuple)
|
||
|
), "input to add_sentence must be tokenized"
|
||
|
for word in sentence:
|
||
|
self.add_word(word)
|
||
|
|
||
|
def add_word(self, word):
|
||
|
if word not in self.word2index:
|
||
|
self.word2index[word] = self.n_words
|
||
|
self.word2count[word] = 1
|
||
|
self.index2word[self.n_words] = word
|
||
|
self.n_words += 1
|
||
|
else:
|
||
|
self.word2count[word] += 1
|
||
|
|
||
|
def __add__(self, other):
|
||
|
"""Returns a new Lang object containing the vocabulary from this and
|
||
|
the other Lang object
|
||
|
"""
|
||
|
new_lang = Lang(f"{self.name}_{other.name}")
|
||
|
|
||
|
# Add vocabulary from both Langs
|
||
|
for word in self.word2count.keys():
|
||
|
new_lang.add_word(word)
|
||
|
for word in other.word2count.keys():
|
||
|
new_lang.add_word(word)
|
||
|
|
||
|
# Fix the counts on the new one
|
||
|
for word in new_lang.word2count.keys():
|
||
|
new_lang.word2count[word] = self.word2count.get(
|
||
|
word, 0
|
||
|
) + other.word2count.get(word, 0)
|
||
|
|
||
|
return new_lang
|
||
|
|
||
|
|
||
|
def load_google(split, max_len=None):
|
||
|
"""Load the Google Sentence Compression Dataset"""
|
||
|
logger = logging.getLogger(f"{__name__}.load_compression")
|
||
|
lang = Lang("compression")
|
||
|
|
||
|
if split == "train":
|
||
|
path = config.google_train_path
|
||
|
elif split == "val":
|
||
|
path = config.google_dev_path
|
||
|
elif split == "test":
|
||
|
path = config.google_test_path
|
||
|
|
||
|
logger.info("loading %s from %s" % (split, path))
|
||
|
|
||
|
data = []
|
||
|
sent = []
|
||
|
mask = []
|
||
|
with open(path) as handle:
|
||
|
for line in handle:
|
||
|
line = line.strip()
|
||
|
if line:
|
||
|
w, d = line.split("\t")
|
||
|
sent.append(w)
|
||
|
mask.append(int(d))
|
||
|
else:
|
||
|
if sent and (max_len is None or len(sent) <= max_len):
|
||
|
data.append([sent, mask])
|
||
|
lang.add_sentence(sent)
|
||
|
sent = []
|
||
|
mask = []
|
||
|
if sent:
|
||
|
data.append([tuple(sent), tuple(mask)])
|
||
|
lang.add_sentence(sent)
|
||
|
|
||
|
return data, lang
|