asr_prep_json.py 3.69 KB
Newer Older
Nikhilesh Bhatnagar's avatar
Nikhilesh Bhatnagar committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function, unicode_literals

import argparse
import concurrent.futures
import json
import multiprocessing
import os
from collections import namedtuple
from itertools import chain

import sentencepiece as spm
from fairseq.data import Dictionary


MILLISECONDS_TO_SECONDS = 0.001


def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
    import torchaudio

    input = {}
    output = {}
    si, ei = torchaudio.info(aud_path)
    input["length_ms"] = int(
        si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
    )
    input["path"] = aud_path

    token = " ".join(sp.EncodeAsPieces(lable))
    ids = tgt_dict.encode_line(token, append_eos=False)
    output["text"] = lable
    output["token"] = token
    output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
    return {utt_id: {"input": input, "output": output}}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--audio-dirs",
        nargs="+",
        default=["-"],
        required=True,
        help="input directories with audio files",
    )
    parser.add_argument(
        "--labels",
        required=True,
        help="aggregated input labels with format <ID LABEL> per line",
        type=argparse.FileType("r", encoding="UTF-8"),
    )
    parser.add_argument(
        "--spm-model",
        required=True,
        help="sentencepiece model to use for encoding",
        type=argparse.FileType("r", encoding="UTF-8"),
    )
    parser.add_argument(
        "--dictionary",
        required=True,
        help="file to load fairseq dictionary from",
        type=argparse.FileType("r", encoding="UTF-8"),
    )
    parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
    parser.add_argument(
        "--output",
        required=True,
        type=argparse.FileType("w"),
        help="path to save json output",
    )
    args = parser.parse_args()

    sp = spm.SentencePieceProcessor()
    sp.Load(args.spm_model.name)

    tgt_dict = Dictionary.load(args.dictionary)

    labels = {}
    for line in args.labels:
        (utt_id, label) = line.split(" ", 1)
        labels[utt_id] = label
    if len(labels) == 0:
        raise Exception("No labels found in ", args.labels_path)

    Sample = namedtuple("Sample", "aud_path utt_id")
    samples = []
    for path, _, files in chain.from_iterable(
        os.walk(path) for path in args.audio_dirs
    ):
        for f in files:
            if f.endswith(args.audio_format):
                if len(os.path.splitext(f)) != 2:
                    raise Exception("Expect <utt_id.extension> file name. Got: ", f)
                utt_id = os.path.splitext(f)[0]
                if utt_id not in labels:
                    continue
                samples.append(Sample(os.path.join(path, f), utt_id))

    utts = {}
    num_cpu = multiprocessing.cpu_count()
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
        future_to_sample = {
            executor.submit(
                process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
            ): s
            for s in samples
        }
        for future in concurrent.futures.as_completed(future_to_sample):
            try:
                data = future.result()
            except Exception as exc:
                print("generated an exception: ", exc)
            else:
                utts.update(data)
    json.dump({"utts": utts}, args.output, indent=4)


if __name__ == "__main__":
    main()