diff options
Diffstat (limited to 'roles/caption/templates')
-rwxr-xr-x | roles/caption/templates/caption.sh | 18 | ||||
-rwxr-xr-x | roles/caption/templates/process-captions.py | 231 |
2 files changed, 231 insertions, 18 deletions
diff --git a/roles/caption/templates/caption.sh b/roles/caption/templates/caption.sh deleted file mode 100755 index 9600a3c..0000000 --- a/roles/caption/templates/caption.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -# {{ ansible_managed }} -FILE="$1" -MODEL="${2:small}" -AUDIO=$(basename "$FILE" | sed s/\\.[a-z][a-z][a-z][a-z]?$//).ogg -if [[ ! -f $AUDIO ]]; then - if [[ "$FILE" == *webm ]]; then - ffmpeg -y -i "$FILE" -acodec copy -vn $AUDIO - else - ffmpeg -y -i "$FILE" -acodec libvorbis -vn $AUDIO - fi -fi -date > $AUDIO-$MODEL.log -time whisper $AUDIO --model $MODEL --threads 12 >> $AUDIO-$MODEL.log -for EXT in vtt txt srt; do - mv $AUDIO.$EXT $(basename -s .webm.$EXT $AUDIO.$EXT) -done -date >> $AUDIO-$MODEL.log diff --git a/roles/caption/templates/process-captions.py b/roles/caption/templates/process-captions.py new file mode 100755 index 0000000..6ad890a --- /dev/null +++ b/roles/caption/templates/process-captions.py @@ -0,0 +1,231 @@ +#!/usr/bin/python3 +"""Use OpenAI Whisper to automatically generate captions for the video files in the specified directory.""" + +# {{ ansible_managed }} + +# The MIT License (MIT) +# Copyright © 2022 Sacha Chua <sacha@sachachua.com> + +# Permission is hereby granted, free of charge, to any person +# obtaining a copy of this software and associated documentation files +# (the “Software”), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from collections import defaultdict +import subprocess +import datetime +import sys +import webvtt +import xml.etree.ElementTree as ET +from lhotse import RecordingSet, Recording, AudioSource, SupervisionSegment, SupervisionSet, create_cut_set_eager, align_with_torchaudio, CutSet, annotate_with_whisper +from tqdm import tqdm +import whisper +import re +import os +import json +import torch + +THREADS = 12 +VIDEO_REGEXP = '\.(webm|mov)$' +AUDIO_REGEXP = '\.(ogg|opus)$' +ALWAYS = False +TRIM_AUDIO = False +MODEL = os.environ.get('MODEL', 'large') # Set to tiny for testing +JSON_FILE = '/data/emacsconf/2022/talks.json' + +def get_slug_from_filename(filename): + m = re.search('emacsconf-[0-9]+-([a-z]+)--', filename) + if m: + return m.group(1) + else: + return os.path.basename(os.path.dirname(filename)) + +def get_files_to_work_on(directory): + """Return the list of audio files to work on. + The specified directory is checked recursively. + Skip any videos that already have caption files. + + Convert any videos that don't already have audio files, and return the audio files instead. + When there are multiple videos and audio files for a talk, pick one. + """ + info = defaultdict(lambda: {}, {}) + directory = os.path.expanduser(directory) + for folder, subs, files in os.walk(directory): + for filename in files: + f = os.path.join(folder, filename) + slug = get_slug_from_filename(f) + info[slug]['slug'] = slug + if re.search(AUDIO_REGEXP, filename): + info[slug]['audio'] = f + elif re.search(VIDEO_REGEXP, filename): + info[slug]['video'] = f + elif re.search('vtt$', filename): + info[slug]['vtt'] = f + elif re.search('srv2$', filename): + info[slug]['srv2'] = f + needs_work = [] + if JSON_FILE: + with open(JSON_FILE) as f: + talks = json.load(f)['talks'] + for key, val in info.items(): + if not 'video' in val and not 'audio' in val: continue + if talks: + talk = next(filter(lambda talk: talk['slug'] == val['slug'], talks), None) + if talk: + val['base'] = os.path.join(os.path.dirname(val['video'] or val['audio']), + base_name(talk['video-slug'])) + else: + val['base'] = os.path.join(os.path.dirname(val['video'] or val['audio']), + base_name(val['video'] or val['audio'])) + if ALWAYS or (not 'vtt' in val or not 'srv2' in val): + if not 'audio' in val and 'video' in val: + # No audio, need to convert it + val = extract_audio(val) + needs_work.append(val) + return needs_work + +def extract_audio(work): + output = subprocess.check_output(['ffprobe', video_file], stderr=subprocess.STDOUT) + extension = 'opus' + if 'Audio: vorbis' in output.decode(): + extension = 'ogg' + new_file = os.path.join(os.path.dirname(video_file), base_name(video_file) + '.' + extension) + acodec = 'copy' if re.search('webm$', video_file) else 'libopus' + log("Extracting audio from %s acodec %s" % (video_file, acodec)) + output = subprocess.check_output(['ffmpeg', '-y', '-i', video_file, '-acodec', acodec, '-vn', new_file], stderr=subprocess.STDOUT) + work['audio'] = new_file + return work + +def to_sec(time_str): + "Convert a WebVTT time into seconds." + h, m, s, ms = re.split('[\\.:]', time_str) + return int(h) * 3600 + int(m) * 60 + int(s) + (int(ms) / 1000) + +def log(s): + print(datetime.datetime.now(), s) + +def clean_up_timestamps(result): + segs = list(result['segments']) + seg_len = len(segs) + for i, seg in enumerate(segs[:-1]): + seg['end'] = min(segs[i + 1]['start'] - 0.001, seg['end']) + result['segments'] = segs + return result + +def generate_captions(work): + """Generate a VTT file based on the audio file.""" + log("Generating captions") + new_file = work['base'] + '.vtt' + model = whisper.load_model(MODEL, device="cuda" if torch.cuda.is_available() else "cpu") + audio = whisper.load_audio(work['audio']) + if TRIM_AUDIO: + audio = whisper.pad_or_trim(audio) + result = model.transcribe(audio, verbose=True) + result = clean_up_timestamps(result) + with open(new_file, 'w') as vtt: + whisper.utils.write_vtt(result['segments'], file=vtt) + work['vtt'] = new_file + if 'srv2' in work: del work['srv2'] + return work + +def generate_srv2(work): + """Generate a SRV2 file.""" + log("Generating SRV2") + recs = RecordingSet.from_recordings([Recording.from_file(work['audio'])]) + rec_id = recs[0].id + captions = [] + for i, caption in enumerate(webvtt.read(work['vtt'])): + if TRIM_AUDIO and i > 2: break + captions.append(SupervisionSegment(id=rec_id + '-sup' + '%05d' % i, channel=recs[0].channel_ids[0], recording_id=rec_id, start=to_sec(caption.start), duration=to_sec(caption.end) - to_sec(caption.start), text=caption.text, language='English')) + sups = SupervisionSet.from_segments(captions) + main = CutSet.from_manifests(recordings=recs, supervisions=sups) + work['cuts'] = main.trim_to_supervisions(keep_overlapping=False,keep_all_channels=True) + cuts_aligned = align_with_torchaudio(work['cuts']) + root = ET.Element("timedtext") + doc = ET.SubElement(root, "window") + for line, aligned in enumerate(cuts_aligned): + # Numbers are weird + words = re.split(' ', captions[line].text) + tokenized_words = [re.sub('[^\'A-Z0-9]', '', w.upper()) for w in words] + if len(aligned.supervisions) == 0: + print(captions[line], aligned) + continue + aligned_words = list(aligned.supervisions[0].alignment['word']) + aligned_index = 0 + aligned_len = len(aligned_words) + word_index = 0 + word_len = len(words) + while word_index < word_len and aligned_index < aligned_len: + # log("Choosing %s %s" % (words[word_index], aligned_words[aligned_index].symbol)) + ET.SubElement(doc, 'text', + t=str(float(aligned_words[aligned_index].start)*1000), + d=str(float(aligned_words[aligned_index].duration)*1000), + w="1", + append="1").text = words[word_index] + if tokenized_words[word_index] != aligned_words[aligned_index].symbol and word_index < word_len - 1: + # Scan ahead for a word that maches the next word, but don't go too far + cur_aligned = aligned_index + while aligned_index < aligned_len and aligned_index < cur_aligned + 5 and aligned_words[aligned_index].symbol != tokenized_words[word_index + 1]: + log("Sliding to match %s %d %s" % (tokenized_words[word_index + 1], aligned_index, aligned_words[aligned_index].symbol)) + aligned_index = aligned_index + 1 + if not aligned_words[aligned_index].symbol == tokenized_words[word_index + 1]: + log("Resetting, couldn't find") + aligned_index = cur_aligned + 1 + else: + aligned_index = aligned_index + 1 + word_index = word_index + 1 + tree = ET.ElementTree(root) + work['srv2'] = work['base'] + '.srv2' + with open(work['srv2'], "w") as f: + tree.write(f.buffer) + return work + +def base_name(s): + """ + Return the base name of file so that we can add extensions to it. + Remove tokens like --normalized, --recoded, etc. + Make sure the filename has either --main or --questions. + """ + s = os.path.basename(s) + type = 'questions' if '--questions.' in s else 'main' + if TRIM_AUDIO: + type = 'test' + match = re.match('^(emacsconf-[0-9]+-[a-z]+--.*?--.*?)(--|\.)', s) + if (match): + return match.group(1) + '--' + type + else: + return os.path.splitext(s)[0] + '--' + type +# assert(base_name('/home/sachac/current/sqlite/emacsconf-2022-sqlite--using-sqlite-as-a-data-source-a-framework-and-an-example--andrew-hyatt--normalized.webm.vtt') == 'emacsconf-2022-sqlite--using-sqlite-as-a-data-source-a-framework-and-an-example--andrew-hyatt--main') + +log(f"MODEL {MODEL} ALWAYS {ALWAYS} TRIM_AUDIO {TRIM_AUDIO}") +directory = sys.argv[1] if len(sys.argv) > 1 else "~/current" +needs_work = get_files_to_work_on(directory) +if THREADS > 0: + torch.set_num_threads(THREADS) +for work in needs_work: + log("Started processing %s" % work['base']) + if work['audio']: + if ALWAYS or not 'vtt' in work: + work = generate_captions(work) + if ALWAYS or not 'srv2' in work: + work = generate_srv2(work) +# print("Aligning words", audio_file, datetime.datetime.now()) +# word_cuts = align_words(cuts) +# convert_cuts_to_word_timing(audio_file, word_cuts) + log("Done %s" % str(work['base'])) + |