summaryrefslogtreecommitdiffstats
path: root/roles/caption/templates/process-captions.py
diff options
context:
space:
mode:
authorSacha Chua <sacha@sachachua.com>2022-10-25 09:07:36 -0400
committerSacha Chua <sacha@sachachua.com>2022-10-25 09:07:36 -0400
commit66b3f5f472ac5fddf0c0e43181f7718af9075d83 (patch)
tree20b8a35aa403c8b6f453d061f09185351b9addd0 /roles/caption/templates/process-captions.py
parent8637995c0f20672553c192907a36d3c8519b61d4 (diff)
downloademacsconf-ansible-66b3f5f472ac5fddf0c0e43181f7718af9075d83.tar.xz
emacsconf-ansible-66b3f5f472ac5fddf0c0e43181f7718af9075d83.zip
process-captions
Diffstat (limited to 'roles/caption/templates/process-captions.py')
-rwxr-xr-xroles/caption/templates/process-captions.py231
1 files changed, 231 insertions, 0 deletions
diff --git a/roles/caption/templates/process-captions.py b/roles/caption/templates/process-captions.py
new file mode 100755
index 0000000..6ad890a
--- /dev/null
+++ b/roles/caption/templates/process-captions.py
@@ -0,0 +1,231 @@
+#!/usr/bin/python3
+"""Use OpenAI Whisper to automatically generate captions for the video files in the specified directory."""
+
+# {{ ansible_managed }}
+
+# The MIT License (MIT)
+# Copyright © 2022 Sacha Chua <sacha@sachachua.com>
+
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation files
+# (the “Software”), to deal in the Software without restriction,
+# including without limitation the rights to use, copy, modify, merge,
+# publish, distribute, sublicense, and/or sell copies of the Software,
+# and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+from collections import defaultdict
+import subprocess
+import datetime
+import sys
+import webvtt
+import xml.etree.ElementTree as ET
+from lhotse import RecordingSet, Recording, AudioSource, SupervisionSegment, SupervisionSet, create_cut_set_eager, align_with_torchaudio, CutSet, annotate_with_whisper
+from tqdm import tqdm
+import whisper
+import re
+import os
+import json
+import torch
+
+THREADS = 12
+VIDEO_REGEXP = '\.(webm|mov)$'
+AUDIO_REGEXP = '\.(ogg|opus)$'
+ALWAYS = False
+TRIM_AUDIO = False
+MODEL = os.environ.get('MODEL', 'large') # Set to tiny for testing
+JSON_FILE = '/data/emacsconf/2022/talks.json'
+
+def get_slug_from_filename(filename):
+ m = re.search('emacsconf-[0-9]+-([a-z]+)--', filename)
+ if m:
+ return m.group(1)
+ else:
+ return os.path.basename(os.path.dirname(filename))
+
+def get_files_to_work_on(directory):
+ """Return the list of audio files to work on.
+ The specified directory is checked recursively.
+ Skip any videos that already have caption files.
+
+ Convert any videos that don't already have audio files, and return the audio files instead.
+ When there are multiple videos and audio files for a talk, pick one.
+ """
+ info = defaultdict(lambda: {}, {})
+ directory = os.path.expanduser(directory)
+ for folder, subs, files in os.walk(directory):
+ for filename in files:
+ f = os.path.join(folder, filename)
+ slug = get_slug_from_filename(f)
+ info[slug]['slug'] = slug
+ if re.search(AUDIO_REGEXP, filename):
+ info[slug]['audio'] = f
+ elif re.search(VIDEO_REGEXP, filename):
+ info[slug]['video'] = f
+ elif re.search('vtt$', filename):
+ info[slug]['vtt'] = f
+ elif re.search('srv2$', filename):
+ info[slug]['srv2'] = f
+ needs_work = []
+ if JSON_FILE:
+ with open(JSON_FILE) as f:
+ talks = json.load(f)['talks']
+ for key, val in info.items():
+ if not 'video' in val and not 'audio' in val: continue
+ if talks:
+ talk = next(filter(lambda talk: talk['slug'] == val['slug'], talks), None)
+ if talk:
+ val['base'] = os.path.join(os.path.dirname(val['video'] or val['audio']),
+ base_name(talk['video-slug']))
+ else:
+ val['base'] = os.path.join(os.path.dirname(val['video'] or val['audio']),
+ base_name(val['video'] or val['audio']))
+ if ALWAYS or (not 'vtt' in val or not 'srv2' in val):
+ if not 'audio' in val and 'video' in val:
+ # No audio, need to convert it
+ val = extract_audio(val)
+ needs_work.append(val)
+ return needs_work
+
+def extract_audio(work):
+ output = subprocess.check_output(['ffprobe', video_file], stderr=subprocess.STDOUT)
+ extension = 'opus'
+ if 'Audio: vorbis' in output.decode():
+ extension = 'ogg'
+ new_file = os.path.join(os.path.dirname(video_file), base_name(video_file) + '.' + extension)
+ acodec = 'copy' if re.search('webm$', video_file) else 'libopus'
+ log("Extracting audio from %s acodec %s" % (video_file, acodec))
+ output = subprocess.check_output(['ffmpeg', '-y', '-i', video_file, '-acodec', acodec, '-vn', new_file], stderr=subprocess.STDOUT)
+ work['audio'] = new_file
+ return work
+
+def to_sec(time_str):
+ "Convert a WebVTT time into seconds."
+ h, m, s, ms = re.split('[\\.:]', time_str)
+ return int(h) * 3600 + int(m) * 60 + int(s) + (int(ms) / 1000)
+
+def log(s):
+ print(datetime.datetime.now(), s)
+
+def clean_up_timestamps(result):
+ segs = list(result['segments'])
+ seg_len = len(segs)
+ for i, seg in enumerate(segs[:-1]):
+ seg['end'] = min(segs[i + 1]['start'] - 0.001, seg['end'])
+ result['segments'] = segs
+ return result
+
+def generate_captions(work):
+ """Generate a VTT file based on the audio file."""
+ log("Generating captions")
+ new_file = work['base'] + '.vtt'
+ model = whisper.load_model(MODEL, device="cuda" if torch.cuda.is_available() else "cpu")
+ audio = whisper.load_audio(work['audio'])
+ if TRIM_AUDIO:
+ audio = whisper.pad_or_trim(audio)
+ result = model.transcribe(audio, verbose=True)
+ result = clean_up_timestamps(result)
+ with open(new_file, 'w') as vtt:
+ whisper.utils.write_vtt(result['segments'], file=vtt)
+ work['vtt'] = new_file
+ if 'srv2' in work: del work['srv2']
+ return work
+
+def generate_srv2(work):
+ """Generate a SRV2 file."""
+ log("Generating SRV2")
+ recs = RecordingSet.from_recordings([Recording.from_file(work['audio'])])
+ rec_id = recs[0].id
+ captions = []
+ for i, caption in enumerate(webvtt.read(work['vtt'])):
+ if TRIM_AUDIO and i > 2: break
+ captions.append(SupervisionSegment(id=rec_id + '-sup' + '%05d' % i, channel=recs[0].channel_ids[0], recording_id=rec_id, start=to_sec(caption.start), duration=to_sec(caption.end) - to_sec(caption.start), text=caption.text, language='English'))
+ sups = SupervisionSet.from_segments(captions)
+ main = CutSet.from_manifests(recordings=recs, supervisions=sups)
+ work['cuts'] = main.trim_to_supervisions(keep_overlapping=False,keep_all_channels=True)
+ cuts_aligned = align_with_torchaudio(work['cuts'])
+ root = ET.Element("timedtext")
+ doc = ET.SubElement(root, "window")
+ for line, aligned in enumerate(cuts_aligned):
+ # Numbers are weird
+ words = re.split(' ', captions[line].text)
+ tokenized_words = [re.sub('[^\'A-Z0-9]', '', w.upper()) for w in words]
+ if len(aligned.supervisions) == 0:
+ print(captions[line], aligned)
+ continue
+ aligned_words = list(aligned.supervisions[0].alignment['word'])
+ aligned_index = 0
+ aligned_len = len(aligned_words)
+ word_index = 0
+ word_len = len(words)
+ while word_index < word_len and aligned_index < aligned_len:
+ # log("Choosing %s %s" % (words[word_index], aligned_words[aligned_index].symbol))
+ ET.SubElement(doc, 'text',
+ t=str(float(aligned_words[aligned_index].start)*1000),
+ d=str(float(aligned_words[aligned_index].duration)*1000),
+ w="1",
+ append="1").text = words[word_index]
+ if tokenized_words[word_index] != aligned_words[aligned_index].symbol and word_index < word_len - 1:
+ # Scan ahead for a word that maches the next word, but don't go too far
+ cur_aligned = aligned_index
+ while aligned_index < aligned_len and aligned_index < cur_aligned + 5 and aligned_words[aligned_index].symbol != tokenized_words[word_index + 1]:
+ log("Sliding to match %s %d %s" % (tokenized_words[word_index + 1], aligned_index, aligned_words[aligned_index].symbol))
+ aligned_index = aligned_index + 1
+ if not aligned_words[aligned_index].symbol == tokenized_words[word_index + 1]:
+ log("Resetting, couldn't find")
+ aligned_index = cur_aligned + 1
+ else:
+ aligned_index = aligned_index + 1
+ word_index = word_index + 1
+ tree = ET.ElementTree(root)
+ work['srv2'] = work['base'] + '.srv2'
+ with open(work['srv2'], "w") as f:
+ tree.write(f.buffer)
+ return work
+
+def base_name(s):
+ """
+ Return the base name of file so that we can add extensions to it.
+ Remove tokens like --normalized, --recoded, etc.
+ Make sure the filename has either --main or --questions.
+ """
+ s = os.path.basename(s)
+ type = 'questions' if '--questions.' in s else 'main'
+ if TRIM_AUDIO:
+ type = 'test'
+ match = re.match('^(emacsconf-[0-9]+-[a-z]+--.*?--.*?)(--|\.)', s)
+ if (match):
+ return match.group(1) + '--' + type
+ else:
+ return os.path.splitext(s)[0] + '--' + type
+# assert(base_name('/home/sachac/current/sqlite/emacsconf-2022-sqlite--using-sqlite-as-a-data-source-a-framework-and-an-example--andrew-hyatt--normalized.webm.vtt') == 'emacsconf-2022-sqlite--using-sqlite-as-a-data-source-a-framework-and-an-example--andrew-hyatt--main')
+
+log(f"MODEL {MODEL} ALWAYS {ALWAYS} TRIM_AUDIO {TRIM_AUDIO}")
+directory = sys.argv[1] if len(sys.argv) > 1 else "~/current"
+needs_work = get_files_to_work_on(directory)
+if THREADS > 0:
+ torch.set_num_threads(THREADS)
+for work in needs_work:
+ log("Started processing %s" % work['base'])
+ if work['audio']:
+ if ALWAYS or not 'vtt' in work:
+ work = generate_captions(work)
+ if ALWAYS or not 'srv2' in work:
+ work = generate_srv2(work)
+# print("Aligning words", audio_file, datetime.datetime.now())
+# word_cuts = align_words(cuts)
+# convert_cuts_to_word_timing(audio_file, word_cuts)
+ log("Done %s" % str(work['base']))
+