apply_bpe.py 10.9 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Rico Sennrich
# flake8: noqa

"""Use operations learned with learn_bpe.py to encode a new text.
The text will not be smaller, but use only a fixed vocabulary, with rare words
encoded as variable-length sequences of subword units.

Reference:
Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units.
Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany.
"""
# This file is retrieved from https://github.com/rsennrich/subword-nmt

from __future__ import unicode_literals, division

import sys
import codecs
import io
import argparse
import json
import re
from collections import defaultdict

# hack for python2/3 compatibility
from io import open
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
28

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
29 30 31 32
argparse.open = open


class BPE(object):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
33
    def __init__(self, codes, separator="@@", vocab=None, glossaries=None):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
34 35
        # check version information
        firstline = codes.readline()
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
36 37 38 39 40 41 42
        if firstline.startswith("#version:"):
            self.version = tuple(
                [
                    int(x)
                    for x in re.sub(r"(\.0+)*$", "", firstline.split()[-1]).split(".")
                ]
            )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
43 44 45 46 47 48 49 50
        else:
            self.version = (0, 1)
            codes.seek(0)

        self.bpe_codes = [tuple(item.split()) for item in codes]

        # some hacking to deal with duplicates (only consider first instance)
        self.bpe_codes = dict(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
51 52
            [(code, i) for (i, code) in reversed(list(enumerate(self.bpe_codes)))]
        )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
53 54

        self.bpe_codes_reverse = dict(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
55 56
            [(pair[0] + pair[1], pair) for pair, i in self.bpe_codes.items()]
        )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
57 58 59 60 61 62 63 64 65 66 67 68 69

        self.separator = separator

        self.vocab = vocab

        self.glossaries = glossaries if glossaries else []

        self.cache = {}

    def segment(self, sentence):
        """segment single sentence (whitespace-tokenized string) with BPE encoding"""
        output = []
        for word in sentence.split():
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
70 71 72 73 74 75 76 77 78 79 80 81 82 83
            new_word = [
                out
                for segment in self._isolate_glossaries(word)
                for out in encode(
                    segment,
                    self.bpe_codes,
                    self.bpe_codes_reverse,
                    self.vocab,
                    self.separator,
                    self.version,
                    self.cache,
                    self.glossaries,
                )
            ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
84 85 86 87 88

            for item in new_word[:-1]:
                output.append(item + self.separator)
            output.append(new_word[-1])

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
89
        return " ".join(output)
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
90 91 92 93

    def _isolate_glossaries(self, word):
        word_segments = [word]
        for gloss in self.glossaries:
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
94 95 96 97 98
            word_segments = [
                out_segments
                for segment in word_segments
                for out_segments in isolate_glossary(segment, gloss)
            ]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
99 100 101 102 103 104
        return word_segments


def create_parser():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
105 106
        description="learn BPE-based word segmentation",
    )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
107 108

    parser.add_argument(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
109 110 111 112 113 114 115
        "--input",
        "-i",
        type=argparse.FileType("r"),
        default=sys.stdin,
        metavar="PATH",
        help="Input file (default: standard input).",
    )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
116
    parser.add_argument(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
117 118 119 120
        "--codes",
        "-c",
        type=argparse.FileType("r"),
        metavar="PATH",
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
121
        required=True,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
122 123
        help="File with BPE codes (created by learn_bpe.py).",
    )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
124
    parser.add_argument(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
125 126 127 128 129 130 131
        "--output",
        "-o",
        type=argparse.FileType("w"),
        default=sys.stdout,
        metavar="PATH",
        help="Output file (default: standard output)",
    )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
132
    parser.add_argument(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
133 134 135 136 137 138 139
        "--separator",
        "-s",
        type=str,
        default="@@",
        metavar="STR",
        help="Separator between non-final subword units (default: '%(default)s'))",
    )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
140
    parser.add_argument(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
141 142 143
        "--vocabulary",
        type=argparse.FileType("r"),
        default=None,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
144
        metavar="PATH",
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
145 146
        help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.",
    )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
147
    parser.add_argument(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
148 149 150
        "--vocabulary-threshold",
        type=int,
        default=None,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
151
        metavar="INT",
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
152 153
        help="Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV",
    )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
154
    parser.add_argument(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
155 156 157 158
        "--glossaries",
        type=str,
        nargs="+",
        default=None,
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
159
        metavar="STR",
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
160 161 162
        help="Glossaries. The strings provided in glossaries will not be affected"
        + "by the BPE (i.e. they will neither be broken into subwords, nor concatenated with other subwords",
    )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179

    return parser


def get_pairs(word):
    """Return set of symbol pairs in a word.

    word is represented as tuple of symbols (symbols being variable-length strings)
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
180 181 182 183 184 185 186 187 188 189 190
def encode(
    orig,
    bpe_codes,
    bpe_codes_reverse,
    vocab,
    separator,
    version,
    cache,
    glossaries=None,
):
    """Encode word based on list of BPE merge operations, which are applied consecutively"""
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
191 192 193 194 195 196 197 198 199

    if orig in cache:
        return cache[orig]

    if orig in glossaries:
        cache[orig] = (orig,)
        return (orig,)

    if version == (0, 1):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
200
        word = tuple(orig) + ("</w>",)
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
201
    elif version == (0, 2):  # more consistent handling of word-final segments
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
202
        word = tuple(orig[:-1]) + (orig[-1] + "</w>",)
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
203 204 205 206 207 208 209 210 211
    else:
        raise NotImplementedError

    pairs = get_pairs(word)

    if not pairs:
        return orig

    while True:
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
212
        bigram = min(pairs, key=lambda pair: bpe_codes.get(pair, float("inf")))
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        if bigram not in bpe_codes:
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break

            if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                new_word.append(first + second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)

    # don't print end-of-word symbols
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
241
    if word[-1] == "</w>":
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
242
        word = word[:-1]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
243 244
    elif word[-1].endswith("</w>"):
        word = word[:-1] + (word[-1].replace("</w>", ""),)
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
245 246 247 248 249 250 251 252 253 254 255 256 257 258

    if vocab:
        word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator)

    cache[orig] = word
    return word


def recursive_split(segment, bpe_codes, vocab, separator, final=False):
    """Recursively split segment into smaller units (by reversing BPE merges)
    until all units are either in-vocabulary, or cannot be split futher."""

    try:
        if final:
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
259
            left, right = bpe_codes[segment + "</w>"]
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
260 261 262 263
            right = right[:-4]
        else:
            left, right = bpe_codes[segment]
    except:
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
264
        # sys.stderr.write('cannot split {0} further.\n'.format(segment))
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
        yield segment
        return

    if left + separator in vocab:
        yield left
    else:
        for item in recursive_split(left, bpe_codes, vocab, separator, False):
            yield item

    if (final and right in vocab) or (not final and right + separator in vocab):
        yield right
    else:
        for item in recursive_split(right, bpe_codes, vocab, separator, final):
            yield item


def check_vocab_and_split(orig, bpe_codes, vocab, separator):
    """Check for each segment in word if it is in-vocabulary,
    and segment OOV segments into smaller units by reversing the BPE merge operations"""

    out = []

    for segment in orig[:-1]:
        if segment + separator in vocab:
            out.append(segment)
        else:
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
291
            # sys.stderr.write('OOV: {0}\n'.format(segment))
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
292 293 294 295 296 297 298
            for item in recursive_split(segment, bpe_codes, vocab, separator, False):
                out.append(item)

    segment = orig[-1]
    if segment in vocab:
        out.append(segment)
    else:
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
299
        # sys.stderr.write('OOV: {0}\n'.format(segment))
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
300 301 302 303 304 305 306
        for item in recursive_split(segment, bpe_codes, vocab, separator, True):
            out.append(item)

    return out


def read_vocabulary(vocab_file, threshold):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
307
    """read vocabulary file produced by get_vocab.py, and filter according to frequency threshold."""
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323

    vocabulary = set()

    for line in vocab_file:
        word, freq = line.split()
        freq = int(freq)
        if threshold == None or freq >= threshold:
            vocabulary.add(word)

    return vocabulary


def isolate_glossary(word, glossary):
    """
    Isolate a glossary present inside a word.

Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
324
    Returns a list of subwords. In which all 'glossary' glossaries are isolated
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
325 326 327 328 329 330 331 332

    For example, if 'USA' is the glossary and '1934USABUSA' the word, the return value is:
        ['1934', 'USA', 'B', 'USA']
    """
    if word == glossary or glossary not in word:
        return [word]
    else:
        splits = word.split(glossary)
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
333 334 335 336 337 338 339
        segments = [
            segment.strip()
            for split in splits[:-1]
            for segment in [split, glossary]
            if segment != ""
        ]
        return segments + [splits[-1].strip()] if splits[-1] != "" else segments
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
340 341


Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
342
if __name__ == "__main__":
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
343 344
    # python 2/3 compatibility
    if sys.version_info < (3, 0):
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
345 346 347
        sys.stderr = codecs.getwriter("UTF-8")(sys.stderr)
        sys.stdout = codecs.getwriter("UTF-8")(sys.stdout)
        sys.stdin = codecs.getreader("UTF-8")(sys.stdin)
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
348
    else:
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
349 350
        sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf-8")
        sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
351
        sys.stdout = io.TextIOWrapper(
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
352 353
            sys.stdout.buffer, encoding="utf-8", write_through=True, line_buffering=True
        )
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
354 355 356 357 358

    parser = create_parser()
    args = parser.parse_args()

    # read/write files as UTF-8
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
359 360 361 362 363
    args.codes = codecs.open(args.codes.name, encoding="utf-8")
    if args.input.name != "<stdin>":
        args.input = codecs.open(args.input.name, encoding="utf-8")
    if args.output.name != "<stdout>":
        args.output = codecs.open(args.output.name, "w", encoding="utf-8")
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
364
    if args.vocabulary:
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
365
        args.vocabulary = codecs.open(args.vocabulary.name, encoding="utf-8")
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
366 367

    if args.vocabulary:
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
368
        vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold)
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
369 370 371 372 373 374 375
    else:
        vocabulary = None

    bpe = BPE(args.codes, args.separator, vocabulary, args.glossaries)

    for line in args.input:
        args.output.write(bpe.segment(line).strip())
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
376
        args.output.write("\n")