summaryrefslogtreecommitdiffstats
path: root/roles/caption/templates/process-captions.py
diff options
context:
space:
mode:
Diffstat (limited to 'roles/caption/templates/process-captions.py')
-rwxr-xr-xroles/caption/templates/process-captions.py84
1 files changed, 33 insertions, 51 deletions
diff --git a/roles/caption/templates/process-captions.py b/roles/caption/templates/process-captions.py
index 72e9ad2..66f39dd 100755
--- a/roles/caption/templates/process-captions.py
+++ b/roles/caption/templates/process-captions.py
@@ -32,16 +32,12 @@ import datetime
import sys
import webvtt
import xml.etree.ElementTree as ET
-from lhotse import RecordingSet, Recording, AudioSource, SupervisionSegment, SupervisionSet, create_cut_set_eager, align_with_torchaudio, CutSet, annotate_with_whisper
-from tqdm import tqdm
-import whisper
-import re
-import os
+from lhotse import RecordinRecording, AudioSource, SupervisionSegment, SupervisionSet, create_cut_set_e
import json
import torch
-THREADS = 12
-VIDEO_REGEXP = '\.(webm|mov)$'
+THREADS = {{ cpus }}
+VIDEO_REGEXP = '\.(webm|mov|mp4)$'
AUDIO_REGEXP = '\.(ogg|opus)$'
ALWAYS = False
TRIM_AUDIO = False
@@ -49,6 +45,8 @@ MODEL = os.environ.get('MODEL', 'large') # Set to tiny for testing
WORK_DIR = "{{ emacsconf_caption_dir }}"
JSON_FILE = os.path.join(WORK_DIR, 'talks.json')
+# ----------------------------------------------------------------
+
def get_slug_from_filename(filename):
m = re.search('emacsconf-[0-9]+-([a-z]+)--', filename)
if m:
@@ -155,41 +153,23 @@ def generate_srv2(work):
captions.append(SupervisionSegment(id=rec_id + '-sup' + '%05d' % i, channel=recs[0].channel_ids[0], recording_id=rec_id, start=to_sec(caption.start), duration=to_sec(caption.end) - to_sec(caption.start), text=caption.text, language='English'))
sups = SupervisionSet.from_segments(captions)
main = CutSet.from_manifests(recordings=recs, supervisions=sups)
- work['cuts'] = main.trim_to_supervisions(keep_overlapping=False,keep_all_channels=True)
+ work['cuts'] = main.trim_to_supervisions(keep_all_channels=True)
cuts_aligned = align_with_torchaudio(work['cuts'])
root = ET.Element("timedtext")
doc = ET.SubElement(root, "window")
for line, aligned in enumerate(cuts_aligned):
- # Numbers are weird
- words = re.split(' ', captions[line].text)
- tokenized_words = [re.sub('[^\'A-Z0-9]', '', w.upper()) for w in words]
- if len(aligned.supervisions) == 0:
- print(captions[line], aligned)
- continue
- aligned_words = list(aligned.supervisions[0].alignment['word'])
- aligned_index = 0
- aligned_len = len(aligned_words)
- word_index = 0
- word_len = len(words)
- while word_index < word_len and aligned_index < aligned_len:
- # log("Choosing %s %s" % (words[word_index], aligned_words[aligned_index].symbol))
- ET.SubElement(doc, 'text',
- t=str(float(aligned_words[aligned_index].start)*1000),
- d=str(float(aligned_words[aligned_index].duration)*1000),
- w="1",
- append="1").text = words[word_index]
- if tokenized_words[word_index] != aligned_words[aligned_index].symbol and word_index < word_len - 1:
- # Scan ahead for a word that maches the next word, but don't go too far
- cur_aligned = aligned_index
- while aligned_index < aligned_len and aligned_index < cur_aligned + 5 and aligned_words[aligned_index].symbol != tokenized_words[word_index + 1]:
- log("Sliding to match %s %d %s" % (tokenized_words[word_index + 1], aligned_index, aligned_words[aligned_index].symbol))
- aligned_index = aligned_index + 1
- if not aligned_words[aligned_index].symbol == tokenized_words[word_index + 1]:
- log("Resetting, couldn't find")
- aligned_index = cur_aligned + 1
- else:
- aligned_index = aligned_index + 1
- word_index = word_index + 1
+ if len(aligned.supervisions) > 0:
+ aligned_words = aligned.supervisions[0].alignment['word']
+ for w, word in enumerate(aligned_words):
+ el = ET.SubElement(doc, 'text',
+ t=str(float(word.start)*1000),
+ d=str(float(word.duration)*1000),
+ w="1",
+ append="1")
+ el.text = word.symbol
+ el.tail = "\n"
+ else:
+ print("No supervisions", aligned)
tree = ET.ElementTree(root)
work['srv2'] = work['base'] + '.srv2'
with open(work['srv2'], "w") as f:
@@ -218,18 +198,20 @@ directory = sys.argv[1] if len(sys.argv) > 1 else WORK_DIR
needs_work = get_files_to_work_on(directory)
if len(needs_work) > 0:
- if THREADS > 0:
- torch.set_num_threads(THREADS)
- for work in needs_work:
- log("Started processing %s" % work['base'])
- if work['audio']:
- if ALWAYS or not 'vtt' in work:
- work = generate_captions(work)
- if ALWAYS or not 'srv2' in work:
- work = generate_srv2(work)
- # print("Aligning words", audio_file, datetime.datetime.now())
- # word_cuts = align_words(cuts)
- # convert_cuts_to_word_timing(audio_file, word_cuts)
- log("Done %s" % str(work['base']))
+ while len(needs_work) > 0:
+ if THREADS > 0:
+ torch.set_num_threads(THREADS)
+ for work in needs_work:
+ log("Started processing %s" % work['base'])
+ if work['audio']:
+ if ALWAYS or not 'vtt' in work:
+ work = generate_captions(work)
+ if ALWAYS or not 'srv2' in work:
+ work = generate_srv2(work)
+ # print("Aligning words", audio_file, datetime.datetime.now())
+ # word_cuts = align_words(cuts)
+ # convert_cuts_to_word_timing(audio_file, word_cuts)
+ log("Done %s" % str(work['base']))
+ needs_work = get_files_to_work_on(directory)
else:
log("No work needed.")