diff --git a/egs/aspire/s5/cmd.sh b/egs/aspire/s5/cmd.sh index 3ca5230bd5e..afc99eee4a4 100755 --- a/egs/aspire/s5/cmd.sh +++ b/egs/aspire/s5/cmd.sh @@ -6,11 +6,11 @@ # the number of cpus on your machine. #a) JHU cluster options -export train_cmd="queue.pl -l arch=*64" -export decode_cmd="queue.pl -l arch=*64,mem_free=2G,ram_free=2G" -export mkgraph_cmd="queue.pl -l arch=*64,ram_free=4G,mem_free=4G" +export train_cmd="queue.pl" +export decode_cmd="queue.pl --mem 2G" +export mkgraph_cmd="queue.pl --mem 4G" -export cuda_cmd="queue.pl -l gpu=1 -q g.q" +export cuda_cmd="queue.pl --gpu 1" #b) BUT cluster options diff --git a/egs/aspire/s5/conf/fbank.conf b/egs/aspire/s5/conf/fbank.conf new file mode 100644 index 00000000000..563960c088b --- /dev/null +++ b/egs/aspire/s5/conf/fbank.conf @@ -0,0 +1,6 @@ +# config for high-resolution Fbank features +--use-energy=false # do not add energy +--sample-frequency=8000 # Switchboard is sampled at 8kHz +--num-mel-bins=40 # similar to Google's setup. +--low-freq=40 # low cutoff frequency for mel bins +--high-freq=-200 # high cutoff frequently, relative to Nyquist of 4000 (=3800) diff --git a/egs/aspire/s5/conf/fbank_bp.conf b/egs/aspire/s5/conf/fbank_bp.conf new file mode 100644 index 00000000000..55522d9d687 --- /dev/null +++ b/egs/aspire/s5/conf/fbank_bp.conf @@ -0,0 +1,8 @@ +# config for high-resolution Fbank features +--use-energy=false # do not add energy +--sample-frequency=8000 # Switchboard is sampled at 8kHz +--num-mel-bins=152 # similar to Google's setup. +--num-fft-bins=512 +--low-freq=330 # low cutoff frequency for mel bins +--high-freq=-1000 # high cutoff frequently, relative to Nyquist of 4000 (=3000) + diff --git a/egs/aspire/s5/conf/mfcc_diarization.conf b/egs/aspire/s5/conf/mfcc_diarization.conf new file mode 100644 index 00000000000..cace02fbaff --- /dev/null +++ b/egs/aspire/s5/conf/mfcc_diarization.conf @@ -0,0 +1,6 @@ +--sample-frequency=8000 +--frame-length=25 # the default is 25, but we usually use 20 for SID +--low-freq=20 # the default. +--high-freq=3700 # the default is zero meaning use the Nyquist (4k in this case). +--num-ceps=20 # higher than the default which is 12. +--snip-edges=false diff --git a/egs/aspire/s5/conf/mfcc_hires_bp.conf b/egs/aspire/s5/conf/mfcc_hires_bp.conf new file mode 100644 index 00000000000..0f06a6fb571 --- /dev/null +++ b/egs/aspire/s5/conf/mfcc_hires_bp.conf @@ -0,0 +1,12 @@ +# config for high-resolution MFCC features, intended for neural network training. +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--sample-frequency=8000 # Switchboard is sampled at 8kHz +--num-mel-bins=152 # similar to Google's setup. +--num-ceps=152 # there is no dimensionality reduction. +--num-fft-bins=512 +--low-freq=330 # low cutoff frequency for mel bins +--high-freq=-1000 # high cutoff frequently, relative to Nyquist of 4000 (=3000) + diff --git a/egs/aspire/s5/conf/mfcc_vad.conf b/egs/aspire/s5/conf/mfcc_vad.conf new file mode 100644 index 00000000000..a5c9243eee0 --- /dev/null +++ b/egs/aspire/s5/conf/mfcc_vad.conf @@ -0,0 +1,5 @@ +--sample-frequency=8000 +--frame-length=25 # the default is 25. +--low-freq=20 # the default. +--high-freq=-300 # the default is zero meaning use the Nyquist (4k in this case). +--num-ceps=13 # higher than the default which is 12. diff --git a/egs/aspire/s5/conf/segmentation.conf b/egs/aspire/s5/conf/segmentation.conf new file mode 100644 index 00000000000..d9153f41635 --- /dev/null +++ b/egs/aspire/s5/conf/segmentation.conf @@ -0,0 +1,24 @@ +method=Viterbi + +# General segmentation options +max_intersegment_length=50 # Merge nearby speech segments if the silence + # between them is less than this many frames. +max_relabel_length=10 # maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +pad_length=5 # Pad speech segments by this many frames on either side +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=100 # Overlapping frames when segments are split. + # See the above option. + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +speech_to_sil_ratio=1 # the prior on speech vs silence + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + diff --git a/egs/aspire/s5/conf/segmentation_aspire.conf b/egs/aspire/s5/conf/segmentation_aspire.conf new file mode 100644 index 00000000000..c703560bff4 --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_aspire.conf @@ -0,0 +1,25 @@ +method=Viterbi + +# General segmentation options +max_intersegment_length=50 # Merge nearby speech segments if the silence + # between them is less than this many frames. +max_relabel_length=10 # maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +pad_length=5 # Pad speech segments by this many frames on either side +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=100 # Overlapping frames when segments are split. + # See the above option. + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +speech_to_sil_ratio=1 # the prior on speech vs silence + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + + diff --git a/egs/aspire/s5/conf/segmentation_babel.conf b/egs/aspire/s5/conf/segmentation_babel.conf new file mode 100644 index 00000000000..36395bb27dd --- /dev/null +++ b/egs/aspire/s5/conf/segmentation_babel.conf @@ -0,0 +1,26 @@ +method=Viterbi + +# General segmentation options +max_intersegment_length=100 # Merge nearby speech segments if the silence + # between them is less than this many frames. +max_relabel_length=10 # maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +pad_length=10 # Pad speech segments by this many frames on either side +post_pad_length=10 # Pad speech segments by this many frames on either side +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=100 # Overlapping frames when segments are split. + # See the above option. + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +speech_to_sil_ratio=1 # the prior on speech vs silence + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + + diff --git a/egs/aspire/s5/conf/vad_icsi_babel.conf b/egs/aspire/s5/conf/vad_icsi_babel.conf new file mode 100644 index 00000000000..70f651403f5 --- /dev/null +++ b/egs/aspire/s5/conf/vad_icsi_babel.conf @@ -0,0 +1,39 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/aspire/s5/conf/vad_icsi_babel_3models.conf b/egs/aspire/s5/conf/vad_icsi_babel_3models.conf new file mode 100644 index 00000000000..1196f0d2aff --- /dev/null +++ b/egs/aspire/s5/conf/vad_icsi_babel_3models.conf @@ -0,0 +1,54 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +num_frames_silence_phase3_init=2000 +num_frames_speech_phase3_init=2000 +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +sil_max_gauss_phase4=8 +speech_max_gauss_phase4=16 +sil_gauss_incr_phase3=1 +sil_gauss_incr_phase4=1 +speech_gauss_incr_phase4=2 +num_iters_phase3=5 +num_iters_phase4=5 + +## Phase 5 parameters +sil_num_gauss_init_phase5=2 +speech_num_gauss_init_phase5=2 +sil_max_gauss_phase5=5 +speech_max_gauss_phase5=12 +sil_gauss_incr_phase5=1 +speech_gauss_incr_phase5=2 +num_iters_phase5=7 + + + diff --git a/egs/aspire/s5/conf/vad_icsi_rt.conf b/egs/aspire/s5/conf/vad_icsi_rt.conf new file mode 100644 index 00000000000..c2964d5171d --- /dev/null +++ b/egs/aspire/s5/conf/vad_icsi_rt.conf @@ -0,0 +1,41 @@ +## Features paramters +window_size=10 # 1s +frames_per_gaussian=2000 + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=10000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + + diff --git a/egs/aspire/s5/conf/weights_segmentation_aspire.conf b/egs/aspire/s5/conf/weights_segmentation_aspire.conf new file mode 100644 index 00000000000..122e061f5f8 --- /dev/null +++ b/egs/aspire/s5/conf/weights_segmentation_aspire.conf @@ -0,0 +1,26 @@ +method=Viterbi + +# General segmentation options +max_intersegment_length=0 # Merge nearby speech segments if the silence + # between them is less than this many frames. +max_relabel_length=0 # maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +pad_length=0 # Pad speech segments by this many frames on either side +max_segment_length=2000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=0 # Overlapping frames when segments are split. + # See the above option. + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +speech_to_sil_ratio=0.1 # the prior on speech vs silence + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + + + diff --git a/egs/aspire/s5/conf/weights_segmentation_babel.conf b/egs/aspire/s5/conf/weights_segmentation_babel.conf new file mode 100644 index 00000000000..15901b3c23d --- /dev/null +++ b/egs/aspire/s5/conf/weights_segmentation_babel.conf @@ -0,0 +1,25 @@ +method=Viterbi + +# General segmentation options +max_intersegment_length=0 # Merge nearby speech segments if the silence + # between them is less than this many frames. +max_relabel_length=0 # maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +pad_length=0 # Pad speech segments by this many frames on either side +max_segment_length=2000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=0 # Overlapping frames when segments are split. + # See the above option. + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +speech_to_sil_ratio=0.1 # the prior on speech vs silence + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + + diff --git a/egs/aspire/s5/conf/zc_vad.conf b/egs/aspire/s5/conf/zc_vad.conf new file mode 100644 index 00000000000..1475967e7b1 --- /dev/null +++ b/egs/aspire/s5/conf/zc_vad.conf @@ -0,0 +1,4 @@ +--sample-frequency=8000 +--frame-length=25 # the default is 25. +--dither=0.0 +--zero-crossing-threshold=1e-5 diff --git a/egs/aspire/s5/diarization b/egs/aspire/s5/diarization new file mode 120000 index 00000000000..ba78a9126af --- /dev/null +++ b/egs/aspire/s5/diarization @@ -0,0 +1 @@ +../../sre08/v1/diarization \ No newline at end of file diff --git a/egs/aspire/s5/local/multi_condition/combine_ali_dirs.sh b/egs/aspire/s5/local/multi_condition/combine_ali_dirs.sh index 1bb276e7dae..09e479c8258 100755 --- a/egs/aspire/s5/local/multi_condition/combine_ali_dirs.sh +++ b/egs/aspire/s5/local/multi_condition/combine_ali_dirs.sh @@ -6,6 +6,7 @@ # Begin configuration section. extra_files= # specify addtional files in 'src-data-dir' to merge, ex. "file1 file2 ..." ref_data_dir= # data directory to be used as reference for rearranging alignments +cmd=run.pl # End configuration section. echo "$0 $@" # Print the command line for logging @@ -78,7 +79,7 @@ if [ ! -z "$ref_data_dir" ]; then awk -v p=\$ali_file '{printf "%s %s %s\n", \$1, p, NR}' > $temp_dir/ali_utt_index.\$JOB EOF chmod +x $temp_dir/create_ali_utt_index.sh - $decode_cmd -v PATH JOB=1:$num_jobs $temp_dir/ali_copy_int.JOB.log $temp_dir/create_ali_utt_index.sh JOB + $cmd -v PATH JOB=1:$num_jobs $temp_dir/ali_copy_int.JOB.log $temp_dir/create_ali_utt_index.sh JOB cat <$temp_dir/create_new_ali.py @@ -147,7 +148,7 @@ EOF # split the ref_data_dir to get reference utt2spk for individual ali.JOB.gz files utils/split_data.sh $ref_data_dir $num_jobs - $decode_cmd -v PATH JOB=1:$num_jobs $temp_dir/create_new_ali.JOB.run.log \ + $cmd JOB=1:$num_jobs $temp_dir/create_new_ali.JOB.run.log \ python $temp_dir/create_new_ali.py \ $ref_data_dir/split$num_jobs/JOB/utt2spk \ $temp_dir/create_new_ali.JOB.sh $temp_dir/ali.JOB.gz || exit 1; diff --git a/egs/aspire/s5/local/multi_condition/copy_ali_dir.sh b/egs/aspire/s5/local/multi_condition/copy_ali_dir.sh index b0d475a7710..3b88965b717 100755 --- a/egs/aspire/s5/local/multi_condition/copy_ali_dir.sh +++ b/egs/aspire/s5/local/multi_condition/copy_ali_dir.sh @@ -18,6 +18,7 @@ # begin configuration section utt_prefix= utt_suffix= +cmd=run.pl # end configuration section . utils/parse_options.sh @@ -72,6 +73,6 @@ for line in sys.stdin: set +o pipefail; # unset the pipefail option. EOF chmod +x $dest_dir/temp/copy_ali.sh -$decode_cmd -v PATH JOB=1:$nj $dest_dir/temp/copy_ali.JOB.log $dest_dir/temp/copy_ali.sh JOB || exit 1; +$cmd -v PATH JOB=1:$nj $dest_dir/temp/copy_ali.JOB.log $dest_dir/temp/copy_ali.sh JOB || exit 1; echo "$0: copied alignments from $src_dir to $dest_dir" diff --git a/egs/aspire/s5/local/multi_condition/create_uniform_segments.py b/egs/aspire/s5/local/multi_condition/create_uniform_segments.py index 68280500f6b..494848837cd 100755 --- a/egs/aspire/s5/local/multi_condition/create_uniform_segments.py +++ b/egs/aspire/s5/local/multi_condition/create_uniform_segments.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # Copyright 2014 Johns Hopkins University (Authors: Daniel Povey, Vijayaditya Peddinti). Apache 2.0. -# creates a segments file in the provided data directory +# creates a segments file in the provided data directory # into uniform segments with specified window and overlap import imp, sys, argparse, os, math, subprocess @@ -16,7 +16,7 @@ def segment(total_length, window_length, overlap = 0): segments[-2] = (segments[-2][0], segments[-1][1]) segments.pop() return segments - + def get_wave_segments(wav_command, window_length, overlap): raw_output = subprocess.check_output(wav_command+" sox -t wav - -n stat 2>&1 | grep Length ", shell = True) parts = raw_output.split(":") @@ -31,7 +31,7 @@ def prepare_segments_file(kaldi_data_dir, window_length, overlap): raise Exception("Not a proper kaldi data directory") ids = [] files = [] - for line in open(kaldi_data_dir+'/wav.scp').readlines(): + for line in open(kaldi_data_dir+'/wav.scp').readlines(): parts = line.split() ids.append(parts[0]) files.append(" ".join(parts[1:])) @@ -54,6 +54,7 @@ def prepare_segments_file(kaldi_data_dir, window_length, overlap): parser = argparse.ArgumentParser() parser.add_argument('--window-length', type = float, default = 30.0, help = 'length of the window used to cut the segment') parser.add_argument('--overlap', type = float, default = 5.0, help = 'overlap of neighboring windows') + parser.add_argument('--base-segments', type = str, help = 'Create subsegments of the base segments') parser.add_argument('data_dir', type=str, help='directory such as data/train') params = parser.parse_args() diff --git a/egs/aspire/s5/local/multi_condition/fill_missing_recordings.py b/egs/aspire/s5/local/multi_condition/fill_missing_recordings.py new file mode 100644 index 00000000000..7eb9578301e --- /dev/null +++ b/egs/aspire/s5/local/multi_condition/fill_missing_recordings.py @@ -0,0 +1,50 @@ +# Copyright Johns Hopkins University (Author: Vijayaditya Peddinti) 2015. Apache 2.0. +# This script checks the ctm for missing recordings and places a dummy id. +# This is necessary for compliance with sclite scoring scripts + +import argparse, sys + +def fill_ctm(input_ctm_file, output_ctm_file, recording_names): + recording_index = 0 + with open(input_ctm_file, "r") as infile, open(output_ctm_file, "w") as outfile: + for line in infile: + if line.split()[0] == recording_names[recording_index]: + outfile.write(line) + else: + processed_line = False + recording_index += 1 + while not processed_line: + if recording_index >= len(recording_names): + raise Exception("There is a mismatch between the recording_names_file and the ctm file. There are recordings in ctm file which are not present in the recording file.") + if line.split()[0] == recording_names[recording_index]: + outfile.write(line) + processed_line = True + else: + # there is a missing recording + outfile.write("{0} 1 0.00 0.01 NOTHINGWASDECODEDHERE\n".format(recording_names[recording_index])) + recording_index += 1 + infile.close() + outfile.close() + + + +if __name__ == "__main__": + usage = """ Python script to check the ctm file for missing recordings + and provide a single line default output for the missing recordings. It assumes + that the ctm file has recordings in the same order as the wav.scp file""" + + + sys.stderr.write(str(" ".join(sys.argv))) + parser = argparse.ArgumentParser(usage) + parser.add_argument('input_ctm_file', type=str, help='ctm file for the recordings') + parser.add_argument('output_ctm_file', type=str, help='ctm file for the recordings') + parser.add_argument('recording_name_file', type=str, help='file with names of the recordings') + + params = parser.parse_args() + + try: + file_names = map(lambda x: x.strip(), open("{0}".format(params.recording_name_file)).readlines()) + except IOError: + raise Exception("Expected to find {0}".format(params.scp_file)) + + fill_ctm(params.input_ctm_file, params.output_ctm_file, file_names) diff --git a/egs/aspire/s5/local/multi_condition/get_ctm.sh b/egs/aspire/s5/local/multi_condition/get_ctm.sh index e6bac453e52..8bd81fec63f 100755 --- a/egs/aspire/s5/local/multi_condition/get_ctm.sh +++ b/egs/aspire/s5/local/multi_condition/get_ctm.sh @@ -5,8 +5,7 @@ decode_mbr=true filter_ctm_command=cp glm= stm= -window=10 -overlap=5 + [ -f ./path.sh ] && . ./path.sh . parse_options.sh || exit 1; @@ -47,7 +46,7 @@ lattice-align-words-lexicon --output-error-lats=true --output-if-empty=true --ma lattice-to-ctm-conf --decode-mbr=$decode_mbr ark:- $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping || exit 1; # combine the segment-wise ctm files, while resolving overlaps -python local/multi_condition/resolve_ctm_overlaps.py --overlap $overlap --window-length $window $data_dir/utt2spk $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping $decode_dir/score_$LMWT/penalty_$wip/ctm.merged || exit 1; +python local/multi_condition/resolve_ctm_overlaps.py --segments $data_dir/segments $data_dir/utt2spk $decode_dir/score_$LMWT/penalty_$wip/ctm.overlapping $decode_dir/score_$LMWT/penalty_$wip/ctm.merged || exit 1; merged_ctm=$decode_dir/score_$LMWT/penalty_$wip/ctm.merged cat $merged_ctm | utils/int2sym.pl -f 5 $lang/words.txt | \ diff --git a/egs/aspire/s5/local/multi_condition/prep_test_aspire_diarization.sh b/egs/aspire/s5/local/multi_condition/prep_test_aspire_diarization.sh new file mode 100755 index 00000000000..cba4693e46f --- /dev/null +++ b/egs/aspire/s5/local/multi_condition/prep_test_aspire_diarization.sh @@ -0,0 +1,511 @@ +#!/bin/bash +# Copyright Johns Hopkins University (Author: Daniel Povey, Vijayaditya Peddinti) 2015. +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +# This script generates the ctm files for dev_aspire, test_aspire and eval_aspire +# for scoring with ASpIRE scoring server. +# It also provides the WER for dev_aspire data. + +set -u +set -o pipefail +set -e + +use_icsi_method=false +iter=final +mfccdir=mfcc_reverb_submission +stage=0 +decode_num_jobs=200 +num_jobs=30 +LMWT=12 +word_ins_penalty=0 +min_lmwt=9 +max_lmwt=20 +word_ins_penalties=0.0,0.25,0.5,0.75,1.0 +decode_mbr=true +acwt=0.1 +lattice_beam=8 +ctm_beam=6 +do_segmentation=true +max_count=100 # parameter for extract_ivectors.sh +sub_speaker_frames=1500 +overlap=5 +window=30 +affix= +ivector_scale=1.0 +pad_frames=0 # this did not seem to be helpful but leaving it as an option. +tune_hyper=true +pass2_decode_opts= +filter_ctm=true +weights_file= +silence_weight=0.00001 +create_whole_dir=true +use_vad_prob=false +use_lats=true +transform_weights=false +speech_to_sil_ratio=1 +use_bootstrap_vad=false +nj=30 +. cmd.sh + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 4 ]; then + echo "Usage: $0 [options] " + echo " Options:" + echo " --stage (0|1|2) # start scoring script from part-way through." + echo "e.g.:" + echo "$0 data/train data/lang exp/nnet2_multicondition/nnet_ms_a" + exit 1; +fi + +data_dir=$1 #select from data/{dev_aspire,test_aspire,eval_aspire} +vad_model_dir=$2 +lang=$3 # data/lang +dir=$4 # exp/nnet2_multicondition/nnet_ms_a + +data_id=`basename $data_dir` +model_affix=`basename $dir` +ivector_dir=`dirname $dir` +ivector_affix=${affix:+_$affix}_$model_affix +vad_dir=exp/vad_${data_id}_${affix} +affix=_${affix}_iter${iter} +act_data_id=${data_id} +if [ "$data_id" == "test_aspire" ]; then + out_file=single_dev_test${affix}_$model_affix.ctm + if [ $stage -le 0 ]; then + steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" --mfcc-config conf/mfcc_hires.conf ${data_dir} exp/make_mfcc_reverb/${data_dir} $mfccdir || exit 1; + steps/compute_cmvn_stats.sh $data_dir exp/make_mfcc_reverb/${data_dir} $mfccdir || exit 1; + fi +elif [ "$data_id" == "eval_aspire" ]; then + out_file=single_eval${affix}_$model_affix.ctm + if [ $stage -le 0 ]; then + steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" --mfcc-config conf/mfcc_hires.conf ${data_dir} exp/make_mfcc_reverb/${data_dir} $mfccdir || exit 1; + steps/compute_cmvn_stats.sh $data_dir exp/make_mfcc_reverb/${data_dir} $mfccdir || exit 1; + fi +else + if $create_whole_dir; then + if [ $stage -le 0 ]; then + echo "Creating the data dir with whole recordings without segmentation" + # create a whole directory without the segments + unseg_dir=data/${data_id}_whole + src_dir=data/$data_id + mkdir -p $unseg_dir + echo "Creating the $unseg_dir/wav.scp file" + cp $src_dir/wav.scp $unseg_dir + + echo "Creating the $unseg_dir/reco2file_and_channel file" + cat $unseg_dir/wav.scp | awk '{print $1, $1, "A";}' > $unseg_dir/reco2file_and_channel + cat $unseg_dir/wav.scp | awk '{print $1, $1;}' > $unseg_dir/utt2spk + utils/utt2spk_to_spk2utt.pl $unseg_dir/utt2spk > $unseg_dir/spk2utt + + steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" --mfcc-config conf/mfcc_hires.conf $unseg_dir exp/make_mfcc_reverb/${data_id}_whole $mfccdir || exit 1; + steps/compute_cmvn_stats.sh $unseg_dir exp/make_mfcc_reverb/${data_id}_whole $mfccdir || exit 1; + fi + data_id=${data_id}_whole + fi + out_file=single_dev${affix}_${model_affix}.ctm +fi + +if [ $stage -le 1 ]; then + echo "Generating uniform segments for VAD with length 600" + mkdir -p ${vad_dir} + rm -rf ${vad_dir}/data_uniform_windows600 + copy_data_dir.sh --validate-opts "--no-text" data/$data_id ${vad_dir}/data_uniform_windows600 || exit 1 + cp data/$data_id/reco2file_and_channel ${vad_dir}/data_uniform_windows600 || exit 1 + python local/multi_condition/create_uniform_segments.py --overlap 0 --window 600 ${vad_dir}/data_uniform_windows600 || exit 1 + for file in cmvn.scp feats.scp; do + rm -f ${vad_dir}/data_uniform_windows600/$file + done + utils/validate_data_dir.sh --no-text --no-feats ${vad_dir}/data_uniform_windows600 || exit 1 +fi +if [ $stage -le 2 ]; then + diarization/prepare_data.sh --nj $nj --cmd "$train_cmd" ${vad_dir}/data_uniform_windows600 ${vad_dir} ${vad_dir}/mfcc || exit 1 +fi + +split_data.sh ${vad_dir}/data_uniform_windows600 $nj + +noise_model_found=false +if [ -f $vad_model_dir/noise.11.mdl ]; then + noise_model_found=true +fi + +if [ $stage -le 3 ]; then + if $noise_model_found; then + $train_cmd JOB=1:$nj ${vad_dir}/do_vad.JOB.log \ + diarization/vad_gmm_3models.sh --config conf/vad_icsi_babel_3models.conf \ + --try-merge-speech-noise true --output-lattice $use_lats --write-feats true \ + --speech-to-sil-ratio $speech_to_sil_ratio --use-bootstrap-vad $use_bootstrap_vad \ + ${vad_dir}/data_uniform_windows600/split$nj/JOB \ + $vad_model_dir/silence.11.mdl $vad_model_dir/speech.11.mdl \ + $vad_model_dir/noise.11.mdl ${vad_dir}/JOB || exit 1 + else + if ! $use_icsi_method; then + $train_cmd JOB=1:$nj ${vad_dir}/do_vad.JOB.log \ + diarization/vad_gmm_2models.sh --config conf/vad_icsi_babel_3models.conf \ + --try-merge-speech-noise true --output-lattice $use_lats --write-feats true \ + --speech-to-sil-ratio $speech_to_sil_ratio \ + ${vad_dir}/data_uniform_windows600/split$nj/JOB \ + $vad_model_dir/silence.11.mdl $vad_model_dir/speech.11.mdl \ + ${vad_dir}/JOB || exit 1 + else + $train_cmd JOB=1:$nj ${vad_dir}/do_vad.JOB.log \ + diarization/vad_gmm_icsi.sh --config conf/vad_icsi_babel.conf \ + --try-merge-speech-noise true --output-lattice $use_lats --write-feats true \ + --speech-to-sil-ratio $speech_to_sil_ratio \ + ${vad_dir}/data_uniform_windows600/split$nj/JOB \ + $vad_model_dir/silence.11.mdl $vad_model_dir/speech.11.mdl \ + ${vad_dir}/JOB || exit 1 + fi + fi + + for n in `seq $nj`; do + for x in `cat ${vad_dir}/data_uniform_windows600/split$nj/$n/utt2spk | awk '{print $1}'`; do + cat ${vad_dir}/$n/$x.vad.final.scp + done + done | sort -k1,1 > ${vad_dir}/vad.scp +fi + +segmented_data_dir=data/${data_id}_uniformsegmented_win${window}_over${overlap} + +if [ $stage -le 4 ]; then + if $use_bootstrap_vad || ! $use_vad_prob; then + $train_cmd ${vad_dir}/get_vad_per_file.log \ + segmentation-to-rttm \ + --segments=${vad_dir}/data_uniform_windows600/segments \ + scp:${vad_dir}/vad.scp - \| grep SPEECH \| \ + rttmSort.pl \| diarization/convert_rttm_to_segments.pl \| \ + segmentation-init-from-segments - ark:${vad_dir}/vad_per_file.ark + else + if $use_lats; then + for n in `seq $nj`; do + for x in `cat ${vad_dir}/data_uniform_windows600/split$nj/$n/utt2spk | awk '{print $1}'`; do + cat ${vad_dir}/$n/$x.lat.scp + done | tee ${vad_dir}/$n/lats.scp + done | sort -k1,1 > ${vad_dir}/lats.scp + + $train_cmd JOB=1:$nj ${vad_dir}/log/get_vad_weights.JOB.log \ + lattice-to-post scp:${vad_dir}/JOB/lats.scp ark:- \| \ + post-to-pdf-post ${vad_dir}/JOB/trans.mdl ark:- ark:- \| \ + weight-pdf-post $silence_weight 0:2 ark:- ark:- \| \ + post-to-weights ark:- ark,t:- \| \ + copy-vector ark,t:- ark:${vad_dir}/weights.JOB.ark + else + for n in `seq $nj`; do + for x in `cat ${vad_dir}/data_uniform_windows600/split$nj/$n/utt2spk | awk '{print $1}'`; do + gmm-compute-likes ${vad_dir}/$n/$x.final.mdl ark:${vad_dir}/$n/$x.feat.ark ark:- | \ + loglikes-to-post ark:- ark:- | \ + weight-pdf-post $silence_weight 0:2 ark:- ark:- | \ + post-to-weights ark:- ark,t:- | \ + copy-vector ark,t:- ark:- + done > ${vad_dir}/weights.$n.ark + done + fi + fi +fi + + +diarized_data_dir=${data_dir}_diarized${ivector_affix} +diarization_dir=$ivector_dir/diarization_${data_id}${ivector_affix} + +if [ $stage -le 5 ]; then + local/run_diarization.sh --nj $nj $data_dir $lang \ + "ark:gunzip -c exp/nnet2_multicondition/ivector_weights_dev_aspire_whole_uniformsegmented_win10_over5_v18_voiced_256_64_64_nnet_ms_a/file_weights.gz |" $ivector_dir/extractor \ + exp/nnet2_multicondition/ivectors_train_offline/plda \ + $diarization_dir || exit 1 +fi + +if [ $stage -le 6 ]; then + rm -rf $diarized_data_dir + utils/copy_data_dir.sh --validate-opts "--no-text" $data_dir $diarized_data_dir || exit 1 + + for n in `seq $nj`; do + cat $diarization_dir/diarization/data_out/utt2spk.$n || exit 1 + done > $diarized_data_dir/utt2spk || exit 1 + + for n in `seq $nj`; do + cat $diarization_dir/diarization/data_out/segments.$n || exit 1 + done > $diarized_data_dir/segments || exit 1 + + rm -f $diarized_data_dir/{feats,cmvn}.scp + rm -f $diarized_data_dir/text + + utils/utt2spk_to_spk2utt.pl $diarized_data_dir/utt2spk > $diarized_data_dir/spk2utt || exit 1 + utils/fix_data_dir.sh $diarized_data_dir + utils/validate_data_dir.sh --no-text --no-feats $diarized_data_dir || exit 1 +fi + +segmented_data_id=`basename $segmented_data_dir` +diarized_data_id=`basename $diarized_data_dir` +if [ $stage -le 7 ]; then + echo "Extracting features for the segments" + # extract the features/i-vectors once again so that they are indexed by utterance and not by recording + rm -rf data/${segmented_data_id}_hires + copy_data_dir.sh --validate-opts "--no-text " data/${segmented_data_id} data/${segmented_data_id}_hires || exit 1; + steps/make_mfcc.sh --nj 10 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${segmented_data_id}_hires \ + exp/make_reverb_hires/${segmented_data_id} $mfccdir || exit 1; + steps/compute_cmvn_stats.sh data/${segmented_data_id}_hires exp/make_reverb_hires/${segmented_data_id} $mfccdir || exit 1; + utils/fix_data_dir.sh data/${segmented_data_id}_hires + utils/validate_data_dir.sh --no-text data/${segmented_data_id}_hires + + echo "Extracting features for the segments" + # extract the features/i-vectors once again so that they are indexed by utterance and not by recording + rm -rf data/${diarized_data_id}_hires + copy_data_dir.sh --validate-opts "--no-text " data/${diarized_data_id} data/${diarized_data_id}_hires || exit 1; + steps/make_mfcc.sh --nj 10 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${diarized_data_id}_hires \ + exp/make_reverb_hires/${diarized_data_id} $mfccdir || exit 1; + steps/compute_cmvn_stats.sh data/${diarized_data_id}_hires exp/make_reverb_hires/${diarized_data_id} $mfccdir || exit 1; + utils/fix_data_dir.sh data/${diarized_data_id}_hires + utils/validate_data_dir.sh --no-text data/${diarized_data_id}_hires +fi + +if [ ! -z $weights_file ]; then + echo "$0: Using provided weights file $weights_file" + ivector_extractor_input=$weights_file +else + if [ $stage -le 8 ]; then + mkdir -p $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix} + $train_cmd $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/log/get_file_lengths.log \ + feat-to-len scp:data/${data_id}/feats.scp \ + ark,t:$ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/file_lengths.ark + + if ! $use_vad_prob; then + segmentation-to-ali --default-label=0 \ + --lengths=ark,t:$ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/file_lengths.ark \ + ark:${vad_dir}/vad_per_file.ark ark,t:- | \ + perl -e ' + my $silence_weight = shift @ARGV; + while () { + chomp; + @A = split; + $utt = shift @A; + print STDOUT "$utt ["; + for ($i = 0; $i <= $#A; $i++) { + if ($A[$i] == 0) { + print STDOUT " $silence_weight"; + } else { + print STDOUT " $A[$i]"; + } + } + print STDOUT " ]\n"; + }' $silence_weight | copy-vector ark,t:- "ark:| gzip -c > $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/file_weights.gz" + else + weight_vecs= + for n in `seq $nj`; do + weight_vecs="${weight_vecs}${vad_dir}/weights.$n.ark " + done + + awk '${print $1" "$2}' ${vad_dir}/data_uniform_windows600/segments | utils/utt2spk_to_spk2utt.pl > ${vad_dir}/data_uniform_windows600/reco2utt + + $train_cmd $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/log/get_vad_file_weights.log \ + combine-vector-segments --max-overshoot=2 --overlap=0 "ark:cat $weight_vecs|" \ + ark:${vad_dir}/data_uniform_windows600/reco2utt ark:${vad_dir}/data_uniform_windows600/segments \ + ark,t:$ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/file_lengths.ark \ + "ark:| gzip -c > $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/file_weights.gz" + fi + fi + + cat $diarized_data_dir/segments | awk '{print $1" "$2" "$3" "$4-0.02}' > $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/truncated_segments + + x_th=0.8 + if [ $stage -le 9 ]; then + if $transform_weights; then + $train_cmd $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/log/extract_weights.log \ + extract-vector-segments "ark:gunzip -c $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/file_weights.gz |" \ + $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/truncated_segments ark,t:- \| \ + awk -v x_th=$x_th '{printf $1" [ "; for(i=3;i<=NF-1;i++) printf 1/sqrt(1+2*exp(-20*($i-x_th)))" " ; print "]"}' \| \ + copy-vector ark,t:- "ark:| gzip -c >$ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/weights.gz" + else + $train_cmd $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/log/extract_weights.log \ + extract-vector-segments "ark:gunzip -c $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/file_weights.gz |" \ + $ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/truncated_segments \ + "ark:| gzip -c >$ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/weights.gz" + fi + fi + ivector_extractor_input=$ivector_dir/ivector_weights_${diarized_data_id}${ivector_affix}/weights.gz +fi + +if [ $stage -le 10 ]; then + echo "Extracting i-vectors, stage 2 with input $ivector_extractor_input" + # this does offline decoding, except we estimate the iVectors per + # speaker, excluding silence (based on alignments from a GMM decoding), with a + # different script. This is just to demonstrate that script. + # the --sub-speaker-frames is optional; if provided, it will divide each speaker + # up into "sub-speakers" of at least that many frames... can be useful if + # acoustic conditions drift over time within the speaker's data. + steps/online/nnet2/extract_ivectors_for_recording.sh --cmd "$train_cmd" --nj 20 \ + --silence-weight $silence_weight \ + --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ + data/${diarized_data_id}_hires $lang $ivector_dir/extractor \ + $ivector_extractor_input $ivector_dir/ivectors_reco${ivector_affix} || exit 1; +fi +if [ $stage -le 11 ]; then + steps/online/nnet2/segment_recording_ivectors.sh ${segmented_data_dir}_hires \ + $ivector_dir/ivectors_reco${ivector_affix} \ + $ivector_dir/ivectors_${segmented_data_id}${ivector_affix} +fi + +decode_dir=$dir/decode_${segmented_data_id}${affix}_pp +if [ $stage -le 12 ]; then + echo "Generating lattices, stage 2 with --acwt $acwt" + local/multi_condition/decode.sh --nj $decode_num_jobs --cmd "$decode_cmd" --config conf/decode.config $pass2_decode_opts \ + --skip-scoring true --iter $iter --acwt $acwt --lattice-beam $lattice_beam \ + --online-ivector-dir $ivector_dir/ivectors_${segmented_data_id}${ivector_affix} \ + exp/tri5a/graph_pp data/${segmented_data_id}_hires ${decode_dir}_tg || \ + { echo "$0: Error decoding"; exit 1; } +fi + +if [ $stage -le 13 ]; then + echo "Rescoring lattices" + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + --skip-scoring true \ + ${lang}_pp_test{,_fg} data/${segmented_data_id}_hires \ + ${decode_dir}_{tg,fg} || exit 1; +fi + +# tune the LMWT and WIP +# make command for filtering the ctms +decode_dir=${decode_dir}_fg +if [ -z $iter ]; then + model=$decode_dir/../final.mdl # assume model one level up from decoding dir. +else + model=$decode_dir/../$iter.mdl +fi + +mkdir -p $decode_dir/scoring +# create a python script to filter the ctm, for labels which are mapped +# to null strings in the glm or which are not accepted by the scoring server +python -c " +import sys, re +lines = map(lambda x: x.strip(), open('data/${act_data_id}/glm').readlines()) +patterns = [] +for line in lines: + if re.search('=>', line) is not None: + parts = re.split('=>', line.split('/')[0]) + if parts[1].strip() == '': + patterns.append(parts[0].strip()) +print '|'.join(patterns) +" > $decode_dir/scoring/glm_ignore_patterns || exit 1; + +ignore_patterns=$(cat $decode_dir/scoring/glm_ignore_patterns) +echo "$0: Ignoring these patterns from the ctm ", $ignore_patterns +cat << EOF > $decode_dir/scoring/filter_ctm.py +import sys +file = open(sys.argv[1]) +out_file = open(sys.argv[2], 'w') +ignore_set = "$ignore_patterns".split("|") +ignore_set.append("[noise]") +ignore_set.append("[laughter]") +ignore_set.append("[vocalized-noise]") +ignore_set.append("!SIL") +ignore_set.append("") +ignore_set.append("%hesitation") +ignore_set = set(ignore_set) +print ignore_set +for line in file: + if line.split()[4] not in ignore_set: + out_file.write(line) +out_file.close() +EOF + +filter_ctm_command="python $decode_dir/scoring/filter_ctm.py " + +if $tune_hyper ; then + if [ $stage -le 14 ]; then + if [[ "$act_data_id" =~ "dev_aspire" ]]; then + wip_string=$(echo $word_ins_penalties | sed 's/,/ /g') + temp_wips=($wip_string) + $decode_cmd WIP=1:${#temp_wips[@]} $decode_dir/scoring/log/score.wip.WIP.log \ + wips=\(0 $wip_string\) \&\& \ + wip=\${wips[WIP]} \&\& \ + echo \$wip \&\& \ + $decode_cmd LMWT=$min_lmwt:$max_lmwt $decode_dir/scoring/log/score.LMWT.\$wip.log \ + local/multi_condition/get_ctm.sh --filter-ctm-command "$filter_ctm_command" \ + --window $window --overlap $overlap \ + --beam $ctm_beam --decode-mbr $decode_mbr \ + --glm data/${act_data_id}/glm --stm data/${act_data_id}/stm \ + LMWT \$wip $lang data/${segmented_data_id}_hires $model $decode_dir || exit 1; + + local/multi_condition/get_ctm_conf.sh --cmd "$decode_cmd" \ + --use-segments true \ + data/${segmented_data_id}_hires \ + ${lang} \ + ${decode_dir} || exit 1; + + eval "grep Sum $decode_dir/score_{${min_lmwt}..${max_lmwt}}/penalty_{$word_ins_penalties}/*.sys"|utils/best_wer.sh 2>/dev/null + eval "grep Sum $decode_dir/score_{${min_lmwt}..${max_lmwt}}/penalty_{$word_ins_penalties}/*.sys" | \ + utils/best_wer.sh 2>/dev/null | python -c "import sys, re +line = sys.stdin.readline() +file_name=line.split()[-1] +parts=file_name.split('/') +penalty = re.sub('penalty_','',parts[-2]) +lmwt = re.sub('score_','', parts[-3]) +lmfile=open('$decode_dir/scoring/bestLMWT','w') +lmfile.write(str(lmwt)) +lmfile.close() +wipfile=open('$decode_dir/scoring/bestWIP','w') +wipfile.write(str(penalty)) +wipfile.close() +" || exit 1; + LMWT=$(cat $decode_dir/scoring/bestLMWT) + word_ins_penalty=$(cat $decode_dir/scoring/bestWIP) + fi + fi + if [ "$act_data_id" == "test_aspire" ] || [ "$act_data_id" == "eval_aspire" ]; then + dev_decode_dir=$(echo $decode_dir|sed "s/test_aspire/dev_aspire_whole/g; s/eval_aspire/dev_aspire_whole/g") + if [ -f $dev_decode_dir/scoring/bestLMWT ]; then + LMWT=$(cat $dev_decode_dir/scoring/bestLMWT) + echo "Using the bestLMWT $LMWT value found in $dev_decode_dir" + else + echo "Unable to find the bestLMWT in the dev decode dir $dev_decode_dir" + echo "Keeping the default/user-specified value" + fi + if [ -f $dev_decode_dir/scoring/bestWIP ]; then + word_ins_penalty=$(cat $dev_decode_dir/scoring/bestWIP) + echo "Using the bestWIP $word_ins_penalty value found in $dev_decode_dir" + else + echo "Unable to find the bestWIP in the dev decode dir $dev_decode_dir" + echo "Keeping the default/user-specified value" + fi + else + echo "Using the default/user-specified values for LMWT and word_ins_penalty" + fi +fi + +# lattice to ctm conversion and scoring. +if [ $stage -le 19 ]; then + echo "Generating CTMs with LMWT $LMWT and word insertion penalty of $word_ins_penalty" + local/multi_condition/get_ctm.sh --filter-ctm-command "$filter_ctm_command" \ + --beam $ctm_beam --decode-mbr $decode_mbr \ + $LMWT $word_ins_penalty $lang data/${segmented_data_id}_hires $model $decode_dir 2>$decode_dir/scoring/finalctm.LMWT$LMWT.WIP$word_ins_penalty.log || exit 1; +fi + +if [ $stage -le 20 ]; then + diarization/filter_ctm.sh --cmd "$train_cmd" $diarization_dir/diarization ${decode_dir}/score_${LMWT}/penalty_${word_ins_penalty}/ctm.filt $data_dir/reco2file_and_channel $diarization_dir/ctm_filter || exit 1 +fi + +if [ $stage -le 21 ]; then + + wip=$word_ins_penalty + + glm=data/${act_data_id}/glm + stm=data/${act_data_id}/stm + + if [ -f $stm ]; then + [ ! -f $glm ] && echo "Need glm file" + hubscr=$KALDI_ROOT/tools/sctk/bin/hubscr.pl + [ ! -f $hubscr ] && echo "Cannot find scoring program at $hubscr" && exit 1; + hubdir=`dirname $hubscr` + $hubscr -p $hubdir -V -l english -h hub5 -g $glm -r $stm $decode_dir/score_$LMWT/penalty_$wip/ctm.filt.nocrosstalk || exit 1; + fi + + cat $decode_dir/score_$LMWT/penalty_$word_ins_penalty/ctm.filt.nocrosstalk | awk '{split($1, parts, "-"); printf("%s 1 %s %s %s\n", parts[1], $3, $4, $5)}' > $out_file + cat ${segmented_data_dir}_hires/wav.scp | awk '{split($1, parts, "-"); printf("%s\n", parts[1])}' > $decode_dir/score_$LMWT/penalty_$word_ins_penalty/recording_names + python local/multi_condition/fill_missing_recordings.py $out_file $out_file.submission $decode_dir/score_$LMWT/penalty_$word_ins_penalty/recording_names + echo "Generated the ctm @ $out_file.submission from the ctm file $decode_dir/score_${LMWT}/penalty_$word_ins_penalty/ctm.filt.nocrosstalk" +fi + + diff --git a/egs/aspire/s5/local/multi_condition/prep_test_aspire_segmented.sh b/egs/aspire/s5/local/multi_condition/prep_test_aspire_segmented.sh new file mode 100644 index 00000000000..3828b006c9f --- /dev/null +++ b/egs/aspire/s5/local/multi_condition/prep_test_aspire_segmented.sh @@ -0,0 +1,373 @@ +#!/bin/bash +# Copyright Johns Hopkins University (Author: Daniel Povey, Vijayaditya Peddinti) 2015. Apache 2.0. +# This script generates the ctm files for dev_aspire, test_aspire and eval_aspire +# for scoring with ASpIRE scoring server. +# It also provides the WER for dev_aspire data. + +iter=final +mfccdir=mfcc_reverb_submission +stage=0 +decode_num_jobs=200 +num_jobs=30 +LMWT=12 +word_ins_penalty=0 +min_lmwt=9 +max_lmwt=20 +word_ins_penalties=0.0,0.25,0.5,0.75,1.0 +decode_mbr=true +acwt=0.1 +lattice_beam=8 +ctm_beam=6 +do_segmentation=true +max_count=100 # parameter for extract_ivectors.sh +sub_speaker_frames=1500 +overlap=5 +window=30 +affix= +ivector_scale=1.0 +pad_frames=0 # this did not seem to be helpful but leaving it as an option. +tune_hyper=true +pass2_decode_opts= +filter_ctm=true +weights_file= +silence_weight=0.00001 +stage1_only=false +create_whole_dir=true +. cmd.sh + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo " Options:" + echo " --stage (0|1|2) # start scoring script from part-way through." + echo "e.g.:" + echo "$0 data/train data/lang exp/nnet2_multicondition/nnet_ms_a" + exit 1; +fi + +data_dir=$1 #select from {dev_aspire, test_aspire, eval_aspire} +lang=$2 # data/lang +dir=$3 # exp/nnet2_multicondition/nnet_ms_a + +model_affix=`basename $dir` +ivector_dir=`dirname $dir` +ivector_affix=${affix:+_$affix}_$model_affix +affix=_${affix}_iter${iter} +act_data_dir=${data_dir} +if [ "$data_dir" == "test_aspire" ]; then + out_file=single_dev_test${affix}_$model_affix.ctm +elif [ "$data_dir" == "eval_aspire" ]; then + out_file=single_eval${affix}_$model_affix.ctm +else + if $create_whole_dir; then + if [ $stage -le 1 ]; then + echo "Creating the data dir with whole recordings without segmentation" + # create a whole directory without the segments + unseg_dir=data/${data_dir}_whole + src_dir=data/$data_dir + mkdir -p $unseg_dir + echo "Creating the $unseg_dir/wav.scp file" + cp $src_dir/wav.scp $unseg_dir + + echo "Creating the $unseg_dir/reco2file_and_channel file" + cat $unseg_dir/wav.scp | awk '{print $1, $1, "A";}' > $unseg_dir/reco2file_and_channel + cat $unseg_dir/wav.scp | awk '{print $1, $1;}' > $unseg_dir/utt2spk + utils/utt2spk_to_spk2utt.pl $unseg_dir/utt2spk > $unseg_dir/spk2utt + + steps/make_mfcc.sh --nj 30 --cmd "$train_cmd" --mfcc-config conf/mfcc_hires.conf $unseg_dir exp/make_mfcc_reverb/${data_dir}_whole $mfccdir || exit 1; + steps/compute_cmvn_stats.sh $unseg_dir exp/make_mfcc_reverb/${data_dir}_whole $mfccdir || exit 1; + fi + data_dir=${data_dir}_whole + fi + out_file=single_dev${affix}_${model_affix}.ctm +fi + +num_jobs=`cat data/${act_data_dir}/wav.scp|wc -l` +segmented_data_dir=${data_dir} +# extract the ivectors +if $do_segmentation; then + segmented_data_dir=${data_dir}_uniformsegmented_win${window}_over${overlap} + + if [ $stage -le 2 ]; then + echo "Generating uniform segments with length $window and overlap $overlap." + rm -rf data/$segmented_data_dir + copy_data_dir.sh --validate-opts "--no-text" data/$data_dir data/$segmented_data_dir || exit 1; + cp data/$data_dir/reco2file_and_channel data/$segmented_data_dir/ || exit 1; + python local/multi_condition/create_uniform_segments.py --overlap $overlap --window $window data/$segmented_data_dir || exit 1; + for file in cmvn.scp feats.scp; do + rm -f data/$segmented_data_dir/$file + done + utils/validate_data_dir.sh --no-text --no-feats data/$segmented_data_dir || exit 1; + fi + +fi + +if [ $stage -le 3 ]; then + echo "Extracting features for the segments" + # extract the features/i-vectors once again so that they are indexed by utterance and not by recording + rm -rf data/${segmented_data_dir}_hires + copy_data_dir.sh --validate-opts "--no-text " data/${segmented_data_dir} data/${segmented_data_dir}_hires || exit 1; + steps/make_mfcc.sh --nj 10 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${segmented_data_dir}_hires \ + exp/make_reverb_hires/${segmented_data_dir} $mfccdir || exit 1; + steps/compute_cmvn_stats.sh data/${segmented_data_dir}_hires exp/make_reverb_hires/${segmented_data_dir} $mfccdir || exit 1; + utils/fix_data_dir.sh data/${segmented_data_dir}_hires + utils/validate_data_dir.sh --no-text data/${segmented_data_dir}_hires +fi + +if [ $stage -le 4 ]; then + echo "Extracting i-vectors, stage 1" + steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 20 \ + --max-count $max_count \ + data/${segmented_data_dir}_hires $ivector_dir/extractor \ + $ivector_dir/ivectors_${segmented_data_dir}${ivector_affix}_stage1 || exit 1; +fi +if [ $ivector_scale != 1.0 ] && [ $ivector_scale != 1 ]; then + ivector_scale_affix=_scale$ivector_scale +else + ivector_scale_affix= +fi + +if [ $stage -le 5 ]; then + if [ "$ivector_scale_affix" != "" ]; then + echo "$0: Scaling iVectors, stage 1" + srcdir=$ivector_dir/ivectors_${segmented_data_dir}${ivector_affix}_stage1 + outdir=$ivector_dir/ivectors_${segmented_data_dir}${ivector_affix}${ivector_scale_affix}_stage1 + mkdir -p $outdir + copy-matrix --scale=$ivector_scale scp:$srcdir/ivector_online.scp ark:- | \ + copy-feats --compress=true ark:- ark,scp:$outdir/ivector_online.ark,$outdir/ivector_online.scp || exit 1; + cp $srcdir/ivector_period $outdir/ivector_period + fi +fi + +decode_dir=$dir/decode_${segmented_data_dir}${affix}_pp +# generate the lattices +if [ $stage -le 6 ]; then + echo "Generating lattices, stage 1" + local/multi_condition/decode.sh --nj $decode_num_jobs --cmd "$decode_cmd" --config conf/decode.config \ + --online-ivector-dir $ivector_dir/ivectors_${segmented_data_dir}${ivector_affix}${ivector_scale_affix}_stage1 \ + --skip-scoring true --iter $iter \ + exp/tri5a/graph_pp data/${segmented_data_dir}_hires ${decode_dir}_stage1 || exit 1; +fi + +if [ $stage -le 7 ]; then + echo "$0: generating CTM from stage-1 lattices" + local/multi_condition/get_ctm_conf.sh --cmd "$decode_cmd" \ + --use-segments false --iter $iter \ + data/${segmented_data_dir}_hires \ + ${lang} \ + ${decode_dir}_stage1 || exit 1; +fi + +if $stage1_only; then + decode_dir=${decode_dir}_stage1 + mv $decode_dir ${decode_dir}_tg +else + if [ $stage -le 8 ]; then + if $filter_ctm; then + if [ ! -z $weights_file ]; then + echo "$0: Using provided weights file $weights_file" + ivector_extractor_input=$weights_file + else + ctm=${decode_dir}_stage1/score_10/${segmented_data_dir}_hires.ctm + echo "$0: generating weights file from stage-1 ctm $ctm" + + feat-to-len scp:data/${segmented_data_dir}_hires/feats.scp ark,t:- >${decode_dir}_stage1/utt.lengths.$affix + if [ ! -f $ctm ]; then echo "$0: stage 8: expected ctm to exist: $ctm"; exit 1; fi + cat $ctm | awk '$6 == 1.0 && $4 < 1.0' | \ + grep -v -w mm | grep -v -w mhm | grep -v -F '[noise]' | \ + grep -v -F '[laughter]' | grep -v -F '' | \ + perl -e ' $lengths=shift @ARGV; $pad_frames=shift @ARGV; $silence_weight=shift @ARGV; + $pad_frames >= 0 || die "bad pad-frames value $pad_frames"; + open(L, "<$lengths") || die "opening lengths file"; + @all_utts = (); + $utt2ref = { }; + while () { + ($utt, $len) = split(" ", $_); + push @all_utts, $utt; + $array_ref = [ ]; + for ($n = 0; $n < $len; $n++) { ${$array_ref}[$n] = $silence_weight; } + $utt2ref{$utt} = $array_ref; + } + while () { + @A = split(" ", $_); + @A == 6 || die "bad ctm line $_"; + $utt = $A[0]; $beg = $A[2]; $len = $A[3]; + $beg_int = int($beg * 100) - $pad_frames; + $len_int = int($len * 100) + 2*$pad_frames; + $array_ref = $utt2ref{$utt}; + !defined $array_ref && die "No length info for utterance $utt"; + for ($t = $beg_int; $t < $beg_int + $len_int; $t++) { + if ($t >= 0 && $t < @$array_ref) { + ${$array_ref}[$t] = 1; + } + } + } + foreach $utt (@all_utts) { $array_ref = $utt2ref{$utt}; + print $utt, " [ ", join(" ", @$array_ref), " ]\n"; + } ' ${decode_dir}_stage1/utt.lengths.$affix $pad_frames $silence_weight | gzip -c >${decode_dir}_stage1/weights${affix}.gz + ivector_extractor_input=${decode_dir}_stage1/weights${affix}.gz + fi + else + ivector_extractor_input=${decode_dir}_stage1 + fi + fi + + if [ $stage -le 8 ]; then + echo "Extracting i-vectors, stage 2 with input $ivector_extractor_input" + # this does offline decoding, except we estimate the iVectors per + # speaker, excluding silence (based on alignments from a GMM decoding), with a + # different script. This is just to demonstrate that script. + # the --sub-speaker-frames is optional; if provided, it will divide each speaker + # up into "sub-speakers" of at least that many frames... can be useful if + # acoustic conditions drift over time within the speaker's data. + steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj 20 \ + --silence-weight $silence_weight \ + --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ + data/${segmented_data_dir}_hires $lang $ivector_dir/extractor \ + $ivector_extractor_input $ivector_dir/ivectors_${segmented_data_dir}${ivector_affix} || exit 1; + fi + + if [ $stage -le 9 ]; then + echo "Generating lattices, stage 2 with --acwt $acwt" + rm -f ${decode_dir}_tg/.error + local/multi_condition/decode.sh --nj $decode_num_jobs --cmd "$decode_cmd" --config conf/decode.config $pass2_decode_opts \ + --skip-scoring true --iter $iter --acwt $acwt --lattice-beam $lattice_beam \ + --online-ivector-dir $ivector_dir/ivectors_${segmented_data_dir}${ivector_affix} \ + exp/tri5a/graph_pp data/${segmented_data_dir}_hires ${decode_dir}_tg || touch ${decode_dir}_tg/.error + [ -f ${decode_dir}_tg/.error ] && echo "$0: Error decoding" && exit 1; + fi +fi + +if [ $stage -le 10 ]; then + echo "Rescoring lattices" + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + --skip-scoring true \ + ${lang}_pp_test{,_fg} data/${segmented_data_dir}_hires \ + ${decode_dir}_{tg,fg} || exit 1; +fi + +# tune the LMWT and WIP +# make command for filtering the ctms +decode_dir=${decode_dir}_fg +if [ -z $iter ]; then + model=$decode_dir/../final.mdl # assume model one level up from decoding dir. +else + model=$decode_dir/../$iter.mdl +fi + +mkdir -p $decode_dir/scoring +# create a python script to filter the ctm, for labels which are mapped +# to null strings in the glm or which are not accepted by the scoring server +python -c " +import sys, re +lines = map(lambda x: x.strip(), open('data/${act_data_dir}/glm').readlines()) +patterns = [] +for line in lines: + if re.search('=>', line) is not None: + parts = re.split('=>', line.split('/')[0]) + if parts[1].strip() == '': + patterns.append(parts[0].strip()) +print '|'.join(patterns) +" > $decode_dir/scoring/glm_ignore_patterns || exit 1; + +ignore_patterns=$(cat $decode_dir/scoring/glm_ignore_patterns) +echo "$0: Ignoring these patterns from the ctm ", $ignore_patterns +cat << EOF > $decode_dir/scoring/filter_ctm.py +import sys +file = open(sys.argv[1]) +out_file = open(sys.argv[2], 'w') +ignore_set = "$ignore_patterns".split("|") +ignore_set.append("[noise]") +ignore_set.append("[laughter]") +ignore_set.append("[vocalized-noise]") +ignore_set.append("!SIL") +ignore_set.append("") +ignore_set.append("%hesitation") +ignore_set = set(ignore_set) +print ignore_set +for line in file: + if line.split()[4] not in ignore_set: + out_file.write(line) +out_file.close() +EOF + +filter_ctm_command="python $decode_dir/scoring/filter_ctm.py " + +if $tune_hyper ; then + if [ $stage -le 11 ]; then + if [[ "$act_data_dir" =~ "dev_aspire" ]]; then + wip_string=$(echo $word_ins_penalties | sed 's/,/ /g') + temp_wips=($wip_string) + $decode_cmd WIP=1:${#temp_wips[@]} $decode_dir/scoring/log/score.wip.WIP.log \ + wips=\(0 $wip_string\) \&\& \ + wip=\${wips[WIP]} \&\& \ + echo \$wip \&\& \ + $decode_cmd LMWT=$min_lmwt:$max_lmwt $decode_dir/scoring/log/score.LMWT.\$wip.log \ + local/multi_condition/get_ctm.sh --filter-ctm-command "$filter_ctm_command" \ + --window $window --overlap $overlap \ + --beam $ctm_beam --decode-mbr $decode_mbr \ + --glm data/${act_data_dir}/glm --stm data/${act_data_dir}/stm \ + LMWT \$wip $lang data/${segmented_data_dir}_hires $model $decode_dir || exit 1; + + local/multi_condition/get_ctm_conf.sh --cmd "$decode_cmd" \ + --use-segments true \ + data/${segmented_data_dir}_hires \ + ${lang} \ + ${decode_dir} || exit 1; + + eval "grep Sum $decode_dir/score_{${min_lmwt}..${max_lmwt}}/penalty_{$word_ins_penalties}/*.sys"|utils/best_wer.sh 2>/dev/null + eval "grep Sum $decode_dir/score_{${min_lmwt}..${max_lmwt}}/penalty_{$word_ins_penalties}/*.sys" | \ + utils/best_wer.sh 2>/dev/null | python -c "import sys, re +line = sys.stdin.readline() +file_name=line.split()[-1] +parts=file_name.split('/') +penalty = re.sub('penalty_','',parts[-2]) +lmwt = re.sub('score_','', parts[-3]) +lmfile=open('$decode_dir/scoring/bestLMWT','w') +lmfile.write(str(lmwt)) +lmfile.close() +wipfile=open('$decode_dir/scoring/bestWIP','w') +wipfile.write(str(penalty)) +wipfile.close() +" || exit 1; + LMWT=$(cat $decode_dir/scoring/bestLMWT) + word_ins_penalty=$(cat $decode_dir/scoring/bestWIP) + fi + fi + if [ "$act_data_dir" == "test_aspire" ] || [ "$act_data_dir" == "eval_aspire" ]; then + dev_decode_dir=$(echo $decode_dir|sed "s/test_aspire/dev_aspire_whole/g; s/eval_aspire/dev_aspire_whole/g") + if [ -f $dev_decode_dir/scoring/bestLMWT ]; then + LMWT=$(cat $dev_decode_dir/scoring/bestLMWT) + echo "Using the bestLMWT $LMWT value found in $dev_decode_dir" + else + echo "Unable to find the bestLMWT in the dev decode dir $dev_decode_dir" + echo "Keeping the default/user-specified value" + fi + if [ -f $dev_decode_dir/scoring/bestWIP ]; then + word_ins_penalty=$(cat $dev_decode_dir/scoring/bestWIP) + echo "Using the bestWIP $word_ins_penalty value found in $dev_decode_dir" + else + echo "Unable to find the bestWIP in the dev decode dir $dev_decode_dir" + echo "Keeping the default/user-specified value" + fi + else + echo "Using the default/user-specified values for LMWT and word_ins_penalty" + fi +fi + +# lattice to ctm conversion and scoring. +if [ $stage -le 12 ]; then + echo "Generating CTMs with LMWT $LMWT and word insertion penalty of $word_ins_penalty" + local/multi_condition/get_ctm.sh --filter-ctm-command "$filter_ctm_command" \ + --beam $ctm_beam --decode-mbr $decode_mbr \ + $LMWT $word_ins_penalty $lang data/${segmented_data_dir}_hires $model $decode_dir 2>$decode_dir/scoring/finalctm.LMWT$LMWT.WIP$word_ins_penalty.log || exit 1; +fi + +if [ $stage -le 13 ]; then + cat $decode_dir/score_$LMWT/penalty_$word_ins_penalty/ctm.filt | awk '{split($1, parts, "-"); printf("%s 1 %s %s %s\n", parts[1], $3, $4, $5)}' > $out_file + echo "Generated the ctm @ $out_file from the ctm file $decode_dir/score_${LMWT}/penalty_$word_ins_penalty/ctm.filt" +fi + diff --git a/egs/aspire/s5/local/multi_condition/prep_test_aspire_vad.sh b/egs/aspire/s5/local/multi_condition/prep_test_aspire_vad.sh new file mode 100755 index 00000000000..1104b97c3fd --- /dev/null +++ b/egs/aspire/s5/local/multi_condition/prep_test_aspire_vad.sh @@ -0,0 +1,482 @@ +#!/bin/bash +# Copyright 2015 Johns Hopkins University (Author: Daniel Povey), Vijayaditya Peddinti +# 2015 Vimal Manohar +# Apache 2.0. + +# This script generates the ctm files for dev_aspire, test_aspire and eval_aspire +# for scoring with ASpIRE scoring server. +# It also provides the WER for dev_aspire data. + +set -u +set -o pipefail +set -e + +stage=-10 + +graph_dir=exp/tri5a/graph_pp +iter=final # Acoustic model to be used for decoding +mfccdir=mfcc_reverb_submission # Dir to store MFCC features +fbankdir=fbank_reverb_submission # Dir to store Fbank features +sad_mfcc_config=conf/mfcc_hires.conf +sad_fbank_config=conf/fbank.conf +mfcc_config=conf/mfcc_hires.conf +fbank_config=conf/fbank.conf +add_frame_snr=true +append_to_orig_feats=false +feature_type=Snr + +nj=30 # number of parallel jobs for VAD and segmentation +decode_nj=200 # number of parallel jobs for decoding + +# segmentation opts +segmentation_config= +weights_segmentation_config= +segmentation_stage=-10 +segmentation_method=Viterbi +snr_predictor_iter=final +sad_model_iter=final + +# ivector extraction opts +use_ivectors=false +max_count=100 # parameter for extract_ivectors.sh +sub_speaker_frames=2500 +ivector_scale=1.0 +weights_file= +weights_method=Viterbi +silence_weight=0 + +# Decoding and scoring opts +acwt=0.1 +LMWT=12 +word_ins_penalty=0 +min_lmwt=9 +max_lmwt=20 +word_ins_penalties=0.0,0.25,0.5,0.75,1.0 +decode_mbr=true + +lattice_beam=8 +ctm_beam=6 +filter_ctm=true + +# output opts +affix= # append this to the directory names +create_whole_dir=true +tune_hyper=true + +# stage opts +input_frame_snrs_dir= +input_vad_dir= + +. cmd.sh + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + + +if [ $# -ne 5 ]; then + echo "Usage: $0 [options] " + echo " Options:" + echo " --stage (0|1|2) # start scoring script from part-way through." + echo "e.g.:" + echo "$0 data/train data/lang exp/nnet2_multicondition/nnet_ms_a" + exit 1; +fi + +data_dir=$1 #select from data/{dev_aspire,test_aspire,eval_aspire}* +snr_predictor=$2 +sad_model_dir=$3 +lang=$4 # data/lang +dir=$5 # exp/nnet2_multicondition/nnet_ms_a + +data_id=`basename $data_dir` # {dev,test,eval}_aspire* +model_affix=`basename $dir` # nnet_ms_* +vad_affix=${affix:+_$affix} +ivector_dir=`dirname $dir` # exp/nnet2_multicondition +ivector_affix=${affix:+_$affix}_$model_affix +frame_snrs_dir=exp/frame_snrs${vad_affix}_${data_id} +vad_dir=exp/vad_${data_id}${vad_affix} # Temporary directory for VAD +segmentation_dir=exp/segmentation_${data_id}${vad_affix} +affix=_${affix}_iter${iter} # affix to be specific to AM used +act_data_id=${data_id} # the original data_id before data_id gets + # modified to something else + +# Function to create mfcc features +make_mfcc () { + if [ $# -lt 2 ] || [ $[$# % 2] -ne 0 ]; then + echo "$0: make_mfcc: Not enough arguments. Some variable is probably not set" + exit 1 + fi + local this_nj=$nj + local mfcc_config=$mfcc_config + + while [ $# -gt 0 ]; do + if [ $[$# % 2] -ne 0 ]; then + echo "$0: make_mfcc: Not enough arguments. Some variable is probably not set" + exit 1 + fi + case $1 in + --nj) + this_nj=$2 + shift; shift + ;; + --mfcc-config) + mfcc_config=$2 + shift; shift + ;; + *) + if [ $# -eq 2 ]; then + break; + else + echo "$0: make_mfcc: Unknown arguments $*" + exit 1 + fi + ;; + esac + done + + if [ $# -ne 2 ]; then + echo "$0: make_mfcc: Not enough arguments. Some variable is probably not set" + exit 1 + fi + + local data_dir=$1 + local mfccdir=$2 + + rm -rf ${data_dir}_hires + utils/copy_data_dir.sh ${data_dir} ${data_dir}_hires + [ -f $data_dir/reco2file_and_channel ] && cp $data_dir/reco2file_and_channel ${data_dir}_hires + + data_dir=${data_dir}_hires + steps/make_mfcc.sh --nj $this_nj --cmd "$train_cmd" \ + --mfcc-config $mfcc_config \ + ${data_dir} exp/make_hires/${data_dir} $mfccdir || exit 1; + steps/compute_cmvn_stats.sh \ + ${data_dir} exp/make_hires/${data_dir} $mfccdir || exit 1; + utils/fix_data_dir.sh ${data_dir} + utils/validate_data_dir.sh --no-text ${data_dir} +} + +make_fbank () { + if [ $# -lt 2 ] || [ $[$# % 2] -ne 0 ]; then + echo "$0: make_fbank: Not enough arguments. Some variable is probably not set" + exit 1 + fi + local this_nj=$nj + local fbank_config=conf/fbank_hires.conf + + while [ $# -gt 0 ]; do + if [ $[$# % 2] -ne 0 ]; then + echo "$0: make_fbank: Not enough arguments. Some variable is probably not set" + exit 1 + fi + case $1 in + --nj) + this_nj=$2 + shift; shift + ;; + --fbank-config) + fbank_config=$2 + shift; shift + ;; + *) + if [ $# -eq 2 ]; then + break; + else + echo "$0: make_fbank: Unknown arguments $*" + exit 1 + fi + ;; + esac + done + + if [ $# -ne 2 ]; then + echo "$0: make_fbank: Not enough arguments. Some variable is probably not set" + exit 1 + fi + + + local data_dir=$1 + local fbankdir=$2 + + rm -rf ${data_dir}_fbank + utils/copy_data_dir.sh ${data_dir} ${data_dir}_fbank + [ -f $data_dir/reco2file_and_channel ] && cp $data_dir/reco2file_and_channel ${data_dir}_fbank + + data_dir=${data_dir}_fbank + steps/make_fbank.sh --nj $this_nj --cmd "$train_cmd" \ + --fbank-config $fbank_config \ + ${data_dir} exp/make_fbank/${data_dir} $fbankdir || exit 1; + steps/compute_cmvn_stats.sh --fake \ + ${data_dir} exp/make_fbank/${data_dir} $fbankdir || exit 1; + utils/fix_data_dir.sh ${data_dir} + utils/validate_data_dir.sh --no-text ${data_dir} +} + +if [[ "$data_id" =~ "test_aspire" ]]; then + out_file=single_dev_test${affix}_$model_affix.ctm + if [ $stage -le 0 ]; then + make_mfcc --nj $nj --mfcc-config $sad_mfcc_config $data_dir $mfccdir + make_fbank --nj $nj --fbank-config $sad_fbank_config $data_dir $fbankdir + fi +elif [[ "$data_id" =~ "eval_aspire" ]]; then + out_file=single_eval${affix}_$model_affix.ctm + if [ $stage -le 0 ]; then + make_mfcc --nj $nj --mfcc-config $sad_mfcc_config $data_dir $mfcc_dir + make_fbank --nj $nj --fbank-config $sad_fbank_config $data_dir $fbankdir + fi +else + if $create_whole_dir; then + if [ $stage -le 0 ]; then + echo "Creating the data dir with whole recordings without segmentation" + # create a whole directory without the segments for the + # purposes of recreating the eval setting on dev set + whole_dir=data/${data_id}_whole # unsegmented_dir + mkdir -p $whole_dir + cp $data_dir/wav.scp $whole_dir # same as before + cat $whole_dir/wav.scp | \ + awk '{print $1, $1, "A";}' > $whole_dir/reco2file_and_channel + + cat $whole_dir/wav.scp | awk '{print $1, $1;}' > $whole_dir/utt2spk + utils/utt2spk_to_spk2utt.pl $whole_dir/utt2spk > $whole_dir/spk2utt + + make_mfcc --nj $nj --mfcc-config $sad_mfcc_config $whole_dir $mfccdir + + make_fbank --nj $nj --fbank-config $sad_fbank_config $whole_dir $fbankdir + fi + data_id=${data_id}_whole + fi + out_file=single_dev${affix}_${model_affix}.ctm +fi + +if [ $stage -le 1 ]; then + # Compute sub-band SNR + local/snr/compute_frame_snrs.sh --cmd "$train_cmd" \ + --use-gpu no --nj $nj --iter $snr_predictor_iter \ + $snr_predictor \ + data/${data_id}_hires data/${data_id}_fbank \ + $frame_snrs_dir || exit 1 +fi + +compute_sad_opts=(--iter $sad_model_iter) + +if [ ! -z "$input_frame_snrs_dir" ] && [ $stage -ge 2 ]; then + frame_snrs_dir=$input_frame_snrs_dir +fi + +if [ $stage -le 2 ]; then + local/snr/create_snr_data_dir.sh --cmd "$train_cmd" --nj $nj \ + --type $feature_type --append-to-orig-feats $append_to_orig_feats \ + --add-frame-snr $add_frame_snr \ + data/${data_id}_fbank $frame_snrs_dir exp/make_snr_data_dir/${data_id} snr_feats $frame_snrs_dir/${data_id}_snr || exit 1 +fi + +if [ $stage -le 3 ]; then + local/snr/compute_sad.sh \ + --nj $nj --use-gpu no "${compute_sad_opts[@]}" \ + --snr-data-dir $frame_snrs_dir/${data_id}_snr \ + $sad_model_dir $frame_snrs_dir ${vad_dir} || exit 1 +fi + +segmented_data_dir=data/${data_id}_seg${vad_affix} +segmented_data_id=`basename $segmented_data_dir` + +if [ ! -z "$input_vad_dir" ] && [ $stage -ge 3 ]; then + vad_dir=$input_vad_dir +fi + +if [ $stage -le 4 ]; then + local/snr/sad_to_segments.sh --cmd "$train_cmd" \ + --method $segmentation_method --stage $segmentation_stage ${segmentation_config:+--config $segmentation_config} \ + data/${data_id}_hires ${vad_dir} $segmentation_dir $segmented_data_dir +fi + +[ -f $data_dir/reco2file_and_channel ] && cp $data_dir/reco2file_and_channel ${segmented_data_dir} + +if [ $stage -le 5 ]; then + make_mfcc --nj $nj --mfcc-config $mfcc_config \ + $segmented_data_dir $mfccdir +fi + +if $use_ivectors; then + if [ ! -z "$weights_file" ]; then + echo "$0: Using provided weights file $weights_file" + ivector_extractor_input=$weights_file + else + mkdir -p $ivector_dir/ivector_weights_${segmented_data_id}${ivector_affix} + + if [ $stage -le 6 ]; then + local/snr/get_weights_for_ivector_extraction.sh --cmd queue.pl \ + --method $weights_method ${segmentation_config:+--config $weights_segmentation_config} \ + --silence-weight $silence_weight \ + ${segmented_data_dir} ${vad_dir} \ + $ivector_dir/ivector_weights_${segmented_data_id}${ivector_affix} + fi + ivector_extractor_input=$ivector_dir/ivector_weights_${segmented_data_id}${ivector_affix}/weights.gz + fi +fi + +if $use_ivectors && [ $stage -le 9 ]; then + echo "Extracting i-vectors, with weights from $ivector_extractor_input" + # this does offline decoding, except we estimate the iVectors per + # speaker, excluding silence (based on alignments from a GMM decoding), with a + # different script. This is just to demonstrate that script. + # the --sub-speaker-frames is optional; if provided, it will divide each speaker + # up into "sub-speakers" of at least that many frames... can be useful if + # acoustic conditions drift over time within the speaker's data. + steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj 20 \ + --silence-weight $silence_weight \ + --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ + data/${segmented_data_id}_hires $lang $ivector_dir/extractor \ + $ivector_extractor_input $ivector_dir/ivectors_${segmented_data_id}${ivector_affix} || exit 1; +fi + +decode_dir=$dir/decode_${segmented_data_id}${affix}_pp +if [ $stage -le 10 ]; then + echo "Generating lattices, with --acwt $acwt" + + ivector_opts=(--online-ivector-dir "") + if $use_ivectors; then + ivector_opts=(--online-ivector-dir $ivector_dir/ivectors_${segmented_data_id}${ivector_affix}) + fi + + local/multi_condition/decode.sh --nj $decode_nj --cmd "$decode_cmd" \ + --config conf/decode.config \ + --skip-scoring true --iter $iter --acwt $acwt --lattice-beam $lattice_beam \ + "${ivector_opts[@]}" \ + $graph_dir data/${segmented_data_id}_hires ${decode_dir}_tg || \ + { echo "$0: Error decoding"; exit 1; } +fi + +if [ $stage -le 11 ]; then + echo "Rescoring lattices" + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + --skip-scoring true \ + ${lang}_pp_test{,_fg} data/${segmented_data_id}_hires \ + ${decode_dir}_{tg,fg} || exit 1; +fi + +# tune the LMWT and WIP +# make command for filtering the ctms +decode_dir=${decode_dir}_fg +if [ -z $iter ]; then + model=$decode_dir/../final.mdl # assume model one level up from decoding dir. +else + model=$decode_dir/../$iter.mdl +fi + +mkdir -p $decode_dir/scoring +# create a python script to filter the ctm, for labels which are mapped +# to null strings in the glm or which are not accepted by the scoring server +python -c " +import sys, re +lines = map(lambda x: x.strip(), open('data/${act_data_id}/glm').readlines()) +patterns = [] +for line in lines: + if re.search('=>', line) is not None: + parts = re.split('=>', line.split('/')[0]) + if parts[1].strip() == '': + patterns.append(parts[0].strip()) +print '|'.join(patterns) +" > $decode_dir/scoring/glm_ignore_patterns || exit 1; + +ignore_patterns=$(cat $decode_dir/scoring/glm_ignore_patterns) +echo "$0: Ignoring these patterns from the ctm ", $ignore_patterns +cat << EOF > $decode_dir/scoring/filter_ctm.py +import sys +file = open(sys.argv[1]) +out_file = open(sys.argv[2], 'w') +ignore_set = "$ignore_patterns".split("|") +ignore_set.append("[noise]") +ignore_set.append("[laughter]") +ignore_set.append("[vocalized-noise]") +ignore_set.append("!SIL") +ignore_set.append("") +ignore_set.append("%hesitation") +ignore_set = set(ignore_set) +print ignore_set +for line in file: + if line.split()[4] not in ignore_set: + out_file.write(line) +out_file.close() +EOF + +filter_ctm_command="python $decode_dir/scoring/filter_ctm.py " + +if $tune_hyper ; then + if [ $stage -le 12 ]; then + if [[ "$act_data_id" =~ "dev_aspire" ]]; then + wip_string=$(echo $word_ins_penalties | sed 's/,/ /g') + temp_wips=($wip_string) + $decode_cmd WIP=1:${#temp_wips[@]} $decode_dir/scoring/log/score.wip.WIP.log \ + wips=\(0 $wip_string\) \&\& \ + wip=\${wips[WIP]} \&\& \ + echo \$wip \&\& \ + $decode_cmd LMWT=$min_lmwt:$max_lmwt \ + $decode_dir/scoring/log/score.LMWT.\$wip.log \ + local/multi_condition/get_ctm.sh \ + --filter-ctm-command "$filter_ctm_command" \ + --beam $ctm_beam --decode-mbr $decode_mbr \ + --glm data/${act_data_id}/glm --stm data/${act_data_id}/stm \ + LMWT \$wip $lang data/${segmented_data_id}_hires \ + $model $decode_dir || exit 1; + + #local/multi_condition/get_ctm_conf.sh --cmd "$decode_cmd" \ + # --use-segments true \ + # data/${segmented_data_id}_hires \ + # ${lang} ${decode_dir} || exit 1; + + eval "grep Sum $decode_dir/score_{${min_lmwt}..${max_lmwt}}/penalty_{$word_ins_penalties}/*.sys"|utils/best_wer.sh 2>/dev/null + eval "grep Sum $decode_dir/score_{${min_lmwt}..${max_lmwt}}/penalty_{$word_ins_penalties}/*.sys" | \ + utils/best_wer.sh 2>/dev/null | python -c "import sys, re +line = sys.stdin.readline() +file_name=line.split()[-1] +parts=file_name.split('/') +penalty = re.sub('penalty_','',parts[-2]) +lmwt = re.sub('score_','', parts[-3]) +lmfile=open('$decode_dir/scoring/bestLMWT','w') +lmfile.write(str(lmwt)) +lmfile.close() +wipfile=open('$decode_dir/scoring/bestWIP','w') +wipfile.write(str(penalty)) +wipfile.close() +" || exit 1; + LMWT=$(cat $decode_dir/scoring/bestLMWT) + word_ins_penalty=$(cat $decode_dir/scoring/bestWIP) + fi + fi + if [[ "$act_data_id" =~ "test_aspire" ]] || [[ "$act_data_id" =~ "eval_aspire" ]]; then + dev_decode_dir=$(echo $decode_dir|sed "s/test_aspire/dev_aspire_whole/g; s/eval_aspire/dev_aspire_whole/g") + if [ -f $dev_decode_dir/scoring/bestLMWT ]; then + LMWT=$(cat $dev_decode_dir/scoring/bestLMWT) + echo "Using the bestLMWT $LMWT value found in $dev_decode_dir" + else + echo "Unable to find the bestLMWT in the dev decode dir $dev_decode_dir" + echo "Keeping the default/user-specified value" + fi + if [ -f $dev_decode_dir/scoring/bestWIP ]; then + word_ins_penalty=$(cat $dev_decode_dir/scoring/bestWIP) + echo "Using the bestWIP $word_ins_penalty value found in $dev_decode_dir" + else + echo "Unable to find the bestWIP in the dev decode dir $dev_decode_dir" + echo "Keeping the default/user-specified value" + fi + else + echo "Using the default/user-specified values for LMWT and word_ins_penalty" + fi +fi + +# lattice to ctm conversion and scoring. +if [ $stage -le 13 ]; then + echo "Generating CTMs with LMWT $LMWT and word insertion penalty of $word_ins_penalty" + local/multi_condition/get_ctm.sh --filter-ctm-command "$filter_ctm_command" \ + --beam $ctm_beam --decode-mbr $decode_mbr \ + $LMWT $word_ins_penalty $lang data/${segmented_data_id}_hires $model $decode_dir 2>$decode_dir/scoring/finalctm.LMWT$LMWT.WIP$word_ins_penalty.log || exit 1; +fi + +if [ $stage -le 14 ]; then + cat $decode_dir/score_$LMWT/penalty_$word_ins_penalty/ctm.filt | awk '{split($1, parts, "-"); printf("%s 1 %s %s %s\n", parts[1], $3, $4, $5)}' > $out_file + cat ${segmented_data_dir}_hires/wav.scp | awk '{split($1, parts, "-"); printf("%s\n", parts[1])}' > $decode_dir/score_$LMWT/penalty_$word_ins_penalty/recording_names + python local/multi_condition/fill_missing_recordings.py $out_file $out_file.submission $decode_dir/score_$LMWT/penalty_$word_ins_penalty/recording_names + echo "Generated the ctm @ $out_file.submission from the ctm file $decode_dir/score_${LMWT}/penalty_$word_ins_penalty/ctm.filt" +fi + diff --git a/egs/aspire/s5/local/multi_condition/resolve_ctm_overlaps.py b/egs/aspire/s5/local/multi_condition/resolve_ctm_overlaps.py index 06f50c42155..5b9a69ae6ac 100755 --- a/egs/aspire/s5/local/multi_condition/resolve_ctm_overlaps.py +++ b/egs/aspire/s5/local/multi_condition/resolve_ctm_overlaps.py @@ -27,7 +27,41 @@ def resolve_overlaps(ctms, window_length, overlap): index = i break total_ctm += cur_ctm[:index] - + + index = 0 + for i in xrange(len(next_ctm)): + if next_ctm[i][2] + next_ctm[i][3]/2.0 > (overlap/2.0): + index = i + break + ctms[ctm_index + 1] = next_ctm[index:] + # merge the last ctm entirely + total_ctm +=ctms[-1] + return total_ctm + +def resolve_overlaps_segments(ctms, segments): + total_ctm = [] + if len(ctms) == 0: + raise Exception('Something wrong with the input ctms') + for ctm_index in range(len(ctms) - 1): + cur_ctm = ctms[ctm_index] + next_ctm = ctms[ctm_index + 1] + # find the breaks after overlap starts + index = len(cur_ctm) + + if (cur_ctm[0][0] not in segments): + raise Exception('Could not find utterance %s in segments' % cur_ctm[0][0]) + if (next_ctm[0][0] not in segments): + raise Exception('Could not find utterance %s in segments' % next_ctm[0][0]) + + window_length = segments[cur_ctm[0][0]][2] - segments[cur_ctm[0][0]][1] + overlap = segments[cur_ctm[0][0]][2] - segments[next_ctm[0][0]][1] + + for i in xrange(len(cur_ctm)): + if cur_ctm[i][2] + cur_ctm[i][3]/2.0 > (window_length - overlap/2.0): + index = i + break + total_ctm += cur_ctm[:index] + index = 0 for i in xrange(len(next_ctm)): if next_ctm[i][2] + next_ctm[i][3]/2.0 > (overlap/2.0): @@ -37,6 +71,7 @@ def resolve_overlaps(ctms, window_length, overlap): # merge the last ctm entirely total_ctm +=ctms[-1] return total_ctm + def read_ctm(ctm_file_lines, utt2spk): ctms = {} for key in utt2spk.values(): @@ -56,7 +91,7 @@ def read_ctm(ctm_file_lines, utt2spk): ctm.append([parts[0], parts[1], float(parts[2]), float(parts[3]), parts[4], parts[5]]) # append the last ctm - ctms[utt2spk[ctm[0][0]]].append(ctm) + ctms[utt2spk[ctm[0][0]]].append(ctm) return ctms def write_ctm(ctm_lines): @@ -66,16 +101,17 @@ def write_ctm(ctm_lines): return ctm_file_lines if __name__ == "__main__": - usage = """ Python script to resolve overlaps in uniformly segmented ctms """ + usage = """ Python script to resolve overlaps in uniformly segmented ctms """ main_parser = argparse.ArgumentParser(usage) parser = argparse.ArgumentParser() parser.add_argument('--window-length', type = float, default = 30.0, help = 'length of the window used to cut the segment') parser.add_argument('--overlap', type = float, default = 5.0, help = 'overlap of neighboring windows') + parser.add_argument('--segments', type = str, help = 'use segments to resolve overlaps') parser.add_argument('utt2spk', type=str, help='spk2utt_file') parser.add_argument('ctm_in', type=str, help='input_ctm_file') parser.add_argument('ctm_out', type=str, help='output_ctm_file') params = parser.parse_args() - + if params.ctm_in == "-": params.ctm_in = sys.stdin else: @@ -90,11 +126,20 @@ def write_ctm(ctm_lines): parts = line.split() utt2spk[parts[0]] = parts[1] + segments = {} + if params.segments: + for line in open(params.segments).readlines(): + parts = line.strip().split() + segments[parts[0]] = [ parts[1] ] + [ float(x) for x in parts[2:] ] + ctms = read_ctm(params.ctm_in.readlines(), utt2spk) speakers = ctms.keys() speakers.sort() for key in speakers: ctm = ctms[key] - ctm = resolve_overlaps(ctm, params.window_length, params.overlap) + if params.segments: + ctm = resolve_overlaps_segments(ctm, segments) + else: + ctm = resolve_overlaps(ctm, params.window_length, params.overlap) params.ctm_out.write("\n".join(write_ctm(ctm))+"\n") params.ctm_out.close() diff --git a/egs/aspire/s5/local/multi_condition/reverberate_wavs.py b/egs/aspire/s5/local/multi_condition/reverberate_wavs.py index 998a3ed5e74..80153047fa3 100755 --- a/egs/aspire/s5/local/multi_condition/reverberate_wavs.py +++ b/egs/aspire/s5/local/multi_condition/reverberate_wavs.py @@ -33,8 +33,8 @@ def return_nonempty_lines(lines): parser.add_argument('output_wav_file_list', type=str, help='wav.scp file to write corrupted output') parser.add_argument('impulses_noises_dir', type=str, help='directory with impulses and noises and info directory (created by local/prep_rirs.sh)') parser.add_argument('output_command_file', type=str, help='file to output the corruption commands') - params = parser.parse_args() - + params = parser.parse_args() + add_noise = True snr_string_parts = params.snrs.split(':') if (len(snr_string_parts) == 1) and snr_string_parts[0] == "inf": diff --git a/egs/aspire/s5/local/multi_condition/rirs/prep_aalto.sh b/egs/aspire/s5/local/multi_condition/rirs/prep_aalto.sh index f8a45c3e790..6771b776e58 100755 --- a/egs/aspire/s5/local/multi_condition/rirs/prep_aalto.sh +++ b/egs/aspire/s5/local/multi_condition/rirs/prep_aalto.sh @@ -27,6 +27,9 @@ RIR_home=$1 output_dir=$2 log_dir=$3 +mkdir -p $log_dir +mkdir -p $output_dir/info + if [ "$download" = true ]; then mkdir -p $RIR_home (cd $RIR_home; @@ -56,7 +59,7 @@ echo "">$command_file type_num=1 data_files=( $(find $RIR_home/aalto_concert_hall_pori/ -name '*.wav' -type f -print || exit -1) ) total_files=$(echo ${data_files[@]}|wc -w) -echo "" > $log_dir/${DBname}_type${type_num}.rir.list +echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list echo "Found $total_files impulse responses in ${RIR_home}/aalto_concert_hall_pori//" tmpdir=`mktemp -d $log_dir/aalto_XXXXXX` tmpdir=`readlink -e $tmpdir` @@ -67,7 +70,7 @@ for data_file in ${data_files[@]}; do output_file_name=${DBname}_type${type_num}_`basename $data_file| tr '[:upper:]' '[:lower:]'` echo "sox -t wav $data_file -t wav -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file # echo "python local/multi_condition/read_rir.py --output-sampling-rate $sampling_rate wav ${tmpdir}/$file_count.wav ${output_dir}/${output_file_name} || exit -1;" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/${DBname}_type${type_num}.rir.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.rir.list file_count=$((file_count + 1)) done diff --git a/egs/aspire/s5/local/multi_condition/rirs/prep_air.sh b/egs/aspire/s5/local/multi_condition/rirs/prep_air.sh index c7b6300db50..acc3963daba 100755 --- a/egs/aspire/s5/local/multi_condition/rirs/prep_air.sh +++ b/egs/aspire/s5/local/multi_condition/rirs/prep_air.sh @@ -27,6 +27,9 @@ RIR_home=$1 output_dir=$2 log_dir=$3 +mkdir -p $log_dir +mkdir -p $output_dir/info + if [ "$download" = true ]; then mkdir -p $RIR_home (cd $RIR_home; @@ -44,7 +47,7 @@ type_num=1 python local/multi_condition/get_air_file_patterns.py $AIR_home > $log_dir/air_file_pattern files_done=0 total_files=$(cat $log_dir/air_file_pattern|wc -l) -echo "" > $log_dir/${DBname}_type$type_num.rir.list +echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list echo "Found $total_files impulse responses in ${AIR_home}." command_file=$log_dir/${DBname}_read_rir_noise.sh @@ -54,7 +57,7 @@ while read file_pattern output_file_name; do # output_file_name=`echo ${DBname}_type${type_num}_${file_count}_$output_file_name| tr '[:upper:]' '[:lower:]'` output_file_name=`echo ${DBname}_type${type_num}_$output_file_name| tr '[:upper:]' '[:lower:]'` echo "local/multi_condition/read_rir.py --output-sampling-rate $sampling_rate air '${file_pattern}' ${output_dir}/${output_file_name} || exit 1;" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/${DBname}_type$type_num.rir.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.rir.list file_count=$((file_count + 1)) done < $log_dir/air_file_pattern diff --git a/egs/aspire/s5/local/multi_condition/rirs/prep_c4dm.sh b/egs/aspire/s5/local/multi_condition/rirs/prep_c4dm.sh index 8e5dd34d9ac..b3e4ed06a05 100755 --- a/egs/aspire/s5/local/multi_condition/rirs/prep_c4dm.sh +++ b/egs/aspire/s5/local/multi_condition/rirs/prep_c4dm.sh @@ -28,6 +28,9 @@ RIR_home=$1 output_dir=$2 log_dir=$3 +mkdir -p $log_dir +mkdir -p $output_dir/info + if [ "$download" = true ]; then mkdir -p $RIR_home (cd $RIR_home; @@ -85,7 +88,7 @@ echo "">$command_file type_num=1 data_files=( $(find $RIR_home/c4dm/*/*/ -name '*.wav' -type f -print || exit -1) ) total_files=$(echo ${data_files[@]}|wc -w) -echo "" > $log_dir/${DBname}_type${type_num}.rir.list +echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list echo "Found $total_files impulse responses in ${RIR_home}/c4dm/" tmpdir=`mktemp -d $log_dir/c4dm_XXXXXX` tmpdir=`readlink -e $tmpdir` @@ -96,7 +99,7 @@ for data_file in ${data_files[@]}; do output_file_name=${DBname}_type${type_num}_`basename $data_file| tr '[:upper:]' '[:lower:]'` echo "sox -t wav $data_file -t wav -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file #echo "python local/multi_condition/read_rir.py --output-sampling-rate $sampling_rate wav ${tmpdir}/${file_count}.wav ${output_dir}/${output_file_name} || exit -1;" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/${DBname}_type${type_num}.rir.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.rir.list file_count=$((file_count + 1)) done diff --git a/egs/aspire/s5/local/multi_condition/rirs/prep_mardy.sh b/egs/aspire/s5/local/multi_condition/rirs/prep_mardy.sh index 4690b9b1861..b0d79235c8d 100755 --- a/egs/aspire/s5/local/multi_condition/rirs/prep_mardy.sh +++ b/egs/aspire/s5/local/multi_condition/rirs/prep_mardy.sh @@ -27,6 +27,9 @@ RIR_home=$1 output_dir=$2 log_dir=$3 +mkdir -p $log_dir +mkdir -p $output_dir/info + if [ "$download" = true ]; then mkdir -p $RIR_home (cd $RIR_home; @@ -44,7 +47,7 @@ echo "">$command_file type_num=1 data_files=( $(find $RIR_home/mardy/ -name '*.wav' -type f -print || exit -1) ) total_files=$(echo ${data_files[@]}|wc -w) -echo "" > $log_dir/${DBname}_type${type_num}.rir.list +echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list echo "Found $total_files impulse responses in ${RIR_home}/mardy/" file_count=1 for data_file in ${data_files[@]}; do @@ -52,7 +55,7 @@ for data_file in ${data_files[@]}; do output_file_name=${DBname}_type${type_num}_`basename $data_file| tr '[:upper:]' '[:lower:]'` echo "sox -t wav $data_file -t wav -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file #echo "python local/multi_condition/read_rir.py --output-sampling-rate $sampling_rate wav ${data_file} ${output_dir}/${output_file_name} || exit -1;" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/${DBname}_type${type_num}.rir.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.rir.list file_count=$((file_count + 1)) done diff --git a/egs/aspire/s5/local/multi_condition/rirs/prep_musan.sh b/egs/aspire/s5/local/multi_condition/rirs/prep_musan.sh new file mode 100755 index 00000000000..f19aada1aa8 --- /dev/null +++ b/egs/aspire/s5/local/multi_condition/rirs/prep_musan.sh @@ -0,0 +1,117 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0 +# This script downloads the MUSAN music and noise database +#----------------------------------------- +# all data is wav at 16kHZ + +download=true +sampling_rate=8k +output_bit=16 +DBname=MUSAN +file_splitter= #script to generate job scripts given the command file + +. cmd.sh +. path.sh +. ./utils/parse_options.sh + +if [ $# != 3 ]; then + echo "Usage: " + echo " $0 [options] " + echo "e.g.:" + echo " $0 --download true db/RIR_databases/ data/impulses_noises exp/make_reverb/log" + exit 1; +fi + +RIR_home=$1 +output_dir=$2 +log_dir=$3 + +mkdir -p $log_dir +mkdir -p $output_dir/info + +if [ "$download" == true ]; then + mkdir -p $RIR_home + # MUSAN sound scene database + #========================== + (cd $RIR_home; + rm -rf musan.tar.gz + wget http://www.openslr.org/resources/17/musan.tar.gz || exit 1; + tar -zxvf musan.tar.gz >/dev/null + ) +fi + +MUSAN_home=$RIR_home/musan + +[ ! -d $MUSAN_home/noise ] && echo "$0: noise files not downloaded" && exit 1 + +find $MUSAN_home/noise -name "*.wav" -type f > $log_dir/${DBname}_noise.list +for x in `cat $log_dir/${DBname}_noise.list`; do + y=`basename $x` + z=${y%*.wav} + echo "$z $x" +done | sort -k 1,1 > $log_dir/${DBname}_noise.scp + +# Read the list of background noises +if [ `cat $MUSAN_home/noise/free-sound/ANNOTATIONS | wc -l` -ne 38 ]; then + echo "$0: expected 38 lines in file; probably the corpus got updated and hence this script must be modified" + exit 1 +fi + +cat $MUSAN_home/noise/free-sound/ANNOTATIONS | tail -n +2 | sort -u > $log_dir/${DBname}_background_noise.fileids + +command_file=$log_dir/${DBname}_read_rir_noise.sh + +echo -n "">$command_file + +# Ambient noise +base_dir_name=$MUSAN_home/noise/free-sound/ + +echo -n "" > $output_dir/info/${DBname}.background.noise.list +background_noise_files_done=0 +for x in `cat $log_dir/${DBname}_background_noise.fileids`; do + output_filename=$x.wav + [ ! -f $base_dir_name/$x.wav ] && echo "$0: could not find file $x.wav in $base_dir_name" && exit 1 + echo "sox $base_dir_name/$x.wav -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_filename}" >> $command_file + echo ${output_dir}/${output_filename} >> $output_dir/info/${DBname}.background.noise.list + background_noise_files_done=$((background_noise_files_done + 1)) +done + +echo -n "" > $output_dir/info/${DBname}.foreground.noise.list +foreground_noise_files_done=0 +utils/filter_scp.pl --exclude $log_dir/${DBname}_background_noise.fileids \ + $log_dir/${DBname}_noise.scp | \ +while IFS=$'\n' read x; do + file_id=`echo $x | awk '{print $1}'` + file=`echo $x | awk '{print $2}'` + output_filename=$file_id.wav + + [ ! -f $file ] && echo "$0: could not find file $file" && exit 1 + echo "sox $file -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_filename}" >> $command_file + echo ${output_dir}/${output_filename} >> $output_dir/info/${DBname}.foreground.noise.list + foreground_noise_files_done=$((foreground_noise_files_done + 1)) +done + +background_noise_files_done=`cat $output_dir/info/${DBname}.background.noise.list | wc -l` +foreground_noise_files_done=`cat $output_dir/info/${DBname}.foreground.noise.list | wc -l` + +echo "$0: read $foreground_noise_files_done foreground noise and $background_noise_files_done background noise files" + +if [ "$foreground_noise_files_done" -eq 0 ] || [ $background_noise_files_done -eq 0 ]; then + echo "$0: failed reading noise files from ${DBname} corpus" + exit 1 +fi + +if [ ! -z "$file_splitter" ]; then + num_jobs=$($file_splitter $command_file || exit 1) + job_file=${command_file%.sh}.JOB.sh + job_log=${command_file%.sh}.JOB.log +else + num_jobs=1 + job_file=$command_file + job_log=${command_file%.sh}.log +fi + +# execute the commands using the above created array jobs +time $decode_cmd --max-jobs-run 40 JOB=1:$num_jobs $job_log \ + sh $job_file || exit 1; diff --git a/egs/aspire/s5/local/multi_condition/rirs/prep_openair.sh b/egs/aspire/s5/local/multi_condition/rirs/prep_openair.sh index bd43da77079..349e12faf49 100755 --- a/egs/aspire/s5/local/multi_condition/rirs/prep_openair.sh +++ b/egs/aspire/s5/local/multi_condition/rirs/prep_openair.sh @@ -27,6 +27,9 @@ RIR_home=$1 output_dir=$2 log_dir=$3 +mkdir -p $log_dir +mkdir -p $output_dir/info + if [ "$download" = true ]; then mkdir -p $RIR_home RIR_home_abs=`readlink -e $RIR_home` @@ -428,7 +431,7 @@ echo "">$command_file type_num=1 data_files=( $(find $RIR_home/open_air/ -name '*.wav' -type f -print || exit -1) ) total_files=$(echo ${data_files[@]}|wc -w) -echo "" > $log_dir/${DBname}_type${type_num}.rir.list +echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list echo "Found $total_files impulse responses in ${RIR_home}/open_air/" file_count=1 # affix to ensure that files with same name are not overwritten for data_file in ${data_files[@]}; do @@ -436,7 +439,7 @@ for data_file in ${data_files[@]}; do # output_file_name=${DBname}_type${type_num}_${file_count}_`basename $data_file| tr '[:upper:]' '[:lower:]'` output_file_name=${DBname}_type${type_num}_`basename $data_file| tr '[:upper:]' '[:lower:]'` echo "sox -t wav $data_file -t wav -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/${DBname}_type${type_num}.rir.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.rir.list file_count=$((file_count + 1)) done diff --git a/egs/aspire/s5/local/multi_condition/rirs/prep_rvb2014.sh b/egs/aspire/s5/local/multi_condition/rirs/prep_rvb2014.sh index 32394556f01..f6a717b5608 100755 --- a/egs/aspire/s5/local/multi_condition/rirs/prep_rvb2014.sh +++ b/egs/aspire/s5/local/multi_condition/rirs/prep_rvb2014.sh @@ -1,5 +1,6 @@ #!/bin/bash # Copyright 2015 Johns Hopkins University (author: Vijayaditya Peddinti) +# 2015 Vimal Manohar # Apache 2.0 # This script downloads the impulse responses and noise files from the # Reverb2014 challenge @@ -28,6 +29,9 @@ RIR_home=$1 output_dir=$2 log_dir=$3 +mkdir -p $log_dir +mkdir -p $output_dir/info + if [ "$download" = true ]; then mkdir -p $RIR_home (cd $RIR_home; @@ -54,24 +58,24 @@ type_num=1 data_files=( $(find $Reverb2014_home1/RIR -name '*.wav' -type f -print || exit -1) ) files_done=0 total_files=$(echo ${data_files[@]}|wc -w) -echo "" > $log_dir/${DBname}_type${type_num}.rir.list +echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list echo "Found $total_files impulse responses in ${Reverb2014_home1}/RIR." for data_file in ${data_files[@]}; do output_file_name=${DBname}_type${type_num}_`basename $data_file | tr '[:upper:]' '[:lower:]'` echo "sox -t wav $data_file -t wav -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/${DBname}_type${type_num}.rir.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.rir.list files_done=$((files_done + 1)) done data_files=( $(find $Reverb2014_home1/NOISE -name '*.wav' -type f -print || exit -1) ) files_done=0 total_files=$(echo ${data_files[@]}|wc -w) -echo "" > $log_dir/${DBname}_type${type_num}.noise.list +echo "" > $output_dir/info/${DBname}_type${type_num}.noise.list echo "Found $total_files noises in ${Reverb2014_home1}/NOISE." for data_file in ${data_files[@]}; do output_file_name=${DBname}_type${type_num}_`basename $data_file| tr '[:upper:]' '[:lower:]'` echo "sox -t wav $data_file -t wav -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/${DBname}_type${type_num}.noise.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.noise.list files_done=$((files_done + 1)) done @@ -80,12 +84,12 @@ type_num=$((type_num + 1)) data_files=( $(find $Reverb2014_home2/RIR -name '*.wav' -type f -print || exit -1) ) files_done=0 total_files=$(echo ${data_files[@]}|wc -w) -echo "" > $log_dir/${DBname}_type${type_num}.rir.list +echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list echo "Found $total_files impulse responses in ${Reverb2014_home2}/RIR." for data_file in ${data_files[@]}; do output_file_name=${DBname}_type${type_num}_`basename $data_file| tr '[:upper:]' '[:lower:]'` echo "sox -t wav $data_file -t wav -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/${DBname}_type${type_num}.rir.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.rir.list files_done=$((files_done + 1)) done @@ -93,12 +97,12 @@ done data_files=( $(find $Reverb2014_home2/NOISE -name '*.wav' -type f -print || exit -1) ) files_done=0 total_files=$(echo ${data_files[@]}|wc -w) -echo "" > $log_dir/${DBname}_type${type_num}.noise.list +echo "" > $output_dir/info/${DBname}_type${type_num}.noise.list echo "Found $total_files noises in ${Reverb2014_home2}/NOISE." for data_file in ${data_files[@]}; do output_file_name=${DBname}_type${type_num}_`basename $data_file | tr '[:upper:]' '[:lower:]'` echo "sox -t wav $data_file -t wav -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/${DBname}_type${type_num}.noise.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.noise.list files_done=$((files_done + 1)) done diff --git a/egs/aspire/s5/local/multi_condition/rirs/prep_rwcp.sh b/egs/aspire/s5/local/multi_condition/rirs/prep_rwcp.sh index b44669b86f1..901dada0c9c 100755 --- a/egs/aspire/s5/local/multi_condition/rirs/prep_rwcp.sh +++ b/egs/aspire/s5/local/multi_condition/rirs/prep_rwcp.sh @@ -34,6 +34,9 @@ RIR_home=$1 output_dir=$2 log_dir=$3 +mkdir -p $log_dir +mkdir -p $output_dir/info + if [ "$download" = true ]; then mkdir -p $RIR_home # RWCP sound scene database @@ -62,20 +65,20 @@ for base_dir_name in ${RWCP_dirs[@]}; do files_done=0 total_files=$(echo ${leaf_directories[@]}|wc -w) echo "Found ${total_files} impulse responses in ${base_dir_name}." - echo "" > $log_dir/RWCP_type$type_num.rir.list + echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list # create the list of commands to be executed for leaf_dir_name in ${leaf_directories[@]}; do first_channel=$(ls $leaf_dir_name|sed -e"s/.*\.//g"|sort -n|head -1) last_channel=$(ls $leaf_dir_name|sed -e"s/.*\.//g"|sort -nr|head -1) file_base_name=$(basename $leaf_dir_name) output_file_name=`echo ${leaf_dir_name#$base_dir_name}| sed -e"s/[\/\]\+/_/g" | tr '[:upper:]' '[:lower:]'` - output_file_name=RWCP_type${type_num}_rir_${output_file_name}.wav + output_file_name=${DBname}_type${type_num}_rir_${output_file_name}.wav channel_files= for i in `seq $first_channel $last_channel`; do channel_files="$channel_files -t raw -e float -b 32 -c 1 -r 48k $leaf_dir_name/$file_base_name.$i "; done echo "sox -M $channel_files -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/RWCP_type$type_num.rir.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.rir.list files_done=$((files_done + 1)) done done @@ -99,15 +102,15 @@ type_num=$((type_num + 1)) data_files=( $(find $RWCP_home/robot/data/non -name '*.dat' -type f -print || exit -1) ) files_done=0 total_files=$(echo ${data_files[@]}|wc -w) -echo "" > $log_dir/RWCP_type$type_num.rir.list +echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list echo "Found $total_files impulse responses in ${RWCP_home}/robot/data/non." # create the list of commands to be executed for data_file in ${data_files[@]}; do temp_file=$tempdir_robo/$files_done.wav python $tempdir_robo/raw_read.py $data_file $temp_file - output_file_name=RWCP_type${type_num}_rir_`basename $data_file .dat | tr '[:upper:]' '[:lower:]'`.wav + output_file_name=${DBname}_type${type_num}_rir_`basename $data_file .dat | tr '[:upper:]' '[:lower:]'`.wav echo "sox -t wav $temp_file -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/RWCP_type$type_num.rir.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.rir.list files_done=$((files_done + 1)) done @@ -117,21 +120,21 @@ base_dir_name=$RWCP_home/micarray/MICARRAY/data6/ leaf_directories=( $(find $base_dir_name -type d -links 2 -print || exit -1) ) files_done=0 total_files=$(echo ${leaf_directories[@]}|wc -w) -echo "" > $log_dir/RWCP_type$type_num.noise.list +echo "" > $output_dir/info/${DBname}_type${type_num}.noise.list echo "Found $total_files noises in ${base_dir_name}." for leaf_dir_name in ${leaf_directories[@]}; do first_channel=$(ls $leaf_dir_name|sed -e"s/.*\.//g"|sort -n|head -1) last_channel=$(ls $leaf_dir_name|sed -e"s/.*\.//g"|sort -nr|head -1) file_base_name=$(basename $leaf_dir_name) output_file_name=`echo ${leaf_dir_name#$base_dir_name}| sed -e"s/[\/\]\+/_/g" | tr '[:upper:]' '[:lower:]'` - output_file_name=RWCP_type${type_num}_noise_${output_file_name}.wav + output_file_name=${DBname}_type${type_num}_noise_${output_file_name}.wav channel_files= for i in `seq $first_channel $last_channel`; do channel_files="$channel_files -t raw -e signed-integer -b 16 -c 1 -r 48k $leaf_dir_name/$file_base_name.$i "; done echo "sox -M $channel_files -r $sampling_rate -e signed-integer -b $output_bit ${output_dir}/${output_file_name}" >> $command_file - echo ${output_dir}/${output_file_name} >> $log_dir/RWCP_type$type_num.noise.list + echo ${output_dir}/${output_file_name} >> $output_dir/info/${DBname}_type${type_num}.noise.list files_done=$((files_done + 1)) done @@ -159,10 +162,10 @@ time $decode_cmd --max-jobs-run 40 JOB=1:$num_jobs $job_log \ # get the RWCP database noise mic and room settings to pair with corresponding impulse responses type_num=5 -noise_patterns=( $(ls ${output_dir}/RWCP_type${type_num}_noise*.wav | xargs -n1 basename | python -c" +noise_patterns=( $(ls ${output_dir}/${DBname}_type${type_num}_noise*.wav | xargs -n1 basename | python -c" import sys for line in sys.stdin: - name = line.split('RWCP_type${type_num}_noise')[1] + name = line.split('${DBname}_type${type_num}_noise')[1] print '_'.join(name.split('_')[1:-1]) "|sort -u) ) diff --git a/egs/aspire/s5/local/multi_condition/rirs/prep_varechoic.sh b/egs/aspire/s5/local/multi_condition/rirs/prep_varechoic.sh index 4be2b1779f3..0d2b825000c 100755 --- a/egs/aspire/s5/local/multi_condition/rirs/prep_varechoic.sh +++ b/egs/aspire/s5/local/multi_condition/rirs/prep_varechoic.sh @@ -27,6 +27,9 @@ RIR_home=$1 output_dir=$2 log_dir=$3 +mkdir -p $log_dir +mkdir -p $output_dir/info + if [ "$download" = true ]; then mkdir -p $RIR_home (cd $RIR_home; @@ -43,13 +46,13 @@ fi command_file=$log_dir/${DBname}_read_rir_noise.sh echo "">$command_file type_num=1 -echo "" > $log_dir/${DBname}_type$type_num.rir.list +echo "" > $output_dir/info/${DBname}_type${type_num}.rir.list varechoic_home=$RIR_home/icsi_varechoic/varechoic for room_type in ir00 ir43 ir100 ; do for mike in m1 m2 m3 m4; do file_basename=${room_type}${mike} echo "sox -B -e float -b 32 -c 1 -r 8k -t raw $varechoic_home/${file_basename}.raw -t wav -b $output_bit $output_dir/${DBname}_${file_basename}.wav" >> $command_file - echo $output_dir/${DBname}_${file_basename}.wav >> $log_dir/${DBname}_type$type_num.rir.list + echo $output_dir/${DBname}_${file_basename}.wav >> $output_dir/info/${DBname}_type${type_num}.rir.list done done diff --git a/egs/aspire/s5/local/multi_condition/run_nnet2_common.sh b/egs/aspire/s5/local/multi_condition/run_nnet2_common.sh index 5b6424a1d86..4f4141f676a 100755 --- a/egs/aspire/s5/local/multi_condition/run_nnet2_common.sh +++ b/egs/aspire/s5/local/multi_condition/run_nnet2_common.sh @@ -1,5 +1,4 @@ #!/bin/bash -#set -e # this script is based on local/online/run_nnet2_comman.sh # but it operates on corrupted training/dev/test data sets @@ -15,6 +14,8 @@ RIR_home=db/RIR_databases/ # parent directory of the RIR databases files download_rirs=true # download the RIR databases from the urls or assume they are present in the RIR_home directory set -e +set -o pipefail + . cmd.sh . ./path.sh . ./utils/parse_options.sh diff --git a/egs/aspire/s5/local/run_diarization.sh b/egs/aspire/s5/local/run_diarization.sh new file mode 100755 index 00000000000..ae2b6efd51e --- /dev/null +++ b/egs/aspire/s5/local/run_diarization.sh @@ -0,0 +1,141 @@ +#LDA_MLLT_transform=exp/nnet/final.mat +#nnet=exp/nnet/final.mdl +#mfccdir=`pwd`/mfcc +#vaddir=`pwd`/mfcc +#trials_female=data/sre10_test_female/trials +#trials_male=data/sre10_test_male/trials +# +#steps/make_mfcc.sh --mfcc-config conf/mfcc.conf --nj 80 --cmd "$train_cmd" \ +# data/callhome exp/make_mfcc $mfccdir +#utils/fix_data_dir.sh data/callhome +# +#sid/compute_vad_decision.sh --nj 40 --cmd "$train_cmd" \ +# data/callhome exp/make_vad $vaddir +# +#sid/extract_ivectors.sh --cmd "$train_cmd -l mem_free=10G,ram_free=10G" \ +# exp/extractor data/callhome \ +# exp/callhome +# +#ivector-normalize-length scp:exp/callhome/ivector.scp ark:- | ivector-subtract-global-mean ark:- ark:ivec1.ark +#ivector-subtract-global-mean scp:exp/callhome/ivector.scp ark:ivec2.ark +# +## The version 2 assumes that there are a variable number of speakers. +## it does a kind of greedy clustering. If you remove the _v2 and the threshold +## argument, it then does K-Means clustering and assumes that you have only 2 speakers. +#speaker-diarization_v2 --threshold=-100 plda ark:data/callhome/spk2utt \ +# ark:feat_len.ark \ +# ark:ivec2.ark ark,t:diar_results.txt + +. cmd.sh +. path.sh +set -e +set -o pipefail + +overlap=0.5 +window=1.5 +silence_weight=0.00001 +max_count=100 # parameter for extract_ivectors.sh +mfccdir=mfcc_diarization +stage=-1 +nj=30 + +. utils/parse_options.sh + +if [ $# -ne 6 ]; then + echo "Usage: $0 [options] " + echo " e.g.: $0 data/dev_aspire_whole \"ark:gunzip -c exp/nnet2_multicondition/ivector_weights_dev_aspire_whole/file_weights.ark.gz |\" exp/nnet2_multicondition/ivector_extractor exp/nnet2_multicondition/diarization_dev_aspire_whole" + echo " Options:" + echo " --stage (0|1|2) # start script from part-way through." + exit 1 +fi + +data_dir=$1 +lang=$2 +file_weights=$3 +extractor=$4 +plda=$5 +dir=$6 + +echo "$0: file weights are ignored by this script" + +data_id=`basename $data_dir` +segmented_data_dir=${data_dir}_uniformsegmented_win${window}_over${overlap} + +if [ $stage -le 0 ]; then + utils/copy_data_dir.sh --validate-opts "--no-text" $data_dir $segmented_data_dir || exit 1 + cp $data_dir/reco2file_and_channel $segmented_data_dir || exit 1 + + local/multi_condition/create_uniform_segments.py --overlap $overlap --window $window $segmented_data_dir || exit 1 + for file in cmvn.scp feats.scp text; do + rm -f $segmented_data_dir/$file + done +fi + +utils/validate_data_dir.sh --no-text --no-feats $segmented_data_dir || exit 1 + +segmented_data_id=`basename $segmented_data_dir` + +if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config conf/mfcc_hires.conf --nj $nj \ + --cmd "$train_cmd" $segmented_data_dir \ + exp/make_mfcc_diarization/$segmented_data_id $mfccdir || exit 1 + steps/compute_cmvn_stats.sh $segmented_data_dir \ + exp/make_mfcc_diarization/$segmented_data_id $mfccdir || exit 1 + utils/fix_data_dir.sh $segmented_data_dir + utils/validate_data_dir.sh --no-text $segmented_data_dir || exit 1 +fi + +#if [ $stage -le 2 ]; then +# $train_cmd $dir/ivector_weights/log/extract_weights.log \ +# extract-vector-segments --trim-last-frames=2 --max-overshoot=0.025 \ +# "$file_weights" $segmented_data_dir/segments \ +# "ark:| gzip -c > $dir/ivector_weights/weights.gz" || exit 1 +#fi + +if [ $stage -le 3 ]; then + diarization/extract_ivectors.sh --cmd "$train_cmd" --nj $nj \ + --silence-weight $silence_weight --max-count $max_count \ + --ivector-period 1 \ + $segmented_data_dir $lang $extractor \ + $dir/ivectors || exit 1 + #$dir/ivector_weights/weights.gz $dir/ivectors || exit 1 +fi + +utils/split_data.sh $segmented_data_dir $nj || exit 1 + +if [ $stage -le 4 ]; then + $train_cmd JOB=1:$nj $dir/diarization/log/compute_feat_len.JOB.log \ + feat-to-len scp:$segmented_data_dir/split$nj/JOB/feats.scp \ + ark,t:$dir/diarization/feat_len.JOB.txt || exit 1 +fi + +if [ $stage -le 5 ]; then + #$train_cmd JOB=1:$nj $dir/diarization/log/do_diarization.JOB.log \ + # speaker-diarization_v2 --threshold=-100 $plda \ + # ark:$segmented_data_dir/split$nj/JOB/spk2utt \ + # ark,t:$dir/diarization/feat_len.JOB.txt \ + # "scp:utils/filter_scp.pl $segmented_data_dir/split$nj/JOB/utt2spk $dir/ivectors/ivectors_utt.scp |" \ + # ark,t:$dir/diarization/diarization_results.JOB.txt || exit 1 + plda= + $train_cmd JOB=1:$nj $dir/diarization/log/do_diarization.JOB.log \ + speaker-diarization --num-speakers=3 $plda \ + ark:$segmented_data_dir/split$nj/JOB/spk2utt \ + "scp:utils/filter_scp.pl $segmented_data_dir/split$nj/JOB/utt2spk $dir/ivectors/ivectors_utt.scp |" \ + ark,t:$dir/diarization/diarization_results.JOB.txt || exit 1 +fi + +echo $nj > $dir/diarization/num_jobs + +if [ $stage -le 6 ]; then + mkdir -p $dir/diarization/data_out + $train_cmd JOB=1:$nj $dir/diarization/log/convert_diarization_to_segmentation.JOB.log \ + segmentation-init-from-diarization --diarization-window-overlap=0.5 \ + ark,t:$dir/diarization/diarization_results.JOB.txt \ + $segmented_data_dir/split$nj/JOB/segments \ + ark,scp:$dir/diarization/diarization_segmentation.JOB.ark,$dir/diarization/diarization_segmentation.JOB.scp || exit 1 + $train_cmd JOB=1:$nj $dir/diarization/log/convert_diarization_segmentation_to_segments.JOB.log \ + segmentation-to-segments ark:$dir/diarization/diarization_segmentation.JOB.ark ark,t:$dir/diarization/data_out/utt2spk.JOB \ + $dir/diarization/data_out/segments.JOB || exit 1 +fi + + diff --git a/egs/aspire/s5/local/score_stm.sh b/egs/aspire/s5/local/score_stm.sh index 7f559f7dd79..16bc70e1d44 100755 --- a/egs/aspire/s5/local/score_stm.sh +++ b/egs/aspire/s5/local/score_stm.sh @@ -17,6 +17,9 @@ # This is a scoring script for the CTMS in /score_/${name}.ctm # it tries to mimic the NIST scoring setup as much as possible (and usually does a good job) +set -e +set -o pipefail + # begin configuration section. cmd=run.pl cer=0 @@ -56,6 +59,8 @@ ScoringProgram=`which sclite` || ScoringProgram=$KALDI_ROOT/tools/sctk/bin/sclit SortingProgram=`which hubscr.pl` || SortingProgram=$KALDI_ROOT/tools/sctk/bin/hubscr.pl [ ! -x $ScoringProgram ] && echo "Cannot find scoring program at $ScoringProgram" && exit 1; +hubscr=`which hubscr.pl` +hubdir=`dirname $hubscr` for f in $data/stm ; do [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; @@ -87,12 +92,12 @@ if [ $stage -le 0 ] ; then -n "$name.ctm" -f 0 -D -F -o sum rsum prf dtl sgml -e utf-8 || exit 1 fi -# Score the set... -if [ $stage -le 1 ]; then - $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.log \ - cp $data/stm $dir/score_LMWT/ '&&' \ - $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/score_LMWT/stm $dir/score_LMWT/ctm.filt || exit 1; -fi +## Score the set... +#if [ $stage -le 1 ]; then +# $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.log \ +# cp $data/stm $dir/score_LMWT/ '&&' \ +# $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/score_LMWT/stm $dir/score_LMWT/ctm.filt || exit 1; +#fi diff --git a/egs/aspire/s5/local/snr b/egs/aspire/s5/local/snr new file mode 120000 index 00000000000..6d422e11960 --- /dev/null +++ b/egs/aspire/s5/local/snr @@ -0,0 +1 @@ +../../../wsj_noisy/s5/local/snr \ No newline at end of file diff --git a/egs/aspire/s5/local/vad_phone_map_2models b/egs/aspire/s5/local/vad_phone_map_2models new file mode 100644 index 00000000000..2373197dc6a --- /dev/null +++ b/egs/aspire/s5/local/vad_phone_map_2models @@ -0,0 +1,176 @@ +sil 0 +sil_B 0 +sil_E 0 +sil_I 0 +sil_S 0 +laughter 0 +laughter_B 0 +laughter_E 0 +laughter_I 0 +laughter_S 0 +noise 0 +noise_B 0 +noise_E 0 +noise_I 0 +noise_S 0 +oov 3 +oov_B 3 +oov_E 3 +oov_I 3 +oov_S 3 +aa_B 1 +aa_E 1 +aa_I 1 +aa_S 1 +ae_B 1 +ae_E 1 +ae_I 1 +ae_S 1 +ah_B 1 +ah_E 1 +ah_I 1 +ah_S 1 +ao_B 1 +ao_E 1 +ao_I 1 +ao_S 1 +aw_B 1 +aw_E 1 +aw_I 1 +aw_S 1 +ay_B 1 +ay_E 1 +ay_I 1 +ay_S 1 +b_B 1 +b_E 1 +b_I 1 +b_S 1 +ch_B 1 +ch_E 1 +ch_I 1 +ch_S 1 +d_B 1 +d_E 1 +d_I 1 +d_S 1 +dh_B 1 +dh_E 1 +dh_I 1 +dh_S 1 +eh_B 1 +eh_E 1 +eh_I 1 +eh_S 1 +er_B 1 +er_E 1 +er_I 1 +er_S 1 +ey_B 1 +ey_E 1 +ey_I 1 +ey_S 1 +f_B 1 +f_E 1 +f_I 1 +f_S 1 +g_B 1 +g_E 1 +g_I 1 +g_S 1 +hh_B 1 +hh_E 1 +hh_I 1 +hh_S 1 +ih_B 1 +ih_E 1 +ih_I 1 +ih_S 1 +iy_B 1 +iy_E 1 +iy_I 1 +iy_S 1 +jh_B 1 +jh_E 1 +jh_I 1 +jh_S 1 +k_B 1 +k_E 1 +k_I 1 +k_S 1 +l_B 1 +l_E 1 +l_I 1 +l_S 1 +m_B 1 +m_E 1 +m_I 1 +m_S 1 +n_B 1 +n_E 1 +n_I 1 +n_S 1 +ng_B 1 +ng_E 1 +ng_I 1 +ng_S 1 +ow_B 1 +ow_E 1 +ow_I 1 +ow_S 1 +oy_B 1 +oy_E 1 +oy_I 1 +oy_S 1 +p_B 1 +p_E 1 +p_I 1 +p_S 1 +r_B 1 +r_E 1 +r_I 1 +r_S 1 +s_B 1 +s_E 1 +s_I 1 +s_S 1 +sh_B 1 +sh_E 1 +sh_I 1 +sh_S 1 +t_B 1 +t_E 1 +t_I 1 +t_S 1 +th_B 1 +th_E 1 +th_I 1 +th_S 1 +uh_B 1 +uh_E 1 +uh_I 1 +uh_S 1 +uw_B 1 +uw_E 1 +uw_I 1 +uw_S 1 +v_B 1 +v_E 1 +v_I 1 +v_S 1 +w_B 1 +w_E 1 +w_I 1 +w_S 1 +y_B 1 +y_E 1 +y_I 1 +y_S 1 +z_B 1 +z_E 1 +z_I 1 +z_S 1 +zh_B 1 +zh_E 1 +zh_I 1 +zh_S 1 diff --git a/egs/aspire/s5/local/vad_phone_map_3models b/egs/aspire/s5/local/vad_phone_map_3models new file mode 100644 index 00000000000..15bd04cfe44 --- /dev/null +++ b/egs/aspire/s5/local/vad_phone_map_3models @@ -0,0 +1,176 @@ +sil 0 +sil_B 0 +sil_E 0 +sil_I 0 +sil_S 0 +laughter 2 +laughter_B 2 +laughter_E 2 +laughter_I 2 +laughter_S 2 +noise 2 +noise_B 2 +noise_E 2 +noise_I 2 +noise_S 2 +oov 3 +oov_B 3 +oov_E 3 +oov_I 3 +oov_S 3 +aa_B 1 +aa_E 1 +aa_I 1 +aa_S 1 +ae_B 1 +ae_E 1 +ae_I 1 +ae_S 1 +ah_B 1 +ah_E 1 +ah_I 1 +ah_S 1 +ao_B 1 +ao_E 1 +ao_I 1 +ao_S 1 +aw_B 1 +aw_E 1 +aw_I 1 +aw_S 1 +ay_B 1 +ay_E 1 +ay_I 1 +ay_S 1 +b_B 1 +b_E 1 +b_I 1 +b_S 1 +ch_B 1 +ch_E 1 +ch_I 1 +ch_S 1 +d_B 1 +d_E 1 +d_I 1 +d_S 1 +dh_B 1 +dh_E 1 +dh_I 1 +dh_S 1 +eh_B 1 +eh_E 1 +eh_I 1 +eh_S 1 +er_B 1 +er_E 1 +er_I 1 +er_S 1 +ey_B 1 +ey_E 1 +ey_I 1 +ey_S 1 +f_B 1 +f_E 1 +f_I 1 +f_S 1 +g_B 1 +g_E 1 +g_I 1 +g_S 1 +hh_B 1 +hh_E 1 +hh_I 1 +hh_S 1 +ih_B 1 +ih_E 1 +ih_I 1 +ih_S 1 +iy_B 1 +iy_E 1 +iy_I 1 +iy_S 1 +jh_B 1 +jh_E 1 +jh_I 1 +jh_S 1 +k_B 1 +k_E 1 +k_I 1 +k_S 1 +l_B 1 +l_E 1 +l_I 1 +l_S 1 +m_B 1 +m_E 1 +m_I 1 +m_S 1 +n_B 1 +n_E 1 +n_I 1 +n_S 1 +ng_B 1 +ng_E 1 +ng_I 1 +ng_S 1 +ow_B 1 +ow_E 1 +ow_I 1 +ow_S 1 +oy_B 1 +oy_E 1 +oy_I 1 +oy_S 1 +p_B 1 +p_E 1 +p_I 1 +p_S 1 +r_B 1 +r_E 1 +r_I 1 +r_S 1 +s_B 1 +s_E 1 +s_I 1 +s_S 1 +sh_B 1 +sh_E 1 +sh_I 1 +sh_S 1 +t_B 1 +t_E 1 +t_I 1 +t_S 1 +th_B 1 +th_E 1 +th_I 1 +th_S 1 +uh_B 1 +uh_E 1 +uh_I 1 +uh_S 1 +uw_B 1 +uw_E 1 +uw_I 1 +uw_S 1 +v_B 1 +v_E 1 +v_I 1 +v_S 1 +w_B 1 +w_E 1 +w_I 1 +w_S 1 +y_B 1 +y_E 1 +y_I 1 +y_S 1 +z_B 1 +z_E 1 +z_I 1 +z_S 1 +zh_B 1 +zh_E 1 +zh_I 1 +zh_S 1 diff --git a/egs/aspire/s5/local/vad_phone_map_voiced b/egs/aspire/s5/local/vad_phone_map_voiced new file mode 100644 index 00000000000..f97504841a2 --- /dev/null +++ b/egs/aspire/s5/local/vad_phone_map_voiced @@ -0,0 +1,176 @@ +sil 0 +sil_B 0 +sil_E 0 +sil_I 0 +sil_S 0 +laughter 2 +laughter_B 2 +laughter_E 2 +laughter_I 2 +laughter_S 2 +noise 2 +noise_B 2 +noise_E 2 +noise_I 2 +noise_S 2 +oov 3 +oov_B 3 +oov_E 3 +oov_I 3 +oov_S 3 +aa_B 1 +aa_E 1 +aa_I 1 +aa_S 1 +ae_B 1 +ae_E 1 +ae_I 1 +ae_S 1 +ah_B 1 +ah_E 1 +ah_I 1 +ah_S 1 +ao_B 1 +ao_E 1 +ao_I 1 +ao_S 1 +aw_B 1 +aw_E 1 +aw_I 1 +aw_S 1 +ay_B 1 +ay_E 1 +ay_I 1 +ay_S 1 +b_B 1 +b_E 1 +b_I 1 +b_S 1 +ch_B 2 +ch_E 2 +ch_I 2 +ch_S 2 +d_B 1 +d_E 1 +d_I 1 +d_S 1 +dh_B 1 +dh_E 1 +dh_I 1 +dh_S 1 +eh_B 1 +eh_E 1 +eh_I 1 +eh_S 1 +er_B 1 +er_E 1 +er_I 1 +er_S 1 +ey_B 1 +ey_E 1 +ey_I 1 +ey_S 1 +f_B 2 +f_E 2 +f_I 2 +f_S 2 +g_B 1 +g_E 1 +g_I 1 +g_S 1 +hh_B 2 +hh_E 2 +hh_I 2 +hh_S 2 +ih_B 1 +ih_E 1 +ih_I 1 +ih_S 1 +iy_B 1 +iy_E 1 +iy_I 1 +iy_S 1 +jh_B 1 +jh_E 1 +jh_I 1 +jh_S 1 +k_B 2 +k_E 2 +k_I 2 +k_S 2 +l_B 2 +l_E 2 +l_I 2 +l_S 2 +m_B 1 +m_E 1 +m_I 1 +m_S 1 +n_B 2 +n_E 2 +n_I 2 +n_S 2 +ng_B 2 +ng_E 2 +ng_I 2 +ng_S 2 +ow_B 1 +ow_E 1 +ow_I 1 +ow_S 1 +oy_B 1 +oy_E 1 +oy_I 1 +oy_S 1 +p_B 2 +p_E 2 +p_I 2 +p_S 2 +r_B 2 +r_E 2 +r_I 2 +r_S 2 +s_B 2 +s_E 2 +s_I 2 +s_S 2 +sh_B 2 +sh_E 2 +sh_I 2 +sh_S 2 +t_B 2 +t_E 2 +t_I 2 +t_S 2 +th_B 2 +th_E 2 +th_I 2 +th_S 2 +uh_B 1 +uh_E 1 +uh_I 1 +uh_S 1 +uw_B 1 +uw_E 1 +uw_I 1 +uw_S 1 +v_B 1 +v_E 1 +v_I 1 +v_S 1 +w_B 1 +w_E 1 +w_I 1 +w_S 1 +y_B 2 +y_E 2 +y_I 2 +y_S 2 +z_B 1 +z_E 1 +z_I 1 +z_S 1 +zh_B 1 +zh_E 1 +zh_I 1 +zh_S 1 diff --git a/egs/aspire/s5/local/vad_phone_map_vowels b/egs/aspire/s5/local/vad_phone_map_vowels new file mode 100644 index 00000000000..bb68cbb4716 --- /dev/null +++ b/egs/aspire/s5/local/vad_phone_map_vowels @@ -0,0 +1,176 @@ +sil 0 +sil_B 0 +sil_E 0 +sil_I 0 +sil_S 0 +laughter 2 +laughter_B 2 +laughter_E 2 +laughter_I 2 +laughter_S 2 +noise 2 +noise_B 2 +noise_E 2 +noise_I 2 +noise_S 2 +oov 3 +oov_B 3 +oov_E 3 +oov_I 3 +oov_S 3 +aa_B 1 +aa_E 1 +aa_I 1 +aa_S 1 +ae_B 1 +ae_E 1 +ae_I 1 +ae_S 1 +ah_B 1 +ah_E 1 +ah_I 1 +ah_S 1 +ao_B 1 +ao_E 1 +ao_I 1 +ao_S 1 +aw_B 1 +aw_E 1 +aw_I 1 +aw_S 1 +ay_B 1 +ay_E 1 +ay_I 1 +ay_S 1 +b_B 2 +b_E 2 +b_I 2 +b_S 2 +ch_B 2 +ch_E 2 +ch_I 2 +ch_S 2 +d_B 2 +d_E 2 +d_I 2 +d_S 2 +dh_B 2 +dh_E 2 +dh_I 2 +dh_S 2 +eh_B 1 +eh_E 1 +eh_I 1 +eh_S 1 +er_B 1 +er_E 1 +er_I 1 +er_S 1 +ey_B 1 +ey_E 1 +ey_I 1 +ey_S 1 +f_B 2 +f_E 2 +f_I 2 +f_S 2 +g_B 2 +g_E 2 +g_I 2 +g_S 2 +hh_B 2 +hh_E 2 +hh_I 2 +hh_S 2 +ih_B 1 +ih_E 1 +ih_I 1 +ih_S 1 +iy_B 1 +iy_E 1 +iy_I 1 +iy_S 1 +jh_B 2 +jh_E 2 +jh_I 2 +jh_S 2 +k_B 2 +k_E 2 +k_I 2 +k_S 2 +l_B 2 +l_E 2 +l_I 2 +l_S 2 +m_B 2 +m_E 2 +m_I 2 +m_S 2 +n_B 2 +n_E 2 +n_I 2 +n_S 2 +ng_B 2 +ng_E 2 +ng_I 2 +ng_S 2 +ow_B 1 +ow_E 1 +ow_I 1 +ow_S 1 +oy_B 1 +oy_E 1 +oy_I 1 +oy_S 1 +p_B 2 +p_E 2 +p_I 2 +p_S 2 +r_B 2 +r_E 2 +r_I 2 +r_S 2 +s_B 2 +s_E 2 +s_I 2 +s_S 2 +sh_B 2 +sh_E 2 +sh_I 2 +sh_S 2 +t_B 2 +t_E 2 +t_I 2 +t_S 2 +th_B 2 +th_E 2 +th_I 2 +th_S 2 +uh_B 1 +uh_E 1 +uh_I 1 +uh_S 1 +uw_B 1 +uw_E 1 +uw_I 1 +uw_S 1 +v_B 2 +v_E 2 +v_I 2 +v_S 2 +w_B 2 +w_E 2 +w_I 2 +w_S 2 +y_B 2 +y_E 2 +y_I 2 +y_S 2 +z_B 2 +z_E 2 +z_I 2 +z_S 2 +zh_B 2 +zh_E 2 +zh_I 2 +zh_S 2 diff --git a/egs/aspire/s5/path.sh b/egs/aspire/s5/path.sh index e93eb33f24b..ec1563602d8 100755 --- a/egs/aspire/s5/path.sh +++ b/egs/aspire/s5/path.sh @@ -1,3 +1,4 @@ export KALDI_ROOT=`pwd`/../../.. -export PATH=$PWD/utils/:$KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin:$PWD:$PATH +export PATH=$PWD/utils/:$KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin:$KALDI_ROOT/src/segmenterbin:$KALDI_ROOT/src/nnet3bin:$PWD:$PATH +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH export LC_ALL=C diff --git a/egs/aspire/s5/run.sh b/egs/aspire/s5/run.sh index c2191aa2c3f..c3ea0a7c1ee 100755 --- a/egs/aspire/s5/run.sh +++ b/egs/aspire/s5/run.sh @@ -134,6 +134,8 @@ steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ steps/train_sat.sh --cmd "$train_cmd" \ 5000 100000 data/train_100k data/lang exp/tri3a_ali exp/tri4a || exit 1; +exit 1 + ( utils/mkgraph.sh data/lang_test exp/tri4a exp/tri4a/graph steps/decode_fmllr.sh --nj 25 --cmd "$decode_cmd" --config conf/decode.config \ diff --git a/egs/babel/s5b/conf/mfcc_vad.conf b/egs/babel/s5b/conf/mfcc_vad.conf new file mode 100644 index 00000000000..a5c9243eee0 --- /dev/null +++ b/egs/babel/s5b/conf/mfcc_vad.conf @@ -0,0 +1,5 @@ +--sample-frequency=8000 +--frame-length=25 # the default is 25. +--low-freq=20 # the default. +--high-freq=-300 # the default is zero meaning use the Nyquist (4k in this case). +--num-ceps=13 # higher than the default which is 12. diff --git a/egs/babel/s5b/conf/vad_icsi_babel.conf b/egs/babel/s5b/conf/vad_icsi_babel.conf new file mode 100644 index 00000000000..70f651403f5 --- /dev/null +++ b/egs/babel/s5b/conf/vad_icsi_babel.conf @@ -0,0 +1,39 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/babel/s5b/conf/zc_vad.conf b/egs/babel/s5b/conf/zc_vad.conf new file mode 100644 index 00000000000..1475967e7b1 --- /dev/null +++ b/egs/babel/s5b/conf/zc_vad.conf @@ -0,0 +1,4 @@ +--sample-frequency=8000 +--frame-length=25 # the default is 25. +--dither=0.0 +--zero-crossing-threshold=1e-5 diff --git a/egs/babel/s5b/diarization b/egs/babel/s5b/diarization new file mode 120000 index 00000000000..861e0c23da4 --- /dev/null +++ b/egs/babel/s5b/diarization @@ -0,0 +1 @@ +../../rt/s5/diarization \ No newline at end of file diff --git a/egs/babel/s5b/local/train_vad_gmm_rttm.sh b/egs/babel/s5b/local/train_vad_gmm_rttm.sh new file mode 100644 index 00000000000..4c6cdc2954c --- /dev/null +++ b/egs/babel/s5b/local/train_vad_gmm_rttm.sh @@ -0,0 +1,223 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail + +# Begin configuration section. +cmd=run.pl +speech_max_gauss=64 +noise_max_gauss=64 +sil_max_gauss=32 +sil_num_gauss_init=4 +noise_num_gauss_init=4 +speech_num_gauss_init=4 +num_iters=10 +stage=-10 +cleanup=true +top_frames_threshold=1.0 +bottom_frames_threshold=1.0 +ignore_energy=true +add_zero_crossing_feats=true +nj=4 +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 5 ]; then + echo "Usage: local/train_vad_gmm_rttm.sh " + echo " e.g.: local/train_vad_gmm_rttm.sh data/dev segments exp/tri4_ali/vad mitfa.rttm exp/vad_dev" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --num-iters <#iters> # Number of iterations of E-M" + exit 1; +fi + +data=$1 +vad_scp=$2 +segments=$3 +rttm=$4 +dir=$5 + +mkdir -p $dir + +feat_dim=`head -n 1 $data/feats.scp | feat-to-dim scp:- ark,t:- | awk '{print $2}'` || exit 1 + +ignore_energy_opts= +if $ignore_energy; then + ignore_energy_opts="select-feats 1-$[feat_dim-1] ark:- ark:- |" +fi + +echo "$ignore_energy_opts" > $dir/ignore_energy_opts +echo "$add_zero_crossing_feats" > $dir/add_zero_crossing_feats + +for f in $data/feats.scp $data/cmvn.scp $data/utt2spk; do + [ ! -s $f ] && echo "$0: could not find $f or $f is empty" && exit 1 +done + +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" +zero_crossing_opts= + +if [ $stage -le -3 ]; then + ########################################################################### + # Prepare data. + # Split the vad in the same way as the data + ########################################################################### + rm -rf $dir/data + utils/copy_data_dir.sh $data $dir/data + + split_data.sh $dir/data $nj || exit 1 + + ########################################################################### + # Add zero-crossing and high-frequency content feats + ########################################################################### + if $add_zero_crossing_feats; then + mkdir -p $dir/data + $cmd JOB=1:$nj $dir/log/compute_zero_crossing.JOB.log \ + compute-zero-crossings $zc_opts scp:$dir/data/split$nj/JOB/wav.scp ark,scp:$dir/data/zero_crossings.JOB.ark,$dir/data/zero_crossings.JOB.scp || exit 1 + + eval cat $dir/data/zero_crossings.{`seq -s',' $nj`}.scp > $dir/data/zero_crossings.scp + + [ ! -f $dir/data/zero_crossings.1.scp ] && exit 1 + fi +fi + +########################################################################### +# Get appropriate $feats variable: +# Apply CMVN. Note that we don't apply CMVN to the zero-crossing feats. +# Remove energy from the features. +# Add zero-crossing feats. +# Add deltas. +########################################################################### + +feats="ark:apply-cmvn-sliding scp:$dir/data/split$nj/JOB/feats.scp ark:- |${ignore_energy_opts}" + +if $add_zero_crossing_feats; then + feats="${feats}paste-feats ark:- scp:$dir/data/zero_crossings.JOB.scp ark:- |" +fi + +feats="${feats}add-deltas ark:- ark:- |" + +feats_all="ark:apply-cmvn-sliding scp:$dir/data/feats.scp ark:- |${ignore_energy_opts}" + +if $add_zero_crossing_feats; then + feats_all="${feats_all}paste-feats ark:- scp:$dir/data/zero_crossings.scp ark:- |" +fi + +feats_all="${feats_all}add-deltas ark:- ark:- |" + +if [ $stage -le -3 ]; then + $cmd JOB=1:$nj $dir/log/get_speech_segmentation.JOB.log \ + utils/filter_scp.pl -f 2 $dir/data/split$nj/JOB/wav.scp $rttm \| \ + egrep 'LEXEME.*lex' \| \ + awk '{print "SPEAKER "$2" 1 "$4" "$5" speech "}' \| \ + rttmSmooth.pl -s 0 \| diarization/convert_rttm_to_segments.pl \| \ + segmentation-init-from-segments --label=1 - \ + ark:$dir/init_segmentation_speech.JOB.ark + + $cmd $dir/log/get_noise_segmentation.log \ + utils/filter_scp.pl -f 2 $dir/data/wav.scp $rttm \| \ + egrep '"NON-SPEECH|NON-LEX"' \| \ + awk '{print "SPEAKER "$2" 1 "$4" "$5" noise "}' \| \ + rttmSmooth.pl -s 0 \| diarization/convert_rttm_to_segments.pl \| \ + segmentation-init-from-segments --label=2 - \ + ark:$dir/init_segmentation_noise.1.ark + + $cmd $dir/log/get_silence_segmentation.1.log \ + segmentation-init-from-ali scp:$vad_scp ark:- \| \ + segmentation-to-rttm --segments="utils/filter_scp.pl -f 2 $dir/data/wav.scp $segments |" \ + ark:- - \| grep "SILENCE" \| \ + diarization/convert_rttm_to_segments.pl \| \ + segmentation-init-from-segments --label=0 - \ + ark:$dir/init_segmentation_silence.1.ark + + $cmd $dir/log/select_feats_init_noise.log \ + select-feats-from-segmentation --select-label=2 "$feats_all" \ + ark:$dir/init_segmentation_noise.1.ark \ + ark:$dir/init_feats_noise.1.ark || exit 1 + + $cmd JOB=1:$nj $dir/log/select_feats_init_speech.JOB.log \ + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:$dir/init_segmentation_speech.JOB.ark \ + ark:$dir/init_feats_speech.JOB.ark || exit 1 + + $cmd JOB=1:1 $dir/log/select_feats_init_silence.JOB.log \ + select-feats-from-segmentation --select-label=0 "$feats_all" \ + ark:$dir/init_segmentation_silence.JOB.ark \ + ark:$dir/init_feats_silence.JOB.ark || exit 1 +fi + +speech_num_gauss=$speech_num_gauss_init +noise_num_gauss=$noise_num_gauss_init +sil_num_gauss=$sil_num_gauss_init + +if [ $stage -le -1 ]; then + $cmd $dir/log/init_gmm_speech.log \ + gmm-global-init-from-feats --num-gauss=$speech_num_gauss --num-iters=$[speech_num_gauss + 2] \ + "ark:cat $dir/init_feats_speech.{?,??,???}.ark |" $dir/speech.0.mdl || exit 1 + + $cmd $dir/log/init_gmm_noise.log \ + gmm-global-init-from-feats --num-gauss=$noise_num_gauss --num-iters=$[noise_num_gauss + 2] \ + "ark:cat $dir/init_feats_noise.{?,??,???}.ark |" $dir/noise.0.mdl || exit 1 + + $cmd $dir/log/init_gmm_silence.log \ + gmm-global-init-from-feats --num-gauss=$sil_num_gauss --num-iters=$[sil_num_gauss + 2] \ + "ark:cat $dir/init_feats_silence.{?,??,???}.ark |" $dir/silence.0.mdl || exit 1 +fi + +x=0 +while [ $x -le $num_iters ]; do + if [ $stage -le $x ]; then + $cmd JOB=1:$nj $dir/log/acc_gmm_stats_speech.$x.JOB.log \ + gmm-global-acc-stats $dir/speech.$x.mdl \ + "ark:copy-feats ark:$dir/init_feats_speech.JOB.ark ark:- |" \ + $dir/speech_accs.$x.JOB || exit 1 + + $cmd JOB=1:1 $dir/log/acc_gmm_stats_noise.$x.JOB.log \ + gmm-global-acc-stats $dir/noise.$x.mdl \ + "ark:copy-feats ark:$dir/init_feats_noise.JOB.ark ark:- |" \ + $dir/noise_accs.$x.JOB || exit 1 + + $cmd JOB=1:1 $dir/log/acc_gmm_stats_silence.$x.JOB.log \ + gmm-global-acc-stats $dir/silence.$x.mdl \ + "ark:copy-feats ark:$dir/init_feats_silence.JOB.ark ark:- |" \ + $dir/silence_accs.$x.JOB || exit 1 + + $cmd $dir/log/gmm_est_speech.$x.log \ + gmm-global-est --mix-up=$speech_num_gauss $dir/speech.$x.mdl \ + "gmm-global-sum-accs - $dir/speech_accs.$x.* |" \ + $dir/speech.$[x+1].mdl || exit 1 + + $cmd $dir/log/gmm_est_noise.$x.log \ + gmm-global-est --mix-up=$noise_num_gauss $dir/noise.$x.mdl \ + "gmm-global-sum-accs - $dir/noise_accs.$x.* |" \ + $dir/noise.$[x+1].mdl || exit 1 + + $cmd $dir/log/gmm_est_silence.$x.log \ + gmm-global-est --mix-up=$sil_num_gauss $dir/silence.$x.mdl \ + "gmm-global-sum-accs - $dir/silence_accs.$x.* |" \ + $dir/silence.$[x+1].mdl || exit 1 + fi + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss * 2] + fi + if [ $speech_num_gauss -lt $speech_max_gauss ]; then + speech_num_gauss=$[speech_num_gauss * 2] + fi + if [ $noise_num_gauss -lt $noise_max_gauss ]; then + noise_num_gauss=$[noise_num_gauss * 2] + fi + + x=$[x+1] +done + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log + diff --git a/egs/babel/s5b/run-4-anydecode.sh b/egs/babel/s5b/run-4-anydecode.sh index a1b943dd35e..1abb574d195 100755 --- a/egs/babel/s5b/run-4-anydecode.sh +++ b/egs/babel/s5b/run-4-anydecode.sh @@ -222,15 +222,14 @@ if [ ! -f $dataset_dir/.done ] ; then echo "Valid dataset kinds are: supervised, unsupervised, shadow"; exit 1 fi - - if [ ! -f ${dataset_dir}/.plp.done ]; then + touch $dataset_dir/.done +fi +if [ ! -f ${dataset_dir}/.plp.done ]; then echo --------------------------------------------------------------------- echo "Preparing ${dataset_kind} parametrization files in ${dataset_dir} on" `date` echo --------------------------------------------------------------------- make_plp ${dataset_dir} exp/make_plp/${dataset_id} plp touch ${dataset_dir}/.plp.done - fi - touch $dataset_dir/.done fi ##################################################################### # diff --git a/egs/rt/s5/cmd.sh b/egs/rt/s5/cmd.sh new file mode 120000 index 00000000000..19f7e836644 --- /dev/null +++ b/egs/rt/s5/cmd.sh @@ -0,0 +1 @@ +../../wsj/s5/cmd.sh \ No newline at end of file diff --git a/egs/rt/s5/conf/librispeech_mfcc.conf b/egs/rt/s5/conf/librispeech_mfcc.conf new file mode 100644 index 00000000000..45d284ad05c --- /dev/null +++ b/egs/rt/s5/conf/librispeech_mfcc.conf @@ -0,0 +1 @@ +--use-energy=false diff --git a/egs/rt/s5/conf/mfcc_vad.conf b/egs/rt/s5/conf/mfcc_vad.conf new file mode 100644 index 00000000000..22765c6280e --- /dev/null +++ b/egs/rt/s5/conf/mfcc_vad.conf @@ -0,0 +1,5 @@ +--sample-frequency=16000 +--frame-length=25 # the default is 25. +--low-freq=20 # the default. +--high-freq=-600 # the default is zero meaning use the Nyquist (4k in this case). +--num-ceps=13 # higher than the default which is 12. diff --git a/egs/rt/s5/conf/pitch.conf b/egs/rt/s5/conf/pitch.conf new file mode 100644 index 00000000000..e959a19d5b8 --- /dev/null +++ b/egs/rt/s5/conf/pitch.conf @@ -0,0 +1 @@ +--sample-frequency=16000 diff --git a/egs/rt/s5/conf/vad_decode_icsi.conf b/egs/rt/s5/conf/vad_decode_icsi.conf new file mode 100644 index 00000000000..15ba288e3af --- /dev/null +++ b/egs/rt/s5/conf/vad_decode_icsi.conf @@ -0,0 +1,40 @@ +## Features paramters +window_size=100 # 1s +frames_per_gaussian=2000 + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=10000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_decode_pitch.conf b/egs/rt/s5/conf/vad_decode_pitch.conf new file mode 100644 index 00000000000..d7ba1d40093 --- /dev/null +++ b/egs/rt/s5/conf/vad_decode_pitch.conf @@ -0,0 +1,55 @@ +## Features paramters +window_size=10 # 1s +smooth_weights=false +smoothing_window=2 +smooth_mask=true + +## Phase 1 parameters +num_frames_init_silence=200 +num_frames_init_sound=200 +num_frames_init_sound_next=200 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=2 +sil_gauss_incr=1 +sound_gauss_incr=1 +sil_frames_incr=200 +sound_frames_incr=200 +sound_frames_next_incr=200 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=5000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=7 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=1 +speech_gauss_incr_phase2=2 +num_iters_phase2=20 +window_size_phase2_init=10 +window_size_phase2_next=10 +window_size_incr_iter=5 + +num_frames_init_speech_phase2=100000 +num_frames_init_silence_phase2=200000 +num_frames_init_sound_phase2=200000 +speech_frames_incr_phase2=200000 +sil_frames_incr_phase2=200000 +sound_frames_incr_phase2=200000 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_icsi_babel.conf b/egs/rt/s5/conf/vad_icsi_babel.conf new file mode 100644 index 00000000000..70f651403f5 --- /dev/null +++ b/egs/rt/s5/conf/vad_icsi_babel.conf @@ -0,0 +1,39 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/vad_icsi_babel_3models.conf b/egs/rt/s5/conf/vad_icsi_babel_3models.conf new file mode 100644 index 00000000000..1196f0d2aff --- /dev/null +++ b/egs/rt/s5/conf/vad_icsi_babel_3models.conf @@ -0,0 +1,54 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +num_frames_silence_phase3_init=2000 +num_frames_speech_phase3_init=2000 +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +sil_max_gauss_phase4=8 +speech_max_gauss_phase4=16 +sil_gauss_incr_phase3=1 +sil_gauss_incr_phase4=1 +speech_gauss_incr_phase4=2 +num_iters_phase3=5 +num_iters_phase4=5 + +## Phase 5 parameters +sil_num_gauss_init_phase5=2 +speech_num_gauss_init_phase5=2 +sil_max_gauss_phase5=5 +speech_max_gauss_phase5=12 +sil_gauss_incr_phase5=1 +speech_gauss_incr_phase5=2 +num_iters_phase5=7 + + + diff --git a/egs/rt/s5/conf/vad_icsi_rt.conf b/egs/rt/s5/conf/vad_icsi_rt.conf new file mode 100644 index 00000000000..d19038014db --- /dev/null +++ b/egs/rt/s5/conf/vad_icsi_rt.conf @@ -0,0 +1,40 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=10000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +#num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/rt/s5/conf/zc_vad.conf b/egs/rt/s5/conf/zc_vad.conf new file mode 100644 index 00000000000..b5d94450709 --- /dev/null +++ b/egs/rt/s5/conf/zc_vad.conf @@ -0,0 +1,5 @@ +--sample-frequency=16000 +--frame-length=25 # the default is 25. +--dither=0.0 +--zero-crossing-threshold=1e-5 + diff --git a/egs/rt/s5/diarization b/egs/rt/s5/diarization new file mode 120000 index 00000000000..ba78a9126af --- /dev/null +++ b/egs/rt/s5/diarization @@ -0,0 +1 @@ +../../sre08/v1/diarization \ No newline at end of file diff --git a/egs/rt/s5/local/make_rt_2004_dev.pl b/egs/rt/s5/local/make_rt_2004_dev.pl new file mode 100755 index 00000000000..5f0d07e673c --- /dev/null +++ b/egs/rt/s5/local/make_rt_2004_dev.pl @@ -0,0 +1,56 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +use strict; +use File::Basename; + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n" . + " e.g.: $0 /export/corpora5/LDC/LDC2007S11 data\n"; + exit(1); +} + +my ($db_base, $out_dir) = @ARGV; +$out_dir = "$out_dir/rt04_dev"; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") + or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") + or die "Could not open the output file $out_dir/wav.scp"; + +open(LIST, 'find ' . $db_base . '/data/audio/dev04s -name "*.sph" |'); + +while (my $line = ) { + chomp($line); + my ($file_id, $path, $suffix) = fileparse($line, qr/\.[^.]*/); + if ($suffix =~ /.sph/) { + print WAV $file_id . " sph2pipe -f wav -p -c 1 $line |\n"; + } elsif ($suffix =~ /.wav/) { + print WAV $file_id . " $line |\n"; + } else { + die "$0: Unknown suffix $suffix in $line\n" + } + + print SPKR "$file_id $file_id\n"; +} + +close(LIST) || die; +close(WAV) || die; +close(SPKR) || die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} + +system("utils/fix_data_dir.sh $out_dir"); + +if (system( + "utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/rt/s5/local/make_rt_2004_eval.pl b/egs/rt/s5/local/make_rt_2004_eval.pl new file mode 100755 index 00000000000..36ae9eae117 --- /dev/null +++ b/egs/rt/s5/local/make_rt_2004_eval.pl @@ -0,0 +1,57 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +use strict; +use File::Basename; + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n" . + " e.g.: $0 /export/corpora5/LDC/LDC2007S12/package/rt04_eval data\n"; + exit(1); +} + +my ($db_base, $out_dir) = @ARGV; +$out_dir = "$out_dir/rt04_eval"; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") + or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") + or die "Could not open the output file $out_dir/wav.scp"; + +open(LIST, 'find ' . $db_base . '/data/audio/eval04s -name "*.sph" |'); + +while (my $line = ) { + chomp($line); + my ($file_id, $path, $suffix) = fileparse($line, qr/\.[^.]*/); + if ($suffix =~ /.sph/) { + print WAV $file_id . " sph2pipe -f wav -p -c 1 $line |\n"; + } elsif ($suffix =~ /.wav/) { + print WAV $file_id . " $line |\n"; + } else { + die "$0: Unknown suffix $suffix in $line\n" + } + + print SPKR "$file_id $file_id\n"; +} + +close(LIST) || die; +close(WAV) || die; +close(SPKR) || die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} + +system("utils/fix_data_dir.sh $out_dir"); + +if (system( + "utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} + diff --git a/egs/rt/s5/local/make_rt_2005_eval.pl b/egs/rt/s5/local/make_rt_2005_eval.pl new file mode 100755 index 00000000000..f00e31b40e8 --- /dev/null +++ b/egs/rt/s5/local/make_rt_2005_eval.pl @@ -0,0 +1,58 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +use strict; +use File::Basename; + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n" . + " e.g.: $0 /export/corpora5/LDC/LDC2011S06 data\n"; + exit(1); +} + +my ($db_base, $out_dir) = @ARGV; +$out_dir = "$out_dir/rt05_eval"; + +if (system("mkdir -p $out_dir")) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") + or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") + or die "Could not open the output file $out_dir/wav.scp"; + +open(LIST, 'find ' . $db_base . '/data/audio/eval05s -name "*.sph" |'); + +while (my $line = ) { + chomp($line); + my ($file_id, $path, $suffix) = fileparse($line, qr/\.[^.]*/); + if ($suffix =~ /.sph/) { + print WAV $file_id . " sph2pipe -f wav -p -c 1 $line |\n"; + } elsif ($suffix =~ /.wav/) { + print WAV $file_id . " $line |\n"; + } else { + die "$0: Unknown suffix $suffix in $line\n" + } + + print SPKR "$file_id $file_id\n"; +} + +close(LIST) || die; +close(WAV) || die; +close(SPKR) || die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} + +system("utils/fix_data_dir.sh $out_dir"); + +if (system( + "utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} + + diff --git a/egs/rt/s5/local/snr b/egs/rt/s5/local/snr new file mode 120000 index 00000000000..6d422e11960 --- /dev/null +++ b/egs/rt/s5/local/snr @@ -0,0 +1 @@ +../../../wsj_noisy/s5/local/snr \ No newline at end of file diff --git a/egs/rt/s5/path.sh b/egs/rt/s5/path.sh new file mode 100755 index 00000000000..8461d980758 --- /dev/null +++ b/egs/rt/s5/path.sh @@ -0,0 +1,5 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin/:$KALDI_ROOT/src/kwsbin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/:$KALDI_ROOT/src/nnet3bin/:$KALDI_ROOT/src/segmenterbin/:$PWD:$PATH:$KALDI_ROOT/tools/sctk/bin +export PATH=$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH +export LC_ALL=C diff --git a/egs/rt/s5/sid b/egs/rt/s5/sid new file mode 120000 index 00000000000..5cb0274b7d6 --- /dev/null +++ b/egs/rt/s5/sid @@ -0,0 +1 @@ +../../sre08/v1/sid/ \ No newline at end of file diff --git a/egs/rt/s5/steps b/egs/rt/s5/steps new file mode 120000 index 00000000000..1b186770dd1 --- /dev/null +++ b/egs/rt/s5/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/rt/s5/utils b/egs/rt/s5/utils new file mode 120000 index 00000000000..a3279dc8679 --- /dev/null +++ b/egs/rt/s5/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/sre08/v1/conf/mfcc_vad.conf b/egs/sre08/v1/conf/mfcc_vad.conf new file mode 100644 index 00000000000..a5c9243eee0 --- /dev/null +++ b/egs/sre08/v1/conf/mfcc_vad.conf @@ -0,0 +1,5 @@ +--sample-frequency=8000 +--frame-length=25 # the default is 25. +--low-freq=20 # the default. +--high-freq=-300 # the default is zero meaning use the Nyquist (4k in this case). +--num-ceps=13 # higher than the default which is 12. diff --git a/egs/sre08/v1/conf/pitch.conf b/egs/sre08/v1/conf/pitch.conf new file mode 100644 index 00000000000..926bcfca92a --- /dev/null +++ b/egs/sre08/v1/conf/pitch.conf @@ -0,0 +1 @@ +--sample-frequency=8000 diff --git a/egs/sre08/v1/conf/vad_decode.conf b/egs/sre08/v1/conf/vad_decode.conf new file mode 100644 index 00000000000..c3c468020d9 --- /dev/null +++ b/egs/sre08/v1/conf/vad_decode.conf @@ -0,0 +1,42 @@ +## Features paramters +window_size=10 # 1s + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=2000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=2 +sil_gauss_incr=1 +sound_gauss_incr=1 +sil_frames_incr=2000 +sound_frames_incr=2000 +sound_frames_next_incr=2000 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=7 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=1 +speech_gauss_incr_phase2=2 +num_iters_phase2=10 +window_size_phase2=10 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + diff --git a/egs/sre08/v1/conf/vad_decode_icsi.conf b/egs/sre08/v1/conf/vad_decode_icsi.conf new file mode 100644 index 00000000000..bbc9687e6d8 --- /dev/null +++ b/egs/sre08/v1/conf/vad_decode_icsi.conf @@ -0,0 +1,40 @@ +## Features paramters +window_size=10 # 1s +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=10000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/sre08/v1/conf/vad_decode_pitch.conf b/egs/sre08/v1/conf/vad_decode_pitch.conf new file mode 100644 index 00000000000..4f713e86a41 --- /dev/null +++ b/egs/sre08/v1/conf/vad_decode_pitch.conf @@ -0,0 +1,43 @@ +## Features paramters +window_size=10 # 1s + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=2000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=2 +sil_gauss_incr=1 +sound_gauss_incr=1 +sil_frames_incr=2000 +sound_frames_incr=2000 +sound_frames_next_incr=2000 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=7 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=1 +speech_gauss_incr_phase2=2 +num_iters_phase2=20 +window_size_phase2=10 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/sre08/v1/conf/vad_icsi_babel.conf b/egs/sre08/v1/conf/vad_icsi_babel.conf new file mode 100644 index 00000000000..70f651403f5 --- /dev/null +++ b/egs/sre08/v1/conf/vad_icsi_babel.conf @@ -0,0 +1,39 @@ +## Features paramters +window_size=10 # 100 ms +frames_per_gaussian=200 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + diff --git a/egs/sre08/v1/conf/vad_icsi_rt.conf b/egs/sre08/v1/conf/vad_icsi_rt.conf new file mode 100644 index 00000000000..c2964d5171d --- /dev/null +++ b/egs/sre08/v1/conf/vad_icsi_rt.conf @@ -0,0 +1,41 @@ +## Features paramters +window_size=10 # 1s +frames_per_gaussian=2000 + +## Phase 1 parameters +num_frames_init_silence=2000 +num_frames_init_sound=10000 +num_frames_init_sound_next=2000 +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=6 +sil_gauss_incr=0 +sound_gauss_incr=2 +num_iters=5 +min_sil_variance=0.1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + + diff --git a/egs/sre08/v1/conf/zc_vad.conf b/egs/sre08/v1/conf/zc_vad.conf new file mode 100644 index 00000000000..1475967e7b1 --- /dev/null +++ b/egs/sre08/v1/conf/zc_vad.conf @@ -0,0 +1,4 @@ +--sample-frequency=8000 +--frame-length=25 # the default is 25. +--dither=0.0 +--zero-crossing-threshold=1e-5 diff --git a/egs/sre08/v1/diarization/analyze_rttm.py b/egs/sre08/v1/diarization/analyze_rttm.py new file mode 100755 index 00000000000..55fd36c057a --- /dev/null +++ b/egs/sre08/v1/diarization/analyze_rttm.py @@ -0,0 +1,23 @@ +#! /usr/bin/env python + +from __future__ import print_function +import sys +import numpy + +A = []; +for line in sys.stdin.readlines(): + line = line.strip(); + splits = line.split(); + x = float(splits[4]); + A.append(x); + +min_x = min(A); +max_x = max(A); +mean_x = sum(A) / len(A); +per10_x = numpy.percentile(A, 10); +per25_x = numpy.percentile(A, 25); +per50_x = numpy.percentile(A, 50); +per75_x = numpy.percentile(A, 75); + +print("%5.2f %5.2f %5.2f" % (min_x, max_x, mean_x)); +print("%5.2f %5.2f %5.2f %5.2f" % (per10_x, per25_x, per50_x, per75_x)); diff --git a/egs/sre08/v1/diarization/convert_ali_to_vad.sh b/egs/sre08/v1/diarization/convert_ali_to_vad.sh new file mode 100755 index 00000000000..8ab72d7a366 --- /dev/null +++ b/egs/sre08/v1/diarization/convert_ali_to_vad.sh @@ -0,0 +1,132 @@ +set -o pipefail + +. path.sh + +cmd=run.pl +nj=4 +stage=-1 +get_whole_vad=false +phone_map= +model= + +. parse_options.sh + +if [ $# -gt 4 ] || [ $# -lt 3 ]; then + echo "Usage: convert_ali_to_vad.sh []" + echo " e.g.: convert_ali_to_vad.sh data/lang exp/tri5_ali" + exit 1 +fi + +data=$1 +lang=$2 +ali_dir=$3 +if [ $# -eq 4 ]; then + dir=$4 +else + dir=$ali_dir/vad +fi + +[ -z "$model" ] && model=$ali_dir/final.mdl + +for f in $ali_dir/ali.1.gz $model; do + [ ! -f $f ] && echo "$0: $f does not exist" && exit 1 +done + +tmpdir=$dir/tmp +mkdir -p $tmpdir + +num_jobs=`cat $ali_dir/num_jobs` || exit 1 +echo $num_jobs > $dir/num_jobs + +split_data.sh $data $num_jobs + +if [ -z "$phone_map" ]; then + { + awk '{print $1" 0"}' $lang/phones/silence.txt; + awk '{print $1" 1"}' $lang/phones/nonsilence.txt; + } | awk '{ if ($1 !~ /oov/) { print $1" "$2} else {print $1" 1"} }' | utils/sym2int.pl -f 1 $lang/phones.txt | \ + sort -k1,1 -n > $dir/phone_map +else + cat $phone_map | utils/sym2int.pl -f 1 $lang/phones.txt | \ + sort -k1,1 -n > $dir/phone_map || exit 1 +fi + +if [ $stage -le 0 ]; then + $cmd JOB=1:$num_jobs $dir/log/convert_ali_to_vad.JOB.log \ + ali-to-phones --per-frame=true $model "ark:gunzip -c $ali_dir/ali.JOB.gz |" \ + ark,t:- \| utils/apply_map.pl -f 2- $dir/phone_map \| \ + copy-int-vector ark,t:- "ark:| gzip -c > $tmpdir/vad.JOB.ark.gz" || exit 1 +fi + +if ! $get_whole_vad; then + if [ $stage -le 1 ]; then + $cmd JOB=1:$num_jobs $dir/log/combine_vad.JOB.log \ + copy-int-vector "ark:gunzip -c $tmpdir/vad.JOB.ark.gz |" \ + ark,scp:$dir/vad.JOB.ark,$dir/vad.JOB.scp || exit 1 + + for n in `seq $num_jobs`; do + cat $dir/vad.$n.scp + done | sort -k1,1 > $dir/vad.scp + fi +else + if [ $stage -le 1 ]; then + $cmd JOB=1:$num_jobs $dir/log/get_whole_vad.JOB.log \ + copy-int-vector "ark:gunzip -c $tmpdir/vad.JOB.ark.gz |" ark,t:- \| \ + diarization/convert_vad_to_rttm.pl --silence-class 0 --speech-class 1 \ + --segments $data/split$num_jobs/JOB/segments \| rttmSort.pl \| \ + diarization/convert_rttm_to_vad.pl --ignore-boundaries true \ + --segments-out $tmpdir/segments.JOB \| \ + copy-int-vector ark,t:- ark,scp:$dir/vad.JOB.ark.gz,$dir/vad.JOB.scp || exit 1 + + for n in `seq $num_jobs`; do + cat $dir/vad.$n.scp + done | sort -k1,1 > $dir/vad.scp + fi + + if [ $stage -le 2 ]; then + rm -rf $dir/data + mkdir -p $dir/data + + utils/copy_data_dir.sh $data $dir/data + for f in feats.scp segments spk2utt utt2spk vad.scp cmvn.scp spk2gender text; do + [ -f $dir/data/$f ] && rm $dir/data/$f + done + rm -rf $dir/data/split* $dir/data/.backup + + eval cat $tmpdir/segments.{`seq -s',' $num_jobs`} | sort -k1,1 > $dir/data/segments + + awk '{print $1" "$2}' $dir/data/segments > $dir/data/utt2spk || exit 1 + utils/utt2spk_to_spk2utt.pl $dir/data/utt2spk > $dir/data/spk2utt || exit 1 + + utils/fix_data_dir.sh $dir/data || exit 1 + + rm -rf $dir/data_whole + + mkdir -p $dir/data_whole + utils/copy_data_dir.sh $data $dir/data_whole + for f in feats.scp segments spk2utt utt2spk vad.scp cmvn.scp spk2gender text; do + [ -f $dir/data_whole/$f ] && rm $dir/data_whole/$f + done + rm -rf $dir/data_whole/split* $dir/data_whole/.backup + + awk '{print $1" "$1}' $dir/data_whole/wav.scp > $dir/data_whole/utt2spk || exit 1 + utils/utt2spk_to_spk2utt.pl $dir/data_whole/utt2spk > $dir/data_whole/spk2utt || exit 1 + + utils/fix_data_dir.sh $dir/data_whole || exit 1 + fi + + if [ $stage -le 3 ]; then + diarization/prepare_data.sh --cmd $cmd --nj $nj $dir/data_whole $dir/tmpdir $dir/mfcc || exit 1 + fi + + if [ $stage -le 4 ]; then + split_data.sh $dir/data $nj + $cmd JOB=1:$nj $dir/log/extract_feats.JOB.log \ + utils/filter_scp.pl $dir/data/split$nj/JOB/wav.scp $dir/data_whole/feats.scp \| \ + extract-feature-segments scp:- \ + $dir/data/split$nj/JOB/segments \ + ark,scp:$dir/mfcc/raw_mfcc_feats_data.JOB.ark,$dir/mfcc/raw_mfcc_feats_data.JOB.scp || exit 1 + + eval cat $dir/mfcc/raw_mfcc_feats_data.{`seq -s',' $nj`}.scp | sort -k1,1 > $dir/data/feats.scp + fi +fi diff --git a/egs/sre08/v1/diarization/convert_data_dir_to_whole.sh b/egs/sre08/v1/diarization/convert_data_dir_to_whole.sh new file mode 100755 index 00000000000..19dc07e7d6a --- /dev/null +++ b/egs/sre08/v1/diarization/convert_data_dir_to_whole.sh @@ -0,0 +1,96 @@ +set -o pipefail + +. path.sh + +cmd=run.pl +stage=-1 + +. parse_options.sh + +if [ $# -ne 2 ]; then + echo "Usage: convert_data_dir_to_whole.sh " + echo " e.g.: convert_data_dir_to_whole.sh data/dev data/dev_whole" + exit 1 +fi + +data=$1 +dir=$2 + +if [ ! -f $data/segments ]; then + utils/copy_data_dir.sh $data $dir + exit 0 +fi + +mkdir -p $dir +cp $data/wav.scp $dir +cp $data/reco2file_and_channel $dir +rm -f $dir/{utt2spk,text} || true + +text_files= +[ -f $data/text ] && text_files=$data/text $dir/text + +cat $data/segments | perl -e ' +if (scalar @ARGV == 4) { + ($utt2spk_in, $utt2spk_out, $text_in, $text_out) = @ARGV; +} elsif (scalar @ARGV == 2) { + ($utt2spk_in, $utt2spk_out) = @ARGV; +} else { + die "Unexpected number of arguments"; +} + +if (defined $text_in) { + open(TI, "<$text_in") || die "Error: fail to open $text_in\n"; + open(TO, ">$text_out") || die "Error: fail to open $text_out\n"; +} +open(UI, "<$utt2spk_in") || die "Error: fail to open $utt2spk_in\n"; +open(UO, ">$utt2spk_out") || die "Error: fail to open $utt2spk_out\n"; + +my %file2utt = (); +while () { + chomp; + my @col = split; + @col >= 4 or die "bad line $_\n"; + + if (! defined $file2utt{$col[1]}) { + $file2utt{$col[1]} = []; + } + push @{$file2utt{$col[1]}}, $col[0]; +} + +my %text = (); +my %utt2spk = (); + +while () { + chomp; + my @col = split; + $utt2spk{$col[0]} = $col[1]; +} + +if (defined $text_in) { + while () { + chomp; + my @col = split; + @col >= 1 or die "bad line $_\n"; + + my $utt = shift @col; + $text{$utt} = join(" ", @col); + } +} + +foreach $file (keys %file2utt) { + my @utts = @{$file2utt{$file}}; + #print STDERR $file . " " . join(" ", @utts) . "\n"; + print UO "$file $file\n"; + + if (defined $text_in) { + $text_line = ""; + print TO "$file $text_line\n"; + } +} +' $data/utt2spk $dir/utt2spk $text_files + +sort -u $dir/utt2spk > $dir/utt2spk.tmp +mv $dir/utt2spk.tmp $dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +utils/fix_data_dir.sh $dir diff --git a/egs/sre08/v1/diarization/convert_ref_to_rttm.pl b/egs/sre08/v1/diarization/convert_ref_to_rttm.pl new file mode 100755 index 00000000000..3971371073c --- /dev/null +++ b/egs/sre08/v1/diarization/convert_ref_to_rttm.pl @@ -0,0 +1,35 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar (Johns Hopkins University) +# Apache 2.0. + +use strict; +use Getopt::Long; +use File::Basename; + +if (@ARGV != 1) { + print STDERR "$0:\n" . + "Usage: convert_ref_to_rttm.pl [options] > \n"; + exit 1; +} + +my $filename = $ARGV[0]; + +print STDERR "Extracting RTTM from ref $filename\n"; + +my $basename = basename($filename); +(my $utt_id = $basename) =~ s/\.[^.]+$//; + +open IN, $filename or die "Could not open $filename"; + +my %seen_spkrs = (); + +while () { + chomp; + my @A = split; + if (! defined $seen_spkrs{$A[2]}) { + printf STDOUT ("SPKR-INFO $utt_id 1 unknown $A[2] \n"); + $seen_spkrs{$A[2]} = 1; + } + + printf STDOUT ("SPEAKER $utt_id 1 %5.2f %5.2f $A[2] \n", $A[0], $A[1] - $A[0]); +} diff --git a/egs/sre08/v1/diarization/convert_ref_to_vad.pl b/egs/sre08/v1/diarization/convert_ref_to_vad.pl new file mode 100755 index 00000000000..3501dc84c53 --- /dev/null +++ b/egs/sre08/v1/diarization/convert_ref_to_vad.pl @@ -0,0 +1,57 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar (Johns Hopkins University) +# Apache 2.0. + +use strict; +use Getopt::Long; +use File::Basename; + +my $frame_shift = 0.01; + +GetOptions('frame-shift:f' => \$frame_shift); + +if (@ARGV != 1) { + print STDERR "$0:\n" . + "Usage: convert_ref_to_vad.pl [options] > \n"; + exit 1; +} + +($frame_shift > 0.0001 && $frame_shift <= 1.0) || + die "Very strange frame-shift value '$frame_shift'"; + +my $filename = $ARGV[0]; + +print STDERR "Extracting VAD from ref $filename\n"; + +my $basename = basename($filename); +(my $utt_id = $basename) =~ s/\.[^.]+$//; + +open IN, $filename or die "Could not open $filename"; + +my $max_time = 0; +while () { + chomp; + my @A = split; + if (int($A[1]/$frame_shift+0.5) > $max_time) { $max_time = int($A[1]/$frame_shift+0.5) } +} + +my @vad = (1)x$max_time; + +close IN; + +open IN, $filename or die "Could not open $filename"; + +while () { + chomp; + my @A = split; + + for (my $i = int($A[0]/$frame_shift); $i <= int($A[1]/$frame_shift+0.5); $i++) { + $vad[$i] = 2; + } +} + +print STDOUT $utt_id; +foreach (@vad) { + print STDOUT " $_"; +} +print STDOUT "\n"; diff --git a/egs/sre08/v1/diarization/convert_rttm_to_segments.pl b/egs/sre08/v1/diarization/convert_rttm_to_segments.pl new file mode 100755 index 00000000000..9504e1361ff --- /dev/null +++ b/egs/sre08/v1/diarization/convert_rttm_to_segments.pl @@ -0,0 +1,49 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar (Johns Hopkins University) +# Apache 2.0. + +use strict; +use POSIX; +use Getopt::Long; +use File::Basename; +use Pod::Usage; + +my $help = 0; + +my $frame_shift = 0.01; +my $ignore_boundaries = "false"; + +GetOptions('frame-shift:f' => \$frame_shift, + 'ignore-boundaries:s' => \$ignore_boundaries, + 'help|?' => \$help); + +if ((@ARGV > 1 || $help)) { + print STDERR "$0:\n" . + "Usage: convert_rttm_to_segments.pl [options] [rttm] > \n"; + exit 0 if $help; + exit 1; +} + +($frame_shift > 0.0001 && $frame_shift <= 1.0) || + die "Very strange frame-shift value '$frame_shift'"; +($ignore_boundaries eq "false" || $ignore_boundaries eq "true") || + die "ignore-boundaries must be (true|false)"; + +while (<>) { + chomp; + my @A = split; + my $file = $A[1]; + + if ($A[0] =~ m/SPKR-INFO/) { + print STDERR "Reading RTTM for file $file\n"; + next; + } elsif ($A[0] !~ m/SPEAKER/) { + next; + } + my $start_frame = floor($A[3] / $frame_shift); + my $end_frame = floor(($A[3] + $A[4]) / $frame_shift); + + my $utt_id = sprintf("$file-%06d-%06d", $start_frame, $end_frame); + + printf STDOUT ("$utt_id $file %6.3f %6.3f\n", $A[3], $A[3] + $A[4]) or die; +} diff --git a/egs/sre08/v1/diarization/convert_rttm_to_vad.pl b/egs/sre08/v1/diarization/convert_rttm_to_vad.pl new file mode 100755 index 00000000000..39fdf11ffad --- /dev/null +++ b/egs/sre08/v1/diarization/convert_rttm_to_vad.pl @@ -0,0 +1,91 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar (Johns Hopkins University) +# Apache 2.0. + +use strict; +use POSIX; +use Getopt::Long; +use File::Basename; + +my $frame_shift = 0.01; +my $ignore_boundaries = "false"; +my $segments_file = ""; + +GetOptions('frame-shift:f' => \$frame_shift, + 'segments-out:s' => \$segments_file, + 'ignore-boundaries:s' => \$ignore_boundaries); + +if (@ARGV > 1) { + print STDERR "$0:\n" . + "Usage: convert_rttm_to_vad.pl [options] [rttm] > \n"; + exit 1; +} +($frame_shift > 0.0001 && $frame_shift <= 1.0) || + die "Very strange frame-shift value '$frame_shift'"; +($ignore_boundaries eq "false" || $ignore_boundaries eq "true") || + die "ignore-boundaries must be (true|false)"; + +if ($segments_file ne "") { + open (SEGMENTS, ">", $segments_file) + or die "Cannot open $segments_file for writingn\n"; + print STDERR "Writing segments to $segments_file\n"; +} + +my %vad_for_file = (); +my %start_times = (); +my %end_times = (); + +while (<>) { + chomp; + my @A = split; + my $file = $A[1]; + + if ($A[0] =~ m/SPKR-INFO/) { + print STDERR "Reading RTTM for file $file\n"; + $vad_for_file{$file} = []; + next; + } elsif ($A[0] !~ m/SPEAKER/) { + next; + } + my $start_time = floor($A[3] / $frame_shift); + my $end_time = floor(($A[3] + $A[4]) / $frame_shift); + + if (! defined $start_times{$file}) { + $start_times{$file} = $A[3]; + $end_times{$file} = $A[3] + $A[4]; + } + + $end_times{$file} = $A[3] + $A[4]; + + exists $vad_for_file{$file} or die "SPKR-INFO not yet seen for file $file. RTTM is not sorted using rttmSort.pl?"; + + for (my $i = scalar @{ $vad_for_file{$file} }; $i < $start_time; $i++) { + $vad_for_file{$file}[$i] = 0; + } + scalar @{$vad_for_file{$file}} < $end_time or die "$end_time is < length of file"; + + for (my $i = scalar @{$vad_for_file{$file}}; $i < $end_time; $i++) { + $vad_for_file{$file}[$i] = 1; + } +} + +foreach (keys %vad_for_file) { + my $file = $_; + my $utt_id = $_; + + defined $start_times{$file} or die "Start time for $file not found\n"; + defined $end_times{$file} or die "End time for $file not found\n"; + + my $start_time = floor($start_times{$file} / $frame_shift); + my $end_time = floor($end_times{$file} / $frame_shift); + if ($ignore_boundaries eq "true") { + if ($segments_file ne "") { + $utt_id = sprintf("$file-%06d-%06d", $start_time, $end_time); + print SEGMENTS "$utt_id $file " . sprintf("%5.2f %5.2f", $start_time * $frame_shift, $end_time * $frame_shift) . "\n"; + } + print STDOUT $utt_id . " " . join(" ", @{ $vad_for_file{$_} }[$start_time..($end_time-1)]) . "\n"; + } else { + print STDOUT $utt_id . " " . join(" ", @{ $vad_for_file{$_} }) . "\n"; + } +} + diff --git a/egs/sre08/v1/diarization/convert_segments_to_rttm.pl b/egs/sre08/v1/diarization/convert_segments_to_rttm.pl new file mode 100755 index 00000000000..90544c55abb --- /dev/null +++ b/egs/sre08/v1/diarization/convert_segments_to_rttm.pl @@ -0,0 +1,50 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar (Johns Hopkins University) +# Apache 2.0. + +use strict; +use POSIX; +use Getopt::Long; +use File::Basename; + +print STDERR join(@ARGV) . "\n"; + +my $frame_shift = 0.01; +GetOptions('frame-shift:f' => \$frame_shift); + +if (@ARGV != 1) { + print STDERR "$0:\n" . + "Usage: cat | convert_segments_to_rttm.pl [options] > \n"; + exit 1; +} +open (UEM, ">", $ARGV[0]) or die "Could not open $ARGV[0]"; + +($frame_shift > 0.0001 && $frame_shift <= 1.0) || + die "Very strange frame-shift value '$frame_shift'"; + +print STDERR "Extracting RTTM from segments\n"; + +my %seen_files = (); +my %min_time = (); +my %max_time = (); + +while () { + chomp; + my @A = split; + my $utt_id = $A[0]; + my $file_id = $A[1]; + + if (! defined $seen_files{$file_id}) { + print STDOUT "SPKR-INFO $file_id 1 unknown speech \n"; + $seen_files{$file_id} = 1; + $min_time{$file_id} = $A[2]; + } + + print STDOUT sprintf("SPEAKER $file_id 1 %.2f %.2f speech \n", $A[2], $A[3] - $A[2]); + $max_time{$file_id} = $A[3]; +} + +foreach (keys %min_time) { + print UEM sprintf("$_ 1 %.2f %.2f\n", $min_time{$_}, $max_time{$_}); +} + diff --git a/egs/sre08/v1/diarization/convert_speaker_conf_to_labels_string.pl b/egs/sre08/v1/diarization/convert_speaker_conf_to_labels_string.pl new file mode 100755 index 00000000000..c6d2d314ce5 --- /dev/null +++ b/egs/sre08/v1/diarization/convert_speaker_conf_to_labels_string.pl @@ -0,0 +1,67 @@ +#!/usr/bin/perl + +use strict; +use POSIX; +use Pod::Usage; +use Getopt::Long; + +my $help = 0; +my $min_spk_conf = 0.8; +my $skip_empty_label_recos = "true"; + +GetOptions('min-spk-conf:f' => \$min_spk_conf, + 'help|?' => \$help, + 'skip-empty-label-recos:s' => \$skip_empty_label_recos); + +if ((@ARGV > 1 || $help)) { + print STDERR "$0:\n" . + "Usage: convert_speaker_conf_to_labels.pl [options] [speaker-conf] > \n"; + exit 0 if $help; + exit 1; +} + +# This is how a line of speaker-conf can look like: +# [ ( )+ ] [ ( )+ ] +#single_f1e7fa29-1 [ 0 0.7433829 1 0.9339832 2 0.9105379 ] [ 0 413.2838 1 818.2226 2 70.9536 ] + +($min_spk_conf > 0.0 && $min_spk_conf <= 1.0) or die "min-spk-conf must be between 0 and 1"; + +while (<>) { + chomp; + + (m/^\s*(\S+)\s+\[((?:\s+[0-9.+]+\s+[0-9.+]+)+)\s+\]\s+\[((?:\s+[0-9.+]+\s+[0-9.+]+)+)\s+\]/) or die "Unparsable line $_ in speaker_conf file"; + + print STDERR "Parsed line: $1 $2 , $3\n"; + my $reco = $1; + my @A = split(' ', $2); + my @B = split(' ', $3); + + ($#A % 2 == 1) or die "Bad line $_: speaker confidences must be of the format "; + $#A == $#B or die "Bad line $_"; + + my @remove_labels; + my @occupancies; + + my $avg_conf = 0; + my $tot_occ = 0; + + for (my $i = 0; $i < @A; $i = $i + 2) { + $avg_conf += $A[$i+1] * $B[$i+1]; + $tot_occ += $B[$i+1]; + } + + $avg_conf /= $tot_occ; + + for (my $i = 0; $i < @A; $i = $i + 2) { + print STDERR "Extracted label : $A[$i], $A[$i+1] " . $min_spk_conf * $avg_conf . "\n"; + if ($A[$i+1] < $min_spk_conf * $avg_conf) { + push @remove_labels, $A[$i]; + } + } + + if ((scalar @remove_labels > 0) || $skip_empty_label_recos ne "true") { + print STDOUT $reco . " " . join(':', @remove_labels) . "\n"; + } else { + print STDERR "Not printing remove-labels for " . $reco . " " . join(':', @remove_labels) . "\n"; + } +} diff --git a/egs/sre08/v1/diarization/convert_vad_to_rttm.pl b/egs/sre08/v1/diarization/convert_vad_to_rttm.pl new file mode 100755 index 00000000000..9e5986df629 --- /dev/null +++ b/egs/sre08/v1/diarization/convert_vad_to_rttm.pl @@ -0,0 +1,92 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar (Johns Hopkins University) +# Apache 2.0. + +use strict; +use POSIX; +use Getopt::Long; +use File::Basename; + +print STDERR join(@ARGV) . "\n"; + +my $frame_shift = 0.01; +my $segments = ""; +my $speech_class = 2; +my $silence_class = 1; +GetOptions('frame-shift:f' => \$frame_shift, + 'segments:s' => \$segments, + 'speech-class:i' => \$speech_class, + 'silence-class:i' => \$silence_class) or die; + +my $in; + +if (@ARGV > 1) { + print STDERR "$0:\n" . + "Usage: convert_vad_to_rttm.pl [options] [] > \n"; + exit 1; +} + +if (@ARGV == 0) { + $in = *STDIN; +} else { + open $in, $ARGV[0] or die "Could not open $ARGV[0]"; +} +($frame_shift > 0.0001 && $frame_shift <= 1.0) || + die "Very strange frame-shift value '$frame_shift'"; + +print STDERR "Extracting RTTM from VAD\n"; + +my %utt2file = (); +my %utt2start = (); +if ($segments ne "") { + open SEGMENTS, $segments or die "Could not open segments file $segments\n"; + while () { + chomp; + my @F = split; + (scalar @F == 4) or die "$0: Invalid line $_ in $segments\n"; + $utt2file{$F[0]} = $F[1]; + $utt2start{$F[0]} = $F[2]; + } +} + +my %seen_files = (); +while (<$in>) { + chomp; + my @A = split; + my $file_id = $A[0]; + if ($segments ne "") { + (defined $utt2file{$A[0]}) or die "$0: Unknown utterance $A[0] in VAD\n"; + $file_id = $utt2file{$A[0]}; + } + if (! defined $seen_files{$file_id}) { + print STDOUT "SPKR-INFO $file_id 1 unknown speech \n"; + $seen_files{$file_id} = 1; + } + + my $state = 1; # silence state + my $begin_time = 0; + my $end_time = 0; + for (my $i = 1; $i < $#A; $i++) { + if ($state == 1 && $A[$i] == $speech_class) { # speech start + $begin_time = ($i-1) * $frame_shift; + $state = 2; + } elsif ($state == 2 && $A[$i] == $silence_class) { # silence start + $end_time = ($i-1) * $frame_shift; + $state = 1; + my $dur = $end_time - $begin_time; + if ($segments ne "") { + $begin_time = $begin_time + $utt2start{$A[0]}; + } + print STDOUT sprintf("SPEAKER $file_id 1 %5.2f %5.2f speech \n", $begin_time, $dur); + } elsif ($A[$i] != $speech_class && $A[$i] != $silence_class) { + die "Unknown class $A[$i]\n"; + } + } + if ($state == 2) { + my $dur = ($#A-1)*$frame_shift - $begin_time; + if ($segments ne "") { + $begin_time = $begin_time + $utt2start{$A[0]}; + } + print STDOUT sprintf("SPEAKER $file_id 1 %5.2f %5.2f speech \n", $begin_time, $dur); + } +} diff --git a/egs/sre08/v1/diarization/create_uniform_segments.sh b/egs/sre08/v1/diarization/create_uniform_segments.sh new file mode 100755 index 00000000000..ec8536b0a55 --- /dev/null +++ b/egs/sre08/v1/diarization/create_uniform_segments.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail + +# Begin configuration section. +cmd=run.pl +nj=4 +window_length=600 # 10 minutes +frame_shift=0.01 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "Usage: create_uniform_segments.sh " + echo " e.g.: diarization/train_vad_gmm_ntu.sh data/dev exp/uniform_segment_dev data/dev_uniform" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +data=$1 +tmpdir=$2 +dir=$3 + +mkdir -p $tmpdir +mkdir -p $dir + +utils/copy_data_dir.sh $data $dir +for f in feats.scp segments spk2utt utt2spk vad.scp cmvn.scp spk2gender text; do + [ -f $dir/$f ] && rm $dir/$f +done +rm -rf $dir/split* $dir/.backup + +awk '{print $1" "$1}' $dir/wav.scp > $dir/utt2spk || exit 1 +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt || exit 1 + +utils/fix_data_dir.sh $dir + +utils/split_data.sh $dir $nj + +$cmd JOB=1:$nj $tmpdir/log/get_wav_lengths.JOB.log \ + wav-to-duration scp:$dir/split$nj/JOB/wav.scp \ + ark,t:$tmpdir/wav_lengths.JOB.txt || exit 1 + +for n in `seq $nj`; do + cat $tmpdir/wav_lengths.$n.txt +done | perl -e ' +$frame_shift = $ARGV[0]; +$window_length = $ARGV[1]; +while () { + chomp; + @F = split; + $file_id = $F[0]; + $duration = $F[1]; + + $num_chunks = int($duration / $window_length + 0.99); + $this_window_length = $duration / $num_chunks; + + for ($i = 0; $i < $num_chunks; $i++) { + $start_time = $i * $this_window_length; + $end_time = ($i+1) * $this_window_length; + $end_time = $duration if ($end_time > $duration); + + $start_str = sprintf("%06d", $start_time / $frame_shift); + $end_str = sprintf("%06d", $end_time / $frame_shift); + + printf STDOUT ("$file_id-$start_str-$end_str $file_id %.2f %.2f\n", $start_time, $end_time); + } +}' $frame_shift $window_length > $dir/segments || exit 1 + +[ ! -s $dir/segments ] && echo "$0: No segments in $dir/segments" && exit 1 + +awk '{print $1" "$2}' $dir/segments > $dir/utt2spk || exit 1 +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt || exit 1 + +utils/fix_data_dir.sh $dir + diff --git a/egs/sre08/v1/diarization/do_vad_gmm.sh b/egs/sre08/v1/diarization/do_vad_gmm.sh new file mode 100755 index 00000000000..ee374383701 --- /dev/null +++ b/egs/sre08/v1/diarization/do_vad_gmm.sh @@ -0,0 +1,334 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail + +# Begin configuration section. +cmd=run.pl +nj=4 +speech_duration=75 +sil_duration=30 +speech_max_gauss=12 +sil_max_gauss=4 +speech_gauss_incr=4 +sil_gauss_incr=1 +num_iters=20 +impr_thres=0.002 +stage=-10 +cleanup=true +top_frames_threshold=0.16 +bottom_frames_threshold=0.04 +select_only_voiced_frames=false +ignore_energy_opts= +apply_cmvn= +init_speech_model= +init_silence_model= +window_size=3 +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 2 ]; then + echo "Usage: diarization/train_vad_gmm.sh " + echo " e.g.: diarization/train_vad_gmm.sh data/dev exp/vad_dev" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --num-iters <#iters> # Number of iterations of E-M" + exit 1; +fi + +data=$1 +dir=$2 + +function build_0gram { +wordlist=$1; lm=$2 +echo "=== Building zerogram $lm from ${wordlist}. ..." +awk '{print $1}' $wordlist | sort -u > $lm +python -c """ +import math +with open('$lm', 'r+') as f: + lines = f.readlines() + p = math.log10(1/float(len(lines))); + lines = ['%f\\t%s'%(p,l) for l in lines] + f.seek(0); f.write('\\n\\\\data\\\\\\nngram 1= %d\\n\\n\\\\1-grams:\\n' % len(lines)) + f.write(''.join(lines) + '\\\\end\\\\') +""" +} + +for f in $data/feats.scp $data/vad.scp; do + [ ! -s $f ] && echo "$0: could not find $f or $f is empty" && exit 1 +done + +feat_dim=`feat-to-dim "ark:head -n 1 $data/feats.scp | add-deltas scp:- ark:- |$ignore_energy_opts" ark,t:- | awk '{print $2}'` || exit 1 + +# Prepare a lang directory +if [ $stage -le -2 ]; then + mkdir -p $dir/local + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states 1 --num-nonsil-states 1 \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fstisstochastic $dir/lang/G.fst || echo "[info]: G not stochastic." + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict $dir/local/lang $dir/lang_test || exit 1 +fi + +if [ $stage -le -1 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang_test/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans_test.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_test $dir/lang_test $dir $dir/graph_test || exit 1 +fi + +cat < $dir/pdf_to_tid.map +0 1 +1 3 +EOF + +apply_cmvn_opts= +if $apply_cmvn; then + apply_cmvn_opts="apply-cmvn-sliding ark:- ark:- |" +fi + +if [ $stage -le 0 ]; then +mkdir -p $dir/q +utils/split_data.sh $data $nj || exit 1 + +select_frames_opts= +select_sil_frames_opts= +if $select_only_voiced_frames; then + select_frames_opts="select-voiced-frames ark:- scp:$data/vad.scp ark:- |" + select_sil_frames_opts="select-voiced-frames --select-unvoiced-frames=true ark:- scp:$data/vad.scp ark:- |" +fi + +extract-column scp:$data/feats.scp ark,scp:mfcc/log_energies.ark,$data/log_energies.scp || { echo "extract-column failed"; exit 1; } + +mkdir -p $dir/silence_gmm +tmpdir=$dir/silence_gmm + +for n in `seq $nj`; do + cat < $dir/q/do_vad.$n.sh +set -e +set -o pipefail +set -u + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +while IFS=$'\n' read line; do + feats="ark:echo \$line | copy-feats scp:- ark:- | add-deltas ark:- ark:- |" + utt_id=\$(echo \$line | awk '{print \$1}') + echo \$utt_id > $dir/\$utt_id.list + + speech_num_gauss=6 + sil_num_gauss=2 + this_top_frames_threshold=$top_frames_threshold + this_bottom_frames_threshold=$bottom_frames_threshold + + speech_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | copy-vector scp:- ark:- |" + sil_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | copy-vector scp:- ark:- |" + + if [ -z "$init_speech_model" ]; then + gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=4 \ + "\$feats $apply_cmvn_opts $select_frames_opts select-top-chunks --window-size=$window_size --frames-proportion=$top_frames_threshold --weights=\"\$speech_energies\" ark:- ark:- |$ignore_energy_opts" \ + $tmpdir/\$utt_id.speech.0.mdl || exit 1 + #gmm-global-init-from-feats --num-gauss=\$sil_num_gauss --num-iters=4 \ + # "\$feats $apply_cmvn_opts $select_sil_frames_opts select-top-chunks --window-size=$window_size --frames-proportion=$bottom_frames_threshold --select-frames=\$[sil_num_gauss * 20] --select-bottom-frames=true --weights=\"\$sil_energies\" ark:- ark:- |$ignore_energy_opts" \ + # $tmpdir/\$utt_id.silence.0.mdl || exit 1 + gmm-global-init-from-feats --num-gauss=\$sil_num_gauss --num-iters=4 \ + "\$feats $apply_cmvn_opts $select_sil_frames_opts select-top-chunks --window-size=$window_size --frames-proportion=$bottom_frames_threshold --select-bottom-frames=true --weights=\"\$sil_energies\" ark:- ark:- |$ignore_energy_opts" \ + $tmpdir/\$utt_id.silence.0.mdl || exit 1 + else + gmm-global-get-frame-likes $init_speech_model \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$dir/\$utt_id.speech_likes.init.ark || exit 1 + + gmm-global-get-frame-likes $init_silence_model \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$dir/\$utt_id.silence_likes.init.ark || exit 1 + + loglikes-to-class --weights=ark:$dir/\$utt_id.weights.init.ark \ + ark:$dir/\$utt_id.silence_likes.init.ark ark:$dir/\$utt_id.speech_likes.init.ark \ + ark:$dir/\$utt_id.vad.init.ark || exit 1 + + speech_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | copy-vector scp:- ark:- |" + sil_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | copy-vector scp:- ark:- |" + + #gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=4 \ + # "\$feats $apply_cmvn_opts select-voiced-frames ark:- ark:$dir/\$utt_id.vad.init.ark ark:- | select-top-chunks --window-size=$window_size --frames-proportion=$top_frames_threshold --weights=\"\$speech_energies\" ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.speech.0.mdl || exit 1 + #gmm-global-init-from-feats --num-gauss=\$sil_num_gauss --num-iters=4 \ + # "\$feats $apply_cmvn_opts select-voiced-frames --select-unvoiced-frames=true ark:- ark:$dir/\$utt_id.vad.init.ark ark:- | select-top-chunks --window-size=$window_size --select-frames=\$[sil_num_gauss * 20] --select-bottom-frames=true --weights=\"\$sil_energies\" ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.silence.0.mdl || exit 1 + + #gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=4 \ + # "\$feats $apply_cmvn_opts select-top-chunks --window-size=1 --frames-proportion=$top_frames_threshold \ + # --weights=\"\$speech_energies\" --selection-mask=ark:$dir/\$utt_id.vad.init.ark ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.speech.0.mdl || exit 1 + + gmm-global-init-from-feats --num-gauss=\$sil_num_gauss --num-iters=4 \ + "\$feats $apply_cmvn_opts select-top-chunks --select-bottom-frames=true --invert-mask=true \ + --window-size=1 --select-frames=\$[sil_num_gauss * 100] \ + --weights=\"\$sil_energies\" --selection-mask=ark:$dir/\$utt_id.vad.init.ark ark:- ark:- |$ignore_energy_opts" \ + $tmpdir/\$utt_id.silence.0.mdl || exit 1 + + fi + + gmm-global-get-frame-likes $init_speech_model \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$tmpdir/\$utt_id.speech_likes.0.ark || exit 1 + + x=0 + while [ \$x -lt $num_iters ]; do + if [ $stage -le \$x ]; then + + if [ \$sil_num_gauss -le $sil_max_gauss ]; then + sil_num_gauss=\$[sil_num_gauss + $sil_gauss_incr] + fi + + gmm-global-get-frame-likes $tmpdir/\$utt_id.silence.\$x.mdl \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$tmpdir/\$utt_id.silence_likes.\$x.ark || exit 1 + + loglikes-to-class --weights=ark:$tmpdir/\$utt_id.weights.\$x.ark \ + ark:$tmpdir/\$utt_id.silence_likes.\$x.ark ark:$tmpdir/\$utt_id.speech_likes.0.ark \ + ark:$tmpdir/\$utt_id.vad.\$x.ark || exit 1 + + sil_energies="ark:utils/filter_scp.pl $tmpdir/\$utt_id.list $data/log_energies.scp | select-voiced-frames --select-unvoiced-frames=true scp:- ark:$tmpdir/\$utt_id.vad.\$x.ark ark:- |" + + gmm-global-acc-stats $tmpdir/\$utt_id.silence.\$x.mdl \ + "\$feats $apply_cmvn_opts select-top-chunks --select-bottom-frames=true --invert-mask=true \ + --window-size=$window_size --select-frames=\$[(sil_num_gauss-1) * 2000] \ + --weights=ark:$tmpdir/\$utt_id.weights.\$x.ark --selection-mask=ark:$tmpdir/\$utt_id.vad.\$x.ark ark:- ark:- |$ignore_energy_opts" - | \ + gmm-global-est --mix-up=\$sil_num_gauss $tmpdir/\$utt_id.silence.\$x.mdl \ + - $tmpdir/\$utt_id.silence.\$[x+1].mdl || exit 1 + + fi + x=\$[x+1] + done + + cp $tmpdir/\$utt_id.silence.\$x.mdl $dir/\$utt_id.silence.0.mdl + cp $init_speech_model $dir/\$utt_id.speech.0.mdl + + x=0 + while [ \$x -lt $num_iters ]; do + if [ $stage -le \$x ]; then + + if [ \$speech_num_gauss -le $speech_max_gauss ]; then + speech_num_gauss=\$[speech_num_gauss + $speech_gauss_incr] + fi + + if [ \$sil_num_gauss -le $sil_max_gauss ]; then + sil_num_gauss=\$[sil_num_gauss + $sil_gauss_incr] + fi + + this_top_frames_threshold=1.0 + #this_bottom_frames_threshold=1.0 + #this_top_frames_threshold=\$(perl -e "if (\$this_top_frames_threshold < 0.5) { print \$this_top_frames_threshold * 2 } else { print 0.5 }") + #this_bottom_frames_threshold=\$(perl -e "if (\$this_bottom_frames_threshold < 0.8) { print \$this_bottom_frames_threshold * 2 } else { print \$this_bottom_frames_threshold }") + + + gmm-global-get-frame-likes $dir/\$utt_id.speech.\$x.mdl \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$dir/\$utt_id.speech_likes.\$x.ark || exit 1 + + gmm-global-get-frame-likes $dir/\$utt_id.silence.\$x.mdl \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$dir/\$utt_id.silence_likes.\$x.ark || exit 1 + + loglikes-to-class --weights=ark:$dir/\$utt_id.weights.\$x.ark \ + ark:$dir/\$utt_id.silence_likes.\$x.ark ark:$dir/\$utt_id.speech_likes.\$x.ark \ + ark:$dir/\$utt_id.vad.\$x.ark || exit 1 + + speech_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | select-voiced-frames scp:- ark:$dir/\$utt_id.vad.\$x.ark ark:- |" + sil_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | select-voiced-frames --select-unvoiced-frames=true scp:- ark:$dir/\$utt_id.vad.\$x.ark ark:- |" + + #gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=10 \ + # "\$feats $apply_cmvn_opts select-top-chunks --window-size=$window_size --frames-proportion=\$this_top_frames_threshold --weights=scp:$data/log_energies.scp --selection-mask=ark:$dir/\$utt_id.vad.\$x.ark ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.speech.\$[x+1].mdl || exit 1 + # + + #gmm-global-acc-stats $dir/\$utt_id.speech.\$x.mdl \ + # "\$feats $apply_cmvn_opts select-voiced-frames ark:- ark:$dir/\$utt_id.vad.\$x.ark ark:- | select-top-chunks --window-size=$window_size --frames-proportion=\$this_top_frames_threshold --weights=\"\$speech_energies\" ark:- ark:- |$ignore_energy_opts" - | \ + # gmm-global-est --mix-up=\$speech_num_gauss $dir/\$utt_id.speech.\$x.mdl \ + # - $dir/\$utt_id.speech.\$[x+1].mdl || exit 1 + + #gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=10 \ + # "\$feats $apply_cmvn_opts select-top-chunks --window-size=$window_size --frames-proportion=\$this_top_frames_threshold \ + # --weights=p:$data/log_energies.scp --selection-mask=ark:$dir/\$utt_id.vad.\$x.ark ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.speech.\$[x+1].mdl || exit 1 + + gmm-global-acc-stats $dir/\$utt_id.speech.\$x.mdl \ + "\$feats $apply_cmvn_opts select-top-chunks --window-size=$window_size --frames-proportion=\$this_top_frames_threshold \ + --weights=ark:$dir/\$utt_id.weights.\$x.ark --selection-mask=ark:$dir/\$utt_id.vad.\$x.ark ark:- ark:- |$ignore_energy_opts" - | \ + gmm-global-est --mix-up=\$speech_num_gauss $dir/\$utt_id.speech.\$x.mdl \ + - $dir/\$utt_id.speech.\$[x+1].mdl || exit 1 + + + #gmm-global-acc-stats $dir/\$utt_id.silence.\$x.mdl \ + # "\$feats $apply_cmvn_opts select-voiced-frames --select-unvoiced-frames=true ark:- ark:$dir/\$utt_id.vad.\$x.ark ark:- | select-top-chunks --window-size=$window_size --select-frames=\$[sil_num_gauss * 40] --select-bottom-frames=true --weights=\"\$sil_energies\" ark:- ark:- |$ignore_energy_opts" - | \ + # gmm-global-est --mix-up=\$sil_num_gauss $dir/\$utt_id.silence.\$x.mdl \ + # - $dir/\$utt_id.silence.\$[x+1].mdl || exit 1 + + gmm-global-acc-stats $dir/\$utt_id.silence.\$x.mdl \ + "\$feats $apply_cmvn_opts select-top-chunks --select-bottom-frames=true --invert-mask=true \ + --window-size=$window_size \ + --weights=ark:$dir/\$utt_id.weights.\$x.ark --selection-mask=ark:$dir/\$utt_id.vad.\$x.ark ark:- ark:- |$ignore_energy_opts" - | \ + gmm-global-est --mix-up=\$sil_num_gauss $dir/\$utt_id.silence.\$x.mdl \ + - $dir/\$utt_id.silence.\$[x+1].mdl || exit 1 + + #objf_impr=\$(cat $dir/log/update.\$utt_id.\$x.log | grep "GMM update: Overall .* objective function" | perl -pe 's/.*GMM update: Overall (\S+) objective function .*/\$1/') + # + #if [ "\$(perl -e "if (\$objf_impr < $impr_thres) { print true; }")" == true ]; then + # break; + #fi + fi + x=\$[x+1] + done + + rm -f $dir/\$utt_id.final.mdl 2>/dev/null || true + #cp $dir/\$utt_id.\$x.mdl $dir/\$utt_id.final.mdl + + ( + copy-transition-model --binary=false $dir/trans_test.mdl - + echo " $feat_dim 2" + gmm-global-copy --binary=false $dir/\$utt_id.silence.\$x.mdl - + gmm-global-copy --binary=false $dir/\$utt_id.speech.\$x.mdl - + ) | gmm-copy - $dir/\$utt_id.final.mdl + + gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/\$utt_id.final.mdl $dir/graph_test/HCLG.fst \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:/dev/null ark:$dir/\$utt_id.final.ali || exit 1 +done < $data/split$nj/$n/feats.scp +EOF +done +fi + +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj $dir/log/do_vad_job.JOB.log bash -x $dir/q/do_vad.JOB.sh || exit 1 +fi + +if $cleanup; then + for x in `seq $[num_iters - 1]`; do + if [ $[x % 10] -ne 0 ]; then + rm $dir/*.$x.mdl + fi + done +fi + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log diff --git a/egs/sre08/v1/diarization/do_vad_gmm_icsi.sh b/egs/sre08/v1/diarization/do_vad_gmm_icsi.sh new file mode 100755 index 00000000000..d5551094ba5 --- /dev/null +++ b/egs/sre08/v1/diarization/do_vad_gmm_icsi.sh @@ -0,0 +1,298 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail + +# Begin configuration section. +cmd=run.pl +nj=4 +speech_duration=75 +sil_duration=30 +speech_max_gauss=12 +sil_max_gauss=4 +speech_gauss_incr=4 +sil_gauss_incr=1 +num_iters=20 +impr_thres=0.002 +stage=-10 +cleanup=true +top_frames_threshold=0.16 +bottom_frames_threshold=0.04 +select_only_voiced_frames=false +ignore_energy_opts= +apply_cmvn= +init_speech_model= +init_silence_model= +window_size=3 +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 2 ]; then + echo "Usage: diarization/train_vad_gmm.sh " + echo " e.g.: diarization/train_vad_gmm.sh data/dev exp/vad_dev" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --num-iters <#iters> # Number of iterations of E-M" + exit 1; +fi + +data=$1 +dir=$2 + +function build_0gram { +wordlist=$1; lm=$2 +echo "=== Building zerogram $lm from ${wordlist}. ..." +awk '{print $1}' $wordlist | sort -u > $lm +python -c """ +import math +with open('$lm', 'r+') as f: + lines = f.readlines() + p = math.log10(1/float(len(lines))); + lines = ['%f\\t%s'%(p,l) for l in lines] + f.seek(0); f.write('\\n\\\\data\\\\\\nngram 1= %d\\n\\n\\\\1-grams:\\n' % len(lines)) + f.write(''.join(lines) + '\\\\end\\\\') +""" +} + +for f in $data/feats.scp $data/vad.scp; do + [ ! -s $f ] && echo "$0: could not find $f or $f is empty" && exit 1 +done + +feat_dim=`feat-to-dim "ark:head -n 1 $data/feats.scp | add-deltas scp:- ark:- |$ignore_energy_opts" ark,t:- | awk '{print $2}'` || exit 1 + +# Prepare a lang directory +if [ $stage -le -2 ]; then + mkdir -p $dir/local + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states 1 --num-nonsil-states 1 \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fstisstochastic $dir/lang/G.fst || echo "[info]: G not stochastic." + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict $dir/local/lang $dir/lang_test || exit 1 +fi + +if [ $stage -le -1 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang_test/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans_test.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_test $dir/lang_test $dir $dir/graph_test || exit 1 +fi + +cat < $dir/pdf_to_tid.map +0 1 +1 3 +EOF + +apply_cmvn_opts= +if $apply_cmvn; then + apply_cmvn_opts="apply-cmvn-sliding ark:- ark:- |" +fi + +if [ $stage -le 0 ]; then +mkdir -p $dir/q +utils/split_data.sh $data $nj || exit 1 + +select_frames_opts= +select_sil_frames_opts= +if $select_only_voiced_frames; then + select_frames_opts="select-voiced-frames ark:- scp:$data/vad.scp ark:- |" + select_sil_frames_opts="select-voiced-frames --select-unvoiced-frames=true ark:- scp:$data/vad.scp ark:- |" +fi + +extract-column scp:$data/feats.scp ark,scp:mfcc/log_energies.ark,$data/log_energies.scp || { echo "extract-column failed"; exit 1; } + +for n in `seq $nj`; do + cat < $dir/q/do_vad.$n.sh +set -e +set -o pipefail +set -u + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +while IFS=$'\n' read line; do + feats="ark:echo \$line | copy-feats scp:- ark:- | add-deltas ark:- ark:- |" + utt_id=\$(echo \$line | awk '{print \$1}') + echo \$utt_id > $dir/\$utt_id.list + + speech_num_gauss=6 + sil_num_gauss=2 + this_top_frames_threshold=$top_frames_threshold + this_bottom_frames_threshold=$bottom_frames_threshold + + speech_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | copy-vector scp:- ark:- |" + sil_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | copy-vector scp:- ark:- |" + + if [ -z "$init_speech_model" ]; then + gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=4 \ + "\$feats $apply_cmvn_opts $select_frames_opts select-top-chunks --window-size=$window_size --frames-proportion=$top_frames_threshold --weights=\"\$speech_energies\" ark:- ark:- |$ignore_energy_opts" \ + $dir/\$utt_id.speech.0.mdl || exit 1 + #gmm-global-init-from-feats --num-gauss=\$sil_num_gauss --num-iters=4 \ + # "\$feats $apply_cmvn_opts $select_sil_frames_opts select-top-chunks --window-size=$window_size --frames-proportion=$bottom_frames_threshold --select-frames=\$[sil_num_gauss * 20] --select-bottom-frames=true --weights=\"\$sil_energies\" ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.silence.0.mdl || exit 1 + gmm-global-init-from-feats --num-gauss=\$sil_num_gauss --num-iters=4 \ + "\$feats $apply_cmvn_opts $select_sil_frames_opts select-top-chunks --window-size=$window_size --frames-proportion=$bottom_frames_threshold --select-bottom-frames=true --weights=\"\$sil_energies\" ark:- ark:- |$ignore_energy_opts" \ + $dir/\$utt_id.silence.0.mdl || exit 1 + else + gmm-global-get-frame-likes $init_speech_model \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$dir/\$utt_id.speech_likes.init.ark || exit 1 + + gmm-global-get-frame-likes $init_silence_model \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$dir/\$utt_id.silence_likes.init.ark || exit 1 + + loglikes-to-class --weights=ark:$dir/\$utt_id.weights.init.ark \ + ark:$dir/\$utt_id.silence_likes.init.ark ark:$dir/\$utt_id.speech_likes.init.ark \ + ark:$dir/\$utt_id.vad.init.ark || exit 1 + + speech_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | copy-vector scp:- ark:- |" + sil_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | copy-vector scp:- ark:- |" + + #gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=4 \ + # "\$feats $apply_cmvn_opts select-voiced-frames ark:- ark:$dir/\$utt_id.vad.init.ark ark:- | select-top-chunks --window-size=$window_size --frames-proportion=$top_frames_threshold --weights=\"\$speech_energies\" ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.speech.0.mdl || exit 1 + #gmm-global-init-from-feats --num-gauss=\$sil_num_gauss --num-iters=4 \ + # "\$feats $apply_cmvn_opts select-voiced-frames --select-unvoiced-frames=true ark:- ark:$dir/\$utt_id.vad.init.ark ark:- | select-top-chunks --window-size=$window_size --select-frames=\$[sil_num_gauss * 20] --select-bottom-frames=true --weights=\"\$sil_energies\" ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.silence.0.mdl || exit 1 + + #gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=4 \ + # "\$feats $apply_cmvn_opts select-top-chunks --window-size=1 --frames-proportion=$top_frames_threshold \ + # --weights=\"\$speech_energies\" --selection-mask=ark:$dir/\$utt_id.vad.init.ark ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.speech.0.mdl || exit 1 + + gmm-global-init-from-feats --num-gauss=\$sil_num_gauss --num-iters=4 \ + "\$feats $apply_cmvn_opts select-top-chunks --select-bottom-frames=true --invert-mask=true \ + --window-size=1 --select-frames=\$[sil_num_gauss * 100] \ + --weights=\"\$sil_energies\" --selection-mask=ark:$dir/\$utt_id.vad.init.ark ark:- ark:- |$ignore_energy_opts" \ + $dir/\$utt_id.silence.0.mdl || exit 1 + + fi + + x=0 + while [ \$x -lt $num_iters ]; do + if [ $stage -le \$x ]; then + + if [ \$speech_num_gauss -le $speech_max_gauss ]; then + speech_num_gauss=\$[speech_num_gauss + $speech_gauss_incr] + fi + + if [ \$sil_num_gauss -le $sil_max_gauss ]; then + sil_num_gauss=\$[sil_num_gauss + $sil_gauss_incr] + fi + + this_top_frames_threshold=0.9 + #this_bottom_frames_threshold=1.0 + #this_top_frames_threshold=\$(perl -e "if (\$this_top_frames_threshold < 0.5) { print \$this_top_frames_threshold * 2 } else { print 0.5 }") + this_bottom_frames_threshold=\$(perl -e "if (\$this_bottom_frames_threshold < 0.8) { print \$this_bottom_frames_threshold * 2 } else { print \$this_bottom_frames_threshold }") + + + gmm-global-get-frame-likes $dir/\$utt_id.speech.\$x.mdl \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$dir/\$utt_id.speech_likes.\$x.ark || exit 1 + + gmm-global-get-frame-likes $dir/\$utt_id.silence.\$x.mdl \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:$dir/\$utt_id.silence_likes.\$x.ark || exit 1 + + loglikes-to-class --weights=ark:$dir/\$utt_id.weights.\$x.ark \ + ark:$dir/\$utt_id.silence_likes.\$x.ark ark:$dir/\$utt_id.speech_likes.\$x.ark \ + ark:$dir/\$utt_id.vad.\$x.ark || exit 1 + + speech_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | select-voiced-frames scp:- ark:$dir/\$utt_id.vad.\$x.ark ark:- |" + sil_energies="ark:utils/filter_scp.pl $dir/\$utt_id.list $data/log_energies.scp | select-voiced-frames --select-unvoiced-frames=true scp:- ark:$dir/\$utt_id.vad.\$x.ark ark:- |" + + #gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=10 \ + # "\$feats $apply_cmvn_opts select-top-chunks --window-size=$window_size --frames-proportion=\$this_top_frames_threshold --weights=scp:$data/log_energies.scp --selection-mask=ark:$dir/\$utt_id.vad.\$x.ark ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.speech.\$[x+1].mdl || exit 1 + # + + #gmm-global-acc-stats $dir/\$utt_id.speech.\$x.mdl \ + # "\$feats $apply_cmvn_opts select-voiced-frames ark:- ark:$dir/\$utt_id.vad.\$x.ark ark:- | select-top-chunks --window-size=$window_size --frames-proportion=\$this_top_frames_threshold --weights=\"\$speech_energies\" ark:- ark:- |$ignore_energy_opts" - | \ + # gmm-global-est --mix-up=\$speech_num_gauss $dir/\$utt_id.speech.\$x.mdl \ + # - $dir/\$utt_id.speech.\$[x+1].mdl || exit 1 + + #gmm-global-init-from-feats --num-gauss=\$speech_num_gauss --num-iters=10 \ + # "\$feats $apply_cmvn_opts select-top-chunks --window-size=$window_size --frames-proportion=\$this_top_frames_threshold \ + # --weights=p:$data/log_energies.scp --selection-mask=ark:$dir/\$utt_id.vad.\$x.ark ark:- ark:- |$ignore_energy_opts" \ + # $dir/\$utt_id.speech.\$[x+1].mdl || exit 1 + + gmm-global-acc-stats $dir/\$utt_id.speech.\$x.mdl \ + "\$feats $apply_cmvn_opts select-top-chunks --window-size=$window_size --frames-proportion=\$this_top_frames_threshold \ + --weights=ark:$dir/\$utt_id.weights.\$x.ark --selection-mask=ark:$dir/\$utt_id.vad.\$x.ark ark:- ark:- |$ignore_energy_opts" - | \ + gmm-global-est --mix-up=\$speech_num_gauss $dir/\$utt_id.speech.\$x.mdl \ + - $dir/\$utt_id.speech.\$[x+1].mdl || exit 1 + + + #gmm-global-acc-stats $dir/\$utt_id.silence.\$x.mdl \ + # "\$feats $apply_cmvn_opts select-voiced-frames --select-unvoiced-frames=true ark:- ark:$dir/\$utt_id.vad.\$x.ark ark:- | select-top-chunks --window-size=$window_size --select-frames=\$[sil_num_gauss * 40] --select-bottom-frames=true --weights=\"\$sil_energies\" ark:- ark:- |$ignore_energy_opts" - | \ + # gmm-global-est --mix-up=\$sil_num_gauss $dir/\$utt_id.silence.\$x.mdl \ + # - $dir/\$utt_id.silence.\$[x+1].mdl || exit 1 + + gmm-global-acc-stats $dir/\$utt_id.silence.\$x.mdl \ + "\$feats $apply_cmvn_opts select-top-chunks --select-bottom-frames=true --invert-mask=true \ + --window-size=$window_size --select-frames=\$[sil_num_gauss * 100] \ + --weights=ark:$dir/\$utt_id.weights.\$x.ark --selection-mask=ark:$dir/\$utt_id.vad.\$x.ark ark:- ark:- |$ignore_energy_opts" - | \ + gmm-global-est --mix-up=\$sil_num_gauss $dir/\$utt_id.silence.\$x.mdl \ + - $dir/\$utt_id.silence.\$[x+1].mdl || exit 1 + + #objf_impr=\$(cat $dir/log/update.\$utt_id.\$x.log | grep "GMM update: Overall .* objective function" | perl -pe 's/.*GMM update: Overall (\S+) objective function .*/\$1/') + # + #if [ "\$(perl -e "if (\$objf_impr < $impr_thres) { print true; }")" == true ]; then + # break; + #fi + fi + x=\$[x+1] + done + + rm -f $dir/\$utt_id.final.mdl 2>/dev/null || true + #cp $dir/\$utt_id.\$x.mdl $dir/\$utt_id.final.mdl + + ( + copy-transition-model --binary=false $dir/trans_test.mdl - + echo " $feat_dim 2" + gmm-global-copy --binary=false $dir/\$utt_id.silence.\$x.mdl - + gmm-global-copy --binary=false $dir/\$utt_id.speech.\$x.mdl - + ) | gmm-copy - $dir/\$utt_id.final.mdl + + gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/\$utt_id.final.mdl $dir/graph_test/HCLG.fst \ + "\${feats}${apply_cmvn_opts}$ignore_energy_opts" ark:/dev/null ark:$dir/\$utt_id.final.ali || exit 1 +done < $data/split$nj/$n/feats.scp +EOF +done +fi + +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj $dir/log/do_vad_job.JOB.log bash -x $dir/q/do_vad.JOB.sh || exit 1 +fi + +if $cleanup; then + for x in `seq $[num_iters - 1]`; do + if [ $[x % 10] -ne 0 ]; then + rm $dir/*.$x.mdl + fi + done +fi + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log + diff --git a/egs/sre08/v1/diarization/do_vad_gmm_post.sh b/egs/sre08/v1/diarization/do_vad_gmm_post.sh new file mode 100755 index 00000000000..5b00a6c93d0 --- /dev/null +++ b/egs/sre08/v1/diarization/do_vad_gmm_post.sh @@ -0,0 +1,197 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail + +# Begin configuration section. +cmd=run.pl +nj=4 +speech_duration=75 +sil_duration=30 +speech_num_gauss=16 +sil_num_gauss=4 +num_iters=20 +impr_thres=0.002 +stage=-10 +cleanup=true +select_top_frames=true +top_frames_threshold=0.16 +bottom_frames_threshold=0.04 +init_vad_model= +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 2 ]; then + echo "Usage: diarization/train_vad_gmm.sh " + echo " e.g.: diarization/train_vad_gmm.sh data/dev exp/vad_dev" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --num-iters <#iters> # Number of iterations of E-M" + exit 1; +fi + +data=$1 +dir=$2 + +function build_0gram { +wordlist=$1; lm=$2 +echo "=== Building zerogram $lm from ${wordlist}. ..." +awk '{print $1}' $wordlist | sort -u > $lm +python -c """ +import math +with open('$lm', 'r+') as f: + lines = f.readlines() + p = math.log10(1/float(len(lines))); + lines = ['%f\\t%s'%(p,l) for l in lines] + f.seek(0); f.write('\\n\\\\data\\\\\\nngram 1= %d\\n\\n\\\\1-grams:\\n' % len(lines)) + f.write(''.join(lines) + '\\\\end\\\\') +""" +} + +for f in $data/feats.scp $data/vad.scp; do + [ ! -s $f ] && echo "$0: could not find $f or $f is empty" && exit 1 +done + +feat_dim=`feat-to-dim "scp:head -n 1 $data/feats.scp |" ark,t:- | awk '{print $2}'` || exit 1 + +# Prepare a lang directory +if [ $stage -le -2 ]; then + mkdir -p $dir/local + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states 1 --num-nonsil-states 1 \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fstisstochastic $dir/lang/G.fst || echo "[info]: G not stochastic." + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict $dir/local/lang $dir/lang_test || exit 1 +fi + +if [ $stage -le -1 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang_test/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans_test.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_test $dir/lang_test $dir $dir/graph_test || exit 1 +fi + +if [ $stage -le 0 ]; then +mkdir -p $dir/q +utils/split_data.sh $data $nj || exit 1 + +for n in `seq $nj`; do + cat < $dir/q/do_vad.$n.sh +set -e +set -o pipefail +set -u + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +while IFS=$'\n' read line; do + feats="ark:echo \$line | copy-feats scp:- ark:- |" + utt_id=\$(echo \$line | awk '{print \$1}') + + if [ -z "$init_vad_model" ]; then + if ! $select_top_frames; then + gmm-global-init-from-feats --num-gauss=$speech_num_gauss --num-iters=10 \ + "\$feats select-voiced-frames ark:- scp:$data/vad.scp ark:- |" \ + $dir/\$utt_id.speech.0.mdl || exit 1 + gmm-global-init-from-feats --num-gauss=$sil_num_gauss --num-iters=6 \ + "\$feats select-voiced-frames --select-unvoiced-frames=true ark:- scp:$data/vad.scp ark:- |" \ + $dir/\$utt_id.silence.0.mdl || exit 1 + else + gmm-global-init-from-feats --num-gauss=$speech_num_gauss --num-iters=12 \ + "\$feats select-top-frames --top-frames-proportion=$top_frames_threshold ark:- ark:- |" \ + $dir/\$utt_id.speech.0.mdl || exit 1 + gmm-global-init-from-feats --num-gauss=$sil_num_gauss --num-iters=8 \ + "\$feats select-top-frames --bottom-frames-proportion=$bottom_frames_threshold --top-frames-proportion=0.0 ark:- ark:- |" \ + $dir/\$utt_id.silence.0.mdl || exit 1 + fi + + { + cat $dir/trans.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $dir/\$utt_id.silence.0.mdl - + gmm-global-copy --binary=false $dir/\$utt_id.speech.0.mdl - + } | gmm-copy - $dir/\$utt_id.0.mdl || exit 1 + else + cp $init_vad_model $dir/\$utt_id.0.mdl + fi + + x=0 + while [ \$x -lt $num_iters ]; do + gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/\$utt_id.\$x.mdl $dir/graph/HCLG.fst \ + "\$feats" ark:/dev/null ark:$dir/\$utt_id.\$x.ali || exit 1 + + gmm-acc-stats-ali \ + $dir/\$utt_id.\$x.mdl "\$feats" \ + ark:$dir/\$utt_id.\$x.ali - | \ + gmm-est $dir/\$utt_id.\$x.mdl - $dir/\$utt_id.\$[x+1].mdl \ + 2>&1 | tee $dir/log/update.\$utt_id.\$x.log || exit 1 + + objf_impr=\$(cat $dir/log/update.\$utt_id.\$x.log | grep "GMM update: Overall .* objective function" | perl -pe 's/.*GMM update: Overall (\S+) objective function .*/\$1/') + + if [ "\$(perl -e "if (\$objf_impr < $impr_thres) { print true; }")" == true ]; then + break; + fi + + x=\$[x+1] + done + + rm -f $dir/\$utt_id.final.mdl 2>/dev/null || true + cp $dir/\$utt_id.\$x.mdl $dir/\$utt_id.final.mdl + + ( + copy-transition-model --binary=false $dir/trans_test.mdl - + gmm-copy --write-tm=false --binary=false $dir/\$utt_id.\$x.mdl - + ) | gmm-copy - $dir/\$utt_id.final.mdl + + #gmm-decode-simple \ + # --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + # $dir/\$utt_id.final.mdl $dir/graph/HCLG.fst \ + # "\$feats" ark:/dev/null ark:$dir/\$utt_id.final.ali || exit 1 + + gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/\$utt_id.final.mdl $dir/graph_test/HCLG.fst \ + "\$feats" ark:/dev/null ark:$dir/\$utt_id.final.ali || exit 1 +done < $data/split$nj/$n/feats.scp +EOF +done +fi + +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj $dir/log/do_vad_job.JOB.log bash -x $dir/q/do_vad.JOB.sh || exit 1 +fi + +if $cleanup; then + for x in `seq $[num_iters - 1]`; do + if [ $[x % 10] -ne 0 ]; then + rm $dir/*.$x.mdl + fi + done +fi + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log diff --git a/egs/sre08/v1/diarization/evaluate_segmentation.pl b/egs/sre08/v1/diarization/evaluate_segmentation.pl new file mode 100755 index 00000000000..9c0dcaae6c8 --- /dev/null +++ b/egs/sre08/v1/diarization/evaluate_segmentation.pl @@ -0,0 +1,198 @@ +#!/usr/bin/perl + +# Copyright 2014 Johns Hopkins University (Author: Sanjeev Khudanpur), Vimal Manohar +# Apache 2.0 + +################################################################################ +# +# This script was written to check the goodness of automatic segmentation tools +# It assumes input in the form of two Kaldi segments files, i.e. a file each of +# whose lines contain four space-separated values: +# +# UtteranceID FileID StartTime EndTime +# +# It computes # missed frames, # false positives and # overlapping frames. +# +################################################################################ + +if ($#ARGV == 1) { + $ReferenceSegmentation = $ARGV[0]; + $HypothesizedSegmentation = $ARGV[1]; + printf STDERR ("Comparing reference segmentation\n\t%s\nwith proposed segmentation\n\t%s\n", + $ReferenceSegmentation, + $HypothesizedSegmentation); +} else { + printf STDERR "This program compares the reference segmenation with the proposted segmentation\n"; + printf STDERR "Usage: $0 reference_segments_filename proposed_segments_filename\n"; + printf STDERR "e.g. $0 data/dev10h/segments data/dev10h.seg/segments\n"; + exit (0); +} + +################################################################################ +# First read the reference segmentation, and +# store the start- and end-times of all segments in each file. +################################################################################ + +open (SEGMENTS, "cat $ReferenceSegmentation | sort -k2,2 -k3n,3 -k4n,4 |") + || die "Unable to open $ReferenceSegmentation"; +$numLines = 0; +while ($line=) { + chomp $line; + @field = split("[ \t]+", $line); + unless ($#field == 3) { + exit (1); + printf STDERR "Skipping unparseable line in file $ReferenceSegmentation\n\t$line\n"; + next; + } + $fileID = $field[1]; + unless (exists $firstSeg{$fileID}) { + $firstSeg{$fileID} = $numLines; + $actualSpeech{$fileID} = 0.0; + $hypothesizedSpeech{$fileID} = 0.0; + $foundSpeech{$fileID} = 0.0; + $falseAlarm{$fileID} = 0.0; + $minStartTime{$fileID} = 0.0; + $maxEndTime{$fileID} = 0.0; + } + $refSegName[$numLines] = $field[0]; + $refSegStart[$numLines] = $field[2]; + $refSegEnd[$numLines] = $field[3]; + $actualSpeech{$fileID} += ($field[3]-$field[2]); + $minStartTime{$fileID} = $field[2] if ($minStartTime{$fileID}>$field[2]); + $maxEndTime{$fileID} = $field[3] if ($maxEndTime{$fileID}<$field[3]); + $lastSeg{$fileID} = $numLines; + ++$numLines; +} +close(SEGMENTS); +print STDERR "Read $numLines segments from $ReferenceSegmentation\n"; + +################################################################################ +# Process hypothesized segments sequentially, and gather speech/nonspeech stats +################################################################################ + +open (SEGMENTS, "cat $HypothesizedSegmentation | sort -k2,2 -k1,1 |") + # Kaldi segments files are sorted by UtteranceID, but we re-sort them here + # so that all segments of a file are read together, sorted by start-time. + || die "Unable to open $HypothesizedSegmentation"; +$numLines = 0; +$totalHypSpeech = 0.0; +$totalFoundSpeech = 0.0; +$totalFalseAlarm = 0.0; +$numShortSegs = 0; +$numLongSegs = 0; +while ($line=) { + chomp $line; + @field = split("[ \t]+", $line); + unless ($#field == 3) { + exit (1); + printf STDERR "Skipping unparseable line in file $HypothesizedSegmentation\n\t$line\n"; + next; + } + $fileID = $field[1]; + $segStart = $field[2]; + $segEnd = $field[3]; + if (exists $firstSeg{$fileID}) { + # This FileID exists in the reference segmentation + # So gather statistics for this UtteranceID + $hypothesizedSpeech{$fileID} += ($segEnd-$segStart); + $totalHypSpeech += ($segEnd-$segStart); + if (($segStart>=$maxEndTime{$fileID}) || ($segEnd<=$minStartTime{$fileID})) { + # This entire segment is a false alarm + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } else { + # This segment may overlap one or more reference segments + $p = $firstSeg{$fileID}; + while ($refSegEnd[$p]<=$segStart) { + ++$p; + } + # The overlap, if any, begins at the reference segment p + $q = $lastSeg{$fileID}; + while ($refSegStart[$q]>=$segEnd) { + --$q; + } + # The overlap, if any, ends at the reference segment q + if ($q<$p) { + # This segment sits entirely in the nonspeech region + # between the two reference speech segments q and p + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } else { + if (($segEnd-$segStart)<0.20) { + # For diagnosing Pascal's VAD segmentation + print STDOUT "Found short speech region $line\n"; + ++$numShortSegs; + } elsif (($segEnd-$segStart)>60.0) { + ++$numLongSegs; + # For diagnosing Pascal's VAD segmentation + print STDOUT "Found long speech region $line\n"; + } + # There is some overlap with segments p through q + for ($s=$p; $s<=$q; ++$s) { + if ($segStart<$refSegStart[$s]) { + # There is a leading false alarm portion before s + $falseAlarm{$fileID} += ($refSegStart[$s]-$segStart); + $totalFalseAlarm += ($refSegStart[$s]-$segStart); + $segStart=$refSegStart[$s]; + } + $speechPortion = ($refSegEnd[$s]<$segEnd) ? + ($refSegEnd[$s]-$segStart) : ($segEnd-$segStart); + $foundSpeech{$fileID} += $speechPortion; + $totalFoundSpeech += $speechPortion; + $segStart=$refSegEnd[$s]; + } + if ($segEnd>$segStart) { + # There is a trailing false alarm portion after q + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } + } + } + } else { + # This FileID does not exist in the reference segmentation + # So all this speech counts as a false alarm + exit (1); + printf STDERR ("Unexpected fileID in hypothesized segments: %s", $fileID); + $totalFalseAlarm += ($segEnd-$segStart); + } + ++$numLines; +} +close(SEGMENTS); +print STDERR "Read $numLines segments from $HypothesizedSegmentation\n"; + +################################################################################ +# Now that all hypothesized segments have been processed, compute needed stats +################################################################################ + +$totalActualSpeech = 0.0; +$totalNonSpeechEst = 0.0; # This is just a crude estimate of total nonspeech. +foreach $fileID (sort keys %actualSpeech) { + $totalActualSpeech += $actualSpeech{$fileID}; + $totalNonSpeechEst += $maxEndTime{$fileID} - $actualSpeech{$fileID}; + ####################################################################### + # Print file-wise statistics to STDOUT; can pipe to /dev/null is needed + ####################################################################### + printf STDOUT ("%s: %.2f min actual speech, %.2f min hypothesized: %.2f min overlap (%d\%), %.2f min false alarm (~%d\%)\n", + $fileID, + ($actualSpeech{$fileID}/60.0), + ($hypothesizedSpeech{$fileID}/60.0), + ($foundSpeech{$fileID}/60.0), + ($foundSpeech{$fileID}*100/($actualSpeech{$fileID}+0.01)), + ($falseAlarm{$fileID}/60.0), + ($falseAlarm{$fileID}*100/($maxEndTime{$fileID}-$actualSpeech{$fileID}+0.01))); +} + +################################################################################ +# Finally, we have everything needed to report the segmentation statistics. +################################################################################ + +printf STDERR ("------------------------------------------------------------------------\n"); +printf STDERR ("TOTAL: %.2f hrs actual speech, %.2f hrs hypothesized: %.2f hrs overlap (%d\%), %.2f hrs false alarm (~%d\%)\n", + ($totalActualSpeech/3600.0), + ($totalHypSpeech/3600.0), + ($totalFoundSpeech/3600.0), + ($totalFoundSpeech*100/($totalActualSpeech+0.000001)), + ($totalFalseAlarm/3600.0), + ($totalFalseAlarm*100/($totalNonSpeechEst+0.000001))); +printf STDERR ("\t$numShortSegs segments < 0.2 sec and $numLongSegs segments > 60.0 sec\n"); +printf STDERR ("------------------------------------------------------------------------\n"); diff --git a/egs/sre08/v1/diarization/evaluate_vad.pl b/egs/sre08/v1/diarization/evaluate_vad.pl new file mode 100755 index 00000000000..9626c60e16c --- /dev/null +++ b/egs/sre08/v1/diarization/evaluate_vad.pl @@ -0,0 +1,67 @@ +#!/usr/bin/perl -w +# Copyright 2015 Vimal Manohar (Johns Hopkins University) +# Apache 2.0. + +use strict; +use Getopt::Long; + +if (@ARGV != 2) { + print STDERR "$0:\n" . + "Usage: evaluate_vad.pl [options] \n"; + exit 1; +} + +my $ref_file = $ARGV[0]; +my $hyp_file = $ARGV[1]; + +open REF, $ref_file or die "$0: Unable to open reference vad $ref_file\n"; + +open HYP, $hyp_file or die "$0: Unable to open hypothesis vad $hyp_file\n"; + +my %hyps = (); + +while () { + chomp; + my @A = split; + $hyps{$A[0]} = [ @A[1..$#A] ]; +#% print STDERR join(' ', @{$hyps{$A[0]}}) . "\n"; +} + +#foreach (keys %hyps) { +# print STDERR $_ . join(' ', @{$hyps{$_}}) . "\n"; +#} + +while () { + chomp; + my @A = split; + my $fp = 0; + my $fn = 0; + my $cor = 0; + + my @B = @A[1..$#A]; + + my $i = 1; + my @H = @{$hyps{$A[0]}}; + + for ($i=0; $i <= $#B; $i++) { + if ( ($B[$i] == 1) && ($i <= $#H) && ($H[$i] == 2) ) { + $fp++; + } elsif ( ($B[$i] == 2) && ( ($i > $#H) || ($H[$i] == 1) ) ) { + $fn++; + } + else { + $cor++; + } + } + while ($i <= $#H) { + if ( $H[$i] == 2 ) { + $fp++; + } else { + $cor++; + } + $i++; + } + + my $n = scalar @B; + print STDOUT $A[0] . " " . scalar @H . " " . scalar @B . " $cor $fp $fn " . sprintf("%6.4f %6.4f %6.4f\n", $cor/$n, $fp/$n, $fn/$n); +} diff --git a/egs/sre08/v1/diarization/extract_ivectors.sh b/egs/sre08/v1/diarization/extract_ivectors.sh new file mode 100755 index 00000000000..c179d546eac --- /dev/null +++ b/egs/sre08/v1/diarization/extract_ivectors.sh @@ -0,0 +1,183 @@ +#!/bin/bash + +# Copyright 2013 Daniel Povey +# 2015 Vimal Manohar +# Apache 2.0. + + +# This script computes iVectors in the same format as extract_ivectors_online.sh, +# except that they are actually not really computed online, they are first computed +# per speaker and just duplicated many times. +# This is mainly intended for use in decoding, where you want the best possible +# quality of iVectors. +# +# This setup also makes it possible to use a previous decoding or alignment, to +# down-weight silence in the stats (default is --silence-weight 0.0). +# +# This is for when you use the "online-decoding" setup in an offline task, and +# you want the best possible results. + + +# Begin configuration section. +nj=30 +cmd="run.pl" +stage=0 +num_gselect=5 # Gaussian-selection using diagonal model: number of Gaussians to select +min_post=0.025 # Minimum posterior to use (posteriors below this are pruned out) +ivector_period=1 +posterior_scale=0.1 # Scale on the acoustic posteriors, intended to account for + # inter-frame correlations. Making this small during iVector + # extraction is equivalent to scaling up the prior, and will + # will tend to produce smaller iVectors where data-counts are + # small. It's not so important that this match the value + # used when training the iVector extractor, but more important + # that this match the value used when you do real online decoding + # with the neural nets trained with these iVectors. +max_count=100 # Interpret this as a number of frames times posterior scale... + # this config ensures that once the count exceeds this (i.e. + # 1000 frames, or 10 seconds, by default), we start to scale + # down the stats, accentuating the prior term. This seems quite + # important for some reason. +compress=true # If true, compress the iVectors stored on disk (it's lossy + # compression, as used for feature matrices). +silence_weight=0.0 +acwt=0.1 # used if input is a decode dir, to get best path from lattices. +mdl=final # change this if decode directory did not have ../final.mdl present. + +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + + +if [ $# != 4 ] && [ $# != 5 ]; then + echo "Usage: $0 [options] [||] " + echo " e.g.: $0 data/test exp/nnet2_online/extractor exp/tri3/decode_test exp/nnet2_online/ivectors_test" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --nj # Number of jobs (also see num-processes and num-threads)" + echo " # Ignored if or supplied." + echo " --stage # To control partial reruns" + echo " --num-gselect # Number of Gaussians to select using" + echo " # diagonal model." + echo " --min-post # Pruning threshold for posteriors" + echo " --ivector-period # How often to extract an iVector (frames)" + echo " --posterior-scale # Scale on posteriors in iVector extraction; " + echo " # affects strength of prior term." + + exit 1; +fi + +if [ $# -eq 4 ]; then + data=$1 + lang=$2 + srcdir=$3 + dir=$4 +else # 5 arguments + data=$1 + lang=$2 + srcdir=$3 + ali_or_decode_dir=$4 + dir=$5 +fi + +for f in $data/feats.scp $srcdir/final.ie $srcdir/final.dubm $srcdir/global_cmvn.stats $srcdir/splice_opts \ + $lang/phones.txt $srcdir/online_cmvn.conf $srcdir/final.mat; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +mkdir -p $dir/log +silphonelist=$(cat $lang/phones/silence.csl) || exit 1; + +if [ ! -z "$ali_or_decode_dir" ]; then + + + if [ -f $ali_or_decode_dir/ali.1.gz ]; then + if [ ! -f $ali_or_decode_dir/${mdl}.mdl ]; then + echo "$0: expected $ali_or_decode_dir/${mdl}.mdl to exist." + exit 1; + fi + nj_orig=$(cat $ali_or_decode_dir/num_jobs) || exit 1; + + if [ $stage -le 0 ]; then + rm $dir/weights.*.gz 2>/dev/null + + $cmd JOB=1:$nj_orig $dir/log/ali_to_post.JOB.log \ + gunzip -c $ali_or_decode_dir/ali.JOB.gz \| \ + ali-to-post ark:- ark:- \| \ + weight-silence-post $silence_weight $silphonelist $ali_or_decode_dir/final.mdl ark:- ark:- \| \ + post-to-weights ark:- "ark:|gzip -c >$dir/weights.JOB.gz" || exit 1; + + # put all the weights in one archive. + for j in $(seq $nj_orig); do gunzip -c $dir/weights.$j.gz; done | gzip -c >$dir/weights.gz || exit 1; + rm $dir/weights.*.gz || exit 1; + fi + + elif [ -f $ali_or_decode_dir/lat.1.gz ]; then + nj_orig=$(cat $ali_or_decode_dir/num_jobs) || exit 1; + if [ ! -f $ali_or_decode_dir/../${mdl}.mdl ]; then + echo "$0: expected $ali_or_decode_dir/../${mdl}.mdl to exist." + exit 1; + fi + + + if [ $stage -le 0 ]; then + rm $dir/weights.*.gz 2>/dev/null + + $cmd JOB=1:$nj_orig $dir/log/lat_to_post.JOB.log \ + lattice-best-path --acoustic-scale=$acwt "ark:gunzip -c $ali_or_decode_dir/lat.JOB.gz|" ark:/dev/null ark:- \| \ + ali-to-post ark:- ark:- \| \ + weight-silence-post $silence_weight $silphonelist $ali_or_decode_dir/../${mdl}.mdl ark:- ark:- \| \ + post-to-weights ark:- "ark:|gzip -c >$dir/weights.JOB.gz" || exit 1; + + # put all the weights in one archive. + for j in $(seq $nj_orig); do gunzip -c $dir/weights.$j.gz; done | gzip -c >$dir/weights.gz || exit 1; + rm $dir/weights.*.gz || exit 1; + fi + elif [ -f $ali_or_decode_dir ] && gunzip -c $ali_or_decode_dir >/dev/null; then + cp -f $ali_or_decode_dir $dir/weights.gz || exit 1; + else + echo "$0: expected ali.1.gz or lat.1.gz to exist in $ali_or_decode_dir"; + exit 1; + fi +fi + +sdata=$data/split$nj; +utils/split_data.sh $data $nj || exit 1; + +echo $ivector_period > $dir/ivector_period || exit 1; +splice_opts=$(cat $srcdir/splice_opts) + +gmm_feats="ark,s,cs:apply-cmvn-online --spk2utt=ark:$sdata/JOB/spk2utt --config=$srcdir/online_cmvn.conf $srcdir/global_cmvn.stats scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $srcdir/final.mat ark:- ark:- |" +feats="ark,s,cs:splice-feats $splice_opts scp:$sdata/JOB/feats.scp ark:- | transform-feats $srcdir/final.mat ark:- ark:- |" + + +if [ $stage -le 2 ]; then + if [ ! -z "$ali_or_decode_dir" ]; then + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ + weight-post ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ + ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ + --max-count=$max_count \ + $srcdir/final.ie "$feats" ark,s,cs:- ark,scp:$dir/ivectors_utt.JOB.ark,$dir/ivectors_utt.JOB.scp || exit 1; + else + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ + ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ + --max-count=$max_count \ + $srcdir/final.ie "$feats" ark,s,cs:- ark,scp:$dir/ivectors_utt.JOB.ark,$dir/ivectors_utt.JOB.scp || exit 1; + fi +fi + +ivector_dim=$[$(copy-vector "scp:head -n 1 $dir/ivectors_utt.1.scp |" ark,t:- | wc -w) - 3] || exit 1; +echo "$0: iVector dim is $ivector_dim" + +if [ $stage -le 5 ]; then + echo "$0: combining iVectors across jobs" + for j in $(seq $nj); do cat $dir/ivectors_utt.$j.scp; done >$dir/ivectors_utt.scp || exit 1; +fi + +echo "$0: done extracting (pseudo-online) iVectors" diff --git a/egs/sre08/v1/diarization/filter_ctm.sh b/egs/sre08/v1/diarization/filter_ctm.sh new file mode 100755 index 00000000000..05528e0f87c --- /dev/null +++ b/egs/sre08/v1/diarization/filter_ctm.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +set -e + +. path.sh + +cmd=run.pl + +. utils/parse_options.sh + +data_id=dev +ivector_affix= +ivector_dir=exp/nnet2_multicondition + +diarization_dir=$ivector_dir/diarization_${data_id}${ivector_affix}/diarization +ctm_file=exp/nnet2_multicondition/nnet_ms_a/decode_dev_aspire_whole_uniformsegmented_win10_over5_v23_voiced_256_128_128_iterfinal_pp_fg/score_13/penalty_0.5/ctm.filt +reco2file_and_channel=data/dev_aspire/reco2file_and_channel +dir=$ivector_dir/diarization_${data_id}${ivector_affix}/ctm_filter + +if [ $# -ne 4 ]; then + echo "Usage: diarization/filter_ctm.sh " + echo " e.g.: diarization/filter_ctm.sh $diarization_dir $ctm_file $reco2file_and_channel $dir" + exit 1 +fi + +diarization_dir=$1 +ctm_file=$2 +reco2file_and_channel=$3 +dir=$4 + +nj=$(cat $diarization_dir/num_jobs) + +for n in `seq $nj`; do + cat $diarization_dir/diarization_segmentation.$n.scp || exit 1 +done > $dir/diarization_segmentation.scp || exit 1 + +$cmd $dir/log/compute_ctm_conf.log \ + segmentation-compute-class-ctm-conf "ark,s:segmentation-post-process --merge-adjacent-segments scp:$dir/diarization_segmentation.scp ark:- |" $ctm_file $reco2file_and_channel ark,t:- \| diarization/convert_speaker_conf_to_labels_string.pl --min-spk-conf 0.9 '>' $dir/remove_labels.csl.txt || exit 1 + +$cmd $dir/log/get_nocrosstalk_segmentation.log \ + segmentation-remove-segments --remove-labels-rspecifier=ark,t:$dir/remove_labels.csl.txt scp:$dir/diarization_segmentation.scp ark:- \| segmentation-post-process --merge-labels=0:1:2:3 --merge-dst-label=0 --merge-adjacent-segments ark:- ark:$dir/non_crosstalk_segmentation.ark || exit 1 + +$cmd $dir/log/filter_ctm.log \ + segmentation-filter-ctm ark:$dir/non_crosstalk_segmentation.ark $ctm_file $reco2file_and_channel ${ctm_file}.nocrosstalk || exit 1 diff --git a/egs/sre08/v1/diarization/gen_vad_topo.pl b/egs/sre08/v1/diarization/gen_vad_topo.pl new file mode 100755 index 00000000000..84330c39fb9 --- /dev/null +++ b/egs/sre08/v1/diarization/gen_vad_topo.pl @@ -0,0 +1,63 @@ +#!/usr/bin/perl + +# Copyright 2012 Johns Hopkins University (author: Daniel Povey) + +# Generate a topology file. This allows control of the number of states in the +# non-silence HMMs, and in the silence HMMs. + +use Getopt::Long; + +$nonsil_self_loop_p = 0.9; +$nonsil_transition_p = 0.1; +$sil_self_loop_p = 0.5; +$sil_transition_p = 0.5; + +GetOptions('nonsil-self-loop-probability:f' => \$nonsil_self_loop_p, + 'nonsil-transition-probability:f' => \$nonsil_transition_p, + 'sil-self-loop-probability:f' => \$sil_self_loop_p, + 'sil-transition-probability:f' => \$sil_transition_p); + +if(@ARGV != 4) { + print STDERR "Usage: sid/gen_vad_topo.pl [options] \n"; + print STDERR "e.g.: sid/gen_vad_topo.pl 75 30 2 1\n"; + exit (1); +} + +($num_nonsil_states, $num_sil_states, $nonsil_phones, $sil_phones) = @ARGV; + +( $num_nonsil_states >= 1 && $num_nonsil_states <= 100 ) || die "Unexpected number of nonsilence-model states $num_nonsil_states\n"; +( $num_sil_states >= 1 && $num_sil_states <= 100 ) || die "Unexpected number of silence-model states $num_sil_states\n"; + +$nonsil_phones =~ s/:/ /g; +$sil_phones =~ s/:/ /g; +$nonsil_phones =~ m/^\d[ \d]*$/ || die "$0: bad arguments @ARGV\n"; +$sil_phones =~ m/^\d[ \d]*$/ || die "$0: bad arguments @ARGV\n"; + +print "\n"; +print "\n"; +print "\n"; +print "$nonsil_phones\n"; +print "\n"; +for ($state = 0; $state < $num_nonsil_states - 1; $state++) { + $statep1 = $state+1; + print " $state 0 $statep1 1.0 \n"; +} +$statep1 = $state+1; +print " $state 0 $state $nonsil_self_loop_p $statep1 $nonsil_transition_p \n"; +print " $num_nonsil_states \n"; # non-emitting final state. +print "\n"; + +# Now silence phones. +print "\n"; +print "\n"; +print "$sil_phones\n"; +print "\n"; +for ($state = 0; $state < $num_sil_states - 1; $state++) { + $statep1 = $state+1; + print " $state 0 $statep1 1.0 \n"; +} +$statep1 = $state+1; +print " $state 0 $state $sil_self_loop_p $statep1 $sil_transition_p \n"; +print " $num_sil_states \n"; # non-emitting final state. +print "\n"; +print "\n"; diff --git a/egs/sre08/v1/diarization/make_vad_graph.sh b/egs/sre08/v1/diarization/make_vad_graph.sh new file mode 100755 index 00000000000..3391c68cf64 --- /dev/null +++ b/egs/sre08/v1/diarization/make_vad_graph.sh @@ -0,0 +1,96 @@ +#!/bin/bash + +# steps/make_phone_graph.sh data/train_100k_nodup/ data/lang exp/tri2_ali_100k_nodup/ exp/tri2 + +# Copyright 2013 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. + +# This script makes a phone-based LM, without smoothing to unigram, that +# is to be used for segmentation, and uses that together with a model to +# make a decoding graph. +# Uses SRILM. + +# Begin configuration section. +stage=0 +cmd=run.pl +iter=final # use $iter.mdl from $model_dir +tree=tree +tscale=1.0 # transition scale. +loopscale=0.1 # scale for self-loops. +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo " e.g.: $0 exp/vad_dev/lang exp/vad_dev exp/vad_dev/graph" + echo "Makes the graph in \$dir, corresponding to the model in \$model_dir" + exit 1; +fi + +lang=$1 +model=$2/$iter.mdl +tree=$2/$tree +dir=$3 + +for f in $lang/G.fst $model $tree; do + if [ ! -f $f ]; then + echo "$0: expected $f to exist" + exit 1; + fi +done + +mkdir -p $dir $lang/tmp + +clg=$lang/tmp/CLG.fst + +if [[ ! -s $clg || $clg -ot $lang/G.fst ]]; then + echo "$0: creating CLG." + + fstcomposecontext --context-size=1 --central-position=0 \ + $lang/tmp/ilabels < $lang/G.fst | \ + fstarcsort --sort_type=ilabel > $clg + fstisstochastic $clg || echo "[info]: CLG not stochastic." +fi + +if [[ ! -s $dir/Ha.fst || $dir/Ha.fst -ot $model || $dir/Ha.fst -ot $lang/tmp/ilabels ]]; then + make-h-transducer --disambig-syms-out=$dir/disambig_tid.int \ + --transition-scale=$tscale $lang/tmp/ilabels $tree $model \ + > $dir/Ha.fst || exit 1; +fi + +if [[ ! -s $dir/HCLGa.fst || $dir/HCLGa.fst -ot $dir/Ha.fst || $dir/HCLGa.fst -ot $clg ]]; then + fsttablecompose $dir/Ha.fst $clg | fstdeterminizestar --use-log=true \ + | fstrmsymbols $dir/disambig_tid.int | fstrmepslocal | \ + fstminimizeencoded > $dir/HCLGa.fst || exit 1; + fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic" +fi + +if [[ ! -s $dir/HCLG.fst || $dir/HCLG.fst -ot $dir/HCLGa.fst ]]; then + add-self-loops --self-loop-scale=$loopscale --reorder=true \ + $model < $dir/HCLGa.fst > $dir/HCLG.fst || exit 1; + + if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then + # No point doing this test if transition-scale not 1, as it is bound to fail. + fstisstochastic $dir/HCLG.fst || echo "[info]: final HCLG is not stochastic." + fi +fi + +# keep a copy of the lexicon and a list of silence phones with HCLG... +# this means we can decode without reference to the $lang directory. + +cp $lang/words.txt $dir/ || exit 1; +mkdir -p $dir/phones +cp $lang/phones/word_boundary.* $dir/phones/ 2>/dev/null # might be needed for ctm scoring, +cp $lang/phones/align_lexicon.* $dir/phones/ 2>/dev/null # might be needed for ctm scoring, + # but ignore the error if it's not there. + +cp $lang/phones/disambig.{txt,int} $dir/phones/ 2> /dev/null +cp $lang/phones/silence.csl $dir/phones/ || exit 1; +cp $lang/phones.txt $dir/ 2> /dev/null # ignore the error if it's not there. + +# to make const fst: +# fstconvert --fst_type=const $dir/HCLG.fst $dir/HCLG_c.fst +am-info --print-args=false $model | grep pdfs | awk '{print $NF}' > $dir/num_pdfs diff --git a/egs/sre08/v1/diarization/prepare_data.sh b/egs/sre08/v1/diarization/prepare_data.sh new file mode 100755 index 00000000000..9c76beb765f --- /dev/null +++ b/egs/sre08/v1/diarization/prepare_data.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail + +# Begin configuration section. +cmd=run.pl +nj=32 +stage=-10 +add_pitch=false +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: diarization/prepare_data.sh " + echo " e.g.: diarization/prepare_data.sh data/dev exp/vad_dev mfcc" + echo "Main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +data=$1 +tmpdir=$2 +featdir=$3 + +if [ $stage -le 1 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $featdir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$featdir/storage $featdir/storage + fi + if $add_pitch; then + steps/make_mfcc_pitch.sh --mfcc-config conf/mfcc_vad.conf --nj $nj --cmd "$cmd" \ + $data $tmpdir/make_mfcc_vad $featdir || exit 1 + else + steps/make_mfcc.sh --mfcc-config conf/mfcc_vad.conf --nj $nj --cmd "$cmd" \ + $data $tmpdir/make_mfcc_vad $featdir || exit 1 + fi +fi + +if [ $stage -le 3 ]; then + steps/compute_cmvn_stats.sh $data $tmpdir/make_mfcc_vad $featdir || exit 1 +fi + +#if [ $stage -le 2 ]; then +# sid/compute_vad_decision.sh --vad-config conf/vad.conf --cmd $cmd --nj $nj \ +# $data $tmpdir/make_vad $featdir || exit 1 +#fi diff --git a/egs/sre08/v1/diarization/prepare_vad_lang.sh b/egs/sre08/v1/diarization/prepare_vad_lang.sh new file mode 100755 index 00000000000..448e2f922b2 --- /dev/null +++ b/egs/sre08/v1/diarization/prepare_vad_lang.sh @@ -0,0 +1,200 @@ +#!/bin/bash +# Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal +# 2014 Guoguo Chen +# 2015 Hainan Xu + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# This script prepares a directory such as data/lang/, in the standard format, +# given a source directory containing a dictionary lexicon.txt in a form like: +# word phone1 phone2 ... phoneN +# per line (alternate prons would be separate lines), or a dictionary with probabilities +# called lexiconp.txt in a form: +# word pron-prob phone1 phone2 ... phoneN +# (with 0.0 < pron-prob <= 1.0); note: if lexiconp.txt exists, we use it even if +# lexicon.txt exists. +# and also files silence_phones.txt, nonsilence_phones.txt, optional_silence.txt +# and extra_questions.txt +# Here, silence_phones.txt and nonsilence_phones.txt are lists of silence and +# non-silence phones respectively (where silence includes various kinds of +# noise, laugh, cough, filled pauses etc., and nonsilence phones includes the +# "real" phones.) +# In each line of those files is a list of phones, and the phones on each line +# are assumed to correspond to the same "base phone", i.e. they will be +# different stress or tone variations of the same basic phone. +# The file "optional_silence.txt" contains just a single phone (typically SIL) +# which is used for optional silence in the lexicon. +# extra_questions.txt might be empty; typically will consist of lists of phones, +# all members of each list with the same stress or tone; and also possibly a +# list for the silence phones. This will augment the automtically generated +# questions (note: the automatically generated ones will treat all the +# stress/tone versions of a phone the same, so will not "get to ask" about +# stress or tone). + +# This script constructs a host of other +# derived files, that go in lang/. + +# Begin configuration section. +nonsil_self_loop_probability=0.9 +nonsil_transition_probability=0.1 +sil_self_loop_probability=0.9 +sil_transition_probability=0.1 +num_sil_states=5 +num_nonsil_states=3 +# end configuration sections + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "usage: utils/prepare_lang.sh " + echo "e.g.: utils/prepare_lang.sh data/local/dict data/local/lang data/lang" + echo " should contain the following files:" + echo " extra_questions.txt lexicon.txt nonsilence_phones.txt optional_silence.txt silence_phones.txt" + echo "See http://kaldi.sourceforge.net/data_prep.html#data_prep_lang_creating for more info." + echo "options: " + echo " --num-sil-states # default: 5, #states in silence models." + echo " --num-nonsil-states # default: 3, #states in non-silence models." + exit 1; +fi + +srcdir=$1 +tmpdir=$2 +dir=$3 +mkdir -p $dir $tmpdir $dir/phones + +[ -f path.sh ] && . ./path.sh + +! utils/validate_dict_dir.pl $srcdir && \ + echo "*Error validating directory $srcdir*" && exit 1; + +if [[ ! -f $srcdir/lexicon.txt ]]; then + echo "**Creating $dir/lexicon.txt from $dir/lexiconp.txt" + perl -ape 's/(\S+\s+)\S+\s+(.+)/$1$2/;' < $srcdir/lexiconp.txt > $srcdir/lexicon.txt || exit 1; +fi +if [[ ! -f $srcdir/lexiconp.txt ]]; then + echo "**Creating $srcdir/lexiconp.txt from $srcdir/lexicon.txt" + perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $srcdir/lexicon.txt > $srcdir/lexiconp.txt || exit 1; +fi + +if ! utils/validate_dict_dir.pl $srcdir >&/dev/null; then + utils/validate_dict_dir.pl $srcdir # show the output. + echo "Validation failed (second time)" + exit 1; +fi + +cp -f $srcdir/lexiconp.txt $tmpdir/lexiconp.txt + +cat $srcdir/silence_phones.txt $srcdir/nonsilence_phones.txt | \ + sed 's/ /\n/g' | awk '(NF>0){print}' > $tmpdir/phones +paste -d' ' $tmpdir/phones $tmpdir/phones > $tmpdir/phone_map.txt + +mkdir -p $dir/phones # various sets of phones... + +cat $srcdir/{,non}silence_phones.txt | utils/apply_map.pl $tmpdir/phone_map.txt > $dir/phones/sets.txt +cat $dir/phones/sets.txt | awk '{print "shared", "split", $0;}' > $dir/phones/roots.txt + +cat $srcdir/silence_phones.txt | utils/apply_map.pl $tmpdir/phone_map.txt | \ + awk '{for(n=1;n<=NF;n++) print $n;}' > $dir/phones/silence.txt +cat $srcdir/nonsilence_phones.txt | utils/apply_map.pl $tmpdir/phone_map.txt | \ + awk '{for(n=1;n<=NF;n++) print $n;}' > $dir/phones/nonsilence.txt + +# if extra_questions.txt is empty, it's OK. +cat $srcdir/extra_questions.txt 2>/dev/null | utils/apply_map.pl $tmpdir/phone_map.txt \ + >$dir/phones/extra_questions.txt + +# Create phone symbol table. +cat $dir/phones/{silence,nonsilence}.txt | \ + awk '{n=NR; print $1, n;}' > $dir/phones.txt + +cat $tmpdir/lexiconp.txt | awk '{print $1}' | sort | uniq | awk ' + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); + }' > $dir/words.txt || exit 1; + +# create $dir/phones/align_lexicon.{txt,int}. +# This is the new-new style of lexicon aligning. + +# First remove pron-probs from the lexicon. +perl -ape 's/(\S+\s+)\S+\s+(.+)/$1$2/;' <$tmpdir/lexiconp.txt >$tmpdir/align_lexicon.txt + +cat $tmpdir/align_lexicon.txt | \ + perl -ane '@A = split; print $A[0], " ", join(" ", @A), "\n";' | sort | uniq > $dir/phones/align_lexicon.txt + +# create phones/align_lexicon.int +cat $dir/phones/align_lexicon.txt | utils/sym2int.pl -f 3- $dir/phones.txt | \ + utils/sym2int.pl -f 1-2 $dir/words.txt > $dir/phones/align_lexicon.int + +# Create the basic L.fst without disambiguation symbols, for use +# in training. + +utils/make_lexicon_fst.pl --pron-probs $tmpdir/lexiconp.txt | \ + fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ + --keep_isymbols=false --keep_osymbols=false | \ + fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; + +# Create these lists of phones in colon-separated integer list form too, +# for purposes of being given to programs as command-line options. +for f in silence nonsilence; do + utils/sym2int.pl $dir/phones.txt <$dir/phones/$f.txt >$dir/phones/$f.int + utils/sym2int.pl $dir/phones.txt <$dir/phones/$f.txt | \ + awk '{printf(":%d", $1);} END{printf "\n"}' | sed s/:// > $dir/phones/$f.csl || exit 1; +done + +for x in sets extra_questions; do + utils/sym2int.pl $dir/phones.txt <$dir/phones/$x.txt > $dir/phones/$x.int || exit 1; +done + +utils/sym2int.pl -f 3- $dir/phones.txt <$dir/phones/roots.txt \ + > $dir/phones/roots.int || exit 1; + +silphonelist=`cat $dir/phones/silence.csl` +nonsilphonelist=`cat $dir/phones/nonsilence.csl` +diarization/gen_vad_topo.pl \ + --nonsil-self-loop-probability $nonsil_self_loop_probability \ + --nonsil-transition-probability $nonsil_transition_probability \ + --sil-self-loop-probability $sil_self_loop_probability \ + --sil-transition-probability $sil_transition_probability \ + $num_nonsil_states $num_sil_states $nonsilphonelist $silphonelist >$dir/topo + +rm -f $dir/L_disambig.fst 2>/dev/null +ln -s L.fst $dir/L_disambig.fst + +num_words=`cat $dir/words.txt | wc -l 2> /dev/null` || exit 1 +#prob=`perl -e "print log($num_words) / log(10)"` +prob=`perl -e "print log($[num_words+1])"` +while IFS=$'\n' read line; do + word=`echo $line | awk '{print $1}'` + echo 0 0 $word $word $prob +done < $dir/words.txt > $tmpdir/G.txt +echo 0 $prob >> $tmpdir/G.txt + +fstcompile --isymbols=$dir/words.txt --osymbols=$dir/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + < $tmpdir/G.txt > $dir/G.fst || exit 1 + +#echo "$(basename $0): validating output directory" +#! utils/validate_lang.pl $dir && echo "$(basename $0): error validating output" && exit 1; + + +exit 0; + diff --git a/egs/sre08/v1/diarization/split_vad_data.sh b/egs/sre08/v1/diarization/split_vad_data.sh new file mode 100755 index 00000000000..22bbf48e569 --- /dev/null +++ b/egs/sre08/v1/diarization/split_vad_data.sh @@ -0,0 +1,124 @@ +#!/bin/bash +# Copyright 2010-2013 Microsoft Corporation +# Johns Hopkins University (Author: Daniel Povey) +# 2015 Vimal Manohar + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +if [ $# != 2 ]; then + echo "Usage: split_vad_data.sh [--per-utt] " + echo "This script will not split the data-dir if it detects that the output is newer than the input." + echo "By default it splits per speaker (so each speaker is in only one split dir)," + echo "but with the --per-utt option it will ignore the speaker information while splitting." + exit 1 +fi + +data=$1 +numsplit=$2 + +if [ $numsplit -le 0 ]; then + echo "Invalid num-split argument $numsplit"; + exit 1; +fi + +n=0; +feats="" +wavs="" +utt2spks="" +texts="" + +nu=`cat $data/utt2spk | wc -l` +nf=`cat $data/feats.scp 2>/dev/null | wc -l` +nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file +if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then + echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can " + echo "** use utils/fix_data_dir.sh $data to fix this." +fi +if [ -f $data/text ] && [ $nu -ne $nt ]; then + echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can " + echo "** use utils/fix_data_dir.sh to fix this." +fi + +s1=$data/split$numsplit/1 +if [ ! -d $s1 ]; then + need_to_split=true +else + need_to_split=false + for f in utt2spk spk2utt spk2warp feats.scp text wav.scp cmvn.scp spk2gender \ + vad.scp segments reco2file_and_channel utt2lang; do + if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then + need_to_split=true + fi + done +fi + +if ! $need_to_split; then + exit 0; +fi + +for n in `seq $numsplit`; do + mkdir -p $data/split$numsplit/$n + spk2utts="$spk2utts $data/split$num_split/$n/spk2utt" +done + +# If lockfile is not installed, just don't lock it. It's not a big deal. +which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock + +utils/split_scp.pl $data/spk2utt $spk2utts || exit 1 + +for n in `seq $numsplit`; do + dsn=$data/split$numsplit/$n + utils/spk2utt_to_utt2spk.pk $dsn/spk2utt > $dsn/utt2spk || exit 1; +done + +maybe_wav_scp= +if [ ! -f $data/segments ]; then + maybe_wav_scp=wav.scp # If there is no segments file, then wav file is + # indexed per utt. +fi + +# split some things that are indexed by utterance. +for f in feats.scp text vad.scp utt2lang $maybe_wav_scp; do + if [ -f $data/$f ]; then + utils/filter_scps.pl JOB=1:$numsplit \ + $data/split$numsplit/JOB/utt2spk $data/$f $data/split$numsplit/JOB/$f || exit 1; + fi +done + +# split some things that are indexed by speaker +for f in spk2gender spk2warp cmvn.scp; do + if [ -f $data/$f ]; then + utils/filter_scps.pl JOB=1:$numsplit \ + $data/split$numsplit/JOB/spk2utt $data/$f $data/split$numsplit/JOB/$f || exit 1; + fi +done + +for n in `seq $numsplit`; do + dsn=$data/split$numsplit/$n + if [ -f $data/segments ]; then + utils/filter_scp.pl $dsn/utt2spk $data/segments > $dsn/segments + awk '{print $2;}' $dsn/segments | sort | uniq > $data/tmp.reco # recording-ids. + if [ -f $data/reco2file_and_channel ]; then + utils/filter_scp.pl $data/tmp.reco $data/reco2file_and_channel > $dsn/reco2file_and_channel + fi + if [ -f $data/wav.scp ]; then + utils/filter_scp.pl $data/tmp.reco $data/wav.scp >$dsn/wav.scp + fi + rm $data/tmp.reco + fi # else it would have been handled above, see maybe_wav. +done + +rm -f $data/.split_lock + +exit 0 diff --git a/egs/sre08/v1/diarization/train_vad_gmm.sh b/egs/sre08/v1/diarization/train_vad_gmm.sh new file mode 100755 index 00000000000..7713352372f --- /dev/null +++ b/egs/sre08/v1/diarization/train_vad_gmm.sh @@ -0,0 +1,175 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail + +# Begin configuration section. +cmd=run.pl +nj=4 +speech_duration=75 +sil_duration=30 +speech_num_gauss=16 +sil_num_gauss=4 +num_iters=20 +impr_thres=0.002 +stage=-10 +cleanup=true +select_top_frames=true +top_frames_threshold=0.16 +bottom_frames_threshold=0.04 +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 2 ]; then + echo "Usage: diarization/train_vad_gmm.sh " + echo " e.g.: diarization/train_vad_gmm.sh data/dev exp/vad_dev" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --num-iters <#iters> # Number of iterations of E-M" + exit 1; +fi + +data=$1 +dir=$2 + +function build_0gram { +wordlist=$1; lm=$2 +echo "=== Building zerogram $lm from ${wordlist}. ..." +awk '{print $1}' $wordlist | sort -u > $lm +python -c """ +import math +with open('$lm', 'r+') as f: + lines = f.readlines() + p = math.log10(1/float(len(lines))); + lines = ['%f\\t%s'%(p,l) for l in lines] + f.seek(0); f.write('\\n\\\\data\\\\\\nngram 1= %d\\n\\n\\\\1-grams:\\n' % len(lines)) + f.write(''.join(lines) + '\\\\end\\\\') +""" +} + +for f in $data/feats.scp $data/vad.scp; do + [ ! -s $f ] && echo "$0: could not find $f or $f is empty" && exit 1 +done + +feat_dim=`feat-to-dim "scp:head -n 1 $data/feats.scp |" ark,t:- | awk '{print $2}'` || exit 1 + +# Prepare a lang directory +if [ $stage -le -2 ]; then + mkdir -p $dir/local + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states 1 --num-nonsil-states 1 \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fstisstochastic $dir/lang/G.fst || echo "[info]: G not stochastic." + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict $dir/local/lang $dir/lang_test || exit 1 +fi + +if [ $stage -le -1 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang_test/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans_test.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_test $dir/lang_test $dir $dir/graph_test || exit 1 +fi + +utils/split_data.sh $data $nj || exit 1 +feats="ark:copy-feats scp:$data/feats.scp ark:- |" + +if [ $stage -le 0 ]; then + + if ! $select_top_frames; then + $cmd $dir/log/init_gmm_speech.log \ + gmm-global-init-from-feats --num-gauss=$speech_num_gauss --num-iters=10 \ + "$feats select-voiced-frames ark:- scp:$data/vad.scp ark:- |" \ + $dir/speech.0.mdl || exit 1 + $cmd $dir/log/init_gmm_silence.log \ + gmm-global-init-from-feats --num-gauss=$sil_num_gauss --num-iters=6 \ + "$feats select-voiced-frames --select-unvoiced-frames=true ark:- scp:$data/vad.scp ark:- |" \ + $dir/silence.0.mdl || exit 1 + else + $cmd $dir/log/init_gmm_speech.log \ + gmm-global-init-from-feats --num-gauss=$speech_num_gauss --num-iters=12 \ + "$feats select-top-frames --top-frames-proportion=$top_frames_threshold ark:- ark:- |" \ + $dir/speech.0.mdl || exit 1 + $cmd $dir/log/init_gmm_silence.log \ + gmm-global-init-from-feats --num-gauss=$sil_num_gauss --num-iters=8 \ + "$feats select-top-frames --bottom-frames-proportion=$bottom_frames_threshold --top-frames-proportion=0.0 ark:- ark:- |" \ + $dir/silence.0.mdl || exit 1 + fi + + { + cat $dir/trans.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $dir/silence.0.mdl - + gmm-global-copy --binary=false $dir/speech.0.mdl - + } > $dir/0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters ]; do + $cmd $dir/log/decode.$x.log \ + gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:$dir/$x.ali || exit 1 + + $cmd $dir/log/update.$x.log \ + gmm-acc-stats-ali \ + $dir/$x.mdl "$feats" \ + ark:$dir/$x.ali - \| \ + gmm-est $dir/$x.mdl - $dir/$[x+1].mdl || exit 1 + + objf_impr=$(cat $dir/log/update.$x.log | grep "GMM update: Overall .* objective function" | perl -pe 's/.*GMM update: Overall (\S+) objective function .*/\$1/') + + if [ "$(perl -e "if ($objf_impr < $impr_thres) { print true; }")" == true ]; then + break; + fi + + x=$[x+1] + done + + rm -f $dir/final.mdl 2>/dev/null || true + cp $dir/$x.mdl $dir/final.mdl + + ( + copy-transition-model --binary=false $dir/trans_test.mdl - + gmm-copy --write-tm=false --binary=false $dir/$x.mdl - + ) | gmm-copy - $dir/final.mdl + + $cmd $dir/log/decode.final.log \ + gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/final.mdl $dir/graph_test/HCLG.fst \ + "$feats" ark:/dev/null ark:$dir/final.ali || exit 1 +fi + +if $cleanup; then + for x in `seq $[num_iters - 1]`; do + if [ $[x % 10] -ne 0 ]; then + rm $dir/$x.mdl + fi + done +fi + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log + diff --git a/egs/sre08/v1/diarization/train_vad_gmm_supervised.sh b/egs/sre08/v1/diarization/train_vad_gmm_supervised.sh new file mode 100755 index 00000000000..e55d24988a2 --- /dev/null +++ b/egs/sre08/v1/diarization/train_vad_gmm_supervised.sh @@ -0,0 +1,245 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail + +# Begin configuration section. +cmd=run.pl +speech_max_gauss=64 +sil_max_gauss=32 +noise_max_gauss=16 +sil_num_gauss_init=4 +speech_num_gauss_init=4 +noise_num_gauss_init=4 +train_noise_gmm=false +num_iters=10 +stage=-10 +cleanup=true +top_frames_threshold=1.0 +bottom_frames_threshold=1.0 +ignore_energy=true +add_zero_crossing_feats=true +add_frame_snrs=false +zero_crossings_scp= +frame_snrs_scp= +io_opts="--max-jobs-run 10" +nj=4 +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "Usage: diarization/train_vad_gmm_supervised.sh " + echo " e.g.: diarization/train_vad_gmm_supervised.sh data/dev exp/tri4_ali/vad/vad.scp exp/vad_dev" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --num-iters <#iters> # Number of iterations of E-M" + exit 1; +fi + +data=$1 +vad_scp=$2 +dir=$3 + +mkdir -p $dir + +feat_dim=`head -n 1 $data/feats.scp | feat-to-dim scp:- ark,t:- | awk '{print $2}'` || exit 1 + +ignore_energy_opts= +if $ignore_energy; then + ignore_energy_opts="select-feats 1-$[feat_dim-1] ark:- ark:- |" +fi + +echo "$ignore_energy_opts" > $dir/ignore_energy_opts +echo "$add_zero_crossing_feats" > $dir/add_zero_crossing_feats +echo "$add_frame_snrs" > $dir/add_frame_snrs + +for f in $data/feats.scp $data/utt2spk; do + [ ! -s $f ] && echo "$0: could not find $f or $f is empty" && exit 1 +done + +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" +zero_crossing_opts= + +if [ $stage -le -4 ]; then + ########################################################################### + # Prepare data. + # Split the vad in the same way as the data + ########################################################################### + rm -rf $dir/data + utils/copy_data_dir.sh $data $dir/data + + utils/filter_scp.pl $vad_scp $data/feats.scp > $dir/data/feats.scp || exit 1 + # Remove bad lines + if [ -f $data/text ]; then + grep -v "IGNORE_TIME_SEGMENT_IN_SCORING" $data/text > $dir/data/text + fi + + utils/fix_data_dir.sh $dir/data || exit 1 +fi + +split_data.sh $dir/data $nj || exit 1 +split_files= +for n in `seq $nj`; do + split_files="$split_files $dir/data/vad.$n.scp" +done + +utils/filter_scp.pl $dir/data/utt2spk $vad_scp | split_scp.pl --utt2spk=$dir/data/utt2spk - $split_files || exit 1 + +if [ $stage -le -3 ]; then + ########################################################################### + # Add zero-crossing and high-frequency content feats + ########################################################################### + if $add_zero_crossing_feats; then + if [ ! -z "$zero_crossings_scp" ]; then + [ ! -s $zero_crossings_scp ] && echo "$zero_crossings_scp does not exist or is empty!" && exit 1 + $cmd JOB=1:$nj $dir/log/copy_zero_crossings.JOB.log \ + utils/filter_scp.pl $dir/data/split$nj/JOB/utt2spk \ + $zero_crossings_scp '>' $dir/data/zero_crossings.JOB.scp || exit 1 + else + if [ -f $data/segments ]; then + $cmd $io_opts JOB=1:$nj $dir/log/compute_zero_crossing.JOB.log \ + extract-segments scp:$dir/data/split$nj/JOB/wav.scp $dir/data/split$nj/JOB/segments ark:- \| \ + compute-zero-crossings $zc_opts ark:- ark,scp:$dir/data/zero_crossings.JOB.ark,$dir/data/zero_crossings.JOB.scp || exit 1 + else + $cmd $io_opts JOB=1:$nj $dir/log/compute_zero_crossing.JOB.log \ + compute-zero-crossings $zc_opts scp:$dir/data/split$nj/JOB/wav.scp ark,scp:$dir/data/zero_crossings.JOB.ark,$dir/data/zero_crossings.JOB.scp || exit 1 + fi + fi + + [ ! -f $dir/data/zero_crossings.1.scp ] && exit 1 + + cat $dir/data/zero_crossings.{?,??}.scp > $dir/data/zero_crossings.scp || exit 1 + fi + + if $add_frame_snrs; then + [ -z "$frame_snrs_scp" ] && echo "$0: add-frame-snrs is true but frame-snrs-scp is not supplied" && exit 1 + for n in `seq $nj`; do + utils/filter_scp.pl $dir/data/split$nj/$n/utt2spk $frame_snrs_scp > $dir/data/frame_snrs.$n.scp + done + fi + +fi + +########################################################################### +# Get appropriate $feats variable: +# Apply CMVN. Note that we don't apply CMVN to the zero-crossing feats. +# Remove energy from the features. +# Add zero-crossing feats. +# Add deltas. +########################################################################### +feats="ark:apply-cmvn-sliding scp:$dir/data/split$nj/JOB/feats.scp ark:- |${ignore_energy_opts}" + +if $add_zero_crossing_feats; then + feats="${feats}paste-feats ark:- scp:$dir/data/zero_crossings.JOB.scp ark:- |" +fi + +if $add_frame_snrs; then + feats="${feats}paste-feats ark:- \"ark:vector-to-feat scp:$dir/data/frame_snrs.JOB.scp ark:- |\" ark:- |" +fi + +feats="${feats}add-deltas ark:- ark:- |" + +if [ $stage -le -2 ]; then + $cmd JOB=1:$nj $dir/log/select_feats_init_speech.JOB.log \ + segmentation-init-from-ali scp:$dir/data/vad.JOB.scp ark:- \| \ + select-feats-from-segmentation --select-label=1 --selection-padding=10 \ + "$feats" ark:- \ + ark:$dir/init_feats_speech.JOB.ark || exit 1 + + $cmd JOB=1:$nj $dir/log/select_feats_init_silence.JOB.log \ + segmentation-init-from-ali scp:$dir/data/vad.JOB.scp ark:- \| \ + select-feats-from-segmentation --select-label=0 --selection-padding=2 \ + "$feats" ark:- \ + ark:$dir/init_feats_silence.JOB.ark || exit 1 + + if $train_noise_gmm; then + $cmd JOB=1:$nj $dir/log/select_feats_init_noise.JOB.log \ + segmentation-init-from-ali scp:$dir/data/vad.JOB.scp ark:- \| \ + select-feats-from-segmentation --select-label=2 --selection-padding=2 \ + "$feats" ark:- \ + ark:$dir/init_feats_noise.JOB.ark || exit 1 + fi + +fi + +speech_num_gauss=$speech_num_gauss_init +sil_num_gauss=$sil_num_gauss_init +noise_num_gauss=$noise_num_gauss_init + +if [ $stage -le -1 ]; then + $cmd $dir/log/init_gmm_speech.log \ + gmm-global-init-from-feats --num-gauss=$speech_num_gauss --num-iters=$[speech_num_gauss + 2] \ + "ark:cat $dir/init_feats_speech.{?,??,???}.ark |" $dir/speech.0.mdl || exit 1 + + $cmd $dir/log/init_gmm_silence.log \ + gmm-global-init-from-feats --num-gauss=$sil_num_gauss --num-iters=$[sil_num_gauss + 2] \ + "ark:cat $dir/init_feats_silence.{?,??,???}.ark |" $dir/silence.0.mdl || exit 1 + + if $train_noise_gmm; then + $cmd $dir/log/init_gmm_noise.log \ + gmm-global-init-from-feats --num-gauss=$noise_num_gauss --num-iters=$[noise_num_gauss + 2] \ + "ark:cat $dir/init_feats_noise.{?,??,???}.ark |" $dir/noise.0.mdl || exit 1 + fi +fi + +x=0 +while [ $x -le $num_iters ]; do + if [ $stage -le $x ]; then + $cmd JOB=1:$nj $dir/log/acc_gmm_stats_speech.$x.JOB.log \ + gmm-global-acc-stats $dir/speech.$x.mdl \ + "ark:copy-feats ark:$dir/init_feats_speech.JOB.ark ark:- |" \ + $dir/speech_accs.$x.JOB || exit 1 + + $cmd JOB=1:$nj $dir/log/acc_gmm_stats_silence.$x.JOB.log \ + gmm-global-acc-stats $dir/silence.$x.mdl \ + "ark:copy-feats ark:$dir/init_feats_silence.JOB.ark ark:- |" \ + $dir/silence_accs.$x.JOB || exit 1 + + if $train_noise_gmm; then + $cmd JOB=1:$nj $dir/log/acc_gmm_stats_noise.$x.JOB.log \ + gmm-global-acc-stats $dir/noise.$x.mdl \ + "ark:copy-feats ark:$dir/init_feats_noise.JOB.ark ark:- |" \ + $dir/noise_accs.$x.JOB || exit 1 + fi + + $cmd $dir/log/gmm_est_speech.$x.log \ + gmm-global-est --mix-up=$speech_num_gauss $dir/speech.$x.mdl \ + "gmm-global-sum-accs - $dir/speech_accs.$x.* |" \ + $dir/speech.$[x+1].mdl || exit 1 + + $cmd $dir/log/gmm_est_silence.$x.log \ + gmm-global-est --mix-up=$sil_num_gauss $dir/silence.$x.mdl \ + "gmm-global-sum-accs - $dir/silence_accs.$x.* |" \ + $dir/silence.$[x+1].mdl || exit 1 + + if $train_noise_gmm; then + $cmd $dir/log/gmm_est_noise.$x.log \ + gmm-global-est --mix-up=$sil_num_gauss $dir/noise.$x.mdl \ + "gmm-global-sum-accs - $dir/noise_accs.$x.* |" \ + $dir/noise.$[x+1].mdl || exit 1 + fi + fi + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss * 2] + fi + if [ $speech_num_gauss -lt $speech_max_gauss ]; then + speech_num_gauss=$[speech_num_gauss * 2] + fi + if $train_noise_gmm && [ $noise_num_gauss -lt $noise_max_gauss ]; then + noise_num_gauss=$[noise_num_gauss * 2] + fi + x=$[x+1] + +done + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log diff --git a/egs/sre08/v1/diarization/train_vad_gmm_v0.sh b/egs/sre08/v1/diarization/train_vad_gmm_v0.sh new file mode 100755 index 00000000000..9c8e7ce3d71 --- /dev/null +++ b/egs/sre08/v1/diarization/train_vad_gmm_v0.sh @@ -0,0 +1,164 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail + +# Begin configuration section. +cmd=run.pl +nj=4 +speech_duration=75 +sil_duration=30 +speech_num_gauss=16 +sil_num_gauss=4 +num_iters=20 +impr_thres=0.002 +stage=-10 +cleanup=true +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 2 ]; then + echo "Usage: diarization/train_vad_gmm.sh " + echo " e.g.: diarization/train_vad_gmm.sh data/dev exp/vad_dev" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --num-iters <#iters> # Number of iterations of E-M" + exit 1; +fi + +data=$1 +dir=$2 + +function build_0gram { +wordlist=$1; lm=$2 +echo "=== Building zerogram $lm from ${wordlist}. ..." +awk '{print $1}' $wordlist | sort -u > $lm +python -c """ +import math +with open('$lm', 'r+') as f: + lines = f.readlines() + p = math.log10(1/float(len(lines))); + lines = ['%f\\t%s'%(p,l) for l in lines] + f.seek(0); f.write('\\n\\\\data\\\\\\nngram 1= %d\\n\\n\\\\1-grams:\\n' % len(lines)) + f.write(''.join(lines) + '\\\\end\\\\') +""" +} + +for f in $data/feats.scp $data/vad.scp; do + [ ! -s $f ] && echo "$0: could not find $f or $f is empty" && exit 1 +done + +feat_dim=`feat-to-dim "scp:head -n 1 $data/feats.scp |" ark,t:- | awk '{print $2}'` || exit 1 + +# Prepare a lang directory +if [ $stage -le -2 ]; then + mkdir -p $dir/local + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fstisstochastic $dir/lang/G.fst || echo "[info]: G not stochastic." +fi + +if [ $stage -le -1 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 +fi + +if [ $stage -le 0 ]; then +mkdir -p $dir/q +utils/split_data.sh $data $nj || exit 1 + +for n in `seq $nj`; do + cat < $dir/q/do_vad.$n.sh +set -e +set -o pipefail + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +while IFS=$'\n' read line; do + feats="ark:echo \$line | copy-feats scp:- ark:- |" + utt_id=\$(echo \$line | awk '{print \$1}') + + run.pl $dir/log/init_speech_model.\$utt_id.log \ + gmm-global-init-from-feats --num-gauss=$speech_num_gauss \ + "\$feats select-voiced-frames ark:- scp:$data/vad.scp ark:- |" \ + $dir/\$utt_id.speech.0.mdl || exit 1 + run.pl $dir/log/init_sil_model.\$utt_id.log \ + gmm-global-init-from-feats --num-gauss=$sil_num_gauss \ + "\$feats select-voiced-frames --select-unvoiced-frames=true ark:- scp:$data/vad.scp ark:- |" \ + $dir/\$utt_id.silence.0.mdl || exit 1 + + { + cat $dir/trans.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $dir/\$utt_id.silence.0.mdl - + gmm-global-copy --binary=false $dir/\$utt_id.speech.0.mdl - + } > $dir/\$utt_id.0.mdl || exit 1 + + x=0 + while [ \$x -lt $num_iters ]; do + run.pl $dir/log/decode.\$utt_id.\$x.log gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/\$utt_id.\$x.mdl $dir/graph/HCLG.fst \ + "\$feats" ark:/dev/null ark:$dir/\$utt_id.\$x.ali || exit 1 + + run.pl $dir/log/update.\$utt_id.\$x.log gmm-acc-stats-ali \ + $dir/\$utt_id.\$x.mdl "\$feats" \ + ark:$dir/\$utt_id.\$x.ali - \| \ + gmm-est $dir/\$utt_id.\$x.mdl - $dir/\$utt_id.\$[x+1].mdl || exit 1 + + objf_impr=\$(cat $dir/log/decode.\$utt_id.\$x.log | grep "GMM update: .* objective function improvements" | perl -pe 's/.*GMM update: Overall (\S+) objective function improvement/\$1/') + + if [ "\$(perl -e "{if (\$objf_impr < $impr_thres) { print true; }")" == true ]; then + break; + fi + + x=\$[x+1] + done + + run.pl $dir/log/decode.\$utt_id.\$x.log gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/\$utt_id.\$x.mdl $dir/graph/HCLG.fst \ + "\$feats" ark:/dev/null ark:$dir/\$utt_id.\$x.ali || exit 1 + + rm -f $dir/\$utt_id.final.mdl 2>/dev/null || true + cp \$utt_id.\$x.mdl $dir/\$utt_id.final.mdl +done < $data/split$nj/$n/feats.scp +EOF +done +fi + +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj $dir/log/do_vad_job.JOB.log bash $dir/q/do_vad.JOB.sh || exit 1 +fi + +if $cleanup; then + for x in `seq $[num_iters - 1]`; do + if [ $[x % 10] -ne 0 ]; then + rm $dir/*.x.mdl + fi + done +fi + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log diff --git a/egs/sre08/v1/diarization/vad_gmm_2models.sh b/egs/sre08/v1/diarization/vad_gmm_2models.sh new file mode 100755 index 00000000000..8b8d2cbd285 --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_2models.sh @@ -0,0 +1,712 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -u +set -o pipefail + +cmd=run.pl +stage=-100 +allow_partial=true +try_merge_speech_noise=false +output_lattice=false +write_feats=false + +## Features paramters +window_size=100 # 1s +min_data=200 +frames_per_gaussian=2000 +num_bins=100 +num_sil_states=30 +num_nonsil_states=75 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=8 +sil_gauss_incr=0 +sound_gauss_incr=2 +sample_per_gaussian=2000 +num_iters_init=3 +num_iters=5 +min_sil_variance=1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=100000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2=10 + +## Phase 3 parameters +num_frames_silence_phase3_init=2000 +num_frames_speech_phase3_init=2000 +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +sil_max_gauss_phase4=8 +speech_max_gauss_phase4=16 +sil_gauss_incr_phase3=1 +sil_gauss_incr_phase4=1 +speech_gauss_incr_phase4=2 +num_iters_phase3=5 +num_iters_phase4=5 + +## Phase 5 parameters +num_frames_silence_phase5_init=2000 +num_frames_speech_phase5_init=2000 +sil_num_gauss_init_phase5=2 +speech_num_gauss_init_phase5=2 +sil_max_gauss_phase5=5 +speech_max_gauss_phase5=12 +sil_gauss_incr_phase5=1 +speech_gauss_incr_phase5=2 +num_iters_phase5=7 + +speech_to_sil_ratio=1 +frame_snrs_scp= +use_bootstrap_vad=false + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 4 ]; then + echo "Usage: vad_gmm_2models.sh " + echo " e.g.: vad_gmm_2models.sh data/rt05_eval exp/librispeech_s5/vad_model/{silence,speech}.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +dir=$4 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 +phase5_dir=$dir/phase5 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +init_model_dir=`dirname $init_speech_model` +ignore_energy_opts=`cat $init_model_dir/ignore_energy_opts` || exit 1 +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 +add_frame_snrs=`cat $init_model_dir/add_frame_snrs` || exit 1 + +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" + +# Prepare a lang directory +if [ $stage -le -4 ]; then + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + mkdir -p $dir/local/dict_2class + mkdir -p $dir/local/lm_2class + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo "1" > $dir/local/dict_2class/silence_phones.txt + echo "1" > $dir/local/dict_2class/optional_silence.txt + echo "2" > $dir/local/dict_2class/nonsilence_phones.txt + echo "3" >> $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict_2class/lexicon.txt + echo -e "1 1\n2 2\n3 3" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict_2class/extra_questions.txt + echo -e "1\n2\n1 2\n3\n1 3\n2 3\n1 2 3" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict_2class $dir/local/lang_2class $dir/lang_2class || exit 1 +fi + +feat_dim=`gmm-global-info $init_speech_model | grep "feature dimension" | awk '{print $NF}'` || exit 1 + +if [ $stage -le -3 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + + run.pl $dir/log/create_transition_model_2class.log gmm-init-mono \ + $dir/lang_2class/topo $feat_dim - $dir/tree_2class \| \ + copy-transition-model --binary=false - $dir/trans_2class.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $dir/lang_2class $dir $dir/graph_2class || exit 1 +fi + +if [ $stage -le -2 ]; then + { + cat $dir/trans_2class.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $init_silence_model - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + } | gmm-copy - $dir/init_2class.mdl || exit 1 +fi + +if [ $stage -le -1 ]; then + t=$speech_to_sil_ratio + lang=$dir/lang_test_${t}x + cp -r $dir/lang $lang + perl -e '$t = shift @ARGV; print "0 0 1 1 " . -log(1/($t+3)) . "\n0 0 2 2 ". -log($t/($t+3)). "\n0 0 3 3 ". -log(1/($t+3)) ."\n0 ". -log(1/($t+3))' $t | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + diarization/make_vad_graph.sh --iter trans $lang $dir $dir/graph_test_${t}x || exit 1 + + lang=$dir/lang_2class_test_${t}x + cp -r $dir/lang_2class $lang + perl -e '$t = shift @ARGV; print "0 0 1 1 " . -log(1/($t+2)) . "\n0 0 2 2 ". -log($t/($t+2)). "\n0 ". -log(1/($t+2))' $t | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $lang $dir $dir/graph_2class_test_${t}x || exit 1 +fi + + +while IFS=$'\n' read line; do + feats="ark:echo $line | apply-cmvn-sliding scp:- ark:- |${ignore_energy_opts}" + + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + if [ -f $data/segments ]; then + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-zero-crossings $zc_opts ark:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + # extract-segments scp:$data/wav.scp - ark:- \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # ark:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + ark:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + else + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-zero-crossings $zc_opts scp:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # scp:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + scp:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + fi + + if $add_frame_snrs; then + [ -z "$frame_snrs_scp" ] && echo "$0: add-frame-snrs is true but frame-snrs-scp is not supplied" && exit 1 + utils/filter_scp.pl $data/utt2spk $frame_snrs_scp > $dir/frame_snrs.scp || exit 1 + fi + + sil_num_gauss=$sil_num_gauss_init + sound_num_gauss=$sound_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + + if $add_zero_crossing_feats; then + feats="${feats}paste-feats ark:- ark:$dir/$utt_id.zero_crossings.ark ark:- |" + fi + + if $add_frame_snrs; then + feats="${feats}paste-feats ark:- \"ark:vector-to-feat scp:$dir/frame_snrs.scp ark:- |\" ark:- |" + fi + + feats="${feats} add-deltas ark:- ark:- |" + + if $write_feats; then + copy-feats "$feats" ark:$dir/$utt_id.feat.ark + fi + + $cmd $dir/log/$utt_id.gmm_compute_likes.bootstrap.log \ + gmm-compute-likes $dir/init_2class.mdl "$feats" \ + ark:$tmpdir/$utt_id.likes.bootstrap.ark & + + # Get VAD: 0 for silence, 1 for speech + $cmd $dir/log/$utt_id.get_vad.bootstrap.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $dir/init_2class.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| ali-to-pdf $dir/init_2class.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + if $use_bootstrap_vad; then + segmentation-copy ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + continue + fi + + cp $tmpdir/$utt_id.vad.bootstrap.ark $tmpdir/$utt_id.seg.0.ark + + x=0 + goto_phase3=false + + while [ $x -lt $num_iters ]; do + num_frames_silence=$[num_frames_init_silence + sil_num_gauss * frames_per_gaussian ] + num_frames_sound=$[num_frames_init_sound + 5 * sound_num_gauss * frames_per_gaussian ] + num_frames_sound_next=$[num_frames_init_sound_next + sound_num_gauss * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + $cmd $tmpdir/log/$utt_id.select_top.first.$[x+1].log \ + segmentation-copy --filter-label=0 \ + --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$tmpdir/$utt_id.seg.$x.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --merge-labels=0:2 --merge-dst-label=0 \ + --num-top-frames=$num_frames_sound --num-bottom-frames=$num_frames_silence \ + --top-select-label=2 --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-select-top --num-bins=$num_bins --src-label=2 \ + --num-top-frames=$num_frames_sound_next --num-bottom-frames=-1 \ + --top-select-label=2 --bottom-select-label=-1 --reject-label=1001 \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark "ark:extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- |" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + else + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=0 ark:$tmpdir/$utt_id.seg.$x.ark \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans.mdl; + echo " $feat_dim 3"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + gmm-global-copy --binary=false $init_speech_model -; + select-feats-from-segmentation --select-label=2 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sound_num_gauss+1] --num-gauss-init=1 --num-gauss=$sound_num_gauss \ + ark:- - || exit 1 + } 2> $tmpdir/log/$utt_id.init_gmm.log | \ + gmm-copy - $tmpdir/$utt_id.$[x+1].mdl 2>> $tmpdir/log/$utt_id.init_gmm.log + if [ $? -ne 0 ]; then + echo "Insufficient frames for training silence or sound model. Skipping to phase 3" + goto_phase3=true + break; + fi + #|| { echo "See $tmpdir/log/$utt_id.init_gmm.log for errors"; exit 1; } + else + #$cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation --pdfs=0:2 \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + # $tmpdir/$utt_id.$x.mdl "$feats" \ + # ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + # $tmpdir/$utt_id.$[x+1].mdl || exit 1 + $cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation --pdfs=0:2 \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + $tmpdir/$utt_id.$x.mdl "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + $tmpdir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $tmpdir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $tmpdir/$utt_id.$x.mdl "$feats" \ + ark:$tmpdir/$utt_id.likes.$x.ark & + + $cmd $tmpdir/log/$utt_id.get_seg.$[x+1].log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $tmpdir/$utt_id.$[x+1].mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $tmpdir/$utt_id.$[x+1].mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.seg.$[x+1].ark || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr] + fi + + x=$[x+1] + done ## Done training Silence and Speech GMMs + + if ! $goto_phase3; then + $cmd $phase2_dir/log/$utt_id.init_speech.log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=1 ark:$tmpdir/$utt_id.seg.$num_iters.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- $phase2_dir/$utt_id.speech.0.mdl + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.init_speech.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 1. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $tmpdir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + fi + + if ! $goto_phase3; then + echo "Beginning phase 2 for utterance $utt_id" + + $cmd $phase2_dir/log/$utt_id.init_gmm.log \ + gmm-init-pdf-from-global $tmpdir/$utt_id.$num_iters.mdl 1 \ + $phase2_dir/$utt_id.speech.0.mdl $phase2_dir/$utt_id.0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters_phase2 ]; do + if [ $sil_num_gauss -lt $sil_max_gauss_phase2 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase2] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss_phase2 ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr_phase2] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase2 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase2] + fi + + $cmd $phase2_dir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $phase2_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase2_dir/$utt_id.likes.$x.ark & + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + #$cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + # $phase2_dir/$utt_id.$x.mdl "$feats" \ + # ark:$phase2_dir/$utt_id.seg.$x.ark \ + # $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + $cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + $phase2_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase2_dir/$utt_id.seg.$x.ark \ + $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + + x=$[x+1] + done ## Done training all 3 GMMs + cp $phase2_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + + $cmd $phase2_dir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $phase2_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase2_dir/$utt_id.likes.$x.ark & + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + mkdir -p $phase3_dir/log + + $cmd $phase2_dir/log/$utt_id.init_nonsil.log \ + segmentation-copy --merge-labels=1:2 --merge-dst-label=1 \ + ark:$phase2_dir/$utt_id.seg.$x.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[sound_num_gauss + speech_num_gauss + 1] \ + --num-gauss-init=1 \ + --num-gauss=$[sound_num_gauss + speech_num_gauss] ark:- \ + $phase2_dir/$utt_id.$x.nonsil.mdl || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_speech_feats.$x.log \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $phase2_dir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_speech_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 2. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + + if $try_merge_speech_noise; then + if ! $goto_phase3; then + speech_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 1 - |" \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_speech_like.$x.log || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_sound_feats.$x.log \ + select-feats-from-segmentation --select-label=2 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark + + if [ $? -eq 0 ]; then + num_selected_sound=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_sound_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_sound -lt $min_data ]; then + echo "Insufficient frames for sound at the end of phase 2. $num_selected_sound < $min_data. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for sound at the end of phase 1. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + fi + + if ! $goto_phase3; then + sound_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 2 - |" \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_sound_like.$x.log || exit 1 + + nonsil_like=$(select-feats-from-segmentation --merge-labels=1:2 --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark ark:- | \ + gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.$x.nonsil.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_nonsil_like.$x.log || exit 1 + + if [ ! -z `perl -e "print \"true\" if ($sound_like + $speech_like < $nonsil_like)"` ]; then + goto_phase3=true + fi + fi + fi + fi + + if $goto_phase3; then + echo "Beginning phase 3 for utterance $utt_id" + speech_num_gauss=$speech_num_gauss_init_phase3 + sil_num_gauss=$sil_num_gauss_init_phase3 + + $cmd $phase3_dir/log/$utt_id.gmm_compute_likes.init.log \ + gmm-compute-likes $dir/init_2class.mdl "$feats" \ + ark:$phase3_dir/$utt_id.likes.0.ark & + + $cmd $phase3_dir/log/$utt_id.get_vad.bootstrap.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $dir/init_2class.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| ali-to-pdf $dir/init_2class.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.vad.0.ark || exit 1 + + x=0 + skip_phase4=false + + while [ $x -lt $num_iters_phase3 ]; do + num_frames_silence=$[num_frames_init_silence + sil_num_gauss * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + $cmd $phase3_dir/log/$utt_id.select_top.second.$[x+1].log \ + segmentation-copy --filter-label=0 \ + --filter-rspecifier=ark:$phase3_dir/$utt_id.vad.0.ark \ + ark:$phase3_dir/$utt_id.vad.$x.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --merge-dst-label=0 --select-from-full-histogram=true \ + --num-top-frames=-1 --num-bottom-frames=$num_frames_silence \ + --top-select-label=-1 --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark || exit 1 + + else + $cmd $phase3_dir/log/$utt_id.select_top.$[x+1].log \ + segmentation-copy --filter-rspecifier=ark:$phase3_dir/$utt_id.vad.0.ark \ + --filter-label=0 ark:$phase3_dir/$utt_id.vad.$x.ark \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans.mdl; + echo " $feat_dim 2"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + } 2> $phase3_dir/log/$utt_id.init_gmm.log | \ + gmm-copy - $phase3_dir/$utt_id.$[x+1].mdl 2>> $phase3_dir/log/$utt_id.init_gmm.log || exit 1 + else + $cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation --pdfs=0 \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\" |" \ + $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark \ + $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $phase3_dir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.likes.$x.ark & + + $cmd $phase3_dir/log/$utt_id.get_seg.$[x+1].log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.$[x+1].mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.$[x+1].mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.vad.$[x+1].ark || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + x=$[x+1] + done ## Done training Silence and Speech GMMs + + $cmd $phase3_dir/log/$utt_id.init_speech.log \ + segmentation-copy \ + ark:$phase3_dir/$utt_id.vad.$x.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- $phase3_dir/$utt_id.speech.$x.mdl + + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase3_dir/log/$utt_id.init_speech.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 3. $num_selected_speech < $min_data. Not re-training speech model." + skip_phase4=true + fi + else + echo "Failed to find any data for speech at the end of phase 3. See $phase3_dir/log/$utt_id.init_speech.log. Not re-training speech model." + skip_phase4=true + fi + + if ! $skip_phase4; then + $cmd $phase3_dir/log/$utt_id.init_gmm.$[x+1].log \ + gmm-init-pdf-from-global $phase3_dir/$utt_id.$x.mdl 1 \ + $phase3_dir/$utt_id.speech.$x.mdl $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + x=$[x+1] + + while [ $x -lt $[num_iters_phase4 + num_iters_phase3+1] ]; do + if [ $sil_num_gauss -lt $sil_max_gauss_phase4 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase4] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase4 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase4] + fi + + $cmd $phase3_dir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.likes.$x.ark & + + $cmd $phase3_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.$x.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.vad.$x.ark || exit 1 + + $cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\" |" \ + $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.vad.$x.ark \ + $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + + x=$[x+1] + done ## Done training all 3 GMMs + fi + cp $phase3_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + + rm -f $dir/$utt_id.graph_final + ln -s graph_2class_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + fi + + if $output_lattice; then + $cmd $dir/log/$utt_id.get_seg.final.log \ + gmm-latgen-faster --allow-partial=$allow_partial \ + --word-symbol-table=$dir/$utt_id.graph_final/words.txt \ + $dir/$utt_id.final.mdl $dir/$utt_id.graph_final/HCLG.fst \ + "$feats" ark,scp:$dir/$utt_id.lat.ark,$dir/$utt_id.lat.scp \ + ark:/dev/null ark:- \| \ + ali-to-pdf $dir/$utt_id.final.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + else + $cmd $dir/log/$utt_id.gmm_compute_likes.final.log \ + gmm-compute-likes $dir/$utt_id.final.mdl "$feats" \ + ark:$dir/$utt_id.likes.final.ark & + + $cmd $dir/log/$utt_id.get_seg.final.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/$utt_id.graph_final/words.txt \ + $dir/$utt_id.final.mdl $dir/$utt_id.graph_final/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $dir/$utt_id.final.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + fi + +done < $data/feats.scp + + diff --git a/egs/sre08/v1/diarization/vad_gmm_3models.sh b/egs/sre08/v1/diarization/vad_gmm_3models.sh new file mode 100755 index 00000000000..fd2cb44bf48 --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_3models.sh @@ -0,0 +1,718 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -u +set -o pipefail + +cmd=run.pl +stage=-100 +allow_partial=true +try_merge_speech_noise=false +output_lattice=false +write_feats=false +use_bootstrap_vad=false + +## Features paramters +window_size=100 # 1s +min_data=200 +frames_per_gaussian=2000 +num_bins=100 +num_sil_states=30 +num_nonsil_states=75 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=8 +sil_gauss_incr=0 +sound_gauss_incr=2 +sample_per_gaussian=2000 +num_iters_init=3 +num_iters=5 +min_sil_variance=1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=100000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2=10 + +## Phase 3 parameters +num_frames_silence_phase3_init=2000 +num_frames_speech_phase3_init=2000 +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +sil_max_gauss_phase4=8 +speech_max_gauss_phase4=16 +sil_gauss_incr_phase3=1 +sil_gauss_incr_phase4=1 +speech_gauss_incr_phase4=2 +num_iters_phase3=5 +num_iters_phase4=5 + +## Phase 5 parameters +num_frames_silence_phase5_init=2000 +num_frames_speech_phase5_init=2000 +sil_num_gauss_init_phase5=2 +speech_num_gauss_init_phase5=2 +sil_max_gauss_phase5=5 +speech_max_gauss_phase5=12 +sil_gauss_incr_phase5=1 +speech_gauss_incr_phase5=2 +num_iters_phase5=7 + +speech_to_sil_ratio=1 + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 5 ]; then + echo "Usage: vad_gmm_icsi.sh " + echo " e.g.: vad_gmm_icsi.sh data/rt05_eval exp/librispeech_s5/vad_model/{silence,speech,noise}.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +init_sound_model=$4 +dir=$5 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 +phase5_dir=$dir/phase5 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +init_model_dir=`dirname $init_speech_model` +ignore_energy_opts=`cat $init_model_dir/ignore_energy_opts` || exit 1 +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 + +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" + +# Prepare a lang directory +if [ $stage -le -4 ]; then + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + mkdir -p $dir/local/dict_2class + mkdir -p $dir/local/lm_2class + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo "1" > $dir/local/dict_2class/silence_phones.txt + echo "1" > $dir/local/dict_2class/optional_silence.txt + echo "2" > $dir/local/dict_2class/nonsilence_phones.txt + echo "3" >> $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict_2class/lexicon.txt + echo -e "1 1\n2 2\n3 3" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict_2class/extra_questions.txt + echo -e "1\n2\n1 2\n3\n1 3\n2 3\n1 2 3" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict_2class $dir/local/lang_2class $dir/lang_2class || exit 1 +fi + +feat_dim=`gmm-global-info $init_speech_model | grep "feature dimension" | awk '{print $NF}'` || exit 1 + +if [ $stage -le -3 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + + run.pl $dir/log/create_transition_model_2class.log gmm-init-mono \ + $dir/lang_2class/topo $feat_dim - $dir/tree_2class \| \ + copy-transition-model --binary=false - $dir/trans_2class.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $dir/lang_2class $dir $dir/graph_2class || exit 1 +fi + +if [ $stage -le -2 ]; then + { + cat $dir/trans.mdl + echo " $feat_dim 3" + gmm-global-copy --binary=false $init_silence_model - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + gmm-global-copy --binary=false $init_sound_model - || exit 1 + } | gmm-copy - $dir/init.mdl || exit 1 + + { + cat $dir/trans_2class.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $init_silence_model - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + } | gmm-copy - $dir/init_2class.mdl || exit 1 +fi + +if [ $stage -le -1 ]; then + t=$speech_to_sil_ratio + lang=$dir/lang_test_${t}x + cp -r $dir/lang $lang + perl -e '$t = shift @ARGV; print "0 0 1 1 " . -log(1/($t+3)) . "\n0 0 2 2 ". -log($t/($t+3)). "\n0 0 3 3 ". -log(1/($t+3)) ."\n0 ". -log(1/($t+3))' $t | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + diarization/make_vad_graph.sh --iter trans $lang $dir $dir/graph_test_${t}x || exit 1 + + lang=$dir/lang_2class_test_${t}x + cp -r $dir/lang_2class $lang + perl -e '$t = shift @ARGV; print "0 0 1 1 " . -log(1/($t+2)) . "\n0 0 2 2 ". -log($t/($t+2)). "\n0 ". -log(1/($t+2))' $t | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $lang $dir $dir/graph_2class_test_${t}x || exit 1 +fi + + +while IFS=$'\n' read line; do + feats="ark:echo $line | apply-cmvn-sliding scp:- ark:- |${ignore_energy_opts}" + + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + if [ -f $data/segments ]; then + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-zero-crossings $zc_opts ark:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + # extract-segments scp:$data/wav.scp - ark:- \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # ark:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + ark:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + else + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-zero-crossings $zc_opts scp:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # scp:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + scp:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + fi + + if $add_frame_snrs; then + [ -z "$frame_snrs_scp" ] && echo "$0: add-frame-snrs is true but frame-snrs-scp is not supplied" && exit 1 + utils/filter_scp.pl $data/utt2spk $frame_snrs_scp > $dir/frame_snrs.scp + fi + + sil_num_gauss=$sil_num_gauss_init + sound_num_gauss=$sound_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + + if $add_zero_crossing_feats; then + feats="${feats}paste-feats ark:- ark:$dir/$utt_id.zero_crossings.ark ark:- |" + fi + + if $add_frame_snrs; then + feats="${feats}paste-feats ark:- \"ark:vector-to-feat scp:$dir/frame_snrs.scp ark:- |\" ark:- |" + fi + + feats="${feats} add-deltas ark:- ark:- |" + + if $write_feats; then + copy-feats "$feats" ark:$dir/$utt_id.feat.ark + fi + + $cmd $dir/log/$utt_id.gmm_compute_likes.bootstrap.log \ + gmm-compute-likes $dir/init.mdl "$feats" \ + ark:$tmpdir/$utt_id.likes.bootstrap.ark & + + # Get VAD: 0 for silence, 1 for speech and 2 for sound + $cmd $dir/log/$utt_id.get_vad.bootstrap.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $dir/init.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| ali-to-pdf $dir/init.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + if $use_bootstrap_vad; then + segmentation-copy ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + continue + fi + + cp $tmpdir/$utt_id.vad.bootstrap.ark $tmpdir/$utt_id.seg.0.ark + + x=0 + goto_phase3=false + + while [ $x -lt $num_iters ]; do + num_frames_silence=$[num_frames_init_silence + sil_num_gauss * frames_per_gaussian ] + num_frames_sound=$[num_frames_init_sound + 5 * sound_num_gauss * frames_per_gaussian ] + num_frames_sound_next=$[num_frames_init_sound_next + sound_num_gauss * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + $cmd $tmpdir/log/$utt_id.select_top.first.$[x+1].log \ + segmentation-copy --filter-label=0 \ + --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$tmpdir/$utt_id.seg.$x.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --merge-labels=0:2 --merge-dst-label=0 \ + --num-top-frames=$num_frames_sound --num-bottom-frames=$num_frames_silence \ + --top-select-label=2 --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-select-top --num-bins=$num_bins --src-label=2 \ + --num-top-frames=$num_frames_sound_next --num-bottom-frames=-1 \ + --top-select-label=2 --bottom-select-label=-1 --reject-label=1001 \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark "ark:extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- |" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + else + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=0 ark:$tmpdir/$utt_id.seg.$x.ark \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans.mdl; + echo " $feat_dim 3"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + gmm-global-copy --binary=false $init_speech_model -; + select-feats-from-segmentation --select-label=2 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sound_num_gauss+1] --num-gauss-init=1 --num-gauss=$sound_num_gauss \ + ark:- - || exit 1 + } 2> $tmpdir/log/$utt_id.init_gmm.log | \ + gmm-copy - $tmpdir/$utt_id.$[x+1].mdl 2>> $tmpdir/log/$utt_id.init_gmm.log + if [ $? -ne 0 ]; then + echo "Insufficient frames for training silence or sound model. Skipping to phase 3" + goto_phase3=true + break; + fi + #|| { echo "See $tmpdir/log/$utt_id.init_gmm.log for errors"; exit 1; } + else + #$cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation --pdfs=0:2 \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + # $tmpdir/$utt_id.$x.mdl "$feats" \ + # ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + # $tmpdir/$utt_id.$[x+1].mdl || exit 1 + $cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation --pdfs=0:2 \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + $tmpdir/$utt_id.$x.mdl "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + $tmpdir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $tmpdir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $tmpdir/$utt_id.$x.mdl "$feats" \ + ark:$tmpdir/$utt_id.likes.$x.ark & + + $cmd $tmpdir/log/$utt_id.get_seg.$[x+1].log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $tmpdir/$utt_id.$[x+1].mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $tmpdir/$utt_id.$[x+1].mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.seg.$[x+1].ark || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr] + fi + + x=$[x+1] + done ## Done training Silence and Speech GMMs + + if ! $goto_phase3; then + $cmd $phase2_dir/log/$utt_id.init_speech.log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=1 ark:$tmpdir/$utt_id.seg.$num_iters.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- $phase2_dir/$utt_id.speech.0.mdl + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.init_speech.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 1. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $tmpdir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + fi + + if ! $goto_phase3; then + echo "Beginning phase 2 for utterance $utt_id" + + $cmd $phase2_dir/log/$utt_id.init_gmm.log \ + gmm-init-pdf-from-global $tmpdir/$utt_id.$num_iters.mdl 1 \ + $phase2_dir/$utt_id.speech.0.mdl $phase2_dir/$utt_id.0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters_phase2 ]; do + if [ $sil_num_gauss -lt $sil_max_gauss_phase2 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase2] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss_phase2 ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr_phase2] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase2 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase2] + fi + + $cmd $phase2_dir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $phase2_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase2_dir/$utt_id.likes.$x.ark & + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + #$cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + # $phase2_dir/$utt_id.$x.mdl "$feats" \ + # ark:$phase2_dir/$utt_id.seg.$x.ark \ + # $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + $cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + $phase2_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase2_dir/$utt_id.seg.$x.ark \ + $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + + x=$[x+1] + done ## Done training all 3 GMMs + cp $phase2_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + + $cmd $phase2_dir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $phase2_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase2_dir/$utt_id.likes.$x.ark & + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + mkdir -p $phase3_dir/log + + $cmd $phase2_dir/log/$utt_id.init_nonsil.log \ + segmentation-copy --merge-labels=1:2 --merge-dst-label=1 \ + ark:$phase2_dir/$utt_id.seg.$x.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[sound_num_gauss + speech_num_gauss + 1] \ + --num-gauss-init=1 \ + --num-gauss=$[sound_num_gauss + speech_num_gauss] ark:- \ + $phase2_dir/$utt_id.$x.nonsil.mdl || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_speech_feats.$x.log \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $phase2_dir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_speech_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 2. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + + if $try_merge_speech_noise; then + if ! $goto_phase3; then + speech_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 1 - |" \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_speech_like.$x.log || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_sound_feats.$x.log \ + select-feats-from-segmentation --select-label=2 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark + + if [ $? -eq 0 ]; then + num_selected_sound=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_sound_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_sound -lt $min_data ]; then + echo "Insufficient frames for sound at the end of phase 2. $num_selected_sound < $min_data. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for sound at the end of phase 1. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + fi + + if ! $goto_phase3; then + sound_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 2 - |" \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_sound_like.$x.log || exit 1 + + nonsil_like=$(select-feats-from-segmentation --merge-labels=1:2 --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark ark:- | \ + gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.$x.nonsil.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_nonsil_like.$x.log || exit 1 + + if [ ! -z `perl -e "print \"true\" if ($sound_like + $speech_like < $nonsil_like)"` ]; then + goto_phase3=true + fi + fi + fi + fi + + if $goto_phase3; then + echo "Beginning phase 3 for utterance $utt_id" + speech_num_gauss=$speech_num_gauss_init_phase3 + sil_num_gauss=$sil_num_gauss_init_phase3 + + $cmd $phase3_dir/log/$utt_id.gmm_compute_likes.init.log \ + gmm-compute-likes $dir/init_2class.mdl "$feats" \ + ark:$phase3_dir/$utt_id.likes.0.ark & + + $cmd $phase3_dir/log/$utt_id.get_vad.bootstrap.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $dir/init_2class.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| ali-to-pdf $dir/init_2class.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.vad.0.ark || exit 1 + + x=0 + skip_phase4=false + + while [ $x -lt $num_iters_phase3 ]; do + num_frames_silence=$[num_frames_init_silence + sil_num_gauss * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + $cmd $phase3_dir/log/$utt_id.select_top.second.$[x+1].log \ + segmentation-copy --filter-label=0 \ + --filter-rspecifier=ark:$phase3_dir/$utt_id.vad.0.ark \ + ark:$phase3_dir/$utt_id.vad.$x.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --merge-dst-label=0 --select-from-full-histogram=true \ + --num-top-frames=-1 --num-bottom-frames=$num_frames_silence \ + --top-select-label=-1 --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark || exit 1 + + else + $cmd $phase3_dir/log/$utt_id.select_top.$[x+1].log \ + segmentation-copy --filter-rspecifier=ark:$phase3_dir/$utt_id.vad.0.ark \ + --filter-label=0 ark:$phase3_dir/$utt_id.vad.$x.ark \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans.mdl; + echo " $feat_dim 2"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + } 2> $phase3_dir/log/$utt_id.init_gmm.log | \ + gmm-copy - $phase3_dir/$utt_id.$[x+1].mdl 2>> $phase3_dir/log/$utt_id.init_gmm.log || exit 1 + else + $cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation --pdfs=0 \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\" |" \ + $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark \ + $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $phase3_dir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.likes.$x.ark & + + $cmd $phase3_dir/log/$utt_id.get_seg.$[x+1].log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.$[x+1].mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.$[x+1].mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.vad.$[x+1].ark || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + x=$[x+1] + done ## Done training Silence and Speech GMMs + + $cmd $phase3_dir/log/$utt_id.init_speech.log \ + segmentation-copy \ + ark:$phase3_dir/$utt_id.vad.$x.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- $phase3_dir/$utt_id.speech.$x.mdl + + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase3_dir/log/$utt_id.init_speech.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 3. $num_selected_speech < $min_data. Not re-training speech model." + skip_phase4=true + fi + else + echo "Failed to find any data for speech at the end of phase 3. See $phase3_dir/log/$utt_id.init_speech.log. Not re-training speech model." + skip_phase4=true + fi + + if ! $skip_phase4; then + $cmd $phase3_dir/log/$utt_id.init_gmm.$[x+1].log \ + gmm-init-pdf-from-global $phase3_dir/$utt_id.$x.mdl 1 \ + $phase3_dir/$utt_id.speech.$x.mdl $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + x=$[x+1] + + while [ $x -lt $[num_iters_phase4 + num_iters_phase3+1] ]; do + if [ $sil_num_gauss -lt $sil_max_gauss_phase4 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase4] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase4 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase4] + fi + + $cmd $phase3_dir/log/$utt_id.gmm_compute_likes.$x.log \ + gmm-compute-likes $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.likes.$x.ark & + + $cmd $phase3_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.$x.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.vad.$x.ark || exit 1 + + $cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\" |" \ + $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.vad.$x.ark \ + $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + + x=$[x+1] + done ## Done training all 3 GMMs + fi + cp $phase3_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + + rm -f $dir/$utt_id.graph_final + ln -s graph_2class_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + fi + + if $output_lattice; then + $cmd $dir/log/$utt_id.get_seg.final.log \ + gmm-latgen-faster --allow-partial=$allow_partial \ + --word-symbol-table=$dir/$utt_id.graph_final/words.txt \ + $dir/$utt_id.final.mdl $dir/$utt_id.graph_final/HCLG.fst \ + "$feats" ark,scp:$dir/$utt_id.lat.ark,$dir/$utt_id.lat.scp \ + ark:/dev/null ark:- \| \ + ali-to-pdf $dir/$utt_id.final.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + else + $cmd $dir/log/$utt_id.gmm_compute_likes.final.log \ + gmm-compute-likes $dir/$utt_id.final.mdl "$feats" \ + ark:$dir/$utt_id.likes.final.ark & + + $cmd $dir/log/$utt_id.get_seg.final.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/$utt_id.graph_final/words.txt \ + $dir/$utt_id.final.mdl $dir/$utt_id.graph_final/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $dir/$utt_id.final.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + + fi + +done < $data/feats.scp diff --git a/egs/sre08/v1/diarization/vad_gmm_icsi.sh b/egs/sre08/v1/diarization/vad_gmm_icsi.sh new file mode 100755 index 00000000000..de562fcc84c --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_icsi.sh @@ -0,0 +1,629 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -u +set -o pipefail + +cmd=run.pl +stage=-100 +allow_partial=true +try_merge_speech_noise=false +output_lattice=false +write_feats=false + +## Features paramters +window_size=100 # 1s +min_data=200 +frames_per_gaussian=2000 +num_bins=100 +num_sil_states=30 +num_nonsil_states=75 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=8 +sil_gauss_incr=0 +sound_gauss_incr=2 +sample_per_gaussian=2000 +num_iters_init=3 +num_iters=5 +min_sil_variance=1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=100000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2=10 + +## Phase 3 parameters +num_frames_silence_phase3_init=2000 +num_frames_speech_phase3_init=2000 +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + +speech_to_sil_ratio=1 + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 4 ]; then + echo "Usage: vad_gmm_icsi.sh " + echo " e.g.: vad_gmm_icsi.sh data/rt05_eval exp/librispeech_s5/vad_model/silence.0.mdl exp/librispeech_s5/vad_model/speech.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +dir=$4 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +init_model_dir=`dirname $init_speech_model` +ignore_energy_opts=`cat $init_model_dir/ignore_energy_opts` || exit 1 +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 + +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" + +# Prepare a lang directory +if [ $stage -le -12 ]; then + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + mkdir -p $dir/local/dict_2class + mkdir -p $dir/local/lm_2class + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo "1" > $dir/local/dict_2class/silence_phones.txt + echo "1" > $dir/local/dict_2class/optional_silence.txt + echo "2" > $dir/local/dict_2class/nonsilence_phones.txt + echo "3" >> $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict_2class/lexicon.txt + echo -e "1 1\n2 2\n3 3" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict_2class/extra_questions.txt + echo -e "1\n2\n1 2\n3\n1 3\n2 3\n1 2 3" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict_2class $dir/local/lang_2class $dir/lang_2class || exit 1 +fi + +feat_dim=`gmm-global-info $init_speech_model | grep "feature dimension" | awk '{print $NF}'` || exit 1 + +if [ $stage -le -11 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + + run.pl $dir/log/create_transition_model_2class.log gmm-init-mono \ + $dir/lang_2class/topo $feat_dim - $dir/tree_2class \| \ + copy-transition-model --binary=false - $dir/trans_2class.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $dir/lang_2class $dir $dir/graph_2class || exit 1 +fi + +if [ $stage -le -10 ]; then + { + cat $dir/trans_2class.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $init_silence_model - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + } | gmm-copy - $dir/init_2class.mdl || exit 1 +fi + +if [ $stage -le -9 ]; then + t=$speech_to_sil_ratio + lang=$dir/lang_test_${t}x + cp -r $dir/lang $lang + perl -e '$t = shift @ARGV; print "0 0 1 1 " . -log(1/($t+3)) . "\n0 0 2 2 ". -log($t/($t+3)). "\n0 0 3 3 ". -log(1/($t+3)) ."\n0 ". -log(1/($t+3))' $t | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + diarization/make_vad_graph.sh --iter trans $lang $dir $dir/graph_test_${t}x || exit 1 + + lang=$dir/lang_2class_test_${t}x + cp -r $dir/lang_2class $lang + perl -e '$t = shift @ARGV; print "0 0 1 1 " . -log(1/($t+2)) . "\n0 0 2 2 ". -log($t/($t+2)). "\n0 ". -log(1/($t+2))' $t | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $lang $dir $dir/graph_2class_test_${t}x || exit 1 +fi + + +while IFS=$'\n' read line; do + feats="ark:echo $line | apply-cmvn-sliding scp:- ark:- |${ignore_energy_opts}" + + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + if [ -f $data/segments ]; then + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-zero-crossings $zc_opts ark:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + # extract-segments scp:$data/wav.scp - ark:- \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # ark:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + ark:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + else + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-zero-crossings $zc_opts scp:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # scp:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + scp:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + fi + + sil_num_gauss=$sil_num_gauss_init + sound_num_gauss=$sound_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + + if $add_zero_crossing_feats; then + feats="${feats} paste-feats ark:- ark:$dir/$utt_id.zero_crossings.ark ark:- |" + fi + + feats="${feats} add-deltas ark:- ark:- |" + + if $write_feats; then + copy-feats "$feats" ark:$dir/$utt_id.feat.ark + fi + + # Get VAD: 0 for silence, 1 for speech and 2 for sound + $cmd $dir/log/$utt_id.get_vad.bootstrap.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $dir/init_2class.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| ali-to-pdf $dir/init_2class.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + cp $tmpdir/$utt_id.vad.bootstrap.ark $tmpdir/$utt_id.seg.0.ark + + x=0 + goto_phase3=false + + while [ $x -lt $num_iters ]; do + num_frames_silence=$[(x+1) * frames_per_gaussian ] + num_frames_sound=$[5 * frames_per_gaussian ] + num_frames_sound_next=$[(x+1) * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + $cmd $tmpdir/log/$utt_id.select_top.first.$[x+1].log \ + segmentation-copy --filter-label=0 \ + --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$tmpdir/$utt_id.seg.$x.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --merge-labels=0:2 --merge-dst-label=0 \ + --num-top-frames=$num_frames_sound --num-bottom-frames=$num_frames_silence \ + --top-select-label=2 --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-select-top --num-bins=$num_bins --src-label=2 \ + --num-top-frames=$num_frames_sound_next --num-bottom-frames=-1 \ + --top-select-label=2 --bottom-select-label=-1 --reject-label=1001 \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark "ark:extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- |" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + else + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=0 ark:$tmpdir/$utt_id.seg.$x.ark \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans.mdl; + echo " $feat_dim 3"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + gmm-global-copy --binary=false $init_speech_model -; + select-feats-from-segmentation --select-label=2 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sound_num_gauss+1] --num-gauss-init=1 --num-gauss=$sound_num_gauss \ + ark:- - || exit 1 + } 2> $tmpdir/log/$utt_id.init_gmm.log | \ + gmm-copy - $tmpdir/$utt_id.$[x+1].mdl 2>> $tmpdir/log/$utt_id.init_gmm.log || exit 1 + else + #$cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation --pdfs=0:2 \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + # $tmpdir/$utt_id.$x.mdl "$feats" \ + # ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + # $tmpdir/$utt_id.$[x+1].mdl || exit 1 + $cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation --pdfs=0:2 \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + $tmpdir/$utt_id.$x.mdl "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + $tmpdir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $tmpdir/log/$utt_id.get_seg.$[x+1].log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $tmpdir/$utt_id.$[x+1].mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $tmpdir/$utt_id.$[x+1].mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.seg.$[x+1].ark || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr] + fi + + x=$[x+1] + done ## Done training Silence and Speech GMMs + + $cmd $phase2_dir/log/$utt_id.init_speech.log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=1 ark:$tmpdir/$utt_id.seg.$num_iters.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- $phase2_dir/$utt_id.speech.0.mdl + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.init_speech.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 1. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $tmpdir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + + if ! $goto_phase3; then + $cmd $phase2_dir/log/$utt_id.init_gmm.log \ + gmm-init-pdf-from-global $tmpdir/$utt_id.$num_iters.mdl 1 \ + $phase2_dir/$utt_id.speech.0.mdl $phase2_dir/$utt_id.0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters_phase2 ]; do + if [ $sil_num_gauss -lt $sil_max_gauss_phase2 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase2] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss_phase2 ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr_phase2] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase2 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase2] + fi + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + #$cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + # $phase2_dir/$utt_id.$x.mdl "$feats" \ + # ark:$phase2_dir/$utt_id.seg.$x.ark \ + # $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + $cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + $phase2_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase2_dir/$utt_id.seg.$x.ark \ + $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + + x=$[x+1] + done ## Done training all 3 GMMs + cp $phase2_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + mkdir -p $phase3_dir/log + + $cmd $phase2_dir/log/$utt_id.init_nonsil.log \ + segmentation-copy --merge-labels=1:2 --merge-dst-label=1 \ + ark:$phase2_dir/$utt_id.seg.$x.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[sound_num_gauss + speech_num_gauss + 1] \ + --num-gauss-init=1 \ + --num-gauss=$[sound_num_gauss + speech_num_gauss] ark:- \ + $phase2_dir/$utt_id.$x.nonsil.mdl || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_speech_feats.$x.log \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $phase2_dir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_speech_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 2. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + + if $try_merge_speech_noise; then + if ! $goto_phase3; then + speech_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 1 - |" \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_speech_like.$x.log || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_sound_feats.$x.log \ + select-feats-from-segmentation --select-label=2 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark + + if [ $? -eq 0 ]; then + num_selected_sound=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_sound_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_sound -lt $min_data ]; then + echo "Insufficient frames for sound at the end of phase 2. $num_selected_sound < $min_data. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for sound at the end of phase 1. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + fi + fi + + if ! $goto_phase3; then + sound_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 1 - |" \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_sound_like.$x.log || exit 1 + + nonsil_like=$(select-feats-from-segmentation --merge-labels=1:2 --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark ark:- | \ + gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.$x.nonsil.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_nonsil_like.$x.log || exit 1 + + if [ ! -z `perl -e "print \"true\" if ($sound_like + $speech_like < $nonsil_like)"` ]; then + goto_phase3=true + fi + fi + fi + + if $goto_phase3; then + speech_num_gauss=$speech_num_gauss_init_phase3 + sil_num_gauss=$sil_num_gauss_init_phase3 + + $cmd $phase3_dir/log/$utt_id.compute_silence_likes.bootstrap.log \ + gmm-global-get-frame-likes $init_silence_model "$feats" \ + ark:$dir/$utt_id.silence_log_likes.bootstrap.ark || exit 1 + + $cmd $phase3_dir/log/$utt_id.compute_speech_likes.bootstrap.log \ + gmm-global-get-frame-likes $init_speech_model "$feats" \ + ark:$dir/$utt_id.speech_log_likes.bootstrap.ark || exit 1 + + { + cat $dir/trans_2class.mdl; + echo " $feat_dim 2"; + segmentation-select-top --num-bins=$num_bins \ + --src-label=0 --num-top-frames=$[200 * frames_per_gaussian] \ + --top-select-label=0 --bottom-select-label=-1 \ + --reject-label=1000 --select-above-mean=true \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=1 --min-window-remainder=1 \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$dir/$utt_id.silence_log_likes.bootstrap.ark ark:- | \ + select-feats-from-segmentation --select-label=0 "$feats" ark:- ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- - || exit 1 + } 2> $phase3_dir/log/$utt_id.check_gmm.log | \ + gmm-copy - $phase3_dir/$utt_id.check.mdl 2>> $phase3_dir/log/$utt_id.check_gmm.log + + $cmd $phase3_dir/log/$utt_id.get_seg.check.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.check.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.check.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.seg.check.ark || exit 1 + + num_frames_speech=$(select-feats-from-segmentation --select-label=1 "$feats" \ + ark:$phase3_dir/$utt_id.seg.check.ark ark:- | \ + feat-to-len ark:- ark,t:- | awk '{i+=$2} END{print i}') + + phase3_done=false + if [ $num_frames_speech -lt $min_data ]; then + phase3_done=true + fi + + if ! $phase3_done; then + x=0 + + $cmd $phase3_dir/log/$utt_id.init_silence_gmm.log \ + segmentation-select-top --num-bins=$num_bins \ + --src-label=0 --num-top-frames=$num_frames_silence_phase3_init \ + --top-select-label=0 --bottom-select-label=-1 \ + --reject-label=1000 --select-above-mean=true \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=1 --min-window-remainder=1 \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$dir/$utt_id.silence_log_likes.bootstrap.ark ark:- \| \ + select-feats-from-segmentation --select-label=0 "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - \| \ + gmm-init-pdf-from-global $dir/init_2class.mdl 0 - \ + $phase3_dir/$utt_id.tmp.mdl || exit 1 + + $cmd $phase3_dir/log/$utt_id.init_speech_gmm.log \ + segmentation-select-top --num-bins=$num_bins \ + --src-label=1 --num-top-frames=$num_frames_speech_phase3_init \ + --top-select-label=1 --bottom-select-label=-1 \ + --reject-label=1000 --select-above-mean=true \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=1 --min-window-remainder=1 \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$dir/$utt_id.speech_log_likes.bootstrap.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- - \| \ + gmm-init-pdf-from-global $phase3_dir/$utt_id.tmp.mdl 1 - \ + $phase3_dir/$utt_id.0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters_phase3 ]; do + $cmd $phase3_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.$x.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.seg.$x.ark || exit 1 + + #$cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\" |" \ + # $phase3_dir/$utt_id.$x.mdl "$feats" \ + # ark:$phase3_dir/$utt_id.seg.$x.ark \ + # $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + $cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\" |" \ + $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.seg.$x.ark \ + $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss_phase3 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase3] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase3 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase3] + fi + + x=$[x+1] + done ## Done training all 2 GMMs + + cp $phase3_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_2class_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + else + echo "Not going to phase3" + fi + fi + + if $output_lattice; then + $cmd $dir/log/$utt_id.get_seg.final.log \ + gmm-latgen-faster --allow-partial=$allow_partial \ + --word-symbol-table=$dir/$utt_id.graph_final/words.txt \ + $dir/$utt_id.final.mdl $dir/$utt_id.graph_final/HCLG.fst \ + "$feats" ark,scp:$dir/$utt_id.lat.ark,$dir/$utt_id.lat.scp \ + ark:/dev/null ark:- \| \ + ali-to-pdf $dir/$utt_id.final.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + else + $cmd $dir/log/$utt_id.get_seg.final.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/$utt_id.graph_final/words.txt \ + $dir/$utt_id.final.mdl $dir/$utt_id.graph_final/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $dir/$utt_id.final.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + fi + +done < $data/feats.scp diff --git a/egs/sre08/v1/diarization/vad_gmm_icsi_clean_phase3.sh b/egs/sre08/v1/diarization/vad_gmm_icsi_clean_phase3.sh new file mode 100755 index 00000000000..1cc3dc5af6c --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_icsi_clean_phase3.sh @@ -0,0 +1,606 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -u +set -o pipefail + +cmd=run.pl +stage=-1 +allow_partial=true + +## Features paramters +window_size=100 # 1s +min_data=200 +frames_per_gaussian=2000 +num_bins=100 +num_sil_states=30 +num_nonsil_states=75 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=8 +sil_gauss_incr=0 +sound_gauss_incr=2 +sample_per_gaussian=2000 +num_iters_init=3 +num_iters=5 +min_sil_variance=1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=100000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2=10 + +## Phase 3 parameters +num_frames_silence_phase3_init=2000 +num_frames_speech_phase3_init=2000 +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +sil_max_gauss_phase4=8 +speech_max_gauss_phase4=16 +sil_gauss_incr_phase3=1 +sil_gauss_incr_phase4=1 +speech_gauss_incr_phase4=2 +num_iters_phase3=5 +num_iters_phase4=5 + +speech_to_sil_ratio=1 + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 4 ]; then + echo "Usage: vad_gmm_icsi.sh " + echo " e.g.: vad_gmm_icsi.sh data/rt05_eval exp/librispeech_s5/vad_model/silence.0.mdl exp/librispeech_s5/vad_model/speech.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +dir=$4 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +init_model_dir=`dirname $init_speech_model` +ignore_energy_opts=`cat $init_model_dir/ignore_energy_opts` || exit 1 +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 + +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" + +# Prepare a lang directory +if [ $stage -le -12 ]; then + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + mkdir -p $dir/local/dict_2class + mkdir -p $dir/local/lm_2class + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo "1" > $dir/local/dict_2class/silence_phones.txt + echo "1" > $dir/local/dict_2class/optional_silence.txt + echo "2" > $dir/local/dict_2class/nonsilence_phones.txt + echo "3" >> $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict_2class/lexicon.txt + echo -e "1 1\n2 2\n3 3" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict_2class/extra_questions.txt + echo -e "1\n2\n1 2\n3\n1 3\n2 3\n1 2 3" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict_2class $dir/local/lang_2class $dir/lang_2class || exit 1 +fi + +feat_dim=`gmm-global-info $init_speech_model | grep "feature dimension" | awk '{print $NF}'` || exit 1 + +if [ $stage -le -11 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + + run.pl $dir/log/create_transition_model_2class.log gmm-init-mono \ + $dir/lang_2class/topo $feat_dim - $dir/tree_2class \| \ + copy-transition-model --binary=false - $dir/trans_2class.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $dir/lang_2class $dir $dir/graph_2class || exit 1 +fi + +if [ $stage -le -10 ]; then + { + cat $dir/trans_2class.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $init_silence_model - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + } | gmm-copy - $dir/init.mdl || exit 1 +fi + +if [ $stage -le -9 ]; then + t=$speech_to_sil_ratio + lang=$dir/lang_test_${t}x + cp -r $dir/lang $lang + perl -e "print \"0 0 1 1 \" . -log(1/$[t+3]) . \"\n0 0 2 2 \". -log($t/$[t+3]). \"\n0 0 3 3 \". -log(1/$[t+3]) .\"\n0 \". -log(1/$[t+3])" | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + diarization/make_vad_graph.sh --iter trans $lang $dir $dir/graph_test_${t}x || exit 1 + + lang=$dir/lang_2class_test_${t}x + cp -r $dir/lang_2class $lang + perl -e "print \"0 0 1 1 \" . -log(1/$[t+2]) . \"\n0 0 2 2 \". -log($t/$[t+2]). \"\n0 \". -log(1/$[t+2])" | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $lang $dir $dir/graph_2class_test_${t}x || exit 1 +fi + + +while IFS=$'\n' read line; do + feats="ark:echo $line | apply-cmvn --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |${ignore_energy_opts}" + + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + if [ -f $data/segments ]; then + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-zero-crossings $zc_opts ark:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + # extract-segments scp:$data/wav.scp - ark:- \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # ark:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + ark:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + else + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-zero-crossings $zc_opts scp:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # scp:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + scp:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + fi + + sil_num_gauss=$sil_num_gauss_init + sound_num_gauss=$sound_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + + if $add_zero_crossing_feats; then + feats="${feats} paste-feats ark:- ark:$dir/$utt_id.zero_crossings.ark ark:- |" + fi + + feats="${feats} add-deltas ark:- ark:- |" + + # Get VAD: 0 for silence, 1 for speech and 2 for sound + $cmd $dir/log/$utt_id.get_vad.bootstrap.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $dir/init.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| ali-to-pdf $dir/init.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + cp $tmpdir/$utt_id.vad.bootstrap.ark $tmpdir/$utt_id.seg.0.ark + + x=0 + goto_phase3=false + + while [ $x -lt $num_iters ]; do + num_frames_silence=$[num_frames_init_silence + sil_num_gauss * frames_per_gaussian ] + num_frames_sound=$[num_frames_init_sound + 5 * sound_num_gauss * frames_per_gaussian ] + num_frames_sound_next=$[num_frames_init_sound_next + sound_num_gauss * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + $cmd $tmpdir/log/$utt_id.select_top.first.$[x+1].log \ + segmentation-copy --filter-label=0 \ + --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$tmpdir/$utt_id.seg.$x.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --merge-labels=0:2 --merge-dst-label=0 \ + --num-top-frames=$num_frames_sound --num-bottom-frames=$num_frames_silence \ + --top-select-label=2 --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-select-top --num-bins=$num_bins --src-label=2 \ + --num-top-frames=$num_frames_sound_next --num-bottom-frames=-1 \ + --top-select-label=2 --bottom-select-label=-1 --reject-label=1001 \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark "ark:extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- |" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + else + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=0 ark:$tmpdir/$utt_id.seg.$x.ark \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans.mdl; + echo " $feat_dim 3"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + gmm-global-copy --binary=false $init_speech_model -; + select-feats-from-segmentation --select-label=2 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sound_num_gauss+1] --num-gauss-init=1 --num-gauss=$sound_num_gauss \ + ark:- - || exit 1 + } 2> $tmpdir/log/$utt_id.init_gmm.log | \ + gmm-copy - $tmpdir/$utt_id.$[x+1].mdl 2>> $tmpdir/log/$utt_id.init_gmm.log || exit 1 + else + #$cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation --pdfs=0:2 \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + # $tmpdir/$utt_id.$x.mdl "$feats" \ + # ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + # $tmpdir/$utt_id.$[x+1].mdl || exit 1 + $cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation --pdfs=0:2 \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + $tmpdir/$utt_id.$x.mdl "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + $tmpdir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $tmpdir/log/$utt_id.get_seg.$[x+1].log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $tmpdir/$utt_id.$[x+1].mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $tmpdir/$utt_id.$[x+1].mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.seg.$[x+1].ark || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr] + fi + + x=$[x+1] + done ## Done training Silence and Speech GMMs + + $cmd $phase2_dir/log/$utt_id.init_speech.log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=1 ark:$tmpdir/$utt_id.seg.$num_iters.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- $phase2_dir/$utt_id.speech.0.mdl + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.init_speech.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 1. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $tmpdir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + + if ! $goto_phase3; then + $cmd $phase2_dir/log/$utt_id.init_gmm.log \ + gmm-init-pdf-from-global $tmpdir/$utt_id.$num_iters.mdl 1 \ + $phase2_dir/$utt_id.speech.0.mdl $phase2_dir/$utt_id.0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters_phase2 ]; do + if [ $sil_num_gauss -lt $sil_max_gauss_phase2 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase2] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss_phase2 ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr_phase2] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase2 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase2] + fi + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + #$cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + # $phase2_dir/$utt_id.$x.mdl "$feats" \ + # ark:$phase2_dir/$utt_id.seg.$x.ark \ + # $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + $cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + $phase2_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase2_dir/$utt_id.seg.$x.ark \ + $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + + x=$[x+1] + done ## Done training all 3 GMMs + cp $phase2_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + mkdir -p $phase3_dir/log + + $cmd $phase2_dir/log/$utt_id.init_nonsil.log \ + segmentation-copy --merge-labels=1:2 --merge-dst-label=1 \ + ark:$phase2_dir/$utt_id.seg.$x.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[sound_num_gauss + speech_num_gauss + 1] \ + --num-gauss-init=1 \ + --num-gauss=$[sound_num_gauss + speech_num_gauss] ark:- \ + $phase2_dir/$utt_id.$x.nonsil.mdl || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_speech_feats.$x.log \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $phase2_dir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_speech_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 2. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + + if ! $goto_phase3; then + speech_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 1 - |" \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_speech_like.$x.log || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_sound_feats.$x.log \ + select-feats-from-segmentation --select-label=2 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark + + if [ $? -eq 0 ]; then + num_selected_sound=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_sound_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_sound -lt $min_data ]; then + echo "Insufficient frames for sound at the end of phase 2. $num_selected_sound < $min_data. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for sound at the end of phase 1. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + fi + + if ! $goto_phase3; then + sound_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 1 - |" \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_sound_like.$x.log || exit 1 + + nonsil_like=$(select-feats-from-segmentation --merge-labels=1:2 --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark ark:- | \ + gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.$x.nonsil.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_nonsil_like.$x.log || exit 1 + + if [ ! -z `perl -e "print \"true\" if ($sound_like + $speech_like < $nonsil_like)"` ]; then + goto_phase3=true + fi + fi + fi + + if $goto_phase3; then + speech_num_gauss=$speech_num_gauss_init_phase3 + sil_num_gauss=$sil_num_gauss_init_phase3 + + $cmd $phase3_dir/log/$utt_id.compute_silence_likes.bootstrap.log \ + gmm-global-get-frame-likes $init_silence_model "$feats" \ + ark:$dir/$utt_id.silence_log_likes.bootstrap.ark || exit 1 + + $cmd $phase3_dir/log/$utt_id.compute_speech_likes.bootstrap.log \ + gmm-global-get-frame-likes $init_speech_model "$feats" \ + ark:$dir/$utt_id.speech_log_likes.bootstrap.ark || exit 1 + + cp $tmpdir/$utt_id.vad.bootstrap.ark $phase3_dir/$utt_id.vad.0.ark + + x=0 + goto_phase3=false + + while [ $x -lt $num_iters_phase3 ]; do + num_frames_silence=$[num_frames_init_silence + sil_num_gauss * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + $cmd $phase3_dir/log/$utt_id.select_top.second.$[x+1].log \ + segmentation-copy --filter-label=0 \ + --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$phase3_dir/$utt_id.vad.$x.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --merge-dst-label=0 \ + --num-top-frames=-1 --num-bottom-frames=$num_frames_silence \ + --top-select-label=-1 --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark || exit 1 + + else + $cmd $phase3_dir/log/$utt_id.select_top.$[x+1].log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=0 ark:$phase3_dir/$utt_id.vad.$x.ark \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans.mdl; + echo " $feat_dim 2"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + } 2> $phase3_dir/log/$utt_id.init_gmm.log | \ + gmm-copy - $phase3_dir/$utt_id.$[x+1].mdl 2>> $phase3_dir/log/$utt_id.init_gmm.log || exit 1 + else + $cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation --pdfs=0 \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\" |" \ + $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.vad.second.$[x+1].ark \ + $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $phase3_dir/log/$utt_id.get_seg.$[x+1].log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.$[x+1].mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.$[x+1].mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.vad.$[x+1].ark || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + x=$[x+1] + done ## Done training Silence and Speech GMMs + + $cmd $phase3_dir/log/$utt_id.init_speech.log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=1 ark:$phase3_dir/$utt_id.vad.$x.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- $phase3_dir/$utt_id.speech.$x.mdl + + $cmd $phase3_dir/log/$utt_id.init_gmm.log \ + gmm-init-pdf-from-global $phase3_dir/$utt_id.$x.mdl 1 \ + $phase3_dir/$utt_id.speech.$x.mdl $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + + x=$[x+1] + + while [ $x -lt $[num_iters_phase4 + num_iters_phase3+1] ]; do + if [ $sil_num_gauss -lt $sil_max_gauss_phase4 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase4] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase4 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase4] + fi + + $cmd $phase3_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.$x.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.vad.$x.ark || exit 1 + + $cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\" |" \ + $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.vad.$x.ark \ + $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + + x=$[x+1] + done ## Done training all 3 GMMs + + cp $phase3_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_2class_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + fi + + $cmd $dir/log/$utt_id.get_seg.final.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/$utt_id.graph_final/words.txt \ + $dir/$utt_id.final.mdl $dir/$utt_id.graph_final/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $dir/$utt_id.final.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$dir/$utt_id.vad.final.ark || exit 1 + +done < $data/feats.scp diff --git a/egs/sre08/v1/diarization/vad_gmm_icsi_em.sh b/egs/sre08/v1/diarization/vad_gmm_icsi_em.sh new file mode 100755 index 00000000000..8b511606a0b --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_icsi_em.sh @@ -0,0 +1,602 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -u +set -o pipefail + +cmd=run.pl +stage=-1 + +## Features paramters +window_size=100 # 1s +force_ignore_energy_opts= +try_merge_sound_speech=false +do_phase1=false + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=8 +sil_gauss_incr=0 +sound_gauss_incr=2 +sil_frames_incr=2000 +sound_frames_incr=10000 +sound_frames_next_incr=2000 +num_iters=5 +min_sil_variance=1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2=10 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 4 ]; then + echo "Usage: vad_gmm_icsi.sh " + echo " e.g.: vad_gmm_icsi.sh data/rt05_eval exp/librispeech_s5/vad_model/silence.0.mdl exp/librispeech_s5/vad_model/speech.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +dir=$4 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +init_model_dir=`dirname $init_speech_model` +ignore_energy_opts=`cat $init_model_dir/ignore_energy_opts` || exit 1 + +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" + + +# Prepare a lang directory +if [ $stage -le -12 ]; then + mkdir -p $dir/local + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo "1" > $dir/local/dict/silence_phones.txt + echo "3" >> $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2\n3 3" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2\n3\n1 3\n2 3\n1 2 3" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states 1 --num-nonsil-states 1 \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fstisstochastic $dir/lang/G.fst || echo "[info]: G not stochastic." + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict $dir/local/lang $dir/lang_test || exit 1 +fi + +feat_dim=`feat-to-dim "ark:head -n 1 $data/feats.scp | copy-feats scp:- ark:- | add-deltas ark:- ark:- |${ignore_energy_opts}" ark,t:- | awk '{print $2}'` || exit 1 + +if [ $stage -le -11 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang_test/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans_test.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_test $dir/lang_test $dir $dir/graph_test || exit 1 +fi + +cat < $dir/pdf_to_tid.map +0 1 +1 3 +EOF + +while IFS=$'\n' read line; do + feats="ark:echo $line | copy-feats scp:- ark:- | add-deltas ark:- ark:- |${ignore_energy_opts}" + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + if [ -f $data/segments ]; then + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-zero-crossings $zc_opts ark:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + $cmd $dir/log/$utt_id.extract_pitch.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + ark:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + else + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-zero-crossings $zc_opts scp:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + $cmd $dir/log/$utt_id.extract_pitch.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + scp:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + fi + + $cmd $dir/log/$utt_id.extract_log_energies.log \ + extract-column "scp:utils/filter_scp.pl $dir/$utt_id.list $data/feats.scp |" \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + + sil_num_gauss=$sil_num_gauss_init + sound_num_gauss=$sound_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + num_frames_silence=$num_frames_init_silence + num_frames_sound=$num_frames_init_sound + num_frames_sound_next=$num_frames_init_sound_next + + if $add_zero_crossing_feats; then + feats="${feats} paste-feats ark:- \"ark:add-deltas ark:$dir/$utt_id.zero_crossings.ark ark:- |\" ark:- |" + fi + + + ### Compute likelihoods wrt bootstrapping models + $cmd $dir/log/$utt_id.compute_speech_like.bootstrap.log \ + gmm-global-get-frame-likes $init_speech_model \ + "${feats}" ark:$dir/$utt_id.speech_likes.bootstrap.ark || exit 1 + + $cmd $dir/log/$utt_id.compute_silence_like.bootstrap.log \ + gmm-global-get-frame-likes $init_silence_model \ + "${feats}" ark:$dir/$utt_id.silence_likes.bootstrap.ark || exit 1 + + ### Get bootstrapping VAD + $cmd $tmpdir/log/$utt_id.get_vad.bootstrap.log \ + loglikes-to-class --post=ark:$dir/$utt_id.post.bootstrap.ark \ + ark:$dir/$utt_id.silence_likes.bootstrap.ark \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + if [ ! -z "$force_ignore_energy_opts" ]; then + ignore_energy_opts=$force_ignore_energy_opts + feats="ark:echo $line | copy-feats scp:- ark:- | add-deltas ark:- ark:- |${ignore_energy_opts}" + if $add_zero_crossing_feats; then + feats="${feats} paste-feats ark:- \"ark:add-deltas ark:$dir/$utt_id.zero_crossings.ark ark:- |\" ark:- |" + fi + fi + + $ + ### Initialize Silence GMM using lowest energy chunks that were classified + ### as silence by the bootstrapping model + $cmd $tmpdir/log/$utt_id.init_silence_gmm.log \ + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=0 \ + --select-bottom-frames=true \ + --weights=ark:$dir/$utt_id.log_energies.ark --num-select-frames=$num_frames_silence \ + "${feats}" ark:- ark:$tmpdir/$utt_id.silence_mask.0.ark \| \ + gmm-global-init-from-feats \ + --min-variance=$min_sil_variance --num-gauss=$sil_num_gauss --num-iters=$[sil_num_gauss+2] ark:- \ + $tmpdir/$utt_id.silence.0.mdl || exit 1 + + ### Initialize Sound GMM using highest zero-crossing + ### chunks that were classified + ### as silence by the bootstrapping model. + $cmd $tmpdir/log/$utt_id.init_sound_gmm.log \ + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=0 \ + --weights="ark:extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- |" --num-select-frames=$num_frames_sound \ + "${feats}" ark:- ark:$tmpdir/$utt_id.sound_mask.0.ark \| \ + gmm-global-init-from-feats \ + --min-variance=$min_sound_variance --num-gauss=$sound_num_gauss --num-iters=$[sound_num_gauss+2] ark:- \ + $tmpdir/$utt_id.sound.0.mdl || exit 1 + + ### Initialize Speech GMM using highest NCCF chunks that were classified as + ### speech by the bootstrapping model. + $cmd $tmpdir/log/$utt_id.init_speech_gmm.log \ + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=1 \ + --weights="ark:extract-column ark:$dir/$utt_id.kaldi_pitch.ark ark:- |" \ + --num-select-frames=$num_frames_init_speech \ + "${feats}" ark:- ark:$tmpdir/$utt_id.speech_mask.0.ark \| \ + gmm-global-init-from-feats \ + --num-gauss=$sound_num_gauss --num-iters=$[speech_num_gauss+2] ark:- \ + $tmpdir/$utt_id.speech.0.mdl || exit 1 + + { + cat $dir/trans_test.mdl + echo " $feat_dim 3" + gmm-global-copy --binary=false $tmpdir/$utt_id.silence.0.mdl + gmm-global-copy --binary=false $tmpdir/$utt_id.sound.0.mdl + gmm-global-copy --binary=false $tmpdir/$utt_id.speech.0.mdl + } | gmm-copy - $dir/$utt_id.0.mdl || exit 1 + + if $do_phase1; then + + ### Compute likelihoods with the newly initialized Silence and Sound GMMs + $cmd $tmpdir/log/$utt_id.compute_silence_likes.0.log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.0.mdl \ + "${feats}" ark:$tmpdir/$utt_id.silence_likes.0.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_sound_likes.0.log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.sound.0.mdl \ + "${feats}" ark:$tmpdir/$utt_id.sound_likes.0.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_speech_likes.0.log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.speech.0.mdl \ + "${feats}" ark:$tmpdir/$utt_id.speech_likes.0.ark || exit 1 + + ### Get initial VAD + { + loglikes-to-class --post=ark:$tmpdir/$utt_id.post.init.ark \ + ark:$tmpdir/$utt_id.silence_likes.0.ark \ + ark:$tmpdir/$utt_id.speech_likes.0.ark \ + ark:$tmpdir/$utt_id.sound_likes.0.ark ark,t:- | \ + perl -pe 's/\[(.+)]/$1/' | \ + utils/apply_map.pl -f 2- <(echo -e "0 0\n1 1\n2 0") | \ + awk '{printf $1" [ "; for (i = 2; i <= NF; i++) {printf $i" ";}; print "]"}' | \ + copy-vector ark,t:- ark:$tmpdir/$utt_id.vad.init.ark ; + } &> $tmpdir/log/$utt_id.get_vad.init.log || exit 1 + + ### Remove frames that were originally classified as speech + ### while training Silence and Sound GMMs + $cmd $tmpdir/log/$utt_id.select_feats_phase1.init.log \ + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.init.ark --select-class=0 \ + "$feats" ark:$tmpdir/$utt_id.feats.init.ark \ + ark:$tmpdir/$utt_id.mask.init.ark || exit 1 + + ## Select energies and zero crossings corresponding to the same selection + + $cmd $tmpdir/log/$utt_id.select_zero_crossings.init.log \ + extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- \| \ + vector-extract-dims ark:- \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.zero_crossings.init.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_energies.init.log \ + vector-extract-dims ark:$dir/$utt_id.log_energies.ark \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.energies.init.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_vad.init.log \ + vector-extract-dims ark:$tmpdir/$utt_id.vad.init.ark \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.vad.0.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_speech_likes.init.log \ + vector-extract-dims \ + ark:$tmpdir/$utt_id.speech_likes.0.ark \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.speech_likes.init.ark || exit 1 + + x=0 + while [ $x -le $num_iters ]; do + ### Update Silence GMM using lowest energy chunks currently classified + ### as silence + $cmd $tmpdir/log/$utt_id.update_silence_gmm.$[x+1].log \ + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.$x.ark --select-class=0 \ + --select-bottom-frames=true --weights=ark:$tmpdir/$utt_id.energies.init.ark \ + --num-select-frames=$num_frames_silence \ + ark:$tmpdir/$utt_id.feats.init.ark ark:- \| \ + gmm-global-acc-stats \ + $tmpdir/$utt_id.silence.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$sil_num_gauss $tmpdir/$utt_id.silence.$x.mdl \ + - $tmpdir/$utt_id.silence.$[x+1].mdl || exit 1 + + ### Update Sound GMM using highest energy and highest zero crossing + ### chunks currently classified as silence + $cmd $tmpdir/log/$utt_id.update_sound_gmm.$[x+1].log \ + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.$x.ark --select-class=0 \ + --weights=ark:$tmpdir/$utt_id.zero_crossings.init.ark \ + --num-select-frames=$num_frames_sound \ + ark:$tmpdir/$utt_id.feats.init.ark ark:- \| \ + gmm-global-acc-stats \ + $tmpdir/$utt_id.sound.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$sound_num_gauss $tmpdir/$utt_id.sound.$x.mdl \ + - $tmpdir/$utt_id.sound.$[x+1].mdl || exit 1 + + ### Compute likelihoods with the current Silence and Sound GMMs + $cmd $tmpdir/log/$utt_id.compute_silence_likes.$[x+1].log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$[x+1].mdl \ + ark:$tmpdir/$utt_id.feats.init.ark ark:$tmpdir/$utt_id.silence_likes.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_sound_likes.$[x+1].log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.sound.$[x+1].mdl \ + ark:$tmpdir/$utt_id.feats.init.ark ark:$tmpdir/$utt_id.sound_likes.$[x+1].ark || exit 1 + + ### Get new VAD predictions on the subset selected for training + ### Silence and Sound GMMs + { + loglikes-to-class --post=ark:$tmpdir/$utt_id.post.$[x+1].ark \ + ark:$tmpdir/$utt_id.silence_likes.$[x+1].ark \ + ark:$tmpdir/$utt_id.speech_likes.init.ark \ + ark:$tmpdir/$utt_id.sound_likes.$[x+1].ark ark,t:- | \ + perl -pe 's/\[(.+)]/$1/' | \ + utils/apply_map.pl -f 2- <(echo -e "0 0\n1 1\n2 0") | \ + awk '{printf $1" [ "; for (i = 2; i <= NF; i++) {printf $i" ";}; print "]"}' | \ + copy-vector ark,t:- ark:$tmpdir/$utt_id.vad.$[x+1].ark ; + } &>$tmpdir/log/$utt_id.get_vad.$[x+1].log || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_silence_all_likes.$[x+1].log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$[x+1].mdl \ + "$feats" ark:$tmpdir/$utt_id.silence_all_likes.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_sound_all_likes.$[x+1].log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.sound.$[x+1].mdl \ + "$feats" ark:$tmpdir/$utt_id.sound_all_likes.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.get_pred.$[x+1].log \ + loglikes-to-class --post=ark:$tmpdir/$utt_id.pred_post.$[x+1].ark \ + ark:$tmpdir/$utt_id.silence_all_likes.$[x+1].ark \ + ark:$tmpdir/$utt_id.speech_likes.0.ark \ + ark:$tmpdir/$utt_id.sound_all_likes.$[x+1].ark \ + ark:$tmpdir/$utt_id.pred.$[x+1].ark || exit 1 + + x=$[x+1] + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_guass=$[sil_num_gauss + sil_gauss_incr] + num_frames_silence=$[num_frames_silence + sil_frames_incr] + fi + if [ $sound_num_gauss -lt $sound_max_gauss ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr] + num_frames_sound=$[num_frames_sound + sound_frames_incr] + num_frames_sound_next=$[num_frames_sound_next + sound_frames_next_incr] + fi + done ## Done training Silence and Speech GMMs + + #### Update Silence and Sound GMMs using new segmentation + #select-top-chunks --window-size=1 \ + # --selection-mask=ark:$phase2_dir/$utt_id.seg.init.ark --select-class=0 \ + # "$feats" ark:- | gmm-global-acc-stats \ # $tmpdir/$utt_id.silence.$x.mdl ark:- - | \ + # gmm-global-est --mix-up=$sil_num_gauss \ + # $tmpdir/$utt_id.silence.$x.mdl - $phase2_dir/$utt_id.silence.0.mdl || exit 1 + + #select-top-chunks --window-size=1 \ + # --selection-mask=ark:$phase2_dir/$utt_id.seg.init.ark --select-class=2 \ + # "$feats" ark:- | gmm-global-acc-stats \ + # $tmpdir/$utt_id.sound.$x.mdl ark:- - | \ + # gmm-global-est --mix-up=$sound_num_gauss \ + # $tmpdir/$utt_id.sound.$x.mdl - $phase2_dir/$utt_id.sound.0.mdl || exit 1 + + cp $tmpdir/$utt_id.silence.$x.mdl $phase2_dir/$utt_id.silence.0.mdl + cp $tmpdir/$utt_id.sound.$x.mdl $phase2_dir/$utt_id.sound.0.mdl + cp $tmpdir/$utt_id.speech.0.mdl $phase2_dir/$utt_id.speech.0.mdl + + else + cp $tmpdir/$utt_id.silence.0.mdl $phase2_dir/$utt_id.silence.0.mdl + cp $tmpdir/$utt_id.sound.0.mdl $phase2_dir/$utt_id.sound.0.mdl + cp $tmpdir/$utt_id.speech.0.mdl $phase2_dir/$utt_id.speech.0.mdl + fi + + x=0 + while [ $x -le $num_iters_phase2 ]; do + ### Compute likelihoods with the current Silence, Speech and Sound GMMs + $cmd $phase2_dir/log/$utt_id.compute_silence_likes.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.silence.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.silence_likes.$x.ark || exit 1 + + $cmd $phase2_dir/log/$utt_id.compute_sound_likes.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.sound.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.sound_likes.$x.ark || exit 1 + + $cmd $phase2_dir/log/$utt_id.compute_speech_likes.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.speech.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.speech_likes.$x.ark || exit 1 + + ### Get segmentation + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + loglikes-to-class --post=ark:$phase2_dir/$utt_id.pred_post.$x.ark \ + ark:$phase2_dir/$utt_id.silence_likes.$x.ark \ + ark:$phase2_dir/$utt_id.speech_likes.$x.ark \ + ark:$phase2_dir/$utt_id.sound_likes.$x.ark \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + ### Update Speech GMM + $cmd $phase2_dir/log/$utt_id.update_gmm_speech.$[x+1].log \ + select-top-chunks --window-size=$window_size_phase2 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=1 \ + "$feats" ark:- \| gmm-global-acc-stats \ + $phase2_dir/$utt_id.speech.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$speech_num_gauss \ + $phase2_dir/$utt_id.speech.$x.mdl - $phase2_dir/$utt_id.speech.$[x+1].mdl || exit 1 + + ### Update Silence GMM + $cmd $phase2_dir/log/$utt_id.update_gmm_silence.$[x+1].log \ + select-top-chunks --window-size=$window_size_phase2 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=0 \ + "$feats" ark:- \| gmm-global-acc-stats \ + $phase2_dir/$utt_id.silence.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$sil_num_gauss \ + $phase2_dir/$utt_id.silence.$x.mdl - $phase2_dir/$utt_id.silence.$[x+1].mdl || exit 1 + + ### Update Sound GMM + $cmd $phase2_dir/log/$utt_id.update_gmm_sound.$[x+1].log \ + select-top-chunks --window-size=$window_size_phase2 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=2 \ + "$feats" ark:- \| gmm-global-acc-stats \ + $phase2_dir/$utt_id.sound.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$sound_num_gauss \ + $phase2_dir/$utt_id.sound.$x.mdl - $phase2_dir/$utt_id.sound.$[x+1].mdl || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss_phase2 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase2] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss_phase2 ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr_phase2] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase2 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase2] + fi + + x=$[x+1] + done ## Done training all 3 GMMs + + if ! $try_merge_sound_speech; then + continue; + fi + + x=$[x-1] + mkdir -p $phase3_dir/log + + { + copy-vector ark:$phase2_dir/$utt_id.seg.$x.ark ark,t:- | \ + perl -pe 's/\[(.+)]/$1/' | \ + utils/apply_map.pl -f 2- <(echo -e "0 0\n1 1\n2 1") | \ + awk '{printf $1" [ "; for (i = 2; i <= NF; i++) {printf $i" ";}; print "]"}' | \ + copy-vector ark,t:- ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark; + } &> $phase3_dir/log/$utt_id.get_sil_nonsil.$x.log || exit 1 + + $cmd $phase3_dir/log/$utt_id.init_gmm_nonsil.$x.log \ + select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark --select-class=1 \ + "$feats" ark:- \| gmm-global-init-from-feats \ + --num-gauss=$[sound_num_gauss + speech_num_gauss] --num-iters=20 \ + ark:- $phase2_dir/$utt_id.nonsil.$x.mdl || exit 1 + + $cmd $phase2_dir/$utt_id.compute_silence_likes.pred.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.silence.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.silence_likes.pred.$x.ark || exit 1 + $cmd $phase2_dir/$utt_id.compute_nonsil_likes.pred.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.nonsil.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.nonsil_likes.pred.$x.ark || exit 1 + + $cmd $phase2_dir/$utt_id.get_pred.nonsil.log \ + loglikes-to-class \ + ark:$phase2_dir/$utt_id.silence_likes.pred.$x.ark \ + ark:$phase2_dir/$utt_id.nonsil_likes.pred.$x.ark \ + ark:$phase2_dir/$utt_id.pred.nonsil.ark || exit 1 + + nonsil_like=$(select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark --select-class=1 \ + "$feats" ark:- | gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.nonsil.$x.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_nonsil_like.$x.log || exit 1 + + speech_like=$(select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=1 \ + "$feats" ark:- | gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.speech.$x.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)' ) 2> $phase2_dir/$utt_id.compute_speech_like.$x.log || exit 1 + + sound_like=$(select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=2 \ + "$feats" ark:- | gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.sound.$x.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)' ) 2> $phase2_dir/$utt_id.compute_sound_like.$x.log || exit 1 + + merge_nonsil=false + if [ ! -z `perl -e "print \"true\" if ($sound_like + $speech_like < $nonsil_like)"` ]; then + merge_nonsil=true + fi + + if $merge_nonsil; then + speech_num_gauss=$speech_num_gauss_init_phase3 + sil_num_gauss=$sil_num_gauss_init_phase3 + + $cmd $phase3_dir/$utt_id.init_gmm_speech.log \ + select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark --select-class=1 \ + "$feats" ark:- \| gmm-global-init-from-feats \ + --num-gauss=$speech_num_gauss --num-iters=$[speech_num_gauss+2] \ + ark:- $phase3_dir/$utt_id.speech.0.mdl || exit 1 + + $cmd $phase3_dir/$utt_id.init_gmm_silence.log \ + select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark --select-class=0 \ + "$feats" ark:- \| gmm-global-init-from-feats \ + --num-gauss=$sil_num_gauss --num-iters=$[sil_num_gauss+2] \ + ark:- $phase3_dir/$utt_id.silence.0.mdl || exit 1 + + cp $phase2_dir/$utt_id.silence.$x.mdl $phase3_dir/$utt_id.silence.0.mdl || exit 1 + cp $phase2_dir/$utt_id.nonsil.$x.mdl $phase3_dir/$utt_id.speech.0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters_phase3 ]; do + ### Compute likelihoods with the current Silence and Speech + $cmd $phase3_dir/$utt_id.compute_silence_likes.$x.log \ + gmm-global-get-frame-likes $phase3_dir/$utt_id.silence.$x.mdl \ + "$feats" ark:$phase3_dir/$utt_id.silence_likes.$x.ark || exit 1 + + $cmd $phase3_dir/$utt_id.compute_speech_likes.$x.log \ + gmm-global-get-frame-likes $phase3_dir/$utt_id.speech.$x.mdl \ + "$feats" ark:$phase3_dir/$utt_id.speech_likes.$x.ark || exit 1 + + ### Get current VAD + $cmd $phase3_dir/$utt_id.get_vad.$x.log \ + loglikes-to-class \ + ark:$phase3_dir/$utt_id.silence_likes.$x.ark \ + ark:$phase3_dir/$utt_id.speech_likes.$x.ark \ + ark:$phase3_dir/$utt_id.vad.$x.ark || exit 1 + + ### Update Speech GMM + $cmd $phase3_dir/$utt_id.update_speech.$[x+1].log \ + select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase3_dir/$utt_id.vad.$x.ark --select-class=1 \ + "$feats" ark:- \| gmm-global-acc-stats \ + $phase3_dir/$utt_id.speech.$x.mdl ark:- - \| \ + gmm-global-est-map \ + $phase3_dir/$utt_id.speech.$x.mdl - $phase3_dir/$utt_id.speech.$[x+1].mdl || exit 1 + + ### Update Silence GMM + $cmd $phase3_dir/$utt_id.update_silence.$[x+1].log \ + select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase3_dir/$utt_id.vad.$x.ark --select-class=0 \ + "$feats" ark:- \| gmm-global-acc-stats \ + $phase3_dir/$utt_id.silence.$x.mdl ark:- - \| \ + gmm-global-est-map \ + $phase3_dir/$utt_id.silence.$x.mdl - $phase3_dir/$utt_id.silence.$[x+1].mdl || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss_phase3 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase3] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase3 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase3] + fi + + x=$[x+1] + done + + cp $phase3_dir/$utt_id.silence.$x.mdl $dir/$utt_id.silence.final.mdl + cp $phase3_dir/$utt_id.speech.$x.mdl $dir/$utt_id.speech.final.mdl + fi +done < $data/feats.scp + diff --git a/egs/sre08/v1/diarization/vad_gmm_icsi_pitch.sh b/egs/sre08/v1/diarization/vad_gmm_icsi_pitch.sh new file mode 100755 index 00000000000..6303fe397fc --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_icsi_pitch.sh @@ -0,0 +1,586 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -u +set -o pipefail + +cmd=run.pl +stage=-1 + +## Features paramters +window_size=100 # 1s +force_ignore_energy_opts= +try_merge_sound_speech=false +do_phase1=false +smooth_mask=false + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=8 +sil_gauss_incr=0 +sound_gauss_incr=2 +sil_frames_incr=2000 +sound_frames_incr=10000 +sound_frames_next_incr=2000 +num_iters=5 +min_sil_variance=1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=10000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2_init=10 +window_size_phase2_next=1 + +num_frames_init_speech_phase2=10000 +num_frames_init_silence_phase2=2000 +num_frames_init_sound_phase2=2000 +speech_frames_incr_phase2=2000 +sil_frames_incr_phase2=2000 +sound_frames_incr_phase2=2000 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 4 ]; then + echo "Usage: vad_gmm_icsi.sh " + echo " e.g.: vad_gmm_icsi.sh data/rt05_eval exp/librispeech_s5/vad_model/silence.0.mdl exp/librispeech_s5/vad_model/speech.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +dir=$4 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +init_model_dir=`dirname $init_speech_model` +ignore_energy_opts=`cat $init_model_dir/ignore_energy_opts` || exit 1 + +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" + +while IFS=$'\n' read line; do + feats="ark:echo $line | copy-feats scp:- ark:- | add-deltas ark:- ark:- |${ignore_energy_opts}" + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + if [ -f $data/segments ]; then + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-zero-crossings $zc_opts ark:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + $cmd $dir/log/$utt_id.extract_pitch.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + ark:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + else + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-zero-crossings $zc_opts scp:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + $cmd $dir/log/$utt_id.extract_pitch.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + scp:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + fi + + $cmd $dir/log/$utt_id.extract_log_energies.log \ + extract-column "scp:utils/filter_scp.pl $dir/$utt_id.list $data/feats.scp |" \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + + sil_num_gauss=$sil_num_gauss_init + sound_num_gauss=$sound_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + num_frames_silence=$num_frames_init_silence + num_frames_sound=$num_frames_init_sound + num_frames_sound_next=$num_frames_init_sound_next + + if $add_zero_crossing_feats; then + feats="${feats} paste-feats ark:- \"ark:add-deltas ark:$dir/$utt_id.zero_crossings.ark ark:- |\" ark:- |" + fi + + + ### Compute likelihoods wrt bootstrapping models + $cmd $dir/log/$utt_id.compute_speech_like.bootstrap.log \ + gmm-global-get-frame-likes $init_speech_model \ + "${feats}" ark:$dir/$utt_id.speech_likes.bootstrap.ark || exit 1 + + $cmd $dir/log/$utt_id.compute_silence_like.bootstrap.log \ + gmm-global-get-frame-likes $init_silence_model \ + "${feats}" ark:$dir/$utt_id.silence_likes.bootstrap.ark || exit 1 + + ### Get bootstrapping VAD + $cmd $tmpdir/log/$utt_id.get_vad.bootstrap.log \ + loglikes-to-class --post=ark:$dir/$utt_id.post.bootstrap.ark \ + ark:$dir/$utt_id.silence_likes.bootstrap.ark \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + if [ ! -z "$force_ignore_energy_opts" ]; then + ignore_energy_opts=$force_ignore_energy_opts + feats="ark:echo $line | copy-feats scp:- ark:- | add-deltas ark:- ark:- |${ignore_energy_opts}" + if $add_zero_crossing_feats; then + feats="${feats} paste-feats ark:- \"ark:add-deltas ark:$dir/$utt_id.zero_crossings.ark ark:- |\" ark:- |" + fi + fi + + ### Initialize Silence GMM using lowest energy chunks that were classified + ### as silence by the bootstrapping model + $cmd $tmpdir/log/$utt_id.init_silence_gmm.log \ + select-top-chunks \ + --window-size=$window_size --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=0 \ + --select-bottom-frames=true \ + --weights=ark:$dir/$utt_id.log_energies.ark --num-select-frames=$num_frames_silence \ + "${feats}" ark:- ark:$tmpdir/$utt_id.silence_mask.0.ark \| \ + gmm-global-init-from-feats \ + --min-variance=$min_sil_variance --num-gauss=$sil_num_gauss --num-iters=$[sil_num_gauss+2] ark:- \ + $tmpdir/$utt_id.silence.0.mdl || exit 1 + + ### Initialize Sound GMM using highest zero-crossing + ### chunks that were classified + ### as silence by the bootstrapping model. + $cmd $tmpdir/log/$utt_id.init_sound_gmm.log \ + select-top-chunks \ + --window-size=$window_size --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=0 \ + --weights="ark:extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- |" --num-select-frames=$num_frames_sound \ + "${feats}" ark:- ark:$tmpdir/$utt_id.sound_mask.0.ark \| \ + gmm-global-init-from-feats \ + --min-variance=$min_sound_variance --num-gauss=$sound_num_gauss --num-iters=$[sound_num_gauss+2] ark:- \ + $tmpdir/$utt_id.sound.0.mdl || exit 1 + + ### Initialize Speech GMM using highest NCCF chunks that were classified as + ### speech by the bootstrapping model. + $cmd $tmpdir/log/$utt_id.init_speech_gmm.log \ + select-top-chunks \ + --window-size=$window_size --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=1 \ + --weights="ark:extract-column ark:$dir/$utt_id.kaldi_pitch.ark ark:- |" \ + --num-select-frames=$num_frames_init_speech \ + "${feats}" ark:- ark:$tmpdir/$utt_id.speech_mask.0.ark \| \ + gmm-global-init-from-feats \ + --num-gauss=$sound_num_gauss --num-iters=$[speech_num_gauss+2] ark:- \ + $tmpdir/$utt_id.speech.0.mdl || exit 1 + + if $do_phase1; then + + ### Compute likelihoods with the newly initialized Silence and Sound GMMs + + $cmd $tmpdir/log/$utt_id.compute_silence_likes.0.log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.0.mdl \ + "${feats}" ark:$tmpdir/$utt_id.silence_likes.0.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_sound_likes.0.log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.sound.0.mdl \ + "${feats}" ark:$tmpdir/$utt_id.sound_likes.0.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_speech_likes.0.log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.speech.0.mdl \ + "${feats}" ark:$tmpdir/$utt_id.speech_likes.0.ark || exit 1 + + ### Get initial VAD + { + loglikes-to-class --post=ark:$tmpdir/$utt_id.post.init.ark \ + ark:$tmpdir/$utt_id.silence_likes.0.ark \ + ark:$tmpdir/$utt_id.speech_likes.0.ark \ + ark:$tmpdir/$utt_id.sound_likes.0.ark ark,t:- | \ + perl -pe 's/\[(.+)]/$1/' | \ + utils/apply_map.pl -f 2- <(echo -e "0 0\n1 1\n2 0") | \ + awk '{printf $1" [ "; for (i = 2; i <= NF; i++) {printf $i" ";}; print "]"}' | \ + copy-vector ark,t:- ark:$tmpdir/$utt_id.vad.init.ark ; + } &> $tmpdir/log/$utt_id.get_vad.init.log || exit 1 + + ### Remove frames that were originally classified as speech + ### while training Silence and Sound GMMs + $cmd $tmpdir/log/$utt_id.select_feats_phase1.init.log \ + select-top-chunks \ + --window-size=$window_size --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$tmpdir/$utt_id.vad.init.ark --select-class=0 \ + "$feats" ark:$tmpdir/$utt_id.feats.init.ark \ + ark:$tmpdir/$utt_id.mask.init.ark || exit 1 + + ## Select energies and zero crossings corresponding to the same selection + + $cmd $tmpdir/log/$utt_id.select_zero_crossings.init.log \ + extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- \| \ + vector-extract-dims ark:- \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.zero_crossings.init.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_energies.init.log \ + vector-extract-dims ark:$dir/$utt_id.log_energies.ark \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.energies.init.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_vad.init.log \ + vector-extract-dims ark:$tmpdir/$utt_id.vad.init.ark \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.vad.0.ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_speech_likes.init.log \ + vector-extract-dims \ + ark:$tmpdir/$utt_id.speech_likes.0.ark \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.speech_likes.init.ark || exit 1 + + x=0 + while [ $x -le $num_iters ]; do + ### Update Silence GMM using lowest energy chunks currently classified + ### as silence + $cmd $tmpdir/log/$utt_id.update_silence_gmm.$[x+1].log \ + select-top-chunks \ + --window-size=$window_size --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$tmpdir/$utt_id.vad.$x.ark --select-class=0 \ + --select-bottom-frames=true --weights=ark:$tmpdir/$utt_id.energies.init.ark \ + --num-select-frames=$num_frames_silence \ + ark:$tmpdir/$utt_id.feats.init.ark ark:- \ + ark:$tmpdir/$utt_id.mask_silence.$[x+1].ark \| \ + gmm-global-acc-stats \ + $tmpdir/$utt_id.silence.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$sil_num_gauss $tmpdir/$utt_id.silence.$x.mdl \ + - $tmpdir/$utt_id.silence.$[x+1].mdl || exit 1 + + ### Update Sound GMM using highest energy and highest zero crossing + ### chunks currently classified as silence + $cmd $tmpdir/log/$utt_id.update_sound_gmm.$[x+1].log \ + select-top-chunks \ + --window-size=$window_size --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$tmpdir/$utt_id.vad.$x.ark --select-class=0 \ + --weights=ark:$tmpdir/$utt_id.zero_crossings.init.ark \ + --num-select-frames=$num_frames_sound \ + ark:$tmpdir/$utt_id.feats.init.ark ark:- \ + ark:$tmpdir/$utt_id.mask_sound.$[x+1].ark \| \ + gmm-global-acc-stats \ + $tmpdir/$utt_id.sound.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$sound_num_gauss $tmpdir/$utt_id.sound.$x.mdl \ + - $tmpdir/$utt_id.sound.$[x+1].mdl || exit 1 + + ### Compute likelihoods with the current Silence and Sound GMMs + $cmd $tmpdir/log/$utt_id.compute_silence_likes.$[x+1].log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$[x+1].mdl \ + ark:$tmpdir/$utt_id.feats.init.ark ark:$tmpdir/$utt_id.silence_likes.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_sound_likes.$[x+1].log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.sound.$[x+1].mdl \ + ark:$tmpdir/$utt_id.feats.init.ark ark:$tmpdir/$utt_id.sound_likes.$[x+1].ark || exit 1 + + ### Get new VAD predictions on the subset selected for training + ### Silence and Sound GMMs + { + loglikes-to-class --post=ark:$tmpdir/$utt_id.post.$[x+1].ark \ + ark:$tmpdir/$utt_id.silence_likes.$[x+1].ark \ + ark:$tmpdir/$utt_id.speech_likes.init.ark \ + ark:$tmpdir/$utt_id.sound_likes.$[x+1].ark ark,t:- | \ + perl -pe 's/\[(.+)]/$1/' | \ + utils/apply_map.pl -f 2- <(echo -e "0 0\n1 1\n2 0") | \ + awk '{printf $1" [ "; for (i = 2; i <= NF; i++) {printf $i" ";}; print "]"}' | \ + copy-vector ark,t:- ark:$tmpdir/$utt_id.vad.$[x+1].ark ; + } &>$tmpdir/log/$utt_id.get_vad.$[x+1].log || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_silence_all_likes.$[x+1].log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$[x+1].mdl \ + "$feats" ark:$tmpdir/$utt_id.silence_all_likes.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.compute_sound_all_likes.$[x+1].log \ + gmm-global-get-frame-likes $tmpdir/$utt_id.sound.$[x+1].mdl \ + "$feats" ark:$tmpdir/$utt_id.sound_all_likes.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.get_pred.$[x+1].log \ + loglikes-to-class --post=ark:$tmpdir/$utt_id.pred_post.$[x+1].ark \ + ark:$tmpdir/$utt_id.silence_all_likes.$[x+1].ark \ + ark:$tmpdir/$utt_id.speech_likes.0.ark \ + ark:$tmpdir/$utt_id.sound_all_likes.$[x+1].ark \ + ark:$tmpdir/$utt_id.pred.$[x+1].ark || exit 1 + + x=$[x+1] + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_guass=$[sil_num_gauss + sil_gauss_incr] + num_frames_silence=$[num_frames_silence + sil_frames_incr] + fi + if [ $sound_num_gauss -lt $sound_max_gauss ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr] + num_frames_sound=$[num_frames_sound + sound_frames_incr] + num_frames_sound_next=$[num_frames_sound_next + sound_frames_next_incr] + fi + done ## Done training Silence and Speech GMMs + + #### Update Silence and Sound GMMs using new segmentation + #select-top-chunks --window-size=1 \ + # --selection-mask=ark:$phase2_dir/$utt_id.seg.init.ark --select-class=0 \ + # "$feats" ark:- | gmm-global-acc-stats \ # $tmpdir/$utt_id.silence.$x.mdl ark:- - | \ + # gmm-global-est --mix-up=$sil_num_gauss \ + # $tmpdir/$utt_id.silence.$x.mdl - $phase2_dir/$utt_id.silence.0.mdl || exit 1 + + #select-top-chunks --window-size=1 \ + # --selection-mask=ark:$phase2_dir/$utt_id.seg.init.ark --select-class=2 \ + # "$feats" ark:- | gmm-global-acc-stats \ + # $tmpdir/$utt_id.sound.$x.mdl ark:- - | \ + # gmm-global-est --mix-up=$sound_num_gauss \ + # $tmpdir/$utt_id.sound.$x.mdl - $phase2_dir/$utt_id.sound.0.mdl || exit 1 + + cp $tmpdir/$utt_id.silence.$x.mdl $phase2_dir/$utt_id.silence.0.mdl + cp $tmpdir/$utt_id.sound.$x.mdl $phase2_dir/$utt_id.sound.0.mdl + cp $tmpdir/$utt_id.speech.0.mdl $phase2_dir/$utt_id.speech.0.mdl + + else + cp $tmpdir/$utt_id.silence.0.mdl $phase2_dir/$utt_id.silence.0.mdl + cp $tmpdir/$utt_id.sound.0.mdl $phase2_dir/$utt_id.sound.0.mdl + cp $tmpdir/$utt_id.speech.0.mdl $phase2_dir/$utt_id.speech.0.mdl + fi + + num_frames_speech=$num_frames_init_speech_phase2 + num_frames_silence=$num_frames_init_silence_phase2 + num_frames_sound=$num_frames_init_sound_phase2 + window_size_phase2=$window_size_phase2_init + + x=0 + while [ $x -le $num_iters_phase2 ]; do + ### Compute likelihoods with the current Silence, Speech and Sound GMMs + $cmd $phase2_dir/log/$utt_id.compute_silence_likes.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.silence.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.silence_likes.$x.ark || exit 1 + + $cmd $phase2_dir/log/$utt_id.compute_sound_likes.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.sound.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.sound_likes.$x.ark || exit 1 + + $cmd $phase2_dir/log/$utt_id.compute_speech_likes.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.speech.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.speech_likes.$x.ark || exit 1 + + ### Get segmentation + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + loglikes-to-class --post=ark:$phase2_dir/$utt_id.pred_post.$x.ark \ + ark:$phase2_dir/$utt_id.silence_likes.$x.ark \ + ark:$phase2_dir/$utt_id.speech_likes.$x.ark \ + ark:$phase2_dir/$utt_id.sound_likes.$x.ark \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + ### Update Speech GMM + $cmd $phase2_dir/log/$utt_id.update_gmm_speech.$[x+1].log \ + select-top-chunks --window-size=$window_size_phase2 --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=1 \ + --num-select-frames=$num_frames_speech \ + "$feats" ark:- ark:$phase2_dir/$utt_id.speech_mask.$[x+1].ark \| \ + gmm-global-acc-stats \ + $phase2_dir/$utt_id.speech.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$speech_num_gauss \ + $phase2_dir/$utt_id.speech.$x.mdl - $phase2_dir/$utt_id.speech.$[x+1].mdl || exit 1 + + ### Update Silence GMM + $cmd $phase2_dir/log/$utt_id.update_gmm_silence.$[x+1].log \ + select-top-chunks --window-size=$window_size_phase2 --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=0 \ + --num-select-frames=$num_frames_silence \ + "$feats" ark:- ark:$phase2_dir/$utt_id.silence_mask.$[x+1].ark \| \ + gmm-global-acc-stats \ + $phase2_dir/$utt_id.silence.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$sil_num_gauss \ + $phase2_dir/$utt_id.silence.$x.mdl - $phase2_dir/$utt_id.silence.$[x+1].mdl || exit 1 + + ### Update Sound GMM + $cmd $phase2_dir/log/$utt_id.update_gmm_sound.$[x+1].log \ + select-top-chunks --window-size=$window_size_phase2 --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=2 \ + --num-select-frames=$num_frames_sound \ + "$feats" ark:- ark:$phase2_dir/$utt_id.sound_mask.$[x+1].ark \| \ + gmm-global-acc-stats \ + $phase2_dir/$utt_id.sound.$x.mdl ark:- - \| \ + gmm-global-est --mix-up=$sound_num_gauss \ + $phase2_dir/$utt_id.sound.$x.mdl - $phase2_dir/$utt_id.sound.$[x+1].mdl || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss_phase2 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase2] + num_frames_silence=$[num_frames_silence + sil_frames_incr_phase2] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss_phase2 ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr_phase2] + num_frames_sound=$[num_frames_sound + sound_frames_incr_phase2] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase2 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase2] + num_frames_speech=$[num_frames_speech + speech_frames_incr_phase2] + fi + + if [ $x -gt $window_size_incr_iter ]; then + window_size_phase2=$window_size_phase2_next + fi + + x=$[x+1] + done ## Done training all 3 GMMs + + if ! $try_merge_sound_speech; then + continue; + fi + + x=$[x-1] + mkdir -p $phase3_dir/log + + { + copy-vector ark:$phase2_dir/$utt_id.seg.$x.ark ark,t:- | \ + perl -pe 's/\[(.+)]/$1/' | \ + utils/apply_map.pl -f 2- <(echo -e "0 0\n1 1\n2 1") | \ + awk '{printf $1" [ "; for (i = 2; i <= NF; i++) {printf $i" ";}; print "]"}' | \ + copy-vector ark,t:- ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark; + } &> $phase3_dir/log/$utt_id.get_sil_nonsil.$x.log || exit 1 + + $cmd $phase3_dir/log/$utt_id.init_gmm_nonsil.$x.log \ + select-top-chunks --window-size=1 --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark --select-class=1 \ + "$feats" ark:- ark:$phase2_dir/$utt_id.nonsil.$x.ark \| \ + gmm-global-init-from-feats \ + --num-gauss=$[sound_num_gauss + speech_num_gauss] --num-iters=20 \ + ark:- $phase2_dir/$utt_id.nonsil.$x.mdl || exit 1 + + $cmd $phase2_dir/$utt_id.compute_silence_likes.pred.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.silence.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.silence_likes.pred.$x.ark || exit 1 + $cmd $phase2_dir/$utt_id.compute_nonsil_likes.pred.$x.log \ + gmm-global-get-frame-likes $phase2_dir/$utt_id.nonsil.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.nonsil_likes.pred.$x.ark || exit 1 + + $cmd $phase2_dir/$utt_id.get_pred.nonsil.log \ + loglikes-to-class \ + ark:$phase2_dir/$utt_id.silence_likes.pred.$x.ark \ + ark:$phase2_dir/$utt_id.nonsil_likes.pred.$x.ark \ + ark:$phase2_dir/$utt_id.pred.nonsil.ark || exit 1 + + nonsil_like=$(select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark --select-class=1 \ + "$feats" ark:- | gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.nonsil.$x.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_nonsil_like.$x.log || exit 1 + + speech_like=$(select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=1 \ + "$feats" ark:- | gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.speech.$x.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)' ) 2> $phase2_dir/$utt_id.compute_speech_like.$x.log || exit 1 + + sound_like=$(select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=2 \ + "$feats" ark:- | gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.sound.$x.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)' ) 2> $phase2_dir/$utt_id.compute_sound_like.$x.log || exit 1 + + merge_nonsil=false + if [ ! -z `perl -e "print \"true\" if ($sound_like + $speech_like < $nonsil_like)"` ]; then + merge_nonsil=true + fi + + if $merge_nonsil; then + speech_num_gauss=$speech_num_gauss_init_phase3 + sil_num_gauss=$sil_num_gauss_init_phase3 + + $cmd $phase3_dir/$utt_id.init_gmm_speech.log \ + select-top-chunks --window-size=1 --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark --select-class=1 \ + "$feats" ark:- ark:$phase3_dir/$utt_id.speech_mask.0.ark \| \ + gmm-global-init-from-feats \ + --num-gauss=$speech_num_gauss --num-iters=$[speech_num_gauss+2] \ + ark:- $phase3_dir/$utt_id.speech.0.mdl || exit 1 + + $cmd $phase3_dir/$utt_id.init_gmm_silence.log \ + select-top-chunks --window-size=1 --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$phase3_dir/$utt_id.sil_nonsil.$x.ark --select-class=0 \ + "$feats" ark:- ark:$phase3_dir/$utt_id.silence_mask.0.ark \| \ + gmm-global-init-from-feats \ --num-gauss=$sil_num_gauss --num-iters=$[sil_num_gauss+2] \ + ark:- $phase3_dir/$utt_id.silence.0.mdl || exit 1 + + cp $phase2_dir/$utt_id.silence.$x.mdl $phase3_dir/$utt_id.silence.0.mdl || exit 1 + cp $phase2_dir/$utt_id.nonsil.$x.mdl $phase3_dir/$utt_id.speech.0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters_phase3 ]; do + ### Compute likelihoods with the current Silence and Speech + $cmd $phase3_dir/$utt_id.compute_silence_likes.$x.log \ + gmm-global-get-frame-likes $phase3_dir/$utt_id.silence.$x.mdl \ + "$feats" ark:$phase3_dir/$utt_id.silence_likes.$x.ark || exit 1 + + $cmd $phase3_dir/$utt_id.compute_speech_likes.$x.log \ + gmm-global-get-frame-likes $phase3_dir/$utt_id.speech.$x.mdl \ + "$feats" ark:$phase3_dir/$utt_id.speech_likes.$x.ark || exit 1 + + ### Get current VAD + $cmd $phase3_dir/$utt_id.get_vad.$x.log \ + loglikes-to-class \ + ark:$phase3_dir/$utt_id.silence_likes.$x.ark \ + ark:$phase3_dir/$utt_id.speech_likes.$x.ark \ + ark:$phase3_dir/$utt_id.vad.$x.ark || exit 1 + + ### Update Speech GMM + $cmd $phase3_dir/$utt_id.update_speech.$[x+1].log \ + select-top-chunks --window-size=1 --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$phase3_dir/$utt_id.vad.$x.ark --select-class=1 \ + "$feats" ark:- ark:$phase3_dir/$utt_id.speech_mask.$[x+1].ark \| \ + gmm-global-acc-stats \ + $phase3_dir/$utt_id.speech.$x.mdl ark:- - \| \ + gmm-global-est \ + $phase3_dir/$utt_id.speech.$x.mdl - $phase3_dir/$utt_id.speech.$[x+1].mdl || exit 1 + + ### Update Silence GMM + $cmd $phase3_dir/$utt_id.update_silence.$[x+1].log \ + select-top-chunks --window-size=1 --smoothing-window=$smoothing_window --smooth-mask=$smooth_mask \ + --selection-mask=ark:$phase3_dir/$utt_id.vad.$x.ark --select-class=0 \ + "$feats" ark:- ark:$phase3_dir/$utt_id.silence_mask.$[x+1].ark \| \ + gmm-global-acc-stats \ + $phase3_dir/$utt_id.silence.$x.mdl ark:- - \| \ + gmm-global-est \ + $phase3_dir/$utt_id.silence.$x.mdl - $phase3_dir/$utt_id.silence.$[x+1].mdl || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss_phase3 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase3] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase3 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase3] + fi + + x=$[x+1] + done + + cp $phase3_dir/$utt_id.silence.$x.mdl $dir/$utt_id.silence.final.mdl + cp $phase3_dir/$utt_id.speech.$x.mdl $dir/$utt_id.speech.final.mdl + fi +done < $data/feats.scp + diff --git a/egs/sre08/v1/diarization/vad_gmm_icsi_vimal.sh b/egs/sre08/v1/diarization/vad_gmm_icsi_vimal.sh new file mode 100755 index 00000000000..52d1e230612 --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_icsi_vimal.sh @@ -0,0 +1,606 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -u +set -o pipefail + +cmd=run.pl +stage=-1 +allow_partial=true + +## Features paramters +window_size=100 # 1s +min_data=200 +frames_per_gaussian=2000 +num_bins=100 + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=8 +sil_gauss_incr=0 +sound_gauss_incr=2 +sample_per_gaussian=2000 +num_iters_init=3 +num_iters=5 +min_sil_variance=1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +num_frames_init_speech=100000 +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2=10 + +## Phase 3 parameters +num_frames_silence_phase3_init=2000 +num_frames_speech_phase3_init=2000 +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + +speech_to_sil_ratio=1 + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 4 ]; then + echo "Usage: vad_gmm_icsi.sh " + echo " e.g.: vad_gmm_icsi.sh data/rt05_eval exp/librispeech_s5/vad_model/silence.0.mdl exp/librispeech_s5/vad_model/speech.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +dir=$4 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +init_model_dir=`dirname $init_speech_model` +ignore_energy_opts=`cat $init_model_dir/ignore_energy_opts` || exit 1 +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 + +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" + +# Prepare a lang directory +if [ $stage -le -12 ]; then + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + mkdir -p $dir/local/dict_2class + mkdir -p $dir/local/lm_2class + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo "1" > $dir/local/dict_2class/silence_phones.txt + echo "1" > $dir/local/dict_2class/optional_silence.txt + echo "2" > $dir/local/dict_2class/nonsilence_phones.txt + echo "3" >> $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict_2class/lexicon.txt + echo -e "1 1\n2 2\n3 3" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict_2class/extra_questions.txt + echo -e "1\n2\n1 2\n3\n1 3\n2 3\n1 2 3" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict_2class $dir/local/lang_2class $dir/lang_2class || exit 1 +fi + +feat_dim=`gmm-global-info $init_speech_model | grep "feature dimension" | awk '{print $NF}'` || exit 1 + +if [ $stage -le -11 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + + run.pl $dir/log/create_transition_model_2class.log gmm-init-mono \ + $dir/lang_2class/topo $feat_dim - $dir/tree_2class \| \ + copy-transition-model --binary=false - $dir/trans_2class.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $dir/lang_2class $dir $dir/graph_2class || exit 1 +fi + +if [ $stage -le -10 ]; then + { + cat $dir/trans_2class.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $init_silence_model - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + } | gmm-copy - $dir/init.mdl || exit 1 +fi + +if [ $stage -le -9 ]; then + t=$speech_to_sil_ratio + lang=$dir/lang_test_${t}x + cp -r $dir/lang $lang + perl -e "print \"0 0 1 1 \" . -log(1/$[t+3]) . \"\n0 0 2 2 \". -log($t/$[t+3]). \"\n0 0 3 3 \". -log(1/$[t+3]) .\"\n0 \". -log(1/$[t+3])" | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + diarization/make_vad_graph.sh --iter trans $lang $dir $dir/graph_test_${t}x || exit 1 + + lang=$dir/lang_2class_test_${t}x + cp -r $dir/lang_2class $lang + perl -e "print \"0 0 1 1 \" . -log(1/$[t+2]) . \"\n0 0 2 2 \". -log($t/$[t+2]). \"\n0 \". -log(1/$[t+2])" | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $lang $dir $dir/graph_2class_test_${t}x || exit 1 +fi + + +while IFS=$'\n' read line; do + feats="ark:echo $line | apply-cmvn --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |${ignore_energy_opts}" + + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + if [ -f $data/segments ]; then + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-zero-crossings $zc_opts ark:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + # extract-segments scp:$data/wav.scp - ark:- \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # ark:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + ark:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + else + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-zero-crossings $zc_opts scp:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + #$cmd $dir/log/$utt_id.extract_pitch.log \ + # utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + # compute-kaldi-pitch-feats --config=conf/pitch.conf --frames-per-chunk=10 --simulate-first-pass-online=true \ + # scp:- ark:$dir/$utt_id.kaldi_pitch.ark || exit 1 + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + scp:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + fi + + sil_num_gauss=$sil_num_gauss_init + sound_num_gauss=$sound_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + + if $add_zero_crossing_feats; then + feats="${feats} paste-feats ark:- ark:$dir/$utt_id.zero_crossings.ark ark:- |" + fi + + feats="${feats} add-deltas ark:- ark:- |" + + # Get VAD: 0 for silence, 1 for speech and 2 for sound + $cmd $dir/log/$utt_id.get_vad.bootstrap.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $dir/init.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| ali-to-pdf $dir/init.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + cp $tmpdir/$utt_id.vad.bootstrap.ark $tmpdir/$utt_id.seg.0.ark + + x=0 + goto_phase3=false + + while [ $x -lt $num_iters ]; do + num_frames_silence=$[num_frames_init_silence + sil_num_gauss * frames_per_gaussian ] + num_frames_sound=$[num_frames_init_sound + 5 * sound_num_gauss * frames_per_gaussian ] + num_frames_sound_next=$[num_frames_init_sound_next + sound_num_gauss * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + $cmd $tmpdir/log/$utt_id.select_top.first.$[x+1].log \ + segmentation-copy --filter-label=0 \ + --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$tmpdir/$utt_id.seg.$x.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --merge-labels=0:2 --merge-dst-label=0 \ + --num-top-frames=$num_frames_sound --num-bottom-frames=$num_frames_silence \ + --top-select-label=2 --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark || exit 1 + + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-select-top --num-bins=$num_bins --src-label=2 \ + --num-top-frames=$num_frames_sound_next --num-bottom-frames=-1 \ + --top-select-label=2 --bottom-select-label=-1 --reject-label=1001 \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:$tmpdir/$utt_id.seg.first.$[x+1].ark "ark:extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- |" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + else + $cmd $tmpdir/log/$utt_id.select_top.$[x+1].log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=0 ark:$tmpdir/$utt_id.seg.$x.ark \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans.mdl; + echo " $feat_dim 3"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + gmm-global-copy --binary=false $init_speech_model -; + select-feats-from-segmentation --select-label=2 "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sound_num_gauss+1] --num-gauss-init=1 --num-gauss=$sound_num_gauss \ + ark:- - || exit 1 + } 2> $tmpdir/log/$utt_id.init_gmm.log | \ + gmm-copy - $tmpdir/$utt_id.$[x+1].mdl 2>> $tmpdir/log/$utt_id.init_gmm.log || exit 1 + else + #$cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation --pdfs=0:2 \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + # $tmpdir/$utt_id.$x.mdl "$feats" \ + # ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + # $tmpdir/$utt_id.$[x+1].mdl || exit 1 + $cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation --pdfs=0:2 \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n2 $sound_num_gauss\" |" \ + $tmpdir/$utt_id.$x.mdl "$feats" \ + ark:$tmpdir/$utt_id.seg.second.$[x+1].ark \ + $tmpdir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $tmpdir/log/$utt_id.get_seg.$[x+1].log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $tmpdir/$utt_id.$[x+1].mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $tmpdir/$utt_id.$[x+1].mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.seg.$[x+1].ark || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr] + fi + + x=$[x+1] + done ## Done training Silence and Speech GMMs + + $cmd $phase2_dir/log/$utt_id.init_speech.log \ + segmentation-copy --filter-rspecifier=ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + --filter-label=1 ark:$tmpdir/$utt_id.seg.$num_iters.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- $phase2_dir/$utt_id.speech.0.mdl + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.init_speech.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 1. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.init_speech.log. Going to phase 3." + goto_phase3=true + fi + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $tmpdir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + + if ! $goto_phase3; then + $cmd $phase2_dir/log/$utt_id.init_gmm.log \ + gmm-init-pdf-from-global $tmpdir/$utt_id.$num_iters.mdl 1 \ + $phase2_dir/$utt_id.speech.0.mdl $phase2_dir/$utt_id.0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters_phase2 ]; do + if [ $sil_num_gauss -lt $sil_max_gauss_phase2 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase2] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss_phase2 ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr_phase2] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase2 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase2] + fi + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + #$cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + # $phase2_dir/$utt_id.$x.mdl "$feats" \ + # ark:$phase2_dir/$utt_id.seg.$x.ark \ + # $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + $cmd $phase2_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + $phase2_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase2_dir/$utt_id.seg.$x.ark \ + $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + + x=$[x+1] + done ## Done training all 3 GMMs + cp $phase2_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + + $cmd $phase2_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase2_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + mkdir -p $phase3_dir/log + + $cmd $phase2_dir/log/$utt_id.init_nonsil.log \ + segmentation-copy --merge-labels=1:2 --merge-dst-label=1 \ + ark:$phase2_dir/$utt_id.seg.$x.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[sound_num_gauss + speech_num_gauss + 1] \ + --num-gauss-init=1 \ + --num-gauss=$[sound_num_gauss + speech_num_gauss] ark:- \ + $phase2_dir/$utt_id.$x.nonsil.mdl || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_speech_feats.$x.log \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark + + if $goto_phase3; then + rm -f $dir/$utt_id.current_seg.ark + ln -s $phase2_dir/$utt_id.seg.$x.ark $dir/$utt_id.current_seg.ark + fi + + if [ $? -eq 0 ]; then + num_selected_speech=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_speech_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 2. $num_selected_speech < $min_data. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for speech at the end of phase 1. See $phase2_dir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + + if ! $goto_phase3; then + speech_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 1 - |" \ + ark:$phase2_dir/$utt_id.speech_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_speech_like.$x.log || exit 1 + + $cmd $phase2_dir/log/$utt_id.select_sound_feats.$x.log \ + select-feats-from-segmentation --select-label=2 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark + + if [ $? -eq 0 ]; then + num_selected_sound=$(grep "Processed .* segmentations; selected" $phase2_dir/log/$utt_id.select_sound_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_sound -lt $min_data ]; then + echo "Insufficient frames for sound at the end of phase 2. $num_selected_sound < $min_data. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + echo "Failed to find any data for sound at the end of phase 1. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + fi + + if ! $goto_phase3; then + sound_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 1 - |" \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_sound_like.$x.log || exit 1 + + nonsil_like=$(select-feats-from-segmentation --merge-labels=1:2 --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark ark:- | \ + gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.$x.nonsil.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_nonsil_like.$x.log || exit 1 + + if [ ! -z `perl -e "print \"true\" if ($sound_like + $speech_like < $nonsil_like)"` ]; then + goto_phase3=true + fi + fi + fi + + if $goto_phase3; then + speech_num_gauss=$speech_num_gauss_init_phase3 + sil_num_gauss=$sil_num_gauss_init_phase3 + + $cmd $phase3_dir/log/$utt_id.compute_silence_likes.bootstrap.log \ + gmm-global-get-frame-likes $init_silence_model "$feats" \ + ark:$dir/$utt_id.silence_log_likes.bootstrap.ark || exit 1 + + $cmd $phase3_dir/log/$utt_id.compute_speech_likes.bootstrap.log \ + gmm-global-get-frame-likes $init_speech_model "$feats" \ + ark:$dir/$utt_id.speech_log_likes.bootstrap.ark || exit 1 + + { + cat $dir/trans_2class.mdl; + echo " $feat_dim 2"; + segmentation-select-top --num-bins=$num_bins \ + --src-label=0 --num-top-frames=$[200 * frames_per_gaussian] \ + --top-select-label=0 --bottom-select-label=-1 \ + --reject-label=1000 --select-above-mean=true \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=1 --min-window-remainder=1 \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$dir/$utt_id.silence_log_likes.bootstrap.ark ark:- | \ + select-feats-from-segmentation --select-label=0 "$feats" ark:- ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- - || exit 1 + } 2> $phase3_dir/log/$utt_id.check_gmm.log | \ + gmm-copy - $phase3_dir/$utt_id.check.mdl 2>> $phase3_dir/log/$utt_id.check_gmm.log + + $cmd $phase3_dir/log/$utt_id.get_seg.check.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.check.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.check.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.seg.check.ark || exit 1 + + num_frames_speech=$(select-feats-from-segmentation --select-label=1 "$feats" \ + ark:$phase3_dir/$utt_id.seg.check.ark ark:- | \ + feat-to-len ark:- ark,t:- | awk '{i+=$2} END{print i}') + + phase3_done=false + if [ $num_frames_speech -lt $min_data ]; then + phase3_done=true + fi + + if ! $phase3_done; then + x=0 + + $cmd $phase3_dir/log/$utt_id.init_silence_gmm.log \ + segmentation-select-top --num-bins=$num_bins \ + --src-label=0 --num-top-frames=$num_frames_silence_phase3_init \ + --top-select-label=0 --bottom-select-label=-1 \ + --reject-label=1000 --select-above-mean=true \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=1 --min-window-remainder=1 \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$dir/$utt_id.silence_log_likes.bootstrap.ark ark:- \| \ + select-feats-from-segmentation --select-label=0 "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - \| \ + gmm-init-pdf-from-global $dir/init.mdl 0 - \ + $phase3_dir/$utt_id.tmp.mdl || exit 1 + + $cmd $phase3_dir/log/$utt_id.init_speech_gmm.log \ + segmentation-select-top --num-bins=$num_bins \ + --src-label=1 --num-top-frames=$num_frames_speech_phase3_init \ + --top-select-label=1 --bottom-select-label=-1 \ + --reject-label=1000 --select-above-mean=true \ + --remove-rejected-frames=true --select-from-full-histogram=true \ + --window-size=1 --min-window-remainder=1 \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark:$dir/$utt_id.speech_log_likes.bootstrap.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- - \| \ + gmm-init-pdf-from-global $phase3_dir/$utt_id.tmp.mdl 1 - \ + $phase3_dir/$utt_id.0.mdl || exit 1 + + x=0 + while [ $x -lt $num_iters_phase3 ]; do + $cmd $phase3_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.$x.mdl $dir/graph_2class/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.seg.$x.ark || exit 1 + + #$cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + # gmm-est-segmentation \ + # --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\" |" \ + # $phase3_dir/$utt_id.$x.mdl "$feats" \ + # ark:$phase3_dir/$utt_id.seg.$x.ark \ + # $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + $cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\" |" \ + $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.seg.$x.ark \ + $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss_phase3 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase3] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase3 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase3] + fi + + x=$[x+1] + done ## Done training all 2 GMMs + + cp $phase3_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_2class_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + else + echo "Not going to phase3" + fi + fi + + $cmd $dir/log/$utt_id.get_seg.final.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/$utt_id.graph_final/words.txt \ + $dir/$utt_id.final.mdl $dir/$utt_id.graph_final/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $dir/$utt_id.final.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$dir/$utt_id.vad.final.ark || exit 1 + +done < $data/feats.scp diff --git a/egs/sre08/v1/diarization/vad_gmm_ntu.sh b/egs/sre08/v1/diarization/vad_gmm_ntu.sh new file mode 100755 index 00000000000..ae67c842288 --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_ntu.sh @@ -0,0 +1,297 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -u +set -e +set -o pipefail + +stage=-1 + +## Features paramters +window_size=100 # 1s +filter_using_zero_crossings=true +ignore_energy_opts="-1" + +## Phase 1 parameters +num_frames_silence_init=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +sil_num_gauss_init=2 +sil_max_gauss=2 +sil_gauss_incr=0 +silence_frames_incr=2000 +num_iters=5 +min_sil_variance=1 +min_speech_variance=0 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2=10 + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 4 ]; then + echo "Usage: vad_gmm_icsi.sh " + echo " e.g.: vad_gmm_icsi.sh data/rt05_eval exp/librispeech_s5/vad_model/silence.0.mdl exp/librispeech_s5/vad_model/speech.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +dir=$4 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +init_model_dir=`dirname $init_speech_model` +if [ "$ignore_energy_opts" == "-1" ]; then + ignore_energy_opts=`cat $init_model_dir/ignore_energy_opts` || exit 1 +fi +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 + +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" + +while IFS=$'\n' read line; do + feats="ark:echo $line | copy-feats scp:- ark:- | add-deltas ark:- ark:- |${ignore_energy_opts}" + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + if $add_zero_crossing_feats || $filter_using_zero_crossings; then + if [ -f $data/segments ]; then + utils/filter_scp.pl $dir/$utt_id.list $data/segments | \ + extract-segments scp:$data/wav.scp - ark:- | \ + compute-zero-crossings $zc_opts ark:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + else + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp | \ + compute-zero-crossings $zc_opts scp:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + fi + fi + + extract-column "scp:utils/filter_scp.pl $dir/$utt_id.list $data/feats.scp |" \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + + sil_num_gauss=$sil_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + num_frames_silence=$num_frames_silence_init + + if $add_zero_crossing_feats; then + feats="${feats} paste-feats ark:- \"ark:add-deltas ark:$dir/$utt_id.zero_crossings.ark ark:- |\" ark:- |" + fi + + ### Compute likelihoods wrt bootstrapping models + gmm-global-get-frame-likes $init_speech_model \ + "${feats}" ark:$dir/$utt_id.speech_likes.bootstrap.ark || exit 1 + + gmm-global-get-frame-likes $init_silence_model \ + "${feats}" ark:$dir/$utt_id.silence_likes.bootstrap.ark || exit 1 + + ### Get bootstrapping VAD + loglikes-to-class --weights=ark:$dir/$utt_id.post.bootstrap.ark \ + ark:$dir/$utt_id.silence_likes.bootstrap.ark \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + ### Initialize Silence GMM using lowest energy chunks that were classified + ### as silence by the bootstrapping model + if ! $filter_using_zero_crossings; then + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=0 \ + --select-bottom-frames=true \ + --weights=ark:$dir/$utt_id.log_energies.ark --num-select-frames=$num_frames_silence \ + "${feats}" ark:- ark:$tmpdir/$utt_id.mask.0.ark | gmm-global-init-from-feats \ + --min-variance=$min_sil_variance --num-gauss=$sil_num_gauss --num-iters=$[sil_num_gauss] ark:- \ + $tmpdir/$utt_id.silence.0.mdl || exit 1 + else + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=0 \ + --weights="ark:extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- |" --num-select-frames=$num_frames_silence \ + "${feats}" ark:- ark:$tmpdir/$utt_id.mask.0.ark | gmm-global-init-from-feats \ + --min-variance=$min_sil_variance --num-gauss=$sil_num_gauss --num-iters=$[sil_num_gauss] ark:- \ + $tmpdir/$utt_id.silence.0.mdl || exit 1 + fi + + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.0.mdl \ + "${feats}" ark:$tmpdir/$utt_id.silence_likes.0.ark || exit 1 + + ### Get initial VAD + loglikes-to-class --weights=ark:$tmpdir/$utt_id.post.init.ark \ + ark:$tmpdir/$utt_id.silence_likes.0.ark \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$tmpdir/$utt_id.vad.init.ark || exit 1 + + ### Remove frames that were originally classified as speech + ### while training Silence and Sound GMMs + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=0 \ + "$feats" ark:$tmpdir/$utt_id.feats.init.ark \ + ark:$tmpdir/$utt_id.mask.init.ark || exit 1 + + #select-top-chunks \ + # --window-size=$window_size \ + # --selection-mask=ark:$tmpdir/$utt_id.vad.init.ark --select-class=0 \ + # "$feats" ark:$tmpdir/$utt_id.feats.init.ark \ + # ark:$tmpdir/$utt_id.mask.init.ark || exit 1 + + ## Select energies and zero crossings corresponding to the same selection + + extract-column ark:$dir/$utt_id.zero_crossings.ark ark:- | \ + vector-extract-dims ark:- \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.zero_crossings.init.ark || exit 1 + + vector-extract-dims ark:$dir/$utt_id.log_energies.ark \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.energies.init.ark || exit 1 + + vector-extract-dims ark:$tmpdir/$utt_id.vad.init.ark \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.vad.0.ark || exit 1 + + vector-extract-dims \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.speech_likes.init.ark || exit 1 + + x=0 + while [ $x -le $num_iters ]; do + ### Update Silence GMM using lowest energy chunks currently classified + ### as silence + + if ! $filter_using_zero_crossings; then + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.$x.ark --select-class=0 \ + --select-bottom-frames=true --weights=ark:$tmpdir/$utt_id.energies.init.ark \ + --num-select-frames=$num_frames_silence \ + ark:$tmpdir/$utt_id.feats.init.ark ark:- | \ + gmm-global-acc-stats \ + $tmpdir/$utt_id.silence.$x.mdl ark:- - | \ + gmm-global-est --mix-up=$sil_num_gauss $tmpdir/$utt_id.silence.$x.mdl \ + - $tmpdir/$utt_id.silence.$[x+1].mdl || exit 1 + else + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.$x.ark --select-class=0 \ + --weights=ark:$tmpdir/$utt_id.zero_crossings.init.ark \ + --num-select-frames=$num_frames_silence \ + ark:$tmpdir/$utt_id.feats.init.ark ark:- | \ + gmm-global-acc-stats \ + $tmpdir/$utt_id.silence.$x.mdl ark:- - | \ + gmm-global-est --mix-up=$sil_num_gauss $tmpdir/$utt_id.silence.$x.mdl \ + - $tmpdir/$utt_id.silence.$[x+1].mdl || exit 1 + fi + + ### Compute likelihoods with the current Silence and Sound GMMs + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$[x+1].mdl \ + ark:$tmpdir/$utt_id.feats.init.ark ark:$tmpdir/$utt_id.silence_likes.$[x+1].ark || exit 1 + + ### Get new VAD predictions on the subset selected for training + ### Silence and Sound GMMs + loglikes-to-class --weights=ark:$tmpdir/$utt_id.post.$[x+1].ark \ + ark:$tmpdir/$utt_id.silence_likes.$[x+1].ark \ + ark:$tmpdir/$utt_id.speech_likes.init.ark \ + ark:$tmpdir/$utt_id.vad.$[x+1].ark || exit 1 + + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$[x+1].mdl \ + "$feats" ark:- | \ + loglikes-to-class --weights=ark:$tmpdir/$utt_id.pred_post.$[x+1].ark ark:- \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$tmpdir/$utt_id.pred.$[x+1].ark || exit 1 + + x=$[x+1] + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_guass=$[sil_num_gauss + sil_gauss_incr] + num_frames_silence=$[num_frames_silence + silence_frames_incr] + fi + done ## Done training Silence and Speech GMMs + + ### Compute likelihoods with the current Silence and Sound GMMs + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.silence_likes.init.ark || exit 1 + + ### Compute initial segmentation for phase 2 training + loglikes-to-class --weights=ark:$phase2_dir/$utt_id.post.init.ark \ + ark:$phase2_dir/$utt_id.silence_likes.init.ark \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$phase2_dir/$utt_id.seg.init.ark || exit 1 + + ### Initialize Speech GMM + select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.init.ark --select-class=1 \ + "$feats" ark:- | gmm-global-init-from-feats --min-variance=$min_speech_variance \ + --num-gauss=$speech_num_gauss --num-iters=$[speech_num_gauss+2] \ + ark:- $phase2_dir/$utt_id.speech.0.mdl || exit 1 + + gmm-global-get-frame-likes $phase2_dir/$utt_id.speech.0.mdl \ + "$feats" ark:$phase2_dir/$utt_id.speech_likes.init.ark || exit 1 + + loglikes-to-class --weights=ark:$phase2_dir/$utt_id.pred.post.init.ark \ + ark:$phase2_dir/$utt_id.silence_likes.init.ark \ + ark:$phase2_dir/$utt_id.speech_likes.init.ark \ + ark:$phase2_dir/$utt_id.pred.init.ark || exit 1 + + cp $tmpdir/$utt_id.silence.$x.mdl $phase2_dir/$utt_id.silence.0.mdl || exit 1 + + x=0 + while [ $x -le $num_iters_phase2 ]; do + ### Compute likelihoods with the current Silence, Speech and Sound GMMs + gmm-global-get-frame-likes $phase2_dir/$utt_id.silence.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.silence_likes.$x.ark || exit 1 + + gmm-global-get-frame-likes $phase2_dir/$utt_id.speech.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.speech_likes.$x.ark || exit 1 + + ### Get segmentation + loglikes-to-class --weights=ark:$phase2_dir/$utt_id.pred.$x.ark \ + ark:$phase2_dir/$utt_id.silence_likes.$x.ark \ + ark:$phase2_dir/$utt_id.speech_likes.$x.ark \ + ark:$phase2_dir/$utt_id.seg.$x.ark || exit 1 + + ### Update Speech GMM + select-top-chunks --window-size=$window_size_phase2 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=1 \ + "$feats" ark:- | gmm-global-acc-stats \ + $phase2_dir/$utt_id.speech.$x.mdl ark:- - | \ + gmm-global-est --mix-up=$speech_num_gauss \ + $phase2_dir/$utt_id.speech.$x.mdl - $phase2_dir/$utt_id.speech.$[x+1].mdl || exit 1 + + cp $phase2_dir/$utt_id.silence.$x.mdl $phase2_dir/$utt_id.silence.$[x+1].mdl + ### Update Silence GMM + #select-top-chunks --window-size=$window_size_phase2 \ + # --selection-mask=ark:$phase2_dir/$utt_id.seg.$x.ark --select-class=0 \ + # "$feats" ark:- | gmm-global-acc-stats \ + # $phase2_dir/$utt_id.silence.$x.mdl ark:- - | \ + # gmm-global-est --mix-up=$sil_num_gauss --min-gaussian-occupancy=100 \ + # $phase2_dir/$utt_id.silence.$x.mdl - $phase2_dir/$utt_id.silence.$[x+1].mdl || exit 1 + + if [ $sil_num_gauss -lt $sil_max_gauss_phase2 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase2] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase2 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase2] + fi + + x=$[x+1] + done ## Done training all 3 GMMs + + cp $phase2_dir/$utt_id.silence.$x.mdl $dir/$utt_id.silence.final.mdl + cp $phase2_dir/$utt_id.speech.$x.mdl $dir/$utt_id.speech.final.mdl + +done < $data/feats.scp diff --git a/egs/sre08/v1/diarization/vad_gmm_ntu_em.sh b/egs/sre08/v1/diarization/vad_gmm_ntu_em.sh new file mode 100755 index 00000000000..7eb008233fc --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_ntu_em.sh @@ -0,0 +1,276 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -u +set -e +set -o pipefail + +stage=-1 + +## Features paramters +ignore_energy_opts= # select-feats 1-12,14-25,27,38 +window_size=100 # 1s + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +sil_num_gauss_init=2 +sil_max_gauss=2 +sil_gauss_incr=0 +sil_frames_incr=2000 +num_iters=5 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2=10 + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 4 ]; then + echo "Usage: vad_gmm_icsi.sh " + echo " e.g.: vad_gmm_icsi.sh data/rt05_eval exp/librispeech_s5/vad_model/silence.0.mdl exp/librispeech_s5/vad_model/speech.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +dir=$4 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +# Prepare a lang directory +if [ $stage -le -12 ]; then + mkdir -p $dir/local + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states 1 --num-nonsil-states 1 \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fstisstochastic $dir/lang/G.fst || echo "[info]: G not stochastic." + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict $dir/local/lang $dir/lang_test || exit 1 +fi + +feat_dim=`feat-to-dim "ark:head -n 1 $data/feats.scp | copy-feats scp:- ark:- | add-deltas ark:- ark:- |${ignore_energy_opts}" ark,t:- | awk '{print $2}'` || exit 1 + +if [ $stage -le -11 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang_test/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans_test.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_test $dir/lang_test $dir $dir/graph_test || exit 1 +fi + +cat < $dir/pdf_to_tid.map +0 1 +1 3 +EOF + + +if [ $stage -le -10 ]; then + if [ ! -f $data/segments ]; then + compute-zero-crossings --write-as-vector=true scp:$data/wav.scp \ + ark,scp:$dir/zero_crossings.ark,$dir/zero_crossings.scp || exit 1 + else + compute-zero-crossings --write-as-vector=true "ark:extract-segments scp:$data/wav.scp $data/segments ark:- |" \ + ark,scp:$dir/zero_crossings.ark,$dir/zero_crossings.scp || exit 1 + fi + extract-column scp:$data/feats.scp ark,scp:$dir/log_energies.ark,$dir/log_energies.scp || { echo "extract-column failed"; exit 1; } +fi + + +while IFS=$'\n' read line; do + feats="ark:echo $line | copy-feats scp:- ark:- | add-deltas ark:- ark:- |${ignore_energy_opts}" + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + sil_num_gauss=$sil_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + num_frames_silence=$num_frames_init_silence + + ### Compute likelihoods wrt bootstrapping models + gmm-global-get-frame-likes $init_speech_model \ + "${feats}" ark:$dir/$utt_id.speech_likes.bootstrap.ark || exit 1 + + gmm-global-get-frame-likes $init_silence_model \ + "${feats}" ark:$dir/$utt_id.silence_likes.bootstrap.ark || exit 1 + + ### Get bootstrapping VAD + loglikes-to-class \ + ark:$dir/$utt_id.silence_likes.bootstrap.ark \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + ### Initialize Silence GMM using lowest energy chunks that were classified + ### as silence by the bootstrapping model + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.bootstrap.ark --select-class=0 \ + --select-bottom-frames=true \ + --weights=scp:$dir/log_energies.scp --num-select-frames=$num_frames_silence \ + "${feats}" ark:- |gmm-global-init-from-feats \ + --num-gauss=$sil_num_gauss --num-iters=$[sil_num_gauss+2] ark:- \ + $tmpdir/$utt_id.silence.0.mdl || exit 1 + + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.0.mdl \ + "${feats}" ark:$tmpdir/$utt_id.silence_likes.0.ark || exit 1 + + ### Get initial VAD + loglikes-to-class \ + ark:$tmpdir/$utt_id.silence_likes.0.ark \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$tmpdir/$utt_id.vad.init.ark || exit 1 + + ### Remove frames that were originally classified as speech + ### while training Silence and Sound GMMs + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.init.ark --select-class=0 \ + "$feats" ark:$tmpdir/$utt_id.feats.init.ark \ + ark:$tmpdir/$utt_id.mask.init.ark || exit 1 + + ## Select energies and zero crossings corresponding to the same selection + + utils/filter_scp.pl $dir/$utt_id.list $dir/log_energies.scp | \ + vector-extract-dims scp:- ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.energies.init.ark || exit 1 + + utils/filter_scp.pl $dir/$utt_id.list $tmpdir/$utt_id.vad.init.ark | \ + vector-extract-dims ark,t:- ark:$tmpdir/$utt_id.mask.init.ark \ + ark:$tmpdir/$utt_id.vad.0.ark || exit 1 + + gmm-global-get-frame-likes $init_speech_model \ + ark:$tmpdir/$utt_id.feats.init.ark ark:$tmpdir/$utt_id.speech_likes.init.ark || exit 1 + + x=0 + while [ $x -le $num_iters ]; do + ### Update Silence GMM using lowest energy chunks currently classified + ### as silence + select-top-chunks \ + --window-size=$window_size \ + --selection-mask=ark:$tmpdir/$utt_id.vad.$x.ark --select-class=0 \ + --select-bottom-frames=true --weights=ark:$tmpdir/$utt_id.energies.init.ark \ + --num-select-frames=$num_frames_silence \ + ark:$tmpdir/$utt_id.feats.init.ark ark:- | \ + gmm-global-acc-stats \ + $tmpdir/$utt_id.silence.$x.mdl ark:- - | \ + gmm-global-est --mix-up=$sil_num_gauss $tmpdir/$utt_id.silence.$x.mdl \ + - $tmpdir/$utt_id.silence.$[x+1].mdl || exit 1 + + ### Compute likelihoods with the current Silence and Sound GMMs + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$[x+1].mdl \ + ark:$tmpdir/$utt_id.feats.init.ark ark:$tmpdir/$utt_id.silence_likes.$[x+1].ark || exit 1 + + ### Get new VAD predictions on the subset selected for training + ### Silence and Sound GMMs + loglikes-to-class \ + ark:$tmpdir/$utt_id.silence_likes.$[x+1].ark \ + ark:$tmpdir/$utt_id.speech_likes.init.ark \ + ark:$tmpdir/$utt_id.vad.$[x+1].ark || exit 1 + + loglikes-to-class \ + "ark:gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$[x+1].mdl \ + \"$feats\" ark:- |" \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$tmpdir/$utt_id.pred.$[x+1].ark || exit 1 + + x=$[x+1] + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_guass=$[sil_num_gauss + sil_gauss_incr] + num_frames_silence=$[num_frames_silence + silence_frames_incr] + fi + done ## Done training Silence and Speech GMMs + + ### Compute likelihoods with the current Silence and Sound GMMs + gmm-global-get-frame-likes $tmpdir/$utt_id.silence.$x.mdl \ + "$feats" ark:$phase2_dir/$utt_id.silence_likes.init.ark || exit 1 + + ### Compute initial segmentation for phase 2 training + loglikes-to-class \ + ark:$phase2_dir/$utt_id.silence_likes.init.ark \ + ark:$dir/$utt_id.speech_likes.bootstrap.ark \ + ark:$phase2_dir/$utt_id.seg.init.ark || exit 1 + + ### Initialize Speech GMM + select-top-chunks --window-size=1 \ + --selection-mask=ark:$phase2_dir/$utt_id.seg.init.ark --select-class=1 \ + "$feats" ark:- | gmm-global-init-from-feats \ + --num-gauss=$speech_num_gauss --num-iters=$[speech_num_gauss*2] \ + ark:- $phase2_dir/$utt_id.speech.0.mdl || exit 1 + + gmm-global-get-frame-likes $phase2_dir/$utt_id.speech.0.mdl \ + "$feats" ark:$phase2_dir/$utt_id.speech_likes.init.ark || exit 1 + + ### Compute initial segmentation for phase 2 training + loglikes-to-class \ + ark:$phase2_dir/$utt_id.silence_likes.init.ark \ + ark:$phase2_dir/$utt_id.speech_likes.init.ark \ + ark:$phase2_dir/$utt_id.pred.init.ark || exit 1 + + cp $tmpdir/$utt_id.silence.$x.mdl $phase2_dir/$utt_id.silence.0.mdl || exit 1 + { + cat $dir/trans.mdl; + echo " $feat_dim 2"; + gmm-global-copy --binary=false $phase2_dir/$utt_id.silence.0.mdl -; + gmm-global-copy --binary=false $phase2_dir/$utt_id.speech.0.mdl -; + } | gmm-copy - $phase2_dir/$utt_id.0.mdl || exit 1 + + x=0 + while [ $x -le $num_iters_phase2 ]; do + gmm-latgen-faster --acoustic-scale=1.0 --determinize-lattice=false \ + $phase2_dir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" "ark:| gzip -c > $phase2_dir/$utt_id.$x.lat.gz" \ + ark:/dev/null ark:- | \ + ali-to-phones --per-frame=true \ + $phase2_dir/$utt_id.$x.mdl ark:- ark:- | \ + copy-int-vector ark:- ark:$phase2_dir/$utt_id.$x.ali || exit 1 + + lattice-to-post --acoustic-scale=1.0 \ + "ark:gunzip -c $phase2_dir/$utt_id.$x.lat.gz |" ark:- | \ + rand-prune-post 0.6 ark:- ark:- | \ + gmm-acc-stats $phase2_dir/$utt_id.$x.mdl "$feats" ark:- - | \ + gmm-est --mix-up=$[sil_num_gauss+speech_num_gauss] \ + --update-flags=tmv $phase2_dir/$utt_id.$x.mdl - $phase2_dir/$utt_id.$[x+1].mdl || exit 1 + + + if [ $sil_num_gauss -lt $sil_max_gauss_phase2 ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr_phase2] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss_phase2 ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr_phase2] + fi + + x=$[x+1] + done ## Done training all 3 GMMs + + cp $phase2_dir/$utt_id.silence.$x.mdl $dir/$utt_id.silence.final.mdl + cp $phase2_dir/$utt_id.speech.$x.mdl $dir/$utt_id.speech.final.mdl + +done < $data/feats.scp + diff --git a/egs/sre08/v1/diarization/vad_gmm_snr.sh b/egs/sre08/v1/diarization/vad_gmm_snr.sh new file mode 100644 index 00000000000..94a9740a2b3 --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_snr.sh @@ -0,0 +1,564 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -u +set -o pipefail + +cmd=run.pl +stage=-100 +try_merge_speech_noise=false +write_feats=false + +## Features paramters +window_size=5 # 5 frame. Window over which initial selection of frames + +. path.sh +. parse_options.sh || exit 1 + +if [ $# -ne 5 ]; then + echo "Usage: vad_gmm_snr.sh " + echo " e.g.: vad_gmm_snr.sh data/rt05_eval exp/librispeech_s5/vad_model/{silence,speech}.0.mdl exp/vad_rt05_eval" + exit 1 +fi + +data=$1 +frame_snrs_scp=$2 +init_silence_model=$3 +init_speech_model=$4 +dir=$5 + +init_model_dir=`dirname $init_speech_model` +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 +add_frame_snrs=`cat $init_model_dir/add_frame_snrs` || exit 1 + +# Prepare a lang directory +if [ $stage -le -4 ]; then + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + mkdir -p $dir/local/dict_2class + mkdir -p $dir/local/lm_2class + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo "1" > $dir/local/dict_2class/silence_phones.txt + echo "1" > $dir/local/dict_2class/optional_silence.txt + echo "2" > $dir/local/dict_2class/nonsilence_phones.txt + echo "3" >> $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict_2class/lexicon.txt + echo -e "1 1\n2 2\n3 3" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict_2class/extra_questions.txt + echo -e "1\n2\n1 2\n3\n1 3\n2 3\n1 2 3" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + diarization/prepare_vad_lang.sh --num-sil-states $num_sil_states --num-nonsil-states $num_nonsil_states \ + $dir/local/dict_2class $dir/local/lang_2class $dir/lang_2class || exit 1 +fi + +feat_dim=`gmm-global-info $init_speech_model | grep "feature dimension" | awk '{print $NF}'` || exit 1 + +if [ $stage -le -3 ]; then + run.pl $dir/log/create_transition_model.log gmm-init-mono \ + $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + + run.pl $dir/log/create_transition_model_2class.log gmm-init-mono \ + $dir/lang_2class/topo $feat_dim - $dir/tree_2class \| \ + copy-transition-model --binary=false - $dir/trans_2class.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $dir/lang_2class $dir $dir/graph_2class || exit 1 +fi + +if [ $stage -le -2 ]; then + { + cat $dir/trans_2class.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $init_silence_model - || exit 1 + gmm-global-copy --binary=false $init_speech_model - || exit 1 + } | gmm-copy - $dir/init_2class.mdl || exit 1 +fi + +if [ $stage -le -1 ]; then + t=$speech_to_sil_ratio + lang=$dir/lang_test_${t}x + cp -r $dir/lang $lang + perl -e '$t = shift @ARGV; print "0 0 1 1 " . -log(1/($t+3)) . "\n0 0 2 2 ". -log($t/($t+3)). "\n0 0 3 3 ". -log(1/($t+3)) ."\n0 ". -log(1/($t+3))' $t | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + diarization/make_vad_graph.sh --iter trans $lang $dir $dir/graph_test_${t}x || exit 1 + + lang=$dir/lang_2class_test_${t}x + cp -r $dir/lang_2class $lang + perl -e '$t = shift @ARGV; print "0 0 1 1 " . -log(1/($t+2)) . "\n0 0 2 2 ". -log($t/($t+2)). "\n0 ". -log(1/($t+2))' $t | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + + diarization/make_vad_graph.sh --iter trans_2class --tree tree_2class $lang $dir $dir/graph_2class_test_${t}x || exit 1 +fi + +while IFS=$'\n' read line; do + feats="ark:echo $line | apply-cmvn-sliding scp:- ark:- |${ignore_energy_opts}" + + utt_id=$(echo $line | awk '{print $1}') + echo $utt_id > $dir/$utt_id.list + + if [ -f $data/segments ]; then + if $add_zero_crossing_feats; then + # Extract zero-crossing feats for adding as a feature + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-zero-crossings $zc_opts ark:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + fi + + # Extract log-energies + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/segments \| \ + extract-segments scp:$data/wav.scp - ark:- \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + ark:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + + else + if $add_zero_crossing_feats; then + # Extract zero-crossing feats for adding as a feature + $cmd $dir/log/$utt_id.extract_zero_crossings.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-zero-crossings $zc_opts scp:- ark:$dir/$utt_id.zero_crossings.ark || exit 1 + fi + + # Extract log-energies + $cmd $dir/log/$utt_id.extract_log_energies.log \ + utils/filter_scp.pl $dir/$utt_id.list $data/wav.scp \| \ + compute-mfcc-feats --config=conf/mfcc_vad.conf --num-ceps=1 \ + scp:- ark:- \| extract-column ark:- \ + ark:$dir/$utt_id.log_energies.ark || exit 1 + fi + + utils/filter_scp.pl $data/utt2spk $frame_snrs_scp > $dir/frame_snrs.scp + + # Initial GMM parameters + sil_num_gauss=$sil_num_gauss_init + sound_num_gauss=$sound_num_gauss_init + speech_num_gauss=$speech_num_gauss_init + + # Optionally add zero-crossings to the features + if $add_zero_crossing_feats; then + feats="${feats}paste-feats ark:- ark:$dir/$utt_id.zero_crossings.ark ark:- |" + fi + + # Optionally add frame-snrs to the features + if $add_frame_snrs; then + feats="${feats}paste-feats ark:- \"ark:vector-to-feat scp:$dir/frame_snrs.scp ark:- |\" ark:- |" + fi + + # Add delta and delta-deltas + feats="${feats} add-deltas ark:- ark:- |" + + if $write_feats; then + copy-feats "$feats" ark:$dir/$utt_id.feat.ark + fi + + # Compute initial likelihoods wrt to speech and silence models + $cmd $dir/log/$utt_id.gmm_compute_likes.bootstrap.log \ + gmm-compute-likes $dir/init.mdl "$feats" \ + ark:$tmpdir/$utt_id.likes.bootstrap.ark & + + # Get VAD from bootstrap model. This is just for baseline. + # This is not actually used later. + $cmd $dir/log/$utt_id.get_vad.bootstrap.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $dir/init.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| ali-to-pdf $dir/init.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.vad.bootstrap.ark || exit 1 + + # i.e. unless use-bootstrap-vad is given. (Only for baseline) + if $use_bootstrap_vad; then + segmentation-copy ark:$tmpdir/$utt_id.vad.bootstrap.ark \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + continue + fi + + cp $tmpdir/$utt_id.likes.bootstrap.ark $tmpdir/$utt_id.likes.0.ark + + x=0 + goto_phase3=false # Stage for merging speech and noise + + ############################################################################# + # Phase 1 + # Train noise GMM on lowest SNR frames. + # Train speech GMM on highest likelihood frames + # Train silence GMM on lowest energy frames. + ############################################################################# + + while [ $x -lt $num_iters ]; do + # Number of frames to initially train the silence and sound GMMs + num_frames_silence=$[num_frames_init_silence + sil_num_gauss * frames_per_gaussian ] + num_frames_silence_next=$[num_frames_init_silence_next + silence_num_gauss * frames_per_gaussian ] + num_frames_sound=$[num_frames_init_sound + 5 * sound_num_gauss * frames_per_gaussian ] + num_frames_sound_next=$[num_frames_init_sound_next + sound_num_gauss * frames_per_gaussian ] + num_frames_speech=$[num_frames_init_speech + speech_num_gauss * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + # For the initial 3 iterations, the silence, sound and speech frames are + # defined as follows: + # Silence -- low energy and low speech likelihood frames + # Sound -- low SNR and low speech likelihood frames + # Speech -- high speech likelihood + + # Find silence frames + $cmd $tmpdir/log/$utt_id.select_top.silence.$x.log \ + segmentation-init-from-lengths "ark:echo $line | feat-to-len scp:- ark:- |" | \ + segmentation-select-top --num-bins=$num_bins \ + --merge-labels=1 --merge-dst-label=0 \ + --num-bottom-frames=$num_frames_silence \ + --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --num-bottom-frames=$num_frames_silence_next \ + --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$tmpdir/$utt_id.likes.$x.ark \ + ark:$tmpdir/$utt_id.seg.silence.$x.ark || exit 1 + + # Find noise frames + $cmd $tmpdir/log/$utt_id.select_top.sound.$x.log \ + segmentation-init-from-lengths "ark:echo $line | feat-to-len scp:- ark:- |" | \ + segmentation-select-top --num-bins=$num_bins \ + --merge-labels=1 --merge-dst-label=2 \ + --num-bottom-frames=$num_frames_sound \ + --bottom-select-label=2 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- "scp:utils/filter_scp.pl $dir/$utt_id.list $dir/frame_snrs.scp |" \| \ + segmentation-select-top --num-bins=$num_bins \ + --num-bottom-frames=$num_frames_sound_next \ + --bottom-select-label=2 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$tmpdir/$utt_id.likes.$x.ark \ + ark:$tmpdir/$utt_id.seg.sound.$x.ark || exit 1 + + #$cmd $tmpdir/log/$utt_id.merge_segmentations.$x.log \ + # segmentation-merge ark:$tmpdir/$utt_id.seg.silence.$[x+1].ark \ + # ark:$tmpdir/$utt_id.seg.sound.$[x+1].ark \ + # ark:$tmpdir/$utt_id.seg.$[x+1].ark + + # Find speech frames + $cmd $tmpdir/log/$utt_id.select_top.speech.$x.log \ + segmentation-init-from-lengths "ark:echo $line | feat-to-len scp:- ark:- |" | \ + segmentation-select-top --num-bins=$num_bins \ + --num-top-frames=$num_frames_speech \ + --top-select-label=1 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$tmpdir/$utt_id.likes.$x.ark \ + ark:$tmpdir/$utt_id.seg.speech.$x.ark || exit 1 + else + # Get segmentation + $cmd $tmpdir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $tmpdir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $tmpdir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.seg.$x.ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans.mdl; + echo " $feat_dim 3"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$tmpdir/$utt_id.seg.silence.$x.ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:$tmpdir/$utt_id.seg.speech.$x.ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- - || exit 1 + select-feats-from-segmentation --select-label=2 "$feats" \ + ark:$tmpdir/$utt_id.seg.sound.$x.ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sound_num_gauss+1] --num-gauss-init=1 --num-gauss=$sound_num_gauss \ + ark:- - || exit 1 + } 2> $tmpdir/log/$utt_id.init_gmm.log | \ + gmm-copy - $tmpdir/$utt_id.$[x+1].mdl 2>> $tmpdir/log/$utt_id.init_gmm.log + if [ $? -ne 0 ]; then + echo "Insufficient frames for training silence or sound model. Skipping to phase 3" + goto_phase3=true + break; + fi + else + $cmd $tmpdir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + $tmpdir/$utt_id.$x.mdl "$feats" \ + ark:$tmpdir/$utt_id.seg.$x.ark \ + $tmpdir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $tmpdir/log/$utt_id.gmm_compute_likes.$[x+1].log \ + gmm-compute-likes $tmpdir/$utt_id.$[x+1].mdl "$feats" \ + ark:$tmpdir/$utt_id.likes.$[x+1].ark & + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + if [ $sound_num_gauss -lt $sound_max_gauss ]; then + sound_num_gauss=$[sound_num_gauss + sound_gauss_incr] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr] + fi + + x=$[x+1] + done ## Done training GMMs + echo "$0: Phase 1 training done!" + + cp $tmpdir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + + # Get final segmentation at the end of phase 1 + $cmd $tmpdir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph/words.txt \ + $tmpdir/$utt_id.$x.mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $tmpdir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$tmpdir/$utt_id.seg.$x.ark || exit 1 + + mkdir -p $phase3_dir/log + + ############################################################################# + # Try merging speech and noise GMMs + ############################################################################# + + # Create a merged model + $cmd $tmpdir/log/$utt_id.init_nonsil.log \ + segmentation-copy --merge-labels=1:2 --merge-dst-label=1 \ + ark:$tmpdir/$utt_id.seg.$x.ark ark:- \| \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:- ark:- \| \ + gmm-global-init-from-feats \ + --num-iters=$[sound_num_gauss + speech_num_gauss + 1] \ + --num-gauss-init=1 \ + --num-gauss=$[sound_num_gauss + speech_num_gauss] ark:- \ + $tmpdir/$utt_id.$x.nonsil.mdl || exit 1 + + # Select speech feats from the final segmentation + $cmd $tmpdir/log/$utt_id.select_speech_feats.$x.log \ + select-feats-from-segmentation --select-label=1 \ + "$feats" ark:$tmpdir/$utt_id.seg.$x.ark \ + ark:$tmpdir/$utt_id.speech_feats.$x.ark + + if [ $? -eq 0 ]; then + # Check if there is sufficient speech frames + num_selected_speech=$(grep "Processed .* segmentations; selected" $tmpdir/log/$utt_id.select_speech_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_speech -lt $min_data ]; then + echo "Insufficient frames for speech at the end of phase 1. $num_selected_speech < $min_data. See $tmpdir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + # Check if there is any speech frame + echo "Failed to find any data for speech at the end of phase 1. See $tmpdir/log/$utt_id.select_speech_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + + if ! $goto_phase3; then + # Not failed yet. So can compute speech likelihood. + speech_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $tmpdir/$utt_id.$x.mdl 1 - |" \ + ark:$tmpdir/$utt_id.speech_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $tmpdir/$utt_id.compute_speech_like.$x.log || exit 1 + + # Select noise feats from the final segmentation + $cmd $tmpdir/log/$utt_id.select_sound_feats.$x.log \ + select-feats-from-segmentation --select-label=2 \ + "$feats" ark:$tmpdir/$utt_id.seg.$x.ark \ + ark:$tmpdir/$utt_id.sound_feats.$x.ark + + if [ $? -eq 0 ]; then + # Check if there is sufficient noise frames + num_selected_sound=$(grep "Processed .* segmentations; selected" $tmpdir/log/$utt_id.select_sound_feats.$x.log | perl -pe 's/.+selected (\S+) out of \S+ frames/$1/') + if [ $num_selected_sound -lt $min_data ]; then + echo "Insufficient frames for sound at the end of phase 1. $num_selected_sound < $min_data. See $tmpdir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + else + # Check if there is any noise frame + echo "Failed to find any data for sound at the end of phase 1. See $phase2_dir/log/$utt_id.select_sound_feats.$x.log. Going to phase 3." + goto_phase3=true + fi + fi + + if ! $goto_phase3; then + # Not failed yet. So can compute noise likelihood. + sound_like=$(gmm-global-get-frame-likes \ + "gmm-extract-pdf $phase2_dir/$utt_id.$x.mdl 2 - |" \ + ark:$phase2_dir/$utt_id.sound_feats.$x.ark ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_sound_like.$x.log || exit 1 + + # Compute non-silence likelihood using combined speech+noise GMM + nonsil_like=$(select-feats-from-segmentation --merge-labels=1:2 --select-label=1 \ + "$feats" ark:$phase2_dir/$utt_id.seg.$x.ark ark:- | \ + gmm-global-get-frame-likes \ + $phase2_dir/$utt_id.$x.nonsil.mdl ark:- ark,t:- | \ + perl -pe 's/.*\[(.+)]/$1/' | \ + perl -ane '$sum = 0; foreach(@F) { $sum = $sum + $_; $i = $i + 1;}; print STDOUT ($sum)') 2> $phase2_dir/$utt_id.compute_nonsil_like.$x.log || exit 1 + + # Likelihood test -- Check if the combined model gives better likelihood + # than separate speech and noise models. If yes, then go to phase 3 + if [ ! -z `perl -e "print \"true\" if ($sound_like + $speech_like < $nonsil_like)"` ]; then + goto_phase3=true + fi + fi + + if $goto_phase3; then + ############################################################################# + # Phase 3 + # Train speech GMM on highest likelihood frames + # Train silence GMM on lowest energy frames. + ############################################################################# + + while [ $x -lt $num_iters ]; do + # Number of frames to initially train the silence and sound GMMs + num_frames_silence=$[num_frames_init_silence + sil_num_gauss * frames_per_gaussian ] + num_frames_silence_next=$[num_frames_init_silence_next + silence_num_gauss * frames_per_gaussian ] + num_frames_speech=$[num_frames_init_speech + speech_num_gauss * frames_per_gaussian ] + + if [ $x -lt 3 ]; then + # For the initial 3 iterations, the silence and speech frames are + # defined as follows: + # Silence -- low energy and low speech likelihood frames + # Speech -- high speech likelihood + + # Find silence frames + $cmd $phase3_dir/log/$utt_id.select_top.silence.$x.log \ + segmentation-init-from-lengths "ark:echo $line | feat-to-len scp:- ark:- |" | \ + segmentation-select-top --num-bins=$num_bins \ + --merge-labels=1 --merge-dst-label=0 \ + --num-bottom-frames=$num_frames_silence \ + --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$dir/$utt_id.log_energies.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --num-bottom-frames=$num_frames_silence_next \ + --bottom-select-label=0 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$phase3_dir/$utt_id.likes.$x.ark \ + ark:$phase3_dir/$utt_id.seg.silence.$x.ark || exit 1 + + # Find speech frames + $cmd $phase3_dir/log/$utt_id.select_top.speech.$x.log \ + segmentation-init-from-lengths "ark:echo $line | feat-to-len scp:- ark:- |" | \ + segmentation-select-top --num-bins=$num_bins \ + --num-top-frames=$num_frames_speech \ + --top-select-label=1 --reject-label=1000 \ + --remove-rejected-frames=true \ + --window-size=$window_size --min-window-remainder=$[window_size/2] \ + ark:- ark:$phase3_dir/$utt_id.likes.$x.ark ark:- \| \ + segmentation-select-top --num-bins=$num_bins \ + --num-top-frames=$num_frames_speech \ + ark:$phase3_dir/$utt_id.seg.speech.$x.ark || exit 1 + else + # Get segmentation using current model + $cmd $phase3_dir/log/$utt_id.get_seg.$x.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/graph_2class/words.txt \ + $phase3_dir/$utt_id.$[x+1].mdl $dir/graph/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $phase3_dir/$utt_id.$x.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark:$phase3_dir/$utt_id.seg.$x.ark || exit 1 + fi + + if [ $x -eq 0 ]; then + { + cat $dir/trans_2class.mdl; + echo " $feat_dim 3"; + select-feats-from-segmentation --select-label=0 "$feats" \ + ark:$phase3_dir/$utt_id.seg.silence.$x.ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[sil_num_gauss+1] --num-gauss-init=1 --num-gauss=$sil_num_gauss \ + ark:- - || exit 1 + select-feats-from-segmentation --select-label=1 "$feats" \ + ark:$phase3_dir/$utt_id.seg.speech.$x.ark ark:- | \ + gmm-global-init-from-feats --binary=false \ + --num-iters=$[speech_num_gauss+1] --num-gauss-init=1 --num-gauss=$speech_num_gauss \ + ark:- - || exit 1 + } 2> $phase3_dir/log/$utt_id.init_gmm.log | \ + gmm-copy - $phase3_dir/$utt_id.$[x+1].mdl 2>> $phase3_dir/log/$utt_id.init_gmm.log + if [ $? -ne 0 ]; then + echo "VAD failed for utterance $utt_id. Utterance is fully silence." + fail_flag=true + break + fi + else + $cmd $phase3_dir/log/$utt_id.gmm_update.$[x+1].log \ + gmm-update-segmentation \ + --mix-up-rxfilename="echo -e \"0 $sil_num_gauss\n1 $speech_num_gauss\n2 $sound_num_gauss\" |" \ + $phase3_dir/$utt_id.$x.mdl "$feats" \ + ark:$phase3_dir/$utt_id.seg.$x.ark \ + $phase3_dir/$utt_id.$[x+1].mdl || exit 1 + fi + + $cmd $phase3_dir/log/$utt_id.gmm_compute_likes.$[x+1].log \ + gmm-compute-likes $phase3_dir/$utt_id.$[x+1].mdl "$feats" \ + ark:$phase3_dir/$utt_id.likes.$[x+1].ark & + + if [ $sil_num_gauss -lt $sil_max_gauss ]; then + sil_num_gauss=$[sil_num_gauss + sil_gauss_incr] + fi + + if [ $speech_num_gauss -lt $speech_max_gauss ]; then + speech_num_gauss=$[speech_num_gauss + speech_gauss_incr] + fi + + x=$[x+1] + done ## Done training GMMs + echo "$0: Phase 3 training done!" + + cp $phase3_dir/$utt_id.$x.mdl $dir/$utt_id.final.mdl + rm -f $dir/$utt_id.graph_final + ln -s graph_2class_test_${speech_to_sil_ratio}x $dir/$utt_id.graph_final + fi + + if ! $fail_flag; then + $cmd $dir/log/$utt_id.gmm_compute_likes.final.log \ + gmm-compute-likes $dir/$utt_id.final.mdl "$feats" \ + ark:$dir/$utt_id.likes.final.ark & + + $cmd $dir/log/$utt_id.get_seg.final.log \ + gmm-decode-simple --allow-partial=$allow_partial \ + --word-symbol-table=$dir/$utt_id.graph_final/words.txt \ + $dir/$utt_id.final.mdl $dir/$utt_id.graph_final/HCLG.fst \ + "$feats" ark:/dev/null ark:- \| \ + ali-to-pdf $dir/$utt_id.final.mdl ark:- ark:- \| \ + segmentation-init-from-ali ark:- \ + ark,scp:$dir/$utt_id.vad.final.ark,$dir/$utt_id.vad.final.scp || exit 1 + fi + +done < $data/feats.scp diff --git a/egs/sre08/v1/diarization/vad_gmm_vimal.sh b/egs/sre08/v1/diarization/vad_gmm_vimal.sh new file mode 100755 index 00000000000..af18a9d2606 --- /dev/null +++ b/egs/sre08/v1/diarization/vad_gmm_vimal.sh @@ -0,0 +1,299 @@ +#!/bin/bash +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e +set -u +set -o pipefail + +# Begin configuration section. +cmd=run.pl +stage=-1 + +# Decode options +speech_duration=75 +sil_duration=30 +impr_thres=0.002 +cleanup=true +use_loglikes_hypothesis=false +use_latgen=false +map_opts= + +## Features paramters +window_size=100 # 1s +force_ignore_energy_opts= + +## Phase 1 parameters +num_frames_init_silence=2000 # 20s - Lowest energy frames selected to initialize Silence GMM +num_frames_init_sound=10000 # 100s - Highest energy frames selected to initialize Sound GMM +num_frames_init_sound_next=2000 # 20s - Highest zero crossing frames selected to initialize Sound GMM +sil_num_gauss_init=2 +sound_num_gauss_init=2 +sil_max_gauss=2 +sound_max_gauss=8 +sil_gauss_incr=0 +sound_gauss_incr=2 +sil_frames_incr=2000 +sound_frames_incr=10000 +sound_frames_next_incr=2000 +num_iters=5 +min_sil_variance=1 +min_sound_variance=0.01 +min_speech_variance=0.001 + +## Phase 2 parameters +speech_num_gauss_init=6 +sil_max_gauss_phase2=7 +sound_max_gauss_phase2=18 +speech_max_gauss_phase2=16 +sil_gauss_incr_phase2=1 +sound_gauss_incr_phase2=2 +speech_gauss_incr_phase2=2 +num_iters_phase2=5 +window_size_phase2=10 + +## Phase 3 parameters +sil_num_gauss_init_phase3=2 +speech_num_gauss_init_phase3=2 +sil_max_gauss_phase3=5 +speech_max_gauss_phase3=12 +sil_gauss_incr_phase3=1 +speech_gauss_incr_phase3=2 +num_iters_phase3=7 + + +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 2 ]; then + echo "Usage: vad_gmm_vimal.sh " + echo " e.g.: vad_gmm_icsi.sh data/rt05_eval exp/librispeech_s5/vad_model/silence.0.mdl exp/librispeech_s5/vad_model/speech.0.mdl exp/vad_rt05_eval" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +data=$1 +init_silence_model=$2 +init_speech_model=$3 +dir=$4 + +mkdir -p $dir +tmpdir=$dir/phase1 +phase2_dir=$dir/phase2 +phase3_dir=$dir/phase3 + +mkdir -p $tmpdir +mkdir -p $phase2_dir +mkdir -p $phase3_dir + +init_model_dir=`dirname $init_speech_model` + +for f in $data/feats.scp $data/wav.scp $init_speech_model $init_silence_model; do + [ ! -s $f ] && echo "$0: could not find $f or $f is empty" && exit 1 +done + +ignore_energy_opts=`cat $init_model_dir/ignore_energy_opts` || exit 1 + +add_zero_crossing_feats=`cat $init_model_dir/add_zero_crossing_feats` || exit 1 + +zc_opts= +[ -f conf/zc_vad.conf ] && zc_opts="--config=conf/zc_vad.conf" + +feat_dim=`feat-to-dim "scp:head -n 1 $data/feats.scp |" ark,t:- | awk '{print $2}'` || exit 1 + +# Prepare a lang directory +if [ $stage -le -2 ]; then + mkdir -p $dir/local + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo $sil_phone_list | \ + while IFS=' ' read phone; do + echo $phone + done > $dir/local/dict/silence_phones.txt + + echo "1" > $dir/local/dict/optional_silence.txt + + echo $speech_phone_list | \ + while IFS=' ' read phone; do + echo $phone + done > $dir/local/dict/nonsilence_phones.txt + + echo "$sil_phone_list $speech_phone_list" | \ + while IFS=' ' read phone; do + echo $phone $phone + done > $dir/local/dict/lexicon.txt + + echo -e "" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + + # Training-time language model for VAD + diarization/prepare_vad_lang.sh --num-sil-states 1 --num-nonsil-states 1 \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fstisstochastic $dir/lang/G.fst || echo "[info]: G not stochastic." + + # Testing-time language model for VAD + diarization/prepare_vad_lang.sh --num-sil-states 30 --num-nonsil-states 75 \ + $dir/local/dict $dir/local/lang $dir/lang_test || exit 1 +fi + +if [ $stage -le -1 ]; then + $cmd $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + $cmd $dir/log/create_transition_model.log gmm-init-mono \ + --binary=false $dir/lang_test/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans_test.mdl || exit 1 + + diarization/make_vad_graph.sh --iter trans $dir/lang $dir $dir/graph || exit 1 + diarization/make_vad_graph.sh --iter trans_test $dir/lang_test $dir $dir/graph_test || exit 1 +fi + +cat < $dir/pdf_to_tid.map +0 1 +1 3 +EOF + +if [ $stage -le 0 ]; then +mkdir -p $dir/q +utils/split_data.sh $data $nj || exit 1 + +map_est= +[ ! -z "$init_vad_model" ] && map_est="-map" +[ ! -z "$init_speech_model" ] && map_est="-map" + +for n in `seq $nj`; do + cat < $dir/q/do_vad.$n.sh +set -e +set -o pipefail +set -u + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +while IFS=$'\n' read line; do + feats="ark:echo \$line | copy-feats scp:- ark:- |" + utt_id=\$(echo \$line | awk '{print \$1}') + + if [ ! -z "$init_vad_model" ]; then + cp $init_vad_model $dir/\$utt_id.0.mdl + elif [ -z "$init_speech_model" ] || [ -z "$init_sil_model" ]; then + if ! $select_top_frames; then + gmm-global-init-from-feats --num-gauss=$speech_num_gauss --num-iters=10 \ + "\$feats select-voiced-frames ark:- scp:$data/vad.scp ark:- |" \ + $dir/\$utt_id.speech.0.mdl || exit 1 + gmm-global-init-from-feats --num-gauss=$sil_num_gauss --num-iters=6 \ + "\$feats select-voiced-frames --select-unvoiced-frames=true ark:- scp:$data/vad.scp ark:- |" \ + $dir/\$utt_id.silence.0.mdl || exit 1 + else + gmm-global-init-from-feats --num-gauss=$speech_num_gauss --num-iters=12 \ + "\$feats select-top-chunks --window-size=100 --top-frames-proportion=$top_frames_threshold ark:- ark:- |" \ + $dir/\$utt_id.speech.0.mdl || exit 1 + gmm-global-init-from-feats --num-gauss=$sil_num_gauss --num-iters=8 \ + "\$feats select-top-chunks --window-size=100 --bottom-frames-proportion=$bottom_frames_threshold --top-frames-proportion=0.0 ark:- ark:- |" \ + $dir/\$utt_id.silence.0.mdl || exit 1 + fi + + { + cat $dir/trans.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $dir/\$utt_id.silence.0.mdl - + gmm-global-copy --binary=false $dir/\$utt_id.speech.0.mdl - + } | gmm-copy - $dir/\$utt_id.0.mdl || exit 1 + else + { + cat $dir/trans.mdl + echo " $feat_dim 2" + gmm-global-copy --binary=false $init_speech_model - + gmm-global-copy --binary=false $init_sil_model - + } | gmm-copy - $dir/\$utt_id.0.mdl || exit 1 + fi + + x=0 + while [ \$x -lt $num_iters ]; do + if $use_loglikes_hypothesis; then + gmm-compute-likes $dir/\$utt_id.\$x.mdl "\$feats" ark:- | \ + loglikes-to-post --min-post=$frame_select_threshold \ + ark:- "ark:| gzip -c > $dir/\$utt_id.\$x.post.gz" || exit 1 + + gmm-acc-stats \ + $dir/\$utt_id.\$x.mdl "\$feats" \ + "ark:gunzip -c $dir/\$utt_id.\$x.post.gz | copy-post-mapped --id-map=$dir/pdf_to_tid.map ark:- ark:- |" - | \ + gmm-est${map_est} ${map_opts} --update-flags=mv $dir/\$utt_id.\$x.mdl - $dir/\$utt_id.\$[x+1].mdl \ + 2>&1 | tee $dir/log/update.\$utt_id.\$x.log || exit 1 + elif $use_latgen; then + gmm-latgen-faster --acoustic-scale=1.0 --determinize-lattice=false \ + $dir/\$utt_id.\$x.mdl $dir/graph/HCLG.fst \ + "\$feats" "ark:| gzip -c > $dir/\$utt_id.\$x.lat.gz" || exit 1 + + lattice-to-post --acoustic-scale=1.0 \ + "ark:gunzip -c $dir/\$utt_id.\$x.lat.gz |" ark:- | \ + gmm-acc-stats $dir/\$utt_id.\$x.mdl "\$feats" ark:- - | \ + gmm-est${map_est} ${map_opts} --update-flags=mv $dir/\$utt_id.\$x.mdl - $dir/\$utt_id.\$[x+1].mdl \ + 2>&1 | tee $dir/log/update.\$utt_id.\$x.log || exit 1 + else + gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/\$utt_id.\$x.mdl $dir/graph/HCLG.fst \ + "\$feats" ark:/dev/null ark:$dir/\$utt_id.\$x.ali || exit 1 + + gmm-acc-stats-ali \ + $dir/\$utt_id.\$x.mdl "\$feats" \ + ark:$dir/\$utt_id.\$x.ali - | \ + gmm-est${map_est} ${map_opts} --update-flags=mv $dir/\$utt_id.\$x.mdl - $dir/\$utt_id.\$[x+1].mdl \ + 2>&1 | tee $dir/log/update.\$utt_id.\$x.log || exit 1 + fi + + objf_impr=\$(cat $dir/log/update.\$utt_id.\$x.log | grep "GMM update: Overall .* objective function" | perl -pe 's/.*GMM update: Overall (\S+) objective function .*/\$1/') + + if [ "\$(perl -e "if (\$objf_impr < $impr_thres) { print true; }")" == true ]; then + break; + fi + + x=\$[x+1] + done + + rm -f $dir/\$utt_id.final.mdl 2>/dev/null || true + #cp $dir/\$utt_id.\$x.mdl $dir/\$utt_id.final.mdl + + ( + copy-transition-model --binary=false $dir/trans_test.mdl - + gmm-copy --write-tm=false --binary=false $dir/\$utt_id.\$x.mdl - + ) | gmm-copy - $dir/\$utt_id.final.mdl + + #gmm-decode-simple \ + # --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + # $dir/\$utt_id.final.mdl $dir/graph/HCLG.fst \ + # "\$feats" ark:/dev/null ark:$dir/\$utt_id.final.ali || exit 1 + + gmm-decode-simple \ + --allow-partial=true --word-symbol-table=$dir/graph/words.txt \ + $dir/\$utt_id.final.mdl $dir/graph_test/HCLG.fst \ + "\$feats" ark:/dev/null ark:$dir/\$utt_id.final.ali || exit 1 +done < $data/split$nj/$n/feats.scp +EOF +done +fi + +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj $dir/log/do_vad_job.JOB.log bash -x $dir/q/do_vad.JOB.sh || exit 1 +fi + +if $cleanup; then + for x in `seq $[num_iters - 1]`; do + if [ $[x % 10] -ne 0 ]; then + rm $dir/*.$x.mdl + fi + done +fi + +# Summarize warning messages... +utils/summarize_warnings.pl $dir/log + diff --git a/egs/sre08/v1/local/prepare_callhome_eval.sh b/egs/sre08/v1/local/prepare_callhome_eval.sh new file mode 100755 index 00000000000..b3b69e64aa8 --- /dev/null +++ b/egs/sre08/v1/local/prepare_callhome_eval.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +set -e + +. path.sh + +echo $* + +if [ $# -ne 2 ]; then + echo "Usage: local/prepare_callhome_eval.sh " + echo " e.g.: local/prepare_callhome_eval.sh /home/dpovey/diarization/data data/callhome_eval" + exit 1 +fi + +DATA=$1 +dir=$2 + +mkdir -p $dir + +sph2pipe=`which sph2pipe` + +if [ -z "$sph2pipe" ]; then + echo "$0: Cannot find sph2pipe" + exit 1 +fi + +for x in `find $DATA/ -name "*.sph"`; do + y=${x##*/} + z=${y%.sph} + echo "$z $sph2pipe -f wav -p $x |" +done | sort -k 1,1 > $dir/wav.scp + +awk '{print $1" "$1}' $dir/wav.scp > $dir/utt2spk +cp $dir/utt2spk $dir/spk2utt + +utils/fix_data_dir.sh $dir diff --git a/egs/sre08/v1/sid/train_ivector_extractor.sh b/egs/sre08/v1/sid/train_ivector_extractor.sh index 5d7eb984485..3ef41846cdc 100755 --- a/egs/sre08/v1/sid/train_ivector_extractor.sh +++ b/egs/sre08/v1/sid/train_ivector_extractor.sh @@ -162,4 +162,4 @@ while [ $x -lt $num_iters ]; do x=$[$x+1] done -ln -s $x.ie $dir/final.ie +ln -sf $x.ie $dir/final.ie diff --git a/egs/wsj/s5/steps/decode_nolats.sh b/egs/wsj/s5/steps/decode_nolats.sh index 9c05d3eea30..dee95007852 100755 --- a/egs/wsj/s5/steps/decode_nolats.sh +++ b/egs/wsj/s5/steps/decode_nolats.sh @@ -93,9 +93,11 @@ esac if [ ! -z "$transform_dir" ]; then # add transforms to features... echo "Using fMLLR transforms from $transform_dir" [ ! -f $transform_dir/trans.1 ] && echo "Expected $transform_dir/trans.1 to exist." - [ "`cat $transform_dir/num_jobs`" -ne $nj ] && \ - echo "Mismatch in number of jobs with $transform_dir"; - feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$transform_dir/trans.JOB ark:- ark:- |" + if [ "`cat $transform_dir/num_jobs`" -ne $nj ]; then + feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk 'ark:cat $transform_dir/trans.* |' ark:- ark:- |" + else + feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$transform_dir/trans.JOB ark:- ark:- |" + fi fi if [ $stage -le 0 ]; then diff --git a/egs/wsj/s5/steps/make_phone_graph.sh b/egs/wsj/s5/steps/make_phone_graph.sh index 4dbb5a8a206..24ae0c55f18 100755 --- a/egs/wsj/s5/steps/make_phone_graph.sh +++ b/egs/wsj/s5/steps/make_phone_graph.sh @@ -142,5 +142,7 @@ if [ $stage -le 7 ]; then # $lang/phones.txt is the symbol table that corresponds to the output # symbols on the graph; decoding scripts expect it as words.txt. cp $lang/phones.txt $dir/phone_graph/words.txt + cp $lang/phones.txt $dir/phone_graph/phones.txt + cp -r $lang/phones $dir/phone_graph/ fi diff --git a/egs/wsj/s5/steps/nnet3/adjust_priors.sh b/egs/wsj/s5/steps/nnet3/adjust_priors.sh new file mode 100755 index 00000000000..daa469f7631 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/adjust_priors.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +. path.sh + +cmd=run.pl +prior_subset_size=20000 # 20k samples per job, for computing priors. +num_jobs_compute_prior=10 # these are single-threaded, run on CPU. +use_gpu=false # if true, we run on GPU. +egs_dir= +iter=final + +. utils/parse_options.sh + +echo "$0 $@" # Print the command line for logging + +if [ $# -ne 1 ]; then + echo "Usage: $0 [opts] " + echo " e.g.: $0 exp/nnet3_sad_snr/tdnn_train_100k_whole_1k_splice2_2_relu500" + exit 1 +fi + +dir=$1 + +if $use_gpu; then + prior_gpu_opt="--use-gpu=yes" + prior_queue_opt="--gpu 1" +else + prior_gpu_opt="--use-gpu=no" + prior_queue_opt="" +fi + +[ -z "$egs_dir" ] && egs_dir=$dir/egs + +for f in $egs_dir/egs.1.ark $dir/configs/vars $egs_dir/info/num_archives; do + if [ ! -f $f ]; then + echo "$f not found" + exit 1 + fi +done + +rm -f $dir/post.$iter.*.vec 2>/dev/null + +. $dir/configs/vars || exit 1; +context_opts="--left-context=$left_context --right-context=$right_context" + +num_archives=$(cat $egs_dir/info/num_archives) || { echo "error: no such file $egs_dir/info/frames_per_eg"; exit 1; } +if [ $num_jobs_compute_prior -gt $num_archives ]; then egs_part=1; +else egs_part=JOB; fi + +$cmd JOB=1:$num_jobs_compute_prior $prior_queue_opt $dir/log/get_post.$iter.JOB.log \ + nnet3-copy-egs --frame=random $context_opts --srand=JOB ark:$egs_dir/egs.$egs_part.ark ark:- \| \ + nnet3-subset-egs --srand=JOB --n=$prior_subset_size ark:- ark:- \| \ + nnet3-merge-egs ark:- ark:- \| \ + nnet3-compute-from-egs $prior_gpu_opt --apply-exp=true \ + $dir/$iter.raw ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| vector-sum ark:- $dir/post.$iter.JOB.vec || exit 1; + +sleep 3; # make sure there is time for $dir/post.$iter.*.vec to appear. + +$cmd $dir/log/vector_sum.$iter.log \ + vector-sum $dir/post.$iter.*.vec $dir/post.$iter.vec || exit 1; + +rm -f $dir/post.$iter.*.vec; diff --git a/egs/wsj/s5/steps/nnet3/components.py b/egs/wsj/s5/steps/nnet3/components.py index 1fc49290dfe..6cb80a71814 100644 --- a/egs/wsj/s5/steps/nnet3/components.py +++ b/egs/wsj/s5/steps/nnet3/components.py @@ -45,7 +45,6 @@ def GetSumDescriptor(inputs): return sum_descriptors - # adds the input nodes and returns the descriptor def AddInputLayer(config_lines, feat_dim, splice_indexes=[0], ivector_dim=0): components = config_lines['components'] @@ -180,18 +179,27 @@ def AddSoftmaxLayer(config_lines, name, input): 'dimension': input['dimension']} -def AddOutputLayer(config_lines, input, label_delay=None): +def AddSigmoidLayer(config_lines, name, input): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + components.append("component name={0}_sigmoid type=SigmoidComponent dim={1}".format(name, input['dimension'])) + component_nodes.append("component-node name={0}_sigmoid component={0}_sigmoid input={1}".format(name, input['descriptor'])) + return {'descriptor': '{0}_sigmoid'.format(name), + 'dimension': input['dimension']} + +def AddOutputLayer(config_lines, input, label_delay = None, objective_type = "linear"): components = config_lines['components'] component_nodes = config_lines['component-nodes'] if label_delay is None: - component_nodes.append('output-node name=output input={0}'.format(input['descriptor'])) + component_nodes.append('output-node name=output input={0} objective={1}'.format(input['descriptor'], objective_type)) else: - component_nodes.append('output-node name=output input=Offset({0},{1})'.format(input['descriptor'], label_delay)) + component_nodes.append('output-node name=output input=Offset({0},{1}) objective={2}'.format(input['descriptor'], label_delay, objective_type)) -def AddFinalLayer(config_lines, input, output_dim, ng_affine_options = " param-stddev=0 bias-stddev=0 ", label_delay=None, use_presoftmax_prior_scale = False, prior_scale_file = None, include_log_softmax = True): +def AddFinalLayer(config_lines, input, output_dim, ng_affine_options = " param-stddev=0 bias-stddev=0 ", label_delay=None, use_presoftmax_prior_scale = False, prior_scale_file = None, include_log_softmax = True, objective_type = "linear"): components = config_lines['components'] component_nodes = config_lines['component-nodes'] - + prev_layer_output = AddAffineLayer(config_lines, "Final", input, output_dim, ng_affine_options) if include_log_softmax: if use_presoftmax_prior_scale : @@ -199,7 +207,16 @@ def AddFinalLayer(config_lines, input, output_dim, ng_affine_options = " param-s component_nodes.append('component-node name=Final-fixed-scale component=Final-fixed-scale input={0}'.format(prev_layer_output['descriptor'])) prev_layer_output['descriptor'] = "Final-fixed-scale" prev_layer_output = AddSoftmaxLayer(config_lines, "Final", prev_layer_output) - AddOutputLayer(config_lines, prev_layer_output, label_delay) + AddOutputLayer(config_lines, prev_layer_output, label_delay, objective_type) + +def AddFinalSigmoidLayer(config_lines, input, output_dim, ng_affine_options = " param-stddev=0 bias-stddev=0 ", label_delay=None, objective_type = "linear"): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + prev_layer_output = AddAffineLayer(config_lines, "Final", input, output_dim, ng_affine_options) + prev_layer_output = AddSigmoidLayer(config_lines, "Final", prev_layer_output) + AddOutputLayer(config_lines, prev_layer_output, label_delay, objective_type) + def AddLstmLayer(config_lines, name, input, cell_dim, @@ -269,6 +286,11 @@ def AddLstmLayer(config_lines, component_nodes.append("component-node name={0}_c_t component={0}_c input=Sum({0}_c1_t, {0}_c2_t)".format(name)) c_tminus1_descriptor = "IfDefined(Offset({0}_c_t, {1}))".format(name, lstm_delay) + result = re.match("^Append\((.+)\)$", input_descriptor) + if result: + print("Removing Append from descriptor", file=sys.stderr) + input_descriptor = result.group(1) + component_nodes.append("# i_t") component_nodes.append("component-node name={0}_i1 component={0}_W_i-xr input=Append({1}, IfDefined(Offset({0}_{2}, {3})))".format(name, input_descriptor, recurrent_connection, lstm_delay)) component_nodes.append("component-node name={0}_i2 component={0}_w_ic input={1}".format(name, c_tminus1_descriptor)) diff --git a/egs/wsj/s5/steps/nnet3/get_egs.sh b/egs/wsj/s5/steps/nnet3/get_egs.sh index 364f6a72443..199a997962a 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs.sh +++ b/egs/wsj/s5/steps/nnet3/get_egs.sh @@ -107,7 +107,7 @@ cp $alidir/tree $dir num_ali_jobs=$(cat $alidir/num_jobs) || exit 1; # Get list of validation utterances. -awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset \ +awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset | sort \ > $dir/valid_uttlist || exit 1; if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. @@ -122,7 +122,7 @@ if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. fi awk '{print $1}' $data/utt2spk | utils/filter_scp.pl --exclude $dir/valid_uttlist | \ - utils/shuffle_list.pl | head -$num_utts_subset > $dir/train_subset_uttlist || exit 1; + utils/shuffle_list.pl | head -$num_utts_subset | sort > $dir/train_subset_uttlist || exit 1; [ -z "$transform_dir" ] && transform_dir=$alidir diff --git a/egs/wsj/s5/steps/nnet3/get_egs_dense_targets.sh b/egs/wsj/s5/steps/nnet3/get_egs_dense_targets.sh new file mode 100755 index 00000000000..a64eac37558 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/get_egs_dense_targets.sh @@ -0,0 +1,448 @@ +#!/bin/bash + +# Copyright 2012-2015 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# +# This script, which will generally be called from other neural-net training +# scripts, extracts the training examples used to train the neural net (and also +# the validation examples used for diagnostics), and puts them in separate archives. +# +# This script dumps egs with several frames of labels, controlled by the +# frames_per_eg config variable (default: 8). This takes many times less disk +# space because typically we have 4 to 7 frames of context on the left and +# right, and this ends up getting shared. This is at the expense of slightly +# higher disk I/O while training. + + +# Begin configuration section. +cmd=run.pl +feat_type=raw # set it to 'lda' to use LDA features. +target_type=dense # dense to have dense targets, + # sparse to have posteriors targets +num_targets= +#sparse_input_dim= +deriv_weights_scp= +frames_per_eg=8 # number of frames of labels per example. more->less disk space and + # less time preparing egs, but more I/O during training. + # note: the script may reduce this if reduce_frames_per_eg is true. +left_context=4 # amount of left-context per eg (i.e. extra frames of input features + # not present in the output supervision). +right_context=4 # amount of right-context per eg. +valid_left_context= # amount of left_context for validation egs, typically used in + # recurrent architectures to ensure matched condition with + # training egs +valid_right_context= # amount of right_context for validation egs +compress=true # set this to false to disable compression (e.g. if you want to see whether + # results are affected). + +reduce_frames_per_eg=true # If true, this script may reduce the frames_per_eg + # if there is only one archive and even with the + # reduced frames_per_eg, the number of + # samples_per_iter that would result is less than or + # equal to the user-specified value. +num_utts_subset=300 # number of utterances in validation and training + # subsets used for shrinkage and diagnostics. +num_valid_frames_combine=0 # #valid frames for combination weights at the very end. +num_train_frames_combine=10000 # # train frames for the above. +num_frames_diagnostic=4000 # number of frames for "compute_prob" jobs +samples_per_iter=400000 # this is the target number of egs in each archive of egs + # (prior to merging egs). We probably should have called + # it egs_per_iter. This is just a guideline; it will pick + # a number that divides the number of samples in the + # entire data. + +transform_dir= # If supplied, overrides alidir as the place to find fMLLR transforms + +stage=0 +nj=6 # This should be set to the maximum number of jobs you are + # comfortable to run in parallel; you can increase it if your disk + # speed is greater and you have more machines. +online_ivector_dir= # can be used if we are including speaker information as iVectors. +cmvn_opts= # can be used for specifying CMVN options, if feature type is not lda (if lda, + # it doesn't make sense to use different options than were used as input to the + # LDA transform). This is used to turn off CMVN in the online-nnet experiments. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + + +if [ $# != 3 ]; then + echo "Usage: $0 [opts] " + echo " e.g.: $0 data/train data/train/snr_targets.scp exp/tri4_nnet/egs" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options" + echo " --cmd (utils/run.pl;utils/queue.pl ) # how to run jobs." + echo " --samples-per-iter <#samples;400000> # Target number of egs per archive (option is badly named)" + echo " --feat-type # (raw is the default). The feature type you want" + echo " # to use as input to the neural net." + echo " --frames-per-eg # number of frames per eg on disk" + echo " --left-context # Number of frames on left side to append for feature input" + echo " --right-context # Number of frames on right side to append for feature input" + echo " --num-frames-diagnostic <#frames;4000> # Number of frames used in computing (train,valid) diagnostics" + echo " --num-valid-frames-combine <#frames;10000> # Number of frames used in getting combination weights at the" + echo " # very end." + echo " --stage # Used to run a partially-completed training process from somewhere in" + echo " # the middle." + + exit 1; +fi + +data=$1 +targets_scp=$2 +dir=$3 + +# Check some files. +[ ! -z "$online_ivector_dir" ] && \ + extra_files="$online_ivector_dir/ivector_online.scp $online_ivector_dir/ivector_period" + +#if [ "$feat_type" != "sparse" ]; then +# feats_scp=$data/feats.scp +#fi + +for f in $feats_scp $targets_scp $extra_files; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +sdata=$data/split$nj +utils/split_data.sh $data $nj + +mkdir -p $dir/log $dir/info +[ ! -z "$transform_dir" ] && cp $transform_dir/tree $dir + +# Get list of validation utterances. +awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset | sort \ + > $dir/valid_uttlist || exit 1; + +if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. + echo "File $data/utt2uniq exists, so augmenting valid_uttlist to" + echo "include all perturbed versions of the same 'real' utterances." + mv $dir/valid_uttlist $dir/valid_uttlist.tmp + utils/utt2spk_to_spk2utt.pl $data/utt2uniq > $dir/uniq2utt + cat $dir/valid_uttlist.tmp | utils/apply_map.pl $data/utt2uniq | \ + sort | uniq | utils/apply_map.pl $dir/uniq2utt | \ + awk '{for(n=1;n<=NF;n++) print $n;}' | sort > $dir/valid_uttlist + rm $dir/uniq2utt $dir/valid_uttlist.tmp +fi + +awk '{print $1}' $data/utt2spk | utils/filter_scp.pl --exclude $dir/valid_uttlist | \ + utils/shuffle_list.pl | head -$num_utts_subset > $dir/train_subset_uttlist || exit 1; + +# because we'll need the features with a different number of jobs than $alidir, +# copy to ark,scp. +if [ -f $transform_dir/trans.1 ] && [ $feat_type != "raw" ]; then + echo "$0: using transforms from $transform_dir" + if [ $stage -le 0 ]; then + $cmd $dir/log/copy_transforms.log \ + copy-feats "ark:cat $transform_dir/trans.* |" "ark,scp:$dir/trans.ark,$dir/trans.scp" + fi +fi +if [ -f $transform_dir/raw_trans.1 ] && [ $feat_type == "raw" ]; then + echo "$0: using raw transforms from $transform_dir" + if [ $stage -le 0 ]; then + $cmd $dir/log/copy_transforms.log \ + copy-feats "ark:cat $transform_dir/raw_trans.* |" "ark,scp:$dir/trans.ark,$dir/trans.scp" + fi +fi + + + +## Set up features. +echo "$0: feature type is $feat_type" + +case $feat_type in + raw|sparse) feats="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $sdata/JOB/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:- ark:- |" + valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" + train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" + echo $cmvn_opts >$dir/cmvn_opts # caution: the top-level nnet training script should copy this to its own dir now. + ;; + lda|lda_sparse) + splice_opts=`cat $transform_dir/splice_opts 2>/dev/null` + # caution: the top-level nnet training script should copy these to its own dir now. + cp $transform_dir/{splice_opts,cmvn_opts,final.mat} $dir || exit 1; + [ ! -z "$cmvn_opts" ] && \ + echo "You cannot supply --cmvn-opts option if feature type is LDA." && exit 1; + cmvn_opts=$(cat $dir/cmvn_opts) + feats="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $sdata/JOB/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" + valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" + train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" + ;; + #sparse) + # for n in `seq $nj`; do + # utils/filter_scp.pl $sdata/$n/utt2spk $feats_scp > $dir/sparse_feats.$n.scp + # done + + # feats_scp_split=$dir/sparse_feats.JOB.scp + + # feats="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $feats_scp_split | copy-post scp:- ark:- |" + # valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $feats_scp | copy-post scp:- ark:- |" + # train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $feats_scp | copy-post scp:- ark:- |" + # ;; + *) echo "$0: invalid feature type --feat-type '$feat_type'" && exit 1; +esac + +if [ -f $dir/trans.scp ]; then + feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk scp:$dir/trans.scp ark:- ark:- |" + valid_feats="$valid_feats transform-feats --utt2spk=ark:$data/utt2spk scp:$dir/trans.scp|' ark:- ark:- |" + train_subset_feats="$train_subset_feats transform-feats --utt2spk=ark:$data/utt2spk scp:$dir/trans.scp|' ark:- ark:- |" +fi + +if [ ! -z "$online_ivector_dir" ]; then + ivector_dim=$(feat-to-dim scp:$online_ivector_dir/ivector_online.scp -) || exit 1; + echo $ivector_dim > $dir/info/ivector_dim + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + + ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $sdata/JOB/utt2spk $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" + valid_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" + train_subset_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" +else + echo 0 >$dir/info/ivector_dim +fi + +if [ $stage -le 1 ]; then + echo "$0: working out number of frames of training data" + num_frames=$(steps/nnet2/get_num_frames.sh $data) + echo $num_frames > $dir/info/num_frames + echo "$0: working out feature dim" + feats_one="$(echo $feats | sed s/JOB/1/g)" + #if [ $feat_type != "sparse" ]; then + feat_dim=$(feat-to-dim "$feats_one" -) || exit 1; + #else + # if [ -z "$sparse_input_dim" ]; then + # echo "$0: feat-type is sparse; but sparse-input-dim is not specified" + # exit 1 + # fi + + # feat_dim=$sparse_input_dim + #fi + + echo $feat_dim > $dir/info/feat_dim +else + num_frames=$(cat $dir/info/num_frames) || exit 1; + feat_dim=$(cat $dir/info/feat_dim) || exit 1; +fi + +# the + 1 is to round up, not down... we assume it doesn't divide exactly. +num_archives=$[$num_frames/($frames_per_eg*$samples_per_iter)+1] +# (for small data)- while reduce_frames_per_eg == true and the number of +# archives is 1 and would still be 1 if we reduced frames_per_eg by 1, reduce it +# by 1. +reduced=false +while $reduce_frames_per_eg && [ $frames_per_eg -gt 1 ] && \ + [ $[$num_frames/(($frames_per_eg-1)*$samples_per_iter)] -eq 0 ]; do + frames_per_eg=$[$frames_per_eg-1] + num_archives=1 + reduced=true +done +$reduced && echo "$0: reduced frames_per_eg to $frames_per_eg because amount of data is small." + +# We may have to first create a smaller number of larger archives, with number +# $num_archives_intermediate, if $num_archives is more than the maximum number +# of open filehandles that the system allows per process (ulimit -n). +max_open_filehandles=$(ulimit -n) || exit 1 +num_archives_intermediate=$num_archives +archives_multiple=1 +while [ $[$num_archives_intermediate+4] -gt $max_open_filehandles ]; do + archives_multiple=$[$archives_multiple+1] + num_archives_intermediate=$[$num_archives/$archives_multiple]; +done +# now make sure num_archives is an exact multiple of archives_multiple. +num_archives=$[$archives_multiple*$num_archives_intermediate] + +echo $num_archives >$dir/info/num_archives +echo $frames_per_eg >$dir/info/frames_per_eg +# Work out the number of egs per archive +egs_per_archive=$[$num_frames/($frames_per_eg*$num_archives)] +! [ $egs_per_archive -le $samples_per_iter ] && \ + echo "$0: script error: egs_per_archive=$egs_per_archive not <= samples_per_iter=$samples_per_iter" \ + && exit 1; + +echo $egs_per_archive > $dir/info/egs_per_archive + +echo "$0: creating $num_archives archives, each with $egs_per_archive egs, with" +echo "$0: $frames_per_eg labels per example, and (left,right) context = ($left_context,$right_context)" + + + +if [ -e $dir/storage ]; then + # Make soft links to storage directories, if distributing this way.. See + # utils/create_split_dir.pl. + echo "$0: creating data links" + utils/create_data_link.pl $(for x in $(seq $num_archives); do echo $dir/egs.$x.ark; done) + for x in $(seq $num_archives_intermediate); do + utils/create_data_link.pl $(for y in $(seq $nj); do echo $dir/egs_orig.$y.$x.ark; done) + done +fi + +egs_opts="--left-context=$left_context --right-context=$right_context --compress=$compress" + +[ ! -z "$deriv_weights_scp" ] && egs_opts="$egs_opts --deriv-weights-rspecifier=scp:$deriv_weights_scp" + +[ -z $valid_left_context ] && valid_left_context=$left_context; +[ -z $valid_right_context ] && valid_right_context=$right_context; +valid_egs_opts="--left-context=$valid_left_context --right-context=$valid_right_context --compress=$compress" + +echo $left_context > $dir/info/left_context +echo $right_context > $dir/info/right_context + +if [ $target_type == "dense" ]; then + num_targets=$(feat-to-dim "scp:$targets_scp" - 2>/dev/null) || exit 1 +else + if [ -z "$num_targets" ]; then + echo "$0: num-targets is not set" + exit 1 + fi +fi + +for n in `seq $nj`; do + utils/filter_scp.pl $sdata/$n/utt2spk $targets_scp > $dir/targets.$n.scp +done +targets_scp_split=$dir/targets.JOB.scp + +case $target_type in + "dense") + #if [ "$feat_type" == "sparse" ]; then + # echo "$0: dense targets with sparse inputs is not supported" + # exit 1 + #fi + get_egs_program="nnet3-get-egs-dense-targets --num-targets=$num_targets" + + targets="ark:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | copy-feats scp:- ark:- |" + valid_targets="ark:utils/filter_scp.pl $dir/valid_uttlist $targets_scp | copy-feats scp:- ark:- |" + train_subset_targets="ark:utils/filter_scp.pl $dir/train_subset_uttlist $targets_scp | copy-feats scp:- ark:- |" + ;; + "sparse") + #if [ "$feat_type" != "sparse" ]; then + get_egs_program="nnet3-get-egs --num-pdfs=$num_targets" + #else + # get_egs_program="nnet3-get-egs-sparse-input --sparse-input-dim=$feat_dim --num-pdfs=$num_targets" + #fi + targets="ark:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | ali-to-post scp:- ark:- |" + valid_targets="ark:utils/filter_scp.pl $dir/valid_uttlist $targets_scp | ali-to-post scp:- ark:- |" \ + train_subset_targets="ark:utils/filter_scp.pl $dir/train_subset_uttlist $targets_scp | ali-to-post scp:- ark:- |" + ;; + default) + echo "$0: Unknown --target-type $target_type. Choices are dense and sparse" + exit 1 +esac + +if [ $stage -le 3 ]; then + echo "$0: Getting validation and training subset examples." + rm $dir/.error 2>/dev/null + echo "$0: ... extracting validation and training-subset alignments." + + ( + $cmd $dir/log/create_valid_subset.log \ + $get_egs_program \ + $valid_ivector_opt $egs_opts "$valid_feats" \ + "$valid_targets" \ + "ark:$dir/valid_all.egs" || touch $dir/.error & + $cmd $dir/log/create_train_subset.log \ + $get_egs_program \ + $train_subset_ivector_opt $egs_opts "$train_subset_feats" \ + "$train_subset_targets" \ + "ark:$dir/train_subset_all.egs" || touch $dir/.error & + wait; + + [ -f $dir/.error ] && echo "Error detected while creating train/valid egs" && exit 1 + echo "... Getting subsets of validation examples for diagnostics and combination." + $cmd $dir/log/create_valid_subset_combine.log \ + nnet3-subset-egs --n=$num_valid_frames_combine ark:$dir/valid_all.egs \ + ark:$dir/valid_combine.egs || touch $dir/.error & + $cmd $dir/log/create_valid_subset_diagnostic.log \ + nnet3-subset-egs --n=$num_frames_diagnostic ark:$dir/valid_all.egs \ + ark:$dir/valid_diagnostic.egs || touch $dir/.error & + + $cmd $dir/log/create_train_subset_combine.log \ + nnet3-subset-egs --n=$num_train_frames_combine ark:$dir/train_subset_all.egs \ + ark:$dir/train_combine.egs || touch $dir/.error & + $cmd $dir/log/create_train_subset_diagnostic.log \ + nnet3-subset-egs --n=$num_frames_diagnostic ark:$dir/train_subset_all.egs \ + ark:$dir/train_diagnostic.egs || touch $dir/.error & + wait + sleep 5 # wait for file system to sync. + cat $dir/valid_combine.egs $dir/train_combine.egs > $dir/combine.egs + + for f in $dir/{combine,train_diagnostic,valid_diagnostic}.egs; do + [ ! -s $f ] && touch $dir/.error && exit 1 + done + rm -f $dir/valid_all.egs $dir/train_subset_all.egs $dir/{train,valid}_combine.egs + ) & +fi + +if [ $stage -le 4 ]; then + # create egs_orig.*.*.ark; the first index goes to $nj, + # the second to $num_archives_intermediate. + + egs_list= + for n in $(seq $num_archives_intermediate); do + egs_list="$egs_list ark:$dir/egs_orig.JOB.$n.ark" + done + echo "$0: Generating training examples on disk" + # The examples will go round-robin to egs_list. + + $cmd JOB=1:$nj $dir/log/get_egs.JOB.log \ + $get_egs_program \ + $ivector_opt $egs_opts --num-frames=$frames_per_eg "$feats" "$targets" \ + ark:- \| \ + nnet3-copy-egs --random=true --srand=JOB ark:- $egs_list || exit 1; +fi + +if [ $stage -le 5 ]; then + echo "$0: recombining and shuffling order of archives on disk" + # combine all the "egs_orig.*.JOB.scp" (over the $nj splits of the data) and + # shuffle the order, writing to the egs.JOB.ark + + # the input is a concatenation over the input jobs. + egs_list= + for n in $(seq $nj); do + egs_list="$egs_list $dir/egs_orig.$n.JOB.ark" + done + + if [ $archives_multiple == 1 ]; then # normal case. + $cmd --max-jobs-run $nj JOB=1:$num_archives_intermediate $dir/log/shuffle.JOB.log \ + nnet3-shuffle-egs --srand=JOB "ark:cat $egs_list|" ark:$dir/egs.JOB.ark || exit 1; + else + # we need to shuffle the 'intermediate archives' and then split into the + # final archives. we create soft links to manage this splitting, because + # otherwise managing the output names is quite difficult (and we don't want + # to submit separate queue jobs for each intermediate archive, because then + # the --max-jobs-run option is hard to enforce). + output_archives="$(for y in $(seq $archives_multiple); do echo ark:$dir/egs.JOB.$y.ark; done)" + for x in $(seq $num_archives_intermediate); do + for y in $(seq $archives_multiple); do + archive_index=$[($x-1)*$archives_multiple+$y] + # egs.intermediate_archive.{1,2,...}.ark will point to egs.archive.ark + ln -sf egs.$archive_index.ark $dir/egs.$x.$y.ark || exit 1 + done + done + $cmd --max-jobs-run $nj JOB=1:$num_archives_intermediate $dir/log/shuffle.JOB.log \ + nnet3-shuffle-egs --srand=JOB "ark:cat $egs_list|" ark:- \| \ + nnet3-copy-egs ark:- $output_archives || exit 1; + fi + +fi + +wait +[ -f $dir/.error ] && echo "Error detected while creating train/valid egs" && exit 1 + +if [ $stage -le 6 ]; then + echo "$0: removing temporary archives" + for x in $(seq $nj); do + for y in $(seq $num_archives_intermediate); do + file=$dir/egs_orig.$x.$y.ark + [ -L $file ] && rm $(readlink -f $file) + rm $file + done + done + if [ $archives_multiple -gt 1 ]; then + # there are some extra soft links that we should delete. + for f in $dir/egs.*.*.ark; do rm $f; done + fi + echo "$0: removing temporary alignments and transforms" + # Ignore errors below because trans.* might not exist. + rm -f $dir/{ali,trans}.{ark,scp} 2>/dev/null +fi + +echo "$0: Finished preparing training examples" diff --git a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py index 17b8bea228d..958508fe0da 100755 --- a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py +++ b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py @@ -69,7 +69,7 @@ def ParseLstmDelayString(lstm_delay): if len(indexes) < 1: raise ValueError("invalid --lstm-delay argument, too-short element: " + lstm_delay) - elif len(indexes) == 2 and indexes[0] * indexes[1] >= 0: + elif len(indexes) == 2 and indexes[0] * indexes[1] >= 0: raise ValueError('Warning: ' + str(indexes) + ' is not a standard BLSTM mode. There should be a negative delay for the forward, and a postive delay for the backward.') lstm_delay_array.append(indexes) except ValueError as e: @@ -77,7 +77,7 @@ def ParseLstmDelayString(lstm_delay): return lstm_delay_array - + if __name__ == "__main__": # we add compulsary arguments as named arguments for readability parser = argparse.ArgumentParser(description="Writes config files and variables " @@ -205,11 +205,11 @@ def ParseLstmDelayString(lstm_delay): prev_layer_output['descriptor'] = 'Append({0}, {1})'.format(prev_layer_output1['descriptor'], prev_layer_output2['descriptor']) prev_layer_output['dimension'] = prev_layer_output1['dimension'] + prev_layer_output2['dimension'] else: # LSTM layer case - prev_layer_output = nodes.AddLstmLayer(config_lines, "Lstm{0}".format(i+1), prev_layer_output, args.cell_dim, - args.recurrent_projection_dim, args.non_recurrent_projection_dim, - args.clipping_threshold, args.norm_based_clipping, - args.ng_per_element_scale_options, args.ng_affine_options, - lstm_delay = lstm_delay[i][0]) + prev_layer_output = nodes.AddLstmLayer(config_lines, "Lstm{0}".format(i+1), prev_layer_output, args.cell_dim, + args.recurrent_projection_dim, args.non_recurrent_projection_dim, + args.clipping_threshold, args.norm_based_clipping, + args.ng_per_element_scale_options, args.ng_affine_options, + lstm_delay = lstm_delay[i][0]) # make the intermediate config file for layerwise discriminative # training nodes.AddFinalLayer(config_lines, prev_layer_output, args.num_targets, args.ng_affine_options, args.label_delay, args.include_log_softmax) diff --git a/egs/wsj/s5/steps/nnet3/make_tdnn_raw_configs.py b/egs/wsj/s5/steps/nnet3/make_tdnn_raw_configs.py new file mode 100755 index 00000000000..d6615af09c2 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/make_tdnn_raw_configs.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import re, os, argparse, sys, math, warnings + + + +parser = argparse.ArgumentParser(description="Writes config files and variables " + "for TDNNs creation and training", + epilog="See steps/nnet3/train_tdnn.sh for example."); +parser.add_argument("--splice-indexes", type=str, + help="Splice indexes at each hidden layer, e.g. '-3,-2,-1,0,1,2,3 0 -2,2 0 -4,4 0 -8,8'") +parser.add_argument("--feat-dim", type=int, + help="Raw feature dimension, e.g. 13") +parser.add_argument("--ivector-dim", type=int, + help="iVector dimension, e.g. 100", default=0) +parser.add_argument("--include-log-softmax", type=str, + help="add the final softmax layer ", default="true", choices = ["false", "true"]) +parser.add_argument("--final-layer-normalize-target", type=float, + help="RMS target for final layer (set to <1 if final layer learns too fast", + default=1.0) +parser.add_argument("--pnorm-input-dim", type=int, + help="input dimension to p-norm nonlinearities") +parser.add_argument("--pnorm-output-dim", type=int, + help="output dimension of p-norm nonlinearities") +parser.add_argument("--relu-dim", type=int, + help="dimension of ReLU nonlinearities") +parser.add_argument("--sigmoid-dim", type=int, + help="dimension of Sigmoid nonlinearities") +parser.add_argument("--use-presoftmax-prior-scale", type=str, + help="if true, a presoftmax-prior-scale is added", + choices=['true', 'false'], default = "true") +parser.add_argument("--num-targets", type=int, + help="number of network targets (e.g. num-pdf-ids/num-leaves)") +parser.add_argument("--skip-lda", type=str, + help="add lda matrix", + choices=['true', 'false'], default = "false") +parser.add_argument("--add-final-sigmoid", type=str, + help="add a sigmoid layer as the final layer. Applicable only if skip-final-softmax is true.", + choices=['true', 'false'], default = "false") +parser.add_argument("--objective-type", type=str, default="linear", + choices = ["linear", "quadratic", "xent"], + help = "the type of objective; i.e. quadratic or linear or cross-entropy") +parser.add_argument("config_dir", + help="Directory to write config files and variables"); +print(' '.join(sys.argv)) + +args = parser.parse_args() + +if not os.path.exists(args.config_dir): + os.makedirs(args.config_dir) + +## Check arguments. +if args.splice_indexes is None: + sys.exit("--splice-indexes argument is required"); +if args.feat_dim is None or not (args.feat_dim > 0): + sys.exit("--feat-dim argument is required"); +if args.num_targets is None or not (args.num_targets > 0): + sys.exit("--num-targets argument is required"); +if not args.relu_dim is None: + if not args.pnorm_input_dim is None or not args.pnorm_output_dim is None or not args.sigmoid_dim is None: + sys.exit("--relu-dim argument not compatible with " + "--pnorm-input-dim, --pnorm-output-dim and --sigmoid-dim options"); + nonlin_input_dim = args.relu_dim + nonlin_output_dim = args.relu_dim +elif not args.sigmoid_dim is None: + if not args.pnorm_input_dim is None or not args.pnorm_output_dim is None: + sys.exit("--sigmoid-dim argument not compatible with " + "--pnorm-input-dim and --pnorm-output-dim options"); + nonlin_input_dim = args.sigmoid_dim + nonlin_output_dim = args.sigmoid_dim +else: + if not args.pnorm_input_dim > 0 or not args.pnorm_output_dim > 0: + sys.exit("--relu-dim and --sigmoid-dim not set, so expected --pnorm-input-dim and " + "--pnorm-output-dim to be provided."); + nonlin_input_dim = args.pnorm_input_dim + nonlin_output_dim = args.pnorm_output_dim + +if args.use_presoftmax_prior_scale == "true": + use_presoftmax_prior_scale = True +else: + use_presoftmax_prior_scale = False + +if args.skip_lda == "true": + skip_lda = True +else: + skip_lda = False + +if args.add_final_sigmoid == "true": + add_final_sigmoid = True +else: + add_final_sigmoid = False + +## Work out splice_array e.g. splice_array = [ [ -3,-2,...3 ], [0], [-2,2], .. [ -8,8 ] ] +splice_array = [] +left_context = 0 +right_context = 0 +split1 = args.splice_indexes.split(); # we already checked the string is nonempty. +input_dim = args.feat_dim + args.ivector_dim +if len(split1) < 1: + sys.exit("invalid --splice-indexes argument, too short: " + + args.splice_indexes) +try: + for string in split1: + split2 = string.split(",") + if len(split2) < 1: + sys.exit("invalid --splice-indexes argument, too-short element: " + + args.splice_indexes) + int_list = [] + for int_str in split2: + int_list.append(int(int_str)) + if not int_list == sorted(int_list): + sys.exit("elements of --splice-indexes must be sorted: " + + args.splice_indexes) + left_context += -int_list[0] + right_context += int_list[-1] + splice_array.append(int_list) +except ValueError as e: + sys.exit("invalid --splice-indexes argument " + args.splice_indexes + e) +left_context = max(0, left_context) +right_context = max(0, right_context) +num_hidden_layers = len(splice_array) +input_dim = len(splice_array[0]) * args.feat_dim + args.ivector_dim + +f = open(args.config_dir + "/vars", "w") +print('left_context=' + str(left_context), file=f) +print('right_context=' + str(right_context), file=f) +# the initial l/r contexts are actually not needed. +# print('initial_left_context=' + str(splice_array[0][0]), file=f) +# print('initial_right_context=' + str(splice_array[0][-1]), file=f) +print('num_hidden_layers=' + str(num_hidden_layers), file=f) +f.close() + +f = open(args.config_dir + "/init.config", "w") +print('# Config file for initializing neural network prior to', file=f) +print('# preconditioning matrix computation', file=f) +print('input-node name=input dim=' + str(args.feat_dim), file=f) +list=[ ('Offset(input, {0})'.format(n) if n != 0 else 'input' ) for n in splice_array[0] ] +if args.ivector_dim > 0: + print('input-node name=ivector dim=' + str(args.ivector_dim), file=f) + list.append('ReplaceIndex(ivector, t, 0)') +# example of next line: +# output-node name=output input="Append(Offset(input, -3), Offset(input, -2), Offset(input, -1), ... , Offset(input, 3), ReplaceIndex(ivector, t, 0))" +print('output-node name=output input=Append({0})'.format(", ".join(list)), file=f) +f.close() + +for l in range(1, num_hidden_layers + 1): + f = open(args.config_dir + "/layer{0}.config".format(l), "w") + print('# Config file for layer {0} of the network'.format(l), file=f) + if l == 1 and not skip_lda: + print('component name=lda type=FixedAffineComponent matrix={0}/lda.mat'. + format(args.config_dir), file=f) + cur_dim = (nonlin_output_dim * len(splice_array[l-1]) if l > 1 else input_dim) + + print('# Note: param-stddev in next component defaults to 1/sqrt(input-dim).', file=f) + print('component name=affine{0} type=NaturalGradientAffineComponent ' + 'input-dim={1} output-dim={2} bias-stddev=0'. + format(l, cur_dim, nonlin_input_dim), file=f) + if args.relu_dim is not None: + print('component name=nonlin{0} type=RectifiedLinearComponent dim={1}'. + format(l, args.relu_dim), file=f) + elif args.sigmoid_dim is not None: + print('component name=nonlin{0} type=SigmoidComponent dim={1}'. + format(l, args.sigmoid_dim), file=f) + else: + print('# In nnet3 framework, p in P-norm is always 2.', file=f) + print('component name=nonlin{0} type=PnormComponent input-dim={1} output-dim={2}'. + format(l, args.pnorm_input_dim, args.pnorm_output_dim), file=f) + print('component name=renorm{0} type=NormalizeComponent dim={1} target-rms={2}'.format( + l, nonlin_output_dim, + (1.0 if l < num_hidden_layers else args.final_layer_normalize_target)), file=f) + print('component name=final-affine type=NaturalGradientAffineComponent ' + 'input-dim={0} output-dim={1} param-stddev=0 bias-stddev=0'.format( + nonlin_output_dim, args.num_targets), file=f) + + if not skip_final_softmax: + # printing out the next two, and their component-nodes, for l > 1 is not + # really necessary as they will already exist, but it doesn't hurt and makes + # the structure clearer. + if use_presoftmax_prior_scale: + print('component name=final-fixed-scale type=FixedScaleComponent ' + 'scales={0}/presoftmax_prior_scale.vec'.format( + args.config_dir), file=f) + print('component name=final-log-softmax type=LogSoftmaxComponent dim={0}'.format( + args.num_targets), file=f) + elif add_final_sigmoid: + print('component name=final-sigmoid type=SigmoidComponent dim={0}'.format( + args.num_targets), file=f) + print('# Now for the network structure', file=f) + if l == 1: + splices = [ ('Offset(input, {0})'.format(n) if n != 0 else 'input') for n in splice_array[l-1] ] + if args.ivector_dim > 0: splices.append('ReplaceIndex(ivector, t, 0)') + orig_input='Append({0})'.format(', '.join(splices)) + # e.g. orig_input = 'Append(Offset(input, -2), ... Offset(input, 2), ivector)' + if not skip_lda: + print('component-node name=lda component=lda input={0}'.format(orig_input), + file=f) + cur_input='lda' + else: + cur_input = orig_input + else: + # e.g. cur_input = 'Append(Offset(renorm1, -2), renorm1, Offset(renorm1, 2))' + splices = [ ('Offset(renorm{0}, {1})'.format(l-1, n) if n !=0 else 'renorm{0}'.format(l-1)) + for n in splice_array[l-1] ] + cur_input='Append({0})'.format(', '.join(splices)) + print('component-node name=affine{0} component=affine{0} input={1} '. + format(l, cur_input), file=f) + print('component-node name=nonlin{0} component=nonlin{0} input=affine{0}'. + format(l), file=f) + print('component-node name=renorm{0} component=renorm{0} input=nonlin{0}'. + format(l), file=f) + + print('component-node name=final-affine component=final-affine input=renorm{0}'. + format(l), file=f) + + if args.include_log_softmax == "true": + if use_presoftmax_prior_scale: + print('component-node name=final-fixed-scale component=final-fixed-scale input=final-affine', + file=f) + print('component-node name=final-log-softmax component=final-log-softmax ' + 'input=final-fixed-scale', file=f) + else: + print('component-node name=final-log-softmax component=final-log-softmax ' + 'input=final-affine', file=f) + print('output-node name=output input=final-log-softmax objective={0}'.format(args.objective_type), file=f) + else: + if add_final_sigmoid: + print('component-node name=final-sigmoid component=final-sigmoid input=final-affine', file=f) + print('output-node name=output input=final-sigmoid objective={0}'.format(args.objective_type), file=f) + else: + print('output-node name=output input=final-affine objective={0}'.format(args.objective_type), file=f) + f.close() + + +# component name=nonlin1 type=PnormComponent input-dim=$pnorm_input_dim output-dim=$pnorm_output_dim +# component name=renorm1 type=NormalizeComponent dim=$pnorm_output_dim +# component name=final-affine type=NaturalGradientAffineComponent input-dim=$pnorm_output_dim output-dim=$num_leaves param-stddev=0 bias-stddev=0 +# component name=final-log-softmax type=LogSoftmaxComponent dim=$num_leaves + + +# ## Write file $config_dir/init.config to initialize the network, prior to computing the LDA matrix. +# ##will look like this, if we have iVectors: +# input-node name=input dim=13 +# input-node name=ivector dim=100 +# output-node name=output input="Append(Offset(input, -3), Offset(input, -2), Offset(input, -1), ... , Offset(input, 3), ReplaceIndex(ivector, t, 0))" + +# ## Write file $config_dir/layer1.config that adds the LDA matrix, assumed to be in the config directory as +# ## lda.mat, the first hidden layer, and the output layer. +# component name=lda type=FixedAffineComponent matrix=$config_dir/lda.mat +# component name=affine1 type=NaturalGradientAffineComponent input-dim=$lda_input_dim output-dim=$pnorm_input_dim bias-stddev=0 +# component name=nonlin1 type=PnormComponent input-dim=$pnorm_input_dim output-dim=$pnorm_output_dim +# component name=renorm1 type=NormalizeComponent dim=$pnorm_output_dim +# component name=final-affine type=NaturalGradientAffineComponent input-dim=$pnorm_output_dim output-dim=$num_leaves param-stddev=0 bias-stddev=0 +# component name=final-log-softmax type=LogSoftmax dim=$num_leaves +# # InputOf(output) says use the same Descriptor of the current "output" node. +# component-node name=lda component=lda input=InputOf(output) +# component-node name=affine1 component=affine1 input=lda +# component-node name=nonlin1 component=nonlin1 input=affine1 +# component-node name=renorm1 component=renorm1 input=nonlin1 +# component-node name=final-affine component=final-affine input=renorm1 +# component-node name=final-log-softmax component=final-log-softmax input=final-affine +# output-node name=output input=final-log-softmax + + +# ## Write file $config_dir/layer2.config that adds the second hidden layer. +# component name=affine2 type=NaturalGradientAffineComponent input-dim=$lda_input_dim output-dim=$pnorm_input_dim bias-stddev=0 +# component name=nonlin2 type=PnormComponent input-dim=$pnorm_input_dim output-dim=$pnorm_output_dim +# component name=renorm2 type=NormalizeComponent dim=$pnorm_output_dim +# component name=final-affine type=NaturalGradientAffineComponent input-dim=$pnorm_output_dim output-dim=$num_leaves param-stddev=0 bias-stddev=0 +# component-node name=affine2 component=affine2 input=Append(Offset(renorm1, -2), Offset(renorm1, 2)) +# component-node name=nonlin2 component=nonlin2 input=affine2 +# component-node name=renorm2 component=renorm2 input=nonlin2 +# component-node name=final-affine component=final-affine input=renorm2 +# component-node name=final-log-softmax component=final-log-softmax input=final-affine +# output-node name=output input=final-log-softmax + + +# ## ... etc. In this example it would go up to $config_dir/layer5.config. + diff --git a/egs/wsj/s5/steps/nnet3/make_tdnn_snr_predictor_configs.py b/egs/wsj/s5/steps/nnet3/make_tdnn_snr_predictor_configs.py new file mode 100644 index 00000000000..df5e92e62f7 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/make_tdnn_snr_predictor_configs.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import re, os, argparse, sys, math, warnings + + + +parser = argparse.ArgumentParser(description="Writes config files and variables " + "for TDNNs creation and training", + epilog="See steps/nnet3/train_tdnn.sh for example."); +parser.add_argument("--splice-indexes", type=str, + help="Splice indexes at each hidden layer, e.g. '-3,-2,-1,0,1,2,3 0 -2,2 0 -4,4 0 -8,8'") +parser.add_argument("--feat-dim", type=int, + help="Raw feature dimension, e.g. 13") +parser.add_argument("--ivector-dim", type=int, + help="iVector dimension, e.g. 100", default=0) +parser.add_argument("--pnorm-input-dim", type=int, + help="input dimension to p-norm nonlinearities") +parser.add_argument("--pnorm-output-dim", type=int, + help="output dimension of p-norm nonlinearities") +parser.add_argument("--relu-dim", type=int, + help="dimension of ReLU nonlinearities") +parser.add_argument("--sigmoid-dim", type=int, + help="dimension of Sigmoid nonlinearities") +parser.add_argument("--pnorm-input-dims", type=str, + help="input dimension to p-norm nonlinearities") +parser.add_argument("--pnorm-output-dims", type=str, + help="output dimension of p-norm nonlinearities") +parser.add_argument("--relu-dims", type=str, + help="dimension of ReLU nonlinearities") +parser.add_argument("--sigmoid-dims", type=str, + help="dimension of Sigmoid nonlinearities") +parser.add_argument("--use-presoftmax-prior-scale", type=str, + help="if true, a presoftmax-prior-scale is added", + choices=['true', 'false'], default = "true") +parser.add_argument("--num-targets", type=int, + help="number of network targets (e.g. num-pdf-ids/num-leaves)") +parser.add_argument("--include-log-softmax", type=str, + help="add the final softmax layer ", default="true", choices = ["false", "true"]) +parser.add_argument("--final-layer-normalize-target", type=float, + help="RMS target for final layer (set to <1 if final layer learns too fast", + default=1.0) +parser.add_argument("--skip-lda", type=str, + help="add lda matrix", + choices=['true', 'false'], default = "false") +parser.add_argument("--add-final-sigmoid", type=str, + help="add a sigmoid layer as the final layer. Applicable only if skip-final-softmax is true.", + choices=['true', 'false'], default = "false") +parser.add_argument("--objective-type", type=str, default="linear", + choices = ["linear", "quadratic", "xent"], + help = "the type of objective; i.e. quadratic or linear or cross-entropy") +parser.add_argument("config_dir", + help="Directory to write config files and variables") +print(' '.join(sys.argv)) + +args = parser.parse_args() + +if not os.path.exists(args.config_dir): + os.makedirs(args.config_dir) + +## Check arguments. +if args.splice_indexes is None: + sys.exit("--splice-indexes argument is required"); +if args.feat_dim is None or not (args.feat_dim > 0): + sys.exit("--feat-dim argument is required"); +if args.num_targets is None or not (args.num_targets > 0): + sys.exit("--num-targets argument is required"); + +if args.use_presoftmax_prior_scale == "true": + use_presoftmax_prior_scale = True +else: + use_presoftmax_prior_scale = False + +if args.skip_lda == "true": + skip_lda = True +else: + skip_lda = False + +if args.include_log_softmax == "true": + include_log_softmax = True +else: + include_log_softmax = False + +if args.add_final_sigmoid == "true": + add_final_sigmoid = True +else: + add_final_sigmoid = False + +## Work out splice_array e.g. splice_array = [ [ -3,-2,...3 ], [0], [-2,2], .. [ -8,8 ] ] +splice_array = [] +left_context = 0 +right_context = 0 +split1 = args.splice_indexes.split(); # we already checked the string is nonempty. +input_dim = args.feat_dim + args.ivector_dim +if len(split1) < 1: + sys.exit("invalid --splice-indexes argument, too short: " + + args.splice_indexes) +try: + for string in split1: + split2 = string.split(",") + if len(split2) < 1: + sys.exit("invalid --splice-indexes argument, too-short element: " + + args.splice_indexes) + int_list = [] + for int_str in split2: + int_list.append(int(int_str)) + if not int_list == sorted(int_list): + sys.exit("elements of --splice-indexes must be sorted: " + + args.splice_indexes) + left_context += -int_list[0] + right_context += int_list[-1] + splice_array.append(int_list) +except ValueError as e: + sys.exit("invalid --splice-indexes argument " + args.splice_indexes + e) +left_context = max(0, left_context) +right_context = max(0, right_context) +num_hidden_layers = len(splice_array) +input_dim = len(splice_array[0]) * args.feat_dim + args.ivector_dim + +if (sum([1 for x in [args.relu_dims, args.relu_dim, args.sigmoid_dims, args.sigmoid_dim, args.pnorm_input_dims, args.pnorm_input_dim] if x]) > 1 + or sum([1 for x in [args.relu_dims, args.relu_dim, args.sigmoid_dims, args.sigmoid_dim, args.pnorm_output_dims, args.pnorm_output_dim] if x]) > 1): + sys.exit("only one of the dimension options must be provided") + +if args.relu_dim is not None: + nonlin_input_dims = [args.relu_dim] * num_hidden_layers + nonlin_output_dims = nonlin_input_dims +if args.relu_dims is not None: + nonlin_input_dims = args.relu_dims.strip().split() + nonlin_output_dims = nonlin_input_dims +if args.sigmoid_dim is not None: + nonlin_input_dims = [args.sigmoid_dim] * num_hidden_layers + nonlin_output_dims = nonlin_input_dims +if args.sigmoid_dims is not None: + nonlin_input_dims = args.sigmoid_dims.strip().split() + nonlin_output_dims = nonlin_input_dims +if args.pnorm_input_dims is not None: + assert(args.pnorm_output_dims is not None) + nonlin_input_dims = args.pnorm_input_dims.strip().split() + nonlin_output_dims = args.pnorm_output_dims.strip().split() +if args.pnorm_input_dim is not None: + assert(args.pnorm_output_dim is not None) + nonlin_input_dims = [args.pnorm_input_dim] * num_hidden_layers + nonlin_output_dims = [args.pnorm_output_dim] * num_hidden_layers + +nonlin_input_dims = [ int(x) for x in nonlin_input_dims ] +nonlin_output_dims = [ int(x) for x in nonlin_output_dims ] + +assert len(nonlin_input_dims) == num_hidden_layers +assert len(nonlin_output_dims) == num_hidden_layers + +f = open(args.config_dir + "/vars", "w") +print('left_context=' + str(left_context), file=f) +print('right_context=' + str(right_context), file=f) +# the initial l/r contexts are actually not needed. +# print('initial_left_context=' + str(splice_array[0][0]), file=f) +# print('initial_right_context=' + str(splice_array[0][-1]), file=f) +print('num_hidden_layers=' + str(num_hidden_layers), file=f) +f.close() + +f = open(args.config_dir + "/init.config", "w") +print('# Config file for initializing neural network prior to', file=f) +print('# preconditioning matrix computation', file=f) +print('input-node name=input dim=' + str(args.feat_dim), file=f) +list=[ ('Offset(input, {0})'.format(n) if n != 0 else 'input' ) for n in splice_array[0] ] +if args.ivector_dim > 0: + print('input-node name=ivector dim=' + str(args.ivector_dim), file=f) + list.append('ReplaceIndex(ivector, t, 0)') +# example of next line: +# output-node name=output input="Append(Offset(input, -3), Offset(input, -2), Offset(input, -1), ... , Offset(input, 3), ReplaceIndex(ivector, t, 0))" +print('output-node name=output input=Append({0})'.format(", ".join(list)), file=f) +f.close() + +for l in range(1, num_hidden_layers + 1): + f = open(args.config_dir + "/layer{0}.config".format(l), "w") + print('# Config file for layer {0} of the network'.format(l), file=f) + if l == 1 and not skip_lda: + print('component name=lda type=FixedAffineComponent matrix={0}/lda.mat'. + format(args.config_dir), file=f) + cur_dim = (nonlin_output_dims[l-2] * len(splice_array[l-1]) if l > 1 else input_dim) + + print('# Note: param-stddev in next component defaults to 1/sqrt(input-dim).', file=f) + print('component name=affine{0} type=NaturalGradientAffineComponent ' + 'input-dim={1} output-dim={2} bias-stddev=0'. + format(l, cur_dim, nonlin_input_dims[l-1]), file=f) + if args.relu_dims is not None: + print('component name=nonlin{0} type=RectifiedLinearComponent dim={1}'. + format(l, nonlin_input_dims[l-1]), file=f) + elif args.sigmoid_dims is not None: + print('component name=nonlin{0} type=SigmoidComponent dim={1}'. + format(l, nonlin_input_dims[l-1]), file=f) + else: + print('# In nnet3 framework, p in P-norm is always 2.', file=f) + print('component name=nonlin{0} type=PnormComponent input-dim={1} output-dim={2}'. + format(l, nonlin_input_dims[l-1], nonlin_output_dims[l-1]), file=f) + print('component name=renorm{0} type=NormalizeComponent dim={1} target-rms={2}'.format( + l, nonlin_output_dims[l-1], + (1.0 if l < num_hidden_layers else args.final_layer_normalize_target)), file=f) + print('component name=final-affine type=NaturalGradientAffineComponent ' + 'input-dim={0} output-dim={1} param-stddev=0 bias-stddev=0'.format( + nonlin_output_dims[l-1], args.num_targets), file=f) + + if args.include_log_softmax == "true": + # printing out the next two, and their component-nodes, for l > 1 is not + # really necessary as they will already exist, but it doesn't hurt and makes + # the structure clearer. + if use_presoftmax_prior_scale: + print('component name=final-fixed-scale type=FixedScaleComponent ' + 'scales={0}/presoftmax_prior_scale.vec'.format( + args.config_dir), file=f) + print('component name=final-log-softmax type=LogSoftmaxComponent dim={0}'.format( + args.num_targets), file=f) + elif add_final_sigmoid: + print('component name=final-sigmoid type=SigmoidComponent dim={0}'.format( + args.num_targets), file=f) + print('# Now for the network structure', file=f) + if l == 1: + splices = [ ('Offset(input, {0})'.format(n) if n != 0 else 'input') for n in splice_array[l-1] ] + if args.ivector_dim > 0: splices.append('ReplaceIndex(ivector, t, 0)') + orig_input='Append({0})'.format(', '.join(splices)) + # e.g. orig_input = 'Append(Offset(input, -2), ... Offset(input, 2), ivector)' + if not skip_lda: + print('component-node name=lda component=lda input={0}'.format(orig_input), + file=f) + cur_input='lda' + else: + cur_input = orig_input + else: + # e.g. cur_input = 'Append(Offset(renorm1, -2), renorm1, Offset(renorm1, 2))' + splices = [ ('Offset(renorm{0}, {1})'.format(l-1, n) if n !=0 else 'renorm{0}'.format(l-1)) + for n in splice_array[l-1] ] + cur_input='Append({0})'.format(', '.join(splices)) + print('component-node name=affine{0} component=affine{0} input={1} '. + format(l, cur_input), file=f) + print('component-node name=nonlin{0} component=nonlin{0} input=affine{0}'. + format(l), file=f) + print('component-node name=renorm{0} component=renorm{0} input=nonlin{0}'. + format(l), file=f) + + print('component-node name=final-affine component=final-affine input=renorm{0}'. + format(l), file=f) + + if args.include_log_softmax == "true": + if use_presoftmax_prior_scale: + print('component-node name=final-fixed-scale component=final-fixed-scale input=final-affine', + file=f) + print('component-node name=final-log-softmax component=final-log-softmax ' + 'input=final-fixed-scale', file=f) + else: + print('component-node name=final-log-softmax component=final-log-softmax ' + 'input=final-affine', file=f) + print('output-node name=output input=final-log-softmax objective={0}'.format(args.objective_type), file=f) + else: + if add_final_sigmoid: + print('component-node name=final-sigmoid component=final-sigmoid input=final-affine', file=f) + print('output-node name=output input=final-sigmoid objective={0}'.format(args.objective_type), file=f) + else: + print('output-node name=output input=final-affine objective={0}'.format(args.objective_type), file=f) + f.close() + + +# component name=nonlin1 type=PnormComponent input-dim=$pnorm_input_dim output-dim=$pnorm_output_dim +# component name=renorm1 type=NormalizeComponent dim=$pnorm_output_dim +# component name=final-affine type=NaturalGradientAffineComponent input-dim=$pnorm_output_dim output-dim=$num_leaves param-stddev=0 bias-stddev=0 +# component name=final-log-softmax type=LogSoftmaxComponent dim=$num_leaves + + +# ## Write file $config_dir/init.config to initialize the network, prior to computing the LDA matrix. +# ##will look like this, if we have iVectors: +# input-node name=input dim=13 +# input-node name=ivector dim=100 +# output-node name=output input="Append(Offset(input, -3), Offset(input, -2), Offset(input, -1), ... , Offset(input, 3), ReplaceIndex(ivector, t, 0))" + +# ## Write file $config_dir/layer1.config that adds the LDA matrix, assumed to be in the config directory as +# ## lda.mat, the first hidden layer, and the output layer. +# component name=lda type=FixedAffineComponent matrix=$config_dir/lda.mat +# component name=affine1 type=NaturalGradientAffineComponent input-dim=$lda_input_dim output-dim=$pnorm_input_dim bias-stddev=0 +# component name=nonlin1 type=PnormComponent input-dim=$pnorm_input_dim output-dim=$pnorm_output_dim +# component name=renorm1 type=NormalizeComponent dim=$pnorm_output_dim +# component name=final-affine type=NaturalGradientAffineComponent input-dim=$pnorm_output_dim output-dim=$num_leaves param-stddev=0 bias-stddev=0 +# component name=final-log-softmax type=LogSoftmax dim=$num_leaves +# # InputOf(output) says use the same Descriptor of the current "output" node. +# component-node name=lda component=lda input=InputOf(output) +# component-node name=affine1 component=affine1 input=lda +# component-node name=nonlin1 component=nonlin1 input=affine1 +# component-node name=renorm1 component=renorm1 input=nonlin1 +# component-node name=final-affine component=final-affine input=renorm1 +# component-node name=final-log-softmax component=final-log-softmax input=final-affine +# output-node name=output input=final-log-softmax + + +# ## Write file $config_dir/layer2.config that adds the second hidden layer. +# component name=affine2 type=NaturalGradientAffineComponent input-dim=$lda_input_dim output-dim=$pnorm_input_dim bias-stddev=0 +# component name=nonlin2 type=PnormComponent input-dim=$pnorm_input_dim output-dim=$pnorm_output_dim +# component name=renorm2 type=NormalizeComponent dim=$pnorm_output_dim +# component name=final-affine type=NaturalGradientAffineComponent input-dim=$pnorm_output_dim output-dim=$num_leaves param-stddev=0 bias-stddev=0 +# component-node name=affine2 component=affine2 input=Append(Offset(renorm1, -2), Offset(renorm1, 2)) +# component-node name=nonlin2 component=nonlin2 input=affine2 +# component-node name=renorm2 component=renorm2 input=nonlin2 +# component-node name=final-affine component=final-affine input=renorm2 +# component-node name=final-log-softmax component=final-log-softmax input=final-affine +# output-node name=output input=final-log-softmax + + +# ## ... etc. In this example it would go up to $config_dir/layer5.config. + diff --git a/egs/wsj/s5/steps/nnet3/train_tdnn_raw.sh b/egs/wsj/s5/steps/nnet3/train_tdnn_raw.sh new file mode 100755 index 00000000000..b173b1b5f35 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/train_tdnn_raw.sh @@ -0,0 +1,660 @@ +#!/bin/bash + +# note, TDNN is the same as what we used to call multisplice. + +# Copyright 2012-2015 Johns Hopkins University (Author: Daniel Povey). +# 2013 Xiaohui Zhang +# 2013 Guoguo Chen +# 2014-2015 Vimal Manohar +# 2014 Vijayaditya Peddinti +# Apache 2.0. + +set -u + +# Begin configuration section. +cmd=run.pl +nj=4 +num_epochs=15 # Number of epochs of training; + # the number of iterations is worked out from this. +initial_effective_lrate=0.01 +final_effective_lrate=0.001 + +relu_dim= +sigmoid_dim= +pnorm_input_dim=3000 +pnorm_output_dim=300 + +pnorm_input_dims= +pnorm_output_dims= +relu_dims= # you can use this to make it use ReLU's instead of p-norms. +sigmoid_dims= # you can use this to make it use Sigmoid's instead of p-norms. + +rand_prune=4.0 # Relates to a speedup we do for LDA. +minibatch_size=512 # This default is suitable for GPU-based training. + # Set it to 128 for multi-threaded CPU-based training. + +samples_per_iter=400000 # each iteration of training, see this many samples + # per job. This option is passed to get_egs.sh +num_jobs_initial=1 # Number of neural net jobs to run in parallel at the start of training +num_jobs_final=8 # Number of neural net jobs to run in parallel at the end of training +prior_subset_size=20000 # 20k samples per job, for computing priors. +num_utts_subset=300 # number of utterances in validation and training + # subsets used for shrinkage and diagnostics. +num_jobs_compute_prior=10 # these are single-threaded, run on CPU. +get_egs_stage=0 # can be used for rerunning after partial +online_ivector_dir= +presoftmax_prior_scale_power=-0.25 +use_presoftmax_prior_scale=true +remove_egs=true # set to false to disable removing egs after training is done. +egs_suffix= +deriv_weights_scp= + +max_models_combine=20 # The "max_models_combine" is the maximum number of models we give + # to the final 'combine' stage, but these models will themselves be averages of + # iteration-number ranges. + +shuffle_buffer_size=5000 # This "buffer_size" variable controls randomization of the samples + # on each iter. You could set it to 0 or to a large value for complete + # randomization, but this would both consume memory and cause spikes in + # disk I/O. Smaller is easier on disk and memory but less random. It's + # not a huge deal though, as samples are anyway randomized right at the start. + # (the point of this is to get data in different minibatches on different iterations, + # since in the preconditioning method, 2 samples in the same minibatch can + # affect each others' gradients. + +add_layers_period=2 # by default, add new layers every 2 iterations. +stage=-6 +exit_stage=-100 # you can set this to terminate the training early. Exits before running this stage + +# count space-separated fields in splice_indexes to get num-hidden-layers. +splice_indexes="-4,-3,-2,-1,0,1,2,3,4 0 -2,2 0 -4,4 0" +# Format : layer/....layer/ " +# note: hidden layers which are composed of one or more components, +# so hidden layer indexing is different from component count + +randprune=4.0 # speeds up LDA. +use_gpu=true # if true, we run on GPU. +cleanup=true +keep_model_iter=10 +egs_dir= +skip_lda=true +max_lda_jobs=10 # use no more than 10 jobs for the LDA accumulation. +lda_opts= +egs_opts= +transform_dir= +cmvn_opts= # will be passed to get_lda.sh and get_egs.sh, if supplied. + # only relevant for "raw" features, not lda. +feat_type=raw # or set to 'lda' to use LDA features. +# End configuration section. +frames_per_eg=8 # to be passed on to get_egs.sh +num_targets= # applicable only if posterior_targets is true +posterior_targets=false +objective_type=linear +include_log_softmax=false +max_param_change=1 +config_dir= + +trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "Usage: $0 [opts] " + echo " e.g.: $0 data/train scp:snr_targets/targets.scp exp/nnet3_snr_predictor" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --num-epochs <#epochs|15> # Number of epochs of training" + echo " --initial-effective-lrate # effective learning rate at start of training." + echo " --final-effective-lrate # effective learning rate at end of training." + echo " # data, 0.00025 for large data" + echo " --num-hidden-layers <#hidden-layers|2> # Number of hidden layers, e.g. 2 for 3 hours of data, 4 for 100hrs" + echo " --add-layers-period <#iters|2> # Number of iterations between adding hidden layers" + echo " --presoftmax-prior-scale-power # use the specified power value on the priors (inverse priors) to scale" + echo " # the pre-softmax outputs (set to 0.0 to disable the presoftmax element scale)" + echo " --num-jobs-initial # Number of parallel jobs to use for neural net training, at the start." + echo " --num-jobs-final # Number of parallel jobs to use for neural net training, at the end" + echo " --num-threads # Number of parallel threads per job, for CPU-based training (will affect" + echo " # results as well as speed; may interact with batch size; if you increase" + echo " # this, you may want to decrease the batch size." + echo " --parallel-opts # extra options to pass to e.g. queue.pl for processes that" + echo " # use multiple threads... note, you might have to reduce mem_free,ram_free" + echo " # versus your defaults, because it gets multiplied by the -pe smp argument." + echo " --io-opts # Options given to e.g. queue.pl for jobs that do a lot of I/O." + echo " --minibatch-size # Size of minibatch to process (note: product with --num-threads" + echo " # should not get too large, e.g. >2k)." + echo " --samples-per-iter <#samples|400000> # Number of samples of data to process per iteration, per" + echo " # process." + echo " --splice-indexes " + echo " # Frame indices used for each splice layer." + echo " # Format : layer/....layer/ " + echo " # (note: we splice processed, typically 40-dimensional frames" + echo " --lda-dim # Dimension to reduce spliced features to with LDA" + echo " --stage # Used to run a partially-completed training process from somewhere in" + echo " # the middle." + + exit 1; +fi + +data=$1 +targets_scp=$2 +dir=$3 + +# Check some files. +for f in $data/feats.scp $targets_scp; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +# in this dir we'll have just one job. +sdata=$data/split$nj +utils/split_data.sh $data $nj + +mkdir -p $dir/log +echo $nj > $dir/num_jobs + +# First work out the feature and iVector dimension, needed for tdnn config creation. +case $feat_type in + raw|sparse) feat_dim=$(feat-to-dim --print-args=false scp:$data/feats.scp -) || \ + { echo "$0: Error getting feature dim"; exit 1; } + ;; + lda|lda_sparse) [ ! -f $transform_dir/final.mat ] && echo "$0: With --feat-type lda option, expect $transform_dir/final.mat to exist." + # get num-rows in lda matrix, which is the lda feature dim. + feat_dim=$(matrix-dim --print-args=false $transform_dir/final.mat | cut -f 1) + ;; + #sparse) + # + # if [ -z "$sparse_input_dim" ]; then + # echo "$0: feat-type is sparse; sparse-input-dim must be specified" + # exit 1 + # fi + + # feat_dim=$sparse_input_dim + # ;; + *) + echo "$0: Bad --feat-type '$feat_type';"; exit 1; +esac +if [ -z "$online_ivector_dir" ]; then + ivector_dim=0 +else + ivector_dim=$(feat-to-dim scp:$online_ivector_dir/ivector_online.scp -) || exit 1; +fi + + +if [ $stage -le -5 ]; then + echo "$0: creating neural net configs"; + + raw_nnet_config_opts=() + + + objective_opts="--objective-type=$objective_type" + + if [ "$objective_type" == "xent" ]; then + raw_nnet_config_opts+=(--add-final-sigmoid=true --include-log-softmax=false) + else + raw_nnet_config_opts+=(--include-log-softmax=$include_log_softmax) + fi + + raw_nnet_config_opts+=(--use-presoftmax-prior-scale=$use_presoftmax_prior_scale) + raw_nnet_config_opts+=(--skip-lda=$skip_lda) + + input_dim=$feat_dim + if ! $posterior_targets; then + # Set num targets as the dimension of targets features + num_targets=`feat-to-dim scp:$targets_scp - 2>/dev/null` || exit 1 + fi + + [ -z $num_targets ] && echo "\$num_targets is unset" && exit 1 + [ "$num_targets" -eq "0" ] && echo "\$num_targets is 0" && exit 1 + + if [ ! -z "$config_dir" ]; then + cp -rT $config_dir $dir/configs + else + # create the config files for nnet initialization + python steps/nnet3/make_tdnn_snr_predictor_configs.py \ + --splice-indexes="$splice_indexes" \ + --feat-dim=$input_dim \ + --ivector-dim=$ivector_dim \ + "${raw_nnet_config_opts[@]}" $objective_opts \ + ${relu_dim:+--relu-dim="$relu_dim"} \ + ${sigmoid_dim:+--sigmoid-dim="$sigmoid_dim"} \ + ${pnorm_input_dim:+--pnorm-input-dim="$pnorm_input_dim"} \ + ${pnorm_output_dim:+--pnorm-output-dim="$pnorm_output_dim"} \ + ${relu_dims:+--relu-dims="$relu_dims"} \ + ${sigmoid_dims:+--sigmoid-dims="$sigmoid_dims"} \ + ${pnorm_input_dims:+--pnorm-input-dims="$pnorm_input_dims"} \ + ${pnorm_output_dims:+--pnorm-output-dims="$pnorm_output_dims"} \ + --num-targets=$num_targets \ + $dir/configs || exit 1; + fi + + # Initialize as "raw" nnet, prior to training the LDA-like preconditioning + # matrix. This first config just does any initial splicing that we do; + # we do this as it's a convenient way to get the stats for the 'lda-like' + # transform. + $cmd $dir/log/nnet_init.log \ + nnet3-init --srand=-2 $dir/configs/init.config $dir/init.raw || exit 1; +fi + +# sourcing the "vars" below sets +# left_context=(something) +# right_context=(something) +# num_hidden_layers=(something) +. $dir/configs/vars || exit 1; + +context_opts="--left-context=$left_context --right-context=$right_context" + +# Allow 0 hidden layers -- Probably only a single affine component followed by +# a softmax +[ "$num_hidden_layers" -le 0 ] && echo \ + "$0: Expected num_hidden_layers to be defined" && exit 1; + +if [ $stage -le -4 ] && [ -z "$egs_dir" ]; then + extra_opts=() + [ ! -z "$cmvn_opts" ] && extra_opts+=(--cmvn-opts "$cmvn_opts") + [ ! -z "$feat_type" ] && extra_opts+=(--feat-type $feat_type) + [ ! -z "$online_ivector_dir" ] && extra_opts+=(--online-ivector-dir $online_ivector_dir) + extra_opts+=(--transform-dir "$transform_dir") + extra_opts+=(--left-context $left_context) + extra_opts+=(--right-context $right_context) + echo "$0: calling get_egs.sh" + [ ! -z "$deriv_weights_scp" ] && extra_opts+=(--deriv-weights-scp $deriv_weights_scp) + + target_type=dense + if $posterior_targets; then + target_type=sparse + fi + + #if [ "$feat_type" == "sparse" ]; then + # if [ -z "$sparse_input_dim" ]; then + # echo "$0: sparse-input-dim or feats-scp not set" + # exit 1 + # fi + # extra_opts+=(--sparse-input-dim $sparse_input_dim --feats-scp $feats_scp) + #fi +if ! $posterior_targets; then + # Set num targets as the dimension of targets features + num_targets=`feat-to-dim scp:$targets_scp - 2>/dev/null` || exit 1 +fi + +[ -z $num_targets ] && echo "\$num_targets is unset" && exit 1 +[ "$num_targets" -eq "0" ] && echo "\$num_targets is 0" && exit 1 + + + steps/nnet3/get_egs_dense_targets.sh $egs_opts "${extra_opts[@]}" \ + --num-utts-subset $num_utts_subset \ + --samples-per-iter $samples_per_iter --stage $get_egs_stage \ + --cmd "$cmd" --nj $nj --num-targets $num_targets $egs_opts \ + --frames-per-eg $frames_per_eg --target-type $target_type \ + $data $targets_scp $dir/egs || exit 1; +fi + +[ -z $egs_dir ] && egs_dir=$dir/egs + +if [ "$feat_dim" != "$(cat $egs_dir/info/feat_dim)" ]; then + echo "$0: feature dimension mismatch with egs, $feat_dim vs $(cat $egs_dir/info/feat_dim)"; + exit 1; +fi +if [ "$ivector_dim" != "$(cat $egs_dir/info/ivector_dim)" ]; then + echo "$0: ivector dimension mismatch with egs, $ivector_dim vs $(cat $egs_dir/info/ivector_dim)"; + exit 1; +fi + +# copy any of the following that exist, to $dir. +cp $egs_dir/{cmvn_opts,splice_opts,final.mat} $dir 2>/dev/null + +# confirm that the egs_dir has the necessary context (especially important if +# the --egs-dir option was used on the command line). +egs_left_context=$(cat $egs_dir/info/left_context) || exit -1 +egs_right_context=$(cat $egs_dir/info/right_context) || exit -1 + ( [ $egs_left_context -lt $left_context ] || \ + [ $egs_right_context -lt $right_context ] ) && \ + echo "$0: egs in $egs_dir have too little context" && exit -1; + +frames_per_eg=$(cat $egs_dir/info/frames_per_eg) || { echo "error: no such file $egs_dir/info/frames_per_eg"; exit 1; } +num_archives=$(cat $egs_dir/info/num_archives) || { echo "error: no such file $egs_dir/info/frames_per_eg"; exit 1; } + +# num_archives_expanded considers each separate label-position from +# 0..frames_per_eg-1 to be a separate archive. +num_archives_expanded=$[$num_archives*$frames_per_eg] + +[ $num_jobs_initial -gt $num_jobs_final ] && \ + echo "$0: --initial-num-jobs cannot exceed --final-num-jobs" && exit 1; + +[ $num_jobs_final -gt $num_archives_expanded ] && \ + echo "$0: --final-num-jobs cannot exceed #archives $num_archives_expanded." && exit 1; + + +if ! $skip_lda && [ $stage -le -3 ]; then + echo "$0: getting preconditioning matrix for input features." + num_lda_jobs=$num_archives + [ $num_lda_jobs -gt $max_lda_jobs ] && num_lda_jobs=$max_lda_jobs + + # Write stats with the same format as stats for LDA. + $cmd JOB=1:$num_lda_jobs $dir/log/get_lda_stats.JOB.log \ + nnet3-acc-lda-stats --rand-prune=$rand_prune \ + $dir/init.raw "ark:$egs_dir/egs.JOB.ark" $dir/JOB.lda_stats || exit 1; + + all_lda_accs=$(for n in $(seq $num_lda_jobs); do echo $dir/$n.lda_stats; done) + $cmd $dir/log/sum_transform_stats.log \ + sum-lda-accs $dir/lda_stats $all_lda_accs || exit 1; + + rm $all_lda_accs || exit 1; + + # this computes a fixed affine transform computed in the way we described in + # Appendix C.6 of http://arxiv.org/pdf/1410.7455v6.pdf; it's a scaled variant + # of an LDA transform but without dimensionality reduction. + $cmd $dir/log/get_transform.log \ + nnet-get-feature-transform $lda_opts $dir/lda.mat $dir/lda_stats || exit 1; + + ln -sf ../lda.mat $dir/configs/lda.mat +fi + + +if $include_log_softmax && [ $stage -le -2 ]; then + echo "$0: preparing initial vector for FixedScaleComponent before softmax" + echo " ... using priors^$presoftmax_prior_scale_power and rescaling to average 1" + + # obtains raw pdf count + $cmd JOB=1:$nj $dir/log/acc_pdf.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j $nj \$[JOB-1] $targets_scp |" ark:- \| \ + post-to-tacc --per-pdf=false --num-targets=$num_targets \ + ark:- $dir/pdf_counts.JOB || exit 1; + $cmd $dir/log/sum_pdf_counts.log \ + vector-sum --binary=false $dir/pdf_counts.* $dir/pdf_counts || exit 1; + rm $dir/pdf_counts.* + + awk -v power=$presoftmax_prior_scale_power -v smooth=0.01 \ + '{ for(i=2; i<=NF-1; i++) { count[i-2] = $i; total += $i; } + num_pdfs=NF-2; average_count = total/num_pdfs; + for (i=0; i $dir/presoftmax_prior_scale.vec + ln -sf ../presoftmax_prior_scale.vec $dir/configs/presoftmax_prior_scale.vec +fi + +if [ $stage -le -1 ]; then + # Add the first layer; this will add in the lda.mat and + # presoftmax_prior_scale.vec. + $cmd $dir/log/add_first_layer.log \ + nnet3-init --srand=-3 $dir/init.raw $dir/configs/layer1.config $dir/0.raw || exit 1; + +fi + + + + +# set num_iters so that as close as possible, we process the data $num_epochs +# times, i.e. $num_iters*$avg_num_jobs) == $num_epochs*$num_archives_expanded, +# where avg_num_jobs=(num_jobs_initial+num_jobs_final)/2. + +num_archives_to_process=$[$num_epochs*$num_archives_expanded] +num_archives_processed=0 +num_iters=$[($num_archives_to_process*2)/($num_jobs_initial+$num_jobs_final)] + +finish_add_layers_iter=$[$num_hidden_layers * $add_layers_period] + +! [ $num_iters -gt $[$finish_add_layers_iter+2] ] \ + && echo "$0: Insufficient epochs" && exit 1 + +echo "$0: Will train for $num_epochs epochs = $num_iters iterations" + +if $use_gpu; then + parallel_suffix="" + train_queue_opt="--gpu 1" + combine_queue_opt="--gpu 1" + prior_gpu_opt="--use-gpu=yes" + prior_queue_opt="--gpu 1" + parallel_train_opts= + if ! cuda-compiled; then + echo "$0: WARNING: you are running with one thread but you have not compiled" + echo " for CUDA. You may be running a setup optimized for GPUs. If you have" + echo " GPUs and have nvcc installed, go to src/ and do ./configure; make" + exit 1 + fi +else + echo "$0: without using a GPU this will be very slow. nnet3 does not yet support multiple threads." + parallel_train_opts="--use-gpu=no" + combine_queue_opt="" # the combine stage will be quite slow if not using + # GPU, as we didn't enable that program to use + # multiple threads. + prior_gpu_opt="--use-gpu=no" + prior_queue_opt="" +fi + + +approx_iters_per_epoch_final=$[$num_archives_expanded/$num_jobs_final] +# First work out how many iterations we want to combine over in the final +# nnet3-combine-fast invocation. (We may end up subsampling from these if the +# number exceeds max_model_combine). The number we use is: +# min(max(max_models_combine, approx_iters_per_epoch_final), +# 1/2 * iters_after_last_layer_added) +num_iters_combine=$max_models_combine +if [ $num_iters_combine -lt $approx_iters_per_epoch_final ]; then + num_iters_combine=$approx_iters_per_epoch_final +fi +half_iters_after_add_layers=$[($num_iters-$finish_add_layers_iter)/2] +if [ $num_iters_combine -gt $half_iters_after_add_layers ]; then + num_iters_combine=$half_iters_after_add_layers +fi +first_model_combine=$[$num_iters-$num_iters_combine+1] + +x=0 + +cur_egs_dir=$egs_dir + +compute_accuracy=false +if [ "$objective_type" == "linear" ]; then + compute_accuracy=true +fi + +echo $feat_type > $dir/feat_type + +while [ $x -lt $num_iters ]; do + [ $x -eq $exit_stage ] && echo "$0: Exiting early due to --exit-stage $exit_stage" && exit 0; + + this_num_jobs=$(perl -e "print int(0.5+$num_jobs_initial+($num_jobs_final-$num_jobs_initial)*$x/$num_iters);") + + ilr=$initial_effective_lrate; flr=$final_effective_lrate; np=$num_archives_processed; nt=$num_archives_to_process; + this_learning_rate=$(perl -e "print (($x + 1 >= $num_iters ? $flr : $ilr*exp($np*log($flr/$ilr)/$nt))*$this_num_jobs);"); + + echo "On iteration $x, learning rate is $this_learning_rate." + + if [ $x -ge 0 ] && [ $stage -le $x ]; then + # Set off jobs doing some diagnostics, in the background. + # Use the egs dir from the previous iteration for the diagnostics + $cmd $dir/log/compute_prob_valid.$x.log \ + nnet3-compute-prob --compute-accuracy=$compute_accuracy $dir/$x.raw \ + "ark:nnet3-merge-egs ark:$cur_egs_dir/valid_diagnostic.egs ark:- |" & + $cmd $dir/log/compute_prob_train.$x.log \ + nnet3-compute-prob --compute-accuracy=$compute_accuracy $dir/$x.raw \ + "ark:nnet3-merge-egs ark:$cur_egs_dir/train_diagnostic.egs ark:- |" & + + if [ $x -gt 0 ]; then + $cmd $dir/log/progress.$x.log \ + nnet3-show-progress --use-gpu=no "nnet3-copy $dir/$[$x-1].raw - |" "nnet3-copy $dir/$x.raw - |" \ + "ark:nnet3-merge-egs ark:$cur_egs_dir/train_diagnostic.egs ark:-|" '&&' \ + nnet3-info "nnet3-copy $dir/$x.raw - |" & + fi + + echo "Training neural net (pass $x)" + + if [ $x -gt 0 ] && \ + [ $x -le $[($num_hidden_layers-1)*$add_layers_period] ] && \ + [ $[$x%$add_layers_period] -eq 0 ]; then + do_average=false # if we've just mixed up, don't do averaging but take the + # best. + cur_num_hidden_layers=$[1+$x/$add_layers_period] + config=$dir/configs/layer$cur_num_hidden_layers.config + raw="nnet3-copy --learning-rate=$this_learning_rate $dir/$x.raw - | nnet3-init --srand=$x - $config - |" + else + do_average=true + if [ $x -eq 0 ]; then do_average=false; fi # on iteration 0, pick the best, don't average. + raw="nnet3-copy --learning-rate=$this_learning_rate $dir/$x.raw -|" + fi + if $do_average; then + this_minibatch_size=$minibatch_size + else + # on iteration zero or when we just added a layer, use a smaller minibatch + # size (and we will later choose the output of just one of the jobs): the + # model-averaging isn't always helpful when the model is changing too fast + # (i.e. it can worsen the objective function), and the smaller minibatch + # size will help to keep the update stable. + this_minibatch_size=$[$minibatch_size/2]; + fi + + rm $dir/.error 2>/dev/null + + + ( # this sub-shell is so that when we "wait" below, + # we only wait for the training jobs that we just spawned, + # not the diagnostic jobs that we spawned above. + + # We can't easily use a single parallel SGE job to do the main training, + # because the computation of which archive and which --frame option + # to use for each job is a little complex, so we spawn each one separately. + for n in $(seq $this_num_jobs); do + k=$[$num_archives_processed + $n - 1]; # k is a zero-based index that we'll derive + # the other indexes from. + archive=$[($k%$num_archives)+1]; # work out the 1-based archive index. + frame=$[(($k/$num_archives)%$frames_per_eg)]; # work out the 0-based frame + # index; this increases more slowly than the archive index because the + # same archive with different frame indexes will give similar gradients, + # so we want to separate them in time. + + $cmd $train_queue_opt $dir/log/train.$x.$n.log \ + nnet3-train $parallel_train_opts --max-param-change=$max_param_change "$raw" \ + "ark:nnet3-copy-egs --frame=$frame $context_opts ark:$cur_egs_dir/egs$egs_suffix.$archive.ark ark:- | nnet3-shuffle-egs --buffer-size=$shuffle_buffer_size --srand=$x ark:- ark:-| nnet3-merge-egs --minibatch-size=$this_minibatch_size ark:- ark:- |" \ + $dir/$[$x+1].$n.raw || touch $dir/.error & + done + wait + ) + # the error message below is not that informative, but $cmd will + # have printed a more specific one. + [ -f $dir/.error ] && echo "$0: error on iteration $x of training" && exit 1; + + nnets_list= + for n in `seq 1 $this_num_jobs`; do + nnets_list="$nnets_list $dir/$[$x+1].$n.raw" + done + + if $do_average; then + # average the output of the different jobs. + $cmd $dir/log/average.$x.log \ + nnet3-average $nnets_list $dir/$[$x+1].raw || exit 1; + else + # choose the best from the different jobs. + n=$(perl -e '($nj,$pat)=@ARGV; $best_n=1; $best_logprob=-1.0e+10; for ($n=1;$n<=$nj;$n++) { + $fn = sprintf($pat,$n); open(F, "<$fn") || die "Error opening log file $fn"; + undef $logprob; while () { if (m/log-prob-per-frame=(\S+)/) { $logprob=$1; } } + close(F); if (defined $logprob && $logprob > $best_logprob) { $best_logprob=$logprob; + $best_n=$n; } } print "$best_n\n"; ' $this_num_jobs $dir/log/train.$x.%d.log) || exit 1; + [ -z "$n" ] && echo "Error getting best model" && exit 1; + $cmd $dir/log/select.$x.log \ + nnet3-copy $dir/$[$x+1].$n.raw $dir/$[$x+1].raw || exit 1; + fi + + rm $nnets_list + [ ! -f $dir/$[$x+1].raw ] && exit 1; + if [ -f $dir/$[$x-1].raw ] && $cleanup && \ + [ $[($x-1)%$keep_model_iter] -ne 0 ] && [ $[$x-1] -lt $first_model_combine ]; then + rm $dir/$[$x-1].raw + fi + fi + x=$[$x+1] + num_archives_processed=$[$num_archives_processed+$this_num_jobs] +done + + +if [ $stage -le $num_iters ]; then + echo "Doing final combination to produce final.raw" + + # Now do combination. In the nnet3 setup, the logic + # for doing averaging of subsets of the models in the case where + # there are too many models to reliably esetimate interpolation + # factors (max_models_combine) is moved into the nnet3-combine + nnets_list=() + for n in $(seq 0 $[num_iters_combine-1]); do + iter=$[$first_model_combine+$n] + nnet=$dir/$iter.raw + [ ! -f $nnet ] && echo "Expected $nnet to exist" && exit 1; + nnets_list[$n]=$nnet + done + + # Below, we use --use-gpu=no to disable nnet3-combine-fast from using a GPU, + # as if there are many models it can give out-of-memory error; and we set + # num-threads to 8 to speed it up (this isn't ideal...) + + $cmd $combine_queue_opt $dir/log/combine.log \ + nnet3-combine --num-iters=40 \ + --enforce-sum-to-one=true --enforce-positive-weights=true \ + --verbose=3 "${nnets_list[@]}" "ark:nnet3-merge-egs --minibatch-size=1024 ark:$cur_egs_dir/combine.egs ark:-|" \ + $dir/final.raw || exit 1; + + # Compute the probability of the final, combined model with + # the same subset we used for the previous compute_probs, as the + # different subsets will lead to different probs. + $cmd $dir/log/compute_prob_valid.final.log \ + nnet3-compute-prob --compute-accuracy=$compute_accuracy $dir/final.raw \ + "ark:nnet3-merge-egs ark:$cur_egs_dir/valid_diagnostic.egs ark:- |" & + $cmd $dir/log/compute_prob_train.final.log \ + nnet3-compute-prob --compute-accuracy=$compute_accuracy $dir/final.raw \ + "ark:nnet3-merge-egs ark:$cur_egs_dir/train_diagnostic.egs ark:- |" & +fi + +sleep 2 + +if $include_log_softmax && [ $stage -le $[$num_iters+1] ]; then + echo "Getting average posterior for purposes of adjusting the priors." + # Note: this just uses CPUs, using a smallish subset of data. + if [ $num_jobs_compute_prior -gt $num_archives ]; then egs_part=1; + else egs_part=JOB; fi + rm $dir/post.$x.*.vec 2>/dev/null + $cmd JOB=1:$num_jobs_compute_prior $prior_queue_opt $dir/log/get_post.$x.JOB.log \ + nnet3-copy-egs --frame=random $context_opts --srand=JOB ark:$cur_egs_dir/egs.$egs_part.ark ark:- \| \ + nnet3-subset-egs --srand=JOB --n=$prior_subset_size ark:- ark:- \| \ + nnet3-merge-egs ark:- ark:- \| \ + nnet3-compute-from-egs $prior_gpu_opt --apply-exp=true \ + $dir/final.raw ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| vector-sum ark:- $dir/post.$x.JOB.vec || exit 1; + + sleep 3; # make sure there is time for $dir/post.$x.*.vec to appear. + + $cmd $dir/log/vector_sum.$x.log \ + vector-sum $dir/post.$x.*.vec $dir/post.$x.vec || exit 1; + + rm -f $dir/post.vec + ln -s post.$x.vec $dir/post.vec + + rm -f $dir/post.$x.*.vec; + + echo "Computed average posterior vector" +fi + +if [ ! -f $dir/final.raw ]; then + echo "$0: $dir/final.raw does not exist." + # we don't want to clean up if the training didn't succeed. + exit 1; +fi + +sleep 2 + +echo Done + +if $cleanup; then + echo Cleaning up data + if $remove_egs && [[ $cur_egs_dir =~ $dir/egs* ]]; then + steps/nnet2/remove_egs.sh $cur_egs_dir + fi + + echo Removing most of the models + for x in `seq 0 $num_iters`; do + if [ $[$x%$keep_model_iter] -ne 0 ] && [ $x -ne $num_iters ] && [ -f $dir/$x.raw ]; then + # delete all but every 100th model; don't delete the ones which combine to form the final model. + rm $dir/$x.raw + fi + done +fi diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh index f27baecd673..fd2a2e9a446 100755 --- a/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh @@ -144,7 +144,7 @@ if [ ! -z "$ali_or_decode_dir" ]; then rm $dir/weights.*.gz || exit 1; fi elif [ -f $ali_or_decode_dir ] && gunzip -c $ali_or_decode_dir >/dev/null; then - cp $ali_or_decode_dir $dir/weights.gz || exit 1; + cp -f $ali_or_decode_dir $dir/weights.gz || exit 1; else echo "$0: expected ali.1.gz or lat.1.gz to exist in $ali_or_decode_dir"; exit 1; @@ -172,8 +172,8 @@ if [ $sub_speaker_frames -gt 0 ]; then feat-to-len scp:$data/feats.scp ark,t:- > $dir/utt_counts || exit 1; fi if ! [ $(wc -l <$dir/utt_counts) -eq $(wc -l <$data/feats.scp) ]; then - echo "$0: error getting per-utterance counts." - exit 0; + echo "$0: error getting per-utterance counts. Number of lines in $dir/utt_counts differs from $data/feats.scp" + exit 1; fi cat $data/spk2utt | python -c " import sys @@ -229,8 +229,8 @@ if [ $stage -le 2 ]; then if [ ! -z "$ali_or_decode_dir" ]; then $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ - weight-post ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ - ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ + weight-post --length-tolerance=1 ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ + ivector-extract --length-tolerance=1 --acoustic-weight=$posterior_scale --compute-objf-change=true \ --max-count=$max_count --spk2utt=ark:$this_sdata/JOB/spk2utt \ $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_spk.JOB.ark || exit 1; else diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_for_recording.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_for_recording.sh new file mode 100755 index 00000000000..8626cb25dd1 --- /dev/null +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_for_recording.sh @@ -0,0 +1,197 @@ +#!/bin/bash + +# Copyright 2013 Daniel Povey +# 2015 Vimal Manohar +# Apache 2.0. + +# This script is similar to extract_ivectors.sh but takes into account of +# the fact that the recording is diarized into different speakers. +# i-vectors are extracted per recording instead of per utterance so that the +# same i-vectors can be used with different segments. + +# i-vectors are not really computed online, they are first computed +# per speaker and used for different parts of recording corresponding to that +# speaker. +# This is mainly intended for use in decoding, where you want the best possible +# quality of iVectors. +# +# This setup also makes it possible to use a previous decoding or alignment, to +# down-weight silence in the stats (default is --silence-weight 0.0). +# +# This is for when you use the "online-decoding" setup in an offline task, and +# you want the best possible results. + + +# Begin configuration section. +nj=30 +cmd="run.pl" +stage=0 +num_gselect=5 # Gaussian-selection using diagonal model: number of Gaussians to select +min_post=0.025 # Minimum posterior to use (posteriors below this are pruned out) +posterior_scale=0.1 # Scale on the acoustic posteriors, intended to account for + # inter-frame correlations. Making this small during iVector + # extraction is equivalent to scaling up the prior, and will + # will tend to produce smaller iVectors where data-counts are + # small. It's not so important that this match the value + # used when training the iVector extractor, but more important + # that this match the value used when you do real online decoding + # with the neural nets trained with these iVectors. +sub_speaker_frames=0 +max_count=100 # Interpret this as a number of frames times posterior scale... + # this config ensures that once the count exceeds this (i.e. + # 1000 frames, or 10 seconds, by default), we start to scale + # down the stats, accentuating the prior term. This seems quite + # important for some reason. + +compress=true # If true, compress the iVectors stored on disk (it's lossy + # compression, as used for feature matrices). +silence_weight=0.0 +acwt=0.1 # used if input is a decode dir, to get best path from lattices. +mdl=final # change this if decode directory did not have ../final.mdl present. + +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + + +if [ $# != 4 ] && [ $# != 5 ]; then + echo "Usage: $0 [options] [||] " + echo " e.g.: $0 data/test exp/nnet2_online/extractor exp/tri3/decode_test exp/nnet2_online/ivectors_test" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --nj # Number of jobs (also see num-processes and num-threads)" + echo " # Ignored if or supplied." + echo " --stage # To control partial reruns" + echo " --num-gselect # Number of Gaussians to select using" + echo " # diagonal model." + echo " --min-post # Pruning threshold for posteriors" + echo " --posterior-scale # Scale on posteriors in iVector extraction; " + echo " # affects strength of prior term." + + exit 1; +fi + +if [ $# -eq 4 ]; then + data=$1 + lang=$2 + srcdir=$3 + dir=$4 +else # 5 arguments + data=$1 + lang=$2 + srcdir=$3 + ali_or_decode_dir=$4 + dir=$5 +fi + +for f in $data/feats.scp $srcdir/final.ie $srcdir/final.dubm $srcdir/global_cmvn.stats $srcdir/splice_opts \ + $lang/phones.txt $srcdir/online_cmvn.conf $srcdir/final.mat; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +mkdir -p $dir/log +silphonelist=$(cat $lang/phones/silence.csl) || exit 1; + +# Get weights for down-weighting silence frames +if [ ! -z "$ali_or_decode_dir" ]; then + if [ -f $ali_or_decode_dir/ali.1.gz ]; then + if [ ! -f $ali_or_decode_dir/${mdl}.mdl ]; then + echo "$0: expected $ali_or_decode_dir/${mdl}.mdl to exist." + exit 1; + fi + nj_orig=$(cat $ali_or_decode_dir/num_jobs) || exit 1; + + if [ $stage -le 0 ]; then + rm $dir/weights.*.gz 2>/dev/null + + $cmd JOB=1:$nj_orig $dir/log/ali_to_post.JOB.log \ + gunzip -c $ali_or_decode_dir/ali.JOB.gz \| \ + ali-to-post ark:- ark:- \| \ + weight-silence-post $silence_weight $silphonelist $ali_or_decode_dir/final.mdl ark:- ark:- \| \ + post-to-weights ark:- "ark:|gzip -c >$dir/weights.JOB.gz" || exit 1; + + # put all the weights in one archive. + for j in $(seq $nj_orig); do gunzip -c $dir/weights.$j.gz; done | gzip -c >$dir/weights.gz || exit 1; + rm $dir/weights.*.gz || exit 1; + fi + + elif [ -f $ali_or_decode_dir/lat.1.gz ]; then + nj_orig=$(cat $ali_or_decode_dir/num_jobs) || exit 1; + if [ ! -f $ali_or_decode_dir/../${mdl}.mdl ]; then + echo "$0: expected $ali_or_decode_dir/../${mdl}.mdl to exist." + exit 1; + fi + + + if [ $stage -le 0 ]; then + rm $dir/weights.*.gz 2>/dev/null + + $cmd JOB=1:$nj_orig $dir/log/lat_to_post.JOB.log \ + lattice-best-path --acoustic-scale=$acwt "ark:gunzip -c $ali_or_decode_dir/lat.JOB.gz|" ark:/dev/null ark:- \| \ + ali-to-post ark:- ark:- \| \ + weight-silence-post $silence_weight $silphonelist $ali_or_decode_dir/../${mdl}.mdl ark:- ark:- \| \ + post-to-weights ark:- "ark:|gzip -c >$dir/weights.JOB.gz" || exit 1; + + # put all the weights in one archive. + for j in $(seq $nj_orig); do gunzip -c $dir/weights.$j.gz; done | gzip -c >$dir/weights.gz || exit 1; + rm $dir/weights.*.gz || exit 1; + fi + elif [ -f $ali_or_decode_dir ] && gunzip -c $ali_or_decode_dir >/dev/null; then + cp -f $ali_or_decode_dir $dir/weights.gz || exit 1; + else + echo "$0: expected ali.1.gz or lat.1.gz to exist in $ali_or_decode_dir"; + exit 1; + fi +fi + +echo $nj > $dir/num_jobs + +sdata=$data/split$nj; +utils/split_data.sh --per-reco $data $nj || exit 1; + +splice_opts=$(cat $srcdir/splice_opts) + +gmm_feats="ark,s,cs:apply-cmvn-online --spk2utt=ark:$sdata/JOB/spk2utt --config=$srcdir/online_cmvn.conf $srcdir/global_cmvn.stats scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $srcdir/final.mat ark:- ark:- |" +feats="ark,s,cs:splice-feats $splice_opts scp:$sdata/JOB/feats.scp ark:- | transform-feats $srcdir/final.mat ark:- ark:- |" + +this_sdata=$sdata + +if [ $stage -le 2 ]; then + if [ ! -z "$ali_or_decode_dir" ]; then + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ + weight-post ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ + ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ + --max-count=$max_count --spk2utt=ark:$this_sdata/JOB/spk2utt \ + $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_spk.JOB.ark || exit 1; + else + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ + ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ + --max-count=$max_count --spk2utt=ark:$this_sdata/JOB/spk2utt \ + $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_spk.JOB.ark || exit 1; + fi +fi + +# get an utterance-level set of iVectors (just duplicate the speaker-level ones). +# note: if $this_sdata is set $dir/split$nj, then these won't be real speakers, they'll +# be "sub-speakers" (speakers split up into multiple utterances). +if [ $stage -le 3 ]; then + for j in $(seq $nj); do + utils/apply_map.pl -f 2 $dir/ivectors_spk.$j.ark <$this_sdata/$j/utt2spk >$dir/ivectors_utt.$j.ark || exit 1; + cut -d ' ' -f 1-2 $this_sdata/$j/segments | utils/utt2spk_to_spk2utt.pl > $this_sdata/$j/reco2utt || exit 1 + done +fi + +$cmd JOB=1:$nj $dir/log/combine_ivectors_for_reco.JOB.log \ + ivector-combine-to-recording ark:$this_sdata/JOB/reco2utt \ + ark:$this_sdata/JOB/segments \ + ark,t:$dir/ivectors_utt.JOB.ark ark:$dir/reco_segmentation.JOB.ark \ + ark:$dir/ivectors_reco.JOB.ark + +echo "$0: done extracting (pseudo-online) iVectors for recordings" + diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh index d8ac11da720..341789e7787 100755 --- a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh @@ -41,6 +41,7 @@ max_count=0 # The use of this option (e.g. --max-count 100) can make # posterior-scaling, so assuming the posterior-scale is 0.1, # --max-count 100 starts having effect after 1000 frames, or # 10 seconds of data. +max_remembered_frames=1000 # End configuration section. @@ -102,9 +103,9 @@ echo "--ivector-extractor=$srcdir/final.ie" >>$ieconf echo "--num-gselect=$num_gselect" >>$ieconf echo "--min-post=$min_post" >>$ieconf echo "--posterior-scale=$posterior_scale" >>$ieconf -echo "--max-remembered-frames=1000" >>$ieconf # the default +echo "--max-remembered-frames=$max_remembered_frames" >>$ieconf # the default echo "--max-count=$max_count" >>$ieconf - +echo "--ivector-period=$ivector_period" >> $ieconf absdir=$(readlink -f $dir) diff --git a/egs/wsj/s5/steps/online/nnet2/segment_recording_ivectors.sh b/egs/wsj/s5/steps/online/nnet2/segment_recording_ivectors.sh new file mode 100755 index 00000000000..ce84b43b332 --- /dev/null +++ b/egs/wsj/s5/steps/online/nnet2/segment_recording_ivectors.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +# This script creates segment-level ivectors from recording-level ivectors. + +# Begin configuration section. +cmd="run.pl" +stage=-10 +ivector_period=10 +compress=true # If true, compress the iVectors stored on disk (it's lossy + # compression, as used for feature matrices). +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo " e.g.: $0 data/test exp/nnet2_online/ivectors_test_reco exp/nnet2_online/ivectors_test" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --ivector-period # How often to extract an iVector (frames)" + exit 1; +fi + +data=$1 +reco_ivector_dir=$2 +dir=$3 + +for f in $data/feats.scp $reco_ivector_dir/ivectors_reco.1.ark $reco_ivector_dir/reco_segmentation.1.ark; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +mkdir -p $dir/log + +echo $ivector_period > $dir/ivector_period || exit 1; + +nj=$(cat $reco_ivector_dir/num_jobs) || exit 1 + +# This will probably work fine because both have the same number of recordings; +# but needs to be checked. Otherwise old data dir must also be input. +utils/split_data.sh --per-reco $data $nj +sdata=$data/split$nj + +for n in `seq $nj`; do + awk '{print $1" "$2}' $sdata/$n/segments | \ + utils/utt2spk_to_spk2utt.pl > $sdata/$n/reco2utt +done + +ivector_dim=$[$(head -n 1 $reco_ivector_dir/ivectors_spk.1.ark | wc -w) - 3] || exit 1; +echo "$0: iVector dim is $ivector_dim" + +base_feat_dim=$(feat-to-dim scp:$data/feats.scp -) || exit 1 + +start_dim=$base_feat_dim +end_dim=$[$base_feat_dim+$ivector_dim-1] + +if [ $stage -le 0 ]; then + if [ -f $data/segments ]; then + $cmd JOB=1:$nj $dir/log/get_ivectors_utt.JOB.log \ + ivector-split-to-segments --offset-frames=0 \ + ark:$reco_ivector_dir/ivectors_reco.JOB.ark \ + ark:$reco_ivector_dir/reco_segmentation.JOB.ark \ + ark:$sdata/JOB/reco2utt ark:$sdata/JOB/segments ark:- \| append-feats \ + --truncate-frames scp:$sdata/JOB/feats.scp ark:- ark:- \| \ + select-feats "$start_dim-$end_dim" ark:- ark:- \| \ + subsample-feats --n=$ivector_period ark:- ark:- \| \ + copy-feats --compress=$compress ark:- \ + ark,scp:$dir/ivector_online.JOB.ark,$dir/ivector_online.JOB.scp || exit 1; + else + $cmd JOB=1:$nj $dir/log/get_ivectors_utt.JOB.log \ + ivector-split-to-segments --offset-frames=0 \ + ark:$reco_ivector_dir/ivectors_reco.JOB.ark \ + ark:$reco_ivector_dir/reco_segmentation.JOB.ark \ + ark:- \| append-feats \ + --truncate-frames scp:$sdata/JOB/feats.scp ark:- ark:- \| \ + select-feats "$start_dim-$end_dim" ark:- ark:- \| \ + subsample-feats --n=$ivector_period ark:- ark:- \| \ + copy-feats --compress=$compress ark:- \ + ark,scp:$dir/ivector_online.JOB.ark,$dir/ivector_online.JOB.scp || exit 1; + fi +fi + +if [ $stage -le 1 ]; then + echo "$0: combining iVectors across jobs" + for j in $(seq $nj); do cat $dir/ivector_online.$j.scp; done >$dir/ivector_online.scp || exit 1; +fi + +echo "$0: done extracting (pseudo-online) iVectors" diff --git a/egs/wsj/s5/utils/copy_data_dir.sh b/egs/wsj/s5/utils/copy_data_dir.sh index bb4d4e77e7c..3fa97caf96e 100755 --- a/egs/wsj/s5/utils/copy_data_dir.sh +++ b/egs/wsj/s5/utils/copy_data_dir.sh @@ -22,6 +22,7 @@ utt_prefix= spk_suffix= utt_suffix= validate_opts= # should rarely be needed. +extra_files= # specify addtional files in 'src-data-dir' to copy, ex. "file1 file2 ..." # end configuration section . utils/parse_options.sh @@ -78,9 +79,6 @@ fi if [ -f $srcdir/segments ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments cp $srcdir/wav.scp $destdir - if [ -f $srcdir/reco2file_and_channel ]; then - cp $srcdir/reco2file_and_channel $destdir/ - fi else # no segments->wav indexed by utt. if [ -f $srcdir/wav.scp ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp @@ -96,7 +94,12 @@ fi if [ -f $srcdir/cmvn.scp ]; then utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp fi -for f in stm glm ctm; do + +if [ -f $srcdir/utt2uniq ]; then + utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2uniq >$destdir/utt2uniq +fi + +for f in stm glm ctm $extra_files; do if [ -f $srcdir/$f ]; then cp $srcdir/$f $destdir fi diff --git a/egs/wsj/s5/utils/fix_data_dir.sh b/egs/wsj/s5/utils/fix_data_dir.sh index 4716925df7d..8319f686e7b 100755 --- a/egs/wsj/s5/utils/fix_data_dir.sh +++ b/egs/wsj/s5/utils/fix_data_dir.sh @@ -6,6 +6,11 @@ # It puts the original contents of data-dir into # data-dir/.backup +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + if [ $# != 1 ]; then echo "Usage: fix_data_dir.sh data-dir" exit 1 @@ -35,7 +40,7 @@ function check_sorted { fi } -for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp reco2file_and_channel spk2gender utt2lang utt2uniq; do +for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp reco2file_and_channel spk2gender utt2lang utt2uniq $utt_extra_files $spk_extra_files; do if [ -f $data/$x ]; then cp $data/$x $data/.backup/$x check_sorted $data/$x @@ -105,7 +110,7 @@ function filter_speakers { filter_file $tmpdir/speakers $data/spk2utt utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk - for s in cmvn.scp spk2gender; do + for s in cmvn.scp spk2gender $spk_extra_files; do f=$data/$s if [ -f $f ]; then filter_file $tmpdir/speakers $f @@ -155,7 +160,7 @@ function filter_utts { fi fi - for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang $maybe_wav; do + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang $maybe_wav $utt_extra_files; do if [ -f $data/$x ]; then cp $data/$x $data/.backup/$x if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then diff --git a/egs/wsj/s5/utils/split_data.sh b/egs/wsj/s5/utils/split_data.sh index 941890cdd57..627d292f14f 100755 --- a/egs/wsj/s5/utils/split_data.sh +++ b/egs/wsj/s5/utils/split_data.sh @@ -16,16 +16,25 @@ # limitations under the License. split_per_spk=true +split_per_reco=false + if [ "$1" == "--per-utt" ]; then split_per_spk=false shift +elif [ "$1" == "--per-reco" ]; then + split_per_reco=true + split_per_spk=false + shift fi if [ $# != 2 ]; then - echo "Usage: split_data.sh [--per-utt] " + echo "Usage: split_data.sh [--per-utt|--per-reco] " echo "This script will not split the data-dir if it detects that the output is newer than the input." echo "By default it splits per speaker (so each speaker is in only one split dir)," echo "but with the --per-utt option it will ignore the speaker information while splitting." + echo "If --per-reco is specified instead, then the split will be such that " + echo "each recording is in only one split. This is useful when diarization " + echo "is done." exit 1 fi @@ -77,10 +86,14 @@ for n in `seq $numsplit`; do utt2spks="$utt2spks $data/split$numsplit/$n/utt2spk" done +utt2spk_opt= if $split_per_spk; then utt2spk_opt="--utt2spk=$data/utt2spk" -else - utt2spk_opt= +elif $split_per_reco; then + if [ -f $data/segments ]; then + awk '{print $1" "$2}' $data/segments > $data/reco2utt + utt2spk_opt="--utt2spk=$data/reco2utt" + fi fi # If lockfile is not installed, just don't lock it. It's not a big deal. diff --git a/egs/wsj/s5/utils/subset_data_dir.sh b/egs/wsj/s5/utils/subset_data_dir.sh index be74ac8c177..7f2e28800ca 100755 --- a/egs/wsj/s5/utils/subset_data_dir.sh +++ b/egs/wsj/s5/utils/subset_data_dir.sh @@ -159,7 +159,7 @@ elif $perspk; then do_filtering; # bash function. exit 0; else - if [ $numutt -gt `cat $srcdir/feats.scp | wc -l` ]; then + if [ $numutt -gt `cat $srcdir/utt2spk | wc -l` ]; then echo "subset_data_dir.sh: cannot subset to more utterances than you originally had." exit 1; fi diff --git a/egs/wsj_noisy/s5/RESULTS b/egs/wsj_noisy/s5/RESULTS new file mode 100644 index 00000000000..d15374ed6a4 --- /dev/null +++ b/egs/wsj_noisy/s5/RESULTS @@ -0,0 +1,282 @@ +#!/bin/bash + +# this RESULTS file was obtained by Haihua Xu in July 2013. + +for x in exp/*/decode*; do [ -d $x ] && grep WER $x/wer_* | utils/best_wer.sh; done +exit 0 + +# Use caution when comparing these results with other published results. +# We use the "20k open" test condition, also known as the "60k vocabulary" +# test condition, in which test utterances are not excluded even if they +# contain words not in the language model. This is the hardest test condition, +# and most published results are in the easier 5k and 20k-closed conditions, +# in which we only test on utterances that are in either a 5k or 20k subset +# of the vocabulary. + +# The following results are updated with LDA+MLLT to use 7, not 9 frames of context, +# and also increased the learning rate for the "indirect" fMMI. + +# monophone, deltas, trained on the 2k shortest utterances from the si84 data. +%WER 35.39 [ 2914 / 8234, 284 ins, 467 del, 2163 sub ] exp/mono0a/decode_tgpr_dev93/wer_10 +%WER 25.78 [ 1455 / 5643, 142 ins, 184 del, 1129 sub ] exp/mono0a/decode_tgpr_eval92/wer_9 + +# first triphone build. Built on half of SI-84. +%WER 20.00 [ 1647 / 8234, 257 ins, 197 del, 1193 sub ] exp/tri1/decode_tgpr_dev93/wer_17 +%WER 13.04 [ 736 / 5643, 137 ins, 61 del, 538 sub ] exp/tri1/decode_tgpr_eval92/wer_14 + +# the same, rescored with full trigram model [not pruned.] Note: the tg{1,2,3,4} are +# different rescoring methods. They all give about the same results. Note: 3 and 4 give +# the "correct" LM scores. +%WER 18.87 [ 1554 / 8234, 295 ins, 136 del, 1123 sub ] exp/tri1/decode_tgpr_dev93_tg1/wer_14 +%WER 18.87 [ 1554 / 8234, 295 ins, 136 del, 1123 sub ] exp/tri1/decode_tgpr_dev93_tg2/wer_14 +%WER 18.75 [ 1544 / 8234, 266 ins, 152 del, 1126 sub ] exp/tri1/decode_tgpr_dev93_tg3/wer_15 +%WER 18.76 [ 1545 / 8234, 266 ins, 152 del, 1127 sub ] exp/tri1/decode_tgpr_dev93_tg4/wer_15 + +# tri2a is delta+delta-delta features. +%WER 17.93 [ 1476 / 8234, 256 ins, 161 del, 1059 sub ] exp/tri2a/decode_tgpr_dev93/wer_16 +%WER 12.42 [ 701 / 5643, 132 ins, 64 del, 505 sub ] exp/tri2a/decode_tgpr_eval92/wer_15 +# just demonstrates how to do decoding constrained by lattices. +%WER 16.76 [ 1380 / 8234, 275 ins, 132 del, 973 sub ] exp/tri2a/decode_tgpr_dev93_fromlats/wer_16 + +# This is an LDA+MLLT system. +%WER 16.43 [ 1353 / 8234, 241 ins, 162 del, 950 sub ] exp/tri2b/decode_tgpr_dev93/wer_16 +%WER 10.69 [ 603 / 5643, 154 ins, 47 del, 402 sub ] exp/tri2b/decode_tgpr_eval92/wer_14 + +# rescoring the lattices with trigram. +%WER 15.29 [ 1252 / 8191, 219 ins, 153 del, 880 sub ] [PARTIAL] exp/tri2b/decode_tgpr_dev93_tg/wer_18 +# using the "biglm" decoding method to avoid the lattice rescoring step [not faster though.] +%WER 15.31 [ 1261 / 8234, 227 ins, 158 del, 876 sub ] exp/tri2b/decode_tgpr_dev93_tg_biglm/wer_18 +# using a Minimum Bayes Risk decoding method on top of the _tg lattices. +%WER 15.15 [ 1241 / 8191, 221 ins, 155 del, 865 sub ] [PARTIAL] exp/tri2b/decode_tgpr_dev93_tg_mbr/wer_18 + +# fMMI, default learning rate (0.001) + +%WER 15.19 [ 1251 / 8234, 213 ins, 148 del, 890 sub ] exp/tri2b_fmmi_b0.1/decode_tgpr_dev93_it3/wer_15 +%WER 15.14 [ 1247 / 8234, 228 ins, 138 del, 881 sub ] exp/tri2b_fmmi_b0.1/decode_tgpr_dev93_it4/wer_14 +%WER 15.06 [ 1240 / 8234, 211 ins, 152 del, 877 sub ] exp/tri2b_fmmi_b0.1/decode_tgpr_dev93_it5/wer_15 +%WER 15.01 [ 1236 / 8234, 206 ins, 154 del, 876 sub ] exp/tri2b_fmmi_b0.1/decode_tgpr_dev93_it6/wer_15 +%WER 14.99 [ 1234 / 8234, 210 ins, 159 del, 865 sub ] exp/tri2b_fmmi_b0.1/decode_tgpr_dev93_it7/wer_15 +%WER 15.23 [ 1254 / 8234, 200 ins, 184 del, 870 sub ] exp/tri2b_fmmi_b0.1/decode_tgpr_dev93_it8/wer_16 + +%WER 15.55 [ 1280 / 8234, 234 ins, 151 del, 895 sub ] exp/tri2b_fmmi_b0.1_lr0.005/decode_tgpr_dev93_it3/wer_15 +%WER 15.63 [ 1287 / 8234, 242 ins, 150 del, 895 sub ] exp/tri2b_fmmi_b0.1_lr0.005/decode_tgpr_dev93_it4/wer_15 +%WER 15.30 [ 1260 / 8234, 224 ins, 143 del, 893 sub ] exp/tri2b_fmmi_b0.1_lr0.005/decode_tgpr_dev93_it5/wer_15 +%WER 15.34 [ 1263 / 8234, 216 ins, 156 del, 891 sub ] exp/tri2b_fmmi_b0.1_lr0.005/decode_tgpr_dev93_it6/wer_16 +%WER 15.34 [ 1263 / 8234, 242 ins, 139 del, 882 sub ] exp/tri2b_fmmi_b0.1_lr0.005/decode_tgpr_dev93_it7/wer_14 +%WER 15.30 [ 1260 / 8234, 245 ins, 134 del, 881 sub ] exp/tri2b_fmmi_b0.1_lr0.005/decode_tgpr_dev93_it8/wer_13 + +%WER 15.21 [ 1252 / 8234, 218 ins, 148 del, 886 sub ] exp/tri2b_fmmi_indirect_b0.1/decode_tgpr_dev93_it3/wer_15 +%WER 15.16 [ 1248 / 8234, 205 ins, 159 del, 884 sub ] exp/tri2b_fmmi_indirect_b0.1/decode_tgpr_dev93_it4/wer_16 +%WER 15.22 [ 1253 / 8234, 229 ins, 147 del, 877 sub ] exp/tri2b_fmmi_indirect_b0.1/decode_tgpr_dev93_it5/wer_15 +%WER 14.90 [ 1227 / 8234, 203 ins, 150 del, 874 sub ] exp/tri2b_fmmi_indirect_b0.1/decode_tgpr_dev93_it6/wer_15 +%WER 14.95 [ 1231 / 8234, 202 ins, 152 del, 877 sub ] exp/tri2b_fmmi_indirect_b0.1/decode_tgpr_dev93_it7/wer_15 +%WER 15.18 [ 1250 / 8234, 184 ins, 172 del, 894 sub ] exp/tri2b_fmmi_indirect_b0.1/decode_tgpr_dev93_it8/wer_16 + +%WER 15.70 [ 1293 / 8234, 218 ins, 163 del, 912 sub ] exp/tri2b_mmi/decode_tgpr_dev93_it3/wer_16 +%WER 15.61 [ 1285 / 8234, 217 ins, 163 del, 905 sub ] exp/tri2b_mmi/decode_tgpr_dev93_it4/wer_16 +%WER 10.46 [ 590 / 5643, 125 ins, 51 del, 414 sub ] exp/tri2b_mmi/decode_tgpr_eval92_it3/wer_15 +%WER 10.40 [ 587 / 5643, 124 ins, 52 del, 411 sub ] exp/tri2b_mmi/decode_tgpr_eval92_it4/wer_16 + +%WER 15.56 [ 1281 / 8234, 224 ins, 152 del, 905 sub ] exp/tri2b_mmi_b0.1/decode_tgpr_dev93_it3/wer_15 +%WER 15.44 [ 1271 / 8234, 220 ins, 165 del, 886 sub ] exp/tri2b_mmi_b0.1/decode_tgpr_dev93_it4/wer_16 +%WER 10.33 [ 583 / 5643, 125 ins, 51 del, 407 sub ] exp/tri2b_mmi_b0.1/decode_tgpr_eval92_it3/wer_15 +%WER 10.33 [ 583 / 5643, 125 ins, 47 del, 411 sub ] exp/tri2b_mmi_b0.1/decode_tgpr_eval92_it4/wer_15 + +%WER 11.43 [ 941 / 8234, 113 ins, 144 del, 684 sub ] exp/tri3b/decode_bd_tgpr_dev93/wer_19 +%WER 16.09 [ 1325 / 8234, 193 ins, 185 del, 947 sub ] exp/tri3b/decode_bd_tgpr_dev93.si/wer_16 +%WER 6.79 [ 383 / 5643, 51 ins, 49 del, 283 sub ] exp/tri3b/decode_bd_tgpr_eval92/wer_18 +%WER 10.61 [ 599 / 5643, 91 ins, 74 del, 434 sub ] exp/tri3b/decode_bd_tgpr_eval92.si/wer_15 +%WER 5.74 [ 324 / 5643, 46 ins, 41 del, 237 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg/wer_19 +%WER 5.90 [ 333 / 5643, 46 ins, 39 del, 248 sub ] exp/tri3b/decode_bd_tgpr_eval92_tg/wer_18 + +# this section demonstrates RNNLM-HS rescoring (commented out by default) +# the exact results might differ insignificantly due to hogwild in RNNLM-HS training that introduces indeterminism +%WER 5.92 [ 334 / 5643, 58 ins, 32 del, 244 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg/wer_14 # baseline (no rescoring) +%WER 5.26 [ 297 / 5643, 47 ins, 29 del, 221 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs100_0.3/wer_15 +%WER 5.17 [ 292 / 5643, 46 ins, 30 del, 216 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs300_0.3/wer_16 +%WER 5.64 [ 318 / 5643, 50 ins, 34 del, 234 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs30_0.15/wer_16 +%WER 5.55 [ 313 / 5643, 51 ins, 32 del, 230 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.15/wer_16 +%WER 5.55 [ 313 / 5643, 51 ins, 32 del, 230 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.15_N1000/wer_16 +%WER 5.39 [ 304 / 5643, 50 ins, 30 del, 224 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3/wer_15 +%WER 5.42 [ 306 / 5643, 50 ins, 30 del, 226 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N10/wer_15 +%WER 5.39 [ 304 / 5643, 50 ins, 30 del, 224 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000/wer_15 +%WER 5.37 [ 303 / 5643, 49 ins, 29 del, 225 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.4/wer_14 +%WER 5.37 [ 303 / 5643, 49 ins, 29 del, 225 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.4_N1000/wer_14 +%WER 5.26 [ 297 / 5643, 45 ins, 32 del, 220 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.5_N1000/wer_15 +%WER 5.14 [ 290 / 5643, 43 ins, 32 del, 215 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.75_N1000/wer_18 + +%WER 14.17 [ 1167 / 8234, 222 ins, 123 del, 822 sub ] exp/tri3b/decode_tgpr_dev93/wer_17 +%WER 19.37 [ 1595 / 8234, 315 ins, 153 del, 1127 sub ] exp/tri3b/decode_tgpr_dev93.si/wer_15 + +%WER 12.98 [ 1069 / 8234, 209 ins, 116 del, 744 sub ] exp/tri3b/decode_tgpr_dev93_tg/wer_19 +%WER 9.30 [ 525 / 5643, 120 ins, 37 del, 368 sub ] exp/tri3b/decode_tgpr_eval92/wer_18 +%WER 12.95 [ 731 / 5643, 167 ins, 46 del, 518 sub ] exp/tri3b/decode_tgpr_eval92.si/wer_14 +%WER 8.54 [ 482 / 5643, 113 ins, 29 del, 340 sub ] exp/tri3b/decode_tgpr_eval92_tg/wer_17 + +%WER 12.12 [ 998 / 8234, 209 ins, 88 del, 701 sub ] exp/tri4a/decode_tgpr_dev93/wer_17 +%WER 15.98 [ 1316 / 8234, 275 ins, 119 del, 922 sub ] exp/tri4a/decode_tgpr_dev93.si/wer_15 +%WER 7.83 [ 442 / 5643, 107 ins, 23 del, 312 sub ] exp/tri4a/decode_tgpr_eval92/wer_16 +%WER 10.90 [ 615 / 5643, 148 ins, 30 del, 437 sub ] exp/tri4a/decode_tgpr_eval92.si/wer_13 + +%WER 9.15 [ 753 / 8234, 90 ins, 113 del, 550 sub ] exp/tri4b/decode_bd_pp_tgpr_dev93/wer_16 +%WER 12.64 [ 1041 / 8234, 137 ins, 145 del, 759 sub ] exp/tri4b/decode_bd_pp_tgpr_dev93.si/wer_16 +%WER 5.74 [ 324 / 5643, 47 ins, 35 del, 242 sub ] exp/tri4b/decode_bd_pp_tgpr_eval92/wer_19 +%WER 7.92 [ 447 / 5643, 64 ins, 46 del, 337 sub ] exp/tri4b/decode_bd_pp_tgpr_eval92.si/wer_15 +%WER 9.38 [ 772 / 8234, 90 ins, 118 del, 564 sub ] exp/tri4b/decode_bd_tgpr_dev93/wer_18 +%WER 13.07 [ 1076 / 8234, 148 ins, 143 del, 785 sub ] exp/tri4b/decode_bd_tgpr_dev93.si/wer_17 +%WER 6.03 [ 340 / 5643, 66 ins, 26 del, 248 sub ] exp/tri4b/decode_bd_tgpr_eval92/wer_13 +%WER 8.19 [ 462 / 5643, 74 ins, 42 del, 346 sub ] exp/tri4b/decode_bd_tgpr_eval92.si/wer_15 +%WER 12.16 [ 1001 / 8234, 197 ins, 98 del, 706 sub ] exp/tri4b/decode_tgpr_dev93/wer_17 +%WER 15.47 [ 1274 / 8234, 235 ins, 120 del, 919 sub ] exp/tri4b/decode_tgpr_dev93.si/wer_17 +%WER 8.08 [ 456 / 5643, 125 ins, 16 del, 315 sub ] exp/tri4b/decode_tgpr_eval92/wer_13 +%WER 10.49 [ 592 / 5643, 147 ins, 27 del, 418 sub ] exp/tri4b/decode_tgpr_eval92.si/wer_12 +%WER 7.99 [ 658 / 8234, 72 ins, 95 del, 491 sub ] exp/tri4b_fmmi_a/decode_bd_tgpr_dev93_it8/wer_12 +%WER 11.15 [ 918 / 8234, 180 ins, 81 del, 657 sub ] exp/tri4b_fmmi_a/decode_tgpr_dev93_it3/wer_15 +%WER 11.23 [ 925 / 8234, 201 ins, 77 del, 647 sub ] exp/tri4b_fmmi_a/decode_tgpr_dev93_it4/wer_12 +%WER 10.64 [ 876 / 8234, 180 ins, 80 del, 616 sub ] exp/tri4b_fmmi_a/decode_tgpr_dev93_it5/wer_13 +%WER 10.43 [ 859 / 8234, 174 ins, 76 del, 609 sub ] exp/tri4b_fmmi_a/decode_tgpr_dev93_it6/wer_12 +%WER 10.42 [ 858 / 8234, 178 ins, 70 del, 610 sub ] exp/tri4b_fmmi_a/decode_tgpr_dev93_it7/wer_11 +%WER 10.41 [ 857 / 8234, 179 ins, 66 del, 612 sub ] exp/tri4b_fmmi_a/decode_tgpr_dev93_it8/wer_10 +%WER 4.09 [ 231 / 5643, 40 ins, 12 del, 179 sub ] exp/tri4b_fmmi_a/decode_tgpr_eval92_it8/wer_11 +%WER 10.61 [ 874 / 8234, 188 ins, 75 del, 611 sub ] exp/tri4b_fmmi_indirect/decode_tgpr_dev93_it3/wer_13 +%WER 10.44 [ 860 / 8234, 183 ins, 79 del, 598 sub ] exp/tri4b_fmmi_indirect/decode_tgpr_dev93_it4/wer_13 +%WER 10.27 [ 846 / 8234, 180 ins, 73 del, 593 sub ] exp/tri4b_fmmi_indirect/decode_tgpr_dev93_it5/wer_12 +%WER 10.27 [ 846 / 8234, 174 ins, 72 del, 600 sub ] exp/tri4b_fmmi_indirect/decode_tgpr_dev93_it6/wer_12 +%WER 10.06 [ 828 / 8234, 162 ins, 80 del, 586 sub ] exp/tri4b_fmmi_indirect/decode_tgpr_dev93_it7/wer_13 +%WER 10.08 [ 830 / 8234, 158 ins, 84 del, 588 sub ] exp/tri4b_fmmi_indirect/decode_tgpr_dev93_it8/wer_13 +%WER 10.77 [ 887 / 8234, 194 ins, 72 del, 621 sub ] exp/tri4b_mmi_b0.1/decode_tgpr_dev93/wer_12 + +%WER 12.27 [ 1010 / 8234, 188 ins, 104 del, 718 sub ] exp/sgmm2_5a/decode_tgpr_dev93/wer_14 +%WER 11.87 [ 977 / 8234, 201 ins, 75 del, 701 sub ] exp/sgmm2_5a_mmi_b0.1/decode_tgpr_dev93_it1/wer_11 +%WER 11.84 [ 975 / 8234, 195 ins, 81 del, 699 sub ] exp/sgmm2_5a_mmi_b0.1/decode_tgpr_dev93_it2/wer_13 +%WER 11.67 [ 961 / 8234, 196 ins, 77 del, 688 sub ] exp/sgmm2_5a_mmi_b0.1/decode_tgpr_dev93_it3/wer_13 +%WER 11.78 [ 970 / 8234, 190 ins, 82 del, 698 sub ] exp/sgmm2_5a_mmi_b0.1/decode_tgpr_dev93_it4/wer_14 +%WER 11.87 [ 977 / 8234, 201 ins, 75 del, 701 sub ] exp/sgmm2_5a_mmi_b0.1_m0.9/decode_tgpr_dev93_it1/wer_11 +%WER 11.85 [ 976 / 8234, 195 ins, 81 del, 700 sub ] exp/sgmm2_5a_mmi_b0.1_m0.9/decode_tgpr_dev93_it2/wer_13 +%WER 11.67 [ 961 / 8234, 196 ins, 77 del, 688 sub ] exp/sgmm2_5a_mmi_b0.1_m0.9/decode_tgpr_dev93_it3/wer_13 +%WER 11.78 [ 970 / 8234, 190 ins, 82 del, 698 sub ] exp/sgmm2_5a_mmi_b0.1_m0.9/decode_tgpr_dev93_it4/wer_14 + +%WER 8.23 [ 678 / 8234, 87 ins, 103 del, 488 sub ] exp/sgmm2_5b/decode_bd_tgpr_dev93/wer_12 +%WER 4.29 [ 242 / 5643, 37 ins, 18 del, 187 sub ] exp/sgmm2_5b/decode_bd_tgpr_eval92/wer_12 +%WER 10.88 [ 896 / 8234, 195 ins, 82 del, 619 sub ] exp/sgmm2_5b/decode_tgpr_dev93/wer_12 +%WER 6.86 [ 387 / 5643, 97 ins, 18 del, 272 sub ] exp/sgmm2_5b/decode_tgpr_eval92/wer_13 + +%WER 3.93 [ 222 / 5643, 36 ins, 14 del, 172 sub ] exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it1/wer_11 +%WER 3.77 [ 213 / 5643, 33 ins, 12 del, 168 sub ] exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it2/wer_12 +%WER 3.62 [ 204 / 5643, 35 ins, 10 del, 159 sub ] exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3/wer_10 +%WER 3.69 [ 208 / 5643, 33 ins, 11 del, 164 sub ] exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3.mbr/wer_11 +%WER 3.51 [ 198 / 5643, 34 ins, 9 del, 155 sub ] exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it4/wer_10 +%WER 7.83 [ 645 / 8234, 83 ins, 95 del, 467 sub ] exp/sgmm2_5b_mmi_b0.1_z/decode_bd_tgpr_dev93_it1/wer_12 +%WER 7.63 [ 628 / 8234, 76 ins, 99 del, 453 sub ] exp/sgmm2_5b_mmi_b0.1_z/decode_bd_tgpr_dev93_it2/wer_14 +%WER 7.52 [ 619 / 8234, 86 ins, 88 del, 445 sub ] exp/sgmm2_5b_mmi_b0.1_z/decode_bd_tgpr_dev93_it3/wer_11 +%WER 7.41 [ 610 / 8234, 76 ins, 93 del, 441 sub ] exp/sgmm2_5b_mmi_b0.1_z/decode_bd_tgpr_dev93_it4/wer_13 +%WER 3.92 [ 221 / 5643, 36 ins, 14 del, 171 sub ] exp/sgmm2_5b_mmi_b0.1_z/decode_bd_tgpr_eval92_it1/wer_11 +%WER 3.72 [ 210 / 5643, 32 ins, 12 del, 166 sub ] exp/sgmm2_5b_mmi_b0.1_z/decode_bd_tgpr_eval92_it2/wer_13 +%WER 3.67 [ 207 / 5643, 33 ins, 10 del, 164 sub ] exp/sgmm2_5b_mmi_b0.1_z/decode_bd_tgpr_eval92_it3/wer_12 +%WER 3.60 [ 203 / 5643, 35 ins, 10 del, 158 sub ] exp/sgmm2_5b_mmi_b0.1_z/decode_bd_tgpr_eval92_it4/wer_11 + +# regular SGMM (only ran the basic one, not discriminatively trained, although the +# scripts are there.) + +# Rescored with quinphone. +# not updated + + +# DNN on fMLLR features (Karel's setup, updated recipe [5.3.2014]). +# frame cross-entropy training +%WER 6.74 [ 555 / 8234, 67 ins, 73 del, 415 sub ] exp/dnn5b_pretrain-dbn_dnn/decode_bd_tgpr_dev93/wer_10 +%WER 4.09 [ 231 / 5643, 33 ins, 15 del, 183 sub ] exp/dnn5b_pretrain-dbn_dnn/decode_bd_tgpr_eval92/wer_12 +# sMBR training (1 iteration) +%WER 6.39 [ 526 / 8234, 56 ins, 77 del, 393 sub ] exp/dnn5b_pretrain-dbn_dnn_smbr/decode_bd_tgpr_dev93_iter1/wer_11 +%WER 3.85 [ 217 / 5643, 23 ins, 16 del, 178 sub ] exp/dnn5b_pretrain-dbn_dnn_smbr/decode_bd_tgpr_eval92_iter1/wer_14 +# sMBR training (1+4 iterations, lattices+alignment updated after 1st iteration) +%WER 6.15 [ 506 / 8234, 55 ins, 70 del, 381 sub ] exp/dnn5b_pretrain-dbn_dnn_smbr_i1lats/decode_bd_tgpr_dev93_iter4/wer_11 +%WER 3.56 [ 201 / 5643, 24 ins, 9 del, 168 sub ] exp/dnn5b_pretrain-dbn_dnn_smbr_i1lats/decode_bd_tgpr_eval92_iter4/wer_13 + + +#DNN results with cpu based setup +%WER 7.21 [ 594 / 8234, 64 ins, 98 del, 432 sub ] exp/nnet5c1/decode_bd_tgpr_dev93/wer_14 + + +#==== Below are some DNN results from an older version of the RESULTS file, + +# Dan's cpu-based neural net recipe. Note: the best number for dev93 is 7.10, an SGMM+MMI system, +# and for eval92 is 3.79, the same system. (On this setup, discriminative training helped a lot, +# which seems to be the reason we can't beat the SGMM+MMI numbers here.) + + +exp/nnet5c1/decode_bd_tgpr_dev93/wer_14:%WER 7.32 [ 603 / 8234, 61 ins, 101 del, 441 sub ] +exp/nnet5c1/decode_bd_tgpr_eval92/wer_14:%WER 4.39 [ 248 / 5643, 32 ins, 17 del, 199 sub ] +# Note: my 4.39% result is worse than Karel's 3.56%. + + +# some GPU-based neural network training results... + +# Below is the recipe with multiple VTLN warps and mel-filterbank inputs.. +%WER 7.24 [ 596 / 8234, 62 ins, 98 del, 436 sub ] exp/nnet5b_gpu/decode_bd_tgpr_dev93/wer_15 +%WER 3.95 [ 223 / 5643, 30 ins, 18 del, 175 sub ] exp/nnet5b_gpu/decode_bd_tgpr_eval92/wer_16 + +# 5c is GPU tanh recipe +%WER 7.08 [ 583 / 8234, 64 ins, 94 del, 425 sub ] exp/nnet5c/decode_bd_tgpr_dev93/wer_14 +%WER 4.02 [ 227 / 5643, 32 ins, 16 del, 179 sub ] exp/nnet5c/decode_bd_tgpr_eval92/wer_13 + +# 5c_gpu (the same, run on GPU) +# note, for 5c and 5c_gpu, we could get better results by using only 4 +# GPU jobs and half the learning rate, but of course it would take twice longer. +%WER 7.29 [ 600 / 8234, 60 ins, 99 del, 441 sub ] exp/nnet5c_gpu/decode_bd_tgpr_dev93/wer_14 +%WER 4.08 [ 230 / 5643, 34 ins, 15 del, 181 sub ] exp/nnet5c_gpu/decode_bd_tgpr_eval92/wer_13 + + +# 5d is the pnorm recipe, with 4 jobs; compare with nnet5c. +%WER 6.97 [ 574 / 8234, 72 ins, 84 del, 418 sub ] exp/nnet5d/decode_bd_tgpr_dev93/wer_12 +%WER 3.86 [ 218 / 5643, 27 ins, 12 del, 179 sub ] exp/nnet5d/decode_bd_tgpr_eval92/wer_13 + (the same without the big dictionary) + %WER 9.40 [ 774 / 8234, 164 ins, 71 del, 539 sub ] exp/nnet5d/decode_tgpr_dev93/wer_12 + %WER 6.45 [ 364 / 5643, 81 ins, 19 del, 264 sub ] exp/nnet5d/decode_tgpr_eval92/wer_14 + +# 5d_gpu is the pnorm recipe with GPU; +%WER 7.07 [ 582 / 8234, 62 ins, 94 del, 426 sub ] exp/nnet5d_gpu/decode_bd_tgpr_dev93/wer_13 +%WER 4.06 [ 229 / 5643, 31 ins, 13 del, 185 sub ] exp/nnet5d_gpu/decode_bd_tgpr_eval92/wer_12 + (the same without the big dictionary)_13 +%WER 9.35 [ 770 / 8234, 161 ins, 78 del, 531 sub ] exp/nnet5d_gpu/decode_tgpr_dev93/wer_12 +%WER 6.59 [ 372 / 5643, 91 ins, 15 del, 266 sub ] exp/nnet5d_gpu/decode_tgpr_eval92/wer_12 + + +%WER 7.13 [ 587 / 8234, 72 ins, 93 del, 422 sub ] exp/nnet5d_gpu/decode_bd_tgpr_dev93/wer_13 +%WER 4.06 [ 229 / 5643, 31 ins, 16 del, 182 sub ] exp/nnet5d_gpu/decode_bd_tgpr_eval92/wer_14 + +# 5e is GPU version of ensemble training of pnorm nnets recipe, with 4 jobs; compare with nnet5d. +%WER 7.19 [ 592 / 8234, 72 ins, 89 del, 431 sub ] exp/nnet5e_gpu/decode_bd_tgpr_dev93/wer_10 +%WER 3.97 [ 224 / 5643, 25 ins, 16 del, 183 sub ] exp/nnet5e_gpu/decode_bd_tgpr_eval92/wer_14 + + # decoded with tgpr LM. + %WER 9.55 [ 786 / 8234, 163 ins, 83 del, 540 sub ] exp/nnet5d_gpu/decode_tgpr_dev93/wer_13 + %WER 6.50 [ 367 / 5643, 95 ins, 16 del, 256 sub ] exp/nnet5d_gpu/decode_tgpr_eval92/wer_14 + +# for results with VTLN, see for example local/run_vtln2.sh, they are at the end. + + +# Online-nnet2 results: + +for x in exp/nnet2_online/nnet_ms_a_online/decode_*; do grep WER $x/wer_* | utils/best_wer.sh ; done +%WER 7.02 [ 578 / 8234, 85 ins, 88 del, 405 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_dev93/wer_10 +%WER 6.56 [ 540 / 8234, 86 ins, 79 del, 375 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_dev93_smbr_epoch1/wer_11 +%WER 6.51 [ 536 / 8234, 84 ins, 77 del, 375 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_dev93_smbr_epoch2/wer_12 +%WER 6.42 [ 529 / 8234, 90 ins, 72 del, 367 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_dev93_smbr_epoch3/wer_12 +%WER 6.40 [ 527 / 8234, 91 ins, 71 del, 365 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_dev93_smbr_epoch4/wer_12 ** +%WER 7.23 [ 595 / 8234, 77 ins, 96 del, 422 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_dev93_utt/wer_10 +%WER 6.95 [ 572 / 8234, 70 ins, 90 del, 412 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_dev93_utt_offline/wer_11 +%WER 3.93 [ 222 / 5643, 32 ins, 12 del, 178 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_eval92/wer_10 +%WER 3.76 [ 212 / 5643, 27 ins, 11 del, 174 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_eval92_smbr_epoch1/wer_11 +%WER 3.60 [ 203 / 5643, 27 ins, 9 del, 167 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_eval92_smbr_epoch2/wer_11 ** +%WER 3.62 [ 204 / 5643, 28 ins, 7 del, 169 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_eval92_smbr_epoch3/wer_11 +%WER 3.72 [ 210 / 5643, 32 ins, 7 del, 171 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_eval92_smbr_epoch4/wer_11 +%WER 4.02 [ 227 / 5643, 32 ins, 14 del, 181 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_eval92_utt/wer_9 +%WER 3.99 [ 225 / 5643, 29 ins, 13 del, 183 sub ] exp/nnet2_online/nnet_ms_a_online/decode_bd_tgpr_eval92_utt_offline/wer_10 +%WER 9.45 [ 778 / 8234, 180 ins, 73 del, 525 sub ] exp/nnet2_online/nnet_ms_a_online/decode_tgpr_dev93/wer_10 +%WER 9.39 [ 773 / 8234, 163 ins, 79 del, 531 sub ] exp/nnet2_online/nnet_ms_a_online/decode_tgpr_dev93_utt/wer_11 +%WER 9.25 [ 762 / 8234, 174 ins, 69 del, 519 sub ] exp/nnet2_online/nnet_ms_a_online/decode_tgpr_dev93_utt_offline/wer_10 +%WER 6.57 [ 371 / 5643, 95 ins, 12 del, 264 sub ] exp/nnet2_online/nnet_ms_a_online/decode_tgpr_eval92/wer_11 +%WER 6.68 [ 377 / 5643, 102 ins, 13 del, 262 sub ] exp/nnet2_online/nnet_ms_a_online/decode_tgpr_eval92_utt/wer_10 +%WER 6.56 [ 370 / 5643, 100 ins, 12 del, 258 sub ] exp/nnet2_online/nnet_ms_a_online/decode_tgpr_eval92_utt_offline/wer_10 + diff --git a/egs/wsj_noisy/s5/cmd.sh b/egs/wsj_noisy/s5/cmd.sh new file mode 100644 index 00000000000..6395d96ca36 --- /dev/null +++ b/egs/wsj_noisy/s5/cmd.sh @@ -0,0 +1,30 @@ +# "queue.pl" uses qsub. The options to it are +# options to qsub. If you have GridEngine installed, +# change this to a queue you have access to. +# Otherwise, use "run.pl", which will run jobs locally +# (make sure your --num-jobs options are no more than +# the number of cpus on your machine. + +#a) JHU cluster options +export train_cmd="queue.pl -l arch=*64" +export decode_cmd="queue.pl -l arch=*64 --mem 2G" +export mkgraph_cmd="queue.pl -l arch=*64 --mem 4G" +export big_memory_cmd="queue.pl -l arch=*64 --mem 8G" +export cuda_cmd="queue.pl -l gpu=1" + + + +#b) BUT cluster options +#export train_cmd="queue.pl -q all.q@@blade -l ram_free=1200M,mem_free=1200M" +#export decode_cmd="queue.pl -q all.q@@blade -l ram_free=1700M,mem_free=1700M" +#export decodebig_cmd="queue.pl -q all.q@@blade -l ram_free=4G,mem_free=4G" + +#export cuda_cmd="queue.pl -q long.q@@pco203 -l gpu=1" +#export cuda_cmd="queue.pl -q long.q@pcspeech-gpu" +#export mkgraph_cmd="queue.pl -q all.q@@servers -l ram_free=4G,mem_free=4G" + +#c) run it locally... +#export train_cmd=run.pl +#export decode_cmd=run.pl +#export cuda_cmd=run.pl +#export mkgraph_cmd=run.pl diff --git a/egs/wsj_noisy/s5/conf/decode_dnn.config b/egs/wsj_noisy/s5/conf/decode_dnn.config new file mode 100644 index 00000000000..89dd9929a62 --- /dev/null +++ b/egs/wsj_noisy/s5/conf/decode_dnn.config @@ -0,0 +1,2 @@ +beam=18.0 # beam for decoding. Was 13.0 in the scripts. +lattice_beam=10.0 # this has most effect on size of the lattices. diff --git a/egs/wsj_noisy/s5/conf/fbank.conf b/egs/wsj_noisy/s5/conf/fbank.conf new file mode 100644 index 00000000000..07e1639e6ee --- /dev/null +++ b/egs/wsj_noisy/s5/conf/fbank.conf @@ -0,0 +1,3 @@ +# No non-default options for now. +--num-mel-bins=40 # similar to Google's setup. + diff --git a/egs/wsj_noisy/s5/conf/mfcc.conf b/egs/wsj_noisy/s5/conf/mfcc.conf new file mode 100644 index 00000000000..7361509099f --- /dev/null +++ b/egs/wsj_noisy/s5/conf/mfcc.conf @@ -0,0 +1 @@ +--use-energy=false # only non-default option. diff --git a/egs/wsj_noisy/s5/conf/mfcc_hires.conf b/egs/wsj_noisy/s5/conf/mfcc_hires.conf new file mode 100644 index 00000000000..434834a6725 --- /dev/null +++ b/egs/wsj_noisy/s5/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=20 # low cutoff frequency for mel bins... this is high-bandwidth data, so + # there might be some information at the low end. +--high-freq=-400 # high cutoff frequently, relative to Nyquist of 8000 (=7600) diff --git a/egs/wsj_noisy/s5/conf/mfcc_vad.conf b/egs/wsj_noisy/s5/conf/mfcc_vad.conf new file mode 100644 index 00000000000..4c6dfe78262 --- /dev/null +++ b/egs/wsj_noisy/s5/conf/mfcc_vad.conf @@ -0,0 +1,4 @@ +--num-ceps=13 # higher than the default which is 12. +--low-freq=20 # the default. +--high-freq=-600 # the default is zero meaning use the Nyquist (8k in this case). + diff --git a/egs/wsj_noisy/s5/conf/online_cmvn.conf b/egs/wsj_noisy/s5/conf/online_cmvn.conf new file mode 100644 index 00000000000..cbdaf5f281c --- /dev/null +++ b/egs/wsj_noisy/s5/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used in the script ../local/online/run_online_decoding_nnet2.sh diff --git a/egs/wsj_noisy/s5/conf/online_pitch.conf b/egs/wsj_noisy/s5/conf/online_pitch.conf new file mode 100644 index 00000000000..bf108594926 --- /dev/null +++ b/egs/wsj_noisy/s5/conf/online_pitch.conf @@ -0,0 +1,45 @@ +## This config is given by conf/make_pitch_online.sh to the program compute-and-process-kaldi-pitch-feats, +## and is copied by steps/online/nnet2/prepare_online_decoding.sh and similar scripts, to be given +## to programs like online2-wav-nnet2-latgen-faster. +## The program compute-and-process-kaldi-pitch-feats will use it to compute pitch features that +## are the same as that those which will generated in online decoding; this enables us to train +## in a way that's compatible with online decoding. +## + +## most of these options relate to the post-processing rather than the pitch +## extraction itself. +--add-raw-log-pitch=true ## this is intended for input to neural nets, so our + ## approach is "throw everything in and see what + ## sticks". +--normalization-left-context=100 +--normalization-right-context=10 # We're removing amost all the right-context + # for the normalization. The reason why we + # include a small nonzero right-context (of + # just 0.1 second) is that by adding a little + # latency to the computation, it enables us to + # get a more accurate estimate of the pitch on + # the frame we're currently computing the + # normalized pitch of. We know for the current + # frame that we will have at least 10 frames to + # the right, and those extra 10 frames will + # increase the quality of the Viterbi + # backtrace. + # + # Note: our changes to the (left,right) context + # from the defaults of (75,75) to (100,10) will + # almost certainly worsen results, but will + # reduce latency. +--frames-per-chunk=10 ## relates to offline simulation of online decoding; 1 + ## would be equivalent to getting in samples one by + ## one. +--simulate-first-pass-online=true ## this make the online-pitch-extraction code + ## output the 'first-pass' features, which + ## are less accurate than the final ones, and + ## which are the only features the neural-net + ## decoding would ever see (since we can't + ## afford to do lattice rescoring in the + ## neural-net code +--delay=5 ## We delay all the pitch information by 5 frames. This is almost + ## certainly not helpful, but it helps to reduce the overall latency + ## added by the pitch computation, from 10 (given by + ## --normalization-right-context) to 10 - 5 = 5. diff --git a/egs/wsj_noisy/s5/conf/zc_vad.conf b/egs/wsj_noisy/s5/conf/zc_vad.conf new file mode 100644 index 00000000000..d8435d4f33a --- /dev/null +++ b/egs/wsj_noisy/s5/conf/zc_vad.conf @@ -0,0 +1,3 @@ +--dither=0.0 +--zero-crossing-threshold=1e-5 + diff --git a/egs/wsj_noisy/s5/diarization b/egs/wsj_noisy/s5/diarization new file mode 120000 index 00000000000..ba78a9126af --- /dev/null +++ b/egs/wsj_noisy/s5/diarization @@ -0,0 +1 @@ +../../sre08/v1/diarization \ No newline at end of file diff --git a/egs/wsj_noisy/s5/local/append_utterances.sh b/egs/wsj_noisy/s5/local/append_utterances.sh new file mode 100755 index 00000000000..e94c19d5cb7 --- /dev/null +++ b/egs/wsj_noisy/s5/local/append_utterances.sh @@ -0,0 +1,93 @@ +#!/bin/bash + +# Copyright 2014 Guoguo Chen +# Apache 2.0 + +# Begin configuration section. +pad_silence=0.5 +# End configuration section. + +echo "$0 $@" + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 2 ]; then + echo "Usage: $0 [options] " + echo "Options:" + echo " --pad-silence # silence to be padded between utterances" + exit 1; +fi + +input_dir=$1 +output_dir=$2 + +for f in spk2gender spk2utt text utt2spk wav.scp; do + [ ! -f $input_dir/$f ] && echo "$0: no such file $input_dir/$f" && exit 1; +done + +# Checks if sox is on the path. +sox=`which sox` +[ $? -ne 0 ] && "sox: command not found." && exit 1; +sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe +[ ! -x $sph2pipe ] && "sph2pipe: command not found." && exit 1; + +mkdir -p $output_dir +cp -f $input_dir/spk2gender $output_dir/spk2gender + +# Creates a silence wav file. We create this actual sil.wav file instead of +# using sox's padding because this way sox can properly pipe the length in the +# header file. Otherwise sox will have to "count" all the samples and then +# update the header, which is not proper in pipe. +mkdir -p $output_dir/.tmp +$sox -n -r 16000 -b 16 $output_dir/.tmp/sil.wav trim 0.0 $pad_silence + +cat $input_dir/spk2utt | perl -e ' + ($text_in, $wav_in, $text_out, $wav_out, $sox, $sph2pipe, $sil_wav) = @ARGV; + open(TI, "<$text_in") || die "Error: fail to open $text_in\n"; + open(TO, ">$text_out") || die "Error: fail to open $text_out\n"; + open(WI, "<$wav_in") || die "Error: fail to open $wav_in\n"; + open(WO, ">$wav_out") || die "Error: fail to open $wav_out\n"; + while () { + chomp; + my @col = split; # We need to add "my" since we use reference below. + @col >= 2 || "bad line $_\n"; + $spk = shift @col; + $spk2utt{$spk} = \@col; + } + while () { + chomp; + @col = split; + @col >= 2 || die "Error: bad line $_\n"; + $utt = shift @col; + $text{$utt} = join(" ", @col); + } + while () { + chomp; + @col = split; + @col >= 2 || die "Error: bad line $_\n"; + $wav{$col[0]} = $col[4]; + } + foreach $spk (keys %spk2utt) { + @utts = @{$spk2utt{$spk}}; + # print $utts[0] . "\n"; + $text_line = ""; + $wav_line = " $sox"; + foreach $utt (@utts) { + $text_line .= " " . $text{$utt}; + $wav_line .= " \"| $sph2pipe -f wav $wav{$utt}\""; # speech + $wav_line .= " $sil_wav"; # silence + } + $text_line = $spk . $text_line . "\n"; + $wav_line = $spk . $wav_line . " -t wav - |\n"; + print TO $text_line; + print WO $wav_line; + }' $input_dir/text $input_dir/wav.scp $output_dir/text \ + $output_dir/wav.scp $sox $sph2pipe $output_dir/.tmp/sil.wav + +cat $input_dir/spk2utt | awk '{print $1" "$1;}' > $output_dir/spk2utt +utils/spk2utt_to_utt2spk.pl $output_dir/spk2utt > $output_dir/utt2spk + +utils/fix_data_dir.sh $output_dir + +exit 0; diff --git a/egs/wsj_noisy/s5/local/cstr_ndx2flist.pl b/egs/wsj_noisy/s5/local/cstr_ndx2flist.pl new file mode 100755 index 00000000000..d19db421a9f --- /dev/null +++ b/egs/wsj_noisy/s5/local/cstr_ndx2flist.pl @@ -0,0 +1,54 @@ +#!/usr/bin/env perl + +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# This is modified from the script in standard Kaldi recipe to account +# for the way the WSJ data is structured on the Edinburgh systems. +# - Arnab Ghoshal, 12/1/12 + +# This program takes as its standard input an .ndx file from the WSJ corpus that looks +# like this: +#;; File: tr_s_wv1.ndx, updated 04/26/94 +#;; +#;; Index for WSJ0 SI-short Sennheiser training data +#;; Data is read WSJ sentences, Sennheiser mic. +#;; Contains 84 speakers X (~100 utts per speaker MIT/SRI and ~50 utts +#;; per speaker TI) = 7236 utts +#;; +#11_1_1:wsj0/si_tr_s/01i/01ic0201.wv1 +#11_1_1:wsj0/si_tr_s/01i/01ic0202.wv1 +#11_1_1:wsj0/si_tr_s/01i/01ic0203.wv1 + +# and as command-line argument it takes the names of the WSJ disk locations, e.g.: +# /group/corpora/public/wsjcam0/data on DICE machines. +# It outputs a list of absolute pathnames. + +$wsj_dir = $ARGV[0]; + +while(){ + if(m/^;/){ next; } # Comment. Ignore it. + else { + m/^([0-9_]+):\s*(\S+)$/ || die "Could not parse line $_"; + $filename = $2; # as a subdirectory of the distributed disk. + if ($filename !~ m/\.wv1$/) { $filename .= ".wv1"; } + $filename = "$wsj_dir/$filename"; + if (-e $filename) { + print "$filename\n"; + } else { + print STDERR "File $filename found in the index but not on disk\n"; + } + } +} diff --git a/egs/wsj_noisy/s5/local/cstr_wsj_data_prep.sh b/egs/wsj_noisy/s5/local/cstr_wsj_data_prep.sh new file mode 100755 index 00000000000..35582646d95 --- /dev/null +++ b/egs/wsj_noisy/s5/local/cstr_wsj_data_prep.sh @@ -0,0 +1,199 @@ +#!/bin/bash +set -e + +# Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0. + +# This is modified from the script in standard Kaldi recipe to account +# for the way the WSJ data is structured on the Edinburgh systems. +# - Arnab Ghoshal, 29/05/12 + +if [ $# -ne 1 ]; then + printf "\nUSAGE: %s \n\n" `basename $0` + echo "The argument should be a the top-level WSJ corpus directory." + echo "It is assumed that there will be a 'wsj0' and a 'wsj1' subdirectory" + echo "within the top-level corpus directory." + exit 1; +fi + +CORPUS=$1 + +dir=`pwd`/data/local/data +lmdir=`pwd`/data/local/nist_lm +mkdir -p $dir $lmdir +local=`pwd`/local +utils=`pwd`/utils + +. ./path.sh # Needed for KALDI_ROOT +sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe +if [ ! -x $sph2pipe ]; then + echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; + exit 1; +fi + +if [ -z $IRSTLM ] ; then + export IRSTLM=$KALDI_ROOT/tools/irstlm/ +fi +export PATH=${PATH}:$IRSTLM/bin +if ! command -v prune-lm >/dev/null 2>&1 ; then + echo "$0: Error: the IRSTLM is not available or compiled" >&2 + echo "$0: Error: We used to install it by default, but." >&2 + echo "$0: Error: this is no longer the case." >&2 + echo "$0: Error: To install it, go to $KALDI_ROOT/tools" >&2 + echo "$0: Error: and run extras/install_irstlm.sh" >&2 + exit 1 +fi + +cd $dir + +# This version for SI-84 +cat $CORPUS/wsj0/doc/indices/train/tr_s_wv1.ndx \ + | $local/cstr_ndx2flist.pl $CORPUS | sort \ + | grep -v wsj0/si_tr_s/401 > train_si84.flist + +# This version for SI-284 +cat $CORPUS/wsj1/doc/indices/si_tr_s.ndx \ + $CORPUS/wsj0/doc/indices/train/tr_s_wv1.ndx \ + | $local/cstr_ndx2flist.pl $CORPUS | sort \ + | grep -v wsj0/si_tr_s/401 > train_si284.flist + +# Now for the test sets. +# $CORPUS/wsj1/doc/indices/readme.doc +# describes all the different test sets. +# Note: each test-set seems to come in multiple versions depending +# on different vocabulary sizes, verbalized vs. non-verbalized +# pronunciations, etc. We use the largest vocab and non-verbalized +# pronunciations. +# The most normal one seems to be the "baseline 60k test set", which +# is h1_p0. + +# Nov'92 (333 utts) +# These index files have a slightly different format; +# have to add .wv1, which is done in cstr_ndx2flist.pl +cat $CORPUS/wsj0/doc/indices/test/nvp/si_et_20.ndx | \ + $local/cstr_ndx2flist.pl $CORPUS | sort > test_eval92.flist + +# Nov'92 (330 utts, 5k vocab) +cat $CORPUS/wsj0/doc/indices/test/nvp/si_et_05.ndx | \ + $local/cstr_ndx2flist.pl $CORPUS | sort > test_eval92_5k.flist + +# Nov'93: (213 utts) +# Have to replace a wrong disk-id. +cat $CORPUS/wsj1/doc/indices/wsj1/eval/h1_p0.ndx | \ + $local/cstr_ndx2flist.pl $CORPUS | sort > test_eval93.flist + +# Nov'93: (215 utts, 5k) +cat $CORPUS/wsj1/doc/indices/wsj1/eval/h2_p0.ndx | \ + $local/cstr_ndx2flist.pl $CORPUS | sort > test_eval93_5k.flist + +# Dev-set for Nov'93 (503 utts) +cat $CORPUS/wsj1/doc/indices/h1_p0.ndx | \ + $local/cstr_ndx2flist.pl $CORPUS | sort > test_dev93.flist + +# Dev-set for Nov'93 (513 utts, 5k vocab) +cat $CORPUS/wsj1/doc/indices/h2_p0.ndx | \ + $local/cstr_ndx2flist.pl $CORPUS | sort > test_dev93_5k.flist + + +# Dev-set Hub 1,2 (503, 913 utterances) + +# Note: the ???'s below match WSJ and SI_DT, or wsj and si_dt. +# Sometimes this gets copied from the CD's with upcasing, don't know +# why (could be older versions of the disks). +find $CORPUS/???1/??_??_20 -print | grep -i ".wv1" | sort > dev_dt_20.flist +find $CORPUS/???1/??_??_05 -print | grep -i ".wv1" | sort > dev_dt_05.flist + + +# Finding the transcript files: +find -L $CORPUS -iname '*.dot' > dot_files.flist + +# Convert the transcripts into our format (no normalization yet) +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + $local/flist2scp.pl $x.flist | sort > ${x}_sph.scp + cat ${x}_sph.scp | awk '{print $1}' \ + | $local/find_transcripts.pl dot_files.flist > $x.trans1 +done + +# Do some basic normalization steps. At this point we don't remove OOVs-- +# that will be done inside the training scripts, as we'd like to make the +# data-preparation stage independent of the specific lexicon used. +noiseword=""; +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + cat $x.trans1 | $local/normalize_transcript.pl $noiseword \ + | sort > $x.txt || exit 1; +done + +# Create scp's with wav's. (the wv1 in the distribution is not really wav, it is sph.) +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < ${x}_sph.scp \ + > ${x}_wav.scp +done + +# Make the utt2spk and spk2utt files. +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + cat ${x}_sph.scp | awk '{print $1}' \ + | perl -ane 'chop; m:^...:; print "$_ $&\n";' > $x.utt2spk + cat $x.utt2spk | $utils/utt2spk_to_spk2utt.pl > $x.spk2utt || exit 1; +done + +#in case we want to limit lm's on most frequent words, copy lm training word frequency list +cp $CORPUS/wsj1/doc/lng_modl/vocab/wfl_64.lst $lmdir +chmod u+w $lmdir/*.lst # had weird permissions on source. + +# The 20K vocab, open-vocabulary language model (i.e. the one with UNK), without +# verbalized pronunciations. This is the most common test setup, I understand. + +cp $CORPUS/wsj1/doc/lng_modl/base_lm/bcb20onp.z $lmdir/lm_bg.arpa.gz || exit 1; +chmod u+w $lmdir/lm_bg.arpa.gz + +# trigram would be: +cat $CORPUS/wsj1/doc/lng_modl/base_lm/tcb20onp.z | \ + perl -e 'while(<>){ if(m/^\\data\\/){ print; last; } } while(<>){ print; }' \ + | gzip -c -f > $lmdir/lm_tg.arpa.gz || exit 1; + +prune-lm --threshold=1e-7 $lmdir/lm_tg.arpa.gz $lmdir/lm_tgpr.arpa || exit 1; +gzip -f $lmdir/lm_tgpr.arpa || exit 1; + +# repeat for 5k language models +cp $CORPUS/wsj1/doc/lng_modl/base_lm/bcb05onp.z $lmdir/lm_bg_5k.arpa.gz || exit 1; +chmod u+w $lmdir/lm_bg_5k.arpa.gz + +# trigram would be: !only closed vocabulary here! +cp $CORPUS/wsj1/doc/lng_modl/base_lm/tcb05cnp.z $lmdir/lm_tg_5k.arpa.gz || exit 1; +chmod u+w $lmdir/lm_tg_5k.arpa.gz +gunzip $lmdir/lm_tg_5k.arpa.gz +tail -n 4328839 $lmdir/lm_tg_5k.arpa | gzip -c -f > $lmdir/lm_tg_5k.arpa.gz +rm $lmdir/lm_tg_5k.arpa + +prune-lm --threshold=1e-7 $lmdir/lm_tg_5k.arpa.gz $lmdir/lm_tgpr_5k.arpa || exit 1; +gzip -f $lmdir/lm_tgpr_5k.arpa || exit 1; + + +if [ ! -f wsj0-train-spkrinfo.txt ] || [ `cat wsj0-train-spkrinfo.txt | wc -l` -ne 134 ]; then + rm -f wsj0-train-spkrinfo.txt + wget http://www.ldc.upenn.edu/Catalog/docs/LDC93S6A/wsj0-train-spkrinfo.txt \ + || ( echo "Getting wsj0-train-spkrinfo.txt from backup location" && \ + wget --no-check-certificate https://sourceforge.net/projects/kaldi/files/wsj0-train-spkrinfo.txt ); +fi + +if [ ! -f wsj0-train-spkrinfo.txt ]; then + echo "Could not get the spkrinfo.txt file from LDC website (moved)?" + echo "This is possibly omitted from the training disks; couldn't find it." + echo "Everything else may have worked; we just may be missing gender info" + echo "which is only needed for VTLN-related diagnostics anyway." + exit 1 +fi +# Note: wsj0-train-spkrinfo.txt doesn't seem to be on the disks but the +# LDC put it on the web. Perhaps it was accidentally omitted from the +# disks. + +cat $CORPUS/wsj0/doc/spkrinfo.txt \ + $CORPUS/wsj1/doc/evl_spok/spkrinfo.txt \ + $CORPUS/wsj1/doc/dev_spok/spkrinfo.txt \ + $CORPUS/wsj1/doc/train/spkrinfo.txt \ + ./wsj0-train-spkrinfo.txt | \ + perl -ane 'tr/A-Z/a-z/; m/^;/ || print;' | \ + awk '{print $1, $2}' | grep -v -- -- | sort | uniq > spk2gender + + +echo "Data preparation succeeded" diff --git a/egs/wsj_noisy/s5/local/cstr_wsj_extend_dict.sh b/egs/wsj_noisy/s5/local/cstr_wsj_extend_dict.sh new file mode 100755 index 00000000000..b2a9faad704 --- /dev/null +++ b/egs/wsj_noisy/s5/local/cstr_wsj_extend_dict.sh @@ -0,0 +1,172 @@ +#!/bin/bash + +# This script builds a larger word-list and dictionary +# than used for the LMs supplied with the WSJ corpus. +# It uses a couple of strategies to fill-in words in +# the LM training data but not in CMUdict. One is +# to generate special prons for possible acronyms, that +# just consist of the constituent letters. The other +# is designed to handle derivatives of known words +# (e.g. deriving the pron of a plural from the pron of +# the base-word), but in a more general, learned-from-data +# way. +# It makes use of scripts in local/dict/ + +if [ $# -ne 1 ]; then + echo "Usage: local/cstr_wsj_train_lms.sh WSJ1_doc_dir" + exit 1 +fi + +export PATH=$PATH:`pwd`/local/dict/ +srcdir=$1 + +if [ ! -d $srcdir/lng_modl ]; then + echo "Expecting 'lng_modl' under WSJ doc directory '$srcdir'" + exit 1 +fi + +mkdir -p data/local/dict_larger +dir=data/local/dict_larger +cp data/local/dict/* data/local/dict_larger # Various files describing phones etc. + # are there; we just want to copy them as the phoneset is the same. +rm data/local/dict_larger/lexicon.txt # we don't want this. +mincount=2 # Minimum count of an OOV we will try to generate a pron for. + +[ ! -f data/local/dict/cmudict/cmudict.0.7a ] && echo "CMU dict not in expected place" && exit 1; + +# Remove comments from cmudict; print first field; remove +# words like FOO(1) which are alternate prons: our dict format won't +# include these markers. +grep -v ';;;' data/local/dict/cmudict/cmudict.0.7a | + perl -ane 's/^(\S+)\(\d+\)/$1/; print; ' | sort | uniq > $dir/dict.cmu + +cat $dir/dict.cmu | awk '{print $1}' | sort | uniq > $dir/wordlist.cmu + +echo "Getting training data [this should take at least a few seconds; if not, there's a problem]" + +# Convert to uppercase, remove XML-like markings. +# For words ending in "." that are not in CMUdict, we assume that these +# are periods that somehow remained in the data during data preparation, +# and we we replace the "." with "\n". Note: we found this by looking at +# oov.counts below (before adding this rule). + +touch $dir/cleaned.gz +if [ `du -m $dir/cleaned.gz | cut -f 1` -eq 73 ]; then + echo "Not getting cleaned data in $dir/cleaned.gz again [already exists]"; +else + gunzip -c $srcdir/lng_modl/lm_train/np_data/{87,88,89}/*.z \ + | awk '/^){ chop; $isword{$_} = 1; } + while() { + @A = split(" ", $_); + for ($n = 0; $n < @A; $n++) { + $a = $A[$n]; + if (! $isword{$a} && $a =~ s/^([^\.]+)\.$/$1/) { # nonwords that end in "." + # and have no other "." in them: treat as period. + print "$a"; + if ($n+1 < @A) { print "\n"; } + } else { print "$a "; } + } + print "\n"; + } + ' $dir/wordlist.cmu | gzip -c > $dir/cleaned.gz +fi + +# get unigram counts +echo "Getting unigram counts" +gunzip -c $dir/cleaned.gz | tr -s ' ' '\n' | \ + awk '{count[$1]++} END{for (w in count) { print count[w], w; }}' | sort -nr > $dir/unigrams + +cat $dir/unigrams | awk -v dict=$dir/dict.cmu \ + 'BEGIN{while(getline $dir/oov.counts + +echo "Most frequent unseen unigrams are: " +head $dir/oov.counts + +# Prune away singleton counts, and remove things with numbers in +# (which should have been normalized) and with no letters at all. + + +cat $dir/oov.counts | awk -v thresh=$mincount '{if ($1 >= thresh) { print $2; }}' \ + | awk '/[0-9]/{next;} /[A-Z]/{print;}' > $dir/oovlist + +# Automatic rule-finding... + +# First make some prons for possible acronyms. +# Note: we don't do this for things like U.K or U.N, +# or A.B. (which doesn't exist anyway), +# as we consider this normalization/spelling errors. + +cat $dir/oovlist | local/dict/get_acronym_prons.pl $dir/dict.cmu > $dir/dict.acronyms + +mkdir $dir/f $dir/b # forward, backward directions of rules... + # forward is normal suffix + # rules, backward is reversed (prefix rules). These + # dirs contain stuff we create while making the rule-based + # extensions to the dictionary. + +# Remove ; and , from words, if they are present; these +# might crash our scripts, as they are used as separators there. +filter_dict.pl $dir/dict.cmu > $dir/f/dict +cat $dir/oovlist | filter_dict.pl > $dir/f/oovs +reverse_dict.pl $dir/f/dict > $dir/b/dict +reverse_dict.pl $dir/f/oovs > $dir/b/oovs + +# The next stage takes a few minutes. +# Note: the forward stage takes longer, as English is +# mostly a suffix-based language, and there are more rules +# that it finds. +for d in $dir/f $dir/b; do + ( + cd $d + cat dict | get_rules.pl 2>get_rules.log >rules + get_rule_hierarchy.pl rules >hierarchy + awk '{print $1}' dict | get_candidate_prons.pl rules dict | \ + limit_candidate_prons.pl hierarchy | \ + score_prons.pl dict | \ + count_rules.pl >rule.counts + # the sort command below is just for convenience of reading. + score_rules.pl rules.with_scores + get_candidate_prons.pl rules.with_scores dict oovs | \ + limit_candidate_prons.pl hierarchy > oovs.candidates + ) & +done +wait + +# Merge the candidates. +reverse_candidates.pl $dir/b/oovs.candidates | cat - $dir/f/oovs.candidates | sort > $dir/oovs.candidates +select_candidate_prons.pl <$dir/oovs.candidates | awk -F';' '{printf("%s %s\n", $1, $2);}' \ + > $dir/dict.oovs + +cat $dir/dict.acronyms $dir/dict.oovs | sort | uniq > $dir/dict.oovs_merged + +awk '{print $1}' $dir/dict.oovs_merged | uniq > $dir/oovlist.handled +sort $dir/oovlist | diff - $dir/oovlist.handled | grep -v 'd' | sed 's:< ::' > $dir/oovlist.not_handled + + +# add_counts.pl attaches to original counts to the list of handled/not-handled OOVs +add_counts.pl $dir/oov.counts $dir/oovlist.handled | sort -nr > $dir/oovlist.handled.counts +add_counts.pl $dir/oov.counts $dir/oovlist.not_handled | sort -nr > $dir/oovlist.not_handled.counts + +echo "**Top OOVs we handled are:**"; +head $dir/oovlist.handled.counts +echo "**Top OOVs we didn't handle are as follows (note: they are mostly misspellings):**"; +head $dir/oovlist.not_handled.counts + + +echo "Count of OOVs we handled is `awk '{x+=$1} END{print x}' $dir/oovlist.handled.counts`" +echo "Count of OOVs we couldn't handle is `awk '{x+=$1} END{print x}' $dir/oovlist.not_handled.counts`" +echo "Count of OOVs we didn't handle due to low count is" \ + `awk -v thresh=$mincount '{if ($1 < thresh) x+=$1; } END{print x;}' $dir/oov.counts` +# The two files created above are for humans to look at, as diagnostics. + +cat < $dir/lexicon.txt +!SIL SIL + SPN + SPN + NSN +EOF + +echo "Created $dir/lexicon.txt" diff --git a/egs/wsj_noisy/s5/local/dict/add_counts.pl b/egs/wsj_noisy/s5/local/dict/add_counts.pl new file mode 100755 index 00000000000..a2ace7e9af2 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/add_counts.pl @@ -0,0 +1,31 @@ +#!/usr/bin/env perl + + +# Add counts to an oovlist. +# Reads in counts as output by uniq -c, and +# an oovlist, and prints out the counts of the oovlist. + +(@ARGV == 1 || @ARGV == 2) || die "Usage: add_counts.pl count_file [oovlist]\n"; + +$counts = shift @ARGV; + +open(C, "<$counts") || die "Opening counts file $counts"; + +while() { + @A = split(" ", $_); + @A == 2 || die "Bad line in counts file: $_"; + ($count, $word) = @A; + $count =~ m:^\d+$: || die "Bad count $A[0]\n"; + $counts{$word} = $count; +} + +while(<>) { + chop; + $w = $_; + $w =~ m:\S+: || die "Bad word $w"; + defined $counts{$w} || die "Word $w not present in counts file"; + print "\t$counts{$w}\t$w\n"; +} + + + diff --git a/egs/wsj_noisy/s5/local/dict/count_rules.pl b/egs/wsj_noisy/s5/local/dict/count_rules.pl new file mode 100755 index 00000000000..1c6cfc4a547 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/count_rules.pl @@ -0,0 +1,44 @@ +#!/usr/bin/env perl + +# This program takes the output of score_prons.pl and collates +# it for each (rule, destress) pair so that we get the +# counts of right/partial/wrong for each pair. + +# The input is a 7-tuple on each line, like: +# word;pron;base-word;base-pron;rule-name;de-stress;right|partial|wrong +# +# The output format is a 5-tuple like: +# +# rule;destress;right-count;partial-count;wrong-count +# + +if (@ARGV != 0 && @ARGV != 1) { + die "Usage: count_rules.pl < scored_candidate_prons > rule_counts"; +} + + +while(<>) { + chop; + $line = $_; + my ($word, $pron, $baseword, $basepron, $rulename, $destress, $score) = split(";", $line); + + my $key = $rulename . ";" . $destress; + + if (!defined $counts{$key}) { + $counts{$key} = [ 0, 0, 0 ]; # new anonymous array. + } + $ref = $counts{$key}; + if ($score eq "right") { + $$ref[0]++; + } elsif ($score eq "partial") { + $$ref[1]++; + } elsif ($score eq "wrong") { + $$ref[2]++; + } else { + die "Bad score $score\n"; + } +} + +while ( my ($key, $value) = each(%counts)) { + print $key . ";" . join(";", @$value) . "\n"; +} diff --git a/egs/wsj_noisy/s5/local/dict/filter_dict.pl b/egs/wsj_noisy/s5/local/dict/filter_dict.pl new file mode 100755 index 00000000000..5e32823ef92 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/filter_dict.pl @@ -0,0 +1,19 @@ +#!/usr/bin/env perl + + +# This program reads and writes either a dictionary or just a list +# of words, and it removes any words containing ";" or "," as these +# are used in these programs. It will warn about these. +# It will die if the pronunciations have these symbols in. +while(<>) { + chop; + @A = split(" ", $_); + $word = shift @A; + + if ($word =~ m:[;,]:) { + print STDERR "Omitting line $_ since it has one of the banned characters ; or ,\n" ; + } else { + $_ =~ m:[;,]: && die "Phones cannot have ; or , in them."; + print $_ . "\n"; + } +} diff --git a/egs/wsj_noisy/s5/local/dict/find_acronyms.pl b/egs/wsj_noisy/s5/local/dict/find_acronyms.pl new file mode 100755 index 00000000000..55e474c4056 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/find_acronyms.pl @@ -0,0 +1,95 @@ +#!/usr/bin/env perl + +# Reads a dictionary, and prints out a list of words that seem to be pronounced +# as acronyms (not including plurals of acronyms, just acronyms). Uses +# the prons of the individual letters (A., B. and so on) to judge this. +# Note: this is somewhat dependent on the convention used in CMUduct, that +# the individual letters are spelled this way (e.g. "A."). + +$max_length = 6; # Max length of words that might be + # acronyms. + +while(<>) { # Read the dict. + chop; + @A = split(" ", $_); + $word = shift @A; + $pron = join(" ", @A); + if ($word =~ m/^([A-Z])\.$/ ) { + chop $word; # Remove trailing "." to get just the letter + $letter = $1; + if (!defined $letter_prons{$letter} ) { + $letter_prons{$letter} = [ ]; # new anonymous array + } + $arrayref = $letter_prons{$letter}; + push @$arrayref, $pron; + } elsif( length($word) <= $max_length ) { + $pronof{$word . "," . $pron} = 1; + $isword{$word} = 1; + #if (!defined $prons{$word} ) { + # $prons{$word} = [ ]; + #} + # push @{$prons{$word}}, $pron; + } +} + +sub get_letter_prons; + +foreach $word (keys %isword) { + my @letter_prons = get_letter_prons($word); + foreach $pron (@letter_prons) { + if (defined $pronof{$word.",".$pron}) { + print "$word $pron\n"; + } + } +} + + +sub get_letter_prons { + @acronym = split("", shift); # The letters in the word. + my @prons = ( "" ); + + while (@acronym > 0) { + $l = shift @acronym; + $n = 1; # num-repeats of letter $l. + while (@acronym > 0 && $acronym[0] eq $l) { + $n++; + shift @acronym; + } + my $arrayref = $letter_prons{$l}; + my @prons_of_block = (); + if ($n == 1) { # Just one repeat. + foreach $lpron ( @$arrayref ) { + push @prons_of_block, $lpron; # typically (always?) just one pron of a letter. + } + } elsif ($n == 2) { # Two repeats. Can be "double a" or "a a" + foreach $lpron ( @$arrayref ) { + push @prons_of_block, "D AH1 B AH0 L " . $lpron; + push @prons_of_block, $lpron . $lpron; + } + } elsif ($n == 3) { # can be "triple a" or "a a a" + foreach $lpron ( @$arrayref ) { + push @prons_of_block, "T R IH1 P AH0 L " . $lpron; + push @prons_of_block, $lpron . $lpron . $lpron; + } + } elsif ($n >= 4) { # let's say it can only be that letter repeated $n times.. + # not sure really. + foreach $lpron ( @$arrayref ) { + $nlpron = ""; + for ($m = 0; $m < $n; $m++) { $nlpron = $nlpron . $lpron; } + push @prons_of_block, $nlpron; + } + } + my @new_prons = (); + foreach $pron (@prons) { + foreach $pron_of_block(@prons_of_block) { + if ($pron eq "") { + push @new_prons, $pron_of_block; + } else { + push @new_prons, $pron . " " . $pron_of_block; + } + } + } + @prons = @new_prons; + } + return @prons; +} diff --git a/egs/wsj_noisy/s5/local/dict/get_acronym_prons.pl b/egs/wsj_noisy/s5/local/dict/get_acronym_prons.pl new file mode 100755 index 00000000000..6294b7046e2 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/get_acronym_prons.pl @@ -0,0 +1,123 @@ +#!/usr/bin/env perl + +# Reads a dictionary (for prons of letters), and an OOV list, +# and puts out candidate pronunciations of words in that list +# that could plausibly be acronyms. +# We judge that a word can plausibly be an acronym if it is +# a sequence of just letters (no non-letter characters such +# as "'"), or something like U.K., +# and the number of letters is four or less. +# +# If the text were not already pre-normalized, there would +# be other hints such as capitalization. + +# This program appends +# the prons of the individual letters (A., B. and so on) to work out +# the pron of the acronym. +# Note: this is somewhat dependent on the convention used in CMUduct, that +# the individual letters are spelled this way (e.g. "A."). [it seems +# to also have the separated versions. + +if (!(@ARGV == 1 || @ARGV == 2)) { + print "Usage: get_acronym_prons.pl dict [oovlist]"; +} + +$max_length = 4; # Max #letters in an acronym. (Longer + # acronyms tend to have "pseudo-pronunciations", e.g. think about UNICEF. + +$dict = shift @ARGV; +open(D, "<$dict") || die "Opening dictionary $dict"; + +while() { # Read the dict, to get the prons of the letters. + chop; + @A = split(" ", $_); + $word = shift @A; + $pron = join(" ", @A); + if ($word =~ m/^([A-Z])\.$/ ) { + chop $word; # Remove trailing "." to get just the letter + $letter = $1; + if (!defined $letter_prons{$letter} ) { + $letter_prons{$letter} = [ ]; # new anonymous array + } + $arrayref = $letter_prons{$letter}; + push @$arrayref, $pron; + } elsif( length($word) <= $max_length ) { + $pronof{$word . "," . $pron} = 1; + $isword{$word} = 1; + #if (!defined $prons{$word} ) { + # $prons{$word} = [ ]; + #} + # push @{$prons{$word}}, $pron; + } +} + +sub get_letter_prons; + +while(<>) { # Read OOVs. + # For now, just do the simple cases without "." in + # between... things with "." in the OOV list seem to + # be mostly errors. + chop; + $word = $_; + if ($word =~ m/^[A-Z]{1,5}$/) { + foreach $pron ( get_letter_prons($word) ) { # E.g. UNPO + print "$word $pron\n"; + } + } elsif ($word =~ m:^(\w\.){1,4}\w\.?$:) { # E.g. U.K. Make the final "." optional. + $letters = $word; + $letters =~ s:\.::g; + foreach $pron ( get_letter_prons($letters) ) { + print "$word $pron\n"; + } + } +} + +sub get_letter_prons { + @acronym = split("", shift); # The letters in the word. + my @prons = ( "" ); + + while (@acronym > 0) { + $l = shift @acronym; + $n = 1; # num-repeats of letter $l. + while (@acronym > 0 && $acronym[0] eq $l) { + $n++; + shift @acronym; + } + my $arrayref = $letter_prons{$l}; + my @prons_of_block = (); + if ($n == 1) { # Just one repeat. + foreach $lpron ( @$arrayref ) { + push @prons_of_block, $lpron; # typically (always?) just one pron of a letter. + } + } elsif ($n == 2) { # Two repeats. Can be "double a" or "a a" + foreach $lpron ( @$arrayref ) { + push @prons_of_block, "D AH1 B AH0 L " . $lpron; + push @prons_of_block, $lpron . " " . $lpron; + } + } elsif ($n == 3) { # can be "triple a" or "a a a" + foreach $lpron ( @$arrayref ) { + push @prons_of_block, "T R IH1 P AH0 L " . $lpron; + push @prons_of_block, "$lpron $lpron $lpron"; + } + } elsif ($n >= 4) { # let's say it can only be that letter repeated $n times.. + # not sure really. + foreach $lpron ( @$arrayref ) { + $nlpron = $lpron; + for ($m = 1; $m < $n; $m++) { $nlpron = $nlpron . " " . $lpron; } + push @prons_of_block, $nlpron; + } + } + my @new_prons = (); + foreach $pron (@prons) { + foreach $pron_of_block(@prons_of_block) { + if ($pron eq "") { + push @new_prons, $pron_of_block; + } else { + push @new_prons, $pron . " " . $pron_of_block; + } + } + } + @prons = @new_prons; + } + return @prons; +} diff --git a/egs/wsj_noisy/s5/local/dict/get_candidate_prons.pl b/egs/wsj_noisy/s5/local/dict/get_candidate_prons.pl new file mode 100755 index 00000000000..b091c6d767e --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/get_candidate_prons.pl @@ -0,0 +1,187 @@ +#!/usr/bin/env perl + +# This script takes three command-line arguments (typically files, or "-"): +# the suffix rules (as output by get_rules.pl), the rule-hierarchy +# (from get_rule_hierarchy.pl), and the words that we want prons to be +# generated for (one per line). + +# The output consists of candidate generated pronunciations for those words, +# together with information about how we generated those pronunciations. +# This does not do pruning of the candidates using the restriction +# "you can't use a more general rule when a more specific one is applicable". +# That is done by limit_candidate_prons.pl. + +# Each line of the output consists of a 4-tuple, separated by ";", of the +# form: +# word;pron;base-word;base-pron;rule-name;destress[;rule-score] +# [the last field is only present if you supplied rules with score information]. +# where: +# - "word" is the input word that we queried for, e.g. WASTED +# - "pron" is the generated pronunciation, e.g. "W EY1 S T AH0 D" +# - rule-name is a 4-tuple separated by commas that describes the rule, e.g. +# "STED,STING,D,NG", +# - "base-word" is the base-word we're getting the pron from, +# e.g. WASTING +# - "base-pron" is the pron of the base-word, e.g. "W EY1 S T IH0 NG" +# - "destress" is either "yes" or "no" and corresponds to whether we destressed the +# base-word or not [de-stressing just corresponds to just taking any 2's down to 1's, +# although we may extend this in future]... +# - "rule-score" is a numeric score of the rule (this field is only present +# if there was score information in your rules. + + +(@ARGV == 2 || @ARGV == 3) || die "Usage: get_candidate_prons.pl rules base-dict [ words ]"; + +$min_prefix_len = 3; # this should probably match with get_rules.pl + +$rules = shift @ARGV; # Note: rules may be with destress "yes/no" indicators or without... + # if without, it's treated as if both "yes" and "no" are present. +$dict = shift @ARGV; + +open(R, "<$rules") || die "Opening rules file: $rules"; + +sub process_word; + +while() { + chop $_; + my ($rule, $destress, $rule_score) = split(";", $_); # We may have "destress" markings (yes|no), + # and scores, or we may have just rule, in which case + # $destress and $rule_score will be undefined. + + my @R = split(",", $rule, 4); # "my" means new instance of @R each + # time we do this loop -> important because we'll be creating + # a reference to @R below. + # Note: the last arg to SPLIT tells it how many fields max to get. + # This stops it from omitting empty trailing fields. + @R == 4 || die "Bad rule $_"; + $suffix = $R[0]; # Suffix of word we want pron for. + if (!defined $isrule{$rule}) { + $isrule{$rule} = 1; # make sure we do this only once for each rule + # (don't repeate for different stresses). + if (!defined $suffix2rule{$suffix}) { + # The syntax [ $x, $y, ... ] means a reference to a newly created array + # containing $x, $y, etc. \@R creates an array reference to R. + # so suffix2rule is a hash from suffix to ref to array of refs to + # 4-dimensional arrays. + $suffix2rule{$suffix} = [ \@R ]; + } else { + # Below, the syntax @{$suffix2rule{$suffix}} dereferences the array + # reference inside the hash; \@R pushes onto that array a new array + # reference pointing to @R. + push @{$suffix2rule{$suffix}}, \@R; + } + } + if (!defined $rule_score) { $rule_score = -1; } # -1 means we don't have the score info. + + # Now store information on which destress markings (yes|no) this rule + # is valid for, and the associated scores (if supplied) + # If just the rule is given (i.e. no destress marking specified), + # assume valid for both. + if (!defined $destress) { # treat as if both "yes" and "no" are valid. + $rule_and_destress_to_rule_score{$rule.";yes"} = $rule_score; + $rule_and_destress_to_rule_score{$rule.";no"} = $rule_score; + } else { + $rule_and_destress_to_rule_score{$rule.";".$destress} = $rule_score; + } + +} + +open(D, "<$dict") || die "Opening base dictionary: $dict"; +while() { + @A = split(" ", $_); + $word = shift @A; + $pron = join(" ", @A); + if (!defined $word2prons{$word}) { + $word2prons{$word} = [ $pron ]; # Ref to new anonymous array containing just "pron". + } else { + push @{$word2prons{$word}}, $pron; # Push $pron onto array referred to (@$ref derefs array). + } +} +foreach $word (%word2prons) { + # Set up the hash "prefixcount", which says how many times a char-sequence + # is a prefix (not necessarily a strict prefix) of a word in the dict. + $len = length($word); + for ($l = 0; $l <= $len; $l++) { + $prefixcount{substr($word, 0, $l)}++; + } +} + +open(R, "<$rules") || die "Opening rules file: $rules"; + + +while(<>) { + chop; + m/^\S+$/ || die; + process_word($_); +} + +sub process_word { + my $word = shift @_; + $len = length($word); + # $owncount is used in evaluating whether a particular prefix is a prefix + # of some other word in the dict... if a word itself may be in the dict + # (usually because we're running this on the dict itself), we need to + # correct for this. + if (defined $word2prons{$word}) { $owncount = 1; } else { $owncount = 0; } + + for ($prefix_len = $min_prefix_len; $prefix_len <= $len; $prefix_len++) { + my $prefix = substr($word, 0, $prefix_len); + my $suffix = substr($word, $prefix_len); + if ($prefixcount{$prefix} - $owncount == 0) { + # This prefix is not a prefix of any word in the dict, so no point + # checking the rules below-- none of them can match. + next; + } + $rules_array_ref = $suffix2rule{$suffix}; + if (defined $rules_array_ref) { + foreach $R (@$rules_array_ref) { # @$rules_array_ref dereferences the array. + # $R is a refernce to a 4-dimensional array, whose elements we access with + # $$R[0], etc. + my $base_suffix = $$R[1]; + my $base_word = $prefix . $base_suffix; + my $base_prons_ref = $word2prons{$base_word}; + if (defined $base_prons_ref) { + my $psuffix = $$R[2]; + my $base_psuffix = $$R[3]; + if ($base_psuffix ne "") { + $base_psuffix = " " . $base_psuffix; + # Include " ", the space between phones, to prevent + # matching partial phones below. + } + my $base_psuffix_len = length($base_psuffix); + foreach $base_pron (@$base_prons_ref) { # @$base_prons_ref derefs + # that reference to an array. + my $base_pron_prefix_len = length($base_pron) - $base_psuffix_len; + # Note: these lengths are in characters, not phones. + if ($base_pron_prefix_len >= 0 && + substr($base_pron, $base_pron_prefix_len) eq $base_psuffix) { + # The suffix of the base_pron is what it should be. + my $pron_prefix = substr($base_pron, 0, $base_pron_prefix_len); + my $rule = join(",", @$R); # we'll output this.. + my $len = @R; + for ($destress = 0; $destress <= 1; $destress++) { # Two versions + # of each rule: with destressing and without. + # pron is the generated pron. + if ($destress) { $pron_prefix =~ s/2/1/g; } + my $pron; + if ($psuffix ne "") { $pron = $pron_prefix . " " . $psuffix; } + else { $pron = $pron_prefix; } + # Now print out the info about the generated pron. + my $destress_mark = ($destress ? "yes" : "no"); + my $rule_score = $rule_and_destress_to_rule_score{$rule.";".$destress_mark}; + if (defined $rule_score) { # Means that the (rule,destress) combination was + # seen [note: this if-statement may be pointless, as currently we don't + # do any pruning of rules]. + my @output = ($word, $pron, $base_word, $base_pron, $rule, $destress_mark); + if ($rule_score != -1) { push @output, $rule_score; } # If scores were supplied, + # we also output the score info. + print join(";", @output) . "\n"; + } + } + } + } + } + } + } + } +} diff --git a/egs/wsj_noisy/s5/local/dict/get_rule_hierarchy.pl b/egs/wsj_noisy/s5/local/dict/get_rule_hierarchy.pl new file mode 100755 index 00000000000..d7c13a8df57 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/get_rule_hierarchy.pl @@ -0,0 +1,73 @@ +#!/usr/bin/env perl + +#This reads in rules, of the form put out by get_rules.pl, e.g.: +# ERT,,ER0 T, +# MENT,ING,M AH0 N T,IH0 NG +# S,TON,Z,T AH0 N +# ,ER,IH0 NG,IH0 NG ER0 +# ,'S,M AH0 N,M AH0 N Z +#TIONS,TIVE,SH AH0 N Z,T IH0 V + +# and it works out a hierarchy that says which rules are sub-cases +# of which rules: it outputs on each line a pair separated by ";", where +# each member of the pair is a rule, first one is the specialization, the +# second one being more general. +# E.g.: +# RED,RE,D,/ED,E,D, +# RED,RE,D,/D,,D, +# GING,GE,IH0 NG,/ING,I,IH0 NG, +# TOR,TING,T ER0,T IH0 NG/OR,OR,T ER0,T ER0 +# ERED,ER,D,/RED,R,D, +# ERED,ER,D,/ED,,D, + + + + +while(<>) { + chop; + $rule = $_; + $isrule{$rule} = 1; + push @rules, $rule; +} + +foreach my $rule (@rules) { + # Truncate the letters and phones in the rule, while we + # can, to get more general rules; if the more general rule + # exists, put out the pair. + @A = split(",", $rule); + @suffixa = split("", $A[0]); + @suffixb = split("", $A[1]); + @psuffixa = split(" ", $A[2]); + @psuffixb = split(" ", $A[3]); + for ($common_suffix_len = 0; $common_suffix_len < @suffixa && $common_suffix_len < @suffixb;) { + if ($suffixa[$common_suffix_len] eq $suffixb[$common_suffix_len]) { + $common_suffix_len++; + } else { + last; + } + } + for ($common_psuffix_len = 0; $common_psuffix_len < @psuffixa && $common_psuffix_len < @psuffixb;) { + if ($psuffixa[$common_psuffix_len] eq $psuffixb[$common_psuffix_len]) { + $common_psuffix_len++; + } else { + last; + } + } + # Get all combinations of pairs of integers <= (common_suffix_len, common_psuffix_len), + # except (0,0), and print out this rule together with the corresponding rule (if it exists). + for ($m = 0; $m <= $common_suffix_len; $m++) { + $sa = join("", @suffixa[$m...$#suffixa]); # @x[a..b] is array slice notation. + $sb = join("", @suffixb[$m...$#suffixb]); + for ($n = 0; $n <= $common_psuffix_len; $n++) { + if (!($m == 0 && $n == 0)) { + $psa = join(" ", @psuffixa[$n...$#psuffixa]); + $psb = join(" ", @psuffixb[$n...$#psuffixb]); + $more_general_rule = join(",", ($sa, $sb, $psa, $psb)); + if (defined $isrule{$more_general_rule}) { + print $rule . ";" . $more_general_rule . "\n"; + } + } + } + } +} + diff --git a/egs/wsj_noisy/s5/local/dict/get_rules.pl b/egs/wsj_noisy/s5/local/dict/get_rules.pl new file mode 100755 index 00000000000..b10eccc9171 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/get_rules.pl @@ -0,0 +1,204 @@ +#!/usr/bin/env perl + +# This program creates suggested suffix rules from a dictionary. +# It outputs quadruples of the form: +# suffix,base-suffix,psuffix,base-psuffix +# where "suffix" is the suffix of the letters of a word, "base-suffix" is +# the suffix of the letters of the base-word, "psuffix" is the suffix of the +# pronunciation of the word (a space-separated list of phonemes), and +# "base-psuffix" is the suffix of the pronunciation of the baseword. +# As far as this program is concerned, there is no distinction between +# "word" and "base-word". To simplify things slightly, what it does +# is return all tuples (a,b,c,d) [with a != b] such that there are +# at least $min_suffix_count instances in the dictionary of +# a (word-prefix, pron-prefix) pair where there exists (word,pron) +# pairs of the form +# ( word-prefix . a, pron-prefix . c) +# and +# ( word-prefix . b, pron-prefix . d) +# For example if (a,b,c,d) equals (USLY,US,S L IY0,S) +# then this quadruple will be output as long as there at least +# e.g. 30 instances of prefixes like (FAM, F EY1 M AH0) +# where there exist (word, pron) pairs like: +# FAMOUS, F EY1 M AH0 S +# FAMOUSLY F EY1 M AH0 S L IY0 +# +# There are some modifications to the picture above, for efficiency. +# If $disallow_empty_suffix != 0, this program will not output 4-tuples where +# the first element (the own-word suffix) is empty, as this would cause +# efficiency problems in get_candidate_prons.pl. If +# $ignore_prefix_stress != 0, this program will ignore stress markings +# while evaluating whether prefixes are the same. +# The minimum count for a quadruple to be output is $min_suffix_count +# (e.g. 30). +# +# The function of this program is not to evaluate the accuracy of these rules; +# it is mostly a pruning step, where we suggest rules that have large enough +# counts to be suitable for our later procedure where we evaluate their +# accuracy in predicting prons. + +$disallow_empty_suffix = 1; # Disallow rules where the suffix of the "own-word" is + # empty. This is for efficiency in later stages (e.g. get_candidate_prons.pl). +$min_prefix_len = 3; # this must match with get_candidate_prons.pl +$ignore_prefix_stress = 1; # or 0 to take account of stress in prefix. +$min_suffix_count = 20; + +# Takes in dictionary. + +print STDERR "Reading dict\n"; +while(<>) { + @A = split(" ", $_); + my $word = shift @A; + my $pron = join(" ", @A); + if (!defined $prons{$word}) { + $prons{$word} = $pron; + push @words, $word; + } else { + $prons{$word} = $prons{$word} . ";" . $pron; + } +} + +# Get common suffixes (e.g., count >100). Include empty suffix. + +print STDERR "Getting common suffix counts.\n"; +{ + foreach $word (@words) { + $len = length($word); + for ($x = $min_prefix_len; $x <= $len; $x++) { + $suffix_count{substr($word, $x)}++; + } + } + + foreach $suffix (keys %suffix_count) { + if ($suffix_count{$suffix} >= $min_suffix_count) { + $newsuffix_count{$suffix} = $suffix_count{$suffix}; + } + } + %suffix_count = %newsuffix_count; + undef %newsuffix_count; + + foreach $suffix ( sort { $suffix_count{$b} <=> $suffix_count{$a} } keys %suffix_count ) { + print STDERR "$suffix_count{$suffix} $suffix\n"; + } +} + +print STDERR "Getting common suffix pairs.\n"; + +{ + print STDERR " Getting map from prefix -> suffix-set.\n"; + + # Create map from prefix -> suffix-set. + foreach $word (@words) { + $len = length($word); + for ($x = $min_prefix_len; $x <= $len; $x++) { + $prefix = substr($word, 0, $x); + $suffix = substr($word, $x); + if (defined $suffix_count{$suffix}) { # Suffix is common... + if (!defined $suffixes_of{$prefix}) { + $suffixes_of{$prefix} = [ $suffix ]; # Create a reference to a new array with + # one element. + } else { + push @{$suffixes_of{$prefix}}, $suffix; # Push $suffix onto array that the + # hash member is a reference . + } + } + } + } + my %suffix_set_count; + print STDERR " Getting map from suffix-set -> count.\n"; + while ( my ($key, $value) = each(%suffixes_of) ) { + my @suffixes = sort ( @$value ); + $suffix_set_count{join(";", @suffixes)}++; + } + print STDERR " Getting counts for suffix pairs.\n"; + while ( my ($suffix_set, $count) = each (%suffix_set_count) ) { + my @suffixes = split(";", $suffix_set); + # Consider pairs to be ordered. This is more convenient + # later on. + foreach $suffix_a (@suffixes) { + foreach $suffix_b (@suffixes) { + if ($suffix_a ne $suffix_b) { + $suffix_pair = $suffix_a . "," . $suffix_b; + $suffix_pair_count{$suffix_pair} += $count; + } + } + } + } + + # To save memory, only keep pairs above threshold in the hash. + while ( my ($suffix_pair, $count) = each (%suffix_pair_count) ) { + if ($count >= $min_suffix_count) { + $new_hash{$suffix_pair} = $count; + } + } + %suffix_pair_count = %new_hash; + undef %new_hash; + + # Print out the suffix pairs so the user can see. + foreach $suffix_pair ( + sort { $suffix_pair_count{$b} <=> $suffix_pair_count{$a} } keys %suffix_pair_count ) { + print STDERR "$suffix_pair_count{$suffix_pair} $suffix_pair\n"; + } +} + +print STDERR "Getting common suffix/suffix/psuffix/psuffix quadruples\n"; + +{ + while ( my ($prefix, $suffixes_ref) = each(%suffixes_of) ) { + # Note: suffixes_ref is a reference to an array. We dereference with + # @$suffixes_ref. + # Consider each pair of suffixes (in each order). + foreach my $suffix_a ( @$suffixes_ref ) { + foreach my $suffix_b ( @$suffixes_ref ) { + # could just used "defined" in next line, but this is for clarity. + $suffix_pair = $suffix_a.",".$suffix_b; + if ( $suffix_pair_count{$suffix_pair} >= $min_suffix_count ) { + foreach $pron_a_str (split(";", $prons{$prefix.$suffix_a})) { + @pron_a = split(" ", $pron_a_str); + foreach $pron_b_str (split(";", $prons{$prefix.$suffix_b})) { + @pron_b = split(" ", $pron_b_str); + $len_a = @pron_a; # evaluating array as scalar automatically gives length. + $len_b = @pron_b; + for (my $pos = 0; $pos <= $len_a && $pos <= $len_b; $pos++) { + # $pos is starting-pos of psuffix-pair. + $psuffix_a = join(" ", @pron_a[$pos...$#pron_a]); + $psuffix_b = join(" ", @pron_b[$pos...$#pron_b]); + $quadruple = $suffix_pair . "," . $psuffix_a . "," . $psuffix_b; + $quadruple_count{$quadruple}++; + + my $pron_a_pos = $pron_a[$pos], $pron_b_pos = $pron_b[$pos]; + if ($ignore_prefix_stress) { + $pron_a_pos =~ s/\d//; # e.g convert IH0 to IH. Only affects + $pron_b_pos =~ s/\d//; # whether we exit the loop below. + } + if ($pron_a_pos ne $pron_b_pos) { + # This is important: we don't consider a pron suffix-pair to be + # valid unless the pron prefix is the same. + last; + } + } + } + } + } + } + } + } + # To save memory, only keep pairs above threshold in the hash. + while ( my ($quadruple, $count) = each (%quadruple_count) ) { + if ($count >= $min_suffix_count) { + $new_hash{$quadruple} = $count; + } + } + %quadruple_count = %new_hash; + undef %new_hash; + + # Print out the quadruples for diagnostics. + foreach $quadruple ( + sort { $quadruple_count{$b} <=> $quadruple_count{$a} } keys %quadruple_count ) { + print STDERR "$quadruple_count{$quadruple} $quadruple\n"; + } +} +# Now print out the quadruples; these are the output of this program. +foreach $quadruple (keys %quadruple_count) { + print $quadruple."\n"; +} diff --git a/egs/wsj_noisy/s5/local/dict/limit_candidate_prons.pl b/egs/wsj_noisy/s5/local/dict/limit_candidate_prons.pl new file mode 100755 index 00000000000..b01218f6e96 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/limit_candidate_prons.pl @@ -0,0 +1,103 @@ +#!/usr/bin/env perl + +# This program enforces the rule that +# if a "more specific" rule applies, we cannot use the more general rule. +# It takes in tuples generated by get_candidate_prons (one per line, separated +# by ";"), of the form: +# word;pron;base-word;base-pron;rule-name;de-stress[;rule-score] +# [note: we mean that the last element, the numeric score of the rule, is optional] +# and it outputs a (generally shorter) list +# of the same form. + + +# For each word: + # For each (base-word,base-pron): + # Eliminate "more-general" rules as follows: + # For each pair of rules applying to this (base-word, base-pron): + # If pair is in more-general hash, disallow more general one. + # Let the output be: for each (base-word, base-pron, rule): + # for (destress-prefix) in [yes, no], do: + # print out the word input, the rule-name, [destressed:yes|no], and the new pron. + + +if (@ARGV != 1 && @ARGV != 2) { + die "Usage: limit_candidate_prons.pl rule_hierarchy [candidate_prons] > limited_candidate_prons"; +} + +$hierarchy = shift @ARGV; +open(H, "<$hierarchy") || die "Opening rule hierarchy $hierarchy"; + +while() { + chop; + m:.+;.+: || die "Bad rule-hierarchy line $_"; + $hierarchy{$_} = 1; # Format is: if $rule1 is the string form of the more specific rule + # and $rule21 is that string form of the more general rule, then $hierarchy{$rule1.";".$rule2} + # is defined, else undefined. +} + + +sub process_word; + +undef $cur_word; +@cur_lines = (); + +while(<>) { + # input, output is: + # word;pron;base-word;base-pron;rule-name;destress;score + chop; + m:^([^;]+);: || die "Unexpected input: $_"; + $word = $1; + if (!defined $cur_word || $word eq $cur_word) { + if (!defined $cur_word) { $cur_word = $word; } + push @cur_lines, $_; + } else { + process_word(@cur_lines); # Process a series of suggested prons + # for a particular word. + $cur_word = $word; + @cur_lines = ( $_ ); + } +} +process_word(@cur_lines); + +sub process_word { + my %pair2rule_list; # hash from $baseword.";".$baseword to ref + # to array of [ line1, line2, ... ]. + my @cur_lines = @_; + foreach my $line (@cur_lines) { + my ($word, $pron, $baseword, $basepron, $rulename, $destress, $rule_score) = split(";", $line); + my $key = $baseword.";".$basepron; + if (defined $pair2rule_list{$key}) { + push @{$pair2rule_list{$key}}, $line; # @{...} derefs the array pointed to + # by the array ref inside {}. + } else { + $pair2rule_list{$key} = [ $line ]; # [ $x ] is new anonymous array with 1 elem ($x) + } + } + while ( my ($key, $value) = each(%pair2rule_list) ) { + my @lines = @$value; # array of lines that are for this (baseword,basepron). + my @stress, @rules; # Arrays of stress markers and rule names, indexed by + # same index that indexes @lines. + for (my $n = 0; $n < @lines; $n++) { + my $line = $lines[$n]; + my ($word, $pron, $baseword, $basepron, $rulename, $destress, $rule_score) = split(";", $line); + $stress[$n] = $destress; + $rules[$n] = $rulename; + } + for (my $m = 0; $m < @lines; $m++) { + my $ok = 1; # if stays 1, this line is OK. + for (my $n = 0; $n < @lines; $n++) { + if ($m != $n && $stress[$m] eq $stress[$n]) { + if (defined $hierarchy{$rules[$n].";".$rules[$m]}) { + # Note: this "hierarchy" variable is defined if $rules[$n] is a more + # specific instances of $rules[$m], thus invalidating $rules[$m]. + $ok = 0; + last; # no point iterating further. + } + } + } + if ($ok != 0) { + print $lines[$m] . "\n"; + } + } + } +} diff --git a/egs/wsj_noisy/s5/local/dict/reverse_candidates.pl b/egs/wsj_noisy/s5/local/dict/reverse_candidates.pl new file mode 100755 index 00000000000..5b7aabd8abd --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/reverse_candidates.pl @@ -0,0 +1,50 @@ +#!/usr/bin/env perl + +# This takes the output of e.g. get_candidate_prons.pl or limit_candidate_prons.pl, +# which is 7-tuples, one per line, of the form: + +# word;pron;base-word;base-pron;rule-name;de-stress;rule-score +# (where rule-score is somtimes listed as optional, but this +# program does expect it, since we don't anticipate it being used +# without it). +# This program assumes that all the words and prons and rules have +# come from a reversed dictionary (reverse_dict.pl) where the order +# of the characters in the words, and the phones in the prons, have +# been reversed, and it un-reverses them. That it, the characters +# in "word" and "base-word", and the phones in "pron" and "base-pron" +# are reversed; and the rule ("rule-name") is parsed as a 4-tuple, +# like: +# suffix,base-suffix,psuffix,base-psuffix +# so this program reverses the characters in "suffix" and "base-suffix" +# and the phones (separated by spaces) in "psuffix" and "base-psuffix". + +sub reverse_str { + $str = shift; + return join("", reverse(split("", $str))); +} +sub reverse_pron { + $str = shift; + return join(" ", reverse(split(" ", $str))); +} + +while(<>){ + chop; + @A = split(";", $_); + @A == 7 || die "Bad input line $_: found $len fields, expected 7."; + + ($word,$pron,$baseword,$basepron,$rule,$destress,$score) = @A; + $word = reverse_str($word); + $pron = reverse_pron($pron); + $baseword = reverse_str($baseword); + $basepron = reverse_pron($basepron); + @R = split(",", $rule, 4); + @R == 4 || die "Bad rule $rule"; + + $R[0] = reverse_str($R[0]); # suffix. + $R[1] = reverse_str($R[1]); # base-suffix. + $R[2] = reverse_pron($R[2]); # pron. + $R[3] = reverse_pron($R[3]); # base-pron. + $rule = join(",", @R); + @A = ($word,$pron,$baseword,$basepron,$rule,$destress,$score); + print join(";", @A) . "\n"; +} diff --git a/egs/wsj_noisy/s5/local/dict/reverse_dict.pl b/egs/wsj_noisy/s5/local/dict/reverse_dict.pl new file mode 100755 index 00000000000..2cd38c54b6a --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/reverse_dict.pl @@ -0,0 +1,14 @@ +#!/usr/bin/env perl + +# Used in conjunction with get_rules.pl +# example input line: XANTHE Z AE1 N DH +# example output line: EHTNAX DH N AE1 Z + +while(<>){ + @A = split(" ", $_); + $word = shift @A; + $word = join("", reverse(split("", $word))); # Reverse letters of word. + @A = reverse(@A); # Reverse phones in pron. + unshift @A, $word; + print join(" ", @A) . "\n"; +} diff --git a/egs/wsj_noisy/s5/local/dict/score_prons.pl b/egs/wsj_noisy/s5/local/dict/score_prons.pl new file mode 100755 index 00000000000..6aa72e42158 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/score_prons.pl @@ -0,0 +1,50 @@ +#!/usr/bin/env perl + +# This program takes candidate prons from "get_candidate_prons.pl" or +# "limit_candidate_prons.pl", and a reference dictionary covering those words, +# and outputs the same format but with scoring information added (so we go from +# 6 to 7 fields). The scoring information says, for each generated pron, +# whether we have a match, a partial match, or no match, to some word in the +# dictionary. A partial match means it's correct except for stress. + +# The input is a 6-tuple on each line, like: +# word;pron;base-word;base-pron;rule-name;de-stress +# +# The output is the same except with one more field, the score, +# which may be "right", "wrong", "partial". + +if (@ARGV != 1 && @ARGV != 2) { + die "Usage: score_prons.pl reference_dict [candidate_prons] > scored_candidate_prons"; +} + +$dict = shift @ARGV; +open(D, "<$dict") || die "Opening dictionary $dict"; + +while() { # Set up some hashes that tell us when + # a (word,pron) pair is correct (and the same for + # prons with stress information removed). + chop; + @A = split(" ", $_); + $word = shift @A; + $pron = join(" ", @A); + $pron_nostress = $pron; + $pron_nostress =~ s:\d::g; + $word_and_pron{$word.";".$pron} = 1; + $word_and_pron_nostress{$word.";".$pron_nostress} = 1; +} + +while(<>) { + chop; + $line = $_; + my ($word, $pron, $baseword, $basepron, $rulename, $destress) = split(";", $line); + $pron_nostress = $pron; + $pron_nostress =~ s:\d::g; + if (defined $word_and_pron{$word.";".$pron}) { + $score = "right"; + } elsif (defined $word_and_pron_nostress{$word.";".$pron_nostress}) { + $score = "partial"; + } else { + $score = "wrong"; + } + print $line.";".$score."\n"; +} diff --git a/egs/wsj_noisy/s5/local/dict/score_rules.pl b/egs/wsj_noisy/s5/local/dict/score_rules.pl new file mode 100755 index 00000000000..252d94677d5 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/score_rules.pl @@ -0,0 +1,52 @@ +#!/usr/bin/env perl + +# This program takes the output of count_rules.pl, which is tuples +# of the form +# +# rule;destress;right-count;partial-count;wrong-count +# +# and outputs lines of the form +# +# rule;de-stress;score +# +# where the score, between 0 and 1 (1 better), is +# equal to: +# +# It forms a score between 0 and 1, of the form: +# ((#correct) + $partial_score * (#partial)) / (#correct + #partial + #wrong + $ballast) +# +# where $partial_score (e.g. 0.8) is the score we assign to a "partial" match, +# and $ballast is a small number, e.g. 1, that is treated like "extra" wrong scores, to penalize +# rules with few observations. +# +# It outputs all rules that at are at least the + +$ballast = 1; +$partial_score = 0.8; +$destress_penalty = 1.0e-05; # Give destressed rules a small +# penalty vs. their no-destress counterparts, so if we +# have to choose arbitrarily we won't destress (seems safer)> + +for ($n = 1; $n <= 4; $n++) { + if ($ARGV[0] eq "--ballast") { + shift @ARGV; + $ballast = shift @ARGV; + } + if ($ARGV[0] eq "--partial-score") { + shift @ARGV; + $partial_score = shift @ARGV; + ($partial_score >= 0.0 && $partial_score <= 1.0) || die "Invalid partial_score: $partial_score"; + } +} + +(@ARGV == 0 || @ARGV == 1) || die "Usage: score_rules.pl [--ballast ballast-count] [--partial-score partial-score] [input from count_rules.pl]"; + +while(<>) { + @A = split(";", $_); + @A == 5 || die "Bad input line; $_"; + ($rule,$destress,$right_count,$partial_count,$wrong_count) = @A; + $rule_score = ($right_count + $partial_score*$partial_count) / + ($right_count+$partial_count+$wrong_count+$ballast); + if ($destress eq "yes") { $rule_score -= $destress_penalty; } + print join(";", $rule, $destress, sprintf("%.5f", $rule_score)) . "\n"; +} diff --git a/egs/wsj_noisy/s5/local/dict/select_candidate_prons.pl b/egs/wsj_noisy/s5/local/dict/select_candidate_prons.pl new file mode 100755 index 00000000000..a24ccdd4de8 --- /dev/null +++ b/egs/wsj_noisy/s5/local/dict/select_candidate_prons.pl @@ -0,0 +1,84 @@ +#!/usr/bin/env perl + +# This takes the output of e.g. get_candidate_prons.pl or limit_candidate_prons.pl +# or reverse_candidates.pl, which is 7-tuples, one per line, of the form: +# +# word;pron;base-word;base-pron;rule-name;de-stress;rule-score +# +# and selects the most likely prons for the words based on rule +# score. It outputs in the same format as the input (thus, it is +# similar to limit_candidates.pl in its input and output format, +# except it has a different way of selecting the prons to put out). +# +# This script will select the $max_prons best pronunciations for +# each candidate word, subject to the constraint that no pron should +# have a rule score worse than $min_rule_score. +# It first merges the candidates by, if there are multiple candidates +# generating the same pron, selecting the candidate that had the +# best associated score. It then sorts the prons on score and +# selects the n best prons (but doesn't print out candidates with +# score beneath the threshold). + + +$max_prons = 4; +$min_rule_score = 0.35; + + +for ($n = 1; $n <= 3; $n++) { + if ($ARGV[0] eq "--max-prons") { + shift @ARGV; + $max_prons = shift @ARGV; + } + if ($ARGV[0] eq "--min-rule-score") { + shift @ARGV; + $min_rule_score = shift @ARGV; + } +} + +if (@ARGV != 0 && @ARGV != 1) { + die "Usage: select_candidates_prons.pl [candidate_prons] > selected_candidate_prons"; +} + +sub process_word; + +undef $cur_word; +@cur_lines = (); + +while(<>) { + # input, output is: + # word;pron;base-word;base-pron;rule-name;destress;score + chop; + m:^([^;]+);: || die "Unexpected input: $_"; + $word = $1; + if (!defined $cur_word || $word eq $cur_word) { + if (!defined $cur_word) { $cur_word = $word; } + push @cur_lines, $_; + } else { + process_word(@cur_lines); # Process a series of suggested prons + # for a particular word. + $cur_word = $word; + @cur_lines = ( $_ ); + } +} +process_word(@cur_lines); + + +sub process_word { + my %pron2rule_score; # hash from generated pron to rule score for that pron. + my %pron2line; # hash from generated pron to best line for that pron. + my @cur_lines = @_; + foreach my $line (@cur_lines) { + my ($word, $pron, $baseword, $basepron, $rulename, $destress, $rule_score) = split(";", $line); + if (!defined $pron2rule_score{$pron} || + $rule_score > $pron2rule_score{$pron}) { + $pron2rule_score{$pron} = $rule_score; + $pron2line{$pron} = $line; + } + } + my @prons = sort { $pron2rule_score{$b} <=> $pron2rule_score{$a} } keys %pron2rule_score; + for (my $n = 0; $n < @prons && $n < $max_prons && + $pron2rule_score{$prons[$n]} >= $min_rule_score; $n++) { + print $pron2line{$prons[$n]} . "\n"; + } +} + diff --git a/egs/wsj_noisy/s5/local/find_transcripts.pl b/egs/wsj_noisy/s5/local/find_transcripts.pl new file mode 100755 index 00000000000..6429411b864 --- /dev/null +++ b/egs/wsj_noisy/s5/local/find_transcripts.pl @@ -0,0 +1,64 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + + +# This program takes on its standard input a list of utterance +# id's, one for each line. (e.g. 4k0c030a is a an utterance id). +# It takes as +# Extracts from the dot files the transcripts for a given +# dataset (represented by a file list). +# + +@ARGV == 1 || die "find_transcripts.pl dot_files_flist < utterance_ids > transcripts"; +$dot_flist = shift @ARGV; + +open(L, "<$dot_flist") || die "Opening file list of dot files: $dot_flist\n"; +while(){ + chop; + m:\S+/(\w{6})00.dot: || die "Bad line in dot file list: $_"; + $spk = $1; + $spk2dot{$spk} = $_; +} + + + +while(){ + chop; + $uttid = $_; + $uttid =~ m:(\w{6})\w\w: || die "Bad utterance id $_"; + $spk = $1; + if($spk ne $curspk) { + %utt2trans = { }; # Don't keep all the transcripts in memory... + $curspk = $spk; + $dotfile = $spk2dot{$spk}; + defined $dotfile || die "No dot file for speaker $spk\n"; + open(F, "<$dotfile") || die "Error opening dot file $dotfile\n"; + while() { + $_ =~ m:(.+)\((\w{8})\)\s*$: || die "Bad line $_ in dot file $dotfile (line $.)\n"; + $trans = $1; + $utt = $2; + $utt2trans{$utt} = $trans; + } + } + if(!defined $utt2trans{$uttid}) { + print STDERR "No transcript for utterance $uttid (current dot file is $dotfile)\n"; + } else { + print "$uttid $utt2trans{$uttid}\n"; + } +} + + diff --git a/egs/wsj_noisy/s5/local/flist2scp.pl b/egs/wsj_noisy/s5/local/flist2scp.pl new file mode 100755 index 00000000000..234e4add1ed --- /dev/null +++ b/egs/wsj_noisy/s5/local/flist2scp.pl @@ -0,0 +1,31 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# takes in a file list with lines like +# /mnt/matylda2/data/WSJ1/13-16.1/wsj1/si_dt_20/4k0/4k0c030a.wv1 +# and outputs an scp in kaldi format with lines like +# 4k0c030a /mnt/matylda2/data/WSJ1/13-16.1/wsj1/si_dt_20/4k0/4k0c030a.wv1 +# (the first thing is the utterance-id, which is the same as the basename of the file. + + +while(<>){ + m:^\S+/(\w+)\.[wW][vV]1$: || die "Bad line $_"; + $id = $1; + $id =~ tr/A-Z/a-z/; # Necessary because of weirdness on disk 13-16.1 (uppercase filenames) + print "$id $_"; +} + diff --git a/egs/wsj_noisy/s5/local/generate_example_kws.sh b/egs/wsj_noisy/s5/local/generate_example_kws.sh new file mode 100755 index 00000000000..2c849438192 --- /dev/null +++ b/egs/wsj_noisy/s5/local/generate_example_kws.sh @@ -0,0 +1,110 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Guoguo Chen) +# Apache 2.0. + + +if [ $# -ne 2 ]; then + echo "Usage: local/generate_example_kws.sh " + echo " e.g.: local/generate_example_kws.sh data/test_eval92/ " + exit 1; +fi + +datadir=$1; +kwsdatadir=$2; +text=$datadir/text; + +mkdir -p $kwsdatadir; + +# Generate keywords; we generate 20 unigram keywords with at least 20 counts, +# 20 bigram keywords with at least 10 counts and 10 trigram keywords with at +# least 5 counts. +cat $text | perl -e ' + %unigram = (); + %bigram = (); + %trigram = (); + while(<>) { + chomp; + @col=split(" ", $_); + shift @col; + for($i = 0; $i < @col; $i++) { + # unigram case + if (!defined($unigram{$col[$i]})) { + $unigram{$col[$i]} = 0; + } + $unigram{$col[$i]}++; + + # bigram case + if ($i < @col-1) { + $word = $col[$i] . " " . $col[$i+1]; + if (!defined($bigram{$word})) { + $bigram{$word} = 0; + } + $bigram{$word}++; + } + + # trigram case + if ($i < @col-2) { + $word = $col[$i] . " " . $col[$i+1] . " " . $col[$i+2]; + if (!defined($trigram{$word})) { + $trigram{$word} = 0; + } + $trigram{$word}++; + } + } + } + + $max_count = 100; + $total = 20; + $current = 0; + $min_count = 20; + while ($current < $total && $min_count <= $max_count) { + foreach $x (keys %unigram) { + if ($unigram{$x} == $min_count) { + print "$x\n"; + $unigram{$x} = 0; + $current++; + } + if ($current == $total) { + last; + } + } + $min_count++; + } + + $total = 20; + $current = 0; + $min_count = 4; + while ($current < $total && $min_count <= $max_count) { + foreach $x (keys %bigram) { + if ($bigram{$x} == $min_count) { + print "$x\n"; + $bigram{$x} = 0; + $current++; + } + if ($current == $total) { + last; + } + } + $min_count++; + } + + $total = 10; + $current = 0; + $min_count = 3; + while ($current < $total && $min_count <= $max_count) { + foreach $x (keys %trigram) { + if ($trigram{$x} == $min_count) { + print "$x\n"; + $trigram{$x} = 0; + $current++; + } + if ($current == $total) { + last; + } + } + $min_count++; + } + ' > $kwsdatadir/raw_keywords.txt + +echo "Keywords generation succeeded" diff --git a/egs/wsj_noisy/s5/local/kws_data_prep.sh b/egs/wsj_noisy/s5/local/kws_data_prep.sh new file mode 100755 index 00000000000..5222a88c9ef --- /dev/null +++ b/egs/wsj_noisy/s5/local/kws_data_prep.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Guoguo Chen) +# Apache 2.0. + + +if [ $# -ne 3 ]; then + echo "Usage: local/kws_data_prep.sh " + echo " e.g.: local/kws_data_prep.sh data/lang_test_bd_tgpr/ data/test_eval92/ data/kws/" + exit 1; +fi + +langdir=$1; +datadir=$2; +kwsdatadir=$3; + +mkdir -p $kwsdatadir; + +# Create keyword id for each keyword +cat $kwsdatadir/raw_keywords.txt | perl -e ' + $idx=1; + while(<>) { + chomp; + printf "WSJ-%04d $_\n", $idx; + $idx++; + }' > $kwsdatadir/keywords.txt + +# Map the keywords to integers; note that we remove the keywords that +# are not in our $langdir/words.txt, as we won't find them anyway... +cat $kwsdatadir/keywords.txt | \ + sym2int.pl --map-oov 0 -f 2- $langdir/words.txt | \ + grep -v " 0 " | grep -v " 0$" > $kwsdatadir/keywords.int + +# Compile keywords into FSTs +transcripts-to-fsts ark:$kwsdatadir/keywords.int ark:$kwsdatadir/keywords.fsts + +# Create utterance id for each utterance; Note that by "utterance" here I mean +# the keys that will appear in the lattice archive. You may have to modify here +cat $datadir/wav.scp | \ + awk '{print $1}' | \ + sort | uniq | perl -e ' + $idx=1; + while(<>) { + chomp; + print "$_ $idx\n"; + $idx++; + }' > $kwsdatadir/utter_id + +# Map utterance to the names that will appear in the rttm file. You have +# to modify the commands below accoring to your rttm file. In the WSJ case +# since each file is an utterance, we assume that the actual file names will +# be the "names" in the rttm, so the utterance names map to themselves. +cat $datadir/wav.scp | \ + awk '{print $1}' | \ + sort | uniq | perl -e ' + while(<>) { + chomp; + print "$_ $_\n"; + }' > $kwsdatadir/utter_map; +echo "Kws data preparation succeeded" diff --git a/egs/wsj_noisy/s5/local/ndx2flist.pl b/egs/wsj_noisy/s5/local/ndx2flist.pl new file mode 100755 index 00000000000..48fc3dec101 --- /dev/null +++ b/egs/wsj_noisy/s5/local/ndx2flist.pl @@ -0,0 +1,62 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This program takes as its standard input an .ndx file from the WSJ corpus that looks +# like this: +#;; File: tr_s_wv1.ndx, updated 04/26/94 +#;; +#;; Index for WSJ0 SI-short Sennheiser training data +#;; Data is read WSJ sentences, Sennheiser mic. +#;; Contains 84 speakers X (~100 utts per speaker MIT/SRI and ~50 utts +#;; per speaker TI) = 7236 utts +#;; +#11_1_1:wsj0/si_tr_s/01i/01ic0201.wv1 +#11_1_1:wsj0/si_tr_s/01i/01ic0202.wv1 +#11_1_1:wsj0/si_tr_s/01i/01ic0203.wv1 + +#and as command-line arguments it takes the names of the WSJ disk locations, e.g.: +#/mnt/matylda2/data/WSJ0/11-1.1 /mnt/matylda2/data/WSJ0/11-10.1 ... etc. +# It outputs a list of absolute pathnames (it does this by replacing e.g. 11_1_1 with +# /mnt/matylda2/data/WSJ0/11-1.1. +# It also does a slight fix because one of the WSJ disks (WSJ1/13-16.1) was distributed with +# uppercase rather than lower case filenames. + +foreach $fn (@ARGV) { + $fn =~ m:.+/([0-9\.\-]+)/?$: || die "Bad command-line argument $fn\n"; + $disk_id=$1; + $disk_id =~ tr/-\./__/; # replace - and . with - so 11-10.1 becomes 11_10_1 + $fn =~ s:/$::; # Remove final slash, just in case it is present. + $disk2fn{$disk_id} = $fn; +} + +while(){ + if(m/^;/){ next; } # Comment. Ignore it. + else { + m/^([0-9_]+):\s*(\S+)$/ || die "Could not parse line $_"; + $disk=$1; + if(!defined $disk2fn{$disk}) { + die "Disk id $disk not found"; + } + $filename = $2; # as a subdirectory of the distributed disk. + if($disk eq "13_16_1" && `hostname` =~ m/fit.vutbr.cz/) { + # The disk 13-16.1 has been uppercased for some reason, on the + # BUT system. This is a fix specifically for that case. + $filename =~ tr/a-z/A-Z/; # This disk contains all uppercase filenames. Why? + } + print "$disk2fn{$disk}/$filename\n"; + } +} diff --git a/egs/wsj_noisy/s5/local/nnet/run_dnn.sh b/egs/wsj_noisy/s5/local/nnet/run_dnn.sh new file mode 100755 index 00000000000..edb00c91f18 --- /dev/null +++ b/egs/wsj_noisy/s5/local/nnet/run_dnn.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +# Copyright 2012-2014 Brno University of Technology (Author: Karel Vesely) +# Apache 2.0 + +# This example script trains a DNN on top of fMLLR features. +# The training is done in 3 stages, +# +# 1) RBM pre-training: +# in this unsupervised stage we train stack of RBMs, +# a good starting point for frame cross-entropy trainig. +# 2) frame cross-entropy training: +# the objective is to classify frames to correct pdfs. +# 3) sequence-training optimizing sMBR: +# the objective is to emphasize state-sequences with better +# frame accuracy w.r.t. reference alignment. + +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. + +. ./path.sh ## Source the tools/utils (import the queue.pl) + +# Config: +gmmdir=exp/tri4b +data_fmllr=data-fmllr-tri4b +stage=0 # resume training with --stage=N +# End of config. +. utils/parse_options.sh || exit 1; +# + +if [ $stage -le 0 ]; then + # Store fMLLR features, so we can train on them easily, + # test_dev93 + dir=$data_fmllr/test_dev93 + steps/nnet/make_fmllr_feats.sh --nj 10 --cmd "$train_cmd" \ + --transform-dir $gmmdir/decode_bd_tgpr_dev93 \ + $dir data/test_dev93 $gmmdir $dir/log $dir/data || exit 1 + # test_eval92 + dir=$data_fmllr/test_eval92 + steps/nnet/make_fmllr_feats.sh --nj 8 --cmd "$train_cmd" \ + --transform-dir $gmmdir/decode_bd_tgpr_eval92 \ + $dir data/test_eval92 $gmmdir $dir/log $dir/data || exit 1 + # train + dir=$data_fmllr/train_si284 + steps/nnet/make_fmllr_feats.sh --nj 10 --cmd "$train_cmd" \ + --transform-dir ${gmmdir}_ali_si284 \ + $dir data/train_si284 $gmmdir $dir/log $dir/data || exit 1 + # split the data : 90% train 10% cross-validation (held-out) + utils/subset_data_dir_tr_cv.sh $dir ${dir}_tr90 ${dir}_cv10 || exit 1 +fi + +if [ $stage -le 1 ]; then + # Pre-train DBN, i.e. a stack of RBMs + dir=exp/dnn5b_pretrain-dbn + (tail --pid=$$ -F $dir/log/pretrain_dbn.log 2>/dev/null)& # forward log + $cuda_cmd $dir/log/pretrain_dbn.log \ + steps/nnet/pretrain_dbn.sh --rbm-iter 3 $data_fmllr/train_si284 $dir || exit 1; +fi + +if [ $stage -le 2 ]; then + # Train the DNN optimizing per-frame cross-entropy. + dir=exp/dnn5b_pretrain-dbn_dnn + ali=${gmmdir}_ali_si284 + feature_transform=exp/dnn5b_pretrain-dbn/final.feature_transform + dbn=exp/dnn5b_pretrain-dbn/6.dbn + (tail --pid=$$ -F $dir/log/train_nnet.log 2>/dev/null)& # forward log + # Train + $cuda_cmd $dir/log/train_nnet.log \ + steps/nnet/train.sh --feature-transform $feature_transform --dbn $dbn --hid-layers 0 --learn-rate 0.008 \ + $data_fmllr/train_si284_tr90 $data_fmllr/train_si284_cv10 data/lang $ali $ali $dir || exit 1; + # Decode (reuse HCLG graph) + steps/nnet/decode.sh --nj 10 --cmd "$decode_cmd" --config conf/decode_dnn.config --acwt 0.1 \ + $gmmdir/graph_bd_tgpr $data_fmllr/test_dev93 $dir/decode_bd_tgpr_dev93 || exit 1; + steps/nnet/decode.sh --nj 8 --cmd "$decode_cmd" --config conf/decode_dnn.config --acwt 0.1 \ + $gmmdir/graph_bd_tgpr $data_fmllr/test_eval92 $dir/decode_bd_tgpr_eval92 || exit 1; +fi + + +# Sequence training using sMBR criterion, we do Stochastic-GD +# with per-utterance updates. We use usually good acwt 0.1 +# Lattices are re-generated after 1st epoch, to get faster convergence. +dir=exp/dnn5b_pretrain-dbn_dnn_smbr +srcdir=exp/dnn5b_pretrain-dbn_dnn +acwt=0.1 + +if [ $stage -le 3 ]; then + # First we generate lattices and alignments: + steps/nnet/align.sh --nj 100 --cmd "$train_cmd" \ + $data_fmllr/train_si284 data/lang $srcdir ${srcdir}_ali || exit 1; + steps/nnet/make_denlats.sh --nj 100 --cmd "$decode_cmd" --config conf/decode_dnn.config --acwt $acwt \ + $data_fmllr/train_si284 data/lang $srcdir ${srcdir}_denlats || exit 1; +fi + +if [ $stage -le 4 ]; then + # Re-train the DNN by 1 iteration of sMBR + steps/nnet/train_mpe.sh --cmd "$cuda_cmd" --num-iters 1 --acwt $acwt --do-smbr true \ + $data_fmllr/train_si284 data/lang $srcdir ${srcdir}_ali ${srcdir}_denlats $dir || exit 1 + # Decode (reuse HCLG graph) + for ITER in 1; do + steps/nnet/decode.sh --nj 10 --cmd "$decode_cmd" --config conf/decode_dnn.config \ + --nnet $dir/${ITER}.nnet --acwt $acwt \ + $gmmdir/graph_bd_tgpr $data_fmllr/test_dev93 $dir/decode_bd_tgpr_dev93_it${ITER} || exit 1; + steps/nnet/decode.sh --nj 8 --cmd "$decode_cmd" --config conf/decode_dnn.config \ + --nnet $dir/${ITER}.nnet --acwt $acwt \ + $gmmdir/graph_bd_tgpr $data_fmllr/test_eval92 $dir/decode_bd_tgpr_eval92_it${ITER} || exit 1; + done +fi + +# Re-generate lattices, run 4 more sMBR iterations +dir=exp/dnn5b_pretrain-dbn_dnn_smbr_i1lats +srcdir=exp/dnn5b_pretrain-dbn_dnn_smbr +acwt=0.1 + +if [ $stage -le 5 ]; then + # Generate lattices and alignments: + steps/nnet/align.sh --nj 100 --cmd "$train_cmd" \ + $data_fmllr/train_si284 data/lang $srcdir ${srcdir}_ali || exit 1; + steps/nnet/make_denlats.sh --nj 100 --cmd "$decode_cmd" --config conf/decode_dnn.config --acwt $acwt \ + $data_fmllr/train_si284 data/lang $srcdir ${srcdir}_denlats || exit 1; +fi + +if [ $stage -le 6 ]; then + # Re-train the DNN by 1 iteration of sMBR + steps/nnet/train_mpe.sh --cmd "$cuda_cmd" --num-iters 4 --acwt $acwt --do-smbr true \ + $data_fmllr/train_si284 data/lang $srcdir ${srcdir}_ali ${srcdir}_denlats $dir || exit 1 + # Decode (reuse HCLG graph) + for ITER in 1 2 3 4; do + steps/nnet/decode.sh --nj 10 --cmd "$decode_cmd" --config conf/decode_dnn.config \ + --nnet $dir/${ITER}.nnet --acwt $acwt \ + $gmmdir/graph_bd_tgpr $data_fmllr/test_dev93 $dir/decode_bd_tgpr_dev93_iter${ITER} || exit 1; + steps/nnet/decode.sh --nj 8 --cmd "$decode_cmd" --config conf/decode_dnn.config \ + --nnet $dir/${ITER}.nnet --acwt $acwt \ + $gmmdir/graph_bd_tgpr $data_fmllr/test_eval92 $dir/decode_bd_tgpr_eval92_iter${ITER} || exit 1; + done +fi + +# Getting results [see RESULTS file] +# for x in exp/*/decode*; do [ -d $x ] && grep WER $x/wer_* | utils/best_wer.sh; done diff --git a/egs/wsj_noisy/s5/local/nnet2/run_5b.sh b/egs/wsj_noisy/s5/local/nnet2/run_5b.sh new file mode 100755 index 00000000000..329e917baa5 --- /dev/null +++ b/egs/wsj_noisy/s5/local/nnet2/run_5b.sh @@ -0,0 +1,70 @@ +#!/bin/bash + + +stage=0 +train_stage=-100 +# This trains only unadapted (just cepstral mean normalized) features, +# and uses various combinations of VTLN warping factor and time-warping +# factor to artificially expand the amount of data. + +. cmd.sh + +. utils/parse_options.sh # to parse the --stage option, if given + +[ $# != 0 ] && echo "Usage: local/run_4b.sh [--stage --train-stage ]" && exit 1; + +set -e + +if [ $stage -le 0 ]; then + # Create the training data. + featdir=`pwd`/mfcc/nnet5b; mkdir -p $featdir + fbank_conf=conf/fbank_40.conf + echo "--num-mel-bins=40" > $fbank_conf + steps/nnet2/get_perturbed_feats.sh --cmd "$train_cmd" \ + $fbank_conf $featdir exp/perturbed_fbanks_si284 data/train_si284 data/train_si284_perturbed_fbank & + steps/nnet2/get_perturbed_feats.sh --cmd "$train_cmd" --feature-type mfcc \ + conf/mfcc.conf $featdir exp/perturbed_mfcc_si284 data/train_si284 data/train_si284_perturbed_mfcc & + wait +fi + +if [ $stage -le 1 ]; then + steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/train_si284_perturbed_mfcc data/lang exp/tri4b exp/tri4b_ali_si284_perturbed_mfcc +fi + +if [ $stage -le 2 ]; then + steps/nnet2/train_block.sh --stage "$train_stage" \ + --cleanup false \ + --initial-learning-rate 0.01 --final-learning-rate 0.001 \ + --num-epochs 10 --num-epochs-extra 5 \ + --cmd "$decode_cmd" \ + --hidden-layer-dim 1536 \ + --num-block-layers 3 --num-normal-layers 3 \ + data/train_si284_perturbed_fbank data/lang exp/tri4b_ali_si284_perturbed_mfcc exp/nnet5b || exit 1 +fi + +if [ $stage -le 3 ]; then # create testing fbank data. + featdir=`pwd`/mfcc + fbank_conf=conf/fbank_40.conf + for x in test_eval92 test_eval93 test_dev93; do + rm -r data/${x}_fbank + cp -r data/$x data/${x}_fbank + rm -r ${x}_fbank/split* || true + steps/make_fbank.sh --fbank-config "$fbank_conf" --nj 8 \ + --cmd "$train_cmd" data/${x}_fbank exp/make_fbank/$x $featdir || exit 1; + steps/compute_cmvn_stats.sh data/${x}_fbank exp/make_fbank/$x $featdir || exit 1; + done +fi + +if [ $stage -le 4 ]; then + steps/nnet2/decode.sh --cmd "$decode_cmd" --nj 10 \ + exp/tri4b/graph_bd_tgpr data/test_dev93_fbank exp/nnet5b/decode_bd_tgpr_dev93 + + steps/nnet2/decode.sh --cmd "$decode_cmd" --nj 8 \ + exp/tri4b/graph_bd_tgpr data/test_eval92_fbank exp/nnet5b/decode_bd_tgpr_eval92 +fi + + + +exit 0; + diff --git a/egs/wsj_noisy/s5/local/nnet2/run_5b_gpu.sh b/egs/wsj_noisy/s5/local/nnet2/run_5b_gpu.sh new file mode 100755 index 00000000000..2dc5afa0e87 --- /dev/null +++ b/egs/wsj_noisy/s5/local/nnet2/run_5b_gpu.sh @@ -0,0 +1,101 @@ +#!/bin/bash + + +stage=0 +train_stage=-100 +temp_dir= +# This trains only unadapted (just cepstral mean normalized) features, +# and uses various combinations of VTLN warping factor and time-warping +# factor to artificially expand the amount of data. + + + + +. ./cmd.sh +. ./path.sh +! cuda-compiled && cat < --train-stage ]" + echo "Options: " + echo " --stage # controls partial re-runs" + echo " --train-stage # use with --stage 2 to control partial rerun of training" + echo " --temp-dir # e.g. --temp-dir /export/my-machine/dpovey/wsj-temp-5b" + echo " # (puts temporary data including MFCC and egs at this location)" + exit 1; +fi + +set -e + +if [ $stage -le 0 ]; then + # Create the training data. + + if [ ! -z "$temp_dir" ]; then + mkdir -p $temp_dir/mfcc_5b_gpu + featdir=$temp_dir/mfcc_5b_gpu + else + featdir=`pwd`/mfcc/nnet5b_gpu; + mkdir -p $featdir + fi + fbank_conf=conf/fbank_40.conf + echo "--num-mel-bins=40" > $fbank_conf + steps/nnet2/get_perturbed_feats.sh --cmd "$train_cmd" \ + $fbank_conf $featdir exp/perturbed_fbanks_si284 data/train_si284 data/train_si284_perturbed_fbank & + steps/nnet2/get_perturbed_feats.sh --cmd "$train_cmd" --feature-type mfcc \ + conf/mfcc.conf $featdir exp/perturbed_mfcc_si284 data/train_si284 data/train_si284_perturbed_mfcc & + wait +fi + +if [ $stage -le 1 ]; then + steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/train_si284_perturbed_mfcc data/lang exp/tri4b exp/tri4b_ali_si284_perturbed_mfcc +fi + +if [ $stage -le 2 ]; then + if [ ! -z "$temp_dir" ] && [ ! -e exp/nnet5b_gpu/egs ]; then + mkdir -p exp/nnet5b_gpu + mkdir -p $temp_dir/nnet5b_gpu/egs + ln -s $temp_dir/nnet5b_gpu/egs exp/nnet5b_gpu/ + fi + + steps/nnet2/train_block.sh --stage "$train_stage" \ + --num-threads 1 --max-change 40.0 --minibatch-size 512 --num-jobs-nnet 8 \ + --parallel-opts "-l gpu=1" \ + --initial-learning-rate 0.0075 --final-learning-rate 0.00075 \ + --num-epochs 10 --num-epochs-extra 5 \ + --cmd "$decode_cmd" \ + --hidden-layer-dim 1536 \ + --num-block-layers 3 --num-normal-layers 3 \ + data/train_si284_perturbed_fbank data/lang exp/tri4b_ali_si284_perturbed_mfcc exp/nnet5b_gpu || exit 1 +fi + +if [ $stage -le 3 ]; then # create testing fbank data. + featdir=`pwd`/mfcc + fbank_conf=conf/fbank_40.conf + for x in test_eval92 test_eval93 test_dev93; do + rm -r data/${x}_fbank + cp -r data/$x data/${x}_fbank + rm -r ${x}_fbank/split* || true + steps/make_fbank.sh --fbank-config "$fbank_conf" --nj 8 \ + --cmd "$train_cmd" data/${x}_fbank exp/make_fbank/$x $featdir || exit 1; + steps/compute_cmvn_stats.sh data/${x}_fbank exp/make_fbank/$x $featdir || exit 1; + done +fi + +if [ $stage -le 4 ]; then + steps/nnet2/decode.sh --cmd "$decode_cmd" --nj 10 \ + exp/tri4b/graph_bd_tgpr data/test_dev93_fbank exp/nnet5b_gpu/decode_bd_tgpr_dev93 + + steps/nnet2/decode.sh --cmd "$decode_cmd" --nj 8 \ + exp/tri4b/graph_bd_tgpr data/test_eval92_fbank exp/nnet5b_gpu/decode_bd_tgpr_eval92 +fi + + + +exit 0; + diff --git a/egs/wsj_noisy/s5/local/nnet2/run_5c.sh b/egs/wsj_noisy/s5/local/nnet2/run_5c.sh new file mode 100755 index 00000000000..e33546572ad --- /dev/null +++ b/egs/wsj_noisy/s5/local/nnet2/run_5c.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +# This is neural net training on top of adapted 40-dimensional features. +# + +train_stage=-10 +use_gpu=true + +. cmd.sh +. ./path.sh +. utils/parse_options.sh + + +if $use_gpu; then + if ! cuda-compiled; then + cat </dev/null + for data in test_eval92 test_dev93 test_eval93; do + steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 8 \ + data/${data}_hires exp/nnet3/extractor exp/nnet3/ivectors_${data} || touch exp/nnet3/.error & + done + wait + [ -f exp/nnet3/.error ] && echo "$0: error extracting iVectors." && exit 1; +fi + +exit 0; diff --git a/egs/wsj_noisy/s5/local/nnet3/run_tdnn.sh b/egs/wsj_noisy/s5/local/nnet3/run_tdnn.sh new file mode 100755 index 00000000000..71f87f82f24 --- /dev/null +++ b/egs/wsj_noisy/s5/local/nnet3/run_tdnn.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +. cmd.sh + + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +dir=exp/nnet3/nnet_tdnn_a +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat < transcript2"; +$noise_word = shift @ARGV; + +while() { + $_ =~ m:^(\S+) (.+): || die "bad line $_"; + $utt = $1; + $trans = $2; + print "$utt"; + foreach $w (split (" ",$trans)) { + $w =~ tr:a-z:A-Z:; # Upcase everything to match the CMU dictionary. . + $w =~ s:\\::g; # Remove backslashes. We don't need the quoting. + $w =~ s:^\%PERCENT$:PERCENT:; # Normalization for Nov'93 test transcripts. + $w =~ s:^\.POINT$:POINT:; # Normalization for Nov'93 test transcripts. + if($w =~ m:^\[\<\w+\]$: || # E.g. [\]$: || # E.g. [door_slam>], this means a door slammed in the next word. Delete. + $w =~ m:\[\w+/\]$: || # E.g. [phone_ring/], which indicates the start of this phenomenon. + $w =~ m:\[\/\w+]$: || # E.g. [/phone_ring], which indicates the end of this phenomenon. + $w eq "~" || # This is used to indicate truncation of an utterance. Not a word. + $w eq ".") { # "." is used to indicate a pause. Silence is optional anyway so not much + # point including this in the transcript. + next; # we won't print this word. + } elsif($w =~ m:\[\w+\]:) { # Other noises, e.g. [loud_breath]. + print " $noise_word"; + } elsif($w =~ m:^\<([\w\']+)\>$:) { + # e.g. replace with and. (the <> means verbal deletion of a word).. but it's pronounced. + print " $1"; + } elsif($w eq "--DASH") { + print " -DASH"; # This is a common issue; the CMU dictionary has it as -DASH. +# } elsif($w =~ m:(.+)\-DASH$:) { # E.g. INCORPORATED-DASH... seems the DASH gets combined with previous word +# print " $1 -DASH"; + } else { + print " $w"; + } + } + print "\n"; +} diff --git a/egs/wsj_noisy/s5/local/online/run_nnet2.sh b/egs/wsj_noisy/s5/local/online/run_nnet2.sh new file mode 100755 index 00000000000..eac76b621ef --- /dev/null +++ b/egs/wsj_noisy/s5/local/online/run_nnet2.sh @@ -0,0 +1,220 @@ +#!/bin/bash + +# this is our online-nnet2 build. it's a "multi-splice" system (i.e. we have +# splicing at various layers), with p-norm nonlinearities. We use the "accel2" +# script which uses between 2 and 14 GPUs depending how far through training it +# is. You can safely reduce the --num-jobs-final to however many GPUs you have +# on your system. + +# For joint training with RM, this script is run using the following command line, +# and note that the --stage 8 option is only needed in case you already ran the +# earlier stages. +# local/online/run_nnet2.sh --stage 8 --dir exp/nnet2_online/nnet_ms_a_partial --exit-train-stage 15 + +. cmd.sh + + +stage=0 +train_stage=-10 +use_gpu=true +dir=exp/nnet2_online/nnet_ms_a +exit_train_stage=-100 +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if $use_gpu; then + if ! cuda-compiled; then + cat </dev/null + for year in eval92 dev93; do + steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj 8 \ + --sub-speaker-frames 1500 \ + data/test_${year}_hires data/lang exp/nnet2_online/extractor \ + exp/tri4b/decode_tgpr_$year exp/nnet2_online/ivectors_spk_test_${year} || touch exp/nnet2_online/.error & + done + wait + [ -f exp/nnet2_online/.error ] && echo "$0: Error getting iVectors" && exit 1; + + for lm_suffix in bd_tgpr; do # just use the bd decoding, to avoid wasting time. + graph_dir=exp/tri4b/graph_${lm_suffix} + # use already-built graphs. + for year in eval92 dev93; do + steps/nnet2/decode.sh --nj 8 --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet2_online/ivectors_spk_test_$year \ + $graph_dir data/test_${year}_hires $dir/decode_${lm_suffix}_${year}_spk || touch exp/nnet2_online/.error & + done + done + wait + [ -f exp/nnet2_online/.error ] && echo "$0: Error decoding" && exit 1; +fi + + + + +exit 0; + +# Here are results. + +# first, this is the baseline. We choose as a baseline our best fMLLR+p-norm system trained +# on si284, so this is a very good baseline. For others you can see ../RESULTS. + + +# %WER 7.13 [ 587 / 8234, 72 ins, 93 del, 422 sub ] exp/nnet5d_gpu/decode_bd_tgpr_dev93/wer_13 +# %WER 4.06 [ 229 / 5643, 31 ins, 16 del, 182 sub ] exp/nnet5d_gpu/decode_bd_tgpr_eval92/wer_14 +# %WER 9.35 [ 770 / 8234, 161 ins, 78 del, 531 sub ] exp/nnet5d_gpu/decode_tgpr_dev93/wer_12 +# %WER 6.59 [ 372 / 5643, 91 ins, 15 del, 266 sub ] exp/nnet5d_gpu/decode_tgpr_eval92/wer_12 + + +# Here is the offline decoding of our system (note: it still has the iVectors estimated frame +# by frame, and for each utterance independently). + +for x in exp/nnet2_online/nnet_a_gpu/decode_*; do grep WER $x/wer_* | utils/best_wer.sh; done | grep -v utt +%WER 7.53 [ 620 / 8234, 63 ins, 105 del, 452 sub ] exp/nnet2_online/nnet_a_gpu/decode_bd_tgpr_dev93/wer_12 +%WER 4.47 [ 252 / 5643, 27 ins, 22 del, 203 sub ] exp/nnet2_online/nnet_a_gpu/decode_bd_tgpr_eval92/wer_13 +%WER 9.91 [ 816 / 8234, 164 ins, 90 del, 562 sub ] exp/nnet2_online/nnet_a_gpu/decode_tgpr_dev93/wer_12 +%WER 7.12 [ 402 / 5643, 91 ins, 22 del, 289 sub ] exp/nnet2_online/nnet_a_gpu/decode_tgpr_eval92/wer_13 + + # Here is the version of the above without iVectors, as done by + # ./run_nnet2_baseline.sh. It's about 0.5% absolute worse. + # There is also an _online version of that decode directory, which is + # essentially the same (we don't show the results here, as it's not really interesting). + for x in exp/nnet2_online/nnet_a_gpu_baseline/decode_*; do grep WER $x/wer_* | utils/best_wer.sh; done + %WER 8.03 [ 661 / 8234, 80 ins, 105 del, 476 sub ] exp/nnet2_online/nnet_a_gpu_baseline/decode_bd_tgpr_dev93/wer_11 + %WER 5.10 [ 288 / 5643, 43 ins, 22 del, 223 sub ] exp/nnet2_online/nnet_a_gpu_baseline/decode_bd_tgpr_eval92/wer_11 + %WER 10.51 [ 865 / 8234, 177 ins, 95 del, 593 sub ] exp/nnet2_online/nnet_a_gpu_baseline/decode_tgpr_dev93/wer_11 + %WER 7.34 [ 414 / 5643, 88 ins, 25 del, 301 sub ] exp/nnet2_online/nnet_a_gpu_baseline/decode_tgpr_eval92/wer_13 + +# Next, truly-online decoding. +# The results below are not quite as good as those in nnet_a_gpu, but I believe +# the difference is that in this setup we're not using config files, and the +# default beams/lattice-beams in the scripts are slightly different: 15.0/8.0 +# above, and 13.0/6.0 below. +for x in exp/nnet2_online/nnet_a_gpu_online/decode_*; do grep WER $x/wer_* | utils/best_wer.sh; done | grep -v utt +%WER 7.53 [ 620 / 8234, 74 ins, 97 del, 449 sub ] exp/nnet2_online/nnet_a_gpu_online/decode_bd_tgpr_dev93/wer_11 +%WER 4.45 [ 251 / 5643, 35 ins, 19 del, 197 sub ] exp/nnet2_online/nnet_a_gpu_online/decode_bd_tgpr_eval92/wer_12 +%WER 10.02 [ 825 / 8234, 166 ins, 88 del, 571 sub ] exp/nnet2_online/nnet_a_gpu_online/decode_tgpr_dev93/wer_12 +%WER 6.91 [ 390 / 5643, 103 ins, 15 del, 272 sub ] exp/nnet2_online/nnet_a_gpu_online/decode_tgpr_eval92/wer_10 + +# Below is as above, but decoding each utterance separately. It actualy seems slightly better, +# which is counterintuitive. +for x in exp/nnet2_online/nnet_a_gpu_online/decode_*; do grep WER $x/wer_* | utils/best_wer.sh; done | grep utt +%WER 7.55 [ 622 / 8234, 57 ins, 109 del, 456 sub ] exp/nnet2_online/nnet_a_gpu_online/decode_bd_tgpr_dev93_utt/wer_13 +%WER 4.43 [ 250 / 5643, 27 ins, 21 del, 202 sub ] exp/nnet2_online/nnet_a_gpu_online/decode_bd_tgpr_eval92_utt/wer_13 +%WER 9.98 [ 822 / 8234, 179 ins, 80 del, 563 sub ] exp/nnet2_online/nnet_a_gpu_online/decode_tgpr_dev93_utt/wer_11 +%WER 7.12 [ 402 / 5643, 98 ins, 18 del, 286 sub ] exp/nnet2_online/nnet_a_gpu_online/decode_tgpr_eval92_utt/wer_12 diff --git a/egs/wsj_noisy/s5/local/online/run_nnet2_baseline.sh b/egs/wsj_noisy/s5/local/online/run_nnet2_baseline.sh new file mode 100755 index 00000000000..17d0face7e3 --- /dev/null +++ b/egs/wsj_noisy/s5/local/online/run_nnet2_baseline.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +. cmd.sh + + +stage=1 +train_stage=-10 +use_gpu=true +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if $use_gpu; then + if ! cuda-compiled; then + cat </dev/null + for data in test_eval92 test_dev93 test_eval93; do + steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj 8 \ + data/${data}_hires exp/nnet2_online/extractor exp/nnet2_online/ivectors_${data} || touch exp/nnet2_online/.error & + done + wait + [ -f exp/nnet2_online/.error ] && echo "$0: error extracting iVectors." && exit 1; +fi + +exit 0; diff --git a/egs/wsj_noisy/s5/local/online/run_nnet2_discriminative.sh b/egs/wsj_noisy/s5/local/online/run_nnet2_discriminative.sh new file mode 100755 index 00000000000..a92e9c3367b --- /dev/null +++ b/egs/wsj_noisy/s5/local/online/run_nnet2_discriminative.sh @@ -0,0 +1,95 @@ +#!/bin/bash + + +# This is discriminative training, to be run after run_nnet2.sh. + +. cmd.sh + + +stage=1 +train_stage=-10 +use_gpu=true +srcdir=exp/nnet2_online/nnet_ms_a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if $use_gpu; then + if ! cuda-compiled; then + cat </dev/null || true + + for epoch in 1 2 3 4; do + # do the actual online decoding with iVectors, carrying info forward from + # previous utterances of the same speaker. + # We just do the bd_tgpr decodes; otherwise the number of combinations + # starts to get very large. + for lm_suffix in bd_tgpr; do + graph_dir=exp/tri4b/graph_${lm_suffix} + for year in eval92 dev93; do + steps/online/nnet2/decode.sh --cmd "$decode_cmd" --nj 8 --iter smbr_epoch${epoch} \ + "$graph_dir" data/test_${year} ${srcdir}_online/decode_${lm_suffix}_${year}_smbr_epoch${epoch} || touch $error_file & + done + done + done + wait + [ -f $error_file ] && echo "$0: error decoding the SMBR systems." && exit 1; +fi + diff --git a/egs/wsj_noisy/s5/local/online/run_nnet2_perturb_speed.sh b/egs/wsj_noisy/s5/local/online/run_nnet2_perturb_speed.sh new file mode 100755 index 00000000000..1a69e50f3ea --- /dev/null +++ b/egs/wsj_noisy/s5/local/online/run_nnet2_perturb_speed.sh @@ -0,0 +1,143 @@ +#!/bin/bash +# Copyright 2013 Johns Hopkins University (author: Daniel Povey) +# 2014 Tom Ko +# Apache 2.0 + +# This example script demonstrates how speed perturbation of the data helps the nnet training. + +. ./cmd.sh +. ./path.sh + +stage=-1 +train_stage=-10 +use_gpu=true +nnet_dir=exp/nnet2_online_perturb + +if $use_gpu; then + if ! cuda-compiled; then + cat < data/$y/utt2spk; + cp data/$y/utt2spk data/$y/spk2utt; + steps/compute_cmvn_stats.sh data/$y exp/make_mfcc/$y $mfccdir || exit 1; +done + + + # basis fMLLR experiments. + # First a baseline: decode per-utterance with normal fMLLR. +steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3b/graph${lang_suffix}_tgpr data/test_dev93_utt \ + exp/tri3b/decode${lang_suffix}_tgpr_dev93_utt || exit 1; +steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri3b/graph${lang_suffix}_tgpr data/test_eval92_utt \ + exp/tri3b/decode${lang_suffix}_tgpr_eval92_utt || exit 1; + + # get the fMLLR basis. +steps/get_fmllr_basis.sh --cmd "$train_cmd" \ + data/train_si84 data/lang${lang_suffix} exp/tri3b + + # decoding tri3b with basis fMLLR +steps/decode_basis_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3b/graph${lang_suffix}_tgpr data/test_dev93 \ + exp/tri3b/decode${lang_suffix}_tgpr_dev93_basis || exit 1; +steps/decode_basis_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri3b/graph${lang_suffix}_tgpr data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_tgpr_eval92_basis || exit 1; + + # The same, per-utterance. +steps/decode_basis_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3b/graph${lang_suffix}_tgpr data/test_dev93_utt \ + exp/tri3b/decode${lang_suffix}_tgpr_dev93_basis_utt || exit 1; +steps/decode_basis_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri3b/graph${lang_suffix}_tgpr data/test_eval92_utt \ + exp/tri3b/decode${lang_suffix}_tgpr_eval92_basis_utt || exit 1; + + diff --git a/egs/wsj_noisy/s5/local/run_bnf.sh b/egs/wsj_noisy/s5/local/run_bnf.sh new file mode 100644 index 00000000000..7c37bf72e7d --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_bnf.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# Note: In order to run BNF, run run_bnf.sh +. ./path.sh +. ./cmd.sh + +set -e +set -o pipefail +set -u + +. utils/parse_options.sh + +bnf_train_stage=-100 +align_dir=exp/tri4b_ali_si284 +if [ ! -f exp_bnf/tri6_bnf/.done ]; then + mkdir -p exp_bnf + mkdir -p exp_bnf/tri6_bnf + echo --------------------------------------------------------------------- + echo "Starting training the bottleneck network" + echo --------------------------------------------------------------------- + steps/nnet2/train_tanh_bottleneck.sh \ + --stage $bnf_train_stage --num-jobs-nnet 4 \ + --num-threads 1 --mix-up 5000 --max-change 40 \ + --minibatch-size 512 \ + --initial-learning-rate 0.005 \ + --final-learning-rate 0.0005 \ + --num-hidden-layers 5 \ + --bottleneck-dim 42 --hidden-layer-dim 1024 --cmd "$train_cmd" \ + data/train_si284 data/lang $align_dir exp_bnf/tri6_bnf || exit 1 + touch exp_bnf/tri6_bnf/.done +fi + +[ ! -d param_bnf ] && mkdir -p param_bnf +if [ ! -f data_bnf/train_bnf/.done ]; then + mkdir -p data_bnf + # put the archives in param_bnf/. + steps/nnet2/dump_bottleneck_features.sh --cmd "$train_cmd" \ + --transform-dir exp/tri4b_ali_si284 data/train_si284 data_bnf/train_bnf exp_bnf/tri6_bnf param_bnf exp_bnf/dump_bnf + touch data_bnf/train_bnf/.done +fi + +[ ! -d data/test_eval92 ] && echo "No such directory data/test_eval92" && exit 1; +[ ! -d data/test_dev93 ] && echo "No such directory data/test_dev93" && exit 1; +[ ! -d exp/tri4b/decode_bd_tgpr_eval92 ] && echo "No such directory exp/tri4b/decode_bd_tgpr_eval92" && exit 1; +[ ! -d exp/tri4b/decode_bd_tgpr_dev93 ] && echo "No such directory exp/tri4b/decode_bd_tgpr_dev93" && exit 1; +# put the archives in param_bnf/. +steps/nnet2/dump_bottleneck_features.sh --nj 8 \ + --transform-dir exp/tri4b/decode_bd_tgpr_eval92 data/test_eval92 data_bnf/eval92_bnf exp_bnf/tri6_bnf param_bnf exp_bnf/dump_bnf + +steps/nnet2/dump_bottleneck_features.sh --nj 10 \ + --transform-dir exp/tri4b/decode_bd_tgpr_dev93 data/test_dev93 data_bnf/dev93_bnf exp_bnf/tri6_bnf param_bnf exp_bnf/dump_bnf + + + +if [ ! data_bnf/train/.done -nt data_bnf/train_bnf/.done ]; then + steps/nnet/make_fmllr_feats.sh --cmd "$train_cmd -tc 10" \ + --transform-dir $align_dir data_bnf/train_sat data/train_si284 \ + exp/tri4b exp_bnf/make_fmllr_feats/log param_bnf/ + + steps/append_feats.sh --cmd "$train_cmd" --nj 4 \ + data_bnf/train_bnf data_bnf/train_sat data_bnf/train \ + exp_bnf/append_feats/log param_bnf/ + steps/compute_cmvn_stats.sh --fake data_bnf/train exp_bnf/make_fmllr_feats param_bnf + rm -r data_bnf/train_sat + + touch data_bnf/train/.done +fi +## preparing Bottleneck features for eval92 and dev93 +steps/nnet/make_fmllr_feats.sh \ + --nj 8 --transform-dir exp/tri4b/decode_bd_tgpr_eval92 data_bnf/eval92_sat data/test_eval92 \ + exp/tri4b_ali_si284 exp_bnf/make_fmllr_feats/log param_bnf/ +steps/nnet/make_fmllr_feats.sh \ + --nj 10 --transform-dir exp/tri4b/decode_bd_tgpr_dev93 data_bnf/dev93_sat data/test_dev93 \ + exp/tri4b_ali_si284 exp_bnf/make_fmllr_feats/log param_bnf/ + +steps/append_feats.sh --nj 4 \ + data_bnf/eval92_bnf data_bnf/eval92_sat data_bnf/eval92 \ + exp_bnf/append_feats/log param_bnf/ +steps/append_feats.sh --nj 4 \ + data_bnf/dev93_bnf data_bnf/dev93_sat data_bnf/dev93 \ + exp_bnf/append_feats/log param_bnf/ + +steps/compute_cmvn_stats.sh --fake data_bnf/eval92 exp_bnf/make_fmllr_feats param_bnf +steps/compute_cmvn_stats.sh --fake data_bnf/dev93 exp_bnf/make_fmllr_feats param_bnf + +rm -r data_bnf/eval92_sat +rm -r data_bnf/dev93_sat + +exit 0; diff --git a/egs/wsj_noisy/s5/local/run_bnf_sgmm.sh b/egs/wsj_noisy/s5/local/run_bnf_sgmm.sh new file mode 100644 index 00000000000..6cfe1df67ed --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_bnf_sgmm.sh @@ -0,0 +1,154 @@ +#!/bin/bash + +# This script builds the SGMM system on top of the kaldi internal bottleneck features. + +. ./cmd.sh + +set -e +set -o pipefail +set -u + +# Set my_nj; typically 64. +numLeaves=2500 +numGauss=15000 +numLeavesSGMM=10000 +bnf_num_gauss_ubm=600 +bnf_num_gauss_sgmm=7000 +align_dir=exp/tri4b_ali_si284 +bnf_decode_acwt=0.0357 +sgmm_group_extra_opts=(--group 3 --cmd "queue.pl -l arch=*64 --mem 7G") + +if [ ! -d exp_bnf ]; then + echo "$0: before running this script, please run local/run_bnf.sh" + exit 1; +fi + +echo --------------------------------------------------------------------- +echo "Starting exp_bnf/tri5 on" `date` +echo --------------------------------------------------------------------- +if [ ! exp_bnf/tri5/.done -nt data_bnf/train/.done ]; then + steps/train_lda_mllt.sh --splice-opts "--left-context=1 --right-context=1" \ + --dim 60 --cmd "$train_cmd" \ + $numLeaves $numGauss data_bnf/train data/lang $align_dir exp_bnf/tri5 ; + touch exp_bnf/tri5/.done +fi + +echo --------------------------------------------------------------------- +echo "Starting exp_bnf/tri6 on" `date` +echo --------------------------------------------------------------------- +if [ ! exp_bnf/tri6/.done -nt exp_bnf/tri5/.done ]; then + steps/train_sat.sh --cmd "$train_cmd" \ + $numLeaves $numGauss data_bnf/train data/lang exp_bnf/tri5 exp_bnf/tri6 + touch exp_bnf/tri6/.done +fi +echo --------------------------------------------------------------------- +echo "Decoding with SAT models on top of bottleneck features on" `date` +echo --------------------------------------------------------------------- +decode1=exp_bnf/tri6/decode_bd_tgpr_eval92 +decode2=exp_bnf/tri6/decode_bd_tgpr_dev93 +utils/mkgraph.sh \ + data/lang_test_bd_tgpr exp_bnf/tri6 exp_bnf/tri6/graph_bd_tgpr |tee exp_bnf/tri6/mkgraph.log + +mkdir -p $decode1 $decode2 +#By default, we do not care about the lattices for this step -- we just want the transforms +#Therefore, we will reduce the beam sizes, to reduce the decoding times +steps/decode_fmllr_extra.sh --skip-scoring true --beam 10 --lattice-beam 4 \ + --acwt $bnf_decode_acwt \ + exp_bnf/tri6/graph_bd_tgpr data_bnf/eval92 ${decode1} |tee ${decode1}/decode.log +steps/decode_fmllr_extra.sh --skip-scoring true --beam 10 --lattice-beam 4 \ + --acwt $bnf_decode_acwt \ + exp_bnf/tri6/graph_bd_tgpr data_bnf/dev93 ${decode2} |tee ${decode2}/decode.log + +echo --------------------------------------------------------------------- +echo "Starting exp_bnf/ubm7 on" `date` +echo --------------------------------------------------------------------- +if [ ! exp_bnf/ubm7/.done -nt exp_bnf/tri6/.done ]; then + steps/train_ubm.sh \ + $bnf_num_gauss_ubm data_bnf/train data/lang exp_bnf/tri6 exp_bnf/ubm7 + touch exp_bnf/ubm7/.done +fi + +if [ ! exp_bnf/sgmm7/.done -nt exp_bnf/ubm7/.done ]; then + echo --------------------------------------------------------------------- + echo "Starting exp_bnf/sgmm7 on" `date` + echo --------------------------------------------------------------------- + steps/train_sgmm2_group.sh \ + "${sgmm_group_extra_opts[@]}"\ + $numLeavesSGMM $bnf_num_gauss_sgmm data_bnf/train data/lang \ + exp_bnf/tri6 exp_bnf/ubm7/final.ubm exp_bnf/sgmm7 + touch exp_bnf/sgmm7/.done +fi + +## SGMM2 decoding +decode1=exp_bnf/sgmm7/decode_bd_tgpr_eval92 +decode2=exp_bnf/sgmm7/decode_bd_tgpr_dev93 + echo --------------------------------------------------------------------- + echo "Spawning $decode1 and $decode2 on" `date` + echo --------------------------------------------------------------------- + utils/mkgraph.sh \ + data/lang_test_bd_tgpr exp_bnf/sgmm7 exp_bnf/sgmm7/graph_bd_tgpr |tee exp_bnf/sgmm7/mkgraph.log + + mkdir -p $decode1 $decode2 + steps/decode_sgmm2.sh --skip-scoring false --use-fmllr true \ + --acwt $bnf_decode_acwt --scoring-opts "--min-lmwt 20 --max-lmwt 40" --cmd "$decode_cmd" \ + --transform-dir exp_bnf/tri6/decode_bd_tgpr_eval92 \ + exp_bnf/sgmm7/graph_bd_tgpr data_bnf/eval92 $decode1 |tee $decode1/decode.log + steps/decode_sgmm2.sh --skip-scoring false --use-fmllr true \ + --acwt $bnf_decode_acwt --scoring-opts "--min-lmwt 20 --max-lmwt 40" --cmd "$decode_cmd" \ + --transform-dir exp_bnf/tri6/decode_bd_tgpr_dev93 \ + exp_bnf/sgmm7/graph_bd_tgpr data_bnf/dev93 $decode2 |tee $decode2/decode.log + +if [ ! exp_bnf/sgmm7_ali/.done -nt exp_bnf/sgmm7/.done ]; then + echo --------------------------------------------------------------------- + echo "Starting exp_bnf/sgmm7_ali on" `date` + echo --------------------------------------------------------------------- + steps/align_sgmm2.sh \ + --transform-dir exp_bnf/tri6 --nj 30 --use-graphs true \ + data_bnf/train data/lang exp_bnf/sgmm7 exp_bnf/sgmm7_ali + touch exp_bnf/sgmm7_ali/.done +fi + +if [ ! exp_bnf/sgmm7_denlats/.done -nt exp_bnf/sgmm7/.done ]; then + echo --------------------------------------------------------------------- + echo "Starting exp_bnf/sgmm5_denlats on" `date` + echo --------------------------------------------------------------------- + steps/make_denlats_sgmm2.sh \ + "${sgmm_denlats_extra_opts[@]}" \ + --transform-dir exp_bnf/tri6 --nj 30 --beam 14.0 --acwt $bnf_decode_acwt --lattice-beam 8 \ + data_bnf/train data/lang exp_bnf/sgmm7_ali exp_bnf/sgmm7_denlats + touch exp_bnf/sgmm7_denlats/.done +fi + +if [ ! exp_bnf/sgmm7_mmi_b0.1/.done -nt exp_bnf/sgmm7_denlats/.done ]; then + steps/train_mmi_sgmm2.sh \ + --acwt $bnf_decode_acwt \ + --transform-dir exp_bnf/tri6 --boost 0.1 --drop-frames true \ + data_bnf/train data/lang exp_bnf/sgmm7_ali exp_bnf/sgmm7_denlats \ + exp_bnf/sgmm7_mmi_b0.1 + touch exp_bnf/sgmm7_mmi_b0.1/.done; +fi + +## SGMM_MMI rescoring +for iter in 1 2 3 4; do + # Decode SGMM+MMI (via rescoring). + decode1=exp_bnf/sgmm7_mmi_b0.1/decode_bd_tgpr_eval92_it$iter + mkdir -p $decode1 + steps/decode_sgmm2_rescore.sh --skip-scoring false --cmd "$decode_cmd" \ + --iter $iter --transform-dir exp_bnf/tri6/decode_bd_tgpr_eval92 --scoring-opts "--min-lmwt 20 --max-lmwt 40" \ + data/lang_test_bd_tgpr data_bnf/eval92 exp_bnf/sgmm7/decode_bd_tgpr_eval92 $decode1 | tee ${decode1}/decode.log +done + +for iter in 1 2 3 4; do + # Decode SGMM+MMI (via rescoring). + decode2=exp_bnf/sgmm7_mmi_b0.1/decode_bd_tgpr_dev93_it$iter + mkdir -p $decode2 + steps/decode_sgmm2_rescore.sh --skip-scoring false --cmd "$decode_cmd" \ + --iter $iter --transform-dir exp_bnf/tri6/decode_bd_tgpr_dev93 --scoring-opts "--min-lmwt 20 --max-lmwt 40" \ + data/lang_test_bd_tgpr data_bnf/dev93 exp_bnf/sgmm7/decode_bd_tgpr_dev93 $decode2 | tee ${decode2}/decode.log +done + +echo --------------------------------------------------------------------- +echo "Finished successfully on" `date` +echo --------------------------------------------------------------------- + +#exit 1 diff --git a/egs/wsj_noisy/s5/local/run_fwdbwd.sh b/egs/wsj_noisy/s5/local/run_fwdbwd.sh new file mode 100755 index 00000000000..c84f2f1e0fb --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_fwdbwd.sh @@ -0,0 +1,41 @@ +#prepare reverse lexicon and language model for backwards decoding +utils/prepare_lang.sh --reverse true data/local/dict "" data/local/lang_tmp.reverse data/lang.reverse || exit 1; +utils/reverse_lm.sh data/local/nist_lm/lm_bg_5k.arpa.gz data/lang.reverse data/lang_test_bg_5k.reverse || exit 1; +utils/reverse_lm_test.sh data/lang_test_bg_5k data/lang_test_bg_5k.reverse || exit 1; + +# normal forward decoding +utils/mkgraph.sh data/lang_test_bg_5k exp/tri2a exp/tri2a/graph_bg5k +steps/decode_fwdbwd.sh --beam 10.0 --latbeam 4.0 --nj 8 --cmd "$decode_cmd" \ + exp/tri2a/graph_bg5k data/test_eval92 exp/tri2a/decode_eval92_bg5k_10 || exit 1; + +# backward decoding +utils/mkgraph.sh --reverse data/lang_test_bg_5k.reverse exp/tri2a exp/tri2a/graph_bg5k_r +steps/decode_fwdbwd.sh --beam 10.0 --latbeam 4.0 --reverse true --nj 8 --cmd "$decode_cmd" \ + exp/tri2a/graph_bg5k_r data/test_eval92 exp/tri2a/decode_eval92_bg5k_reverse10 || exit 1; + +# pingpong decoding +steps/decode_fwdbwd.sh --beam 10.0 --max-beam 20.0 --reverse true --nj 8 --cmd "$decode_cmd" \ + --first_pass exp/tri2a/decode_eval92_bg5k_10 exp/tri2a/graph_bg5k_r data/test_eval92 exp/tri2a/decode_eval92_bg5k_pingpong10 || exit 1; +steps/decode_fwdbwd.sh --beam 10.0 --max-beam 20.0 --nj 8 --cmd "$decode_cmd" \ + --first_pass exp/tri2a/decode_eval92_bg5k_reverse10 exp/tri2a/graph_bg5k data/test_eval92 exp/tri2a/decode_eval92_bg5k_pongping10 || exit 1; + +# same for bigger language models (on machine with 8GB RAM, you can run the whole decoding in 3-4 min without SGE) +utils/prepare_lang.sh --reverse true data/local/dict_larger "" data/local/lang_larger.reverse data/lang_bd.reverse || exit; +utils/reverse_lm.sh --lexicon data/local/dict_larger/lexicon.txt data/local/local_lm/3gram-mincount/lm_pr6.0.gz data/lang_bd.reverse data/lang_test_bd_tgpr.reverse || exit 1; +utils/reverse_lm_test.sh data/lang_test_bd_tgpr data/lang_test_bd_tgpr.reverse || exit 1; + +utils/mkgraph.sh data/lang_test_bd_tgpr exp/tri2a exp/tri2a/graph_bd_tgpr +steps/decode_fwdbwd.sh --beam 10.0 --latbeam 4.0 --nj 4 --cmd run.pl \ + exp/tri2a/graph_bd_tgpr data/test_eval92 exp/tri2a/decode_eval92_bdtgpr4_10 || exit 1; + +utils/mkgraph.sh --reverse data/lang_test_bd_tgpr.reverse exp/tri2a exp/tri2a/graph_bd_tgpr_r +steps/decode_fwdbwd.sh --beam 10.0 --latbeam 4.0 --reverse true --nj 4 --cmd run.pl \ + exp/tri2a/graph_bd_tgpr_r data/test_eval92 exp/tri2a/decode_eval92_bdtgpr4_reverse10 || exit 1; + +steps/decode_fwdbwd.sh --beam 10.0 --max-beam 20.0 --reverse true --nj 4 --cmd run.pl \ + --first_pass exp/tri2a/decode_eval92_bdtgpr4_10 exp/tri2a/graph_bd_tgpr_r data/test_eval92 \ + exp/tri2a/decode_eval92_bdtgpr4_pingpong10 || exit 1; + +steps/decode_fwdbwd.sh --beam 10.0 --max-beam 20.0 --nj 4 --cmd run.pl \ + --first_pass exp/tri2a/decode_eval92_bdtgpr4_reverse10 exp/tri2a/graph_bd_tgpr data/test_eval92 \ + exp/tri2a/decode_eval92_bdtgpr4_pongping10 || exit 1; diff --git a/egs/wsj_noisy/s5/local/run_gender_dep.sh b/egs/wsj_noisy/s5/local/run_gender_dep.sh new file mode 100755 index 00000000000..050474d70b6 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_gender_dep.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# This script is not really finished, all it does is train a model with its +# means adapted to the female data, to demonstrate MAP adaptation. To have real +# gender dependent decoding (which anyway we're not very enthused about), we +# would have to train both models, do some kind of gender identification, and +# then decode. Or we could use the gender information in the test set. But +# anyway that's not a direction we really want to go right now. + +. ./cmd.sh + +awk '{if ($2 == "f") { print $1; }}' < data/train_si84/spk2gender > spklist + +utils/subset_data_dir.sh --spk-list spklist data/train_si84 data/train_si84_f + +steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + data/train_si84_f data/lang exp/tri2b exp/tri2b_ali_si84_f + +steps/train_map.sh --cmd "$train_cmd" data/train_si84_f data/lang exp/tri2b_ali_si84_f exp/tri2b_f diff --git a/egs/wsj_noisy/s5/local/run_kl_hmm.sh b/egs/wsj_noisy/s5/local/run_kl_hmm.sh new file mode 100644 index 00000000000..9e7679a7675 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_kl_hmm.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Copyright 2013 Idiap Research Institute (Author: David Imseng) +# Apache 2.0 + +. cmd.sh + +states=20000 +dir=exp/tri4b_pretrain-dbn_dnn/ + +steps/kl_hmm/build_tree.sh --cmd "$big_memory_cmd" --thresh -1 --nnet_dir exp/tri4b_pretrain-dbn_dnn/ \ + ${states} data-fmllr-tri4b/train_si284 data/lang exp/tri4b_ali_si284 exp/tri4b-${states} || exit 1; + +utils/mkgraph.sh data/lang_test_bd_tgpr exp/tri4b-${states} exp/tri4b-${states}/graph_bd_tgpr || exit 1; + +steps/kl_hmm/train_kl_hmm.sh --nj 30 --cmd "$big_memory_cmd" --model exp/tri4b-${states}/final.mdl data-fmllr-tri4b/train_si284 exp/tri4b-${states} $dir/kl-hmm-${states} + +steps/kl_hmm/decode_kl_hmm.sh --nj 10 --cmd "$big_memory_cmd" --acwt 0.1 --nnet $dir/kl-hmm-${states}/final.nnet --model exp/tri4b-${states}/final.mdl \ + --config conf/decode_dnn.config exp/tri4b-${states}/graph_bd_tgpr/ data-fmllr-tri4b/test_dev93 $dir/decode_dev93_kl-hmm-bd-${states}_tst + +steps/kl_hmm/decode_kl_hmm.sh --nj 8 --cmd "$big_memory_cmd" --acwt 0.1 --nnet $dir/kl-hmm-${states}/final.nnet --model exp/tri4b-${states}/final.mdl \ + --config conf/decode_dnn.config exp/tri4b-${states}/graph_bd_tgpr/ data-fmllr-tri4b/test_eval92 $dir/decode_eval92_kl-hmm-bd-${states}_tst + + diff --git a/egs/wsj_noisy/s5/local/run_mmi_tri2b.sh b/egs/wsj_noisy/s5/local/run_mmi_tri2b.sh new file mode 100755 index 00000000000..d7ddbfbaf62 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_mmi_tri2b.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +lang_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +. ./cmd.sh + +# Train and test MMI (and boosted MMI) on tri2b system. +steps/make_denlats.sh --sub-split 20 --nj 10 --cmd "$train_cmd" \ + data/train_si84 data/lang${lang_suffix} \ + exp/tri2b exp/tri2b_denlats_si84 || exit 1; + +# train the basic MMI system. +steps/train_mmi.sh --cmd "$train_cmd" \ + data/train_si84 data/lang${lang_suffix} exp/tri2b_ali_si84 \ + exp/tri2b_denlats_si84 exp/tri2b_mmi || exit 1; +for iter in 3 4; do + steps/decode_si.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ + exp/tri2b/graph${lang_suffix}_tgpr data/test_dev93 \ + exp/tri2b_mmi/decode${lang_suffix}_tgpr_dev93_it$iter & + steps/decode_si.sh --nj 8 --cmd "$decode_cmd" --iter $iter \ + exp/tri2b/graph${lang_suffix}_tgpr data/test_eval92 \ + exp/tri2b_mmi/decode${lang_suffix}_tgpr_eval92_it$iter & +done + +# MMI with 0.1 boosting factor. +steps/train_mmi.sh --cmd "$train_cmd" --boost 0.1 \ + data/train_si84 data/lang${lang_suffix} exp/tri2b_ali_si84 \ + exp/tri2b_denlats_si84 exp/tri2b_mmi_b0.1 || exit 1; + +for iter in 3 4; do + steps/decode_si.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ + exp/tri2b/graph${lang_suffix}_tgpr data/test_dev93 \ + exp/tri2b_mmi_b0.1/decode${lang_suffix}_tgpr_dev93_it$iter & + steps/decode_si.sh --nj 8 --cmd "$decode_cmd" --iter $iter \ + exp/tri2b/graph${lang_suffix}_tgpr data/test_eval92 \ + exp/tri2b_mmi_b0.1/decode${lang_suffix}_tgpr_eval92_it$iter & +done + + +# Train a UBM with 400 components, for fMMI. +steps/train_diag_ubm.sh --silence-weight 0.5 --nj 10 --cmd "$train_cmd" \ + 400 data/train_si84 data/lang${lang_suffix} exp/tri2b_ali_si84 exp/dubm2b + +steps/train_mmi_fmmi.sh --boost 0.1 --cmd "$train_cmd" \ + data/train_si84 data/lang${lang_suffix} exp/tri2b_ali_si84 \ + exp/dubm2b exp/tri2b_denlats_si84 exp/tri2b_fmmi_b0.1 + +for iter in `seq 3 8`; do + steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ + exp/tri2b/graph${lang_suffix}_tgpr data/test_dev93 \ + exp/tri2b_fmmi_b0.1/decode${lang_suffix}_tgpr_dev93_it$iter & +done + +steps/train_mmi_fmmi.sh --learning-rate 0.005 --boost 0.1 --cmd "$train_cmd" \ + data/train_si84 data/lang${lang_suffix} exp/tri2b_ali_si84 \ + exp/dubm2b exp/tri2b_denlats_si84 exp/tri2b_fmmi_b0.1_lr0.005 || exit 1; +for iter in `seq 3 8`; do + steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ + exp/tri2b/graph${lang_suffix}_tgpr data/test_dev93 \ + exp/tri2b_fmmi_b0.1_lr0.005/decode${lang_suffix}_tgpr_dev93_it$iter & +done + +steps/train_mmi_fmmi_indirect.sh --boost 0.1 --cmd "$train_cmd" \ + data/train_si84 data/lang${lang_suffix} exp/tri2b_ali_si84 \ + exp/dubm2b exp/tri2b_denlats_si84 exp/tri2b_fmmi_indirect_b0.1 +for iter in `seq 3 8`; do + steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ + exp/tri2b/graph${lang_suffix}_tgpr data/test_dev93 \ + exp/tri2b_fmmi_indirect_b0.1/decode${lang_suffix}_tgpr_dev93_it$iter & +done diff --git a/egs/wsj_noisy/s5/local/run_mmi_tri4b.sh b/egs/wsj_noisy/s5/local/run_mmi_tri4b.sh new file mode 100755 index 00000000000..db34f8e1d84 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_mmi_tri4b.sh @@ -0,0 +1,50 @@ +#!/bin/bash +. ./cmd.sh + +steps/make_denlats.sh --nj 30 --sub-split 30 --cmd "$train_cmd" \ + --transform-dir exp/tri4b_ali_si284 \ + data/train_si284 data/lang exp/tri4b exp/tri4b_denlats_si284 || exit 1; + +steps/train_mmi.sh --cmd "$train_cmd" --boost 0.1 \ + data/train_si284 data/lang exp/tri4b_ali_si284 exp/tri4b_denlats_si284 \ + exp/tri4b_mmi_b0.1 || exit 1; + +steps/decode.sh --nj 10 --cmd "$decode_cmd" --transform-dir exp/tri3b/decode_tgpr_dev93 \ + exp/tri4b/graph_tgpr data/test_dev93 exp/tri4b_mmi_b0.1/decode_tgpr_dev93 + +#first, train UBM for fMMI experiments. +steps/train_diag_ubm.sh --silence-weight 0.5 --nj 30 --cmd "$train_cmd" \ + 600 data/train_si284 data/lang exp/tri4b_ali_si284 exp/dubm4b + +# Next, fMMI+MMI. +steps/train_mmi_fmmi.sh \ + --boost 0.1 --cmd "$train_cmd" data/train_si284 data/lang exp/tri4b_ali_si284 exp/dubm4b exp/tri4b_denlats_si284 \ + exp/tri4b_fmmi_a || exit 1; + +for iter in 3 4 5 6 7 8; do + steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri3b/decode_tgpr_dev93 exp/tri4b/graph_tgpr data/test_dev93 \ + exp/tri4b_fmmi_a/decode_tgpr_dev93_it$iter & +done +# decode the last iter with the bd model. +for iter in 8; do + steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri3b/decode_bd_tgpr_dev93 exp/tri4b/graph_bd_tgpr data/test_dev93 \ + exp/tri4b_fmmi_a/decode_bd_tgpr_dev93_it$iter & + steps/decode_fmmi.sh --nj 8 --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri3b/decode_bd_tgpr_eval92 exp/tri4b/graph_bd_tgpr data/test_eval92 \ + exp/tri4b_fmmi_a/decode_tgpr_eval92_it$iter & +done + + +# fMMI + mmi with indirect differential. +steps/train_mmi_fmmi_indirect.sh \ + --boost 0.1 --cmd "$train_cmd" data/train_si284 data/lang exp/tri4b_ali_si284 exp/dubm4b exp/tri4b_denlats_si284 \ + exp/tri4b_fmmi_indirect || exit 1; + +for iter in 3 4 5 6 7 8; do + steps/decode_fmmi.sh --nj 10 --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri3b/decode_tgpr_dev93 exp/tri4b/graph_tgpr data/test_dev93 \ + exp/tri4b_fmmi_indirect/decode_tgpr_dev93_it$iter & +done + diff --git a/egs/wsj_noisy/s5/local/run_nnet2.sh b/egs/wsj_noisy/s5/local/run_nnet2.sh new file mode 100755 index 00000000000..728356e3ec7 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_nnet2.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +. ./cmd.sh + +# This shows what you can potentially run; you'd probably want to pick and choose. + +use_gpu=true + +if $use_gpu; then + local/nnet2/run_5b_gpu.sh # various VTLN combinations, Mel-filterbank features, si284 train (multiplied by 5). + local/nnet2/run_5c.sh --use-gpu true # this is on top of fMLLR features. + local/nnet2/run_6c_gpu.sh # this is discriminative training of tanh neural nets on top of run_5c_gpu.sh + local/nnet2/run_5d.sh --use-gpu true # this is p-norm training on top of fMLLR features. + local/nnet2/run_5e_gpu.sh # this is ensemble training of p-norm nnets on top of fMLLR features. + local/nnet2/run_6d_gpu.sh # this is discriminative training of p-norm neural nets on top of run_5d_gpu.sh +else + local/nnet2/run_5b.sh # various VTLN combinations, Mel-filterbank features, si284 train (multiplied by 5). + local/nnet2/run_5c.sh --use-gpu false # this is on top of fMLLR features. + local/nnet2/run_5d.sh --use-gpu false # this is p-norm on top of fMLLR features. +fi + + diff --git a/egs/wsj_noisy/s5/local/run_raw_fmllr.sh b/egs/wsj_noisy/s5/local/run_raw_fmllr.sh new file mode 100644 index 00000000000..69f716b80f5 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_raw_fmllr.sh @@ -0,0 +1,67 @@ +#!/bin/bash + + +steps/align_raw_fmllr.sh --nj 10 --cmd "$train_cmd" --use-graphs true \ + data/train_si84 data/lang exp/tri2b exp/tri2b_ali_si84_raw + +steps/train_raw_sat.sh --cmd "$train_cmd" \ + 2500 15000 data/train_si84 data/lang exp/tri2b_ali_si84_raw exp/tri3c || exit 1; + + +mfccdir=mfcc +for x in test_eval92 test_eval93 test_dev93 ; do + y=${x}_utt + mkdir -p data/$y + cp -r data/$x/* data/$y + cat data/$x/utt2spk | awk '{print $1, $1;}' > data/$y/utt2spk; + cp data/$y/utt2spk data/$y/spk2utt; + steps/compute_cmvn_stats.sh data/$y exp/make_mfcc/$y $mfccdir || exit 1; +done + +( +utils/mkgraph.sh data/lang_test_tgpr exp/tri3c exp/tri3c/graph_tgpr || exit 1; +steps/decode_raw_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3c/graph_tgpr data/test_dev93 exp/tri3c/decode_tgpr_dev93 || exit 1; +steps/decode_raw_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri3c/graph_tgpr data/test_eval92 exp/tri3c/decode_tgpr_eval92 || exit 1; + +steps/decode_raw_fmllr.sh --nj 30 --cmd "$decode_cmd" \ + exp/tri3c/graph_tgpr data/test_dev93_utt exp/tri3c/decode_tgpr_dev93_utt || exit 1; +steps/decode_raw_fmllr.sh --nj 30 --cmd "$decode_cmd" \ + exp/tri3c/graph_tgpr data/test_eval92_utt exp/tri3c/decode_tgpr_eval92_utt || exit 1; + +steps/decode_raw_fmllr.sh --use-normal-fmllr true --nj 10 --cmd "$decode_cmd" \ + exp/tri3c/graph_tgpr data/test_dev93 exp/tri3c/decode_tgpr_dev93_2fmllr || exit 1; +steps/decode_raw_fmllr.sh --use-normal-fmllr true --nj 8 --cmd "$decode_cmd" \ + exp/tri3c/graph_tgpr data/test_eval92 exp/tri3c/decode_tgpr_eval92_2fmllr || exit 1; +)& + +( +utils/mkgraph.sh data/lang_test_bd_tgpr exp/tri3c exp/tri3c/graph_bd_tgpr || exit 1; + +steps/decode_raw_fmllr.sh --cmd "$decode_cmd" --nj 8 exp/tri3c/graph_bd_tgpr \ + data/test_eval92 exp/tri3c/decode_bd_tgpr_eval92 + steps/decode_raw_fmllr.sh --cmd "$decode_cmd" --nj 10 exp/tri3c/graph_bd_tgpr \ + data/test_dev93 exp/tri3c/decode_bd_tgpr_dev93 +)& + +steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/train_si284 data/lang exp/tri3c exp/tri3c_ali_si284 || exit 1; + + +steps/train_raw_sat.sh --cmd "$train_cmd" \ + 4200 40000 data/train_si284 data/lang exp/tri3c_ali_si284 exp/tri4d || exit 1; +( + utils/mkgraph.sh data/lang_test_tgpr exp/tri4d exp/tri4d/graph_tgpr || exit 1; + steps/decode_raw_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4d/graph_tgpr data/test_dev93 exp/tri4d/decode_tgpr_dev93 || exit 1; + steps/decode_raw_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4d/graph_tgpr data/test_eval92 exp/tri4d/decode_tgpr_eval92 || exit 1; +) & + + +wait + + +#for x in exp/tri3{b,c}/decode_tgpr*; do grep WER $x/wer_* | utils/best_wer.sh ; done + diff --git a/egs/wsj_noisy/s5/local/run_rnnlm-hs_tri3b.sh b/egs/wsj_noisy/s5/local/run_rnnlm-hs_tri3b.sh new file mode 100755 index 00000000000..302973aaf2f --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_rnnlm-hs_tri3b.sh @@ -0,0 +1,122 @@ +#!/bin/bash + +lang_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +. cmd.sh + # This step interpolates a small RNNLM (with weight 0.15) with the 4-gram LM. +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.15 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h30.voc10k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs30_0.15 || exit 1; + +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.3 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h100.voc20k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs100_0.3 || exit 1; + +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.3 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h300.voc30k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs300_0.3 || exit 1; + +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.3 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3 || exit 1; + +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.3 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000 + +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.3 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000 \ + || exit 1; + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.4_N1000 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000 $dir +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --stage 7 --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.4 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.4 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3 $dir +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --stage 7 --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.4 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.15 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3 $dir +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --stage 7 --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.15 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir + +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --N 10 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.3 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N10 \ + || exit 1; + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.4_N1000 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000 $dir +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --stage 7 --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.4 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.15_N1000 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000 $dir +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --stage 7 --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.15 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.5_N1000 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000 $dir +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --stage 7 --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.5 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.75_N1000 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000 $dir +steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \ + --stage 7 --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.75 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm-hs.h400.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir diff --git a/egs/wsj_noisy/s5/local/run_rnnlms_sgmm5b.sh b/egs/wsj_noisy/s5/local/run_rnnlms_sgmm5b.sh new file mode 100755 index 00000000000..67fcee50a93 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_rnnlms_sgmm5b.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +for test in dev93 eval92; do + + steps/lmrescore.sh --cmd "$decode_cmd" data/lang_test_bd_tgpr data/lang_test_bd_fg \ + data/test_${test} exp/sgmm5b_mmi_b0.1/decode_bd_tgpr_${test}_it4 exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4 || exit 1; + + +# Note: for N-best-list generation, choosing the acoustic scale (12) that gave +# the best WER on this test set. Ideally we should do this on a dev set. + + # This step interpolates a small RNNLM (with weight 0.25) with the 4-gram LM. + steps/rnnlmrescore.sh \ + --N 100 --cmd "$decode_cmd" --inv-acwt 12 \ + 0.25 data/lang_test_bd_fg data/local/rnnlm.h30.voc10k data/test_${test} \ + exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4 exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4_rnnlm30_0.25 \ + || exit 1; + + steps/rnnlmrescore.sh \ + --N 100 --cmd "$decode_cmd" --inv-acwt 12 \ + 0.5 data/lang_test_bd_fg data/local/rnnlm.h100.voc20k data/test_${test} \ + exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4 exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4_rnnlm100_0.5 \ + || exit 1; + + steps/rnnlmrescore.sh \ + --N 100 --cmd "$decode_cmd" --inv-acwt 12 \ + 0.5 data/lang_test_bd_fg data/local/rnnlm.h200.voc30k data/test_${test} \ + exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4 exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4_rnnlm200_0.5 \ + || exit 1; + + steps/rnnlmrescore.sh \ + --N 100 --cmd "$decode_cmd" --inv-acwt 12 \ + 0.5 data/lang_test_bd_fg data/local/rnnlm.h300.voc40k data/test_${test} \ + exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4 exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4_rnnlm300_0.5 \ + || exit 1; + + steps/rnnlmrescore.sh \ + --N 100 --cmd "$decode_cmd" --inv-acwt 12 \ + 0.75 data/lang_test_bd_fg data/local/rnnlm.h300.voc40k data/test_${test} \ + exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4 exp/sgmm5b_mmi_b0.1/decode_bd_fg_${test}_it4_rnnlm300_0.75 \ + || exit 1; +done diff --git a/egs/wsj_noisy/s5/local/run_rnnlms_tri3b.sh b/egs/wsj_noisy/s5/local/run_rnnlms_tri3b.sh new file mode 100755 index 00000000000..5d056860848 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_rnnlms_tri3b.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +lang_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +. cmd.sh + + # This step interpolates a small RNNLM (with weight 0.25) with the 4-gram LM. +steps/rnnlmrescore.sh \ + --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.25 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm.h30.voc10k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm30_0.25 || exit 1; + +steps/rnnlmrescore.sh \ + --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.5 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm.h100.voc20k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm100_0.5 || exit 1; + +steps/rnnlmrescore.sh \ + --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.5 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm.h200.voc30k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm200_0.5 || exit 1; + +steps/rnnlmrescore.sh \ + --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.5 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm.h300.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm300_0.5 || exit 1; + +steps/rnnlmrescore.sh \ + --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.5 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm.h300.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm300_0.5_N1000 + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm300_0.75_N1000 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm300_0.5_N1000 $dir +steps/rnnlmrescore.sh \ + --stage 7 --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.75 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm.h300.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm300_0.75 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm300_0.5 $dir +steps/rnnlmrescore.sh \ + --stage 7 --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.75 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm.h300.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir + +dir=exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm300_0.25 +rm -rf $dir +cp -r exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm300_0.5 $dir +steps/rnnlmrescore.sh \ + --stage 7 --N 100 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.25 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm.h300.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg $dir + +steps/rnnlmrescore.sh \ + --N 10 --cmd "$decode_cmd" --inv-acwt 17 \ + 0.5 data/lang${lang_suffix}_test_bd_fg \ + data/local/rnnlm.h300.voc40k data/test_eval92 \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg \ + exp/tri3b/decode${lang_suffix}_bd_tgpr_eval92_fg_rnnlm300_0.5_N10 \ + || exit 1; + diff --git a/egs/wsj_noisy/s5/local/run_segmentation.sh b/egs/wsj_noisy/s5/local/run_segmentation.sh new file mode 100755 index 00000000000..0eee275419c --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_segmentation.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# Copyright 2014 Guoguo Chen +# Apache 2.0 + +# This script demonstrates how to re-segment long audios into short segments. +# The basic idea is to decode with an existing in-domain acoustic model, and a +# bigram language model built from the reference, and then work out the +# segmentation from a ctm like file. + +. ./cmd.sh +. ./path.sh + +local/append_utterances.sh data/train_si284 data/train_si284_long +steps/cleanup/split_long_utterance.sh \ + --seg-length 30 --overlap-length 5 \ + data/train_si284_long data/train_si284_split + +steps/make_mfcc.sh --cmd "$train_cmd" --nj 64 \ + data/train_si284_split exp/make_mfcc/train_si284_split mfcc || exit 1; +steps/compute_cmvn_stats.sh data/train_si284_split \ + exp/make_mfcc/train_si284_split mfcc || exit 1; + +steps/cleanup/make_segmentation_graph.sh \ + --cmd "$mkgraph_cmd" --nj 32 \ + data/train_si284_split/ data/lang exp/tri2b/ \ + exp/tri2b/graph_train_si284_split || exit 1; + +steps/cleanup/decode_segmentation.sh \ + --nj 64 --cmd "$decode_cmd" --skip-scoring true \ + exp/tri2b/graph_train_si284_split \ + data/train_si284_split exp/tri2b/decode_train_si284_split || exit 1; + +steps/get_ctm.sh --cmd "$decode_cmd" data/train_si284_split \ + exp/tri2b/graph_train_si284_split exp/tri2b/decode_train_si284_split + +steps/cleanup/make_segmentation_data_dir.sh --wer-cutoff 0.9 \ + --min-sil-length 0.5 --max-seg-length 15 --min-seg-length 1 \ + exp/tri2b/decode_train_si284_split/score_10/train_si284_split.ctm \ + data/train_si284_split data/train_si284_reseg + +# Now, use the re-segmented data for training. +steps/make_mfcc.sh --cmd "$train_cmd" --nj 64 \ + data/train_si284_reseg exp/make_mfcc/train_si284_reseg mfcc || exit 1; +steps/compute_cmvn_stats.sh data/train_si284_reseg \ + exp/make_mfcc/train_si284_reseg mfcc || exit 1; + +steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/train_si284_reseg data/lang exp/tri3b exp/tri3b_ali_si284_reseg || exit 1; + +steps/train_sat.sh --cmd "$train_cmd" \ + 4200 40000 data/train_si284_reseg \ + data/lang exp/tri3b_ali_si284_reseg exp/tri4c || exit 1; + +utils/mkgraph.sh data/lang_test_tgpr exp/tri4c exp/tri4c/graph_tgpr || exit 1; +steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4c/graph_tgpr data/test_dev93 exp/tri4c/decode_tgpr_dev93 || exit 1; +steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4c/graph_tgpr data/test_eval92 exp/tri4c/decode_tgpr_eval92 || exit 1; diff --git a/egs/wsj_noisy/s5/local/run_sgmm.sh b/egs/wsj_noisy/s5/local/run_sgmm.sh new file mode 100755 index 00000000000..27d8449896f --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_sgmm.sh @@ -0,0 +1,112 @@ +#!/bin/bash + +# This script is invoked from ../run.sh +# It contains some SGMM-related scripts that I am breaking out of the main run.sh for clarity. + +. cmd.sh + +# SGMM system on si84 data [sgmm5a]. Note: the system we aligned from used the si284 data for +# training, but this shouldn't have much effect. + +( + steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/train_si84 data/lang exp/tri4b exp/tri4b_ali_si84 || exit 1; + + steps/train_ubm.sh --cmd "$train_cmd" \ + 400 data/train_si84 data/lang exp/tri4b_ali_si84 exp/ubm5a || exit 1; + + steps/train_sgmm.sh --cmd "$train_cmd" \ + 3500 10000 data/train_si84 data/lang exp/tri4b_ali_si84 \ + exp/ubm5a/final.ubm exp/sgmm5a || exit 1; + + ( + utils/mkgraph.sh data/lang_test_tgpr exp/sgmm5a exp/sgmm5a/graph_tgpr + steps/decode_sgmm.sh --nj 10 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_dev93 \ + exp/sgmm5a/graph_tgpr data/test_dev93 exp/sgmm5a/decode_tgpr_dev93 + ) & + + steps/align_sgmm.sh --nj 30 --cmd "$train_cmd" --transform-dir exp/tri4b_ali_si84 \ + --use-graphs true --use-gselect true data/train_si84 data/lang exp/sgmm5a exp/sgmm5a_ali_si84 || exit 1; + steps/make_denlats_sgmm.sh --nj 30 --sub-split 30 --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si84 \ + data/train_si84 data/lang exp/sgmm5a_ali_si84 exp/sgmm5a_denlats_si84 + + steps/train_mmi_sgmm.sh --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si84 --boost 0.1 \ + data/train_si84 data/lang exp/sgmm5a_ali_si84 exp/sgmm5a_denlats_si84 exp/sgmm5a_mmi_b0.1 + + for iter in 1 2 3 4; do + steps/decode_sgmm_rescore.sh --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri4b/decode_tgpr_dev93 data/lang_test_tgpr data/test_dev93 exp/sgmm5a/decode_tgpr_dev93 \ + exp/sgmm5a_mmi_b0.1/decode_tgpr_dev93_it$iter & + done + + steps/train_mmi_sgmm.sh --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si84 --boost 0.1 \ + --update-opts "--cov-min-value=0.9" data/train_si84 data/lang exp/sgmm5a_ali_si84 exp/sgmm5a_denlats_si84 exp/sgmm5a_mmi_b0.1_m0.9 + + for iter in 1 2 3 4; do + steps/decode_sgmm_rescore.sh --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri4b/decode_tgpr_dev93 data/lang_test_tgpr data/test_dev93 exp/sgmm5a/decode_tgpr_dev93 \ + exp/sgmm5a_mmi_b0.1_m0.9/decode_tgpr_dev93_it$iter & + done + +) & + + +( +# The next commands are the same thing on all the si284 data. + +# SGMM system on the si284 data [sgmm5b] + steps/train_ubm.sh --cmd "$train_cmd" \ + 600 data/train_si284 data/lang exp/tri4b_ali_si284 exp/ubm5b || exit 1; + + steps/train_sgmm.sh --cmd "$train_cmd" \ + 5500 25000 data/train_si284 data/lang exp/tri4b_ali_si284 \ + exp/ubm5b/final.ubm exp/sgmm5b || exit 1; + + ( + utils/mkgraph.sh data/lang_test_tgpr exp/sgmm5b exp/sgmm5b/graph_tgpr + steps/decode_sgmm.sh --nj 10 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_dev93 \ + exp/sgmm5b/graph_tgpr data/test_dev93 exp/sgmm5b/decode_tgpr_dev93 + steps/decode_sgmm.sh --nj 8 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_eval92 \ + exp/sgmm5b/graph_tgpr data/test_eval92 exp/sgmm5b/decode_tgpr_eval92 + + utils/mkgraph.sh data/lang_test_bd_tgpr exp/sgmm5b exp/sgmm5b/graph_bd_tgpr || exit 1; + steps/decode_sgmm.sh --nj 10 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_bd_tgpr_dev93 \ + exp/sgmm5b/graph_bd_tgpr data/test_dev93 exp/sgmm5b/decode_bd_tgpr_dev93 + steps/decode_sgmm.sh --nj 8 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_bd_tgpr_eval92 \ + exp/sgmm5b/graph_bd_tgpr data/test_eval92 exp/sgmm5b/decode_bd_tgpr_eval92 + ) & + + steps/align_sgmm.sh --nj 30 --cmd "$train_cmd" --transform-dir exp/tri4b_ali_si284 \ + --use-graphs true --use-gselect true data/train_si284 data/lang exp/sgmm5b exp/sgmm5b_ali_si284 + + steps/make_denlats_sgmm.sh --nj 30 --sub-split 30 --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si284 \ + data/train_si284 data/lang exp/sgmm5b_ali_si284 exp/sgmm5b_denlats_si284 + + steps/train_mmi_sgmm.sh --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si284 --boost 0.1 \ + data/train_si284 data/lang exp/sgmm5b_ali_si284 exp/sgmm5b_denlats_si284 exp/sgmm5b_mmi_b0.1 + + for iter in 1 2 3 4; do + for test in dev93 eval92; do + steps/decode_sgmm_rescore.sh --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri4b/decode_tgpr_${test} data/lang_test_tgpr data/test_${test} exp/sgmm5b/decode_tgpr_${test} \ + exp/sgmm5b_mmi_b0.1/decode_tgpr_${test}_it$iter & + + steps/decode_sgmm_rescore.sh --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri4b/decode_bd_tgpr_${test} data/lang_test_bd_tgpr data/test_${test} exp/sgmm5b/decode_bd_tgpr_${test} \ + exp/sgmm5b_mmi_b0.1/decode_bd_tgpr_${test}_it$iter & + done + done +) & + + + +# Train quinphone SGMM system. + +steps/train_sgmm.sh --cmd "$train_cmd" \ + --context-opts "--context-width=5 --central-position=2" \ + 5500 25000 data/train_si284 data/lang exp/tri4b_ali_si284 \ + exp/ubm5b/final.ubm exp/sgmm5c || exit 1; + +# Decode from lattices in exp/sgmm5a/decode_tgpr_dev93. +steps/decode_sgmm_fromlats.sh --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_dev93 \ + data/test_dev93 data/lang_test_tgpr exp/sgmm5a/decode_tgpr_dev93 exp/sgmm5c/decode_tgpr_dev93 diff --git a/egs/wsj_noisy/s5/local/run_sgmm2.sh b/egs/wsj_noisy/s5/local/run_sgmm2.sh new file mode 100755 index 00000000000..d767b054499 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_sgmm2.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +# This script is invoked from ../run.sh +# It contains some SGMM-related scripts that I am breaking out of the main run.sh for clarity. + +. cmd.sh + +# Note: you might want to try to give the option --spk-dep-weights=false to train_sgmm2.sh; +# this takes out the "symmetric SGMM" part which is not always helpful. + +# SGMM system on si84 data [sgmm5a]. Note: the system we aligned from used the si284 data for +# training, but this shouldn't have much effect. + +( + steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/train_si84 data/lang exp/tri4b exp/tri4b_ali_si84 || exit 1; + + steps/train_ubm.sh --cmd "$train_cmd" \ + 400 data/train_si84 data/lang exp/tri4b_ali_si84 exp/ubm5a || exit 1; + + steps/train_sgmm2.sh --cmd "$train_cmd" \ + 7000 9000 data/train_si84 data/lang exp/tri4b_ali_si84 \ + exp/ubm5a/final.ubm exp/sgmm2_5a || exit 1; + + ( + utils/mkgraph.sh data/lang_test_tgpr exp/sgmm2_5a exp/sgmm2_5a/graph_tgpr + steps/decode_sgmm2.sh --nj 10 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_dev93 \ + exp/sgmm2_5a/graph_tgpr data/test_dev93 exp/sgmm2_5a/decode_tgpr_dev93 + ) & + + steps/align_sgmm2.sh --nj 30 --cmd "$train_cmd" --transform-dir exp/tri4b_ali_si84 \ + --use-graphs true --use-gselect true data/train_si84 data/lang exp/sgmm2_5a exp/sgmm2_5a_ali_si84 || exit 1; + steps/make_denlats_sgmm2.sh --nj 30 --sub-split 30 --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si84 \ + data/train_si84 data/lang exp/sgmm2_5a_ali_si84 exp/sgmm2_5a_denlats_si84 + + steps/train_mmi_sgmm2.sh --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si84 --boost 0.1 \ + data/train_si84 data/lang exp/sgmm2_5a_ali_si84 exp/sgmm2_5a_denlats_si84 exp/sgmm2_5a_mmi_b0.1 + + for iter in 1 2 3 4; do + steps/decode_sgmm2_rescore.sh --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri4b/decode_tgpr_dev93 data/lang_test_tgpr data/test_dev93 exp/sgmm2_5a/decode_tgpr_dev93 \ + exp/sgmm2_5a_mmi_b0.1/decode_tgpr_dev93_it$iter & + done + + steps/train_mmi_sgmm2.sh --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si84 --boost 0.1 \ + --update-opts "--cov-min-value=0.9" data/train_si84 data/lang exp/sgmm2_5a_ali_si84 exp/sgmm2_5a_denlats_si84 exp/sgmm2_5a_mmi_b0.1_m0.9 + + for iter in 1 2 3 4; do + steps/decode_sgmm2_rescore.sh --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri4b/decode_tgpr_dev93 data/lang_test_tgpr data/test_dev93 exp/sgmm2_5a/decode_tgpr_dev93 \ + exp/sgmm2_5a_mmi_b0.1_m0.9/decode_tgpr_dev93_it$iter & + done + +) & + + +( +# The next commands are the same thing on all the si284 data. + +# SGMM system on the si284 data [sgmm5b] + steps/train_ubm.sh --cmd "$train_cmd" \ + 600 data/train_si284 data/lang exp/tri4b_ali_si284 exp/ubm5b || exit 1; + + steps/train_sgmm2.sh --cmd "$train_cmd" \ + 11000 25000 data/train_si284 data/lang exp/tri4b_ali_si284 \ + exp/ubm5b/final.ubm exp/sgmm2_5b || exit 1; + + ( + utils/mkgraph.sh data/lang_test_tgpr exp/sgmm2_5b exp/sgmm2_5b/graph_tgpr + steps/decode_sgmm2.sh --nj 10 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_dev93 \ + exp/sgmm2_5b/graph_tgpr data/test_dev93 exp/sgmm2_5b/decode_tgpr_dev93 + steps/decode_sgmm2.sh --nj 8 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_eval92 \ + exp/sgmm2_5b/graph_tgpr data/test_eval92 exp/sgmm2_5b/decode_tgpr_eval92 + + utils/mkgraph.sh data/lang_test_bd_tgpr exp/sgmm2_5b exp/sgmm2_5b/graph_bd_tgpr || exit 1; + steps/decode_sgmm2.sh --nj 10 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_bd_tgpr_dev93 \ + exp/sgmm2_5b/graph_bd_tgpr data/test_dev93 exp/sgmm2_5b/decode_bd_tgpr_dev93 + steps/decode_sgmm2.sh --nj 8 --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_bd_tgpr_eval92 \ + exp/sgmm2_5b/graph_bd_tgpr data/test_eval92 exp/sgmm2_5b/decode_bd_tgpr_eval92 + ) & + + + # This shows how you would build and test a quinphone SGMM2 system, but + ( + steps/train_sgmm2.sh --cmd "$train_cmd" \ + --context-opts "--context-width=5 --central-position=2" \ + 11000 25000 data/train_si284 data/lang exp/tri4b_ali_si284 \ + exp/ubm5b/final.ubm exp/sgmm2_5c || exit 1; + # Decode from lattices in exp/sgmm2_5b + steps/decode_sgmm2_fromlats.sh --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_dev93 \ + data/test_dev93 data/lang_test_tgpr exp/sgmm2_5b/decode_tgpr_dev93 exp/sgmm2_5c/decode_tgpr_dev93 + steps/decode_sgmm2_fromlats.sh --cmd "$decode_cmd" --transform-dir exp/tri4b/decode_tgpr_eval92 \ + data/test_eval92 data/lang_test_tgpr exp/sgmm2_5b/decode_tgpr_eval92 exp/sgmm2_5c/decode_tgpr_eval92 + ) & + + + steps/align_sgmm2.sh --nj 30 --cmd "$train_cmd" --transform-dir exp/tri4b_ali_si284 \ + --use-graphs true --use-gselect true data/train_si284 data/lang exp/sgmm2_5b exp/sgmm2_5b_ali_si284 + + steps/make_denlats_sgmm2.sh --nj 30 --sub-split 30 --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si284 \ + data/train_si284 data/lang exp/sgmm2_5b_ali_si284 exp/sgmm2_5b_denlats_si284 + + steps/train_mmi_sgmm2.sh --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si284 --boost 0.1 \ + data/train_si284 data/lang exp/sgmm2_5b_ali_si284 exp/sgmm2_5b_denlats_si284 exp/sgmm2_5b_mmi_b0.1 + + for iter in 1 2 3 4; do + for test in eval92; do # dev93 + steps/decode_sgmm2_rescore.sh --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri4b/decode_bd_tgpr_${test} data/lang_test_bd_fg data/test_${test} exp/sgmm2_5b/decode_bd_tgpr_${test} \ + exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_${test}_it$iter & + done + done + + steps/train_mmi_sgmm2.sh --cmd "$decode_cmd" --transform-dir exp/tri4b_ali_si284 --boost 0.1 \ + --drop-frames true data/train_si284 data/lang exp/sgmm2_5b_ali_si284 exp/sgmm2_5b_denlats_si284 exp/sgmm2_5b_mmi_b0.1_z + + for iter in 1 2 3 4; do + for test in eval92 dev93; do + steps/decode_sgmm2_rescore.sh --cmd "$decode_cmd" --iter $iter \ + --transform-dir exp/tri4b/decode_bd_tgpr_${test} data/lang_test_bd_fg data/test_${test} exp/sgmm2_5b/decode_bd_tgpr_${test} \ + exp/sgmm2_5b_mmi_b0.1_z/decode_bd_tgpr_${test}_it$iter & + done + done + +) & + +wait + +# Examples of combining some of the best decodings: SGMM+MMI with +# MMI+fMMI on a conventional system. + +local/score_combine.sh data/test_eval92 \ + data/lang_test_bd_tgpr \ + exp/tri4b_fmmi_a/decode_tgpr_eval92_it8 \ + exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3 \ + exp/combine_tri4b_fmmi_a_sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it8_3 + + +# %WER 4.43 [ 250 / 5643, 41 ins, 12 del, 197 sub ] exp/tri4b_fmmi_a/decode_tgpr_eval92_it8/wer_11 +# %WER 3.85 [ 217 / 5643, 35 ins, 11 del, 171 sub ] exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3/wer_10 +# combined to: +# %WER 3.76 [ 212 / 5643, 32 ins, 12 del, 168 sub ] exp/combine_tri4b_fmmi_a_sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it8_3/wer_12 + +# Checking MBR decode of baseline: +cp -r -T exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3{,.mbr} +local/score_mbr.sh data/test_eval92 data/lang_test_bd_tgpr exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3.mbr +# MBR decoding did not seem to help (baseline was 3.85). I think this is normal at such low WERs. +%WER 3.86 [ 218 / 5643, 35 ins, 11 del, 172 sub ] exp/sgmm2_5b_mmi_b0.1/decode_bd_tgpr_eval92_it3.mbr/wer_10 diff --git a/egs/wsj_noisy/s5/local/run_vtln.sh b/egs/wsj_noisy/s5/local/run_vtln.sh new file mode 100755 index 00000000000..4312ff896aa --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_vtln.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +lang_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +. cmd.sh +featdir=mfcc_vtln +num_leaves=2500 +num_gauss=15000 + + +# train linear vtln +steps/train_lvtln.sh --cmd "$train_cmd" $num_leaves $num_gauss \ + data/train_si84 data/lang${lang_suffix} exp/tri2a exp/tri2c || exit 1 +mkdir -p data/train_si84_vtln +cp -r data/train_si84/* data/train_si84_vtln || exit 1 +cp exp/tri2c/final.warp data/train_si84_vtln/spk2warp || exit 1 + +utils/mkgraph.sh data/lang${lang_suffix}_test_bg_5k \ + exp/tri2c exp/tri2c/graph${lang_suffix}_bg_5k || exit 1; +utils/mkgraph.sh data/lang${lang_suffix}_test_tgpr \ + exp/tri2c exp/tri2c/graph${lang_suffix}_tgpr || exit 1; + + +for t in eval93 dev93 eval92; do + nj=10 + [ $t == eval92 ] && nj=8 + steps/decode_lvtln.sh --nj $nj --cmd "$decode_cmd" \ + exp/tri2c/graph${lang_suffix}_bg_5k data/test_$t \ + exp/tri2c/decode${lang_suffix}_${t}_bg_5k || exit 1 + mkdir -p data/test_${t}_vtln + cp -r data/test_$t/* data/test_${t}_vtln || exit 1 + cp exp/tri2c/decode${lang_suffix}_${t}_bg_5k/final.warp \ + data/test_${t}_vtln/spk2warp || exit 1 +done + +for x in test_eval92 test_eval93 test_dev93 train_si84; do + steps/make_mfcc.sh --nj 20 --cmd "$train_cmd" data/${x}_vtln exp/make_mfcc/${x}_vtln ${featdir} || exit 1 + steps/compute_cmvn_stats.sh data/${x}_vtln exp/make_mfcc/${x}_vtln ${featdir} || exit 1 + utils/fix_data_dir.sh data/${x}_vtln || exit 1 # remove segments with problems +done + +steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + data/train_si84_vtln data/lang${lang_suffix} \ + exp/tri2c exp/tri2c_ali_si84 || exit 1 + +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + --splice-opts "--left-context=3 --right-context=3" \ + 2500 15000 data/train_si84_vtln \ + data/lang${lang_suffix} exp/tri2c_ali_si84 exp/tri2d || exit 1 + +( +utils/mkgraph.sh data/lang${lang_suffix}_test_tgpr \ + exp/tri2d exp/tri2d/graph${lang_suffix}_tgpr || exit 1; +steps/decode.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri2d/graph${lang_suffix}_tgpr data/test_dev93_vtln \ + exp/tri2d/decode${lang_suffix}_tgpr_dev93 || exit 1; +steps/decode.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri2d/graph${lang_suffix}_tgpr data/test_eval92_vtln \ + exp/tri2d/decode${lang_suffix}_tgpr_eval92 || exit 1; +) & + +steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + data/train_si84_vtln data/lang${lang_suffix} \ + exp/tri2d exp/tri2d_ali_si84 || exit 1 + +# From 2d system, train 3c which is LDA + MLLT + SAT. +steps/train_sat.sh --cmd "$train_cmd" \ + 2500 15000 data/train_si84_vtln \ + data/lang${lang_suffix} exp/tri2d_ali_si84 exp/tri3c || exit 1; + +( +utils/mkgraph.sh data/lang${lang_suffix}_test_tgpr \ + exp/tri3c exp/tri3c/graph${lang_suffix}_tgpr || exit 1; +steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3c/graph${lang_suffix}_tgpr data/test_dev93_vtln \ + exp/tri3c/decode${lang_suffix}_tgpr_dev93 || exit 1; +steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3c/graph${lang_suffix}_tgpr data/test_eval93_vtln \ + exp/tri3c/decode${lang_suffix}_tgpr_eval93 || exit 1; +steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri3c/graph${lang_suffix}_tgpr data/test_eval92_vtln \ + exp/tri3c/decode${lang_suffix}_tgpr_eval92 || exit 1; +) & diff --git a/egs/wsj_noisy/s5/local/run_vtln2.sh b/egs/wsj_noisy/s5/local/run_vtln2.sh new file mode 100755 index 00000000000..1a94d74d612 --- /dev/null +++ b/egs/wsj_noisy/s5/local/run_vtln2.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +lang_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +. cmd.sh +featdir=mfcc_vtln +num_leaves=2500 +num_gauss=15000 + +# train linear vtln +steps/train_lvtln.sh --cmd "$train_cmd" $num_leaves $num_gauss \ + data/train_si84 data/lang${lang_suffix} exp/tri2b exp/tri2c || exit 1 + +mkdir -p data/train_si84_vtln +cp -r data/train_si84/* data/train_si84_vtln || exit 1 +cp exp/tri2c/final.warp data/train_si84_vtln/spk2warp || exit 1 + +utils/mkgraph.sh data/lang${lang_suffix}_test_bg_5k \ + exp/tri2c exp/tri2c/graph${lang_suffix}_bg_5k || exit 1; +utils/mkgraph.sh data/lang${lang_suffix}_test_tgpr \ + exp/tri2c exp/tri2c/graph${lang_suffix}_tgpr || exit 1; + +for t in eval93 dev93 eval92; do + nj=10 + [ $t == eval92 ] && nj=8 + steps/decode_lvtln.sh --nj $nj --cmd "$decode_cmd" \ + exp/tri2c/graph${lang_suffix}_bg_5k data/test_$t \ + exp/tri2c/decode${lang_suffix}_${t}_bg_5k || exit 1 + mkdir -p data/test_${t}_vtln + cp -r data/test_$t/* data/test_${t}_vtln || exit 1 + cp exp/tri2c/decode${lang_suffix}_${t}_bg_5k/final.warp \ + data/test_${t}_vtln/spk2warp || exit 1 +done + + +for x in test_eval92 test_eval93 test_dev93 train_si84; do + steps/make_mfcc.sh --nj 20 --cmd "$train_cmd" data/${x}_vtln exp/make_mfcc/${x}_vtln ${featdir} || exit 1 + steps/compute_cmvn_stats.sh data/${x}_vtln exp/make_mfcc/${x}_vtln ${featdir} || exit 1 + utils/fix_data_dir.sh data/${x}_vtln || exit 1 # remove segments with problems +done + +steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + data/train_si84_vtln data/lang${lang_suffix} \ + exp/tri2c exp/tri2c_ali_si84 || exit 1 + +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + --splice-opts "--left-context=3 --right-context=3" \ + 2500 15000 data/train_si84_vtln data/lang${lang_suffix} \ + exp/tri2c_ali_si84 exp/tri2d || exit 1 + +( +utils/mkgraph.sh data/lang${lang_suffix}_test_tgpr \ + exp/tri2d exp/tri2d/graph${lang_suffix}_tgpr || exit 1; +steps/decode.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri2d/graph${lang_suffix}_tgpr data/test_dev93_vtln \ + exp/tri2d/decode${lang_suffix}_tgpr_dev93 || exit 1; +steps/decode.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri2d/graph${lang_suffix}_tgpr data/test_eval92_vtln \ + exp/tri2d/decode${lang_suffix}_tgpr_eval92 || exit 1; +) & + +steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + data/train_si84_vtln data/lang${lang_suffix} \ + exp/tri2d exp/tri2d_ali_si84 || exit 1 + +# From 2d system, train 3c which is LDA + MLLT + SAT. +steps/train_sat.sh --cmd "$train_cmd" \ + 2500 15000 data/train_si84_vtln data/lang${lang_suffix} \ + exp/tri2d_ali_si84 exp/tri3c || exit 1; + +( +utils/mkgraph.sh data/lang${lang_suffix}_test_tgpr \ + exp/tri3c exp/tri3c/graph${lang_suffix}_tgpr || exit 1; +steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3c/graph${lang_suffix}_tgpr data/test_dev93_vtln \ + exp/tri3c/decode${lang_suffix}_tgpr_dev93 || exit 1; +steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3c/graph${lang_suffix}_tgpr data/test_eval93_vtln \ + exp/tri3c/decode${lang_suffix}_tgpr_eval93 || exit 1; +steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri3c/graph${lang_suffix}_tgpr data/test_eval92_vtln \ + exp/tri3c/decode${lang_suffix}_tgpr_eval92 || exit 1; +) & + + +# Below shows the results we got with this script. +# Actually we only have improvement on dev93 and the others get worse. +# With VTLN: +# for x in exp/tri3c/decode_tgpr_{dev,eval}{92,93}; do grep WER $x/wer_* | utils/best_wer.sh ; done +# %WER 13.86 [ 1141 / 8234, 235 ins, 123 del, 783 sub ] exp/tri3c/decode_tgpr_dev93/wer_17 +# %WER 9.23 [ 521 / 5643, 131 ins, 31 del, 359 sub ] exp/tri3c/decode_tgpr_eval92/wer_16 +# %WER 12.47 [ 430 / 3448, 67 ins, 43 del, 320 sub ] exp/tri3c/decode_tgpr_eval93/wer_14 + +# Baseline: +#(note, I had to run the following extra decoding to get this) +#steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" exp/tri3b/graph_tgpr data/test_eval93_vtln exp/tri3b/decode_tgpr_eval93 +# +# a04:s5: for x in exp/tri3b/decode_tgpr_{dev,eval}{92,93}; do grep WER $x/wer_* | utils/best_wer.sh ; done +# %WER 14.37 [ 1183 / 8234, 228 ins, 122 del, 833 sub ] exp/tri3b/decode_tgpr_dev93/wer_19 +# %WER 8.98 [ 507 / 5643, 129 ins, 28 del, 350 sub ] exp/tri3b/decode_tgpr_eval92/wer_14 +# %WER 12.21 [ 421 / 3448, 68 ins, 39 del, 314 sub ] exp/tri3b/decode_tgpr_eval93/wer_14 + diff --git a/egs/wsj_noisy/s5/local/score.sh b/egs/wsj_noisy/s5/local/score.sh new file mode 120000 index 00000000000..0afefc3158c --- /dev/null +++ b/egs/wsj_noisy/s5/local/score.sh @@ -0,0 +1 @@ +../steps/score_kaldi.sh \ No newline at end of file diff --git a/egs/wsj_noisy/s5/local/score_combine.sh b/egs/wsj_noisy/s5/local/score_combine.sh new file mode 100755 index 00000000000..576962c7442 --- /dev/null +++ b/egs/wsj_noisy/s5/local/score_combine.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +# Copyright 2013 Arnab Ghoshal + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Script for system combination using minimum Bayes risk decoding. +# This calls lattice-combine to create a union of lattices that have been +# normalized by removing the total forward cost from them. The resulting lattice +# is used as input to lattice-mbr-decode. This should not be put in steps/ or +# utils/ since the scores on the combined lattice must not be scaled. + +# begin configuration section. +cmd=run.pl +min_lmwt=9 +max_lmwt=20 +lat_weights= +#end configuration section. + +help_message="Usage: "$(basename $0)" [options] [decode-dir3 ... ] +Options: + --cmd (run.pl|queue.pl...) # specify how to run the sub-processes. + --min-lmwt INT # minumum LM-weight for lattice rescoring + --max-lmwt INT # maximum LM-weight for lattice rescoring + --lat-weights STR # colon-separated string of lattice weights +"; + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -lt 5 ]; then + printf "$help_message\n"; + exit 1; +fi + +data=$1 +graphdir=$2 +odir=${@: -1} # last argument to the script +shift 2; +decode_dirs=( $@ ) # read the remaining arguments into an array +unset decode_dirs[${#decode_dirs[@]}-1] # 'pop' the last argument which is odir +num_sys=${#decode_dirs[@]} # number of systems to combine + +symtab=$graphdir/words.txt +[ ! -f $symtab ] && echo "$0: missing word symbol table '$symtab'" && exit 1; +[ ! -f $data/text ] && echo "$0: missing reference '$data/text'" && exit 1; + + +mkdir -p $odir/log + +for i in `seq 0 $[num_sys-1]`; do + model=${decode_dirs[$i]}/../final.mdl # model one level up from decode dir + for f in $model ${decode_dirs[$i]}/lat.1.gz ; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; + done + lats[$i]="\"ark:gunzip -c ${decode_dirs[$i]}/lat.*.gz |\"" +done + +mkdir -p $odir/scoring/log + +cat $data/text | sed 's:::g' | sed 's:::g' \ + > $odir/scoring/test_filt.txt + +if [ -z "$lat_weights" ]; then + $cmd LMWT=$min_lmwt:$max_lmwt $odir/log/combine_lats.LMWT.log \ + lattice-combine --inv-acoustic-scale=LMWT ${lats[@]} ark:- \| \ + lattice-mbr-decode --word-symbol-table=$symtab ark:- \ + ark,t:$odir/scoring/LMWT.tra || exit 1; +else + $cmd LMWT=$min_lmwt:$max_lmwt $odir/log/combine_lats.LMWT.log \ + lattice-combine --inv-acoustic-scale=LMWT --lat-weights=$lat_weights \ + ${lats[@]} ark:- \| \ + lattice-mbr-decode --word-symbol-table=$symtab ark:- \ + ark,t:$odir/scoring/LMWT.tra || exit 1; +fi + +$cmd LMWT=$min_lmwt:$max_lmwt $odir/scoring/log/score.LMWT.log \ + cat $odir/scoring/LMWT.tra \| \ + utils/int2sym.pl -f 2- $symtab \| sed 's:\::g' \| \ + compute-wer --text --mode=present \ + ark:$odir/scoring/test_filt.txt ark,p:- ">&" $odir/wer_LMWT || exit 1; + +exit 0 diff --git a/egs/wsj_noisy/s5/local/score_mbr.sh b/egs/wsj_noisy/s5/local/score_mbr.sh new file mode 100755 index 00000000000..4052512f726 --- /dev/null +++ b/egs/wsj_noisy/s5/local/score_mbr.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# Script for minimum bayes risk decoding. + +[ -f ./path.sh ] && . ./path.sh; + +# begin configuration section. +cmd=run.pl +min_lmwt=9 +max_lmwt=20 +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score_mbr.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + exit 1; +fi + +data=$1 +lang_or_graph=$2 +dir=$3 + +symtab=$lang_or_graph/words.txt + +for f in $symtab $dir/lat.1.gz $data/text; do + [ ! -f $f ] && echo "score_mbr.sh: no such file $f" && exit 1; +done + +mkdir -p $dir/scoring/log + +cat $data/text | sed 's:::g' | sed 's:::g' > $dir/scoring/test_filt.txt + +# We submit the jobs separately, not as an array, because it's hard +# to get the inverse of the LM scales. +rm $dir/.error 2>/dev/null +for inv_acwt in `seq $min_lmwt $max_lmwt`; do + acwt=`perl -e "print (1.0/$inv_acwt);"` + $cmd $dir/scoring/rescore_mbr.${inv_acwt}.log \ + lattice-mbr-decode --acoustic-scale=$acwt --word-symbol-table=$symtab \ + "ark:gunzip -c $dir/lat.*.gz|" ark,t:$dir/scoring/${inv_acwt}.tra \ + || touch $dir/.error & +done +wait; +[ -f $dir/.error ] && echo "score_mbr.sh: errror getting MBR outout."; + + +$cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.log \ + cat $dir/scoring/LMWT.tra \| \ + utils/int2sym.pl -f 2- $symtab \| sed 's:\::g' \| \ + compute-wer --text --mode=present \ + ark:$dir/scoring/test_filt.txt ark,p:- ">" $dir/wer_LMWT || exit 1; + diff --git a/egs/wsj_noisy/s5/local/snr/compute_frame_snrs.sh b/egs/wsj_noisy/s5/local/snr/compute_frame_snrs.sh new file mode 100755 index 00000000000..e147f39f5dc --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/compute_frame_snrs.sh @@ -0,0 +1,136 @@ +#!/bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0 + +set -e +set -o pipefail + +. path.sh + +# This script computes per-frame SNR from time-frequency bin SNR predicted +# by an SNR predictor nnet and the original noisy fbank features + +cmd=run.pl +nj=4 +use_gpu=no +iter=final +copy_opts= # Due to code change, the log(Irm) predicted might have previously been log(sqrt(Irm)). Hence use "matrix-scale --scale=2.0 ark:- ark:- \|". Also for log(Snr), it might have been log(sqrt(Snr)). +stage=0 + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 exp/nnet3_snr_predictor/nnet_tdnn_a data/train_si284_corrupted_hires data/train_si284_corrupted_fbank exp/frame_snrs_train_si284_corrupted" + exit 1 +fi + +snr_predictor_nnet_dir=$1 +corrupted_data_dir=$2 +corrupted_fbank_dir=$3 +dir=$4 + +mkdir -p $dir + +[ ! -f $snr_predictor_nnet_dir/target_type ] && echo "$snr_predictor_nnet_dir/target_type could not be found" && exit 1 +prediction_type=`cat $snr_predictor_nnet_dir/target_type` || exit 1 +echo $prediction_type > $dir/prediction_type + +if [ $prediction_type == "IrmExp" ]; then + copy_opts="$copy_opts copy-matrix --apply-log=true ark:- ark:- |" +fi + +split_data.sh $corrupted_fbank_dir $nj +split_data.sh $corrupted_data_dir $nj + +sdata=$corrupted_data_dir/split$nj + +cmvn_opts=$(cat $snr_predictor_nnet_dir/cmvn_opts 2>/dev/null) || exit 1 + +feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |" + +if [ -f $snr_predictor_nnet_dir/final.mat ]; then + feat_type=lda + + splice_opts=`cat $snr_predictor_nnet_dir/splice_opts 2>/dev/null` + feats="$feats splice-feats $splice_opts ark:- ark:- | transform-feats $snr_predictor_nnet_dir/final.mat ark:- ark:- |" +fi + +gpu_cmd=$cmd +if [ $use_gpu != "no" ]; then + gpu_cmd="$cmd --gpu 1" +fi + +if [ $stage -le 0 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$dir/storage $dir/storage + for n in `seq $nj`; do + utils/create_data_link.pl $dir/nnet_pred.$n.ark + utils/create_data_link.pl $dir/clean_pred.$n.ark + utils/create_data_link.pl $dir/frame_snrs.$n.ark + done + fi + + $gpu_cmd JOB=1:$nj $dir/log/compute_nnet_pred.JOB.log \ + nnet3-compute --use-gpu=$use_gpu $snr_predictor_nnet_dir/$iter.raw "$feats" \ + ark:- \| ${copy_opts}copy-feats --compress=false ark:- \ + ark,scp:$dir/nnet_pred.JOB.ark,$dir/nnet_pred.JOB.scp || exit 1 + + for n in `seq $nj`; do + cat $dir/nnet_pred.$n.scp; + done > $dir/nnet_pred.scp +fi + +if [ $stage -le 1 ]; then + case $prediction_type in + "Irm"|"IrmExp") + # nnet_pred is log (clean energy / (clean energy + noise energy) ) + $cmd JOB=1:$nj $dir/log/compute_frame_snrs.JOB.log \ + compute-frame-snrs --prediction-type="Irm" \ + scp:$corrupted_fbank_dir/split$nj/JOB/feats.scp \ + ark:$dir/nnet_pred.JOB.ark \ + "ark:|vector-to-feat ark:- ark:- | copy-feats --compress=true ark:- ark,scp:$dir/frame_snrs.JOB.ark,$dir/frame_snrs.JOB.scp" \ + "ark:|copy-feats --compress=true ark:- ark,scp:$dir/clean_pred.JOB.ark,$dir/clean_pred.JOB.scp" \ + "ark:|copy-feats --compress=true ark:- ark,scp:$dir/out_snr.JOB.ark,$dir/out_snr.JOB.scp" + ;; + "FbankMask") + # nnet_pred is log (clean feat / noisy feat) + $cmd JOB=1:$nj $dir/log/compute_frame_snrs.JOB.log \ + compute-frame-snrs --prediction-type="FbankMask" \ + scp:$corrupted_fbank_dir/split$nj/JOB/feats.scp \ + ark:$dir/nnet_pred.JOB.ark \ + "ark:|vector-to-feat ark:- ark:- | copy-feats --compress=true ark:- ark,scp:$dir/frame_snrs.JOB.ark,$dir/frame_snrs.JOB.scp" \ + "ark:|copy-feats --compress=true ark:- ark,scp:$dir/clean_pred.JOB.ark,$dir/clean_pred.JOB.scp" \ + "ark:|copy-feats --compress=true ark:- ark,scp:$dir/out_snr.JOB.ark,$dir/out_snr.JOB.scp" + ;; + "FrameSnr") + $cmd JOB=1:$nj $dir/log/compute_frame_snrs.JOB.log \ + extract-column 0 $dir/nnet_pred.JOB.ark ark:- \| \ + vector-to-feat ark:- ark:- \| copy-feats --compress=true ark:- ark,scp:$dir/frame_snrs.JOB.ark,$dir/frame_snrs.JOB.scp + ;; + "Snr") + $cmd JOB=1:$nj $dir/log/compute_frame_snrs.JOB.log \ + compute-frame-snrs --prediction-type="Snr" \ + scp:$corrupted_fbank_dir/split$nj/JOB/feats.scp \ + ark:$dir/nnet_pred.JOB.ark \ + "ark:|vector-to-feat ark:- ark:- | copy-feats --compress=true ark:- ark,scp:$dir/frame_snrs.JOB.ark,$dir/frame_snrs.JOB.scp" \ + "ark:|copy-feats --compress=true ark:- ark,scp:$dir/clean_pred.JOB.ark,$dir/clean_pred.JOB.scp" \ + "ark:|copy-feats --compress=true ark:- ark,scp:$dir/out_snr.JOB.ark,$dir/out_snr.JOB.scp" + ;; + *) + echo "Unknown prediction-type '$prediction_type'" && exit 1 + esac +fi + +if [ $stage -le 2 ]; then + for n in `seq $nj`; do + cat $dir/frame_snrs.$n.scp + done > $dir/frame_snrs.scp + + for n in `seq $nj`; do + cat $dir/out_snr.$n.scp + done > $dir/nnet_pred_snrs.scp +fi + diff --git a/egs/wsj_noisy/s5/local/snr/compute_sad.sh b/egs/wsj_noisy/s5/local/snr/compute_sad.sh new file mode 100755 index 00000000000..d87fa35de47 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/compute_sad.sh @@ -0,0 +1,146 @@ +#!/bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0 +set -o pipefail +set -e +set -u + +. cmd.sh +. path.sh + +method=LogisticRegression +nj=40 +stage=-10 +iter=final +splice_opts="--left-context=10 --right-context=10" +model_dir=exp/nnet3_sad_snr/tdnn_train_si284_corrupted_splice21 +snr_pred_dir=exp/frame_snrs_lwr_snr_reverb_dev_aspire_whole/ +dir=exp/nnet3_sad_snr/sad_train_si284_corrupted +quantization_bins=-2.5:2.5:7.5:12.5:17.5 +use_gpu=no +sil_prior=0.5 +speech_prior=0.5 +add_frame_snr=false +snr_data_dir= + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 $model_dir $snr_pred_dir $dir" + exit 1 +fi + +model_dir=$1 +snr_pred_dir=$2 +dir=$3 + +if [ -z "$snr_data_dir" ]; then + if [ ! -s $snr_pred_dir/nnet_pred_snrs.scp ]; then + echo "$0: Could not read $snr_pred_dir/nnet_pred_snrs.scp or it is empty" + exit 1 + fi +fi + +mkdir -p $dir + +feat_type=`cat $model_dir/feat_type` || exit 1 + +echo $nj > $dir/num_jobs + +gpu_opts= +if [ $use_gpu == "yes" ]; then + gpu_opts="--gpu 1" +fi + +append_feats_opts="copy-feats scp:- ark:- |" + + +if [ -z "$snr_data_dir" ]; then + if $add_frame_snr; then + append_feats_opts="append-feats scp:- scp:$snr_pred_dir/frame_snrs.scp ark:- |" + fi +fi + +feats=snr_pred_dir/nnet_pred_snrs.scp +if [ ! -z "$snr_data_dir" ]; then + feats=$snr_data_dir/feats.scp +fi + +if [ $stage -le 1 ]; then + case $method in + "LogisticRegressionSubsampled") + model=$model_dir/$iter.mdl + + $decode_cmd --mem 8G JOB=1:$nj $dir/log/eval_logistic_regression.JOB.log \ + logistic-regression-eval-on-feats "$model" \ + "ark:utils/split_scp.pl -j $nj \$[JOB-1] $feats |$append_feats_opts splice-feats $splice_opts ark:- ark:- |" \ + ark:$dir/log_nnet_posteriors.JOB.ark || exit 1 + ;; + "LogisticRegression"|"Dnn") + model=$model_dir/$iter.raw + + if [ $feat_type != "sparse" ]; then + $decode_cmd --mem 8G $gpu_opts JOB=1:$nj $dir/log/eval_tdnn.JOB.log \ + nnet3-compute --apply-exp=false --use-gpu=$use_gpu "$model" \ + "ark:utils/split_scp.pl -j $nj \$[JOB-1] $feats |$append_feats_opts" \ + ark:$dir/log_nnet_posteriors.JOB.ark || exit 1 + else + num_bins=`echo $quantization_bins | awk -F ':' '{print NF + 1}' 2>/dev/null` || exit 1 + feat_dim=`head -n 1 $feats | feat-to-dim scp:- - 2>/dev/null` || exit 1 + if $add_frame_snr; then + feat_dim=$[feat_dim+1] + fi + sparse_input_dim=$[num_bins * feat_dim] + + + train_num_bins=`cat $model_dir/quantization_bin_boundaries | awk -F ':' '{print NF + 1}' 2>/dev/null` || exit 1 + + if [ $num_bins -ne $train_num_bins ]; then + echo "$0: Mismatch in number of bins during test and train; $num_bins vs $train_num_bins" + exit 1 + fi + + $decode_cmd --mem 8G $gpu_opts JOB=1:$nj $dir/log/eval_tdnn.JOB.log \ + nnet3-compute-from-sparse-input --apply-exp=false --use-gpu=$use_gpu --sparse-input-dim=$sparse_input_dim "$model" \ + "ark:utils/split_scp.pl -j $nj \$[JOB-1] $feats |$append_feats_opts quantize-feats ark:- $quantization_bins ark:- |" \ + ark:$dir/log_nnet_posteriors.JOB.ark || exit 1 + fi + ;; + *) + echo "Unknown method $method" + exit 1 + esac +fi + +if [ $stage -le 2 ]; then + if [ ! -f $model_dir/post.$iter.vec ]; then + echo "Could not find $model_dir/post.$iter.vec. Usually computed by averaging the nnet posteriors" + exit 1 + fi + + cat $model_dir/post.$iter.vec | awk '{if (NF != 4) { print "posterior vector must have dimension two; but has dimension "NF-2; exit 1;} else { printf ("[ %f %f ]\n", log($2/($2+$3)), log($3/($2+$3)));}}' > $dir/nnet_log_priors + + $decode_cmd JOB=1:$nj $dir/log/get_likes.JOB.log \ + matrix-add-offset ark:$dir/log_nnet_posteriors.JOB.ark "vector-scale --scale=-1.0 --binary=false $dir/nnet_log_priors - |" \ + ark,scp:$dir/log_likes.JOB.ark,$dir/log_likes.JOB.scp || exit 1 + + cat $dir/nnet_log_priors | awk -v sil_prior=$sil_prior -v speech_prior=$speech_prior '{sum_prior = speech_prior + sil_prior; printf ("[ %f %f ]", -$2+log(sil_prior)-log(sum_prior), -$3+log(speech_prior)-log(sum_prior));}' > $dir/log_priors + + $decode_cmd JOB=1:$nj $dir/log/adjust_priors.JOB.log \ + matrix-add-offset ark:$dir/log_nnet_posteriors.JOB.ark $dir/log_priors \ + ark,scp:$dir/log_posteriors.JOB.ark,$dir/log_posteriors.JOB.scp || exit 1 + + $decode_cmd JOB=1:$nj $dir/log/extract_logits.JOB.log \ + vector-sum "ark:extract-column --column-index=1 scp:$dir/log_posteriors.JOB.scp ark:- |" \ + "ark:extract-column --column-index=0 scp:$dir/log_posteriors.JOB.scp ark:- | vector-scale --scale=-1 ark:- ark:- |" \ + ark,scp:$dir/logits.JOB.ark,$dir/logits.JOB.scp || exit 1 +fi + +if [ $stage -le 3 ]; then + $decode_cmd JOB=1:$nj $dir/log/extract_prob.JOB.log \ + loglikes-to-post scp:$dir/log_posteriors.JOB.scp ark:- \| \ + weight-pdf-post 0 0 ark:- ark:- \| post-to-weights ark:- \ + ark,scp:$dir/speech_prob.JOB.ark,$dir/speech_prob.JOB.scp || exit 1 +fi diff --git a/egs/wsj_noisy/s5/local/snr/convert_file_vad_to_segments.sh b/egs/wsj_noisy/s5/local/snr/convert_file_vad_to_segments.sh new file mode 100755 index 00000000000..7f828f81e25 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/convert_file_vad_to_segments.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +. path.sh + +cmd=run.pl +nj=4 +keep_only_speech=false + +. utils/parse_options.sh + +if [ $# -ne 2 ]; then + echo "Usage: $0 " + echo " e.g.: $0 exp/vad_data_prep_dev/file_vad exp/vad_data_prep_dev/file_vad" + exit 1 +fi + +vad_dir=$1 +dir=$2 + +merge_opts=( --remove-labels=4:10 --merge-labels=0:1 --merge-dst-label=1 ) +if $keep_only_speech; then + merge_opts=( --remove-labels=0:4:10 ) +fi + +$cmd JOB=1:$nj $dir/log/get_segments.JOB.log \ + segmentation-init-from-ali ark:$vad_dir/vad.JOB.ark ark:- \| \ + segmentation-post-process ${merge_opts[@]} \ + --shrink-length=20 --shrink-label=1 --merge-adjacent-segments=true --max-intersegment-length=1 \ + ark:- ark:- \| segmentation-to-segments --single-speaker=true ark:- \ + ark,t:$dir/reco2utt.JOB ark,t:$dir/segments.JOB || exit 1 + +for n in `seq $nj`; do + cat $dir/reco2utt.$n +done > $dir/reco2utt + +for n in `seq $nj`; do + cat $dir/segments.$n +done > $dir/segments + diff --git a/egs/wsj_noisy/s5/local/snr/corrupt.py b/egs/wsj_noisy/s5/local/snr/corrupt.py new file mode 100755 index 00000000000..7d7d872aba6 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/corrupt.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +# Copyright 2014 Johns Hopkins University (Authors: Vijayaditya Peddinti). Apache 2.0. + +# corrupts the wave files supplied via input pipe with the specified +# room-impulse response (RIR) and additive noise distortions (specified by corresponding files) + +import wave, struct, sys, scipy.signal as signal, numpy as np, argparse, scipy.io.wavfile, warnings, subprocess + +def wave_load_from_command(wav_command, temp_file='temp.wav'): + try: + subprocess.check_call(wav_command + ' > ' + temp_file, shell = True) + return wave_load(temp_file) + except subprocess.CalledProcessError: + return None + +def wave_load_from_command_secure(wav_command): + sub_commands = wav_command.split('|') + subprocess_list = [] + input = None + for command in sub_commands: + subprocess_list.append(subprocess.Popen(command.split(), stdin = input, stdout = subprocess.PIPE)) + input = subprocess_list[-1].stdout + return wave_load(subprocess_list[-1].stdout) + +def wave_load(file): + if hasattr(file,'read'): + warnings.warn(' Assuming that the input is int stream.') + wav = wave.open(file) + (nchannels, sampwidth, framerate, nframes, comptype, compname) = wav.getparams () + if sampwidth == 2: + dtype = np.int16 + elif sampwidth == 4: + dtype = np.int32 + elif sampwidth == 8: + dtype = np.int64 + frames = wav.readframes(nframes * nchannels) + out = struct.unpack_from("%dh" % nframes * nchannels, frames) + out = np.array(out, dtype = dtype) + out = np.reshape(out, [nchannels, -1], order = 'F') + out = out.transpose() + else: + [framerate, out] = scipy.io.wavfile.read(file) + if len(out.shape) == 1: + out = out.reshape([out.shape[0], 1]) + if issubclass(out.dtype.type, np.integer): + max_val = float(np.iinfo(out.dtype).max) + out = out / max_val + return (framerate, out) + +def wav_write(file_handle, fs, data, normalize = 'True'): + if str(data.dtype) in set(['float64', 'float32']): + #rms_val = np.sqrt(np.mean(data * data)) + #data = (0.25 * data / rms_val ) * (2 ** 15) + if (normalize): + data = (0.99 * data / np.max(np.abs(data))) * (2 ** 15) + else: + data = np.clip(data, 0, 0.99) * (2 ** 15); + data = data.astype('int16', copy = False) + elif str(data.dtype) == 'int16': + pass + else: + raise Exception('Not implemented for '+str(data.dtype)) + scipy.io.wavfile.write(file_handle, fs, data) + +def corrupt(x, h, n, snr, signal_db): + # x : signal, single channel signal + # h : room impulse response, can be multi-channel + # n : noise signal, can be multi-channel (same as h) + # snr : snr of the noise added signal + # signal_db : required clean signal power + + direct_signal = x[1][:, 0] # make input single channel + x_power = float(np.mean(direct_signal**2)) + direct_signal = direct_signal / np.sqrt(x_power) * (10 ** (signal_db / 20.0)) + assert (abs(float(np.mean(direct_signal**2)) - 10 ** (signal_db / 10.0)) < 1e-10) + + # compute direct reverberation of the RIR + fs = x[0] + if h is not None: + x = x[1][:, 0] / np.sqrt(x_power) * (10 ** (signal_db / 20.0)) # make input single channel + assert(h[0] == fs) + h = h[1] # copy the samples from (sampling_rate, samples) tuple + channel_one = h[:,0] + max_h = max(channel_one) + delay_impulse = [i for i, j in enumerate(channel_one) if j == max_h][0] + before_impulse = np.floor(fs * 0.001) + after_impulse = np.floor(fs * 0.05) + direct_rir = channel_one[max(0, delay_impulse - before_impulse):min(len(channel_one), delay_impulse + after_impulse)] + direct_rir = np.array(direct_rir) + direct_signal = signal.fftconvolve(x, direct_rir) + + # compute the reverberant signal + y = np.zeros([x.shape[0] + h.shape[0] - 1, h.shape[1]]) + for channel in xrange(h.shape[1]): + y[:, channel] = signal.fftconvolve(x, h[:,channel]) + else: + assert (n is not None) + x = x[1][:, 0] / np.sqrt(x_power) * (10 ** (signal_db / 20.0)) # make input single channel + y = np.zeros([x.shape[0], n[1].shape[1]]) + for channel in xrange(n[1].shape[1]): + y[:, channel] = x + direct_signal = x; + delay_impulse = 0 + # compute the scaled noise + if n is not None: + fs_n = n[0] + n = n[1] + sys.stderr.write('Noise signal : '+str(n.shape) + '\n') + n_y = np.zeros(y.shape) + assert(fs_n == fs) # sampling rate of noise and signal is same + assert(n.shape[1] == y.shape[1]) # both the reverberant signal and noise signal have the same number of channels + # repeat the source noise data "n" to match the length of the reverberant signal + num_reps = int(np.floor(n_y.shape[0] / n.shape[0])) + dest_array_index = 0 + for i in xrange(0, num_reps): + np.copyto(n_y[n.shape[0] * i : n.shape[0] * (i+1), :], n) + dest_array_index = n.shape[0] * (i+1) + # fill the remaining portion of destination with the initial samples of n + np.copyto(n_y[dest_array_index:, :], n[0 : n_y.shape[0] - dest_array_index, :]) + # normalize noise data according to the prefixed SNR value + n_ref = n_y[:, 0] + n_power = float(np.mean(n_ref**2)) + x_power = float(np.mean(direct_signal**2)) + M_snr = np.multiply(1/n_power, x_power) + M_snr = np.sqrt((10**(-snr/10.0))*M_snr) + n_scaled = np.dot(n_y, np.diagflat(M_snr)) + y = y + n_scaled + # scipy.io.savemat('debug.mat',{'n_scaled':n_scaled, 'y':y, 'x':x, 'h':h}) + return (y[delay_impulse:(delay_impulse + x.shape[0]), :], x, n_scaled) + +if __name__ == "__main__": + usage = """ Python script to corrupt the input wav stream with + the specified room impulse response and noise source.""" + sys.stderr.write(str(" ".join(sys.argv))) + main_parser = argparse.ArgumentParser(usage) + main_parser.add_argument('--temp-file-name', type=str, default='temp.wav', help='file name of temp file to be used') + main_parser.add_argument('--normalize', type=str, default='true', choices=['true','True','false','False'], help='normalize wave while writing') + main_parser.add_argument('input_file', type=str, help='file with list of wave files and corresponding corruption parameters') + main_params = main_parser.parse_args() + temp_file = main_params.temp_file_name + wav_param_list = map( lambda x: x.strip(), open(main_params.input_file)) + + for line in wav_param_list: + try: + parser = argparse.ArgumentParser() + parser.add_argument('--rir-file', type=str, help='file with the room impulse response') + parser.add_argument('--noise-file', type=str, help='file with additive noise') + parser.add_argument('--snr-db', type=float, default=20, help='desired SNR(dB) of the output') + parser.add_argument('--signal-db', type=float, default=-5, help='desired signal power (dB) of the clean signal') + parser.add_argument('--multi-channel', type=str, default='False', help='is output multi-channel') + parser.add_argument('--out-clean-file', type=str, help='Write the clean file just before adding noise') + parser.add_argument('--out-noise-file', type=str, help='Write the noise file just before adding to clean') + parser.add_argument('input_file', type=str, help='input-file') + parser.add_argument('output_file', type=str, help='output-file') + + parts = line.split('|') + wav_command = "|".join(parts[:-1]) + params = parser.parse_args(parts[-1].split()) + if params.multi_channel.lower() == 'true': + params.multi_channel = True + raise Exception("Cannot generate multi-channel outputs") + else: + params.multi_channel = False + sys.stderr.write(line) + # read the wav input from the stdin + x = wave_load_from_command(wav_command, temp_file) + if x is None: + sys.stderr.write('There was error trying to run the command\n'+wav_command) + continue + + sys.stderr.write('Input signal : '+str(x[1].shape) + '\n') + fs = x[0] + if x[1].shape[1] > 1: + raise Exception('Input wave file cannot be multi-channel') + # read the impulse response if available from the file + if params.rir_file is not None: + h = wave_load(params.rir_file) + if not params.multi_channel: + sys.stderr.write('Impulse response : '+str(h[1].shape) + '\n') + channel1 = h[1][:, 0] + h = (h[0], channel1.reshape([channel1.shape[0],1])) # just select the first channel + else: + h = None + + # read the noise if available from the file + if params.noise_file is not None: + n = wave_load(params.noise_file) + if not params.multi_channel: + channel1 = n[1][:, 0] + n = (n[0], channel1.reshape([channel1.shape[0], 1])) + else: + n = None + + y,x_clean,x_noise = corrupt(x, h, n, params.snr_db, params.signal_db) + if params.out_clean_file is not None: + wav_write(params.out_clean_file, fs, x_clean, main_params.normalize) + if params.out_noise_file is not None: + wav_write(params.out_noise_file, fs, x_noise, main_params.normalize) + wav_write(params.output_file, fs, y, main_params.normalize) + sys.stderr.write('Output signal : '+str(y.shape) + '\n') + if hasattr(params.output_file, 'write'): + params.output_file.flush() + except struct.error: + warnings.warn("Could not reverberate signal {0}") + continue diff --git a/egs/wsj_noisy/s5/local/snr/corrupt_data_dir.sh b/egs/wsj_noisy/s5/local/snr/corrupt_data_dir.sh new file mode 100755 index 00000000000..01c4748bf23 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/corrupt_data_dir.sh @@ -0,0 +1,267 @@ +#!/bin/bash + +# Copyright 2014 Johns Hopkins University (Author: Vijayaditya Peddinti) +# 2015 Tom Ko +# 2015 Vimal Manohar +# Apache 2.0. +# This script processes generates multi-condition training data from clean data dir +# and directory with impulse responses and noises + +. ./cmd.sh; +set -e +set -o pipefail + +stage=0 +random_seed=0 +num_files_per_job=100 +background_snrs="20:10:15:5:0:-2:-5:-10" +foreground_snrs="20:10:15:5:0:-2:-5:-10" +tmp_dir=exp/make_corrupt +output_clean_dir= +output_clean_wav_dir= +output_noise_dir= +output_noise_wav_dir= +dest_wav_dir= +select_only_corruption_with_noise=false +nj=200 +dry_run=true +pad_silence=false + +. ./path.sh; +. ./utils/parse_options.sh + +if [ $# != 3 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 --random-seed 12 data/train_si284 data_multicondition/impulses_noises data/train_si284p" + exit 1; +fi + +src_dir=$1 +impnoise_dir=$2 +dest_dir=$3 + +# $impnoise_dir must contain a directory info which has the following files +# impulse_files - list of impulse response wav files +# background_noise_files - list of noise wav files +# noise_impulse_* - containes pairs of impulse responses and noise files in +# the following format +# background_noise_files = +# impulse_files = + +data_id=`basename $src_dir` +if [ -z "$dest_wav_dir" ]; then + dest_wav_dir=$dest_dir/wavs_${data_id} +fi + +dest_wav_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dest_wav_dir ${PWD}` + +if [ ! -z "$output_clean_dir" ]; then + [ -z "$output_clean_wav_dir" ] && output_clean_wav_dir=$output_clean_dir/wavs_${data_id} + mkdir -p $output_clean_dir $output_clean_wav_dir + output_clean_wav_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $output_clean_wav_dir ${PWD}` +fi + +if [ ! -z "$output_noise_dir" ]; then + [ -z "$output_noise_wav_dir" ] && output_noise_wav_dir=$output_noise_dir/wavs_${data_id} + mkdir -p $output_noise_wav_dir + output_noise_wav_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $output_noise_wav_dir ${PWD}` +fi + +mkdir -p $dest_dir +mkdir -p $tmp_dir +mkdir -p $dest_wav_dir + +wav_prefix="corrupted${random_seed}_" +utt_prefix="corrupted${random_seed}_" +spk_prefix="corrupted${random_seed}_" + +if [ $stage -le 0 ]; then + # Create the distorted wave files + utils/copy_data_dir.sh --spk-prefix "$spk_prefix" --utt-prefix "$utt_prefix" \ + $src_dir $dest_dir + cat $src_dir/utt2spk | \ + awk -v p=$utt_prefix '{printf("%s%s %s\n", p, $1, $1);}' > $dest_dir/utt2uniq + + cat $src_dir/wav.scp | sed -e "s/^\s*//g" | \ + cut -d' ' -f1 | \ + awk -v p1=$dest_wav_dir -v p2=$wav_prefix \ + '{printf("%s%s %s/%s%s.wav\n", p2, $1, p1, p2, $1);}'> $tmp_dir/corrupted_${random_seed}.list + + if [ ! -z "$output_clean_dir" ]; then + utils/copy_data_dir.sh --extra-files utt2uniq $dest_dir $output_clean_dir + + cat $src_dir/wav.scp | sed -e "s/^\s*//g" | \ + cut -d' ' -f1 | \ + awk -v p1=$output_clean_wav_dir -v p2=$wav_prefix \ + '{printf("%s%s %s/%s%s.wav\n", p2, $1, p1, p2, $1);}'> $tmp_dir/clean_${random_seed}.list + fi + + if [ ! -z "$output_noise_dir" ]; then + utils/copy_data_dir.sh --extra-files utt2uniq $dest_dir $output_noise_dir + + cat $src_dir/wav.scp | sed -e "s/^\s*//g" | \ + cut -d' ' -f1 | \ + awk -v p1=$output_noise_wav_dir -v p2=$wav_prefix \ + '{printf("%s%s %s/%s%s.wav\n", p2, $1, p1, p2, $1);}'> $tmp_dir/noise_${random_seed}.list + fi +fi + + +## Create a list of new wave files +#if [ $stage -le 1 ]; then +# # Create the new wav.scp file +# python -c " +#import re +#file_ids = map(lambda x: x.split()[0], open('$src_dir/wav.scp').readlines()) +#dest_file_names = map(lambda x: x.split()[0], open('$tmp_dir/corrupted_${random_seed}.list')) +#for file_id, dest_file_name in zip(file_ids, dest_file_names): +# print '$wav_prefix{0} cat {1} |'.format(file_id, dest_file_name) +#" > $dest_dir/wav.scp +# +# if [ ! -z "$output_clean_dir" ]; then +# python -c " +#import re +#file_ids = map(lambda x: x.split()[0], open('$src_dir/wav.scp').readlines()) +#dest_file_names = map(lambda x: x.split()[0], open('$tmp_dir/clean_${random_seed}.list')) +#for file_id, dest_file_name in zip(file_ids, dest_file_names): +# print '$wav_prefix{0} cat {1} |'.format(file_id, dest_file_name) +#" > $output_clean_dir/wav.scp +# fi +# +# if [ ! -z "$output_noise_dir" ]; then +# python -c " +#import re +#file_ids = map(lambda x: x.split()[0], open('$src_dir/wav.scp').readlines()) +#dest_file_names = map(lambda x: x.split()[0], open('$tmp_dir/noise_${random_seed}.list')) +#for file_id, dest_file_name in zip(file_ids, dest_file_names): +# print '$wav_prefix{0} cat {1} |'.format(file_id, dest_file_name) +#" > $output_noise_dir/wav.scp +# fi +#fi + +if [ $stage -le 1 ]; then + # Modify segments file to point to the new wav files + if [ -f $dest_dir/segments ]; then + cat $dest_dir/segments | awk -v p=$wav_prefix \ + '{printf("%s %s%s %s %s\n", $1, p, $2, $3, $4);}' > $tmp_dir/segments_temp + mv $tmp_dir/segments_temp $dest_dir/segments + + if [ ! -z "$output_clean_dir" ]; then + cat $output_clean_dir/segments | awk -v p=$wav_prefix \ + '{printf("%s %s%s %s %s\n", $1, p, $2, $3, $4);}' > $tmp_dir/segments_temp + mv $tmp_dir/segments_temp $output_clean_dir/segments + fi + + if [ ! -z "$output_noise_dir" ]; then + cat $output_noise_dir/segments | awk -v p=$wav_prefix \ + '{printf("%s %s%s %s %s\n", $1, p, $2, $3, $4);}' > $tmp_dir/segments_temp + mv $tmp_dir/segments_temp $output_noise_dir/segments + fi + fi +fi + +# Remove these files as we would have to extract +# features for this new audio and out audio +# is single channel +for file in cmvn.scp feats.scp reco2file_and_channel; do + rm -f $dest_dir/$file + if [ ! -z "$output_clean_dir" ]; then + rm -f $output_clean_dir/$file + fi + if [ ! -z "$output_noise_dir" ]; then + rm -f $output_noise_dir/$file + fi +done + +#if $select_only_corruption_with_noise; then +# if [ $stage -le 2 ]; then +# mkdir -p ${impnoise_dir}_noisy/info +# cp ${impnoise_dir}/info/* ${impnoise_dir}_noisy/info +# cat ${impnoise_dir}/info/noise_impulse_* | grep "impulse_files" | \ +# python -c " +#import sys +#for line in sys.stdin.readlines(): +# for x in line.strip().split('=')[1].split(): +# print (x) +#" > ${impnoise_dir}_noisy/info/impulse_files +# fi +# impnoise_dir=${impnoise_dir}_noisy +#fi +[ ! -s $impnoise_dir/info/impulse_files ] && echo "$0: $impnoise_dir/info/impulse_files contains no impulses" && exit 1 +[ ! -s $impnoise_dir/info/background_noise_files ] && echo "$0: $impnoise_dir/info/noise_files contains no noises" && exit 1 +[ ! -s $impnoise_dir/info/foreground_noise_files ] && echo "$0: $impnoise_dir/info/noise_files contains no noises" && exit 1 + +if $pad_silence; then + sox -n -r 16000 -c 1 $tmp_dir/silence.wav trim 0.0 2.0 + + cat $src_dir/wav.scp | \ + awk -v sil=$tmp_dir/silence.wav '{ + if ($NF == 2) { + print $1" cat "$2" | sox "sil" -t wav - "sil" -t wav - |"; + } else { + print $0" sox "sil" -t wav - "sil" -t wav - |"; + } }' > $tmp_dir/src_wav.scp +else + cp $src_dir/wav.scp $tmp_dir/src_wav.scp +fi + +if [ $stage -le 3 ]; then + python local/snr/corrupt_wavs.py \ + --background-snrs $background_snrs --foreground-snrs $foreground_snrs \ + --random-seed $random_seed \ + --output-clean-wav-file-list $tmp_dir/clean_${random_seed}.list \ + --output-noise-wav-file-list $tmp_dir/noise_${random_seed}.list \ + $tmp_dir/src_wav.scp $tmp_dir/corrupted_${random_seed}.list $impnoise_dir \ + $tmp_dir/corrupt_wav_commands.${random_seed}.list +fi + +if [ $stage -le 4 ]; then + corrupt_wav_command_lists= + for i in `seq $nj`; do + corrupt_wav_command_lists="$corrupt_wav_command_lists $tmp_dir/corrupt_wav_commands.${random_seed}.$i.list" + done + + #if $select_only_corruptions_with_noise; then + # grep "noise-file=" $tmp_dir/corrupt_wav_commands.${random_seed}.list | \ + # awk '{print $1}' > $tmp_dir/wavs_with_noise.${random_seed}.list + + # utils/filter_scp.pl $tmp_dir/wavs_with_noise.${random_seed}.list \ + # $tmp_dir/corrupt_wav_commands.${random_seed}.list | \ + # utils/split_scp.pl /dev/stdin \ + # $corrupt_wav_command_lists + + # utils/filter_scp.pl $tmp_dir/wavs_with_noise.${random_seed}.list \ + # $tmp_dir/corrupted_${random_seed}.list | sort -k1,1 -u > $dest_dir/wav.scp + # utils/filter_scp.pl $tmp_dir/wavs_with_noise.${random_seed}.list \ + # $tmp_dir/clean_${random_seed}.list | sort -k1,2 -u > $output_clean_dir/wav.scp + # utils/filter_scp.pl $tmp_dir/wavs_with_noise.${random_seed}.list \ + # $tmp_dir/noise_${random_seed}.list | sort -k1,2 -u > $output_noise_dir/wav.scp + + # utils/fix_data_dir.sh $dest_dir + # utils/fix_data_dir.sh $output_noise_dir + # utils/fix_data_dir.sh $output_clean_dir + #else + utils/split_scp.pl $tmp_dir/corrupt_wav_commands.${random_seed}.list \ + $corrupt_wav_command_lists + #fi + + if ! $dry_run; then + $train_cmd JOB=1:$nj $tmp_dir/corrupt_wavs.${random_seed}.JOB.log \ + cat $tmp_dir/corrupt_wav_commands.${random_seed}.JOB.list \| \ + awk '{a=""; for (i=2; i<=NF; i++) a=a" "$i; print(a)}' \| \ + bash -xe || exit 1 + fi +fi + +if [ $stage -le 5 ]; then + cat $tmp_dir/corrupted_${random_seed}.list | sort -k1,1 > $dest_dir/wav.scp + cat $tmp_dir/clean_${random_seed}.list | sort -k1,1 > $output_clean_dir/wav.scp + cat $tmp_dir/noise_${random_seed}.list | sort -k1,1 > $output_noise_dir/wav.scp + + utils/fix_data_dir.sh $dest_dir + utils/fix_data_dir.sh $output_noise_dir + utils/fix_data_dir.sh $output_clean_dir +fi + +echo "Successfully generated corrupted data and stored it in $dest_dir." && exit 0; diff --git a/egs/wsj_noisy/s5/local/snr/corrupt_wavs.py b/egs/wsj_noisy/s5/local/snr/corrupt_wavs.py new file mode 100644 index 00000000000..68daa216de0 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/corrupt_wavs.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python +# Copyright 2014 Johns Hopkins University (Authors: Vijayaditya Peddinti). Apache 2.0. +# 2015 Tom Ko +# 2015 Vimal Manohar +# script to generate multicondition training data / dev data / test data +import argparse, glob, math, os, random, scipy.io.wavfile, sys + +class list_cyclic_iterator: + def __init__(self, list, random_seed = 0): + self.list_index = 0 + self.list = list + random.seed(random_seed) + random.shuffle(self.list) + + def next(self): + if (len(self.list) == 0): + return None + item = self.list[self.list_index] + self.list_index = (self.list_index + 1) % len(self.list) + return item + + def next_few(self, num_items): + if (len(self.list) == 0): + return None + a = [] + if (num_items > len(self.list)): + num_items = len(self.list) + for i in range(0, num_items): + item = self.list[self.list_index] + self.list_index = (self.list_index + 1) % len(self.list) + a.append(item) + return a + + def add_to_list(self, a): + self.list.insert(random.randrange(len(self.list)+1), a) + +def return_nonempty_lines(lines): + new_lines = [] + for line in lines: + if len(line.strip()) > 0: + new_lines.append(line.strip()) + return new_lines + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--background-snrs', type=str, default = '20:10:0', + help='snrs to be used for corruption') + parser.add_argument('--foreground-snrs', type=str, default = '20:10:0', + help='snrs to be used with foreground noises') + parser.add_argument('--rms-amplitudes', type=str, + default = '0.0001:0.0005:0.001:0.002:0.005:0.008:0.01:0.02:0.05:0.08:0.1:0.2:0.3:0.5', + help='desired signal rms') + parser.add_argument('--foreground-prob', type=float, default = 0.7, + help = 'probability with which to add foreground noise ' + 'on non-pair impulse noise files') + parser.add_argument('--foreground-prob-for-pair', type=float, default = 0.4, + help = 'probability with which to add foreground noise ' + 'on impulse-noise pairs') + parser.add_argument('--reverb-prob', type=float, default=0.4, + help = 'probability with which to add reverberation') + parser.add_argument('--random-seed', type = int, default = 0, + help = 'seed to be used in the randomization') + parser.add_argument('--output-clean-wav-file-list', type=str, + help='file to write clean output') + parser.add_argument('--output-noise-wav-file-list', type=str, + help='file to write noise output') + parser.add_argument('wav_scp', type=str, + help='wav.scp file to corrupt') + parser.add_argument('output_wav_scp', type=str, + help='wav.scp file to write corrupted output') + parser.add_argument('impulses_noises_dir', type=str, + help='directory with impulses and noises and info ' + 'directory (created by local/snr/prepare_noise_impulses.sh)') + parser.add_argument('output_command_file', type=str, + help='file to output the corruption commands') + params = parser.parse_args() + + add_noise = True + background_snr_string_parts = params.background_snrs.split(':') + foreground_snr_string_parts = params.foreground_snrs.split(':') + if (len(background_snr_string_parts) == 1 and + background_snr_string_parts[0] == "inf" and + len(foreground_snr_string_parts) == 1 and + foreground_snr_string_parts[0] == "inf"): + add_noise = False + + background_snrs = list_cyclic_iterator(background_snr_string_parts, + random_seed = params.random_seed) + foreground_snrs = list_cyclic_iterator(foreground_snr_string_parts, + random_seed = params.random_seed) + rms_amplitudes = list_cyclic_iterator(params.rms_amplitudes.split(':'), + random_seed = params.random_seed) + + wav_files = return_nonempty_lines(open(params.wav_scp, 'r').readlines()) + wav_out_files = return_nonempty_lines( + open(params.output_wav_scp, 'r').readlines()) + assert(len(wav_files) == len(wav_out_files)) + + if params.output_clean_wav_file_list is not None: + clean_wav_out_files = return_nonempty_lines( + open(params.output_clean_wav_file_list, 'r').readlines()) + assert(len(wav_files) == len(clean_wav_out_files)) + # TODO: They must also be corresponding files. This must be checked + # somewhere down the line. + if params.output_noise_wav_file_list is not None: + noise_wav_out_files = return_nonempty_lines( + open(params.output_noise_wav_file_list, 'r').readlines()) + assert(len(wav_files) == len(noise_wav_out_files)) + + impulses = list_cyclic_iterator(return_nonempty_lines( + open(params.impulses_noises_dir+'/info/impulse_files').readlines()), + random_seed = params.random_seed) # This list could be empty + + background_noises = list_cyclic_iterator(return_nonempty_lines( + open(params.impulses_noises_dir+'/info/background_noise_files').readlines()), + random_seed = params.random_seed) + # This must ideally not be empty because it will create infinities in SNR objective + + foreground_noises = list_cyclic_iterator(return_nonempty_lines( + open(params.impulses_noises_dir+'/info/foreground_noise_files').readlines()), + random_seed = params.random_seed) + # This list could be empty too, which just means we won't be adding any foreground noise + + # noise-impulse pair files. If a background noise has a corresponding pair + # then with a high probability, an rir paired with it will be selected. + # Also there will be a low probability for adding foreground noise. + noises_impulses_files = glob.glob(params.impulses_noises_dir+'/info/noise_impulse_*') + impulse_noise_index = [] + + for file in noises_impulses_files: + noises_list = [] + impulses_set = set([]) + for line in return_nonempty_lines(open(file).readlines()): + line = line.strip() + if len(line) == 0 or line[0] == '#': + continue + parts = line.split('=') + if parts[0].strip() == 'noise_files': + noises_list = list_cyclic_iterator(parts[1].split(), + random_seed = params.random_seed) + elif parts[0].strip() == 'impulse_files': + impulses_set = set(parts[1].split()) + else: + raise Exception('Unknown format of ' + file) + impulse_noise_index.append([impulses_set, noises_list]) + + command_list = [] + for i in range(len(wav_files)): + wav_file_splits = wav_files[i].split() + wav_file = " ".join(wav_file_splits[1:]) + file_id = wav_file_splits[0] + + splits = wav_out_files[i].split() + output_file_id = splits[0] + output_wav_file = " ".join(splits[1:]) + + if random.uniform(0,1) <= params.reverb_prob: + # randomly select corruption parameters + impulse_file = impulses.next() + else: + impulse_file = None + + found_impulse_noise_pair = False + if impulse_file is not None: + for x in impulse_noise_index: + if impulse_file in x[0]: + found_impulse_noise_pair = True + background_noise_file = x[1].next() + foreground_prob = params.foreground_prob_for_pair + + if not found_impulse_noise_pair: + foreground_prob = params.foreground_prob + background_noise_file = background_noises.next() + + background_snr = background_snrs.next() + rms_amplitude = rms_amplitudes.next() + + assert(len(wav_file.strip()) > 0) + assert(impulse_file is None or len(impulse_file.strip()) > 0) + assert(len(background_noise_file.strip()) > 0) + assert(len(background_snr.strip()) > 0) + assert(len(output_wav_file.strip()) > 0) + + rir_opts = '' + if impulse_file is not None: + rir_opts = '--rir-file={0}'.format(impulse_file) + + background_noise_opts = '--background-noise-file={0} --background-snr-db={1}'.format(background_noise_file, background_snr) + + foreground_noise_opts = '' + if random.uniform(0,1) <= foreground_prob: + foreground_snr = foreground_snrs.next() + foreground_noise_files = foreground_noises.next_few(10) + if foreground_noise_files is not None: + foreground_noise_opts = '--foreground-noise-files={0} --foreground-snr-db={1}'.format(":".join(foreground_noise_files), foreground_snr) + + volume_opts = "" + if rms_amplitude is not None: + assert(len(rms_amplitude.strip()) > 0) + volume_opts = "--volume=-1 --rms-amplitude={0} --normalize-by-power=true".format(rms_amplitude) + + if params.output_clean_wav_file_list is not None: + splits = clean_wav_out_files[i].split() + assert(output_file_id == splits[0]) + output_clean_wav_opts = "--output-clean-file={0}".format(" ".join(splits[1:])) + if params.output_noise_wav_file_list is not None: + splits = noise_wav_out_files[i].split() + assert(output_file_id == splits[0]) + output_noise_wav_opts = "--output-noise-file={0}".format(" ".join(splits[1:])) + + # wav_file here is something like "cat .wav |" + command = "{0} {1} corrupt-wav {2} {3} {4} {5} {6} {7} - {8}\n".format(output_file_id, wav_file, rir_opts, background_noise_opts, foreground_noise_opts, volume_opts, output_clean_wav_opts, output_noise_wav_opts, output_wav_file) + command_list.append(command) + + file_handle = open(params.output_command_file, 'w') + file_handle.write("".join(command_list)) + file_handle.close() + diff --git a/egs/wsj_noisy/s5/local/snr/create_segmented_data_dir_from_vad.sh b/egs/wsj_noisy/s5/local/snr/create_segmented_data_dir_from_vad.sh new file mode 100755 index 00000000000..962ec98e77c --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/create_segmented_data_dir_from_vad.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +set -o pipefail +set -e +set -u + +. path.sh + +cmd=run.pl +nj=4 +feats= + +. utils/parse_options.sh + +if [ $# -ne 6 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/train_100k exp/vad_data_prep_train_100k/file_vad/segments exp/vad_data_prep_train_100k/file_vad/seg2utt exp/make_segmented_feats/log segmented_feats data/train_100k_seg" + exit 1 +fi + +data=$1 +segments=$2 +seg2utt_file=$3 +tmpdir=$4 +featdir=$5 +dir=$6 + + +utils/copy_data_dir.sh --extra-files utt2uniq $data $dir + +if [ -z "$feats" ]; then + feats=$data/feats.scp + cp $data/cmvn.scp $dir +fi + +for f in $data/spk2utt $segments $seg2utt_file $feats; do + if [ ! -f $f ]; then + echo "$0: Could not read file $f" + exit 1 + fi +done + +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +mkdir -p $tmpdir $featdir + +rm -f $dir/{cmvn.scp,feats.scp,utt2spk,spk2utt,utt2uniq,spk2utt,text} +utils/filter_scp.pl -f 2 $data/spk2utt $segments > $dir/segments + +awk '{print $1" "$2}' $dir/segments > $dir/utt2spk + +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $featdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/wsj_noisy-$(date +'%m_%d_%H_%M')/s5/$featdir $featdir/storage +fi + +data_id=`basename $data` + +for n in `seq $nj`; do + utils/create_data_link.pl $featdir/storage/raw_feats_${data_id}.$n.ark +done + +$cmd JOB=1:$nj $tmpdir/extract_feature_segments.JOB.log \ + extract-feature-segments scp:$feats \ + "ark,t:utils/split_scp.pl -j $nj \$[JOB-1] $dir/segments |" \ + ark:- \| copy-feats --compress=true ark:- \ + ark,scp:$featdir/raw_feats_${data_id}.JOB.ark,$featdir/raw_feats_${data_id}.JOB.scp + +for n in `seq $nj`; do + cat $featdir/raw_feats_${data_id}.$n.scp +done | sort -k1,1 > $dir/feats.scp + +utils/fix_data_dir.sh $dir + diff --git a/egs/wsj_noisy/s5/local/snr/create_segmented_vad_from_vad.sh b/egs/wsj_noisy/s5/local/snr/create_segmented_vad_from_vad.sh new file mode 100755 index 00000000000..ca95486dd6d --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/create_segmented_vad_from_vad.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +set -o pipefail +set -e +set -u + +. path.sh + +cmd=run.pl +nj=4 + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/train_100k exp/vad_data_prep_train_100k/file_vad exp/make_segmented_vad/train_100k exp/vad_data_prep_train_100k/seg_vad" + exit 1 +fi + +data=$1 +vad_dir=$2 +tmpdir=$3 +dir=$4 + +segments=$vad_dir/segments +seg2utt_file=$vad_dir/reco2utt +vad_scp=$vad_dir/vad.scp + +for f in $vad_scp $segments $seg2utt_file; do + if [ ! -f $f ]; then + echo "$0: Could not read file $f" + exit 1 + fi +done + +mkdir -p $dir/split$nj + +utils/filter_scp.pl -f 2 $data/utt2spk $segments > $dir/segments +utils/filter_scp.pl -f 2 $data/utt2spk $seg2utt_file > $dir/reco2utt + +$cmd JOB=1:$nj $tmpdir/extract_vad_segments.JOB.log \ + extract-int-vector-segments scp:$vad_scp \ + "ark,t:utils/split_scp.pl -j $nj \$[JOB-1] $dir/segments |" \ + ark:- \| segmentation-init-from-ali ark:- ark:- \| \ + segmentation-post-process --merge-labels=0:2 --merge-dst-label=0 ark:- ark:- \| \ + segmentation-to-ali ark:- ark,scp:$dir/split$nj/vad.JOB.ark,$dir/split$nj/vad.JOB.scp || exit 1 + +for n in `seq $nj`; do + cat $dir/split$nj/vad.$n.scp +done | sort -k1,1 > $dir/vad.scp + diff --git a/egs/wsj_noisy/s5/local/snr/create_snr_data_dir.sh b/egs/wsj_noisy/s5/local/snr/create_snr_data_dir.sh new file mode 100755 index 00000000000..cc5b3e8d736 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/create_snr_data_dir.sh @@ -0,0 +1,115 @@ +#!/bin/bash +set -e +set -o pipefail + +. path.sh + +append_to_orig_feats=true +add_frame_snr=false +add_pov_feature=false +nj=4 +cmd=run.pl +stage=0 +dataid= +compress=true +type=Snr + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/train_100k_whole_hires exp/frame_snrs_snr_train_100k_whole exp/make_snr_data_dir snr_feats data/train_100k_whole_snr" + exit 1 +fi + +data=$1 +snr_dir=$2 +tmpdir=$3 +featdir=$4 +dir=$5 + +extra_files= +$append_to_orig_feats && extra_files="$extra_files $data/feats.scp" +$add_frame_snr && extra_files="$extra_files $snr_dir/frame_snrs.scp" + +scp_file=$snr_dir/nnet_pred_snrs.scp +type_str=snr +if [ $type == "Irm" ]; then + type_str=irm + scp_file=$snr_dir/nnet_pred.scp +fi + +for f in $scp_file $extra_files; do + [ ! -f $f ] && echo "$0: Could not find $f" && exit 1 +done + +featdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $featdir ${PWD}` + +[ -z "$dataid" ] && dataid=`basename $data` +mkdir -p $dir $featdir $tmpdir/$dataid + +for n in `seq $nj`; do + utils/create_data_link.pl $featdir/appended_${type_str}_feats_$dataid.$n.ark +done + +if $add_pov_feature; then +if [ $stage -le 0 ]; then + utils/split_data.sh $data $nj + sdata=$data/split$nj + $cmd JOB=1:$nj $tmpdir/make_pov_feat_$dataid.JOB.log \ + compute-kaldi-pitch-feats --config=conf/pitch.conf \ + scp:$sdata/JOB/wav.scp ark:- \| process-kaldi-pitch-feats \ + --add-delta-pitch=false --add-normalized-log-pitch=false ark:- \ + ark,scp:$featdir/pov_feature_$dataid.JOB.ark,$featdir/pov_feature_$dataid.JOB.scp || exit 1 +fi +fi + +for n in `seq $nj`; do + cat $featdir/pov_feature_$dataid.$n.scp +done > $snr_dir/pov_feature.scp + +if [ $stage -le 1 ]; then + if $append_to_orig_feats; then + utils/split_data.sh $data $nj + sdata=$data/split$nj + + if [ $type == "Snr" ]; then + if $add_frame_snr; then + append_opts="paste-feats ark:- scp:$snr_dir/frame_snrs.scp ark:- |" + fi + fi + if $add_pov_feature; then + append_opts="paste-feats --length-tolerance=2 ark:- scp:$snr_dir/pov_feature.scp ark:- |" + fi + + $cmd JOB=1:$nj $tmpdir/$dataid/make_append_${type_str}_feats.JOB.log \ + paste-feats scp:$sdata/JOB/feats.scp scp:$scp_file ark:- \| \ + $append_opts copy-feats --compress=$compress ark:- \ + ark,scp:$featdir/appended_${type_str}_feats_$dataid.JOB.ark,$featdir/appended_${type_str}_feats_$dataid.JOB.scp || exit 1 + else + if [ $type == "Snr" ]; then + if $add_frame_snr; then + append_opts="paste-feats scp:- scp:$snr_dir/frame_snrs.scp ark:- |" + fi + fi + if $add_pov_feature; then + append_opts="paste-feats --length-tolerance=2 scp:- scp:$snr_dir/pov_feature.scp ark:- |" + fi + + $cmd JOB=1:$nj $tmpdir/$dataid/make_append_${type_str}_feats.JOB.log \ + utils/split_scp.pl -j $nj \$[JOB-1] $scp_file \| \ + $append_opts copy-feats --compress=$compress ark:- \ + ark,scp:$featdir/appended_${type_str}_feats_$dataid.JOB.ark,$featdir/appended_${type_str}_feats_$dataid.JOB.scp || exit 1 + fi + +fi + +utils/copy_data_dir.sh $data $dir +rm -f $dir/cmvn.scp + +steps/compute_cmvn_stats.sh --fake $dir $tmpdir/$dataid $featdir + +for n in `seq $nj`; do + cat $featdir/appended_${type_str}_feats_$dataid.$n.scp +done > $dir/feats.scp + diff --git a/egs/wsj_noisy/s5/local/snr/get_corruption_parameter_lists.py b/egs/wsj_noisy/s5/local/snr/get_corruption_parameter_lists.py new file mode 100755 index 00000000000..d760676da11 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/get_corruption_parameter_lists.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# Copyright 2014 Johns Hopkins University (Authors: Vijayaditya Peddinti). +# 2015 Vimal Manohar +# Apache 2.0. +# script to generate multicondition training data / dev data / test data +import argparse, glob, math, os, random, scipy.io.wavfile, sys + +class list_cyclic_iterator: + def __init__(self, list, random_seed = 0): + self.list_index = 0 + self.list = list + random.seed(random_seed) + random.shuffle(self.list) + + def next(self): + if (len(self.list) == 0): + return None + item = self.list[self.list_index] + self.list_index = (self.list_index + 1) % len(self.list) + return item + +# Return non-empty lines from a list of lines +def return_nonempty_lines(lines): + new_lines = [] + for line in lines: + if len(line.strip()) > 0: + new_lines.append(line.strip()) + return new_lines + +def exists_wavfile(file_name): + return os.path.isfile(file_name) + try: + scipy.io.wavfile.read(file_name) + return True + except IOError: + return False + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--snrs', type=str, default = '20:10:0', + help='snrs to be used for corruption') + parser.add_argument('--signal-dbs', type=str, default = '2:0:-2:-5:-10:-20:-40:-60:-80:-120:-160:-200:-300:-400', + help='clean signal dbs to be used') + parser.add_argument('--num-files-per-job', type=int, default = None, + help='number of commands for corruption to be stored in each file -- This is the number of parallel jobs that will be run') + parser.add_argument('--check-output-exists', type = str, default = 'True', + help = 'process file only if output file does not exist', choices = ['True', 'true', 'False', 'false']) + parser.add_argument('--random-seed', type = int, default = 0, + help = 'seed to be used in the randomization of corruption') + parser.add_argument('--clean-wav-scp-file', type=str, help='list file to write the clean output file before adding noise to create corrupted file') + parser.add_argument('--noise-wav-scp-file', type=str, help='list file to write the noise that is added to create the corrupted output file') + parser.add_argument('wav_scp_file', type=str, help='wav.scp file to corrupt') + parser.add_argument('output_wav_scp_file', type=str, help='list file to write corrupted output') + parser.add_argument('impulses_noises_dir', type=str, help='directory with impulses and noises and info directory (e.g. created by local/multicondition/prep_rirs.sh)') + parser.add_argument('output_command_file', type=str, help='file to output the corruption commands') + params = parser.parse_args() + + add_noise = True + snr_string_parts = params.snrs.split(':') + if (len(snr_string_parts) == 1) and snr_string_parts[0] == "inf": + add_noise = False + snrs = list_cyclic_iterator(snr_string_parts, random_seed = params.random_seed) + + signal_db_string_parts = params.signal_dbs.split(':') + signal_dbs = list_cyclic_iterator(signal_db_string_parts, random_seed = params.random_seed) + + if params.check_output_exists.lower() == 'true': + params.check_output_exists = True + else: + params.check_output_exists = False + + wav_files = return_nonempty_lines(open(params.wav_scp_file, 'r').readlines()) + wav_out_files = return_nonempty_lines(open(params.output_wav_scp_file, 'r').readlines()) + assert(len(wav_files) == len(wav_out_files)) + if params.clean_wav_scp_file is not None: + clean_wav_out_files = return_nonempty_lines(open(params.clean_wav_scp_file, 'r').readlines()) + assert(len(wav_files) == len(clean_wav_out_files)) + if params.noise_wav_scp_file is not None: + noise_wav_out_files = return_nonempty_lines(open(params.noise_wav_scp_file, 'r').readlines()) + assert(len(wav_files) == len(noise_wav_out_files)) + impulses = list_cyclic_iterator(return_nonempty_lines(open(params.impulses_noises_dir+'/info/impulse_files').readlines()), random_seed = params.random_seed) # This list could be empty + noises_impulses_files = glob.glob(params.impulses_noises_dir+'/info/noise_impulse_*') + impulse_noise_index = [] + + all_impulses_set = set() + for file in noises_impulses_files: + noises_list = [] + impulses_set = set([]) + for line in return_nonempty_lines(open(file).readlines()): + line = line.strip() + if len(line) == 0 or line[0] == '#': + continue + parts = line.split('=') + if parts[0].strip() == 'noise_files': + noises_list = list_cyclic_iterator(parts[1].split()) + elif parts[0].strip() == 'impulse_files': + impulses_set = set(parts[1].split()) + all_impulses_set.union(impulses_set) + else: + raise Exception('Unknown format of ' + file) + impulse_noise_index.append([impulses_set, noises_list]) + impulses = list_cyclic_iterator(list(all_impulses_set)) + + if params.num_files_per_job is None: + lines_per_file = len(wav_files) + else: + lines_per_file = params.num_files_per_job + num_parts = int(math.ceil(len(wav_files)/ float(lines_per_file))) # The number of parallel jobs + indices_per_file = map(lambda x: xrange(lines_per_file * (x-1), lines_per_file * x), range(1, num_parts)) + indices_per_file.append(xrange(lines_per_file * (num_parts-1), len(wav_files))) + + part_counter = 1 + commands_file_base, ext = os.path.splitext(params.output_command_file) + for indices in indices_per_file: + command_list = [] + for i in indices: + wav_file = " ".join(wav_files[i].split()[1:]) # Can be a pipe input + output_wav_file = wav_out_files[i] # An actual wave file + clean_wav_file = '' + noise_wav_file = '' + if params.clean_wav_scp_file is not None: + clean_wav_file = ''.join(['--out-clean-file ', clean_wav_out_files[i], ' ']) + if params.noise_wav_scp_file is not None: + noise_wav_file = ''.join(['--out-noise-file ', noise_wav_out_files[i], ' ']) + impulse_file = impulses.next() # Can be None + noise_file = '' + snr = '' + signal_db = '' + found_impulse = (impulse_file is not None) + found_noise = False + if add_noise: + for j in xrange(len(impulse_noise_index)): + if impulse_file is None and not impulse_noise_index[j][0]: + noise_file = impulse_noise_index[j][1].next() + snr = snrs.next() + signal_db = signal_dbs.next() + assert(len(wav_file.strip()) > 0) + assert(len(noise_file.strip()) > 0) + assert(len(snr.strip()) > 0) + assert(len(signal_db.strip()) > 0) + assert(len(output_wav_file.strip()) > 0) + command_list.append("{0} --noise-file {1} --snr-db {2} --signal-db {3} {4}{5}- {6} \n".format(wav_file, noise_file, snr, signal_db, clean_wav_file, noise_wav_file, output_wav_file)) + found_noise = True + break + if impulse_file in impulse_noise_index[j][0]: + noise_file = impulse_noise_index[j][1].next() + snr = snrs.next() + signal_db = signal_dbs.next() + assert(len(wav_file.strip()) > 0) + assert(len(impulse_file.strip()) > 0) + assert(len(noise_file.strip()) > 0) + assert(len(snr.strip()) > 0) + assert(len(signal_db.strip()) > 0) + assert(len(output_wav_file.strip()) > 0) + command_list.append("{0} --rir-file {1} --noise-file {2} --snr-db {3} --signal-db {4} {5}{6}- {7} \n".format(wav_file, impulse_file, noise_file, snr, signal_db, clean_wav_file, noise_wav_file, output_wav_file)) + found_impulse = True + found_noise = True + break + if not found_noise: + continue + assert (found_impulse) + assert(len(wav_file.strip()) > 0) + assert(len(impulse_file.strip()) > 0) + assert(len(output_wav_file.strip()) > 0) + command_list.append("{0} --rir-file {1} {2}{3}- {4} \n".format(wav_file, impulse_file, clean_wav_file, noise_wav_file, output_wav_file)) + if params.check_output_exists and exists_wavfile(output_wav_file): + # we perform the check at this point to ensure replication of (wavfile, impulse, noise, snr) tuples across runs. + command_list.pop() + file_handle = open("{0}.{1}{2}".format(commands_file_base, part_counter, ext), 'w') + part_counter += 1 + file_handle.write("".join(command_list)) + file_handle.close() + print num_parts diff --git a/egs/wsj_noisy/s5/local/snr/get_weights_for_ivector_extraction.sh b/egs/wsj_noisy/s5/local/snr/get_weights_for_ivector_extraction.sh new file mode 100755 index 00000000000..7cb26f0b06d --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/get_weights_for_ivector_extraction.sh @@ -0,0 +1,146 @@ +#!/bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0 + +set -o pipefail + +. path.sh + +cmd=run.pl +method=Viterbi +stage=-1 + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +nonsil_self_loop_probability=0.9 +nonsil_transition_probability=0.1 +sil_self_loop_probability=0.9 +sil_transition_probability=0.1 +silence_weight=0 +speech_prior=0.2 +sil_prior=0.8 + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + +. utils/parse_options.sh + +data_dir=data/dev_aspire_whole_seg_v102 +file_vad_dir=exp/vad_dev_aspire_v102 +dir=exp/nnet2_multicondition/ivector_weights_dev_aspire_whole_seg_v102 + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 $data_dir $file_vad_dir $dir" + exit 1 +fi + +data_dir=$1 +file_vad_dir=$2 +dir=$3 + +for f in $file_vad_dir/log_likes.1.scp; do + if [ ! -f $f ]; then + echo "$0: Could not find $f" + exit 1 + fi +done + +mkdir -p $dir + +nj=`cat $file_vad_dir/num_jobs` || exit 1 +utils/split_data.sh $data_dir $nj || exit 1 + +perl -e "\$sum_prior = $speech_prior + $sil_prior; printf ('[ %f %f ]', log($sil_prior)-log(\$sum_prior), log($speech_prior)-log(\$sum_prior));" > $dir/log_priors + +case $method in + "Weighting") + $cmd JOB=1:$nj $dir/log/extract_weights.JOB.log \ + extract-feature-segments --snip-edges=true \ + "ark:cat $file_vad_dir/log_likes.*.ark |" \ + ark,t:$data_dir/split$nj/JOB/segments ark:- \| \ + matrix-add-offset ark:- $dir/log_priors ark:- \| \ + logprob-to-post ark:- ark:- \| \ + weight-pdf-post $silence_weight 0 ark:- ark:- \| \ + post-to-weights ark:- "ark:|gzip -c > $dir/weights.JOB.gz" || exit 1 + ;; + "Viterbi") + # Prepare a lang directory + if [ $stage -le 1 ]; then + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh --num-sil-states $min_silence_duration \ + --num-nonsil-states $min_speech_duration \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fi + + feat_dim=2 # dummy. We don't need this. + if [ $stage -le 2 ]; then + $cmd $dir/log/create_transition_model.log gmm-init-mono \ + $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + fi + + lang=$dir/lang_test_sp${speech_prior}_sil${sil_prior} + if [ $stage -le 3 ]; then + cp -r $dir/lang $lang + perl -e '$sil_prior = shift @ARGV; $speech_prior = shift @ARGV; $s = $sil_prior + $speech_prior; $sil_prior = $sil_prior / $s; $speech_prior = $speech_prior / $s; $s = $sil_prior + $speech_prior; print "0 0 1 1 " . -log($sil_prior/(1.1 * $s)) . "\n0 0 2 2 ". -log($speech_prior/(1.1 * $s)). "\n0 ". -log(0.1 / 1.1)' $sil_prior $speech_prior | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + fi + + if [ $stage -le 4 ]; then + $cmd $dir/log/make_vad_graph.log \ + diarization/make_vad_graph.sh --iter trans \ + $lang $dir $dir/graph_test_${t}x || exit 1 + fi + + file_nj=`cat $file_vad_dir/num_jobs` || exit 1 + + log_likes=ark:$file_vad_dir/log_likes.JOB.ark + + decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + + if [ $stage -le 5 ]; then + $cmd JOB=1:$file_nj $dir/log/decode.JOB.log \ + decode-faster-mapped ${decoder_opts[@]} \ + $dir/trans.mdl \ + $dir/graph_test_${t}x/HCLG.fst $log_likes \ + ark:/dev/null ark:- \| \ + ali-to-pdf $dir/trans.mdl ark:- \ + "ark:|gzip -c > $dir/ali.JOB.gz" || exit 1 + fi + + if [ $stage -le 6 ]; then + $cmd JOB=1:$nj $dir/log/extract_weights.JOB.log \ + extract-int-vector-segments --snip-edges=true \ + "ark:gunzip -c $dir/ali.*.gz |" \ + ark,t:$data_dir/split$nj/JOB/segments ark:- \| \ + ali-to-post ark:- ark:- \| \ + weight-pdf-post $silence_weight 0 ark:- ark:- \| \ + post-to-weights ark:- "ark:|gzip -c > $dir/weights.JOB.gz" || exit 1 + fi + + ;; + *) + echo "$0: Unknown method $method for weights extraction" + exit 1 + ;; +esac + +for n in `seq $nj`; do cat $dir/weights.$n.gz; done > $dir/weights.gz + + diff --git a/egs/wsj_noisy/s5/local/snr/make_irm_targets.sh b/egs/wsj_noisy/s5/local/snr/make_irm_targets.sh new file mode 100644 index 00000000000..9868bb0e4ae --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/make_irm_targets.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0 + +nj=4 +cmd=run.pl +compress=true +data_id= + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 5 ]; then + echo "Usage: $0 [options] "; + echo "e.g.: $0 data/train_clean_fbank data/train_noise_fbank data/train_corrupted_hires exp/make_snr_targets/train snr_targets" + echo "options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +clean_fbank_dir=$1 +noise_fbank_dir=$2 +dir=$3 +tmpdir=$4 +targets_dir=$5 + +mkdir -p $targets_dir + +[ -z "$data_id" ] && data_id=`basename $dir` + +utils/split_data.sh $clean_fbank_dir $nj +utils/split_data.sh $noise_fbank_dir $nj + +targets_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $targets_dir ${PWD}` + +for n in `seq $nj`; do + utils/create_data_link.pl $targets_dir/${data_id}.$n.ark +done + +$cmd JOB=1:$nj $tmpdir/${tmpdir}_${data_id}.JOB.log \ + compute-snr-targets --target-type="Irm" \ + scp:$clean_fbank_dir/split$nj/JOB/feats.scp \ + scp:$noise_fbank_dir/split$nj/JOB/feats.scp \ + ark:- \| \ + copy-feats --compress=$compress ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp + +for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp +done > $dir/`basename $targets_dir`.scp + diff --git a/egs/wsj_noisy/s5/local/snr/make_snr_targets.sh b/egs/wsj_noisy/s5/local/snr/make_snr_targets.sh new file mode 100755 index 00000000000..86ed40d24f3 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/make_snr_targets.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0 +set -e +set -o pipefail + +nj=4 +cmd=run.pl +compress=true +data_id= +target_type=Irm +ali_rspecifier= +silence_phones_str=0 +apply_exp=false +ignore_noise_dir=false +ceiling=inf +floor=-inf +stage=0 +length_tolerance=2 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 5 ]; then + echo "Usage: $0 [options] --target-type (Irm|Snr) "; + echo " or : $0 [options] --target-type FbankMask "; + echo "e.g.: $0 data/train_clean_fbank data/train_noise_fbank data/train_corrupted_hires exp/make_snr_targets/train snr_targets" + echo "options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +clean_fbank_dir=$1 +noise_or_noisy_fbank_dir=$2 +dir=$3 +tmpdir=$4 +targets_dir=$5 + +mkdir -p $targets_dir + +[ -z "$data_id" ] && data_id=`basename $dir` + +utils/split_data.sh $clean_fbank_dir $nj +utils/split_data.sh $noise_or_noisy_fbank_dir $nj + +$ignore_noise_dir && utils/split_data.sh $dir $nj + +targets_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $targets_dir ${PWD}` + +for n in `seq $nj`; do + utils/create_data_link.pl $targets_dir/${data_id}.$n.ark +done + +#if [ ! -z "$ali_rspecifier" ]; then +# if ! $apply_exp; then +# echo "If --ali-rspecifier is specified, then apply-exp must be true." +# exit 1 +# fi +#fi + +apply_exp_opts= +if $apply_exp; then + apply_exp_opts=" copy-matrix --apply-exp=true ark:- ark:- |" +fi + +if [ $stage -le 1 ]; then + if ! $ignore_noise_dir; then + $cmd JOB=1:$nj $tmpdir/make_`basename $targets_dir`_${data_id}.JOB.log \ + compute-snr-targets --length-tolerance=$length_tolerance --target-type=$target_type ${ali_rspecifier:+--ali-rspecifier="$ali_rspecifier" --silence-phones=$silence_phones_str} \ + --floor=$floor --ceiling=$ceiling \ + scp:$clean_fbank_dir/split$nj/JOB/feats.scp \ + scp:$noise_or_noisy_fbank_dir/split$nj/JOB/feats.scp \ + ark:- \|$apply_exp_opts \ + copy-feats --compress=$compress ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp || exit 1 + else + feat_dim=$(feat-to-dim scp:$clean_fbank_dir/feats.scp -) || exit 1 + $cmd JOB=1:$nj $tmpdir/make_`basename $targets_dir`_${data_id}.JOB.log \ + compute-snr-targets --length-tolerance=$length_tolerance --target-type=$target_type ${ali_rspecifier:+--ali-rspecifier="$ali_rspecifier" --silence-phones=$silence_phones_str} \ + --floor=$floor --ceiling=$ceiling --binary-targets --target-dim=$feat_dim \ + scp:$dir/split$nj/JOB/feats.scp \ + ark:- \|$apply_exp_opts \ + copy-feats --compress=$compress ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp || exit 1 + fi +fi + +for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp +done > $dir/`basename $targets_dir`.scp diff --git a/egs/wsj_noisy/s5/local/snr/normalize_wavs.py b/egs/wsj_noisy/s5/local/snr/normalize_wavs.py new file mode 100755 index 00000000000..83055d45be2 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/normalize_wavs.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# Copyright 2014 Johns Hopkins University (Authors: Vijayaditya Peddinti). Apache 2.0. + +# normalizes the wave files provided in input file list with a common scaling factor +# the common scaling factor is computed to 1/\sqrt(1/(total_samples) * \sum_i{\sum_j x_i(j)^2}) where total_samples is sum of all samples of all wavefiles. If the data is multi-channel then each channel is treated as a seperate wave files +import argparse, scipy.io.wavfile, warnings, numpy as np, math + +def get_normalization_coefficient(file_list, is_rir, additional_scaling): + assert(len(file_list) > 0) + sampling_rate = None + total_energy = 0.0 + total_samples = 0.0 + prev_dtype_max_value = None + for file in file_list: + try: + [rate, data] = scipy.io.wavfile.read(file) + if not str(data.dtype) in set(['int16', 'int32', 'int64']): + raise Exception('Cannot process {0}, only wav files of integer type are suppported'.format(file)) + + dtype_max_value = np.iinfo(data.dtype).max + # ensure that all the data in the current list is of the same format + if prev_dtype_max_value is not None: + assert(dtype_max_value == prev_dtype_max_value) + prev_dtype_max_value = dtype_max_value + + if len(data.shape) == 1: + data = data.reshape([data.shape[0], 1]) + if sampling_rate is not None: + assert(rate == sampling_rate) + else: + sampling_rate = rate + data = data / dtype_max_value + if is_rir: + # just count the energy of the direct impulse response + # this is treated as energy of signal from 0.001 seconds before impulse + # to 0.05 seconds after impulse. This is done as we do not want the + # recording length to influence the scaling factor + channel_one = data[:, 0] + max_d = max(channel_one) + delay_impulse = [i for i, j in enumerate(channel_one) if j == max_d][0] + before_impulse = np.floor(rate * 0.001) + after_impulse = np.floor(rate * 0.05) + start_index = max(0, delay_impulse - before_impulse) + end_index = min(len(channel_one), delay_impulse + after_impulse) + else: + start_index = 0 + end_index = data.shape[0] + # numpy does not check for numerical overflow in integer type + # so we convert the data into floats + data = data.astype(np.float64) + total_energy += np.sum(data[start_index:end_index, :] ** 2) + data_shape = list(data.shape) + data_shape[0] = end_index-start_index + total_samples += np.prod(data_shape) + except IOError: + warnings.warn("Did not find the file {0}.".format(file)) + assert(total_samples > 0) + scaling_coefficient = np.sqrt(total_samples / total_energy) + print "Scaling coefficient is {0}.".format(scaling_coefficient) + if math.isnan(scaling_coefficient): + raise Exception(" Nan encountered while computing scaling coefficient. This is mostly due to numerical overflow") + return scaling_coefficient + +if __name__ == "__main__": + usage = """ Python script to normalize input wave file list""" + + parser = argparse.ArgumentParser(usage) + parser.add_argument('--is-room-impulse-response', type=str, default = "false", help='is the input a list of room impulse responses', choices = ['True', 'False', 'true', 'false']) + parser.add_argument('--extra-scaling-factor', type=float, default = 1.0, help='additional scaling factor to be multiplied with the wav files') + parser.add_argument('input_file_list', type=str, help='list of wav files to be normalized collectively') + parser.add_argument('output_file', type=str, help='output file to store normalization coefficient') + params = parser.parse_args() + if params.is_room_impulse_response.lower() == 'true': + params.is_room_impulse_response = True + else: + params.is_room_impulse_response = False + + file_list = [] + for line in open(params.input_file_list).readlines(): + if len(line.strip()) > 0 : + file_list.append(line.strip()) + norm_coefficient = get_normalization_coefficient(file_list, params.is_room_impulse_response, params.extra_scaling_factor) + out_file = open(params.output_file, 'w') + out_file.write('{0}'.format(norm_coefficient)) + out_file.close() diff --git a/egs/wsj_noisy/s5/local/snr/prepare_impulses_noises.sh b/egs/wsj_noisy/s5/local/snr/prepare_impulses_noises.sh new file mode 100755 index 00000000000..848d9104498 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/prepare_impulses_noises.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# set -e + +# Copyright 2014 Johns Hopkins University (Author: Vijayaditya Peddinti) +# Apache 2.0. +# This script processes RIRs available from available databases into +# 8Khz wav files so that these can be used to corrupt the Fisher data. +# The databases used are: +# RWCP : http://research.nii.ac.jp/src/en/RWCP-SSD.html (this data is mirrored @ openslr.org. Thanks to Mitsubishi Electric Research Laboratories) +# AIRD : Aachen Impulse response database (http://www.ind.rwth-aachen.de/en/research/tools-downloads/aachen-impulse-response-database/) +# Reverb2014 : http://reverb2014.dereverberation.com/download.html +# OpenAIR : http://www.openairlib.net/auralizationdb +# MARDY : http://www.commsp.ee.ic.ac.uk/~sap/resources/mardy-multichannel-acoustic-reverberation-database-at-york-database/ +# QMUL impulse response dataset : http://c4dm.eecs.qmul.ac.uk/rdr/handle/123456789/6 +# Impulse responses from Varechoic chamber at Bell Labs : http://www1.icsi.berkeley.edu/Speech/papers/gelbart-ms/pointers/ +# Concert Hall impulse responses, Aalto University : http://legacy.spa.aalto.fi/projects/poririrs/ +set -e +set -o pipefail +set -u + +stage=0 +download_rirs=true # download the RIRs +sampling_rate=16000 # sampling rate to be used for the RIRs +log_dir=exp/make_reverb/log # directory to store the log files +RIR_home=db/RIR_databases/ # parent directory of the RIR databases files +db_string="'aalto' 'air' 'rwcp' 'rvb2014' 'c4dm' 'varechoic' 'mardy' 'openair' 'musan'" # RIR dbs to be used in the experiment + +. cmd.sh +. path.sh +. utils/parse_options.sh + +echo $* +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/impulse_noises" + exit 1; +fi + +output_dir=$1 +mkdir -p $output_dir/info ${output_dir}_non_normalized/info +mkdir -p $log_dir + +if [ -z "$db_string" ]; then + echo "$0 : Please specify the db_string."; + exit 1; +fi + +# write the file_splitter to create job files for queue.pl +# we use this to parallelize the audio corruption and download jobs +cat << EOF > $log_dir/file_splitter.py +#!/usr/bin/env python +import os.path, sys, math + +num_lines_per_file = int(sys.argv[1]) +input_file = sys.argv[2] +[file_base_name, ext] = os.path.splitext(input_file) +lines = open(input_file).readlines(); +num_lines = len(lines) +num_jobs = int(math.ceil(num_lines/ float(num_lines_per_file))) + +# filtering commands into seperate task files +for i in xrange(1, num_jobs+1) : + cur_lines = map(lambda index: lines[index], range(i - 1, num_lines , num_jobs)) + file = open("{0}.{1}{2}".format(file_base_name, i, ext), 'w') + file.write("which python\n") + file.write("".join(cur_lines)) + file.close() +print num_jobs +EOF +chmod +x $log_dir/file_splitter.py + +if [ $stage -le 1 ]; then + echo "Extracting the impulse responses from the databases $db_string" + num_db_jobs=`echo $db_string|wc -w` + $decode_cmd JOB=1:$num_db_jobs $log_dir/log/DBprocess.JOB.log \ + db=\(0 $db_string \) \&\& \ + local/multi_condition/rirs/prep_\$\{db\[JOB\]\}.sh \ + --file-splitter "$log_dir/file_splitter.py 10 " \ + --download $download_rirs --sampling-rate $sampling_rate \ + $RIR_home ${output_dir}_non_normalized/\$\{db\[JOB\]\} $log_dir/\${db\[JOB\]\} || exit 1; +fi + +if [ $stage -le 2 ]; then + echo "Normalizing the extracted room impulse responses and noises, per type" + echo "Note: Due to wav-format mismatch between sox and scipy, there might be warnings generated during file normalization." + echo " 'WavFileWarning: Unknown wave file format' warnings are benign." + # normalizing the RIR files + for i in `find ${output_dir}_non_normalized -name "*type*.rir.list"`; do + echo "Processing files in $i" + python local/multi_condition/normalize_wavs.py \ + --is-room-impulse-response true $i $i.normval || exit 1; + norm_coefficient=`cat $i.normval` + echo "" > $i.normalized + while read file_name; do + if [ ! -z $file_name ]; then + output_file_name=`echo $file_name | sed s:${output_dir}_non_normalized:${output_dir}:g` + mkdir -p `dirname $output_file_name` + sox --volume $norm_coefficient -t wav $file_name -t wav $output_file_name 2>/dev/null + echo $output_file_name >> $i.normalized + fi + done < $i + done +fi + +if [ $stage -le 3 ]; then + # normalizing the noise files + for i in `find ${output_dir}_non_normalized -name "*type*.noise.list" -o -name "*.background.noise.list" -o -name "*.foreground.noise.list"`; do + echo "Processing files in $i" + python local/multi_condition/normalize_wavs.py --is-room-impulse-response false $i $i.normval || exit 1; + norm_coefficient=`cat $i.normval` + echo "" > $i.normalized + while read file_name; do + if [ ! -z $file_name ]; then + output_file_name=`echo $file_name | sed s:${output_dir}_non_normalized:${output_dir}:g` + mkdir -p `dirname $output_file_name` + sox --volume $norm_coefficient -t wav $file_name -t wav $output_file_name 2>/dev/null + echo $output_file_name >> $i.normalized + fi + done < $i + done +fi + +if [ $stage -le 4 ]; then + # copying the noise-rir pairing files + db_string_bash=($(echo $db_string|sed -e "s/'//g")) + for i in `seq 0 $[${#db_string_bash[@]}-1]`; do + x=${db_string_bash[i]} + mkdir -p $output_dir/$x/info + cp ${output_dir}_non_normalized/$x/info/* $output_dir/$x/info + done + + # generating the rir-list + db_string_python=$(echo $db_string|sed -e "s/'\s\+'/','/g") +python -c " +import glob, string, re, sys +dbs=[$db_string_python] +rirs = [] +for db in dbs: + files = glob.glob('${output_dir}/{0}/info/{1}*type*.rir.list.normalized'.format(db,string.upper(db))) + for file in files: + for line in open(file).readlines(): + if len(line.strip()) > 0: + rirs.append(line.strip()) +if len(rirs) == 0: + sys.stderr.write('Did not read any rirs') + sys.exit(1) +final_rir_list_file = open('$output_dir/info/impulse_files', 'w') +final_rir_list_file.write('\n'.join(rirs)) +final_rir_list_file.close() +" + + # generating the backgroud noise-list + db_string_python=$(echo $db_string|sed -e "s/'\s\+'/','/g") + python -c " +import glob, string, re, sys +dbs=[$db_string_python] +noises = [] +for db in dbs: + files = glob.glob('$output_dir/{0}/info/{1}*type*.noise.list.normalized'.format(db,string.upper(db))) + files.extend(glob.glob('$output_dir/{0}/info/{1}*.background.noise.list.normalized'.format(db,string.upper(db)))) + for file in files: + sys.stderr.write(file) + for line in open(file).readlines(): + if len(line.strip()) > 0: + noises.append(line.strip()) +if len(noises) == 0: + sys.stderr.write('Did not read any noises') + sys.exit(1) +final_noise_list_file = open('$output_dir/info/background_noise_files', 'w') +final_noise_list_file.write('\n'.join(noises)) +final_noise_list_file.close() +" + + # generating the foreground noise-list + db_string_python=$(echo $db_string|sed -e "s/'\s\+'/','/g") + python -c " +import glob, string, re, sys +dbs=[$db_string_python] +noises = [] +for db in dbs: + files = glob.glob('$output_dir/{0}/info/{1}*.foreground.noise.list.normalized'.format(db,string.upper(db))) + for file in files: + for line in open(file).readlines(): + if len(line.strip()) > 0: + noises.append(line.strip()) +if len(noises) == 0: + sys.stderr.write('Did not read any noises') + sys.exit(1) +final_noise_list_file = open('$output_dir/info/foreground_noise_files', 'w') +final_noise_list_file.write('\n'.join(noises)) +final_noise_list_file.close() +" + + wc -l $output_dir/info/impulse_files + wc -l $output_dir/info/background_noise_files + wc -l $output_dir/info/foreground_noise_files + +fi diff --git a/egs/wsj_noisy/s5/local/snr/prepare_unsad_data.sh b/egs/wsj_noisy/s5/local/snr/prepare_unsad_data.sh new file mode 100755 index 00000000000..ad9fb0fd981 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/prepare_unsad_data.sh @@ -0,0 +1,550 @@ +#!/bin/bash + +set -u +set -e +set -o pipefail + +. path.sh + +stage=-2 +reco_nj=40 +nj=100 +cmd=queue.pl +map_noise_to_sil=true +map_unknown_to_speech=true +feat_type=mfcc +add_pitch=false +pitch_config= +phone_map= +feat_config= +config_dir=conf +outside_keep_proportion=1.0 +get_whole_recordings_and_weights=true + +. utils/parse_options.sh + +if [ $# -ne 6 ]; then + echo "This script takes a data directory and creates a new data directory " + echo "and speech activity labels " + echo "for the purpose of training a Universal Speech Activity Detector." + echo "Usage: $0 [options] " + echo " e.g.: $0 data/train_100k data/lang exp/tri4a_ali_100k exp/vad_data_prep" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --file-nj <#njobs|4> # Split a whole data directory into these many pieces" + echo " --nj <#njobs|4> # Split a segmented data directory into these many pieces" + exit 1 +fi + +data_dir=$1 +lang=$2 +ali_dir=$3 +model_dir=$4 +out_data_dir=$5 +dir=$6 + +if [ $feat_type != "plp" ] && [ $feat_type != "mfcc" ]; then + echo "$0: --feat-type must be plp or mfcc. Must match the model_dir used." + exit 1 +fi + +[ -z "$phone_map" ] && phone_map=$config_dir/phone_map +[ -z "$feat_config" ] && feat_config=$config_dir/$feat_type.conf +[ -z "$pitch_config" ] && pitch_config=$config_dir/pitch.conf + +extra_files= + +if $add_pitch; then + extra_files="$extra_files $pitch_config" +fi + +for f in $phone_map $feat_config $extra_files; do + if [ ! -f $f ]; then + echo "$f could not be found" + exit 1 + fi +done + +mkdir -p $dir + +function make_mfcc { + local nj=$nj + local mfcc_config=$feat_config + local add_pitch=$add_pitch + local cmd=$cmd + local pitch_config=$pitch_config + + while [ $# -gt 0 ]; do + if [ $1 == "--nj" ]; then + nj=$2 + shift; shift; + elif [ $1 == "--mfcc-config" ]; then + mfcc_config=$2 + shift; shift; + elif [ $1 == "--add-pitch" ]; then + add_pitch=$2 + shift; shift; + elif [ $1 == "--cmd" ]; then + cmd=$2 + shift; shift; + elif [ $1 == "--pitch-config" ]; then + pitch_config=$2 + shift; shift; + else + break + fi + done + + if [ $# -ne 3 ]; then + echo "Usage: make_mfcc " + exit 1 + fi + + if $add_pitch; then + steps/make_mfcc_pitch.sh --cmd "$cmd" --nj $nj \ + --mfcc-config $mfcc_config --pitch-config $pitch_config $1 $2 $3 || exit 1 + else + steps/make_mfcc.sh --cmd "$cmd" --nj $nj \ + --mfcc-config $mfcc_config $1 $2 $3 || exit 1 + fi +} + +function make_plp { + local nj=$nj + local mfcc_config=$feat_config + local add_pitch=$add_pitch + local cmd=$cmd + local pitch_config=$pitch_config + + while [ $# -gt 0 ]; do + if [ $1 == "--nj" ]; then + nj=$2 + shift; shift; + elif [ $1 == "--plp-config" ]; then + plp_config=$2 + shift; shift; + elif [ $1 == "--add-pitch" ]; then + add_pitch=$2 + shift; shift; + elif [ $1 == "--cmd" ]; then + cmd=$2 + shift; shift; + elif [ $1 == "--pitch-config" ]; then + pitch_config=$2 + shift; shift; + else + break + fi + done + + if [ $# -ne 3 ]; then + echo "Usage: make_plp " + exit 1 + fi + + if $add_pitch; then + steps/make_plp_pitch.sh --cmd "$cmd" --nj $nj \ + --plp-config $plp_config --pitch-config $pitch_config $1 $2 $3 || exit 1 + else + steps/make_plp.sh --cmd "$cmd" --nj $nj \ + --plp-config $plp_config $1 $2 $3 || exit 1 + fi +} + +if $map_noise_to_sil || $map_unknown_to_speech; then + cat $phone_map | \ + awk -v map_noise_to_sil=$map_noise_to_sil -v map_unknown_to_speech=$map_unknown_to_speech \ + '{if ($2 == 2 && map_noise_to_sil == "true") print $1" 0"; + else if ($2 == 3 && map_unknown_to_speech) print $1" 1"; + else print $0;}' > \ + $dir/phone_map + phone_map=$dir/phone_map +fi + +data_id=$(basename $data_dir) + +utils/split_data.sh --per-reco $data_dir $reco_nj + +# Convert alignment for the provided segments into +# initial speech activity labels +vad_dir=$dir/`basename ${ali_dir}`_vad_${data_id} +if [ $stage -le -1 ]; then + diarization/convert_ali_to_vad.sh --phone-map $phone_map \ + --cmd "$cmd" \ + $data_dir $lang $ali_dir $vad_dir || exit 1 +fi + +[ ! -s $vad_dir/vad.scp ] && echo "$0: $vad_dir/vad.scp is empty" && exit 1 + +# Compute total lengths of each recording +if [ $stage -le 0 ]; then + $cmd JOB=1:$reco_nj $dir/log/get_recording_lengths.JOB.log \ + wav-to-duration scp:$data_dir/split$reco_nj/JOB/wav.scp \ + ark,t:- \| awk \'\{print \$1 " " int\(\$2 \* 100\)\}\' '>' $dir/reco_lengths.JOB.ark.txt || exit 1 + + for n in `seq $reco_nj`; do + cat $dir/reco_lengths.$n.ark.txt + done | sort -u > $dir/reco_lengths.ark.txt +fi + +# Create extended data directory that consists of the provided +# segments along with the segments outside it. +# This is basically dividing the whole recording into pieces +# consisting of pieces corresponding to the provided segments +# and outside the provided segments. + +# First create the segments outside of the provided segments +extended_data_dir=$dir/${data_id}_extended +if [ $stage -le 1 ]; then + rm -rf $extended_data_dir + mkdir -p $extended_data_dir/split$reco_nj + utils/copy_data_dir.sh $data_dir $extended_data_dir + for f in cmvn.scp feats.scp text; do + rm -f $extended_data_dir/$f + done + + $cmd JOB=1:$reco_nj $dir/log/get_empty_segments.JOB.log \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 --ignore-missing=false \ + "ark:segmentation-init-from-lengths --label=0 ark,t:$dir/reco_lengths.JOB.ark.txt ark:- |" \ + "ark:segmentation-init-from-segments $data_dir/split$reco_nj/JOB/segments ark:- |" \ + ark:- \| segmentation-post-process --remove-labels=1 ark:- ark:- \| \ + segmentation-post-process --max-segment-length=1000 --post-process-label=0 \ + ark:- ark:- \| segmentation-to-segments --single-speaker=true --frame-overlap=0 \ + ark:- ark,t:$extended_data_dir/split$reco_nj/utt2spk_empty.JOB \ + ark,t:$extended_data_dir/split$reco_nj/segments_empty.JOB || exit 1 +fi + +awk '{print $1" "$2"-"$1}' $data_dir/segments > $data_dir/old2new.utt_map + +# Combine provided segments with segments outside the provided segments to +# create the extended data directory +if [ $stage -le 2 ] ; then + for n in `seq $reco_nj`; do + cat $data_dir/split$reco_nj/$n/segments | \ + utils/apply_map.pl -f 1 $data_dir/old2new.utt_map | \ + cat - $extended_data_dir/split$reco_nj/segments_empty.$n | \ + sort -k1,1 | tee $extended_data_dir/split$reco_nj/segments.$n + done > $extended_data_dir/segments + + awk '{print $1" "$2}' $extended_data_dir/segments > $extended_data_dir/utt2spk + + utils/utt2spk_to_spk2utt.pl $extended_data_dir/utt2spk > $extended_data_dir/spk2utt + utils/fix_data_dir.sh $extended_data_dir +fi + +## Create text for the extended data directory +if [ $stage -le 3 ]; then + mkdir -p $dir/split$reco_nj + for n in `seq $reco_nj`; do + cat $extended_data_dir/split$reco_nj/utt2spk_empty.$n | awk '{print $1}' > \ + $extended_data_dir/split$reco_nj/text_empty.$n || exit 1 + cat $data_dir/split$reco_nj/$n/text | \ + utils/apply_map.pl -f 1 $data_dir/old2new.utt_map | \ + cat - $extended_data_dir/split$reco_nj/text_empty.$n | sort -k1,1 || tee $extended_data_dir/split$reco_nj/text.$n + done > $extended_data_dir/text + utils/fix_data_dir.sh $extended_data_dir +fi + +# Get initial voice activity labels for the outside segments and combine them +# with the voice activity labels for the provided segments. +# The extended voice activity labels are put in $dir/vad/vad.scp +if [ $stage -le 4 ]; then + mkdir -p $dir/vad + + # We split the initial vad.scp based on recording with the same splits as + # the other files + for n in `seq $reco_nj`; do + utils/filter_scp.pl $data_dir/split$reco_nj/$n/utt2spk $vad_dir/vad.scp | \ + utils/apply_map.pl -f 1 $data_dir/old2new.utt_map > \ + $dir/vad/vad_tmp.$n.scp || exit 1 + [ ! -s $dir/vad/vad_tmp.$n.scp ] && echo "$0: no utterances in $dir/vad/vad_tmp.$n.scp" && exit 1 + done + + $cmd JOB=1:$reco_nj $dir/log/get_empty_vad.JOB.log \ + segmentation-init-from-segments --label=0 --per-utt=true \ + $extended_data_dir/split$reco_nj/segments_empty.JOB ark:- \| \ + segmentation-to-ali ark:- ark,scp:$dir/vad/vad_empty.JOB.ark,$dir/vad/vad_empty.JOB.scp + + for n in `seq $reco_nj`; do + cat $dir/vad/vad_tmp.$n.scp $dir/vad/vad_empty.$n.scp | sort -k 1,1 | tee $dir/vad/vad.$n.scp + done > $dir/vad/vad.scp + + for n in `seq $reco_nj`; do + cat $dir/vad/vad_tmp.$n.scp + done > $dir/vad/vad_tmp.scp + + for n in `seq $reco_nj`; do + cat $dir/vad/vad_empty.$n.scp + done > $dir/vad/vad_empty.scp +fi + +# Make features for the extended data directory. +# At this stage, we can split into larger number of pieces. +if [ $stage -le 6 ]; then + if [ $feat_type == "mfcc" ]; then + make_mfcc --cmd "$cmd" --nj $nj \ + --mfcc-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${extended_data_dir} exp/make_mfcc/${data_id}_extended mfcc || exit 1 + elif [ $feat_type == "plp" ]; then + make_plp --cmd "$cmd" --nj $nj \ + --plp-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${extended_data_dir} exp/make_plp/${data_id}_extended plp || exit 1 + fi + utils/fix_data_dir.sh $extended_data_dir + + # We also create a temporary directory to compute cmvn stats + # only on the provided segments and copy the stats to the + # extended data directory + temp_data_dir=$dir/${data_id}_temp + + rm -rf $temp_data_dir || true + + awk '{print $2" "$1}' $data_dir/old2new.utt_map > $data_dir/new2old.utt_map + utils/subset_data_dir.sh --utt-list $data_dir/new2old.utt_map $extended_data_dir $temp_data_dir + + if [ $feat_type == "mfcc" ]; then + make_mfcc --cmd "$cmd" --nj $nj \ + --mfcc-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${temp_data_dir} exp/make_mfcc/${data_id}_temp mfcc || exit 1 + steps/compute_cmvn_stats.sh \ + ${temp_data_dir} exp/make_mfcc/${data_id}_temp mfcc || exit 1 + elif [ $feat_type == "plp" ]; then + make_plp --cmd "$cmd" --nj $nj \ + --plp-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${temp_data_dir} exp/make_plp/${data_id}_temp plp || exit 1 + steps/compute_cmvn_stats.sh \ + ${temp_data_dir} exp/make_plp/${data_id}_temp plp || exit 1 + fi + + cp ${temp_data_dir}/cmvn.scp $extended_data_dir + rm -rf $extended_data_dir/split* +fi + +# By default, we use word LM. If required, we can think +# consider phone LM +graph_dir=$model_dir/graph +if [ $stage -le 7 ]; then + if [ ! -d $graph_dir ]; then + utils/mkgraph.sh ${lang} $model_dir $graph_dir || exit 1 + fi +fi + +# Decode without lattice (get only best path) +if [ $stage -le 8 ]; then + steps/decode_nolats.sh --cmd "$cmd --mem 2G" --nj $nj \ + --max-active 1000 --beam 10.0 --write-words false \ + --write-alignments true \ + $graph_dir ${extended_data_dir} \ + ${model_dir}/decode_${data_id}_extended || exit 1 +fi + +# Get VAD based on the decoded best path +decode_vad_dir=$dir/${model_dir}_decode_vad_${data_id} +if [ $stage -le 9 ]; then + diarization/convert_ali_to_vad.sh --phone-map $phone_map \ + --cmd "$cmd" --model $model_dir/final.mdl \ + $extended_data_dir $graph_dir \ + $model_dir/decode_${data_id}_extended $decode_vad_dir || exit 1 +fi + + + for n in `seq $reco_nj`; do + cat $dir/vad/vad_tmp.$n.scp + done > $dir/vad/vad_tmp.scp + + for n in `seq $reco_nj`; do + cat $dir/vad/vad_empty.$n.scp + done > $dir/vad/vad_empty.scp +# Intersect the initial VAD with the VAD from the decode +if [ $stage -le 10 ]; then + vad_scps=() + mkdir -p $dir/vad/split$nj + mkdir -p $decode_vad_dir/split$nj + for n in `seq $nj`; do + vad_scps+=($dir/vad/split$nj/vad.$n.scp) + done + utils/split_scp.pl $dir/vad/vad.scp ${vad_scps[@]} + + mkdir -p $dir/intersected_segmentations + + # For outside of the provided segments, + # * Intersect the initial VAD and the decode VAD and label the mismatch + # regions as class 10, which can be removed later. + $cmd JOB=1:$nj $dir/log/intersect_empty_segments.JOB.log \ + utils/filter_scp.pl $dir/vad/vad_empty.scp $dir/vad/split$nj/vad.JOB.scp \ + '>' $dir/vad/split$nj/vad_empty.JOB.scp '&&' \ + utils/filter_scp.pl $dir/vad/split$nj/vad_empty.JOB.scp $decode_vad_dir/vad.scp \ + '>' $decode_vad_dir/split$nj/vad_empty.JOB.scp '&&' \ + segmentation-intersect-segments --mismatch-label=10 \ + "ark:segmentation-init-from-ali scp:$dir/vad/split$nj/vad_empty.JOB.scp ark:- |" \ + "ark:segmentation-init-from-ali scp:$decode_vad_dir/split$nj/vad_empty.JOB.scp ark:- |" \ + ark:- \| segmentation-post-process --remove-labels=10 \ + --merge-adjacent-segments=true --max-intersegment-length=10 ark:- \ + ark,scp:$dir/intersected_segmentations/intersected_segmentations_empty.JOB.ark,$dir/intersected_segmentations/intersected_segmentations_empty.JOB.scp || exit 1 + + # For the provided segments, + # * For now, just convert the inital VAD into segmentations + $cmd JOB=1:$nj $dir/log/intersect_provided_segments.JOB.log \ + utils/filter_scp.pl $dir/vad/vad_tmp.scp $dir/vad/split$nj/vad.JOB.scp \ + '>' $dir/vad/split$nj/vad_tmp.JOB.scp '&&' \ + utils/filter_scp.pl $dir/vad/split$nj/vad_tmp.JOB.scp $decode_vad_dir/vad.scp \ + '>' $decode_vad_dir/split$nj/vad_tmp.JOB.scp '&&' \ + segmentation-intersect-segments --mismatch-label=10 \ + "ark:segmentation-init-from-ali scp:$dir/vad/split$nj/vad_tmp.JOB.scp ark:- |" \ + "ark:segmentation-init-from-ali scp:$decode_vad_dir/split$nj/vad_tmp.JOB.scp ark:- |" \ + ark:- \| segmentation-post-process --remove-labels=10 \ + --merge-adjacent-segments=true --max-intersegment-length=10 ark:- \ + ark,scp:$dir/intersected_segmentations/intersected_segmentations_tmp.JOB.ark,$dir/intersected_segmentations/intersected_segmentations_tmp.JOB.scp || exit 1 + + for n in `seq $nj`; do + cat $dir/intersected_segmentations/intersected_segmentations_empty.$n.scp + done > $dir/intersected_segmentations/intersected_segmentations_empty.scp + + for n in `seq $nj`; do + cat $dir/intersected_segmentations/intersected_segmentations_tmp.$n.scp + done > $dir/intersected_segmentations/intersected_segmentations_tmp.scp +fi + +# Optionally select only a small percentage of the outside utterances +# in the final set of utterances. This can be used to balance the amount of +# speech vs silence. +empty_copy_cmd="cat $dir/intersected_segmentations/intersected_segmentations_empty.scp" +if [ $outside_keep_proportion != 1.0 ]; then + nlines=`wc -l $dir/intersected_segmentations/intersected_segmentations_empty.scp` || exit 1 + empty_copy_cmd="utils/subset_scp.pl $nlines $dir/intersected_segmentations/intersected_segmentations_empty.scp" +fi + +if [ $stage -le 11 ]; then + eval $empty_copy_cmd | \ + cat - $dir/intersected_segmentations/intersected_segmentations_tmp.scp > \ + $dir/intersected_segmentations/final_segmentations_p$outside_keep_proportion.scp || exit 1 +fi + +if [ $stage -le 12 ]; then + awk '{print $1" "$2}' $extended_data_dir/segments | \ + utils/utt2spk_to_spk2utt.pl > $extended_data_dir/reco2utt + + mkdir -p $dir/reco_segmentations + mkdir -p $extended_data_dir/split$reco_nj + + reco2utts=() + for n in `seq $reco_nj`; do + reco2utts+=($extended_data_dir/split$reco_nj/reco2utt.$n) + done + utils/split_scp.pl $extended_data_dir/reco2utt ${reco2utts[@]} + + $cmd JOB=1:$reco_nj $dir/log/get_reco_segmentation.JOB.log \ + utils/spk2utt_to_utt2spk.pl $extended_data_dir/split$reco_nj/reco2utt.JOB '>' $extended_data_dir/split$reco_nj/utt2reco.JOB '&&' \ + segmentation-combine-segments \ + "scp:utils/filter_scp.pl $extended_data_dir/split$reco_nj/utt2reco.JOB $dir/intersected_segmentations/final_segmentations_p$outside_keep_proportion.scp |" \ + "ark,t:utils/filter_scp.pl $extended_data_dir/split$reco_nj/utt2reco.JOB $extended_data_dir/segments |" \ + ark,t:$extended_data_dir/split$reco_nj/reco2utt.JOB ark:- \| \ + segmentation-post-process --remove-labels=3 --merge-adjacent-segments=true \ + --max-segment-length=1000 ark:- \ + ark:$dir/reco_segmentations/reco_segmentation.JOB.ark || exit 1 + + mkdir -p $dir/reco_vad + + $cmd JOB=1:$reco_nj $dir/log/get_reco_vad.JOB.log \ + segmentation-to-ali --default-label=4 --lengths="ark,t:cat $dir/reco_lengths.*.ark.txt |" \ + ark:$dir/reco_segmentations/reco_segmentation.JOB.ark \ + ark,scp:$dir/reco_vad/vad.JOB.ark,$dir/reco_vad/vad.JOB.scp || exit 1 +fi + +if $get_whole_recordings_and_weights; then + if [ $stage -le 13 ]; then + rm -rf $dir/final_vad + mkdir -p $dir/final_vad + + # Get deriv weights assigning 1.0 to frames of classes 0, 1 or 2 + # and 0.0 to every other class. This ensures that training + # is done only on frames that have accurate labels + $cmd JOB=1:$reco_nj $dir/log/get_deriv_weights.JOB.log \ + segmentation-post-process --merge-labels=0:1:2 --merge-dst-label=1 \ + --remove-labels=3:4:10 \ + ark:$dir/reco_segmentations/reco_segmentation.JOB.ark ark:- \| \ + segmentation-to-ali --default-label=0 --lengths="ark,t:cat $dir/reco_lengths.*.ark.txt |" \ + ark:- ark:- \| ali-to-post ark:- ark:- \| weight-pdf-post 0 0 ark:- ark:- \| \ + post-to-weights ark:- \ + ark,scp:$dir/final_vad/deriv_weights.JOB.ark,$dir/final_vad/deriv_weights.JOB.scp || exit 1 + + # Get deriv weights assigning 1.0 to frames to class 0 + # and 0.0 to every other class. This ensures that training is + # done only on accurately labelled silence frames for the + # uncorrupted data. + # Note that we don't have noise for the uncorrupted data. + # Hence, we won't be able to get accurate sub-band IRM targets + # for the speech regions. But for silence regions, the IRM targets + # are 0 anyways. + $cmd JOB=1:$reco_nj $dir/log/get_deriv_weights_for_uncorrupted.JOB.log \ + segmentation-create-subsegments --filter-label=1 ark:$dir/reco_segmentations/reco_segmentation.JOB.ark "ark:segmentation-init-from-segments $extended_data_dir/split$reco_nj/segments_empty.JOB ark:- |" ark:- \| \ + segmentation-post-process --remove-labels=1:2:3:4:10 ark:- ark:- \| \ + segmentation-post-process --merge-labels=0 --merge-dst-label=1 \ + ark:- ark:- \| \ + segmentation-to-ali --default-label=0 --lengths="ark,t:cat $dir/reco_lengths.*.ark.txt |" \ + ark:- ark:- \| ali-to-post ark:- ark:- \| weight-pdf-post 0 0 ark:- ark:- \| \ + post-to-weights ark:- \ + ark,scp:$dir/final_vad/deriv_weights_for_uncorrupted.JOB.ark,$dir/final_vad/deriv_weights_for_uncorrupted.JOB.scp || exit 1 + fi + + rm -rf $out_data_dir + diarization/convert_data_dir_to_whole.sh $extended_data_dir $out_data_dir + rm -f $out_data_dir/{feats.scp,cmvn.scp,text} + + for n in `seq $reco_nj`; do + cat $dir/final_vad/deriv_weights.$n.scp + done > $dir/final_vad/deriv_weights.scp + + for n in `seq $reco_nj`; do + cat $dir/final_vad/deriv_weights_for_uncorrupted.$n.scp + done > $dir/final_vad/deriv_weights_for_uncorrupted.scp + + echo "$0: Finished creating corpus for training Universal SAD with deriv weights" + exit 0 +fi + +# Split the recording into new segments in the output data directory. +# Create VAD corresponding to the same segments +if [ $stage -le 13 ]; then + rm -rf $out_data_dir + utils/copy_data_dir.sh $extended_data_dir $out_data_dir + rm -f $out_data_dir/{feats.scp,cmvn.scp,text} + + mkdir -p $out_data_dir/split$reco_nj + + $cmd JOB=1:$reco_nj $dir/log/split_reco_into_segments.JOB.log \ + segmentation-post-process --merge-labels=0:1:2:3:4:10 --merge-dst-label=1 \ + --max-intersegment-length=10 --max-segment-length=1000 --merge-adjacent-segments \ + ark:$dir/reco_segmentations/reco_segmentation.JOB.ark ark:- \| \ + segmentation-to-segments --single-speaker=true --frame-overlap=0 \ + ark:- ark,t:$out_data_dir/split$reco_nj/utt2spk.JOB \ + ark,t:$out_data_dir/split$reco_nj/segments.JOB || exit 1 + + for n in `seq $reco_nj`; do + cat $out_data_dir/split$reco_nj/segments.$n + done > $out_data_dir/segments + + for n in `seq $reco_nj`; do + cat $out_data_dir/split$reco_nj/utt2spk.$n + done > $out_data_dir/utt2spk + + utils/utt2spk_to_spk2utt.pl $out_data_dir/utt2spk > $out_data_dir/spk2utt + + mkdir -p $dir/final_vad + $cmd JOB=1:$reco_nj $dir/log/extract_segment_vad.JOB.log \ + extract-int-vector-segments ark:$dir/reco_vad/vad.JOB.ark \ + ark,t:$out_data_dir/split$reco_nj/segments.JOB \ + ark,scp:$dir/final_vad/vad.JOB.ark,$dir/final_vad/vad.JOB.scp || exit 1 + + for n in `seq $reco_nj`; do + cat $dir/final_vad/vad.$n.scp + done > $dir/final_vad/vad.scp +fi + +echo "$0: Finished creating corpus for training Universal SAD" diff --git a/egs/wsj_noisy/s5/local/snr/prepare_vad_training_data.sh b/egs/wsj_noisy/s5/local/snr/prepare_vad_training_data.sh new file mode 100755 index 00000000000..11eb5f68964 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/prepare_vad_training_data.sh @@ -0,0 +1,290 @@ +#!/bin/bash + +set -u +set -e +set -o pipefail + +. path.sh +. cmd.sh + +stage=-2 +file_nj=40 +nj=100 +vad_dir= +ali_dir= +graph_dir=exp/tri4a/graph +transform_dir= +map_noise_to_sil=true +phone_map= +mfcc_config= + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/train_100k data/train_100k_corrupted exp/tri4a_ali_100k exp/vad_data_prep" + exit 1 +fi + +data_dir=$1 +corrupted_data_dir=$2 +lang=$3 +model_dir=$4 +dir=$5 + +mkdir -p $dir + +data_id=$(basename $data_dir) + +utils/split_data.sh $data_dir $file_nj + +if [ -z "$ali_dir" ]; then + ali_dir=$dir/`basename ${model_dir}`_ali_${data_id} + if [ $stage -le -2 ]; then + steps/align_si.sh --cmd "$train_cmd" --nj $nj \ + $data_dir $lang $model_dir $ali_dir || exit 1 + fi +fi + +if [ -z "$phone_map" ] || [ -z "$mfcc_config" ]; then + echo phone-map and mfcc-config are required && exit 1 +fi + +if [ ! -f $phone_map ]; then + echo "$0: Expecting $phone_map to exist!" && exit 1 +fi + +if $map_noise_to_sil; then + cat $phone_map | awk '{if ($2 == 2) print $1" 0"; else print $0}' > $dir/phone_map + phone_map=$dir/phone_map +fi + +if [ -z "$vad_dir" ]; then + vad_dir=$dir/`basename ${model_dir}`_vad_${data_id} + if [ $stage -le -1 ]; then + diarization/convert_ali_to_vad.sh --phone-map $phone_map \ + --cmd "$train_cmd" \ + $data_dir $lang $ali_dir $vad_dir || exit 1 + fi +fi + +if [ $stage -le 0 ]; then + $train_cmd JOB=1:$file_nj $dir/log/get_file_lengths.JOB.log \ + wav-to-duration scp:$data_dir/split$file_nj/JOB/wav.scp \ + ark,t:- \| awk \'\{print \$1 " " int\(\$2 \* 100\)\}\' '>' $dir/file_lengths.JOB.ark || exit 1 +fi + +extended_data_dir=$dir/${data_id}_extended +if [ $stage -le 1 ]; then + rm -rf $extended_data_dir + mkdir -p $extended_data_dir/split$file_nj + utils/copy_data_dir.sh $data_dir $extended_data_dir + for f in cmvn.scp feats.scp text; do + rm -f $extended_data_dir/$f + done + + $train_cmd JOB=1:$file_nj $dir/log/get_empty_segments.JOB.log \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=1 --ignore-missing=false \ + "ark:segmentation-init-from-lengths --label=0 ark:$dir/file_lengths.JOB.ark ark:- |" \ + "ark:segmentation-init-from-segments $data_dir/split$file_nj/JOB/segments ark:- |" \ + ark:- \| segmentation-post-process --remove-labels=1 ark:- ark:- \| \ + segmentation-post-process --max-segment-length=1000 --post-process-label=0 \ + ark:- ark:- \| segmentation-to-segments --single-speaker=true --frame-overlap=0 \ + ark:- ark,t:$extended_data_dir/split$file_nj/utt2spk_empty.JOB \ + ark,t:$extended_data_dir/split$file_nj/segments_empty.JOB || exit 1 +fi + +if [ $stage -le 2 ] ; then + for n in `seq $file_nj`; do + cat $extended_data_dir/split$file_nj/utt2spk_empty.$n $data_dir/split$file_nj/$n/utt2spk | sort -k1,1 | tee $extended_data_dir/split$file_nj/utt2spk.$n + done > $extended_data_dir/utt2spk + + [ ! -s $extended_data_dir/utt2spk ] && echo "$0: $extended_data_dir/utt2spk is empty!" && exit 1 + + for n in `seq $file_nj`; do + cat $extended_data_dir/split$file_nj/segments_empty.$n $data_dir/split$file_nj/$n/segments | sort -k1,1 | tee $extended_data_dir/split$file_nj/segments.$n + done > $extended_data_dir/segments + + utils/utt2spk_to_spk2utt.pl $extended_data_dir/utt2spk > $extended_data_dir/spk2utt + utils/fix_data_dir.sh $extended_data_dir +fi + +if [ $stage -le 3 ]; then + mkdir -p $dir/split$file_nj + for n in `seq $file_nj`; do + cat $extended_data_dir/split$file_nj/utt2spk_empty.$n | awk '{print $1}' > \ + $extended_data_dir/split$file_nj/text_empty.$n || exit 1 + cat $data_dir/split$file_nj/$n/text $extended_data_dir/split$file_nj/text_empty.$n | sort -k1,1 || tee $extended_data_dir/split$file_nj/text.$n + done > $extended_data_dir/text + utils/fix_data_dir.sh $extended_data_dir +fi + +[ ! -s $vad_dir/vad.scp ] && echo "$0: $vad_dir/vad.scp is empty" && exit 1 +if [ $stage -le 4 ]; then + mkdir -p $dir/vad + for n in `seq $file_nj`; do + utils/filter_scp.pl $data_dir/split$file_nj/$n/utt2spk $vad_dir/vad.scp > \ + $dir/vad/vad_tmp.$n.scp || exit 1 + [ ! -s $dir/vad/vad_tmp.$n.scp ] && echo "$0: no utterances in $dir/vad/vad_tmp.$n.scp" && exit 1 + done + + $train_cmd JOB=1:$file_nj $dir/log/get_empty_vad.JOB.log \ + segmentation-init-from-segments --label=0 --per-utt=true $extended_data_dir/split$file_nj/segments_empty.JOB ark:- \| \ + segmentation-to-ali ark:- ark,scp:$dir/vad/vad_empty.JOB.ark,$dir/vad/vad_empty.JOB.scp + + for n in `seq $file_nj`; do + cat $dir/vad/vad_tmp.$n.scp $dir/vad/vad_empty.$n.scp | sort -k 1,1 | tee $dir/vad/vad.$n.scp + done > $dir/vad/vad.scp +fi + +if [ $stage -le 6 ]; then + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj --mfcc-config $mfcc_config ${extended_data_dir} \ + exp/make_mfcc/${data_id}_whole mfcc || exit 1 + utils/fix_data_dir.sh $extended_data_dir + + temp_data_dir=$dir/${data_id}_temp + utils/copy_data_dir.sh $data_dir ${temp_data_dir} + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj --mfcc-config $mfcc_config ${temp_data_dir} \ + exp/make_mfcc/${data_id}_temp mfcc || exit 1 + steps/compute_cmvn_stats.sh ${temp_dir} exp/make_mfcc/${data_id}_temp mfcc +fi + +[ -z "$model_dir" ] && model_dir=$ali_dir +[ -z "$graph_dir" ] && graph_dir=$model_dir/graph + +if [ $stage -le 7 ]; then + if [ ! -d $graph_dir ]; then + utils/mkgraph.sh ${lang} $model_dir $graph_dir || exit 1 + fi +fi + +if [ $stage -le 8 ]; then + steps/decode_nolats.sh --cmd "$train_cmd --mem 2G" --nj $nj --transform-dir "$transform_dir" \ + --max-active 1000 --beam 10.0 --write-words false --write-alignments true \ + $graph_dir ${extended_data_dir} ${model_dir}/decode_${data_id}_whole || exit 1 +fi + +decode_vad_dir=$dir/${model_dir}_decode_vad_${data_id} +if [ $stage -le 9 ]; then + diarization/convert_ali_to_vad.sh --phone-map $phone_map \ + --cmd "$train_cmd" --model $model_dir/final.mdl \ + $extended_data_dir $graph_dir $model_dir/decode_${data_id}_whole $decode_vad_dir || exit 1 +fi + +if [ $stage -le 10 ]; then + vad_scps=() + mkdir -p $dir/vad/split$nj + mkdir -p $decode_vad_dir/split$nj + for n in `seq $nj`; do + vad_scps+=($dir/vad/split$nj/vad.$n.scp) + done + utils/split_scp.pl $dir/vad/vad.scp ${vad_scps[@]} + + mkdir -p $dir/intersected_segmentations + $train_cmd JOB=1:$nj $dir/log/intersect_segments_empty.JOB.log \ + utils/filter_scp.pl $data_dir/utt2spk $dir/vad/split$nj/vad.JOB.scp \ + '>' $dir/vad/split$nj/vad_tmp.JOB.scp '&&' \ + utils/filter_scp.pl --exclude $data_dir/utt2spk $dir/vad/split$nj/vad.JOB.scp \ + '>' $dir/vad/split$nj/vad_empty.JOB.scp '&&' \ + utils/filter_scp.pl $dir/vad/split$nj/vad_tmp.JOB.scp $decode_vad_dir/vad.scp \ + '>' $decode_vad_dir/split$nj/vad_tmp.JOB.scp '&&' \ + utils/filter_scp.pl $dir/vad/split$nj/vad_empty.JOB.scp $decode_vad_dir/vad.scp \ + '>' $decode_vad_dir/split$nj/vad_empty.JOB.scp '&&' \ + segmentation-intersect-segments --mismatch-label=10 \ + "ark:segmentation-init-from-ali scp:$dir/vad/split$nj/vad_empty.JOB.scp ark:- |" \ + "ark:segmentation-init-from-ali scp:$decode_vad_dir/split$nj/vad_empty.JOB.scp ark:- |" \ + ark,scp:$dir/intersected_segmentations/intersected_segmentations_empty.JOB.ark,$dir/intersected_segmentations/intersected_segmentations_empty.JOB.scp '&&' \ + segmentation-init-from-ali scp:$dir/vad/split$nj/vad_tmp.JOB.scp \ + ark,scp:$dir/intersected_segmentations/intersected_segmentations_tmp.JOB.ark,$dir/intersected_segmentations/intersected_segmentations_tmp.JOB.scp || exit 1 + + for n in `seq $nj`; do + cat $dir/intersected_segmentations/intersected_segmentations_empty.$n.scp + cat $dir/intersected_segmentations/intersected_segmentations_tmp.$n.scp + done > $dir/intersected_segmentations/final_segmentations.scp +fi + +#if [ $stage -le 11 ]; then +# for n in `seq +# utils/split_data.sh $extended_data_dir $nj +# +# $train_cmd JOB=1:$nj $dir/log/post_process_intersected_orig_segmentations.JOB.log \ +# utils/filter_scp.pl $data_dir/utt2spk $dir/intersected_segmentations.JOB.scp \| \ +# segmentation-post-process --remove-labels=10 --merge-adjacent-segments=true \ +# --max-intersegment-length=10 scp:- \ +# ark,scp:$dir/intersected_segmentations/intersected_segmentations_tmp.JOB.ark,$dir/intersected_segmentations/intersected_segmentations_tmp.JOB.scp || exit 1 +# +# $train_cmd JOB=1:$nj $dir/log/create_final_segmentations.JOB.log \ +# utils/filter_scp.pl --exclude $dir/intersected_segmentations/intersected_segmentations_tmp.JOB.scp \ +# $dir/intersected_segmentations/intersected_segmentations.JOB.scp \| \ +# cat $dir/intersected_segmentations/intersected_segmentations_tmp.JOB.scp - \| \ +# sort -k1,1 \| segmentation-post-process --remove-labels=10 --merge-adjacent-segments=true \ +# scp:- ark,scp:$dir/intersected_segmentations/final_segmentations.JOB.ark,$dir/intersected_segmentations/final_segmentations.JOB.scp || exit 1 +# +# for n in `seq $nj`; do +# cat $dir/intersected_segmentations/final_segmentations.$n.scp +# done > $dir/intersected_segmentations/final_segmentations.scp +#fi + +if [ $stage -le 12 ]; then + awk '{print $1" "$2}' $extended_data_dir/segments | \ + utils/utt2spk_to_spk2utt.pl > $extended_data_dir/reco2utt + + mkdir -p $dir/file_vad + + reco2utts=() + for n in `seq $file_nj`; do + reco2utts+=($extended_data_dir/split$file_nj/reco2utt.$n) + done + utils/split_scp.pl $extended_data_dir/reco2utt ${reco2utts[@]} + + $train_cmd JOB=1:$file_nj $dir/log/get_file_vad.JOB.log \ + utils/spk2utt_to_utt2spk.pl $extended_data_dir/split$file_nj/reco2utt.JOB '>' $extended_data_dir/split$file_nj/utt2reco.JOB '&&' \ + segmentation-combine-segments \ + "scp:utils/filter_scp.pl $extended_data_dir/split$file_nj/utt2reco.JOB $dir/intersected_segmentations/final_segmentations.scp |" \ + "ark,t:utils/filter_scp.pl $extended_data_dir/split$file_nj/utt2reco.JOB $extended_data_dir/segments |" \ + ark,t:$extended_data_dir/split$file_nj/reco2utt.JOB ark:- \| \ + segmentation-post-process --remove-labels=3:4 --merge-adjacent-segments=true ark:- ark:- \| \ + segmentation-to-ali --default-label=4 --lengths="ark:cat $dir/file_lengths.*.ark |" \ + ark:- ark,scp:$dir/file_vad/vad.JOB.ark,$dir/file_vad/vad.JOB.scp || exit 1 + + for n in `seq $file_nj`; do + cat $dir/file_vad/vad.$n.scp + done > $dir/file_vad/vad.scp +fi + +#vad_data_dir=$dir/${data_id}_vad +#if [ $stage -le 13 ]; then +# diarization/convert_data_dir_to_whole.sh $extended_data_dir $vad_data_dir +# utils/fix_data_dir.sh ${vad_data_dir} +# +# utils/copy_data_dir.sh ${vad_data_dir} ${vad_data_dir}_hires +# steps/make_mfcc.sh --mfcc-config conf/mfcc_hires.conf ${vad_data_dir}_hires exp/make_hires/${data_id}_vad mfcc_hires +# steps/compute_cmvn_stats.sh ${vad_data_dir}_hires exp/make_hires/${data_id}_vad mfcc_hires +# utils/fix_data_dir.sh ${vad_data_dir}_hires +# +# utils/copy_data_dir.sh ${vad_data_dir} ${vad_data_dir}_fbank +# steps/make_fbank.sh --fbank-config conf/fbank.conf ${vad_data_dir}_fbank exp/make_fbank/${data_id}_vad fbank +# steps/compute_cmvn_stats.sh --fake ${vad_data_dir}_fbank exp/make_fbank/${data_id}_vad mfcc_fbank +# utils/fix_data_dir.sh ${vad_data_dir}_fbank +#fi + +if [ $stage -le 13 ]; then + local/snr/prepare_vad_training_data.sh --nj $file_nj --cmd "$train_cmd" \ + --keep-only-speech false $dir/file_vad/ $dir/file_vad || exit 1 +fi + +# $train_cmd JOB=1:$file_nj $dir/file_vad/split$file_nj/log/get_segments.JOB.log \ +# segmentation-init-from-ali ark:$dir/file_vad/vad.JOB.ark ark:- \| \ +# segmentation-post-process --remove-labels=4:10 --merge-labels=0:1 --merge-dst-label=1 \ +# --shrink-length=20 --shrink-label=1 --merge-adjacent-segments=true --max-intersegment-length=1 \ +# ark:- ark:- \| segmentation-to-segments --single-speaker=true ark:- \ +# ark,t:$dir/file_vad/split$file_nj/reco2utt.JOB ark,t:$dir/file_vad/split$file_nj/segments.JOB || exit 1 +# +# for n in `seq $file_nj`; do +# cat $dir/file_vad/split$file_nj/reco2utt.$n +# done > $dir/file_vad/reco2utt +# +# for n in `seq $file_nj`; do +# cat $dir/file_vad/split$file_nj/segments.$n +# done > $dir/file_vad/segments +#fi diff --git a/egs/wsj_noisy/s5/local/snr/run_corrupt.sh b/egs/wsj_noisy/s5/local/snr/run_corrupt.sh new file mode 100755 index 00000000000..cfbafeb7b1a --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/run_corrupt.sh @@ -0,0 +1,319 @@ +#!/bin/bash +set -e +set -o pipefail + +. path.sh +. cmd.sh + +num_data_reps=5 +data_dir=data/train_si284 +dest_wav_dir=wavs +nj=40 +stage=1 +corruption_stage=-10 +pad_silence=false +mfcc_config=conf/mfcc_hires.conf +fbank_config=conf/fbank.conf +data_only=true +corrupt_only=true +dry_run=true + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +dataid=`basename ${data_dir}` + +if [ $stage -le $num_data_reps ]; then + corrupted_data_dirs= + start_state=1 + if [ $stage -gt 1 ]; then + start_stage=$stage + fi + for x in `seq $start_stage $num_data_reps`; do + cur_dest_dir=data/temp_${dataid}_$x + output_clean_dir=data/temp_clean_${dataid}_$x + output_noise_dir=data/temp_noise_${dataid}_$x + local/snr/corrupt_data_dir.sh --dry-run $dry_run --random-seed $x --dest-wav-dir $dest_wav_dir/corrupted$x \ + --output-clean-wav-dir $dest_wav_dir/clean$x --output-clean-dir $output_clean_dir \ + --output-noise-wav-dir $dest_wav_dir/noise$x --output-noise-dir $output_noise_dir \ + --pad-silence $pad_silence --stage $corruption_stage --tmp-dir exp/make_corrupt_$dataid/$x \ + --nj $nj $data_dir data/impulse_noises $cur_dest_dir + corrupted_data_dirs+=" $cur_dest_dir" + clean_data_dirs+=" $output_clean_dir" + noise_data_dirs+=" $output_noise_dir" + done + + rm -rf ${data_dir}_{corrupted,clean,noise} + utils/combine_data.sh --extra-files utt2uniq ${data_dir}_corrupted ${corrupted_data_dirs} + utils/combine_data.sh --extra-files utt2uniq ${data_dir}_clean ${clean_data_dirs} + utils/combine_data.sh --extra-files utt2uniq ${data_dir}_noise ${noise_data_dirs} + + rm -rf $corrupted_data_dirs + rm -rf $clean_data_dirs +fi + +data_id=`basename $data_dir` +corrupted_data_dir=${data_dir}_corrupted +corrupted_data_id=`basename $corrupted_data_dir` +clean_data_dir=${data_dir}_clean +clean_data_id=`basename $clean_data_dir` +noise_data_dir=${data_dir}_noise +noise_data_id=`basename $noise_data_dir` + +$corrupt_only && echo "--corrupt-only is true" && exit 1 + +mfccdir=mfcc_hires +#if [ $stage -le 2 ]; then +# if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then +# date=$(date +'%m_%d_%H_%M') +# utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$mfccdir/storage $mfccdir/storage +# fi +# +# utils/copy_data_dir.sh ${clean_data_dir} ${clean_data_dir}_hires +# steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj --mfcc-config $mfcc_config ${clean_data_dir}_hires exp/make_hires/${clean_data_id} mfcc_hires +# steps/compute_cmvn_stats.sh ${clean_data_dir}_hires exp/make_hires/${clean_data_id} mfcc_hires +# utils/fix_data_dir.sh ${clean_data_dir}_hires +#fi + +if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$mfccdir/storage $mfccdir/storage + fi + + rm -rf ${corrupted_data_dir}_hires + utils/copy_data_dir.sh ${corrupted_data_dir} ${corrupted_data_dir}_hires + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj --mfcc-config $mfcc_config ${corrupted_data_dir}_hires exp/make_hires/${corrupted_data_id} mfcc_hires || true + steps/compute_cmvn_stats.sh --fake ${corrupted_data_dir}_hires exp/make_hires/${corrupted_data_id} mfcc_hires + utils/fix_data_dir.sh --utt-extra-files utt2uniq ${corrupted_data_dir}_hires +fi + +fbankdir=fbank_feats +if [ $stage -le 13 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $fbankdir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$fbankdir/storage $fbankdir/storage + fi + + rm -rf ${clean_data_dir}_fbank + utils/copy_data_dir.sh ${clean_data_dir} ${clean_data_dir}_fbank + steps/make_fbank.sh --cmd "$train_cmd --max-jobs-run 50" --nj $nj --fbank-config $fbank_config ${clean_data_dir}_fbank exp/make_fbank/${clean_data_id} fbank_feats || true + steps/compute_cmvn_stats.sh --fake ${clean_data_dir}_fbank exp/make_fbank/${clean_data_id} fbank_feats + utils/fix_data_dir.sh ${clean_data_dir}_fbank +fi + +if [ $stage -le 14 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $fbankdir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$fbankdir/storage $fbankdir/storage + fi + + rm -rf ${noise_data_dir}_fbank + utils/copy_data_dir.sh ${noise_data_dir} ${noise_data_dir}_fbank + steps/make_fbank.sh --cmd "$train_cmd --max-jobs-run 50" --nj $nj --fbank-config $fbank_config ${noise_data_dir}_fbank exp/make_fbank/${noise_data_id} fbank_feats || true + steps/compute_cmvn_stats.sh --fake ${noise_data_dir}_fbank exp/make_fbank/${noise_data_id} fbank_feats + utils/fix_data_dir.sh ${noise_data_dir}_fbank +fi + +if [ $stage -le 15 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $fbankdir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$fbankdir/storage $fbankdir/storage + fi + + rm -rf ${corrupted_data_dir}_fbank + utils/copy_data_dir.sh ${corrupted_data_dir} ${corrupted_data_dir}_fbank + steps/make_fbank.sh --cmd "$train_cmd --max-jobs-run 50" --nj $nj --fbank-config $fbank_config ${corrupted_data_dir}_fbank exp/make_fbank/${corrupted_data_id} fbank_feats || true + steps/compute_cmvn_stats.sh --fake ${corrupted_data_dir}_fbank exp/make_fbank/${corrupted_data_id} fbank_feats + utils/fix_data_dir.sh --utt-extra-files utt2uniq ${corrupted_data_dir}_fbank +fi + +if [ $stage -le 16 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$mfccdir/storage $mfccdir/storage + fi + + rm -rf ${clean_data_dir}_hires + utils/copy_data_dir.sh ${clean_data_dir} ${clean_data_dir}_hires + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj --mfcc-config $mfcc_config ${clean_data_dir}_hires exp/make_hires/${clean_data_id} mfcc_hires || true + steps/compute_cmvn_stats.sh --fake ${clean_data_dir}_hires exp/make_hires/${clean_data_id} mfcc_hires + utils/fix_data_dir.sh --utt-extra-files utt2uniq ${clean_data_dir}_hires +fi + +if [ $stage -le 17 ]; then + utils/copy_data_dir.sh --utt-prefix "clean-" --spk-prefix "clean-" ${clean_data_dir}_fbank ${clean_data_dir}_clean_fbank + utils/copy_data_dir.sh --utt-prefix "clean-" --spk-prefix "clean-" ${clean_data_dir}_hires ${clean_data_dir}_clean_hires +fi + +if [ $stage -le 18 ]; then + rm -rf ${data_dir}_hires + utils/copy_data_dir.sh ${data_dir} ${data_dir}_hires + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj --mfcc-config $mfcc_config ${data_dir}_hires exp/make_hires/${data_id} mfcc_hires || true + steps/compute_cmvn_stats.sh --fake ${data_dir}_hires exp/make_hires/${data_id} mfcc_hires + utils/fix_data_dir.sh --utt-extra-files utt2uniq ${data_dir}_hires +fi + +if [ $stage -le 19 ]; then + rm -rf ${data_dir}_fbank + utils/copy_data_dir.sh ${data_dir} ${data_dir}_fbank + steps/make_fbank.sh --cmd "$train_cmd --max-jobs-run 50" --nj $nj --fbank-config $fbank_config ${data_dir}_fbank exp/make_fbank/${data_id} fbank_feats || true + steps/compute_cmvn_stats.sh --fake ${data_dir}_fbank exp/make_fbank/${data_id} fbank_feats + utils/fix_data_dir.sh --utt-extra-files utt2uniq ${data_dir}_fbank +fi + +if [ $stage -le 20 ]; then +utils/combine_data.sh --extra-files utt2uniq ${data_dir}_multi_fbank ${corrupted_data_dir} ${clean_data_dir}_clean_fbank ${data_dir}_fbank +utils/combine_data.sh --extra-files utt2uniq ${data_dir}_multi_hires ${corrupted_data_dir} ${clean_data_dir}_clean_hires ${data_dir}_hires +fi + +[ $(cat ${clean_data_dir}_fbank/utt2spk | wc -l) -ne $(cat ${corrupted_data_dir}_fbank/utt2spk | wc -l) ] && echo "$0: ${clean_data_dir}_fbank/utt2spk and ${corrupted_data_dir}_fbank/utt2spk have different number of lines" && exit 1 + +[ $(cat ${noise_data_dir}_fbank/utt2spk | wc -l) -ne $(cat ${corrupted_data_dir}_fbank/utt2spk | wc -l) ] && echo "$0: ${noise_data_dir}_fbank/utt2spk and ${corrupted_data_dir}_fbank/utt2spk have different number of lines" && exit 1 + +$data_only && echo "--data-only is true" && exit 1 + +tmpdir=exp/make_irm_targets +targets_dir=irm_targets +if [ $stage -le 16 ]; then + utils/split_data.sh ${clean_data_dir}_fbank $nj + utils/split_data.sh ${noise_data_dir}_fbank $nj + + sleep 2 + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{5,6,7,8}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$targets_dir/storage $targets_dir/storage + for n in `seq $nj`; do + utils/create_data_link.pl $targets_dir/${data_id}.$n.ark + done + fi + + mkdir -p $targets_dir + $train_cmd --max-jobs-run 30 JOB=1:$nj $tmpdir/${tmpdir}_${data_id}.JOB.log \ + compute-snr-targets --target-type="Irm" \ + scp:${clean_data_dir}_fbank/split$nj/JOB/feats.scp \ + scp:${noise_data_dir}_fbank/split$nj/JOB/feats.scp \ + ark:- \| \ + copy-feats --compress=true ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp + + for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp + done > ${corrupted_data_dir}_hires/`basename $targets_dir`.scp +fi + +exit 0 + +tmpdir=exp/make_fbank_mask_targets +targets_dir=fbank_mask_targets +if [ $stage -le 17 ]; then + utils/split_data.sh ${corrupted_data_dir}_fbank $nj + utils/split_data.sh ${clean_data_dir}_fbank $nj + + sleep 2 + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{1,2,3,4}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$targets_dir/storage $targets_dir/storage + for n in `seq $nj`; do + utils/create_data_link.pl $targets_dir/${data_id}.$n.ark + done + fi + + mkdir -p $targets_dir + $train_cmd --max-jobs-run 30 JOB=1:$nj $tmpdir/${tmpdir}_${data_id}.JOB.log \ + compute-snr-targets --target-type="FbankMask" \ + scp:${clean_data_dir}_fbank/split$nj/JOB/feats.scp \ + scp:${corrupted_data_dir}_fbank/split$nj/JOB/feats.scp \ + ark:- \| \ + copy-feats --compress=true ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp + + for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp + done > ${corrupted_data_dir}_hires/`basename $targets_dir`.scp +fi + +tmpdir=exp/make_snr_targets +targets_dir=snr_targets +if [ $stage -le 18 ]; then + utils/split_data.sh ${clean_data_dir}_fbank $nj + utils/split_data.sh ${noise_data_dir}_fbank $nj + + sleep 2 + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{5,6,7,8}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$targets_dir/storage $targets_dir/storage + for n in `seq $nj`; do + utils/create_data_link.pl $targets_dir/${data_id}.$n.ark + done + fi + + mkdir -p $targets_dir + $train_cmd --max-jobs-run 30 JOB=1:$nj $tmpdir/${tmpdir}_${data_id}.JOB.log \ + compute-snr-targets --target-type="Snr" \ + scp:${clean_data_dir}_fbank/split$nj/JOB/feats.scp \ + scp:${noise_data_dir}_fbank/split$nj/JOB/feats.scp \ + ark:- \| \ + copy-feats --compress=true ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp + + for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp + done > ${corrupted_data_dir}_hires/`basename $targets_dir`.scp +fi + +tmpdir=exp/make_frame_snr_correct_targets +targets_dir=frame_snr_correct_targets +if [ $stage -le 19 ]; then + utils/split_data.sh ${clean_data_dir}_fbank $nj + utils/split_data.sh ${noise_data_dir}_fbank $nj + + sleep 2 + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $targets_dir/storage ]; then + date=$(date +'%m_%d_%H_%M') + utils/create_split_dir.pl /export/b0{5,6,7,8}/$USER/kaldi-data/egs/wsj_noisy-$date/s5/$targets_dir/storage $targets_dir/storage + for n in `seq $nj`; do + utils/create_data_link.pl $targets_dir/${data_id}.$n.ark + done + fi + + mkdir -p $targets_dir + $train_cmd JOB=1:$nj $tmpdir/${tmpdir}_${data_id}.JOB.log \ + matrix-sum --scale1=1.0 --scale2=-1.0 \ + "ark:compute-mfcc-feats --config=conf/mfcc.conf --num-ceps=1 --num-mel-bins=3 scp:${clean_data_dir}_fbank/split$nj/JOB/wav.scp ark:- |" \ + "ark:compute-mfcc-feats --config=conf/mfcc.conf --num-ceps=1 --num-mel-bins=3 scp:${noise_data_dir}_fbank/split$nj/JOB/wav.scp ark:- |" \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp + + for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp + done > ${corrupted_data_dir}_hires/`basename $targets_dir`.scp +fi + +tmpdir=exp/make_frame_snr_targets +targets_dir=frame_snr_targets +if [ $stage -le 20 ]; then + utils/split_data.sh ${clean_data_dir}_fbank $nj + utils/split_data.sh ${noise_data_dir}_fbank $nj + + sleep 2 + + mkdir -p $targets_dir + $train_cmd --max-jobs-run 30 JOB=1:$nj $tmpdir/${tmpdir}_${data_id}.JOB.log \ + vector-sum \ + "ark:matrix-scale scp:${clean_data_dir}_fbank/split$nj/JOB/feats.scp ark:- | matrix-sum-cols --log-sum-exp=true ark:- ark:- |" \ + "ark:matrix-scale scp:${noise_data_dir}_fbank/split$nj/JOB/feats.scp ark:- | matrix-sum-cols --log-sum-exp=true ark:- ark:- | vector-scale --scale=-1.0 ark:- ark:- |" \ + ark:- \| vector-to-feat ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp + + for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp + done > ${corrupted_data_dir}_hires/`basename $targets_dir`.scp +fi diff --git a/egs/wsj_noisy/s5/local/snr/run_test.sh b/egs/wsj_noisy/s5/local/snr/run_test.sh new file mode 100755 index 00000000000..353bc9cd549 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/run_test.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +set -e +set -o pipefail +set -u + +. path.sh +. cmd.sh + +feat_affix=_bp_vh +snr_affix= +affix=_babel_assamese_r2000_a +reco_nj=32 + +src_data_dir=data/dev10h.pem +data_dir=data/babel_assamese_dev10h +irm_predictor=exp/nnet3_irm_predictor/nnet_tdnn_a_w_bp_vh_seg_babel_assamese_unsad_r2000_n6_lrate0.00003_0.000003 +predictor_iter=final +sad_nnet_dir=exp/nnet3_sad_snr/tdnn_irm_babel_assamese_train_unsad_splice5_2 +sad_nnet_iter=final +append_to_orig_feats=false +add_pov_feature=false +create_uniform_segments=false +overlap_length=500 +window_length=3000 +stage=-1 +feature_type=Snr + +segmentation_config=conf/segmentation.conf +weights_segmentation_config=conf/segmentation.conf +fbank_config=conf/fbank_bp.conf +mfcc_config=conf/mfcc_hires_bp.conf + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 $src_data_dir $data_dir $irm_predictor $sad_nnet_dir" + exit 1 +fi + +src_data_dir=$1 +data_dir=$2 +irm_predictor=$3 +sad_nnet_dir=$4 + +data_id=`basename $data_dir` +snr_data_dir=exp/frame_snrs_irm${affix}${feat_affix}_${data_id}_whole${feat_affix}/${data_id}${feat_affix}_snr${snr_affix} +sad_dir=${sad_nnet_dir}/sad${affix}_${data_id}_whole${feat_affix} +seg_dir=${sad_nnet_dir}/segmentation${affix}_${data_id}_whole${feat_affix} + +if [ $stage -le 0 ]; then + diarization/convert_data_dir_to_whole.sh $src_data_dir ${data_dir}_whole + utils/copy_data_dir.sh ${data_dir}_whole ${data_dir}_whole${feat_affix}_hires + utils/copy_data_dir.sh ${data_dir}_whole ${data_dir}_whole${feat_affix}_fbank +fi + +if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ + ${data_dir}_whole${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} mfcc_hires + steps/compute_cmvn_stats.sh ${data_dir}_whole${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} mfcc_hires +fi + +if [ $stage -le 2 ]; then + steps/make_fbank.sh --fbank-config $fbank_config --nj $reco_nj --cmd "$train_cmd" \ + ${data_dir}_whole${feat_affix}_fbank exp/make_fbank/${data_id}_whole${feat_affix} fbank + steps/compute_cmvn_stats.sh ${data_dir}_whole${feat_affix}_fbank exp/make_fbank/${data_id}_whole${feat_affix} fbank +fi + +if [ $stage -le 3 ]; then + local/snr/compute_frame_snrs.sh --use-gpu yes --nj 32 \ + --cmd "$decode_cmd --max-jobs-run 32 --gpu 1" --iter $predictor_iter \ + $irm_predictor ${data_dir}_whole${feat_affix}_hires ${data_dir}_whole${feat_affix}_fbank \ + exp/frame_snrs_irm${affix}${feat_affix}_${data_id}_whole${feat_affix} +fi + +if [ $stage -le 4 ]; then + local/snr/create_snr_data_dir.sh --cmd "$train_cmd" --nj $reco_nj --type $feature_type \ + --add-frame-snr true --append-to-orig-feats $append_to_orig_feats --add-pov-feature $add_pov_feature --dataid ${data_id}_whole${feat_affix} \ + ${data_dir}_whole${feat_affix}_fbank exp/frame_snrs_irm${affix}${feat_affix}_${data_id}_whole${feat_affix} \ + exp/make_snr_data_dir snr_feats $snr_data_dir +fi + +if [ $stage -le 5 ]; then + local/snr/compute_sad.sh --snr-data-dir $snr_data_dir --use-gpu no --method Dnn \ + --iter $sad_nnet_iter $sad_nnet_dir $snr_data_dir $sad_dir +fi + +if ! $create_uniform_segments; then + if [ $stage -le 6 ]; then + local/snr/sad_to_segments.sh --config $segmentation_config --acwt 0.1 \ + --speech-to-sil-ratio 1.0 --sil-self-loop-probability 0.5 \ + --sil-transition-probability 0.5 ${data_dir}_whole${feat_affix}_fbank \ + $sad_dir $seg_dir $seg_dir/${data_id}_seg + fi + + if [ $stage -le 7 ]; then + local/snr/get_weights_for_ivector_extraction.sh --cmd queue.pl \ + --method Viterbi --config $weights_segmentation_config \ + --silence-weight 0 \ + ${seg_dir}/${data_id}_seg ${sad_dir} $seg_dir/ivector_weights_${data_id}_seg + fi +else + if [ $stage -le 6 ]; then + local/snr/uniform_segment_data_dir.sh --overlap-length $overlap_length \ + --window-length $window_length \ + ${data_dir}_whole${feat_affix}_fbank $seg_dir $seg_dir/${data_id}_uniform_seg + fi + if [ $stage -le 7 ]; then + local/snr/get_weights_for_ivector_extraction.sh --cmd queue.pl \ + --method Viterbi --config $weights_segmentation_config \ + --silence-weight 0 \ + ${seg_dir}/${data_id}_uniform_seg ${sad_dir} $seg_dir/ivector_weights_${data_id}_uniform_seg + fi +fi diff --git a/egs/wsj_noisy/s5/local/snr/run_train_sad.sh b/egs/wsj_noisy/s5/local/snr/run_train_sad.sh new file mode 100755 index 00000000000..2280a815a54 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/run_train_sad.sh @@ -0,0 +1,236 @@ +#!/bin/bash + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +num_epochs=8 +splice_indexes=`seq -s',' -50 50` +initial_effective_lrate=0.005 +final_effective_lrate=0.0005 +relu_dim= +sigmoid_dim=50 +train_data_dir=data/train_si284_corrupted_hires +snr_scp= +vad_scp= +final_vad_scp= +max_change_per_sample=0.075 +datadir= +egs_dir= +dir= +nj=40 +method=Dnn +splice_opts="--left-context=50 --right-context=50" +max_param_change=1 +feat_type= +config_dir= +deriv_weights_scp= +lda_opts= + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ $method == "Dnn" ]; then + num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +else + num_hidden_layers=0 +fi + +if [ -z "$dir" ]; then + dir=exp/nnet3_sad_snr/nnet_tdnn_a +fi + +case $method in + "Dnn") + dir=${dir} #_i${relu_dim}_n${num_hidden_layers}_lrate${initial_effective_lrate}_${final_effective_lrate} + ;; + "LogisticRegressionSubsampled") + dir=${dir} + ;; + "LogisticRegression") + dir=${dir} + ;; + "Gmm") + dir=${dir}_gmm + ;; +esac + +if ! cuda-compiled; then + cat < $datadir/segments.tmp + # #cat $datadir/segments.tmp | utils/apply_map.pl -f 2 $train_data_dir/utt2spk > $datadir/segments + # #utils/filter_scp.pl -f 2 $train_data_dir/utt2spk $seg2utt_file | \ + # # utils/apply_map.pl -f 2 $train_data_dir/utt2spk > $datadir/utt2spk + # #utils/utt2spk_to_spk2utt.pl $datadir/utt2spk > $datadir/spk2utt + # # + # #if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/snr_feats/storage ]; then + # # utils/create_split_dir.pl \ + # # /export/b0{3,4,5,6}/$USER/kaldi-data/egs/wsj_noisy-$(date +'%m_%d_%H_%M')/s5/$dir/snr_feats/storage $dir/snr_feats/storage + # #fi + + # #$train_cmd JOB=1:$nj $dir/log/extract_feature_segments.JOB.log \ + # # extract-feature-segments scp:$snr_scp \ + # # "ark,t:utils/split_scp.pl -j $nj \$[JOB-1] $datadir/segments.tmp |" \ + # # ark:- \| copy-feats --compress=true ark:- \ + # # ark,scp:$dir/snr_feats/raw_snr.JOB.ark,$dir/snr_feats/raw_snr.JOB.scp + + # #for n in `seq $nj`; do + # # cat $dir/snr_feats/raw_snr.$n.scp + # #done | sort -k1,1 > $datadir/feats.scp + + # #utils/fix_data_dir.sh $datadir + # else + # cp $snr_scp $datadir/feats.scp + # fi + # else + # cp $snr_scp $datadir/feats.scp + # fi + # steps/compute_cmvn_stats.sh --fake $datadir $datadir/log snr + #fi + +datadir=${train_data_dir} + +if [ -z "$final_vad_scp" ] && [ $method != "Gmm" ]; then + if [ $stage -le 1 ]; then + mkdir -p $dir/vad/split$nj + vad_scp_splits=() + for n in `seq $nj`; do + vad_scp_splits+=($dir/vad/vad.tmp.$n.scp) + done + utils/split_scp.pl $vad_scp ${vad_scp_splits[@]} || exit 1 + + cat < $dir/vad/vad_map +0 0 +1 1 +2 0 +3 0 +4 1 +EOF + $train_cmd JOB=1:$nj $dir/vad/log/convert_vad.JOB.log \ + copy-int-vector scp:$dir/vad/vad.tmp.JOB.scp ark,t:- \| \ + utils/apply_map.pl -f 2- $dir/vad/vad_map \| \ + copy-int-vector ark,t:- \ + ark,scp:$dir/vad/split$nj/vad.JOB.ark,$dir/vad/split$nj/vad.JOB.scp || exit 1 + fi + + for n in `seq $nj`; do + cat $dir/vad/split$nj/vad.$n.scp + done | sort -k1,1 > $dir/vad/vad.scp + final_vad_scp=$dir/vad/vad.scp +fi + +if [ ! -s $final_vad_scp ]; then + echo "$0: $final_vad_scp file is empty!" && exit 1 +fi + +feats_opts=(--feat-type $feat_type) +if [ "$feat_type" == "sparse" ]; then + exit 1 +fi + +if [ $stage -le 3 ]; then + case $method in + "Gmm") + diarization/train_vad_gmm_supervised.sh \ + --ignore-energy false --add-zero-crossing-feats false \ + --add-frame-snrs false \ + --nj $nj --cmd "$train_cmd" \ + $datadir $final_vad_scp $dir || exit 1 + ;; + "LogisticRegressionSubsampled") + $train_cmd --mem 8G $dir/log/train_logistic_regression.log \ + logistic-regression-train-on-feats --num-frames=8000000 --num-targets=2 \ + "ark:cat $datadir/feats.scp | splice-feats $splice_opts scp:- ark:- |" \ + scp:$final_vad_scp $dir/0.mdl || exit 1 + ;; + "LogisticRegression") + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/wsj_noisy-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + deriv_weights_opt= + if [ ! -z "$deriv_weights_scp" ]; then + deriv_weights_opt="--deriv-weights-scp $deriv_weights_scp" + fi + + steps/nnet3/train_tdnn_raw.sh --stage $train_stage \ + --num-epochs $num_epochs --num-jobs-initial 1 --num-jobs-final 4 \ + --splice-indexes "$splice_indexes" --no-hidden-layers true --minibatch-size 512 \ + --egs-dir "$egs_dir" "${feats_opts[@]}" \ + --cmvn-opts "--norm-means=false --norm-vars=false" \ + --max-param-change $max_param_change $deriv_weights_opt \ + --initial-effective-lrate $initial_effective_lrate --final-effective-lrate $final_effective_lrate \ + --cmd "$decode_cmd" --nj 40 --objective-type linear --use-presoftmax-prior-scale false \ + --include-log-softmax true --skip-lda true --posterior-targets true \ + --num-targets 2 --cleanup false --max-param-change $max_param_change \ + $datadir "$final_vad_scp" $dir || exit 1; + ;; + "Dnn") + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/wsj_noisy-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + deriv_weights_opt= + if [ ! -z "$deriv_weights_scp" ]; then + deriv_weights_opt="--deriv-weights-scp $deriv_weights_scp" + fi + + bash -x steps/nnet3/train_tdnn_raw.sh --stage $train_stage \ + --num-epochs $num_epochs --num-jobs-initial 2 --num-jobs-final 4 \ + --splice-indexes "$splice_indexes" \ + --egs-dir "$egs_dir" ${feats_opts[@]} \ + --cmvn-opts "--norm-means=false --norm-vars=false" \ + --max-param-change $max_param_change $deriv_weights_opt --lda-opts "$lda_opts" \ + --initial-effective-lrate $initial_effective_lrate --final-effective-lrate $final_effective_lrate \ + --cmd "$decode_cmd" --nj 40 --objective-type linear --cleanup true --use-presoftmax-prior-scale true \ + --include-log-softmax true --skip-lda true --posterior-targets true \ + --num-targets 2 --max-param-change $max_param_change --config-dir "$config_dir" --pnorm-input-dim "" --pnorm-output-dim "" \ + --cleanup false${relu_dim:+ --relu-dim $relu_dim}${sigmoid_dim:+ --sigmoid-dim $sigmoid_dim} \ + $datadir "$final_vad_scp" $dir || exit 1; + ;; + *) + echo "Unknown method $method" + exit 1 + esac +fi + diff --git a/egs/wsj_noisy/s5/local/snr/run_train_snr_predictor.sh b/egs/wsj_noisy/s5/local/snr/run_train_snr_predictor.sh new file mode 100755 index 00000000000..9cdee06090f --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/run_train_snr_predictor.sh @@ -0,0 +1,109 @@ +#!/bin/bash + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +. cmd.sh + + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +num_epochs=8 +num_utts_subset=300 # number of utterances in validation and training + # subsets used for shrinkage and diagnostics. +splice_indexes="-4,-3,-2,-1,0,1,2,3,4 0 -3,1 0 -7,2 0" +initial_effective_lrate=0.005 +final_effective_lrate=0.0005 +pnorm_input_dims="3000 3000 3000 3000 3000 3000" +pnorm_output_dims="300 300 300 300 300 300" +relu_dims= +train_data_dir=data/train_si284_corrupted_hires +targets_scp=data/train_si284_corrupted_hires/snr_targets.scp +max_param_change=1 +add_layers_period=2 +target_type=IrmExp +config_dir= +egs_dir= +egs_suffix= +src_dir= +src_iter=final +dir= +affix= +deriv_weights_scp= + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/nnet3_snr_predictor/nnet_tdnn_a +fi + +if [ -z "$relu_dims" ]; then +dir=${dir}_pn${num_hidden_layers}_lrate${initial_effective_lrate}_${final_effective_lrate} +else +dir=${dir}_rn${num_hidden_layers}_lrate${initial_effective_lrate}_${final_effective_lrate} +fi + +dir=${dir}${affix} + +if ! cuda-compiled; then + cat < $dir/target_type + + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/wsj_noisy-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + deriv_weights_opt= + if [ ! -z "$deriv_weights_scp" ]; then + deriv_weights_opt="--deriv-weights-scp $deriv_weights_scp" + fi + + if [ -z "$src_dir" ]; then + steps/nnet3/train_tdnn_raw.sh --stage $train_stage \ + --num-epochs $num_epochs --num-jobs-initial 2 --num-jobs-final 4 \ + --splice-indexes "$splice_indexes" --egs-suffix "$egs_suffix" --num-utts-subset $num_utts_subset \ + --feat-type raw --egs-dir "$egs_dir" --get-egs-stage $get_egs_stage \ + --cmvn-opts "--norm-means=false --norm-vars=false" $deriv_weights_opt \ + --max-param-change $max_param_change \ + --initial-effective-lrate $initial_effective_lrate --final-effective-lrate $final_effective_lrate \ + --cmd "$decode_cmd" --nj 40 --objective-type $objective_type --cleanup false --config-dir "$config_dir" \ + --pnorm-input-dims "$pnorm_input_dims" --pnorm-output-dims "$pnorm_output_dims" --pnorm-input-dim "" --pnorm-output-dim "" \ + --relu-dims "$relu_dims" \ + --add-layers-period $add_layers_period \ + $train_data_dir $targets_scp $dir || exit 1; + else + steps/nnet3/train_more.sh --stage $train_stage \ + --num-epochs $num_epochs --num-jobs-initial 2 --num-jobs-final 4 \ + --egs-suffix "$egs_suffix" $deriv_weights_opt \ + --feat-type raw --egs-dir "$egs_dir" --get-egs-stage $get_egs_stage \ + --cmvn-opts "--norm-means=false --norm-vars=false" --iter $src_iter \ + --max-param-change $max_param_change \ + --initial-effective-lrate $initial_effective_lrate --final-effective-lrate $final_effective_lrate \ + --cmd "$decode_cmd" --nj 40 --objective-type $objective_type --cleanup false --config-dir "$config_dir" \ + $train_data_dir $targets_scp $src_dir $dir || exit 1; + fi +fi + diff --git a/egs/wsj_noisy/s5/local/snr/sad_to_segments.sh b/egs/wsj_noisy/s5/local/snr/sad_to_segments.sh new file mode 100755 index 00000000000..6bf391c60e6 --- /dev/null +++ b/egs/wsj_noisy/s5/local/snr/sad_to_segments.sh @@ -0,0 +1,201 @@ +#! /bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -u +. path.sh + +cmd=run.pl +method=Viterbi +stage=-10 + +# General segmentation options +max_intersegment_length=50 # Merge nearby speech segments if the silence + # between them is less than this many frames. +max_relabel_length=10 # maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +pad_length=50 # Pad speech segments by this many frames on either side +post_pad_length=50 # Pad speech segments by this many frames on either side +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=100 # Overlapping frames when segments are split. + # See the above option. + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +nonsil_self_loop_probability=0.9 +nonsil_transition_probability=0.1 +sil_self_loop_probability=0.9 +sil_transition_probability=0.1 +speech_to_sil_ratio=1.0 # the prior on speech vs silence +speech_prior=0.5 +sil_prior=0.5 + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire exp/segmentation_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +if [ "$speech_to_sil_ratio" != "1.0" ]; then + speech_prior=$speech_to_sil_ratio + sil_prior=1 +fi + +data_dir=$1 +vad_dir=$2 +dir=$3 +segmented_data_dir=$4 + +nj=`cat $vad_dir/num_jobs` || exit 1 + +mkdir -p $dir + +if [ $stage -le 0 ]; then + utils/copy_data_dir.sh $data_dir $segmented_data_dir || exit 1 + rm -f $segmented_data_dir/{cmvn.scp,feats.scp,text,segments,utt2spk,spk2utt} +fi + +decoder_opts=(--allow-partial=true) +case $method in + "Smoothing") + if [ $stage -le 1 ]; then + cat < $dir/prob_to_ali.awk +#!/bin/awk -f +{ + printf \$1; + for (i=3; i < NF; i++) { + if (\$i > 0.5) + printf " 1"; + else + printf " 0"; + } + print ""; +} +EOF + + $cmd JOB=1:$nj $dir/log/convert_speech_prob_to_segments.JOB.log \ + copy-vector scp:$vad_dir/speech_prob.JOB.scp ark,t:- \| \ + awk -f $dir/prob_to_ali.awk \| \ + segmentation-init-from-ali ark,t:- ark:- \| \ + segmentation-post-process --remove-labels=0 ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments=true \ + --max-intersegment-length=$max_intersegment_length ark:- ark:- \| \ + segmentation-post-process --max-relabel-length=$max_relabel_length --relabel-short-segments-class=1 ark:- ark:- \| \ + segmentation-post-process --widen-label=1 --widen-length=$pad_length ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments=true \ + --max-intersegment-length=$max_intersegment_length ark:- ark:- \| \ + segmentation-post-process \ + --max-segment-length=$max_segment_length --overlap-length=$overlap_length ark:- ark:- \| \ + segmentation-to-segments ark:- \ + ark,t:$dir/utt2spk.JOB \ + ark,t:$dir/segments.JOB || exit 1 + fi + ;; + "Viterbi") + # Prepare a lang directory + if [ $stage -le 1 ]; then + mkdir -p $dir/local/dict + mkdir -p $dir/local/lm + + echo "1" > $dir/local/dict/silence_phones.txt + echo "1" > $dir/local/dict/optional_silence.txt + echo "2" > $dir/local/dict/nonsilence_phones.txt + echo -e "1 1\n2 2" > $dir/local/dict/lexicon.txt + echo -e "1\n2\n1 2" > $dir/local/dict/extra_questions.txt + + mkdir -p $dir/lang + diarization/prepare_vad_lang.sh \ + --nonsil-self-loop-probability $nonsil_self_loop_probability \ + --nonsil-transition-probability $nonsil_transition_probability \ + --sil-self-loop-probability $sil_self_loop_probability \ + --sil-transition-probability $sil_transition_probability \ + --num-sil-states $min_silence_duration \ + --num-nonsil-states $min_speech_duration \ + $dir/local/dict $dir/local/lang $dir/lang || exit 1 + fi + + feat_dim=2 # dummy. We don't need this. + if [ $stage -le 2 ]; then + $cmd $dir/log/create_transition_model.log gmm-init-mono \ + $dir/lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 + fi + t=sp${speech_prior}_sil${sil_prior} + lang=$dir/lang_test_${t} + if [ $stage -le 3 ]; then + cp -r $dir/lang $lang + perl -e '$sil_prior = shift @ARGV; $speech_prior = shift @ARGV; $s = $sil_prior + $speech_prior; $sil_prior = $sil_prior / $s; $speech_prior = $speech_prior / $s; $s = $sil_prior + $speech_prior; print "0 0 1 1 " . -log($sil_prior/(1.1 * $s)) . "\n0 0 2 2 ". -log($speech_prior/(1.1 * $s)). "\n0 ". -log(0.1 / 1.1)' $sil_prior $speech_prior | \ + fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \ + --keep_isymbols=false --keep_osymbols=false \ + > $lang/G.fst || exit 1 + fi + + if [ $stage -le 4 ]; then + $cmd $dir/log/make_vad_graph.log \ + diarization/make_vad_graph.sh --iter trans \ + $lang $dir $dir/graph_test_${t} || exit 1 + fi + + log_likes=ark:$vad_dir/log_likes.JOB.ark + + decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + + if [ $stage -le 5 ]; then + $cmd JOB=1:$nj $dir/log/decode.JOB.log \ + decode-faster-mapped ${decoder_opts[@]} \ + $dir/trans.mdl \ + $dir/graph_test_${t}/HCLG.fst $log_likes \ + ark:/dev/null ark:- \| \ + ali-to-phones --per-frame=true $dir/trans.mdl ark:- \ + "ark:|gzip -c > $dir/ali.JOB.gz" || exit 1 + fi + + if [ $stage -le 6 ]; then + $cmd JOB=1:$nj $dir/log/segmentation.JOB.log \ + segmentation-init-from-ali "ark:gunzip -c $dir/ali.JOB.gz |" ark:- \| \ + segmentation-post-process --remove-labels=1 ark:- ark:- \| \ + segmentation-post-process --merge-labels=2 --merge-dst-label=1 --widen-label=1 --widen-length=$pad_length ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments=true --max-intersegment-length=$max_intersegment_length ark:- ark:- \| \ + segmentation-post-process --widen-label=1 --widen-length=$post_pad_length ark:- ark:- \| \ + segmentation-post-process --max-segment-length=$max_segment_length --overlap-length=$overlap_length ark:- ark:- \| \ + segmentation-to-segments ark:- \ + ark,t:$dir/utt2spk.JOB \ + ark,t:$dir/segments.JOB || exit 1 + fi + ;; + *) + echo "$0: Unknown method $method specified for segmentation" + exit 1 +esac + +for n in `seq $nj`; do + cat $dir/utt2spk.$n +done > $segmented_data_dir/utt2spk + +for n in `seq $nj`; do + cat $dir/segments.$n +done > $segmented_data_dir/segments + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi + +utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 +utils/fix_data_dir.sh $segmented_data_dir + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi diff --git a/egs/wsj_noisy/s5/local/vad_phone_map_2models b/egs/wsj_noisy/s5/local/vad_phone_map_2models new file mode 100644 index 00000000000..f2bc2eb7d0e --- /dev/null +++ b/egs/wsj_noisy/s5/local/vad_phone_map_2models @@ -0,0 +1,351 @@ +SIL 0 +SIL_B 0 +SIL_E 0 +SIL_I 0 +SIL_S 0 +SPN 3 +SPN_B 3 +SPN_E 3 +SPN_I 3 +SPN_S 3 +NSN 3 +NSN_B 3 +NSN_E 3 +NSN_I 3 +NSN_S 3 +S_B 1 +S_E 1 +S_I 1 +S_S 1 +UW_B 1 +UW_E 1 +UW_I 1 +UW_S 1 +UW0_B 1 +UW0_E 1 +UW0_I 1 +UW0_S 1 +UW1_B 1 +UW1_E 1 +UW1_I 1 +UW1_S 1 +UW2_B 1 +UW2_E 1 +UW2_I 1 +UW2_S 1 +T_B 1 +T_E 1 +T_I 1 +T_S 1 +N_B 1 +N_E 1 +N_I 1 +N_S 1 +K_B 1 +K_E 1 +K_I 1 +K_S 1 +Y_B 1 +Y_E 1 +Y_I 1 +Y_S 1 +Z_B 1 +Z_E 1 +Z_I 1 +Z_S 1 +AO_B 1 +AO_E 1 +AO_I 1 +AO_S 1 +AO0_B 1 +AO0_E 1 +AO0_I 1 +AO0_S 1 +AO1_B 1 +AO1_E 1 +AO1_I 1 +AO1_S 1 +AO2_B 1 +AO2_E 1 +AO2_I 1 +AO2_S 1 +AY_B 1 +AY_E 1 +AY_I 1 +AY_S 1 +AY0_B 1 +AY0_E 1 +AY0_I 1 +AY0_S 1 +AY1_B 1 +AY1_E 1 +AY1_I 1 +AY1_S 1 +AY2_B 1 +AY2_E 1 +AY2_I 1 +AY2_S 1 +SH_B 1 +SH_E 1 +SH_I 1 +SH_S 1 +W_B 1 +W_E 1 +W_I 1 +W_S 1 +NG_B 1 +NG_E 1 +NG_I 1 +NG_S 1 +EY_B 1 +EY_E 1 +EY_I 1 +EY_S 1 +EY0_B 1 +EY0_E 1 +EY0_I 1 +EY0_S 1 +EY1_B 1 +EY1_E 1 +EY1_I 1 +EY1_S 1 +EY2_B 1 +EY2_E 1 +EY2_I 1 +EY2_S 1 +B_B 1 +B_E 1 +B_I 1 +B_S 1 +CH_B 1 +CH_E 1 +CH_I 1 +CH_S 1 +OY_B 1 +OY_E 1 +OY_I 1 +OY_S 1 +OY0_B 1 +OY0_E 1 +OY0_I 1 +OY0_S 1 +OY1_B 1 +OY1_E 1 +OY1_I 1 +OY1_S 1 +OY2_B 1 +OY2_E 1 +OY2_I 1 +OY2_S 1 +JH_B 1 +JH_E 1 +JH_I 1 +JH_S 1 +D_B 1 +D_E 1 +D_I 1 +D_S 1 +ZH_B 1 +ZH_E 1 +ZH_I 1 +ZH_S 1 +G_B 1 +G_E 1 +G_I 1 +G_S 1 +UH_B 1 +UH_E 1 +UH_I 1 +UH_S 1 +UH0_B 1 +UH0_E 1 +UH0_I 1 +UH0_S 1 +UH1_B 1 +UH1_E 1 +UH1_I 1 +UH1_S 1 +UH2_B 1 +UH2_E 1 +UH2_I 1 +UH2_S 1 +F_B 1 +F_E 1 +F_I 1 +F_S 1 +V_B 1 +V_E 1 +V_I 1 +V_S 1 +ER_B 1 +ER_E 1 +ER_I 1 +ER_S 1 +ER0_B 1 +ER0_E 1 +ER0_I 1 +ER0_S 1 +ER1_B 1 +ER1_E 1 +ER1_I 1 +ER1_S 1 +ER2_B 1 +ER2_E 1 +ER2_I 1 +ER2_S 1 +AA_B 1 +AA_E 1 +AA_I 1 +AA_S 1 +AA0_B 1 +AA0_E 1 +AA0_I 1 +AA0_S 1 +AA1_B 1 +AA1_E 1 +AA1_I 1 +AA1_S 1 +AA2_B 1 +AA2_E 1 +AA2_I 1 +AA2_S 1 +IH_B 1 +IH_E 1 +IH_I 1 +IH_S 1 +IH0_B 1 +IH0_E 1 +IH0_I 1 +IH0_S 1 +IH1_B 1 +IH1_E 1 +IH1_I 1 +IH1_S 1 +IH2_B 1 +IH2_E 1 +IH2_I 1 +IH2_S 1 +M_B 1 +M_E 1 +M_I 1 +M_S 1 +DH_B 1 +DH_E 1 +DH_I 1 +DH_S 1 +L_B 1 +L_E 1 +L_I 1 +L_S 1 +AH_B 1 +AH_E 1 +AH_I 1 +AH_S 1 +AH0_B 1 +AH0_E 1 +AH0_I 1 +AH0_S 1 +AH1_B 1 +AH1_E 1 +AH1_I 1 +AH1_S 1 +AH2_B 1 +AH2_E 1 +AH2_I 1 +AH2_S 1 +P_B 1 +P_E 1 +P_I 1 +P_S 1 +OW_B 1 +OW_E 1 +OW_I 1 +OW_S 1 +OW0_B 1 +OW0_E 1 +OW0_I 1 +OW0_S 1 +OW1_B 1 +OW1_E 1 +OW1_I 1 +OW1_S 1 +OW2_B 1 +OW2_E 1 +OW2_I 1 +OW2_S 1 +AW_B 1 +AW_E 1 +AW_I 1 +AW_S 1 +AW0_B 1 +AW0_E 1 +AW0_I 1 +AW0_S 1 +AW1_B 1 +AW1_E 1 +AW1_I 1 +AW1_S 1 +AW2_B 1 +AW2_E 1 +AW2_I 1 +AW2_S 1 +HH_B 1 +HH_E 1 +HH_I 1 +HH_S 1 +AE_B 1 +AE_E 1 +AE_I 1 +AE_S 1 +AE0_B 1 +AE0_E 1 +AE0_I 1 +AE0_S 1 +AE1_B 1 +AE1_E 1 +AE1_I 1 +AE1_S 1 +AE2_B 1 +AE2_E 1 +AE2_I 1 +AE2_S 1 +R_B 1 +R_E 1 +R_I 1 +R_S 1 +TH_B 1 +TH_E 1 +TH_I 1 +TH_S 1 +IY_B 1 +IY_E 1 +IY_I 1 +IY_S 1 +IY0_B 1 +IY0_E 1 +IY0_I 1 +IY0_S 1 +IY1_B 1 +IY1_E 1 +IY1_I 1 +IY1_S 1 +IY2_B 1 +IY2_E 1 +IY2_I 1 +IY2_S 1 +EH_B 1 +EH_E 1 +EH_I 1 +EH_S 1 +EH0_B 1 +EH0_E 1 +EH0_I 1 +EH0_S 1 +EH1_B 1 +EH1_E 1 +EH1_I 1 +EH1_S 1 +EH2_B 1 +EH2_E 1 +EH2_I 1 +EH2_S 1 diff --git a/egs/wsj_noisy/s5/local/wer_hyp_filter b/egs/wsj_noisy/s5/local/wer_hyp_filter new file mode 100755 index 00000000000..939cd08720d --- /dev/null +++ b/egs/wsj_noisy/s5/local/wer_hyp_filter @@ -0,0 +1,11 @@ +#!/bin/sed -f +s:::g +s:::g +s:::g +s/://g +s/\*//g +s/-HOLDER/HOLDER/g +s/COMPAIGN/CAMPAIGN/g +s/APPROACHES-/APPROACHES/g +s/RESEACHERS/RESEARCHERS/g + diff --git a/egs/wsj_noisy/s5/local/wer_output_filter b/egs/wsj_noisy/s5/local/wer_output_filter new file mode 100755 index 00000000000..939cd08720d --- /dev/null +++ b/egs/wsj_noisy/s5/local/wer_output_filter @@ -0,0 +1,11 @@ +#!/bin/sed -f +s:::g +s:::g +s:::g +s/://g +s/\*//g +s/-HOLDER/HOLDER/g +s/COMPAIGN/CAMPAIGN/g +s/APPROACHES-/APPROACHES/g +s/RESEACHERS/RESEARCHERS/g + diff --git a/egs/wsj_noisy/s5/local/wer_ref_filter b/egs/wsj_noisy/s5/local/wer_ref_filter new file mode 100755 index 00000000000..939cd08720d --- /dev/null +++ b/egs/wsj_noisy/s5/local/wer_ref_filter @@ -0,0 +1,11 @@ +#!/bin/sed -f +s:::g +s:::g +s:::g +s/://g +s/\*//g +s/-HOLDER/HOLDER/g +s/COMPAIGN/CAMPAIGN/g +s/APPROACHES-/APPROACHES/g +s/RESEACHERS/RESEARCHERS/g + diff --git a/egs/wsj_noisy/s5/local/wsj_data_prep.sh b/egs/wsj_noisy/s5/local/wsj_data_prep.sh new file mode 100755 index 00000000000..3463747138a --- /dev/null +++ b/egs/wsj_noisy/s5/local/wsj_data_prep.sh @@ -0,0 +1,212 @@ +#!/bin/bash + +# Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0. + + +if [ $# -le 3 ]; then + echo "Arguments should be a list of WSJ directories, see ../run.sh for example." + exit 1; +fi + + +dir=`pwd`/data/local/data +lmdir=`pwd`/data/local/nist_lm +mkdir -p $dir $lmdir +local=`pwd`/local +utils=`pwd`/utils + +. ./path.sh # Needed for KALDI_ROOT +sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe +if [ ! -x $sph2pipe ]; then + echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; + exit 1; +fi + +if [ -z $IRSTLM ] ; then + export IRSTLM=$KALDI_ROOT/tools/irstlm/ +fi +export PATH=${PATH}:$IRSTLM/bin +if ! command -v prune-lm >/dev/null 2>&1 ; then + echo "$0: Error: the IRSTLM is not available or compiled" >&2 + echo "$0: Error: We used to install it by default, but." >&2 + echo "$0: Error: this is no longer the case." >&2 + echo "$0: Error: To install it, go to $KALDI_ROOT/tools" >&2 + echo "$0: Error: and run extras/install_irstlm.sh" >&2 + exit 1 +fi + +cd $dir +# Make directory of links to the WSJ disks such as 11-13.1. This relies on the command +# line arguments being absolute pathnames. +rm -r links/ 2>/dev/null +mkdir links/ +ln -s $* links + +# Do some basic checks that we have what we expected. +if [ ! -d links/11-13.1 -o ! -d links/13-34.1 -o ! -d links/11-2.1 ]; then + echo "wsj_data_prep.sh: Spot check of command line arguments failed" + echo "Command line arguments must be absolute pathnames to WSJ directories" + echo "with names like 11-13.1." + exit 1; +fi + +# This version for SI-84 + +cat links/11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx | \ + $local/ndx2flist.pl $* | sort | \ + grep -v -i 11-2.1/wsj0/si_tr_s/401 > train_si84.flist + +nl=`cat train_si84.flist | wc -l` +[ "$nl" -eq 7138 ] || echo "Warning: expected 7138 lines in train_si84.flist, got $nl" + +# This version for SI-284 +cat links/13-34.1/wsj1/doc/indices/si_tr_s.ndx \ + links/11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx | \ + $local/ndx2flist.pl $* | sort | \ + grep -v -i 11-2.1/wsj0/si_tr_s/401 > train_si284.flist + +nl=`cat train_si284.flist | wc -l` +[ "$nl" -eq 37416 ] || echo "Warning: expected 37416 lines in train_si284.flist, got $nl" + +# Now for the test sets. +# links/13-34.1/wsj1/doc/indices/readme.doc +# describes all the different test sets. +# Note: each test-set seems to come in multiple versions depending +# on different vocabulary sizes, verbalized vs. non-verbalized +# pronunciations, etc. We use the largest vocab and non-verbalized +# pronunciations. +# The most normal one seems to be the "baseline 60k test set", which +# is h1_p0. + +# Nov'92 (333 utts) +# These index files have a slightly different format; +# have to add .wv1 +cat links/11-13.1/wsj0/doc/indices/test/nvp/si_et_20.ndx | \ + $local/ndx2flist.pl $* | awk '{printf("%s.wv1\n", $1)}' | \ + sort > test_eval92.flist + +# Nov'92 (330 utts, 5k vocab) +cat links/11-13.1/wsj0/doc/indices/test/nvp/si_et_05.ndx | \ + $local/ndx2flist.pl $* | awk '{printf("%s.wv1\n", $1)}' | \ + sort > test_eval92_5k.flist + +# Nov'93: (213 utts) +# Have to replace a wrong disk-id. +cat links/13-32.1/wsj1/doc/indices/wsj1/eval/h1_p0.ndx | \ + sed s/13_32_1/13_33_1/ | \ + $local/ndx2flist.pl $* | sort > test_eval93.flist + +# Nov'93: (213 utts, 5k) +cat links/13-32.1/wsj1/doc/indices/wsj1/eval/h2_p0.ndx | \ + sed s/13_32_1/13_33_1/ | \ + $local/ndx2flist.pl $* | sort > test_eval93_5k.flist + +# Dev-set for Nov'93 (503 utts) +cat links/13-34.1/wsj1/doc/indices/h1_p0.ndx | \ + $local/ndx2flist.pl $* | sort > test_dev93.flist + +# Dev-set for Nov'93 (513 utts, 5k vocab) +cat links/13-34.1/wsj1/doc/indices/h2_p0.ndx | \ + $local/ndx2flist.pl $* | sort > test_dev93_5k.flist + + +# Dev-set Hub 1,2 (503, 913 utterances) + +# Note: the ???'s below match WSJ and SI_DT, or wsj and si_dt. +# Sometimes this gets copied from the CD's with upcasing, don't know +# why (could be older versions of the disks). +find `readlink links/13-16.1`/???1/??_??_20 -print | grep -i ".wv1" | sort > dev_dt_20.flist +find `readlink links/13-16.1`/???1/??_??_05 -print | grep -i ".wv1" | sort > dev_dt_05.flist + + +# Finding the transcript files: +for x in $*; do find -L $x -iname '*.dot'; done > dot_files.flist + +# Convert the transcripts into our format (no normalization yet) +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + $local/flist2scp.pl $x.flist | sort > ${x}_sph.scp + cat ${x}_sph.scp | awk '{print $1}' | $local/find_transcripts.pl dot_files.flist > $x.trans1 +done + +# Do some basic normalization steps. At this point we don't remove OOVs-- +# that will be done inside the training scripts, as we'd like to make the +# data-preparation stage independent of the specific lexicon used. +noiseword=""; +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + cat $x.trans1 | $local/normalize_transcript.pl $noiseword | sort > $x.txt || exit 1; +done + +# Create scp's with wav's. (the wv1 in the distribution is not really wav, it is sph.) +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < ${x}_sph.scp > ${x}_wav.scp +done + +# Make the utt2spk and spk2utt files. +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + cat ${x}_sph.scp | awk '{print $1}' | perl -ane 'chop; m:^...:; print "$_ $&\n";' > $x.utt2spk + cat $x.utt2spk | $utils/utt2spk_to_spk2utt.pl > $x.spk2utt || exit 1; +done + + +#in case we want to limit lm's on most frequent words, copy lm training word frequency list +cp links/13-32.1/wsj1/doc/lng_modl/vocab/wfl_64.lst $lmdir +chmod u+w $lmdir/*.lst # had weird permissions on source. + +# The 20K vocab, open-vocabulary language model (i.e. the one with UNK), without +# verbalized pronunciations. This is the most common test setup, I understand. + +cp links/13-32.1/wsj1/doc/lng_modl/base_lm/bcb20onp.z $lmdir/lm_bg.arpa.gz || exit 1; +chmod u+w $lmdir/lm_bg.arpa.gz + +# trigram would be: +cat links/13-32.1/wsj1/doc/lng_modl/base_lm/tcb20onp.z | \ + perl -e 'while(<>){ if(m/^\\data\\/){ print; last; } } while(<>){ print; }' | \ + gzip -c -f > $lmdir/lm_tg.arpa.gz || exit 1; + +prune-lm --threshold=1e-7 $lmdir/lm_tg.arpa.gz $lmdir/lm_tgpr.arpa || exit 1; +gzip -f $lmdir/lm_tgpr.arpa || exit 1; + +# repeat for 5k language models +cp links/13-32.1/wsj1/doc/lng_modl/base_lm/bcb05onp.z $lmdir/lm_bg_5k.arpa.gz || exit 1; +chmod u+w $lmdir/lm_bg_5k.arpa.gz + +# trigram would be: !only closed vocabulary here! +cp links/13-32.1/wsj1/doc/lng_modl/base_lm/tcb05cnp.z $lmdir/lm_tg_5k.arpa.gz || exit 1; +chmod u+w $lmdir/lm_tg_5k.arpa.gz +gunzip $lmdir/lm_tg_5k.arpa.gz +tail -n 4328839 $lmdir/lm_tg_5k.arpa | gzip -c -f > $lmdir/lm_tg_5k.arpa.gz +rm $lmdir/lm_tg_5k.arpa + +prune-lm --threshold=1e-7 $lmdir/lm_tg_5k.arpa.gz $lmdir/lm_tgpr_5k.arpa || exit 1; +gzip -f $lmdir/lm_tgpr_5k.arpa || exit 1; + + +if [ ! -f wsj0-train-spkrinfo.txt ] || [ `cat wsj0-train-spkrinfo.txt | wc -l` -ne 134 ]; then + rm wsj0-train-spkrinfo.txt + ! wget http://www.ldc.upenn.edu/Catalog/docs/LDC93S6A/wsj0-train-spkrinfo.txt && \ + echo "Getting wsj0-train-spkrinfo.txt from backup location" && \ + wget --no-check-certificate https://sourceforge.net/projects/kaldi/files/wsj0-train-spkrinfo.txt +fi + +if [ ! -f wsj0-train-spkrinfo.txt ]; then + echo "Could not get the spkrinfo.txt file from LDC website (moved)?" + echo "This is possibly omitted from the training disks; couldn't find it." + echo "Everything else may have worked; we just may be missing gender info" + echo "which is only needed for VTLN-related diagnostics anyway." + exit 1 +fi +# Note: wsj0-train-spkrinfo.txt doesn't seem to be on the disks but the +# LDC put it on the web. Perhaps it was accidentally omitted from the +# disks. + +cat links/11-13.1/wsj0/doc/spkrinfo.txt \ + links/13-32.1/wsj1/doc/evl_spok/spkrinfo.txt \ + links/13-34.1/wsj1/doc/dev_spok/spkrinfo.txt \ + links/13-34.1/wsj1/doc/train/spkrinfo.txt \ + ./wsj0-train-spkrinfo.txt | \ + perl -ane 'tr/A-Z/a-z/; m/^;/ || print;' | \ + awk '{print $1, $2}' | grep -v -- -- | sort | uniq > spk2gender + + +echo "Data preparation succeeded" diff --git a/egs/wsj_noisy/s5/local/wsj_extend_dict.sh b/egs/wsj_noisy/s5/local/wsj_extend_dict.sh new file mode 100755 index 00000000000..160d866843a --- /dev/null +++ b/egs/wsj_noisy/s5/local/wsj_extend_dict.sh @@ -0,0 +1,181 @@ +#!/bin/bash + +# This script builds a larger word-list and dictionary +# than used for the LMs supplied with the WSJ corpus. +# It uses a couple of strategies to fill-in words in +# the LM training data but not in CMUdict. One is +# to generate special prons for possible acronyms, that +# just consist of the constituent letters. The other +# is designed to handle derivatives of known words +# (e.g. deriving the pron of a plural from the pron of +# the base-word), but in a more general, learned-from-data +# way. +# It makes use of scripts in local/dict/ + +dict_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +if [ $# -ne 1 ]; then + echo "Usage: local/wsj_train_lms.sh /foo/bar/WSJ/13-32.1/" + exit 1 +fi +if [ "`basename $1`" != 13-32.1 ]; then + echo "Expecting the argument to this script to end in 13-32.1" + exit 1 +fi + +# e.g. +#srcdir=/mnt/matylda2/data/WSJ1/13-32.1 +export PATH=$PATH:`pwd`/local/dict/ +srcdir=$1 +mkdir -p data/local/dict${dict_suffix}_larger +dir=data/local/dict${dict_suffix}_larger +cp -r data/local/dict${dict_suffix}/* \ + data/local/dict${dict_suffix}_larger # Various files describing phones etc. + # are there; we just want to copy them + # as the phoneset is the same. +rm data/local/dict${dict_suffix}_larger/lexicon.txt # we don't want this. +rm data/local/dict${dict_suffix}_larger/lexiconp.txt # we don't want this either. +mincount=2 # Minimum count of an OOV we will try to generate a pron for. + +[ ! -f data/local/dict${dict_suffix}/cmudict/cmudict.0.7a ] && \ + echo "CMU dict not in expected place" && exit 1; + +# Remove comments from cmudict; print first field; remove +# words like FOO(1) which are alternate prons: our dict format won't +# include these markers. +grep -v ';;;' data/local/dict${dict_suffix}/cmudict/cmudict.0.7a | + perl -ane 's/^(\S+)\(\d+\)/$1/; print; ' | sort | uniq > $dir/dict.cmu + +cat $dir/dict.cmu | awk '{print $1}' | sort | uniq > $dir/wordlist.cmu + +echo "Getting training data [this should take at least a few seconds; if not, there's a problem]" + +# Convert to uppercase, remove XML-like markings. +# For words ending in "." that are not in CMUdict, we assume that these +# are periods that somehow remained in the data during data preparation, +# and we we replace the "." with "\n". Note: we found this by looking at +# oov.counts below (before adding this rule). + +touch $dir/cleaned.gz +if [ `du -m $dir/cleaned.gz | cut -f 1` -eq 73 ]; then + echo "Not getting cleaned data in $dir/cleaned.gz again [already exists]"; +else + gunzip -c $srcdir/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z \ + | awk '/^){ chop; $isword{$_} = 1; } + while() { + @A = split(" ", $_); + for ($n = 0; $n < @A; $n++) { + $a = $A[$n]; + if (! $isword{$a} && $a =~ s/^([^\.]+)\.$/$1/) { # nonwords that end in "." + # and have no other "." in them: treat as period. + print "$a"; + if ($n+1 < @A) { print "\n"; } + } else { print "$a "; } + } + print "\n"; + } + ' $dir/wordlist.cmu | gzip -c > $dir/cleaned.gz +fi + +# get unigram counts +echo "Getting unigram counts" +gunzip -c $dir/cleaned.gz | tr -s ' ' '\n' | \ + awk '{count[$1]++} END{for (w in count) { print count[w], w; }}' | sort -nr > $dir/unigrams + +cat $dir/unigrams | awk -v dict=$dir/dict.cmu \ + 'BEGIN{while(getline $dir/oov.counts + +echo "Most frequent unseen unigrams are: " +head $dir/oov.counts + +# Prune away singleton counts, and remove things with numbers in +# (which should have been normalized) and with no letters at all. + + +cat $dir/oov.counts | awk -v thresh=$mincount '{if ($1 >= thresh) { print $2; }}' \ + | awk '/[0-9]/{next;} /[A-Z]/{print;}' > $dir/oovlist + +# Automatic rule-finding... + +# First make some prons for possible acronyms. +# Note: we don't do this for things like U.K or U.N, +# or A.B. (which doesn't exist anyway), +# as we consider this normalization/spelling errors. + +cat $dir/oovlist | local/dict/get_acronym_prons.pl $dir/dict.cmu > $dir/dict.acronyms + +mkdir $dir/f $dir/b # forward, backward directions of rules... + # forward is normal suffix + # rules, backward is reversed (prefix rules). These + # dirs contain stuff we create while making the rule-based + # extensions to the dictionary. + +# Remove ; and , from words, if they are present; these +# might crash our scripts, as they are used as separators there. +filter_dict.pl $dir/dict.cmu > $dir/f/dict +cat $dir/oovlist | filter_dict.pl > $dir/f/oovs +reverse_dict.pl $dir/f/dict > $dir/b/dict +reverse_dict.pl $dir/f/oovs > $dir/b/oovs + +# The next stage takes a few minutes. +# Note: the forward stage takes longer, as English is +# mostly a suffix-based language, and there are more rules +# that it finds. +for d in $dir/f $dir/b; do + ( + cd $d + cat dict | get_rules.pl 2>get_rules.log >rules + get_rule_hierarchy.pl rules >hierarchy + awk '{print $1}' dict | get_candidate_prons.pl rules dict | \ + limit_candidate_prons.pl hierarchy | \ + score_prons.pl dict | \ + count_rules.pl >rule.counts + # the sort command below is just for convenience of reading. + score_rules.pl rules.with_scores + get_candidate_prons.pl rules.with_scores dict oovs | \ + limit_candidate_prons.pl hierarchy > oovs.candidates + ) & +done +wait + +# Merge the candidates. +reverse_candidates.pl $dir/b/oovs.candidates | cat - $dir/f/oovs.candidates | sort > $dir/oovs.candidates +select_candidate_prons.pl <$dir/oovs.candidates | awk -F';' '{printf("%s %s\n", $1, $2);}' \ + > $dir/dict.oovs + +cat $dir/dict.acronyms $dir/dict.oovs | sort | uniq > $dir/dict.oovs_merged + +awk '{print $1}' $dir/dict.oovs_merged | uniq > $dir/oovlist.handled +sort $dir/oovlist | diff - $dir/oovlist.handled | grep -v 'd' | sed 's:< ::' > $dir/oovlist.not_handled + + +# add_counts.pl attaches to original counts to the list of handled/not-handled OOVs +add_counts.pl $dir/oov.counts $dir/oovlist.handled | sort -nr > $dir/oovlist.handled.counts +add_counts.pl $dir/oov.counts $dir/oovlist.not_handled | sort -nr > $dir/oovlist.not_handled.counts + +echo "**Top OOVs we handled are:**"; +head $dir/oovlist.handled.counts +echo "**Top OOVs we didn't handle are as follows (note: they are mostly misspellings):**"; +head $dir/oovlist.not_handled.counts + + +echo "Count of OOVs we handled is `awk '{x+=$1} END{print x}' $dir/oovlist.handled.counts`" +echo "Count of OOVs we couldn't handle is `awk '{x+=$1} END{print x}' $dir/oovlist.not_handled.counts`" +echo "Count of OOVs we didn't handle due to low count is" \ + `awk -v thresh=$mincount '{if ($1 < thresh) x+=$1; } END{print x;}' $dir/oov.counts` +# The two files created above are for humans to look at, as diagnostics. + +cat < $dir/lexicon.txt +!SIL SIL + SPN + SPN + NSN +EOF + +echo "Created $dir/lexicon.txt" diff --git a/egs/wsj_noisy/s5/local/wsj_format_data.sh b/egs/wsj_noisy/s5/local/wsj_format_data.sh new file mode 100755 index 00000000000..c476e83ee6f --- /dev/null +++ b/egs/wsj_noisy/s5/local/wsj_format_data.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +# Copyright 2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +# 2015 Guoguo Chen +# Apache 2.0 + +# This script takes data prepared in a corpus-dependent way +# in data/local/, and converts it into the "canonical" form, +# in various subdirectories of data/, e.g. data/lang, data/lang_test_ug, +# data/train_si284, data/train_si84, etc. + +# Don't bother doing train_si84 separately (although we have the file lists +# in data/local/) because it's just the first 7138 utterances in train_si284. +# We'll create train_si84 after doing the feature extraction. + +lang_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +. ./path.sh || exit 1; + +echo "Preparing train and test data" +srcdir=data/local/data +lmdir=data/local/nist_lm +tmpdir=data/local/lm_tmp +lexicon=data/local/lang${lang_suffix}_tmp/lexiconp.txt +mkdir -p $tmpdir + +for x in train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + mkdir -p data/$x + cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1; + cp $srcdir/$x.txt data/$x/text || exit 1; + cp $srcdir/$x.spk2utt data/$x/spk2utt || exit 1; + cp $srcdir/$x.utt2spk data/$x/utt2spk || exit 1; + utils/filter_scp.pl data/$x/spk2utt $srcdir/spk2gender > data/$x/spk2gender || exit 1; +done + + +# Next, for each type of language model, create the corresponding FST +# and the corresponding lang_test_* directory. + +echo Preparing language models for test + +for lm_suffix in bg tgpr tg bg_5k tgpr_5k tg_5k; do + test=data/lang${lang_suffix}_test_${lm_suffix} + + mkdir -p $test + cp -r data/lang${lang_suffix}/* $test || exit 1; + + gunzip -c $lmdir/lm_${lm_suffix}.arpa.gz | \ + utils/find_arpa_oovs.pl $test/words.txt > $tmpdir/oovs_${lm_suffix}.txt + + # grep -v ' ' because the LM seems to have some strange and useless + # stuff in it with multiple 's in the history. Encountered some other similar + # things in a LM from Geoff. Removing all "illegal" combinations of and , + # which are supposed to occur only at being/end of utt. These can cause + # determinization failures of CLG [ends up being epsilon cycles]. + gunzip -c $lmdir/lm_${lm_suffix}.arpa.gz | \ + grep -v ' ' | \ + grep -v ' ' | \ + grep -v ' ' | \ + arpa2fst - | fstprint | \ + utils/remove_oovs.pl $tmpdir/oovs_${lm_suffix}.txt | \ + utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$test/words.txt \ + --osymbols=$test/words.txt --keep_isymbols=false --keep_osymbols=false | \ + fstrmepsilon | fstarcsort --sort_type=ilabel > $test/G.fst + + utils/validate_lang.pl --skip-determinization-check $test || exit 1; +done + +echo "Succeeded in formatting data." +rm -r $tmpdir diff --git a/egs/wsj_noisy/s5/local/wsj_format_local_lms.sh b/egs/wsj_noisy/s5/local/wsj_format_local_lms.sh new file mode 100755 index 00000000000..22493fbe963 --- /dev/null +++ b/egs/wsj_noisy/s5/local/wsj_format_local_lms.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# Copyright Johns Hopkins University (Author: Daniel Povey) 2012 +# Guoguo Chen 2014 + +lang_suffix= + +echo "$0 $@" # Print the command line for logging +. ./path.sh +. utils/parse_options.sh || exit 1; + +[ ! -d data/lang${lang_suffix}_bd ] &&\ + echo "Expect data/local/lang${lang_suffix}_bd to exist" && exit 1; + +lm_srcdir_3g=data/local/local_lm/3gram-mincount +lm_srcdir_4g=data/local/local_lm/4gram-mincount + +[ ! -d "$lm_srcdir_3g" ] && echo "No such dir $lm_srcdir_3g" && exit 1; +[ ! -d "$lm_srcdir_4g" ] && echo "No such dir $lm_srcdir_4g" && exit 1; + +for d in data/lang${lang_suffix}_test_bd_{tg,tgpr,tgconst,fg,fgpr,fgconst}; do + rm -r $d 2>/dev/null + cp -r data/lang${lang_suffix}_bd $d +done + +lang=data/lang${lang_suffix}_bd + +# Check a few files that we have to use. +for f in words.txt oov.int; do + if [[ ! -f $lang/$f ]]; then + echo "$0: no such file $lang/$f" + exit 1; + fi +done + +# Parameters needed for ConstArpaLm. +unk=`cat $lang/oov.int` +bos=`grep "" $lang/words.txt | awk '{print $2}'` +eos=`grep "" $lang/words.txt | awk '{print $2}'` +if [[ -z $bos || -z $eos ]]; then + echo "$0: and symbols are not in $lang/words.txt" + exit 1; +fi + +# Be careful: this time we dispense with the grep -v ' ' so this might +# not work for LMs generated from all toolkits. +gunzip -c $lm_srcdir_3g/lm_pr6.0.gz | \ + arpa2fst - | fstprint | \ + utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \ + --osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \ + fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_tgpr/G.fst || exit 1; + fstisstochastic data/lang${lang_suffix}_test_bd_tgpr/G.fst + +gunzip -c $lm_srcdir_3g/lm_unpruned.gz | \ + arpa2fst - | fstprint | \ + utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \ + --osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \ + fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_tg/G.fst || exit 1; + fstisstochastic data/lang${lang_suffix}_test_bd_tg/G.fst + +# Build ConstArpaLm for the unpruned language model. +gunzip -c $lm_srcdir_3g/lm_unpruned.gz | \ + utils/map_arpa_lm.pl $lang/words.txt | \ + arpa-to-const-arpa --bos-symbol=$bos --eos-symbol=$eos \ + --unk-symbol=$unk - data/lang${lang_suffix}_test_bd_tgconst/G.carpa || exit 1 + +gunzip -c $lm_srcdir_4g/lm_unpruned.gz | \ + arpa2fst - | fstprint | \ + utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \ + --osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \ + fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_fg/G.fst || exit 1; + fstisstochastic data/lang${lang_suffix}_test_bd_fg/G.fst + +# Build ConstArpaLm for the unpruned language model. +gunzip -c $lm_srcdir_4g/lm_unpruned.gz | \ + utils/map_arpa_lm.pl $lang/words.txt | \ + arpa-to-const-arpa --bos-symbol=$bos --eos-symbol=$eos \ + --unk-symbol=$unk - data/lang${lang_suffix}_test_bd_fgconst/G.carpa || exit 1 + +gunzip -c $lm_srcdir_4g/lm_pr7.0.gz | \ + arpa2fst - | fstprint | \ + utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \ + --osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \ + fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_fgpr/G.fst || exit 1; + fstisstochastic data/lang${lang_suffix}_test_bd_fgpr/G.fst + +exit 0; diff --git a/egs/wsj_noisy/s5/local/wsj_prepare_dict.sh b/egs/wsj_noisy/s5/local/wsj_prepare_dict.sh new file mode 100755 index 00000000000..c644f91bc6e --- /dev/null +++ b/egs/wsj_noisy/s5/local/wsj_prepare_dict.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +# Copyright 2010-2012 Microsoft Corporation +# 2012-2014 Johns Hopkins University (Author: Daniel Povey) +# 2015 Guoguo Chen + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# Call this script from one level above, e.g. from the s3/ directory. It puts +# its output in data/local/. + +# The parts of the output of this that will be needed are +# [in data/local/dict/ ] +# lexicon.txt +# extra_questions.txt +# nonsilence_phones.txt +# optional_silence.txt +# silence_phones.txt + +# run this from ../ +dict_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +dir=data/local/dict${dict_suffix} +mkdir -p $dir + + +# (1) Get the CMU dictionary +svn co https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict \ + $dir/cmudict || exit 1; + +# can add -r 10966 for strict compatibility. + + +#(2) Dictionary preparation: + + +# Make phones symbol-table (adding in silence and verbal and non-verbal noises at this point). +# We are adding suffixes _B, _E, _S for beginning, ending, and singleton phones. + +# silence phones, one per line. +(echo SIL; echo SPN; echo NSN) > $dir/silence_phones.txt +echo SIL > $dir/optional_silence.txt + +# nonsilence phones; on each line is a list of phones that correspond +# really to the same base phone. +cat $dir/cmudict/cmudict.0.7a.symbols | perl -ane 's:\r::; print;' | \ + perl -e 'while(<>){ + chop; m:^([^\d]+)(\d*)$: || die "Bad phone $_"; + $phones_of{$1} .= "$_ "; } + foreach $list (values %phones_of) {print $list . "\n"; } ' \ + > $dir/nonsilence_phones.txt || exit 1; + +# A few extra questions that will be added to those obtained by automatically clustering +# the "real" phones. These ask about stress; there's also one for silence. +cat $dir/silence_phones.txt| awk '{printf("%s ", $1);} END{printf "\n";}' > $dir/extra_questions.txt || exit 1; +cat $dir/nonsilence_phones.txt | perl -e 'while(<>){ foreach $p (split(" ", $_)) { + $p =~ m:^([^\d]+)(\d*)$: || die "Bad phone $_"; $q{$2} .= "$p "; } } foreach $l (values %q) {print "$l\n";}' \ + >> $dir/extra_questions.txt || exit 1; + +grep -v ';;;' $dir/cmudict/cmudict.0.7a | \ + perl -ane 'if(!m:^;;;:){ s:(\S+)\(\d+\) :$1 :; print; }' \ + > $dir/lexicon1_raw_nosil.txt || exit 1; + +# Add to cmudict the silences, noises etc. + +# the sort | uniq is to remove a duplicated pron from cmudict. +(echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; echo ' NSN'; ) | \ + cat - $dir/lexicon1_raw_nosil.txt | sort | uniq > $dir/lexicon2_raw.txt || exit 1; + + +# lexicon.txt is without the _B, _E, _S, _I markers. +# This is the input to wsj_format_data.sh +cp $dir/lexicon2_raw.txt $dir/lexicon.txt + +rm $dir/lexiconp.txt 2>/dev/null + +echo "Dictionary preparation succeeded" + diff --git a/egs/wsj_noisy/s5/local/wsj_train_lms.sh b/egs/wsj_noisy/s5/local/wsj_train_lms.sh new file mode 100755 index 00000000000..0807210be18 --- /dev/null +++ b/egs/wsj_noisy/s5/local/wsj_train_lms.sh @@ -0,0 +1,217 @@ +#!/bin/bash + +# This script trains LMs on the WSJ LM-training data. +# It requires that you have already run wsj_extend_dict.sh, +# to get the larger-size dictionary including all of CMUdict +# plus any OOVs and possible acronyms that we could easily +# derive pronunciations for. + +dict_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +dir=data/local/local_lm +srcdir=data/local/dict${dict_suffix}_larger +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/kaldi_lm:$PATH +( # First make sure the kaldi_lm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d kaldi_lm ]; then + echo Not installing the kaldi_lm toolkit since it is already there. + else + echo Downloading and installing the kaldi_lm tools + if [ ! -f kaldi_lm.tar.gz ]; then + wget http://www.danielpovey.com/files/kaldi/kaldi_lm.tar.gz || exit 1; + fi + tar -xvzf kaldi_lm.tar.gz || exit 1; + cd kaldi_lm + make || exit 1; + echo Done making the kaldi_lm tools + fi +) || exit 1; + + + +if [ ! -f $srcdir/cleaned.gz -o ! -f $srcdir/lexicon.txt ]; then + echo "Expecting files $srcdir/cleaned.gz and $srcdir/lexicon.txt to exist"; + echo "You need to run local/wsj_extend_dict.sh before running this script." + exit 1; +fi + +# Get a wordlist-- keep everything but silence, which should not appear in +# the LM. +awk '{print $1}' $srcdir/lexicon.txt | grep -v -w '!SIL' > $dir/wordlist.txt + +# Get training data with OOV words (w.r.t. our current vocab) replaced with . +echo "Getting training data with OOV words replaced with (train_nounk.gz)" +gunzip -c $srcdir/cleaned.gz | awk -v w=$dir/wordlist.txt \ + 'BEGIN{while((getline0) v[$1]=1;} + {for (i=1;i<=NF;i++) if ($i in v) printf $i" ";else printf " ";print ""}'|sed 's/ $//g' \ + | gzip -c > $dir/train_nounk.gz + +# Get unigram counts (without bos/eos, but this doens't matter here, it's +# only to get the word-map, which treats them specially & doesn't need their +# counts). +# Add a 1-count for each word in word-list by including that in the data, +# so all words appear. +gunzip -c $dir/train_nounk.gz | cat - $dir/wordlist.txt | \ + awk '{ for(x=1;x<=NF;x++) count[$x]++; } END{for(w in count){print count[w], w;}}' | \ + sort -nr > $dir/unigram.counts + +# Get "mapped" words-- a character encoding of the words that makes the common words very short. +cat $dir/unigram.counts | awk '{print $2}' | get_word_map.pl "" "" "" > $dir/word_map + +gunzip -c $dir/train_nounk.gz | awk -v wmap=$dir/word_map 'BEGIN{while((getline0)map[$1]=$2;} + { for(n=1;n<=NF;n++) { printf map[$n]; if(n$dir/train.gz + +# To save disk space, remove the un-mapped training data. We could +# easily generate it again if needed. +rm $dir/train_nounk.gz + +train_lm.sh --arpa --lmtype 3gram-mincount $dir +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 141.444826 +# 7.8 million N-grams. + +prune_lm.sh --arpa 6.0 $dir/3gram-mincount/ +# 1.45 million N-grams. +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 165.394139 + +train_lm.sh --arpa --lmtype 4gram-mincount $dir +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 126.734180 +# 10.3 million N-grams. + +prune_lm.sh --arpa 7.0 $dir/4gram-mincount +# 1.50 million N-grams +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 155.663757 + + +exit 0 + +### Below here, this script is showing various commands that +## were run during LM tuning. + +train_lm.sh --arpa --lmtype 3gram-mincount $dir +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 141.444826 +# 7.8 million N-grams. + +prune_lm.sh --arpa 3.0 $dir/3gram-mincount/ +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 156.408740 +# 2.5 million N-grams. + +prune_lm.sh --arpa 6.0 $dir/3gram-mincount/ +# 1.45 million N-grams. +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 165.394139 + +train_lm.sh --arpa --lmtype 4gram-mincount $dir +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 126.734180 +# 10.3 million N-grams. + +prune_lm.sh --arpa 3.0 $dir/4gram-mincount +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 143.206294 +# 2.6 million N-grams. + +prune_lm.sh --arpa 4.0 $dir/4gram-mincount +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 146.927717 +# 2.15 million N-grams. + +prune_lm.sh --arpa 5.0 $dir/4gram-mincount +# 1.86 million N-grams +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 150.162023 + +prune_lm.sh --arpa 7.0 $dir/4gram-mincount +# 1.50 million N-grams +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 155.663757 + +train_lm.sh --arpa --lmtype 3gram $dir +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 135.692866 +# 20.0 million N-grams + +! which ngram-count \ + && echo "SRILM tools not installed so not doing the comparison" && exit 1; + +################# +# You could finish the script here if you wanted. +# Below is to show how to do baselines with SRILM. +# You'd have to install the SRILM toolkit first. + +heldout_sent=10000 # Don't change this if you want result to be comparable with + # kaldi_lm results +sdir=$dir/srilm # in case we want to use SRILM to double-check perplexities. +mkdir -p $sdir +gunzip -c $srcdir/cleaned.gz | head -$heldout_sent > $sdir/cleaned.heldout +gunzip -c $srcdir/cleaned.gz | tail -n +$heldout_sent > $sdir/cleaned.train +(echo ""; echo "" ) | cat - $dir/wordlist.txt > $sdir/wordlist.final.s + +# 3-gram: +ngram-count -text $sdir/cleaned.train -order 3 -limit-vocab -vocab $sdir/wordlist.final.s -unk \ + -map-unk "" -kndiscount -interpolate -lm $sdir/srilm.o3g.kn.gz +ngram -lm $sdir/srilm.o3g.kn.gz -ppl $sdir/cleaned.heldout # consider -debug 2 +#file data/local/local_lm/srilm/cleaned.heldout: 10000 sentences, 218996 words, 478 OOVs +#0 zeroprobs, logprob= -491456 ppl= 141.457 ppl1= 177.437 + +# Trying 4-gram: +ngram-count -text $sdir/cleaned.train -order 4 -limit-vocab -vocab $sdir/wordlist.final.s -unk \ + -map-unk "" -kndiscount -interpolate -lm $sdir/srilm.o4g.kn.gz +ngram -order 4 -lm $sdir/srilm.o4g.kn.gz -ppl $sdir/cleaned.heldout +#file data/local/local_lm/srilm/cleaned.heldout: 10000 sentences, 218996 words, 478 OOVs +#0 zeroprobs, logprob= -480939 ppl= 127.233 ppl1= 158.822 + +#3-gram with pruning: +ngram-count -text $sdir/cleaned.train -order 3 -limit-vocab -vocab $sdir/wordlist.final.s -unk \ + -prune 0.0000001 -map-unk "" -kndiscount -interpolate -lm $sdir/srilm.o3g.pr7.kn.gz +ngram -lm $sdir/srilm.o3g.pr7.kn.gz -ppl $sdir/cleaned.heldout +#file data/local/local_lm/srilm/cleaned.heldout: 10000 sentences, 218996 words, 478 OOVs +#0 zeroprobs, logprob= -510828 ppl= 171.947 ppl1= 217.616 +# Around 2.25M N-grams. +# Note: this is closest to the experiment done with "prune_lm.sh --arpa 3.0 $dir/3gram-mincount/" +# above, which gave 2.5 million N-grams and a perplexity of 156. + +# Note: all SRILM experiments above fully discount all singleton 3 and 4-grams. +# You can use -gt3min=0 and -gt4min=0 to stop this (this will be comparable to +# the kaldi_lm experiments above without "-mincount". + +## From here is how to train with +# IRSTLM. This is not really working at the moment. + +if [ -z $IRSTLM ] ; then + export IRSTLM=$KALDI_ROOT/tools/irstlm/ +fi +export PATH=${PATH}:$IRSTLM/bin +if ! command -v prune-lm >/dev/null 2>&1 ; then + echo "$0: Error: the IRSTLM is not available or compiled" >&2 + echo "$0: Error: We used to install it by default, but." >&2 + echo "$0: Error: this is no longer the case." >&2 + echo "$0: Error: To install it, go to $KALDI_ROOT/tools" >&2 + echo "$0: Error: and run extras/install_irstlm.sh" >&2 + exit 1 +fi + +idir=$dir/irstlm +mkdir $idir +gunzip -c $srcdir/cleaned.gz | tail -n +$heldout_sent | add-start-end.sh | \ + gzip -c > $idir/train.gz + +dict -i=WSJ.cleaned.irstlm.txt -o=dico -f=y -sort=no + cat dico | gawk 'BEGIN{while (getline<"vocab.20k.nooov") v[$1]=1; print "DICTIONARY 0 "length(v);}FNR>1{if ($1 in v)\ +{print $0;}}' > vocab.irstlm.20k + + +build-lm.sh -i "gunzip -c $idir/train.gz" -o $idir/lm_3gram.gz -p yes \ + -n 3 -s improved-kneser-ney -b yes +# Testing perplexity with SRILM tools: +ngram -lm $idir/lm_3gram.gz -ppl $sdir/cleaned.heldout +#data/local/local_lm/irstlm/lm_3gram.gz: line 162049: warning: non-zero probability for in closed-vocabulary LM +#file data/local/local_lm/srilm/cleaned.heldout: 10000 sentences, 218996 words, 0 OOVs +#0 zeroprobs, logprob= -513670 ppl= 175.041 ppl1= 221.599 + +# Perplexity is very bad (should be ~141, since we used -p option, +# not 175), +# but adding -debug 3 to the command line shows that +# the IRSTLM LM does not seem to sum to one properly, so it seems that +# it produces an LM that isn't interpretable in the normal way as an ARPA +# LM. + + + diff --git a/egs/wsj_noisy/s5/local/wsj_train_rnnlms.sh b/egs/wsj_noisy/s5/local/wsj_train_rnnlms.sh new file mode 100755 index 00000000000..1d4fda63fe7 --- /dev/null +++ b/egs/wsj_noisy/s5/local/wsj_train_rnnlms.sh @@ -0,0 +1,162 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (author: Daniel Povey) Tony Robinson +# 2015 Guoguo Chen + +# This script trains LMs on the WSJ LM-training data. +# It requires that you have already run wsj_extend_dict.sh, +# to get the larger-size dictionary including all of CMUdict +# plus any OOVs and possible acronyms that we could easily +# derive pronunciations for. + +# This script takes no command-line arguments but takes the --cmd option. + +# Begin configuration section. +rand_seed=0 +cmd=run.pl +nwords=10000 # This is how many words we're putting in the vocab of the RNNLM. +hidden=30 +class=200 # Num-classes... should be somewhat larger than sqrt of nwords. +direct=1000 # Number of weights that are used for "direct" connections, in millions. +rnnlm_ver=rnnlm-0.3e # version of RNNLM to use +threads=1 # for RNNLM-HS +bptt=2 # length of BPTT unfolding in RNNLM +bptt_block=20 # length of BPTT unfolding in RNNLM +dict_suffix= +# End configuration section. + +[ -f ./path.sh ] && . ./path.sh +. utils/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: local/wsj_train_rnnlms.sh [options] " + echo "For options, see top of script file" + exit 1; +fi + +dir=$1 +srcdir=data/local/dict${dict_suffix}_larger +mkdir -p $dir + +export PATH=$KALDI_ROOT/tools/$rnnlm_ver:$PATH + + +( # First make sure the kaldi_lm toolkit is installed. + # Note: this didn't work out of the box for me, I had to + # change the g++ version to just "g++" (no cross-compilation + # needed for me as I ran on a machine that had been setup + # as 64 bit by default. + cd $KALDI_ROOT/tools || exit 1; + if [ -f $rnnlm_ver/rnnlm ]; then + echo Not installing the rnnlm toolkit since it is already there. + else + if [ $rnnlm_ver == "rnnlm-hs-0.1b" ]; then + extras/install_rnnlm_hs.sh + else + echo Downloading and installing the rnnlm tools + # http://www.fit.vutbr.cz/~imikolov/rnnlm/$rnnlm_ver.tgz + if [ ! -f $rnnlm_ver.tgz ]; then + wget http://www.fit.vutbr.cz/~imikolov/rnnlm/$rnnlm_ver.tgz || exit 1; + fi + mkdir $rnnlm_ver + cd $rnnlm_ver + tar -xvzf ../$rnnlm_ver.tgz || exit 1; + make CC=g++ || exit 1; + echo Done making the rnnlm tools + fi + fi +) || exit 1; + + +if [ ! -f $srcdir/cleaned.gz -o ! -f $srcdir/lexicon.txt ]; then + echo "Expecting files $srcdir/cleaned.gz and $srcdir/wordlist.final to exist"; + echo "You need to run local/wsj_extend_dict.sh before running this script." + exit 1; +fi + +cat $srcdir/lexicon.txt | awk '{print $1}' | grep -v -w '!SIL' > $dir/wordlist.all + +# Get training data with OOV words (w.r.t. our current vocab) replaced with . +echo "Getting training data with OOV words replaced with (train_nounk.gz)" +gunzip -c $srcdir/cleaned.gz | awk -v w=$dir/wordlist.all \ + 'BEGIN{while((getline0) v[$1]=1;} + {for (i=1;i<=NF;i++) if ($i in v) printf $i" ";else printf " ";print ""}'|sed 's/ $//g' \ + | gzip -c > $dir/all.gz + +echo "Splitting data into train and validation sets." +heldout_sent=10000 +gunzip -c $dir/all.gz | head -n $heldout_sent > $dir/valid.in # validation data +gunzip -c $dir/all.gz | tail -n +$heldout_sent | \ + perl -e ' use List::Util qw(shuffle); @A=<>; print join("", shuffle(@A)); ' \ + > $dir/train.in # training data + + + # The rest will consist of a word-class represented by , that + # maps (with probabilities) to a whole class of words. + +# Get unigram counts from our training data, and use this to select word-list +# for RNNLM training; e.g. 10k most frequent words. Rest will go in a class +# that we (manually, at the shell level) assign probabilities for words that +# are in that class. Note: this word-list doesn't need to include ; this +# automatically gets added inside the rnnlm program. +# Note: by concatenating with $dir/wordlist.all, we are doing add-one +# smoothing of the counts. + +cat $dir/train.in $dir/wordlist.all | grep -v '' | grep -v '' | \ + awk '{ for(x=1;x<=NF;x++) count[$x]++; } END{for(w in count){print count[w], w;}}' | \ + sort -nr > $dir/unigram.counts + +head -$nwords $dir/unigram.counts | awk '{print $2}' > $dir/wordlist.rnn + +tail -n +$nwords $dir/unigram.counts > $dir/unk_class.counts + +tot=`awk '{x=x+$1} END{print x}' $dir/unk_class.counts` +awk -v tot=$tot '{print $2, ($1*1.0/tot);}' <$dir/unk_class.counts >$dir/unk.probs + + +for type in train valid; do + cat $dir/$type.in | awk -v w=$dir/wordlist.rnn \ + 'BEGIN{while((getline0) v[$1]=1;} + {for (i=1;i<=NF;i++) if ($i in v) printf $i" ";else printf " ";print ""}'|sed 's/ $//g' \ + > $dir/$type +done +rm $dir/train.in # no longer needed-- and big. + +# Now randomize the order of the training data. +cat $dir/train | awk -v rand_seed=$rand_seed 'BEGIN{srand(rand_seed);} {printf("%f\t%s\n", rand(), $0);}' | \ + sort | cut -f 2 > $dir/foo +mv $dir/foo $dir/train + +# OK we'll train the RNNLM on this data. + +# todo: change 100 to 320. +# using 100 classes as square root of 10k. +echo "Training RNNLM (note: this uses a lot of memory! Run it on a big machine.)" +#time rnnlm -train $dir/train -valid $dir/valid -rnnlm $dir/100.rnnlm \ +# -hidden 100 -rand-seed 1 -debug 2 -class 100 -bptt 2 -bptt-block 20 \ +# -direct-order 4 -direct 1000 -binary >& $dir/rnnlm1.log & + +$cmd $dir/rnnlm.log \ + $KALDI_ROOT/tools/$rnnlm_ver/rnnlm -threads $threads -independent -train $dir/train -valid $dir/valid \ + -rnnlm $dir/rnnlm -hidden $hidden -rand-seed 1 -debug 2 -class $class -bptt $bptt -bptt-block $bptt_block \ + -direct-order 4 -direct $direct -binary || exit 1; + + +# make it like a Kaldi table format, with fake utterance-ids. +cat $dir/valid.in | awk '{ printf("uttid-%d ", NR); print; }' > $dir/valid.with_ids + +utils/rnnlm_compute_scores.sh --rnnlm_ver $rnnlm_ver $dir $dir/tmp.valid $dir/valid.with_ids \ + $dir/valid.scores +nw=`wc -w < $dir/valid.with_ids` # Note: valid.with_ids includes utterance-ids which + # is one per word, to account for the at the end of each sentence; this is the + # correct number to normalize buy. +p=`awk -v nw=$nw '{x=x+$2} END{print exp(x/nw);}' <$dir/valid.scores` +echo Perplexity is $p | tee $dir/perplexity.log + +rm $dir/train $dir/all.gz + +# This is a better setup, but takes a long time to train: +#echo "Training RNNLM (note: this uses a lot of memory! Run it on a big machine.)" +#time rnnlm -train $dir/train -valid $dir/valid -rnnlm $dir/320.rnnlm \ +# -hidden 320 -rand-seed 1 -debug 2 -class 300 -bptt 2 -bptt-block 20 \ +# -direct-order 4 -direct 2000 -binary diff --git a/egs/wsj_noisy/s5/path.sh b/egs/wsj_noisy/s5/path.sh new file mode 100755 index 00000000000..343ace34179 --- /dev/null +++ b/egs/wsj_noisy/s5/path.sh @@ -0,0 +1,4 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin/:$KALDI_ROOT/src/kwsbin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/:$KALDI_ROOT/src/nnet3bin/:$KALDI_ROOT/src/segmenterbin/:$PWD:$PATH +export LC_ALL=C diff --git a/egs/wsj_noisy/s5/run.sh b/egs/wsj_noisy/s5/run.sh new file mode 100755 index 00000000000..947cf5cb19c --- /dev/null +++ b/egs/wsj_noisy/s5/run.sh @@ -0,0 +1,478 @@ +#!/bin/bash + +. ./cmd.sh ## You'll want to change cmd.sh to something that will work on your system. + ## This relates to the queue. + +# This is a shell script, but it's recommended that you run the commands one by +# one by copying and pasting into the shell. + +#wsj0=/ais/gobi2/speech/WSJ/csr_?_senn_d? +#wsj1=/ais/gobi2/speech/WSJ/csr_senn_d? + +#wsj0=/mnt/matylda2/data/WSJ0 +#wsj1=/mnt/matylda2/data/WSJ1 + +#wsj0=/data/corpora0/LDC93S6B +#wsj1=/data/corpora0/LDC94S13B + +wsj0=/export/corpora5/LDC/LDC93S6B +wsj1=/export/corpora5/LDC/LDC94S13B + +local/wsj_data_prep.sh $wsj0/??-{?,??}.? $wsj1/??-{?,??}.? || exit 1; + +# Sometimes, we have seen WSJ distributions that do not have subdirectories +# like '11-13.1', but instead have 'doc', 'si_et_05', etc. directly under the +# wsj0 or wsj1 directories. In such cases, try the following: +# +# corpus=/exports/work/inf_hcrc_cstr_general/corpora/wsj +# local/cstr_wsj_data_prep.sh $corpus +# rm data/local/dict/lexiconp.txt +# $corpus must contain a 'wsj0' and a 'wsj1' subdirectory for this to work. +# +# "nosp" refers to the dictionary before silence probabilities and pronunciation +# probabilities are added. +local/wsj_prepare_dict.sh --dict-suffix "_nosp" || exit 1; + +utils/prepare_lang.sh data/local/dict_nosp \ + "" data/local/lang_tmp_nosp data/lang_nosp || exit 1; + +local/wsj_format_data.sh --lang-suffix "_nosp" || exit 1; + + # We suggest to run the next three commands in the background, + # as they are not a precondition for the system building and + # most of the tests: these commands build a dictionary + # containing many of the OOVs in the WSJ LM training data, + # and an LM trained directly on that data (i.e. not just + # copying the arpa files from the disks from LDC). + # Caution: the commands below will only work if $decode_cmd + # is setup to use qsub. Else, just remove the --cmd option. + # NOTE: If you have a setup corresponding to the cstr_wsj_data_prep.sh style, + # use local/cstr_wsj_extend_dict.sh $corpus/wsj1/doc/ instead. + + # Note: I am commenting out the RNNLM-building commands below. They take up a lot + # of CPU time and are not really part of the "main recipe." + # Be careful: appending things like "--mem 10G" to $decode_cmd + # won't always work, it depends what $decode_cmd is. + ( + local/wsj_extend_dict.sh --dict-suffix "_nosp" $wsj1/13-32.1 && \ + utils/prepare_lang.sh data/local/dict_nosp_larger \ + "" data/local/lang_tmp_nosp_larger data/lang_nosp_bd && \ + local/wsj_train_lms.sh --dict-suffix "_nosp" && + local/wsj_format_local_lms.sh --lang-suffix "_nosp" # && + # + # ( local/wsj_train_rnnlms.sh --dict-suffix "_nosp" \ + # --cmd "$decode_cmd --mem 10G" data/local/rnnlm.h30.voc10k & + # sleep 20; # wait till tools compiled. + # local/wsj_train_rnnlms.sh --dict-suffix "_nosp" \ + # --cmd "$decode_cmd --mem 12G" \ + # --hidden 100 --nwords 20000 --class 350 \ + # --direct 1500 data/local/rnnlm.h100.voc20k & + # local/wsj_train_rnnlms.sh --dict-suffix "_nosp" \ + # --cmd "$decode_cmd --mem 14G" \ + # --hidden 200 --nwords 30000 --class 350 \ + # --direct 1500 data/local/rnnlm.h200.voc30k & + # local/wsj_train_rnnlms.sh --dict-suffix "_nosp" \ + # --cmd "$decode_cmd --mem 16G" \ + # --hidden 300 --nwords 40000 --class 400 \ + # --direct 2000 data/local/rnnlm.h300.voc40k & + # ) + false && \ # Comment this out to train RNNLM-HS + ( + num_threads_rnnlm=8 + local/wsj_train_rnnlms.sh --dict-suffix "_nosp" \ + --rnnlm_ver rnnlm-hs-0.1b --threads $num_threads_rnnlm \ + --cmd "$decode_cmd --mem 1G --num-threads $num_threads_rnnlm" --bptt 4 --bptt-block 10 \ + --hidden 30 --nwords 10000 --direct 1000 data/local/rnnlm-hs.h30.voc10k + local/wsj_train_rnnlms.sh --dict-suffix "_nosp" \ + --rnnlm_ver rnnlm-hs-0.1b --threads $num_threads_rnnlm \ + --cmd "$decode_cmd --mem 1G --num-threads $num_threads_rnnlm" --bptt 4 --bptt-block 10 \ + --hidden 100 --nwords 20000 --direct 1500 data/local/rnnlm-hs.h100.voc20k + local/wsj_train_rnnlms.sh --dict-suffix "_nosp" \ + --rnnlm_ver rnnlm-hs-0.1b --threads $num_threads_rnnlm \ + --cmd "$decode_cmd --mem 1G --num-threads $num_threads_rnnlm" --bptt 4 --bptt-block 10 \ + --hidden 300 --nwords 30000 --direct 1500 data/local/rnnlm-hs.h300.voc30k + local/wsj_train_rnnlms.sh --dict-suffix "_nosp" \ + --rnnlm_ver rnnlm-hs-0.1b --threads $num_threads_rnnlm \ + --cmd "$decode_cmd --mem 1G --num-threads $num_threads_rnnlm" --bptt 4 --bptt-block 10 \ + --hidden 400 --nwords 40000 --direct 2000 data/local/rnnlm-hs.h400.voc40k + ) + ) & + +# Now make MFCC features. +# mfccdir should be some place with a largish disk where you +# want to store MFCC features. +mfccdir=mfcc +for x in test_eval92 test_eval93 test_dev93 train_si284; do + steps/make_mfcc.sh --cmd "$train_cmd" --nj 20 \ + data/$x exp/make_mfcc/$x $mfccdir || exit 1; + steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir || exit 1; +done + +utils/subset_data_dir.sh --first data/train_si284 7138 data/train_si84 || exit 1 + +# Now make subset with the shortest 2k utterances from si-84. +utils/subset_data_dir.sh --shortest data/train_si84 2000 data/train_si84_2kshort || exit 1; + +# Now make subset with half of the data from si-84. +utils/subset_data_dir.sh data/train_si84 3500 data/train_si84_half || exit 1; + + +# Note: the --boost-silence option should probably be omitted by default +# for normal setups. It doesn't always help. [it's to discourage non-silence +# models from modeling silence.] +steps/train_mono.sh --boost-silence 1.25 --nj 10 --cmd "$train_cmd" \ + data/train_si84_2kshort data/lang_nosp exp/mono0a || exit 1; + +( + utils/mkgraph.sh --mono data/lang_nosp_test_tgpr \ + exp/mono0a exp/mono0a/graph_nosp_tgpr && \ + steps/decode.sh --nj 10 --cmd "$decode_cmd" exp/mono0a/graph_nosp_tgpr \ + data/test_dev93 exp/mono0a/decode_nosp_tgpr_dev93 && \ + steps/decode.sh --nj 8 --cmd "$decode_cmd" exp/mono0a/graph_nosp_tgpr \ + data/test_eval92 exp/mono0a/decode_nosp_tgpr_eval92 +) & + +steps/align_si.sh --boost-silence 1.25 --nj 10 --cmd "$train_cmd" \ + data/train_si84_half data/lang_nosp exp/mono0a exp/mono0a_ali || exit 1; + +steps/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" 2000 10000 \ + data/train_si84_half data/lang_nosp exp/mono0a_ali exp/tri1 || exit 1; + +while [ ! -f data/lang_nosp_test_tgpr/tmp/LG.fst ] || \ + [ -z data/lang_nosp_test_tgpr/tmp/LG.fst ]; do + sleep 20; +done +sleep 30; +# or the mono mkgraph.sh might be writing +# data/lang_test_tgpr/tmp/LG.fst which will cause this to fail. + +utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri1 exp/tri1/graph_nosp_tgpr || exit 1; + +steps/decode.sh --nj 10 --cmd "$decode_cmd" exp/tri1/graph_nosp_tgpr \ + data/test_dev93 exp/tri1/decode_nosp_tgpr_dev93 || exit 1; +steps/decode.sh --nj 8 --cmd "$decode_cmd" exp/tri1/graph_nosp_tgpr \ + data/test_eval92 exp/tri1/decode_nosp_tgpr_eval92 || exit 1; + +# test various modes of LM rescoring (4 is the default one). +# This is just confirming they're equivalent. +for mode in 1 2 3 4; do + steps/lmrescore.sh --mode $mode --cmd "$decode_cmd" \ + data/lang_nosp_test_{tgpr,tg} data/test_dev93 \ + exp/tri1/decode_nosp_tgpr_dev93 \ + exp/tri1/decode_nosp_tgpr_dev93_tg$mode || exit 1; +done + +# demonstrate how to get lattices that are "word-aligned" (arcs coincide with +# words, with boundaries in the right place). +sil_label=`grep '!SIL' data/lang_nosp_test_tgpr/words.txt | awk '{print $2}'` +steps/word_align_lattices.sh --cmd "$train_cmd" --silence-label $sil_label \ + data/lang_nosp_test_tgpr exp/tri1/decode_nosp_tgpr_dev93 \ + exp/tri1/decode_nosp_tgpr_dev93_aligned || exit 1; + +steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + data/train_si84 data/lang_nosp exp/tri1 exp/tri1_ali_si84 || exit 1; + +# Train tri2a, which is deltas + delta-deltas, on si84 data. +steps/train_deltas.sh --cmd "$train_cmd" 2500 15000 \ + data/train_si84 data/lang_nosp exp/tri1_ali_si84 exp/tri2a || exit 1; + +utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri2a exp/tri2a/graph_nosp_tgpr || exit 1; + +steps/decode.sh --nj 10 --cmd "$decode_cmd" exp/tri2a/graph_nosp_tgpr \ + data/test_dev93 exp/tri2a/decode_nosp_tgpr_dev93 || exit 1; +steps/decode.sh --nj 8 --cmd "$decode_cmd" exp/tri2a/graph_nosp_tgpr \ + data/test_eval92 exp/tri2a/decode_nosp_tgpr_eval92 || exit 1; + +utils/mkgraph.sh data/lang_nosp_test_bg_5k exp/tri2a exp/tri2a/graph_nosp_bg5k +steps/decode.sh --nj 8 --cmd "$decode_cmd" exp/tri2a/graph_nosp_bg5k \ + data/test_eval92 exp/tri2a/decode_nosp_eval92_bg5k || exit 1; + +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + --splice-opts "--left-context=3 --right-context=3" 2500 15000 \ + data/train_si84 data/lang_nosp exp/tri1_ali_si84 exp/tri2b || exit 1; + +utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri2b exp/tri2b/graph_nosp_tgpr || exit 1; +steps/decode.sh --nj 10 --cmd "$decode_cmd" exp/tri2b/graph_nosp_tgpr \ + data/test_dev93 exp/tri2b/decode_nosp_tgpr_dev93 || exit 1; +steps/decode.sh --nj 8 --cmd "$decode_cmd" exp/tri2b/graph_nosp_tgpr \ + data/test_eval92 exp/tri2b/decode_nosp_tgpr_eval92 || exit 1; + +# At this point, you could run the example scripts that show how VTLN works. +# We haven't included this in the default recipes yet. +# local/run_vtln.sh --lang-suffix "_nosp" +# local/run_vtln2.sh --lang-suffix "_nosp" + +# Now, with dev93, compare lattice rescoring with biglm decoding, +# going from tgpr to tg. Note: results are not the same, even though they should +# be, and I believe this is due to the beams not being wide enough. The pruning +# seems to be a bit too narrow in the current scripts (got at least 0.7% absolute +# improvement from loosening beams from their current values). + +steps/decode_biglm.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri2b/graph_nosp_tgpr data/lang_test_{tgpr,tg}/G.fst \ + data/test_dev93 exp/tri2b/decode_nosp_tgpr_dev93_tg_biglm + +# baseline via LM rescoring of lattices. +steps/lmrescore.sh --cmd "$decode_cmd" \ + data/lang_nosp_test_tgpr/ data/lang_nosp_test_tg/ \ + data/test_dev93 exp/tri2b/decode_nosp_tgpr_dev93 \ + exp/tri2b/decode_nosp_tgpr_dev93_tg || exit 1; + +# Trying Minimum Bayes Risk decoding (like Confusion Network decoding): +mkdir exp/tri2b/decode_nosp_tgpr_dev93_tg_mbr +cp exp/tri2b/decode_nosp_tgpr_dev93_tg/lat.*.gz \ + exp/tri2b/decode_nosp_tgpr_dev93_tg_mbr +local/score_mbr.sh --cmd "$decode_cmd" \ + data/test_dev93/ data/lang_nosp_test_tgpr/ \ + exp/tri2b/decode_nosp_tgpr_dev93_tg_mbr + +steps/decode_fromlats.sh --cmd "$decode_cmd" \ + data/test_dev93 data/lang_nosp_test_tgpr exp/tri2b/decode_nosp_tgpr_dev93 \ + exp/tri2a/decode_nosp_tgpr_dev93_fromlats || exit 1 + +# Align tri2b system with si84 data. +steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + --use-graphs true data/train_si84 \ + data/lang_nosp exp/tri2b exp/tri2b_ali_si84 || exit 1; + +local/run_mmi_tri2b.sh --lang-suffix "_nosp" + +# From 2b system, train 3b which is LDA + MLLT + SAT. +steps/train_sat.sh --cmd "$train_cmd" 2500 15000 \ + data/train_si84 data/lang_nosp exp/tri2b_ali_si84 exp/tri3b || exit 1; +utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri3b exp/tri3b/graph_nosp_tgpr || exit 1; +steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3b/graph_nosp_tgpr data/test_dev93 \ + exp/tri3b/decode_nosp_tgpr_dev93 || exit 1; +steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri3b/graph_nosp_tgpr data/test_eval92 \ + exp/tri3b/decode_nosp_tgpr_eval92 || exit 1; + +# At this point you could run the command below; this gets +# results that demonstrate the basis-fMLLR adaptation (adaptation +# on small amounts of adaptation data). +local/run_basis_fmllr.sh --lang-suffix "_nosp" + +steps/lmrescore.sh --cmd "$decode_cmd" \ + data/lang_nosp_test_tgpr data/lang_nosp_test_tg \ + data/test_dev93 exp/tri3b/decode_nosp_tgpr_dev93 \ + exp/tri3b/decode_nosp_tgpr_dev93_tg || exit 1; +steps/lmrescore.sh --cmd "$decode_cmd" \ + data/lang_nosp_test_tgpr data/lang_nosp_test_tg \ + data/test_eval92 exp/tri3b/decode_nosp_tgpr_eval92 \ + exp/tri3b/decode_nosp_tgpr_eval92_tg || exit 1; + +# Trying the larger dictionary ("big-dict"/bd) + locally produced LM. +utils/mkgraph.sh data/lang_nosp_test_bd_tgpr \ + exp/tri3b exp/tri3b/graph_nosp_bd_tgpr || exit 1; + +steps/decode_fmllr.sh --cmd "$decode_cmd" --nj 8 \ + exp/tri3b/graph_nosp_bd_tgpr data/test_eval92 \ + exp/tri3b/decode_nosp_bd_tgpr_eval92 || exit 1; +steps/decode_fmllr.sh --cmd "$decode_cmd" --nj 10 \ + exp/tri3b/graph_nosp_bd_tgpr data/test_dev93 \ + exp/tri3b/decode_nosp_bd_tgpr_dev93 || exit 1; + +# Example of rescoring with ConstArpaLm. +steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_nosp_test_bd_{tgpr,fgconst} \ + data/test_eval92 exp/tri3b/decode_nosp_bd_tgpr_eval92{,_fgconst} || exit 1; + +steps/lmrescore.sh --cmd "$decode_cmd" \ + data/lang_nosp_test_bd_tgpr data/lang_nosp_test_bd_fg \ + data/test_eval92 exp/tri3b/decode_nosp_bd_tgpr_eval92 \ + exp/tri3b/decode_nosp_bd_tgpr_eval92_fg || exit 1; +steps/lmrescore.sh --cmd "$decode_cmd" \ + data/lang_nosp_test_bd_tgpr data/lang_nosp_test_bd_tg \ + data/test_eval92 exp/tri3b/decode_nosp_bd_tgpr_eval92 \ + exp/tri3b/decode_nosp_bd_tgpr_eval92_tg || exit 1; + +# The command below is commented out as we commented out the steps above +# that build the RNNLMs, so it would fail. +# local/run_rnnlms_tri3b.sh --lang-suffix "_nosp" + +# The command below is commented out as we commented out the steps above +# that build the RNNLMs (HS version), so it would fail. +# wait; local/run_rnnlm-hs_tri3b.sh --lang-suffix "_nosp" + +# The following two steps, which are a kind of side-branch, try mixing up +( # from the 3b system. This is to demonstrate that script. + steps/mixup.sh --cmd "$train_cmd" \ + 20000 data/train_si84 data/lang_nosp exp/tri3b exp/tri3b_20k || exit 1; + steps/decode_fmllr.sh --cmd "$decode_cmd" --nj 10 \ + exp/tri3b/graph_nosp_tgpr data/test_dev93 \ + exp/tri3b_20k/decode_nosp_tgpr_dev93 || exit 1; +) + +# From 3b system, align all si284 data. +steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/train_si284 data/lang_nosp exp/tri3b exp/tri3b_ali_si284 || exit 1; + + +# From 3b system, train another SAT system (tri4a) with all the si284 data. + +steps/train_sat.sh --cmd "$train_cmd" 4200 40000 \ + data/train_si284 data/lang_nosp exp/tri3b_ali_si284 exp/tri4a || exit 1; +( + utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri4a exp/tri4a/graph_nosp_tgpr || exit 1; + steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4a/graph_nosp_tgpr data/test_dev93 \ + exp/tri4a/decode_nosp_tgpr_dev93 || exit 1; + steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4a/graph_nosp_tgpr data/test_eval92 \ + exp/tri4a/decode_nosp_tgpr_eval92 || exit 1; +) & + + +# This step is just to demonstrate the train_quick.sh script, in which we +# initialize the GMMs from the old system's GMMs. +steps/train_quick.sh --cmd "$train_cmd" 4200 40000 \ + data/train_si284 data/lang_nosp exp/tri3b_ali_si284 exp/tri4b || exit 1; + +( + utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri4b exp/tri4b/graph_nosp_tgpr || exit 1; + steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4b/graph_nosp_tgpr data/test_dev93 \ + exp/tri4b/decode_nosp_tgpr_dev93 || exit 1; + steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4b/graph_nosp_tgpr data/test_eval92 \ + exp/tri4b/decode_nosp_tgpr_eval92 || exit 1; + + utils/mkgraph.sh data/lang_nosp_test_bd_tgpr \ + exp/tri4b exp/tri4b/graph_nosp_bd_tgpr || exit 1; + steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4b/graph_nosp_bd_tgpr data/test_dev93 \ + exp/tri4b/decode_nosp_bd_tgpr_dev93 || exit 1; + steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4b/graph_nosp_bd_tgpr data/test_eval92 \ + exp/tri4b/decode_nosp_bd_tgpr_eval92 || exit 1; +) & + +# Silprob for normal lexicon. +steps/get_prons.sh --cmd "$train_cmd" \ + data/train_si284 data/lang_nosp exp/tri4b || exit 1; +utils/dict_dir_add_pronprobs.sh --max-normalize true \ + data/local/dict_nosp \ + exp/tri4b/pron_counts_nowb.txt exp/tri4b/sil_counts_nowb.txt \ + exp/tri4b/pron_bigram_counts_nowb.txt data/local/dict || exit 1 + +utils/prepare_lang.sh data/local/dict \ + "" data/local/lang_tmp data/lang || exit 1; + +for lm_suffix in bg bg_5k tg tg_5k tgpr tgpr_5k; do + mkdir -p data/lang_test_${lm_suffix} + cp -r data/lang/* data/lang_test_${lm_suffix}/ || exit 1; + rm -rf data/lang_test_${lm_suffix}/tmp + cp data/lang_nosp_test_${lm_suffix}/G.* data/lang_test_${lm_suffix}/ +done + +# Silprob for larger lexicon. +utils/dict_dir_add_pronprobs.sh --max-normalize true \ + data/local/dict_nosp_larger \ + exp/tri4b/pron_counts_nowb.txt exp/tri4b/sil_counts_nowb.txt \ + exp/tri4b/pron_bigram_counts_nowb.txt data/local/dict_larger || exit 1 + +utils/prepare_lang.sh data/local/dict_larger \ + "" data/local/lang_tmp_larger data/lang_bd || exit 1; + +for lm_suffix in tgpr tgconst tg fgpr fgconst fg; do + mkdir -p data/lang_test_bd_${lm_suffix} + cp -r data/lang_bd/* data/lang_test_bd_${lm_suffix}/ || exit 1; + rm -rf data/lang_test_bd_${lm_suffix}/tmp + cp data/lang_nosp_test_bd_${lm_suffix}/G.* data/lang_test_bd_${lm_suffix}/ +done + +( + utils/mkgraph.sh data/lang_test_tgpr exp/tri4b exp/tri4b/graph_tgpr || exit 1; + steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4b/graph_tgpr data/test_dev93 exp/tri4b/decode_tgpr_dev93 || exit 1; + steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4b/graph_tgpr data/test_eval92 exp/tri4b/decode_tgpr_eval92 || exit 1; + + utils/mkgraph.sh data/lang_test_bd_tgpr \ + exp/tri4b exp/tri4b/graph_bd_tgpr || exit 1; + steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4b/graph_bd_tgpr data/test_dev93 \ + exp/tri4b/decode_bd_tgpr_dev93 || exit 1; + steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4b/graph_bd_tgpr data/test_eval92 \ + exp/tri4b/decode_bd_tgpr_eval92 || exit 1; +) & + + +# Train and test MMI, and boosted MMI, on tri4b (LDA+MLLT+SAT on +# all the data). Use 30 jobs. +steps/align_fmllr.sh --nj 30 --cmd "$train_cmd" \ + data/train_si284 data/lang_nosp exp/tri4b exp/tri4b_ali_si284 || exit 1; + +# These demonstrate how to build a sytem usable for online-decoding with the nnet2 setup. +# (see local/run_nnet2.sh for other, non-online nnet2 setups). +local/online/run_nnet2.sh +local/online/run_nnet2_baseline.sh +local/online/run_nnet2_discriminative.sh + +local/run_mmi_tri4b.sh + +#local/run_nnet2.sh + +## Segregated some SGMM builds into a separate file. +#local/run_sgmm.sh + +# You probably want to run the sgmm2 recipe as it's generally a bit better: +local/run_sgmm2.sh + +# We demonstrate MAP adaptation of GMMs to gender-dependent systems here. This also serves +# as a generic way to demonstrate MAP adaptation to different domains. +# local/run_gender_dep.sh + +# You probably want to run the hybrid recipe as it is complementary: +local/nnet/run_dnn.sh + +# The following demonstrate how to re-segment long audios. +# local/run_segmentation.sh + +# The next two commands show how to train a bottleneck network based on the nnet2 setup, +# and build an SGMM system on top of it. +#local/run_bnf.sh +#local/run_bnf_sgmm.sh + + +# You probably want to try KL-HMM +#local/run_kl_hmm.sh + +# Getting results [see RESULTS file] +# for x in exp/*/decode*; do [ -d $x ] && grep WER $x/wer_* | utils/best_wer.sh; done + + +# KWS setup. We leave it commented out by default + +# $duration is the length of the search collection, in seconds +#duration=`feat-to-len scp:data/test_eval92/feats.scp ark,t:- | awk '{x+=$2} END{print x/100;}'` +#local/generate_example_kws.sh data/test_eval92/ data/kws/ +#local/kws_data_prep.sh data/lang_test_bd_tgpr/ data/test_eval92/ data/kws/ +# +#steps/make_index.sh --cmd "$decode_cmd" --acwt 0.1 \ +# data/kws/ data/lang_test_bd_tgpr/ \ +# exp/tri4b/decode_bd_tgpr_eval92/ \ +# exp/tri4b/decode_bd_tgpr_eval92/kws +# +#steps/search_index.sh --cmd "$decode_cmd" \ +# data/kws \ +# exp/tri4b/decode_bd_tgpr_eval92/kws +# +# If you want to provide the start time for each utterance, you can use the --segments +# option. In WSJ each file is an utterance, so we don't have to set the start time. +#cat exp/tri4b/decode_bd_tgpr_eval92/kws/result.* | \ +# utils/write_kwslist.pl --flen=0.01 --duration=$duration \ +# --normalize=true --map-utter=data/kws/utter_map \ +# - exp/tri4b/decode_bd_tgpr_eval92/kws/kwslist.xml + +# # forward-backward decoding example [way to speed up decoding by decoding forward +# # and backward in time] +# local/run_fwdbwd.sh diff --git a/egs/wsj_noisy/s5/steps b/egs/wsj_noisy/s5/steps new file mode 120000 index 00000000000..6e99bf5b5ad --- /dev/null +++ b/egs/wsj_noisy/s5/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/egs/wsj_noisy/s5/utils b/egs/wsj_noisy/s5/utils new file mode 120000 index 00000000000..b240885218f --- /dev/null +++ b/egs/wsj_noisy/s5/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/src/Makefile b/src/Makefile index 57a4b98e0aa..7f683f26fcc 100644 --- a/src/Makefile +++ b/src/Makefile @@ -6,16 +6,15 @@ SHELL := /bin/bash SUBDIRS = base matrix util feat tree thread gmm transform sgmm \ - fstext hmm lm decoder lat kws cudamatrix nnet \ + fstext hmm lm decoder lat kws cudamatrix nnet segmenter \ bin fstbin gmmbin fgmmbin sgmmbin featbin \ nnetbin latbin sgmm2 sgmm2bin nnet2 nnet3 chain nnet3bin nnet2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin chainbin + ivector ivectorbin online2 online2bin lmbin chainbin segmenterbin MEMTESTDIRS = base matrix util feat tree thread gmm transform sgmm \ - fstext hmm lm decoder lat nnet kws chain \ - bin fstbin gmmbin fgmmbin sgmmbin featbin \ + fstext hmm lm decoder lat nnet kws chain segmenter \ nnetbin latbin sgmm2 nnet2 nnet3 nnet2bin nnet3bin sgmm2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin + ivector ivectorbin online2 online2bin lmbin chainbin segmenterbin CUDAMEMTESTDIR = cudamatrix @@ -145,9 +144,9 @@ $(EXT_SUBDIRS) : mklibdir # this is necessary for correct parallel compilation #1)The tools depend on all the libraries -bin fstbin gmmbin fgmmbin sgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin: \ +bin fstbin gmmbin fgmmbin sgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin segmenterbin: \ base matrix util feat tree optimization thread gmm transform sgmm sgmm2 fstext hmm \ - lm decoder lat cudamatrix nnet nnet2 nnet3 ivector + lm decoder lat cudamatrix nnet nnet2 nnet3 ivector segmenter #2)The libraries have inter-dependencies base: @@ -171,7 +170,8 @@ nnet: base util matrix cudamatrix nnet2: base util matrix thread lat gmm hmm tree transform cudamatrix nnet3: base util matrix thread lat gmm hmm tree transform cudamatrix chain chain: lat hmm tree fstext matrix cudamatrix util base -ivector: base util matrix thread transform tree gmm +ivector: base util matrix thread transform tree gmm +segmenter: base matrix util gmm thread #3)Dependencies for optional parts of Kaldi onlinebin: base matrix util feat tree optimization gmm transform sgmm sgmm2 fstext hmm lm decoder lat cudamatrix nnet nnet2 online thread # python-kaldi-decoding: base matrix util feat tree optimization thread gmm transform sgmm sgmm2 fstext hmm decoder lat online diff --git a/src/base/kaldi-extra-types.h b/src/base/kaldi-extra-types.h new file mode 100644 index 00000000000..239d696c5c9 --- /dev/null +++ b/src/base/kaldi-extra-types.h @@ -0,0 +1,44 @@ +// base/kaldi-extra-types.h + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_EXTRA_TYPES_H_ +#define KALDI_BASE_KALDI_EXTRA_TYPES_H_ 1 + +#include "base/kaldi-common.h" + +namespace kaldi { + +struct UtteranceSegment { + std::string reco_id; + BaseFloat start_time; + BaseFloat end_time; + std::string channel_id; + + UtteranceSegment() : start_time(-1), end_time(-1), channel_id("-1") { } + + void Reset() { + start_time = end_time = -1; + reco_id = ""; + channel_id = "-1"; + } +}; + +} + +#endif // KALDI_BASE_KALDI_EXTRA_TYPES_H_ diff --git a/src/bin/Makefile b/src/bin/Makefile index ac175e42e0e..3619b8b7ed9 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -22,7 +22,10 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-logprob matrix-sum latgen-tracking-mapped \ build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ - transform-vec align-text matrix-dim + loglikes-to-post copy-post-mapped weight-pdf-post \ + split-speakers-on-diarization-assigments vector-apply-log \ + matrix-scale matrix-sum-cols extract-int-vector-segments \ + matrix-add-offset OBJFILES = diff --git a/src/bin/copy-matrix.cc b/src/bin/copy-matrix.cc index d7b8181c64c..70419e534d7 100644 --- a/src/bin/copy-matrix.cc +++ b/src/bin/copy-matrix.cc @@ -38,14 +38,25 @@ int main(int argc, char *argv[]) { "See also: copy-feats\n"; bool binary = true; + bool apply_log = false; + bool apply_exp = false; + BaseFloat apply_power = 1.0; BaseFloat scale = 1.0; + ParseOptions po(usage); po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); po.Register("scale", &scale, "This option can be used to scale the matrices being copied."); - + po.Register("apply-log", &apply_log, + "This option can be used to apply log on the matrices. " + "Must be avoided if matrix has negative quantities."); + po.Register("apply-exp", &apply_exp, + "This option can be used to apply exp on the matrices"); + po.Register("apply-power", &apply_power, + "This option can be used to apply a power on the matrices"); + po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -53,6 +64,8 @@ int main(int argc, char *argv[]) { exit(1); } + if (apply_log && apply_exp) + KALDI_ERR << "Only one of apply-log and apply-exp can be given"; std::string matrix_in_fn = po.GetArg(1), matrix_out_fn = po.GetArg(2); @@ -73,6 +86,8 @@ int main(int argc, char *argv[]) { Matrix mat; ReadKaldiObject(matrix_in_fn, &mat); if (scale != 1.0) mat.Scale(scale); + if (apply_log) mat.ApplyLog(); + if (apply_exp) mat.ApplyExp(); Output ko(matrix_out_fn, binary); mat.Write(ko.Stream(), binary); KALDI_LOG << "Copied matrix to " << matrix_out_fn; @@ -82,9 +97,12 @@ int main(int argc, char *argv[]) { BaseFloatMatrixWriter writer(matrix_out_fn); SequentialBaseFloatMatrixReader reader(matrix_in_fn); for (; !reader.Done(); reader.Next(), num_done++) { - if (scale != 1.0) { + if (scale != 1.0 || apply_log || apply_exp || apply_power != 1.0) { Matrix mat(reader.Value()); - mat.Scale(scale); + if (scale != 1.0) mat.Scale(scale); + if (apply_log) mat.ApplyLog(); + if (apply_exp) mat.ApplyExp(); + if (apply_power != 1.0) mat.ApplyPow(apply_power); writer.Write(reader.Key(), mat); } else { writer.Write(reader.Key(), reader.Value()); diff --git a/src/bin/copy-post-mapped.cc b/src/bin/copy-post-mapped.cc new file mode 100644 index 00000000000..0c8efa766ac --- /dev/null +++ b/src/bin/copy-post-mapped.cc @@ -0,0 +1,114 @@ +// bin/copy-post-mapped.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/posterior.h" +#include "util/kaldi-io.h" + +namespace kaldi { + void MapPosterior(std::vector id_map, + Posterior *post) { + for (size_t i = 0; i < post->size(); i++) { + for (size_t j = 0; j < (*post)[i].size(); j++) { + (*post)[i][j].first = id_map[(*post)[i][j].first]; + } + } + } +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + + const char *usage = + "Copy archives of posteriors mapping to new ids\n" + "(Also see copy-post, rand-prune-post and sum-post)\n" + "\n" + "Usage: copy-post-mapped \n"; + + BaseFloat scale = 1.0; + std::string id_map_rxfilename = ""; + + ParseOptions po(usage); + po.Register("scale", &scale, "Scale for posteriors"); + po.Register("id-map", &id_map_rxfilename, + "File name containing old->new id mapping (each line is: " + "old-integer-id new-integer-id)"); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string post_rspecifier = po.GetArg(1), + post_wspecifier = po.GetArg(2); + + kaldi::SequentialPosteriorReader posterior_reader(post_rspecifier); + kaldi::PosteriorWriter posterior_writer(post_wspecifier); + + std::vector id_map; + if (id_map_rxfilename != "") { // read id map. + std::vector > vec; + if (!ReadIntegerVectorVectorSimple(id_map_rxfilename, &vec)) + KALDI_ERR << "Could not read map from " << id_map_rxfilename; + for (size_t i = 0; i < vec.size(); i++) { + if (vec[i].size() != 2 || vec[i][0]<0 || vec[i][1]<=0 || + (vec[i][0](id_map.size()) && + id_map[vec[i][0]] != -1)) + KALDI_ERR << "Error reading id map from " + << id_map_rxfilename + << " (bad line " << i << ")"; + if (vec[i][0] >= static_cast(id_map.size())) + id_map.resize(vec[i][0]+1, -1); + KALDI_ASSERT(id_map[vec[i][0]] == -1); + id_map[vec[i][0]] = vec[i][1]; + } + if (id_map.empty()) { + KALDI_ERR << "Read empty id map from " + << id_map_rxfilename; + } + } + + int32 num_done = 0; + + for (; !posterior_reader.Done(); posterior_reader.Next()) { + std::string key = posterior_reader.Key(); + + kaldi::Posterior posterior = posterior_reader.Value(); + if (scale != 1.0) + ScalePosterior(scale, &posterior); + if (id_map_rxfilename != "") + MapPosterior(id_map, &posterior); + posterior_writer.Write(key, posterior); + + num_done++; + } + KALDI_LOG << "Done copying " << num_done << " posteriors."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/extract-int-vector-segments.cc b/src/bin/extract-int-vector-segments.cc new file mode 100644 index 00000000000..321dfd537e7 --- /dev/null +++ b/src/bin/extract-int-vector-segments.cc @@ -0,0 +1,166 @@ +// bin/extract-int-vector-segments.cc + +// Copyright 2009-2011 Microsoft Corporation; Govivace Inc. +// 2012-2013 Mirko Hannemann; Arnab Ghoshal +// 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Extract integer vectors corresponding to segments from whole vectors\n" + "Usage: extract-int-vector-segments [options...] \n" + " e.g.: extract-int-vector-segments ark:file_ali.ark ark,t:data/train/segments ark:utt_ali.ark"; + + // construct all the global objects + ParseOptions po(usage); + + BaseFloat min_segment_length = 0.1, // Minimum segment length in seconds. + max_overshoot = 0.0; // max time by which last segment can overshoot + int32 frame_shift = 10; + int32 frame_length = 25; + bool snip_edges = true; + + // Register the options + po.Register("min-segment-length", &min_segment_length, + "Minimum segment length in seconds (reject shorter segments)"); + po.Register("frame-length", &frame_length, "Frame length in milliseconds"); + po.Register("frame-shift", &frame_shift, "Frame shift in milliseconds"); + po.Register("max-overshoot", &max_overshoot, + "End segments overshooting by less (in seconds) are truncated," + " else rejected."); + po.Register("snip-edges", &snip_edges, + "If true, n_frames frames will be snipped from the beginning of each " + "extracted feature matrix, " + "where n_frames = ceil((frame_length - frame_shift) / frame_shift), " + "except for the segments at the beginning of a file, where " + "the snipping is done from the end. " + "This ensures that only the feature vectors that " + "completely fit in the segment are extracted. " + "This makes the extracted segment lengths match the lengths of the " + "features that have been extracted from already segmented audio."); + + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string ali_rspecifier = po.GetArg(1); // get script file/vector archive + std::string segments_rspecifier = po.GetArg(2);// get segment file + std::string ali_wspecifier = po.GetArg(3); // get written archive name + + Int32VectorWriter alignment_writer(ali_wspecifier); + SequentialUtteranceSegmentReader segments_reader(segments_rspecifier); + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + + int32 num_done = 0, num_err = 0; + + int32 snip_length = 0; + if (snip_edges) { + snip_length = static_cast(ceil( + 1.0 * (frame_length - frame_shift) / frame_shift)); + } + + for (; !segments_reader.Done(); segments_reader.Next()) { + const std::string &seg_id = segments_reader.Key(); + const UtteranceSegment &segment = segments_reader.Value(); + + if (!alignment_reader.HasKey(segment.reco_id)) { + KALDI_WARN << "Did not find vector for utterance " << segment.reco_id + << ", skipping segment " << segment.reco_id; + num_err++; + continue; + } + const std::vector &ali = alignment_reader.Value(segment.reco_id); + + // total number of samples present in alignment + int32 num_samp = ali.size(); + // Convert start & end times of the segment to corresponding sample number + int32 start_samp = static_cast(( + (segment.start_time * 1000.0 / frame_shift))); + int32 end_samp = static_cast((segment.end_time * 1000.0 / frame_shift + 0.0495)); + + if (snip_edges) { + // snip the edge at the end of the segment (usually 2 frames), + end_samp -= snip_length; + } + + /* start sample must be less than total number of samples + * otherwise skip the segment + */ + if (start_samp < 0 || start_samp >= num_samp) { + KALDI_WARN << "Start sample out of range " << start_samp << " [length:] " + << num_samp << ", skipping segment " << seg_id; + num_err++; + continue; + } + + /* end sample must be less than total number samples + * otherwise skip the segment + */ + if (end_samp > num_samp) { + if (end_samp > + num_samp + static_cast(max_overshoot / frame_shift)) { + KALDI_WARN << "End sample too far out of range " << end_samp + << " [overshooted length:] " + << num_samp + static_cast(max_overshoot / frame_shift) + << ", skipping segment " << seg_id; + num_err++; + continue; + } + end_samp = num_samp; // for small differences, just truncate. + } + + /* check whether the segment size is less than minimum segment length(default 0.1 sec) + * if yes, skip the segment + */ + if (end_samp + <= start_samp + + static_cast(round( + (min_segment_length * 1000.0 / frame_shift)))) { + KALDI_WARN<< "Segment " << seg_id << " too short, skipping it."; + num_err++; + continue; + } + + std::vector seg_ali(ali.begin() + start_samp, + ali.begin() + end_samp); + + alignment_writer.Write(seg_id, seg_ali); + num_done++; + } + KALDI_LOG << "Successfully processed " << num_done << " segments; failed " + << num_err << " segments."; + /* prints number of segments processed */ + if (num_done == 0) return -1; + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/bin/loglikes-to-class.cc b/src/bin/loglikes-to-class.cc new file mode 100644 index 00000000000..474c22aa769 --- /dev/null +++ b/src/bin/loglikes-to-class.cc @@ -0,0 +1,152 @@ +// bin/loglike-to-pred.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "hmm/hmm-utils.h" +#include "hmm/posterior.h" + +/* For each frame, predict its class given the log-likelihoods under + * different class models based on maximum-likelihood decoding */ + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Convert a set of vectors of log-likelihoods (e.g. from gmm-global-get-frame-likes) to class predictions using ML decoding\n" + "Usage: loglikes-to-pred [options] [ ... ] \n" + "e.g.:\n" + " loglikes-to-pred ark:silence_likes.ark ark:speech_likes.ark ark:vad.ark\n"; + + std::string weights_wspecifier; + std::string post_wspecifier; + + ParseOptions po(usage); + po.Register("weights", &weights_wspecifier, "Write posterior probability of each class."); + po.Register("post", &post_wspecifier, "Write posteriors"); + + po.Read(argc, argv); + + if (po.NumArgs() < 2) { + po.PrintUsage(); + exit(1); + } + + std::string loglikes_vec_rspecifier1 = po.GetArg(1); + std::string prediction_wspecifier = po.GetArg(po.NumArgs()); + + int32 num_done = 0; + SequentialBaseFloatVectorReader loglikes_reader1(loglikes_vec_rspecifier1); + std::vector loglikes_readers(po.NumArgs()-2, + static_cast(NULL)); + BaseFloatVectorWriter prediction_writer(prediction_wspecifier); + BaseFloatVectorWriter weights_writer(weights_wspecifier); + PosteriorWriter post_writer(post_wspecifier); + + for (int32 i = 0; i < po.NumArgs()-2; i++) + loglikes_readers[i] = new RandomAccessBaseFloatVectorReader(po.GetArg(i+2)); + + std::vector class_loglikes(po.NumArgs()-1); + std::vector class_counts(po.NumArgs()-1); + + for (; !loglikes_reader1.Done(); loglikes_reader1.Next()) { + const Vector &loglikes1 = loglikes_reader1.Value(); + std::string key = loglikes_reader1.Key(); + std::vector*> loglikes(po.NumArgs()-2, + static_cast*>(NULL)); + for (int32 i = 0; i < po.NumArgs()-2; i++) { + if (!loglikes_readers[i]->HasKey(key)) { + KALDI_ERR << "Key " << key << " not found in " + << po.GetArg(i+2); + } + loglikes[i] = new Vector(loglikes_readers[i]->Value(key)); + } + + Vector prediction(loglikes1.Dim()); + Vector weights; + Posterior post(loglikes1.Dim()); + + if (weights_wspecifier != "") + weights.Resize(loglikes1.Dim()); + + for (int32 j = 0; j < loglikes1.Dim(); j++) { + BaseFloat max_like = loglikes1(j); + Vector this_log_likes(po.NumArgs()-1); + this_log_likes(0) = max_like; + for (int32 i = 0; i < po.NumArgs()-2; i++) { + if (loglikes[i] == NULL) continue; + this_log_likes(i+1) = (*loglikes[i])(j); + KALDI_VLOG(1) << loglikes1(j) << " " << (*loglikes[i])(j); + if ((*(loglikes[i]))(j) > max_like) { + prediction(j) = i+1; + max_like = (*(loglikes[i]))(j); + } + } + if (weights_wspecifier != "") { + weights(j) = Exp(max_like - this_log_likes.LogSumExp()); + KALDI_ASSERT(weights(j) <= 1.0); + } + class_loglikes[prediction(j)] += max_like; + class_counts[prediction(j)]++; + + if (post_wspecifier != "") { + post[j].push_back(std::make_pair(0, Exp(this_log_likes(0) - this_log_likes.LogSumExp()))); + for (int32 i = 0; i < po.NumArgs()-2; i++) { + post[j].push_back(std::make_pair(i+1, Exp(this_log_likes(i+1) - this_log_likes.LogSumExp()))); + } + } + } + + for (int32 i = 0; i < po.NumArgs()-2; i++) { + delete loglikes[i]; + } + + prediction_writer.Write(key, prediction); + if (weights_wspecifier != "") + weights_writer.Write(key, weights); + + if (post_wspecifier != "") + post_writer.Write(key, post); + + num_done++; + } + + KALDI_LOG << "Average log-likelihood of frames of class " << 0 + << " is " << class_loglikes[0] / class_counts[0] + << " over " << class_counts[0] << " frames."; + + for (int32 i = 0; i < po.NumArgs()-2; i++) { + delete loglikes_readers[i]; + KALDI_LOG << "Average log-likelihood of frames of class " << i+1 + << " is " << class_loglikes[i+1] / class_counts[i+1] + << " over " << class_counts[i+1] << " frames."; + } + + KALDI_LOG << "Converted " << num_done << " sets of log-likes vectors to predictions."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/bin/loglikes-to-post.cc b/src/bin/loglikes-to-post.cc new file mode 100644 index 00000000000..3c4a195ac90 --- /dev/null +++ b/src/bin/loglikes-to-post.cc @@ -0,0 +1,89 @@ +// bin/loglike-to-post.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "hmm/hmm-utils.h" +#include "hmm/posterior.h" + +/* Convert a matrix of log-likelihoods to posteriors */ + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Convert a matrix of log-likelihoods (e.g. from gmm-compute-loglikes) to posteriors\n" + "Usage: loglikes-to-post [options] \n" + "e.g.:\n" + " gmm-compute-loglikes [args] | loglike-to-post ark:- ark:1.post\n"; + + ParseOptions po(usage); + + BaseFloat min_post = 0.01; + bool random_prune = true; // preserve expectations. + + po.Register("min-post", &min_post, "Minimum posterior we will output (smaller " + "ones are pruned). Also see --random-prune"); + po.Register("random-prune", &random_prune, "If true, prune posteriors with a " + "randomized method that preserves expectations."); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string loglikes_rspecifier = po.GetArg(1); + std::string posteriors_wspecifier = po.GetArg(2); + + int32 num_done = 0; + SequentialBaseFloatMatrixReader loglikes_reader(loglikes_rspecifier); + PosteriorWriter posterior_writer(posteriors_wspecifier); + + for (; !loglikes_reader.Done(); loglikes_reader.Next()) { + num_done++; + const Matrix &loglikes = loglikes_reader.Value(); + // Posterior is vector > > + Posterior post(loglikes.NumRows()); + for (int32 i = 0; i < loglikes.NumRows(); i++) { + Vector row(SubVector(loglikes, i)); + row.ApplySoftMax(); + for (int32 j = 0; j < row.Dim(); j++) { + BaseFloat p = row(j); + if (p >= min_post) { + post[i].push_back(std::make_pair(j, p)); + } else if (random_prune && (p / min_post) >= RandUniform()) { + post[i].push_back(std::make_pair(j, min_post)); + } + } + } + posterior_writer.Write(loglikes_reader.Key(), post); + } + KALDI_LOG << "Converted " << num_done << " log-likes matrices to posteriors."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/bin/matrix-add-offset.cc b/src/bin/matrix-add-offset.cc new file mode 100644 index 00000000000..81651f7503a --- /dev/null +++ b/src/bin/matrix-add-offset.cc @@ -0,0 +1,83 @@ +// bin/matrix-add-offset.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Add an offset vector to the rows of matrices in a table.\n" + "\n" + "Usage: matrix-add-offset [options] \n" + "e.g.: matrix-add-offset log_post.mat neg_priors.vec log_like.mat\n" + "See also: matrix-sum-rows, matrix-sum, vector-sum\n"; + + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + std::string rspecifier = po.GetArg(1); + std::string vector_rxfilename = po.GetArg(2); + std::string wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader mat_reader(rspecifier); + BaseFloatMatrixWriter mat_writer(wspecifier); + + int32 num_done = 0; + + Vector vec; + { + bool binary_in; + Input ki(vector_rxfilename, &binary_in); + vec.Read(ki.Stream(), binary_in); + } + + for (; !mat_reader.Done(); mat_reader.Next()) { + std::string key = mat_reader.Key(); + Matrix mat(mat_reader.Value()); + if (vec.Dim() != mat.NumCols()) { + KALDI_ERR << "Mismatch in vector dimension and number of columns in matrix; " + << vec.Dim() << " vs " << mat.NumCols(); + } + mat.AddVecToRows(1.0, vec); + mat_writer.Write(key, mat); + num_done++; + } + + KALDI_LOG << "Added offset to " << num_done << " matrices."; + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + diff --git a/src/bin/matrix-scale.cc b/src/bin/matrix-scale.cc new file mode 100644 index 00000000000..e58fdbcce34 --- /dev/null +++ b/src/bin/matrix-scale.cc @@ -0,0 +1,64 @@ +// bin/matrix-scale.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2014 Johns Hopkins University (author: Daniel Povey) +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Scale a set of matrices in a Table\n" + "Usage: matrix-scale [options] \n"; + + ParseOptions po(usage); + BaseFloat scale = 1.0; + + po.Register("scale", &scale, "Scaling factor for matrices"); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string rspecifier = po.GetArg(1); + std::string wspecifier = po.GetArg(2); + + BaseFloatMatrixWriter mat_writer(wspecifier); + + SequentialBaseFloatMatrixReader mat_reader(rspecifier); + for (; !mat_reader.Done(); mat_reader.Next()) { + Matrix mat(mat_reader.Value()); + mat.Scale(scale); + mat_writer.Write(mat_reader.Key(), mat); + } + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/matrix-sum-cols.cc b/src/bin/matrix-sum-cols.cc new file mode 100644 index 00000000000..2a648826039 --- /dev/null +++ b/src/bin/matrix-sum-cols.cc @@ -0,0 +1,96 @@ +// bin/matrix-sum-cols.cc + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Sum the rows of an input table of matrices and output the corresponding\n" + "table of vectors\n" + "\n" + "Usage: matrix-sum-rows [options] \n" + "e.g.: matrix-sum-rows ark:- ark:- | vector-sum ark:- sum.vec\n" + "See also: matrix-sum, vector-sum\n"; + + + ParseOptions po(usage); + + bool log_sum_exp = false; + + po.Register("log-sum-exp", &log_sum_exp, "Sum columns considering the " + "numbers to be stored in log"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + std::string rspecifier = po.GetArg(1); + std::string wspecifier = po.GetArg(2); + + SequentialBaseFloatMatrixReader mat_reader(rspecifier); + BaseFloatVectorWriter vec_writer(wspecifier); + + int32 num_done = 0; + int64 num_rows_done = 0; + + for (; !mat_reader.Done(); mat_reader.Next()) { + std::string key = mat_reader.Key(); + if (!log_sum_exp) { + Matrix mat(mat_reader.Value(), kTrans); + // Do the summation in double, to minimize roundoff. + Vector vec(mat.NumCols()); + vec.AddRowSumMat(1.0, mat, 0.0); + Vector float_vec(vec); + vec_writer.Write(key, float_vec); + num_rows_done += mat.NumCols(); + } else { + Matrix mat(mat_reader.Value()); + Vector vec(mat.NumRows()); + for (size_t t = 0; t < mat.NumRows(); t++) { + SubVector mat_t(mat, t); + vec(t) = mat_t.LogSumExp(); + } + Vector float_vec(vec); + vec_writer.Write(key, float_vec); + num_rows_done += mat.NumRows(); + } + + num_done++; + } + + KALDI_LOG << "Summed columns for " << num_rows_done << " in " + << num_done << " matrices, "; + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + diff --git a/src/bin/matrix-sum.cc b/src/bin/matrix-sum.cc index 6c1d2ad9f12..f5e776627bb 100644 --- a/src/bin/matrix-sum.cc +++ b/src/bin/matrix-sum.cc @@ -29,7 +29,8 @@ namespace kaldi { // of the first two input archives. int32 TypeOneUsage(const ParseOptions &po, BaseFloat scale1, - BaseFloat scale2) { + BaseFloat scale2, + bool log_add_exp = false) { int32 num_args = po.NumArgs(); std::string matrix_in_fn1 = po.GetArg(1), matrix_out_fn = po.GetArg(num_args); @@ -69,7 +70,10 @@ int32 TypeOneUsage(const ParseOptions &po, if (SameDim(matrix2, matrix_out)) { BaseFloat scale = (i == 0 ? scale2 : 1.0); // note: i == 0 corresponds to the 2nd input archive. - matrix_out.AddMat(scale, matrix2, kNoTrans); + if (log_add_exp) + matrix_out.LogAddExpMat(scale, matrix2, kNoTrans); + else + matrix_out.AddMat(scale, matrix2, kNoTrans); } else { KALDI_WARN << "Dimension mismatch for utterance " << key << " : " << matrix2.NumRows() << " by " @@ -104,7 +108,8 @@ int32 TypeOneUsage(const ParseOptions &po, } int32 TypeTwoUsage(const ParseOptions &po, - bool binary) { + bool binary, + bool log_add_exp = false) { KALDI_ASSERT(po.NumArgs() == 2); KALDI_ASSERT(ClassifyRspecifier(po.GetArg(1), NULL, NULL) != kNoRspecifier && "matrix-sum: first argument must be an rspecifier"); @@ -133,7 +138,10 @@ int32 TypeTwoUsage(const ParseOptions &po, num_err++; } else { Matrix dmat(mat); - sum.AddMat(1.0, dmat, kNoTrans); + if (log_add_exp) + sum.LogAddExpMat(1.0, dmat, kNoTrans); + else + sum.AddMat(1.0, dmat, kNoTrans); num_done++; } } @@ -209,6 +217,7 @@ int main(int argc, char *argv[]) { BaseFloat scale1 = 1.0, scale2 = 1.0; bool binary = true; + bool log_add_exp = false; ParseOptions po(usage); @@ -216,6 +225,8 @@ int main(int argc, char *argv[]) { "(only for type one usage)"); po.Register("scale2", &scale2, "Scale applied to second matrix " "(only for type one usage)"); + po.Register("log-add-exp", &log_add_exp, "Treat the input matrices to be " + "in log and also output in log"); po.Register("binary", &binary, "If true, write output as binary (only " "relevant for usage types two or three"); @@ -239,6 +250,9 @@ int main(int argc, char *argv[]) { ClassifyWspecifier(po.GetArg(N), NULL, NULL, NULL) == kNoWspecifier) { KALDI_ASSERT(scale1 == 1.0 && scale2 == 1.0); // summing flat files. + if (log_add_exp) + KALDI_ERR << "log-add-exp is not supported with type 3 usage"; + exit_status = TypeThreeUsage(po, binary); } else { po.PrintUsage(); diff --git a/src/bin/post-to-tacc.cc b/src/bin/post-to-tacc.cc index 6456195e998..9133f3e3c66 100644 --- a/src/bin/post-to-tacc.cc +++ b/src/bin/post-to-tacc.cc @@ -39,33 +39,49 @@ int main(int argc, char *argv[]) { bool binary = true; bool per_pdf = false; + int32 num_targets = -1; + ParseOptions po(usage); po.Register("binary", &binary, "Write output in binary mode."); - po.Register("per-pdf", &per_pdf, "if ture, accumulate counts per pdf-id" + po.Register("per-pdf", &per_pdf, "if true, accumulate counts per pdf-id" " rather than transition-id. (default: false)"); + po.Register("num-targets", &num_targets, "number of targets; useful when " + "there is no transition model."); po.Read(argc, argv); - if (po.NumArgs() != 3) { + if ( (po.NumArgs() != 3) && (po.NumArgs() != 2) ) { po.PrintUsage(); exit(1); } - std::string model_rxfilename = po.GetArg(1), - post_rspecifier = po.GetArg(2), - accs_wxfilename = po.GetArg(3); + int32 N = po.NumArgs(); + + std::string model_rxfilename, + post_rspecifier = po.GetArg(N-1), + accs_wxfilename = po.GetArg(N); + + + if (N == 3) + model_rxfilename = po.GetArg(1); + else + KALDI_ASSERT(num_targets > 0 && !per_pdf); kaldi::SequentialPosteriorReader posterior_reader(post_rspecifier); - int32 num_transition_ids; + int32 num_transition_ids = 0; + + TransitionModel *trans_model = NULL; + if (N == 3) { bool binary_in; Input ki(model_rxfilename, &binary_in); - TransitionModel trans_model; - trans_model.Read(ki.Stream(), binary_in); - num_transition_ids = trans_model.NumTransitionIds(); + trans_model = new TransitionModel; + trans_model->Read(ki.Stream(), binary_in); + num_transition_ids = trans_model->NumTransitionIds(); + } - Vector transition_accs(num_transition_ids+1); // +1 because they're - // 1-based; position zero is empty. We'll write as float. + Vector accs(trans_model ? num_transition_ids+1 : num_targets); + // +1 because tids 1-based; position zero is empty. int32 num_done = 0; for (; !posterior_reader.Done(); posterior_reader.Next()) { @@ -73,12 +89,16 @@ int main(int argc, char *argv[]) { int32 num_frames = static_cast(posterior.size()); for (int32 i = 0; i < num_frames; i++) { for (int32 j = 0; j < static_cast(posterior[i].size()); j++) { - int32 tid = posterior[i][j].first; - if (tid <= 0 || tid > num_transition_ids) - KALDI_ERR << "Invalid transition-id " << tid + int32 id = posterior[i][j].first; + if (num_targets < 0 && (id <= 0 || id > num_transition_ids) ) + KALDI_ERR << "Invalid transition-id " << id + << " encountered for utterance " + << posterior_reader.Key(); + else if (num_targets >= 0 && (id < 0 || id > num_targets)) + KALDI_ERR << "Invalid target " << id << " encountered for utterance " << posterior_reader.Key(); - transition_accs(tid) += posterior[i][j].second; + accs(id) += posterior[i][j].second; } } num_done++; @@ -86,21 +106,21 @@ int main(int argc, char *argv[]) { if (per_pdf) { KALDI_LOG << "accumulate counts per pdf-id"; - int32 num_pdf_ids = trans_model.NumPdfs(); + int32 num_pdf_ids = trans_model->NumPdfs(); Vector pdf_accs(num_pdf_ids); for (int32 i = 1; i < num_transition_ids; i++) { - int32 pid = trans_model.TransitionIdToPdf(i); - pdf_accs(pid) += transition_accs(i); + int32 pid = trans_model->TransitionIdToPdf(i); + pdf_accs(pid) += accs(i); } Vector pdf_accs_float(pdf_accs); Output ko(accs_wxfilename, binary); pdf_accs_float.Write(ko.Stream(), binary); } else { - Vector transition_accs_float(transition_accs); + Vector accs_float(accs); Output ko(accs_wxfilename, binary); - transition_accs_float.Write(ko.Stream(), binary); + accs_float.Write(ko.Stream(), binary); } - KALDI_LOG << "Done computing transition stats over " + KALDI_LOG << "Done accumulating stats over " << num_done << " utterances; wrote stats to " << accs_wxfilename; return (num_done != 0 ? 0 : 1); diff --git a/src/bin/split-speakers-on-diarization-assigments.cc b/src/bin/split-speakers-on-diarization-assigments.cc new file mode 100644 index 00000000000..2501e1d026d --- /dev/null +++ b/src/bin/split-speakers-on-diarization-assigments.cc @@ -0,0 +1,73 @@ +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" + +using namespace kaldi; + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = "Splits speakers using diarization assigments\n" + "Usage: split-speakers-on-diarization-assigments \n" + " e.g.: split-speakers-on-diarization-assigments ark,t:data/dev/utt2spk ark,t:exp/diarization_dev/diarization.txt ark,t:exp/diarization_dev/utt2spk\n" + "\n"; + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + std::string utt2spk_rspecifier = po.GetArg(1); + std::string diar_rspecifier = po.GetArg(2); + std::string utt2spk_wspecifier = po.GetArg(3); + + SequentialTokenReader utt2spk_reader(utt2spk_rspecifier); + RandomAccessInt32Reader diar_reader(diar_rspecifier); + TokenWriter utt2spk_writer(utt2spk_wspecifier); + + int32 num_done = 0, num_err = 0; + for (; !utt2spk_reader.Done(); utt2spk_reader.Next()) { + std::string utt = utt2spk_reader.Key(); + std::string spk = utt2spk_reader.Value(); + + if (!diar_reader.HasKey(utt)) { + KALDI_WARN << "No speaker assignment for utterance " << utt; + num_err++; + continue; + } else { + int32 spk_id = diar_reader.Value(utt); + std::ostringstream oss; + oss << spk << "-" << spk_id; + utt2spk_writer.Write(utt, oss.str()); + } + num_done++; + } + + KALDI_LOG << "Done splitting speaker for " << num_done << " utterances; " + << "failed for " << num_err; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/bin/vector-apply-log.cc b/src/bin/vector-apply-log.cc new file mode 100644 index 00000000000..41be6e4efe4 --- /dev/null +++ b/src/bin/vector-apply-log.cc @@ -0,0 +1,70 @@ +// bin/vector-apply-log.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2014 Johns Hopkins University (author: Daniel Povey) +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Apply log on a set of vectors in a Table (useful for probabilities)\n" + "Usage: vector-apply-log [options] \n"; + + bool invert = false; + + ParseOptions po(usage); + + po.Register("invert", &invert, "Apply exp instead of log"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string rspecifier = po.GetArg(1); + std::string wspecifier = po.GetArg(2); + + BaseFloatVectorWriter vec_writer(wspecifier); + + SequentialBaseFloatVectorReader vec_reader(rspecifier); + for (; !vec_reader.Done(); vec_reader.Next()) { + Vector vec(vec_reader.Value()); + if (!invert) + vec.ApplyLog(); + else + vec.ApplyExp(); + vec_writer.Write(vec_reader.Key(), vec); + } + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + diff --git a/src/bin/vector-scale.cc b/src/bin/vector-scale.cc index 60d4d3121d2..ea68ae31ad0 100644 --- a/src/bin/vector-scale.cc +++ b/src/bin/vector-scale.cc @@ -30,11 +30,14 @@ int main(int argc, char *argv[]) { const char *usage = "Scale a set of vectors in a Table (useful for speaker vectors and " "per-frame weights)\n" - "Usage: vector-scale [options] \n"; + "Usage: vector-scale [options] \n"; ParseOptions po(usage); BaseFloat scale = 1.0; + bool binary = false; + po.Register("binary", &binary, "If true, write output as binary " + "not relevant for archives"); po.Register("scale", &scale, "Scaling factor for vectors"); po.Read(argc, argv); @@ -43,17 +46,33 @@ int main(int argc, char *argv[]) { exit(1); } - std::string rspecifier = po.GetArg(1); - std::string wspecifier = po.GetArg(2); + std::string vector_in_fn = po.GetArg(1); + std::string vector_out_fn = po.GetArg(2); - BaseFloatVectorWriter vec_writer(wspecifier); - - SequentialBaseFloatVectorReader vec_reader(rspecifier); - for (; !vec_reader.Done(); vec_reader.Next()) { - Vector vec(vec_reader.Value()); + if (ClassifyWspecifier(vector_in_fn, NULL, NULL, NULL) != kNoWspecifier) { + if (ClassifyRspecifier(vector_in_fn, NULL, NULL) == kNoRspecifier) { + KALDI_ERR << "Cannot mix archives and regular files"; + } + BaseFloatVectorWriter vec_writer(vector_out_fn); + SequentialBaseFloatVectorReader vec_reader(vector_in_fn); + for (; !vec_reader.Done(); vec_reader.Next()) { + Vector vec(vec_reader.Value()); + vec.Scale(scale); + vec_writer.Write(vec_reader.Key(), vec); + } + } else { + if (ClassifyRspecifier(vector_in_fn, NULL, NULL) != kNoRspecifier) { + KALDI_ERR << "Cannot mix archives and regular files"; + } + bool binary_in; + Input ki(vector_in_fn, &binary_in); + Vector vec; + vec.Read(ki.Stream(), binary_in); vec.Scale(scale); - vec_writer.Write(vec_reader.Key(), vec); + Output ko(vector_out_fn, binary); + vec.Write(ko.Stream(), binary); } + return 0; } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/bin/weight-pdf-post.cc b/src/bin/weight-pdf-post.cc new file mode 100644 index 00000000000..90f0b1e1643 --- /dev/null +++ b/src/bin/weight-pdf-post.cc @@ -0,0 +1,150 @@ +// bin/weight-pdf-post.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "hmm/hmm-utils.h" +#include "hmm/posterior.h" + +namespace kaldi { + +void WeightPdfPost(const ConstIntegerSet &pdf_set, + BaseFloat pdf_scale, + Posterior *post) { + for (size_t i = 0; i < post->size(); i++) { + std::vector > this_post; + this_post.reserve((*post)[i].size()); + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + if (pdf_set.count(pdf_id) != 0) { // is a silence. + if (pdf_scale != 0.0) + this_post.push_back(std::make_pair(pdf_id, weight*pdf_scale)); + } else { + this_post.push_back(std::make_pair(pdf_id, weight)); + } + } + (*post)[i].swap(this_post); + } +} + +void WeightPdfPostDistributed(const ConstIntegerSet &pdf_set, + BaseFloat pdf_scale, + Posterior *post) { + for (size_t i = 0; i < post->size(); i++) { + std::vector > this_post; + this_post.reserve((*post)[i].size()); + BaseFloat sil_weight = 0.0, nonsil_weight = 0.0; + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + if (pdf_set.count(pdf_id) != 0) sil_weight += weight; + else nonsil_weight += weight; + } + KALDI_ASSERT(sil_weight >= 0.0 && nonsil_weight >= 0.0); // This "distributed" + // weighting approach doesn't make sense if we have negative weights. + if (sil_weight + nonsil_weight == 0.0) continue; + BaseFloat frame_scale = (sil_weight * pdf_scale + nonsil_weight) / + (sil_weight + nonsil_weight); + if (frame_scale != 0.0) { + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + this_post.push_back(std::make_pair(pdf_id, weight * frame_scale)); + } + } + (*post)[i].swap(this_post); + } +} + +} + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Apply weight to specific pdfs or tids in posts\n" + "Usage: weight-pdf-post [options] " + " \n" + "e.g.:\n" + " weight-pdf-post 0.00001 0:2 ark:1.post ark:nosil.post\n"; + + ParseOptions po(usage); + + bool distribute = false; + + po.Register("distribute", &distribute, "If true, rather than weighting the " + "individual posteriors, apply the weighting to the whole frame: " + "i.e. on time t, scale all posterior entries by " + "p(sil)*silence-weight + p(non-sil)*1.0"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string pdf_weight_str = po.GetArg(1), + pdfs_str = po.GetArg(2), + posteriors_rspecifier = po.GetArg(3), + posteriors_wspecifier = po.GetArg(4); + + BaseFloat pdf_weight = 0.0; + if (!ConvertStringToReal(pdf_weight_str, &pdf_weight)) + KALDI_ERR << "Invalid pdf-weight parameter: expected float, got \"" + << pdf_weight << '"'; + std::vector pdfs; + if (!SplitStringToIntegers(pdfs_str, ":", false, &pdfs)) + KALDI_ERR << "Invalid pdf string string " << pdfs_str; + if (pdfs.empty()) + KALDI_WARN <<"No pdf specified, this will have no effect"; + ConstIntegerSet pdf_set(pdfs); // faster lookup. + + int32 num_posteriors = 0; + SequentialPosteriorReader posterior_reader(posteriors_rspecifier); + PosteriorWriter posterior_writer(posteriors_wspecifier); + + for (; !posterior_reader.Done(); posterior_reader.Next()) { + num_posteriors++; + // Posterior is vector > > + Posterior post = posterior_reader.Value(); + // Posterior is vector > > + if (distribute) + WeightPdfPostDistributed(pdf_set, + pdf_weight, &post); + else + WeightPdfPost(pdf_set, + pdf_weight, &post); + + posterior_writer.Write(posterior_reader.Key(), post); + } + KALDI_LOG << "Done " << num_posteriors << " posteriors."; + return (num_posteriors != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/weight-post.cc b/src/bin/weight-post.cc index d536896eaaa..8198f8db746 100644 --- a/src/bin/weight-post.cc +++ b/src/bin/weight-post.cc @@ -34,7 +34,14 @@ int main(int argc, char *argv[]) { "\n" "Usage: weight-post \n"; + ParseOptions po(usage); + + int32 length_tolerance = 0; + po.Register("length-tolerance", &length_tolerance, + "Tolerance on difference in number of frames in posterior " + "and weights."); + po.Read(argc, argv); if (po.NumArgs() != 3) { @@ -61,18 +68,23 @@ int main(int argc, char *argv[]) { continue; } const Vector &weights = weights_reader.Value(key); - if (weights.Dim() != static_cast(post.size())) { + if (std::abs(weights.Dim() - static_cast(post.size())) > length_tolerance) { KALDI_WARN << "Weights for utterance " << key << " have wrong size, " << weights.Dim() << " vs. " << post.size(); num_err++; continue; } - for (size_t i = 0; i < post.size(); i++) { + int32 len = std::min(static_cast(post.size()), weights.Dim()); + for (size_t i; i < len; i++) { if (weights(i) == 0.0) post[i].clear(); for (size_t j = 0; j < post[i].size(); j++) post[i][j].second *= weights(i); } + for (int32 j = post.size() - 1; j >= len; j--) { + post.pop_back(); + } + KALDI_ASSERT(post.size() == len); post_writer.Write(key, post); num_done++; } diff --git a/src/cudamatrix/cu-kernels-ansi.h b/src/cudamatrix/cu-kernels-ansi.h index 804bea1a217..caa2239cffa 100644 --- a/src/cudamatrix/cu-kernels-ansi.h +++ b/src/cudamatrix/cu-kernels-ansi.h @@ -59,6 +59,7 @@ void cudaF_apply_exp(dim3 Gr, dim3 Bl, float* mat, MatrixDim d); void cudaF_apply_pow(dim3 Gr, dim3 Bl, float* mat, float power, MatrixDim d); void cudaF_apply_pow_abs(dim3 Gr, dim3 Bl, float* mat, float power, bool include_sign, MatrixDim d); void cudaF_apply_heaviside(dim3 Gr, dim3 Bl, float* mat, MatrixDim d); +void cudaF_apply_signum(dim3 Gr, dim3 Bl, float* mat, MatrixDim d); void cudaF_apply_floor(dim3 Gr, dim3 Bl, float* mat, float floor_val, MatrixDim d); void cudaF_copy_cols(dim3 Gr, dim3 Bl, float* dst, const float* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride); void cudaF_add_cols(dim3 Gr, dim3 Bl, float* dst, const float* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride); @@ -89,6 +90,7 @@ void cudaF_calc_pnorm_deriv(dim3 Gr, dim3 Bl, float *y, const float *x1, const f void cudaF_calc_group_max_deriv(dim3 Gr, dim3 Bl, float *y, const float *x1, const float *x2, MatrixDim d, int src_stride, int group_size); void cudaF_div_rows_vec(dim3 Gr, dim3 Bl, float *mat, const float *vec_div, MatrixDim d); void cudaF_add_mat(dim3 Gr, dim3 Bl, float alpha, const float *src, float *dst, MatrixDim d, int src_stride, int A_trans); +void cudaF_log_add_exp_mat(dim3 Gr, dim3 Bl, float alpha, const float *src, float *dst, MatrixDim d, int src_stride, int A_trans); void cudaF_add_mat_blocks(dim3 Gr, dim3 Bl, float alpha, const float *src, int32_cuda num_row_blocks, int32_cuda num_col_blocks, float *dst, MatrixDim d, int src_stride, int A_trans); void cudaF_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B, const float *C, float *dst, MatrixDim d, int stride_a, int stride_b, int stride_c); void cudaF_add_vec_to_cols(dim3 Gr, dim3 Bl, float alpha, const float *col, float beta, float *dst, MatrixDim d); @@ -198,6 +200,7 @@ void cudaD_apply_exp(dim3 Gr, dim3 Bl, double* mat, MatrixDim d); void cudaD_apply_pow(dim3 Gr, dim3 Bl, double* mat, double power, MatrixDim d); void cudaD_apply_pow_abs(dim3 Gr, dim3 Bl, double* mat, double power, bool include_sign, MatrixDim d); void cudaD_apply_heaviside(dim3 Gr, dim3 Bl, double* mat, MatrixDim d); +void cudaD_apply_signum(dim3 Gr, dim3 Bl, double* mat, MatrixDim d); void cudaD_apply_floor(dim3 Gr, dim3 Bl, double* mat, double floor_val, MatrixDim d); void cudaD_copy_cols(dim3 Gr, dim3 Bl, double* dst, const double* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride); void cudaD_add_cols(dim3 Gr, dim3 Bl, double* dst, const double* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride); @@ -228,6 +231,7 @@ void cudaD_calc_pnorm_deriv(dim3 Gr, dim3 Bl, double *y, const double *x1, const void cudaD_calc_group_max_deriv(dim3 Gr, dim3 Bl, double *y, const double *x1, const double *x2, MatrixDim d, int src_stride, int group_size); void cudaD_div_rows_vec(dim3 Gr, dim3 Bl, double *mat, const double *vec_div, MatrixDim d); void cudaD_add_mat(dim3 Gr, dim3 Bl, double alpha, const double *src, double *dst, MatrixDim d, int src_stride, int A_trans); +void cudaD_log_add_exp_mat(dim3 Gr, dim3 Bl, double alpha, const double *src, double *dst, MatrixDim d, int src_stride, int A_trans); void cudaD_add_mat_blocks(dim3 Gr, dim3 Bl, double alpha, const double *src, int32_cuda num_row_blocks, int32_cuda num_col_blocks, double *dst, MatrixDim d, int src_stride, int A_trans); void cudaD_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A, const double *B, const double *C, double *dst, MatrixDim d, int stride_a, int stride_b, int stride_c); void cudaD_add_vec_to_cols(dim3 Gr, dim3 Bl, double alpha, const double *col, double beta, double *dst, MatrixDim d); diff --git a/src/cudamatrix/cu-kernels.cu b/src/cudamatrix/cu-kernels.cu index 5412dec7ecd..0cd7fb5b253 100644 --- a/src/cudamatrix/cu-kernels.cu +++ b/src/cudamatrix/cu-kernels.cu @@ -573,6 +573,112 @@ static void _add_mat_trans(Real alpha, const Real* src, Real* dst, MatrixDim d, dst[index] = alpha*src[index_src] + dst[index]; } +__device__ +static float _log_add(float x, float y) { + float diff; + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= log(FLT_EPSILON)) { + float res; + res = x + log1p(exp(diff)); + return res; + } else { + return x; // return the larger one. + } +} + +__device__ +static double _log_add(double x, double y) { + double diff; + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= log(DBL_EPSILON)) { + double res; + res = x + log1p(exp(diff)); + return res; + } else { + return x; // return the larger one. + } +} + +__device__ +static float _log_sub(float x, float y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return -1.0 / 0.0; + else + return 0.0 / 0.0; + } + + float diff = y - x; // Will be negative. + float res = x + log1p(-exp(diff)); + + if (isnan(res)) + return -1.0 / 0.0; + return res; +} + +__device__ +static double _log_sub(double x, double y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return -1.0 / 0.0; + else + return 0.0 / 0.0; + } + + double diff = y - x; // Will be negative. + double res = x + log1p(-exp(diff)); + + if (isnan(res)) + return -1.0 / 0.0; + return res; +} + + +template +__global__ +static void _log_add_exp_mat(Real alpha, const Real* src, Real* dst, MatrixDim d, int src_stride) { + int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x; + int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y; + int32_cuda index = i + j*d.stride; + int32_cuda index_src = i + j*src_stride; + if (i < d.cols && j < d.rows) { + if (alpha > 0) + dst[index] = _log_add(log(alpha) + src[index_src], dst[index]); + else if (alpha < 0) + dst[index] = _log_sub(dst[index], log(-alpha) + src[index_src]); + } +} + +template +__global__ +static void _log_add_exp_mat_trans(Real alpha, const Real* src, Real* dst, MatrixDim d, int src_stride) { + int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x; + int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y; + int32_cuda index = i + j *d.stride; + int32_cuda index_src = j + i*src_stride; + if (i < d.cols && j < d.rows) { + if (alpha > 0) + dst[index] = _log_add(log(alpha) + src[index_src], dst[index]); + else if (alpha < 0) + dst[index] = _log_sub(dst[index], log(-alpha) + src[index_src]); + } + +} + template __global__ static void _add_mat_blocks(Real alpha, const Real* src, int32_cuda num_row_blocks, int32_cuda num_col_blocks, Real* dst, MatrixDim d, int src_stride) { @@ -1260,6 +1366,23 @@ static void _apply_heaviside(Real* mat, MatrixDim d) { } +// Caution, here i/block{idx,dim}.x is the row index and j/block{idx,dim}.y is the col index. +// this is for no reason, really, I just happened to prefer this +// at the time. [dan] +template +__global__ +static void _apply_signum(Real* mat, MatrixDim d) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + int index = i * d.stride + j; + + if (i < d.rows && j < d.cols) { + if (mat[index] > 0.0) mat[index] = 1.0; + else if (mat[index] < 0.0) mat[index] = -1.0; + } +} + + template __global__ static void _apply_floor(Real* mat, Real floor_val, MatrixDim d) { @@ -2146,6 +2269,10 @@ void cudaF_apply_heaviside(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) { _apply_heaviside<<>>(mat, d); } +void cudaF_apply_signum(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) { + _apply_signum<<>>(mat, d); +} + void cudaF_copy_cols(dim3 Gr, dim3 Bl, float* dst, const float* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride) { _copy_cols<<>>(dst, src, reorder, dst_dim, src_stride); } @@ -2271,6 +2398,14 @@ void cudaF_add_mat(dim3 Gr, dim3 Bl, float alpha, const float* src, float* dst, } } +void cudaF_log_add_exp_mat(dim3 Gr, dim3 Bl, float alpha, const float* src, float* dst, MatrixDim d, int src_stride, int A_trans) { + if (A_trans) { + _log_add_exp_mat_trans<<>>(alpha,src,dst,d,src_stride); + } else { + _log_add_exp_mat<<>>(alpha,src,dst,d,src_stride); + } +} + void cudaF_add_mat_blocks(dim3 Gr, dim3 Bl, float alpha, const float* src, int32_cuda num_row_blocks, int32_cuda num_col_blocks, float* dst, MatrixDim d, int src_stride, int A_trans) { if (A_trans) { _add_mat_blocks_trans<<>>(alpha, src, num_row_blocks, num_col_blocks, dst, d, src_stride); @@ -2608,6 +2743,10 @@ void cudaD_apply_heaviside(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) { _apply_heaviside<<>>(mat, d); } +void cudaD_apply_signum(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) { + _apply_signum<<>>(mat, d); +} + void cudaD_copy_cols(dim3 Gr, dim3 Bl, double* dst, const double* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride) { _copy_cols<<>>(dst, src, reorder, dst_dim, src_stride); } @@ -2733,6 +2872,14 @@ void cudaD_add_mat(dim3 Gr, dim3 Bl, double alpha, const double* src, double* ds } } +void cudaD_log_add_exp_mat(dim3 Gr, dim3 Bl, double alpha, const double* src, double* dst, MatrixDim d, int src_stride, int A_trans) { + if (A_trans) { + _log_add_exp_mat_trans<<>>(alpha,src,dst,d,src_stride); + } else { + _log_add_exp_mat<<>>(alpha,src,dst,d,src_stride); + } +} + void cudaD_add_mat_blocks(dim3 Gr, dim3 Bl, double alpha, const double* src, int32_cuda num_row_blocks, int32_cuda num_col_blocks, double* dst, MatrixDim d, int src_stride, int A_trans) { if (A_trans) { _add_mat_blocks_trans<<>>(alpha, src, num_row_blocks, num_col_blocks, dst, d, src_stride); diff --git a/src/cudamatrix/cu-kernels.h b/src/cudamatrix/cu-kernels.h index fc1fbae54da..cb25ffb106d 100644 --- a/src/cudamatrix/cu-kernels.h +++ b/src/cudamatrix/cu-kernels.h @@ -125,6 +125,7 @@ inline void cuda_apply_exp(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) { cudaF_ap inline void cuda_apply_pow(dim3 Gr, dim3 Bl, float* mat, float power, MatrixDim dim) { cudaF_apply_pow(Gr,Bl,mat,power,dim); } inline void cuda_apply_pow_abs(dim3 Gr, dim3 Bl, float* mat, float power, bool include_sign, MatrixDim dim) { cudaF_apply_pow_abs(Gr,Bl,mat,power,include_sign, dim); } inline void cuda_apply_heaviside(dim3 Gr, dim3 Bl, float* mat, MatrixDim dim) { cudaF_apply_heaviside(Gr,Bl,mat,dim); } +inline void cuda_apply_signum(dim3 Gr, dim3 Bl, float* mat, MatrixDim dim) { cudaF_apply_signum(Gr,Bl,mat,dim); } inline void cuda_apply_floor(dim3 Gr, dim3 Bl, float* mat, float floor_val, MatrixDim dim) { cudaF_apply_floor(Gr,Bl,mat,floor_val,dim); } inline void cuda_apply_ceiling(dim3 Gr, dim3 Bl, float* mat, float ceiling_val, MatrixDim dim) { cudaF_apply_ceiling(Gr,Bl,mat,ceiling_val,dim); } inline void cuda_copy_cols(dim3 Gr, dim3 Bl, float* dst, const float* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride) { @@ -169,6 +170,7 @@ inline void cuda_mul_rows_group_mat(dim3 Gr, dim3 Bl, float *y, const float *x, inline void cuda_calc_pnorm_deriv(dim3 Gr, dim3 Bl, float *y, const float *x1, const float *x2, MatrixDim d, int src_stride, int group_size, float power) {cudaF_calc_pnorm_deriv(Gr, Bl, y, x1, x2, d, src_stride, group_size, power); } inline void cuda_calc_group_max_deriv(dim3 Gr, dim3 Bl, float *y, const float *x1, const float *x2, MatrixDim d, int src_stride, int group_size) {cudaF_calc_group_max_deriv(Gr, Bl, y, x1, x2, d, src_stride, group_size); } inline void cuda_add_mat(dim3 Gr, dim3 Bl, float alpha, const float *src, float *dst, MatrixDim d, int src_stride, int A_trans) { cudaF_add_mat(Gr,Bl,alpha,src,dst,d,src_stride, A_trans); } +inline void cuda_log_add_exp_mat(dim3 Gr, dim3 Bl, float alpha, const float *src, float *dst, MatrixDim d, int src_stride, int A_trans) { cudaF_log_add_exp_mat(Gr,Bl,alpha,src,dst,d,src_stride, A_trans); } inline void cuda_add_mat_blocks(dim3 Gr, dim3 Bl, float alpha, const float *src, int32_cuda num_row_blocks, int32_cuda num_col_blocks, float *dst, MatrixDim d, int src_stride, int A_trans) { cudaF_add_mat_blocks(Gr, Bl, alpha, src, num_row_blocks, num_col_blocks, dst, d, src_stride, A_trans); } inline void cuda_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B, const float *C, float *dst, MatrixDim d, int stride_a, int stride_b, int stride_c) { cudaF_add_mat_mat_div_mat(Gr,Bl,A,B,C,dst,d,stride_a,stride_b,stride_c); } inline void cuda_add_vec_to_cols(dim3 Gr, dim3 Bl, float alpha, const float *col, float beta, float *dst, MatrixDim d) { cudaF_add_vec_to_cols(Gr,Bl,alpha,col,beta,dst,d); } @@ -311,6 +313,7 @@ inline void cuda_apply_exp(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) { cudaD_a inline void cuda_apply_pow(dim3 Gr, dim3 Bl, double* mat, double power, MatrixDim dim) { cudaD_apply_pow(Gr,Bl,mat,power,dim); } inline void cuda_apply_pow_abs(dim3 Gr, dim3 Bl, double* mat, double power, bool include_sign, MatrixDim dim) { cudaD_apply_pow_abs(Gr,Bl,mat,power,include_sign,dim); } inline void cuda_apply_heaviside(dim3 Gr, dim3 Bl, double* mat, MatrixDim dim) { cudaD_apply_heaviside(Gr,Bl,mat,dim); } +inline void cuda_apply_signum(dim3 Gr, dim3 Bl, double* mat, MatrixDim dim) { cudaD_apply_signum(Gr,Bl,mat,dim); } inline void cuda_apply_floor(dim3 Gr, dim3 Bl, double* mat, double floor_val, MatrixDim dim) { cudaD_apply_floor(Gr,Bl,mat,floor_val,dim); } inline void cuda_apply_ceiling(dim3 Gr, dim3 Bl, double* mat, double ceiling_val, MatrixDim dim) { cudaD_apply_ceiling(Gr,Bl,mat,ceiling_val,dim); } inline void cuda_copy_cols(dim3 Gr, dim3 Bl, double* dst, const double* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride) { @@ -355,6 +358,7 @@ inline void cuda_mul_rows_group_mat(dim3 Gr, dim3 Bl, double *y, const double *x inline void cuda_calc_pnorm_deriv(dim3 Gr, dim3 Bl, double *y, const double *x1, const double *x2, MatrixDim d, int src_stride, int group_size, double power) {cudaD_calc_pnorm_deriv(Gr, Bl, y, x1, x2, d, src_stride, group_size, power); } inline void cuda_calc_group_max_deriv(dim3 Gr, dim3 Bl, double *y, const double *x1, const double *x2, MatrixDim d, int src_stride, int group_size) {cudaD_calc_group_max_deriv(Gr, Bl, y, x1, x2, d, src_stride, group_size); } inline void cuda_add_mat(dim3 Gr, dim3 Bl, double alpha, const double *src, double *dst, MatrixDim d, int src_stride, int A_trans) { cudaD_add_mat(Gr,Bl,alpha,src,dst,d,src_stride, A_trans); } +inline void cuda_log_add_exp_mat(dim3 Gr, dim3 Bl, double alpha, const double *src, double *dst, MatrixDim d, int src_stride, int A_trans) { cudaD_log_add_exp_mat(Gr,Bl,alpha,src,dst,d,src_stride, A_trans); } inline void cuda_add_mat_blocks(dim3 Gr, dim3 Bl, double alpha, const double *src, int32_cuda num_row_blocks, int32_cuda num_col_blocks, double *dst, MatrixDim d, int src_stride, int A_trans) { cudaD_add_mat_blocks(Gr, Bl, alpha, src, num_row_blocks, num_col_blocks, dst, d, src_stride, A_trans); } inline void cuda_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A, const double *B, const double *C, double *dst, MatrixDim d, int stride_a, int stride_b, int stride_c) { cudaD_add_mat_mat_div_mat(Gr,Bl,A,B,C,dst,d,stride_a,stride_b,stride_c); } inline void cuda_add_vec_to_cols(dim3 Gr, dim3 Bl, double alpha, const double *col, double beta, double *dst, MatrixDim d) { cudaD_add_vec_to_cols(Gr,Bl,alpha,col,beta,dst,d); } diff --git a/src/cudamatrix/cu-matrix.cc b/src/cudamatrix/cu-matrix.cc index 03114e61ed1..5a0b23d81cc 100644 --- a/src/cudamatrix/cu-matrix.cc +++ b/src/cudamatrix/cu-matrix.cc @@ -908,6 +908,33 @@ void CuMatrixBase::AddMat(Real alpha, const CuMatrixBase& A, } } +template +void CuMatrixBase::LogAddExpMat(Real alpha, const CuMatrixBase& A, + MatrixTransposeType transA) { + +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + if (transA == kNoTrans) { + KALDI_ASSERT(A.NumRows() == num_rows_ && A.NumCols() == num_cols_); + } else { + KALDI_ASSERT(A.NumCols() == num_rows_ && A.NumRows() == num_cols_); + } + if (num_rows_ == 0) return; + Timer tim; + dim3 dimBlock(CU2DBLOCK, CU2DBLOCK); + dim3 dimGrid(n_blocks(NumCols(), CU2DBLOCK), n_blocks(NumRows(), CU2DBLOCK)); + cuda_log_add_exp_mat(dimGrid, dimBlock, alpha, A.data_, data_, Dim(), + A.Stride(), (transA == kTrans ? 1 : 0)); + CU_SAFE_CALL(cudaGetLastError()); + + CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed()); + } else +#endif + { + Mat().LogAddExpMat(alpha, A.Mat(), transA); + } +} + template void CuMatrixBase::AddMatBlocks(Real alpha, const CuMatrixBase &A, MatrixTransposeType transA) { @@ -2017,6 +2044,24 @@ void CuMatrixBase::ApplyHeaviside() { } } +template +void CuMatrixBase::ApplySignum() { +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + Timer tim; + dim3 dimBlock(CU2DBLOCK, CU2DBLOCK); + dim3 dimGrid(n_blocks(NumRows(), CU2DBLOCK), + n_blocks(NumCols(), CU2DBLOCK)); + + cuda_apply_heaviside(dimGrid, dimBlock, data_, Dim()); + CU_SAFE_CALL(cudaGetLastError()); + CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed()); + } else +#endif + { + Mat().ApplySignum(); + } +} template void CuMatrixBase::ApplyExp() { diff --git a/src/cudamatrix/cu-matrix.h b/src/cudamatrix/cu-matrix.h index fd4c642ab7f..95aa41bd8ec 100644 --- a/src/cudamatrix/cu-matrix.h +++ b/src/cudamatrix/cu-matrix.h @@ -337,6 +337,7 @@ class CuMatrixBase { ///< multiply the result by the sign of the input. void ApplyPowAbs(Real power, bool include_sign=false); void ApplyHeaviside(); ///< For each element, sets x = (x > 0 ? 1.0 : 0.0) + void ApplySignum(); ///< For each element, sets x = (1 if x > 0; 0 if x = 0; -1 if x < 0) void ApplyFloor(Real floor_val); void ApplyCeiling(Real ceiling_val); void ApplyExp(); @@ -381,6 +382,10 @@ class CuMatrixBase { /// *this += alpha * A void AddMat(Real alpha, const CuMatrixBase &A, MatrixTransposeType trans = kNoTrans); + + /// Version of AddMat when the matrices are stored in log + void LogAddExpMat(Real alpha, const CuMatrixBase &A, + MatrixTransposeType trans = kNoTrans); /// if A.NumRows() is multiple of (*this)->NumRows and A.NumCols() is multiple of (*this)->NumCols /// divide A into blocks of the same size as (*this) and add them to *this (times alpha) diff --git a/src/feat/feature-fbank.cc b/src/feat/feature-fbank.cc index af1f7b1a346..15cca6574a8 100644 --- a/src/feat/feature-fbank.cc +++ b/src/feat/feature-fbank.cc @@ -28,9 +28,9 @@ Fbank::Fbank(const FbankOptions &opts) if (opts.energy_floor > 0.0) log_energy_floor_ = Log(opts.energy_floor); - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); + int32 num_fft_bins = opts.frame_opts.NumFftBins(); + if ((num_fft_bins & (num_fft_bins-1)) == 0) // Is a power of two... + srfft_ = new SplitRadixRealFft(num_fft_bins); // We'll definitely need the filterbanks info for VTLN warping factor 1.0. // [note: this call caches it.] The reason we call this here is to @@ -134,6 +134,11 @@ void Fbank::ComputeInternal(const VectorBase &wave, // Cut the window, apply window function ExtractWindow(wave, r, opts_.frame_opts, feature_window_function_, &window, (opts_.use_energy && opts_.raw_energy ? &log_energy : NULL)); + + int32 num_fft_bins = opts_.frame_opts.NumFftBins(); + + KALDI_ASSERT(window.Dim() <= num_fft_bins); + window.Resize(num_fft_bins, kCopyData); // Compute energy after window function (not the raw one) if (opts_.use_energy && !opts_.raw_energy) diff --git a/src/feat/feature-functions.cc b/src/feat/feature-functions.cc index 9678e909a5a..d3de2d3b1d6 100644 --- a/src/feat/feature-functions.cc +++ b/src/feat/feature-functions.cc @@ -531,5 +531,44 @@ void SlidingWindowCmn(const SlidingWindowCmnOptions &opts, } +void ComputeZeroCrossings(const VectorBase &wave, + FrameExtractionOptions frame_opts, + BaseFloat threshold, + Vector *output, + Vector *wave_remainder) { + KALDI_ASSERT(output != NULL); + + // Get dimensions of output features + int32 rows_out = NumFrames(wave.Dim(), frame_opts); + if (rows_out == 0) + KALDI_ERR << "No frames fit in file (#samples is " << wave.Dim() << ")"; + // Prepare the output buffer + output->Resize(rows_out); + + // Optionally extract the remainder for further processing + if (wave_remainder != NULL) + ExtractWaveformRemainder(wave, frame_opts, wave_remainder); + + // Buffers + Vector window; // windowed waveform. + + FeatureWindowFunction feature_window_function(frame_opts); + // Compute all the freames, r is frame index.. + for (int32 r = 0; r < rows_out; r++) { + // Cut the window, apply window function + ExtractWindow(wave, r, frame_opts, feature_window_function, + &window, NULL); + + int32 zc = 0; + for (int32 i = 1; i < window.Dim(); i++) { + if ( (window(i-1) < -threshold && window(i) > threshold) + || (window(i-1) > threshold && window(i) < -threshold) ) { + zc++; + } + } + + (*output)(r) = zc; + } +} } // namespace kaldi diff --git a/src/feat/feature-functions.h b/src/feat/feature-functions.h index c5dfe9a3010..74fad44bfa1 100644 --- a/src/feat/feature-functions.h +++ b/src/feat/feature-functions.h @@ -79,6 +79,7 @@ struct FrameExtractionOptions { bool remove_dc_offset; // Subtract mean of wave before FFT. std::string window_type; // e.g. Hamming window bool round_to_power_of_two; + int32 num_fft_bins; bool snip_edges; // Maybe "hamming", "rectangular", "povey", "hanning" // "povey" is a window I made to be similar to Hamming but to go to zero at the @@ -93,6 +94,7 @@ struct FrameExtractionOptions { remove_dc_offset(true), window_type("povey"), round_to_power_of_two(true), + num_fft_bins(128), snip_edges(true){ } void Register(OptionsItf *opts) { @@ -110,6 +112,8 @@ struct FrameExtractionOptions { "(\"hamming\"|\"hanning\"|\"povey\"|\"rectangular\")"); opts->Register("round-to-power-of-two", &round_to_power_of_two, "If true, round window size to power of two."); + opts->Register("num-fft-bins", &num_fft_bins, + "Number of FFT bins to compute spectrogram"); opts->Register("snip-edges", &snip_edges, "If true, end effects will be handled by outputting only frames that " "completely fit in the file, and the number of frames depends on the " @@ -126,6 +130,15 @@ struct FrameExtractionOptions { return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) : WindowSize()); } + + int32 NumFftBins() const { + int32 padded_window_size = PaddedWindowSize(); + if (num_fft_bins > padded_window_size) + return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(num_fft_bins) : + num_fft_bins); + return padded_window_size; + } + }; @@ -345,6 +358,12 @@ void SlidingWindowCmn(const SlidingWindowCmnOptions &opts, MatrixBase *output); +void ComputeZeroCrossings(const VectorBase &wave, + FrameExtractionOptions frame_opts, + BaseFloat threshold, + Vector *output, + Vector *wave_remainder); + /// @} End of "addtogroup feat" } // namespace kaldi diff --git a/src/feat/feature-mfcc.cc b/src/feat/feature-mfcc.cc index 518f7462951..fc89714a782 100644 --- a/src/feat/feature-mfcc.cc +++ b/src/feat/feature-mfcc.cc @@ -41,9 +41,9 @@ Mfcc::Mfcc(const MfccOptions &opts) if (opts.energy_floor > 0.0) log_energy_floor_ = Log(opts.energy_floor); - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); + int32 num_fft_bins = opts.frame_opts.NumFftBins(); + if ((num_fft_bins & (num_fft_bins-1)) == 0) // Is a power of two... + srfft_ = new SplitRadixRealFft(num_fft_bins); // We'll definitely need the filterbanks info for VTLN warping factor 1.0. // [note: this call caches it.] The reason we call this here is to @@ -134,6 +134,9 @@ void Mfcc::ComputeInternal(const VectorBase &wave, std::vector temp_buffer; // used by srfft. for (int32 r = 0; r < rows_out; r++) { // r is frame index.. BaseFloat log_energy; + + // If opts_.raw_energy and opts_.use_energy are both true, then + // log_energy is computed before windowing and stored in log_energy ExtractWindow(wave, r, opts_.frame_opts, feature_window_function_, &window, (opts_.use_energy && opts_.raw_energy ? &log_energy : NULL)); @@ -141,6 +144,10 @@ void Mfcc::ComputeInternal(const VectorBase &wave, log_energy = Log(std::max(VecVec(window, window), std::numeric_limits::min())); + int32 num_fft_bins = opts_.frame_opts.NumFftBins(); + KALDI_ASSERT(window.Dim() <= num_fft_bins); + window.Resize(num_fft_bins, kCopyData); + if (srfft_ != NULL) // Compute FFT using the split-radix algorithm. srfft_->Compute(window.Data(), true, &temp_buffer); else // An alternative algorithm that works for non-powers-of-two. diff --git a/src/feat/feature-spectrogram.cc b/src/feat/feature-spectrogram.cc index faa0b44aba6..ba8b4f58ba8 100644 --- a/src/feat/feature-spectrogram.cc +++ b/src/feat/feature-spectrogram.cc @@ -29,9 +29,9 @@ Spectrogram::Spectrogram(const SpectrogramOptions &opts) if (opts.energy_floor > 0.0) log_energy_floor_ = Log(opts.energy_floor); - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two - srfft_ = new SplitRadixRealFft(padded_window_size); + int32 num_fft_bins = opts_.frame_opts.NumFftBins(); + if ((num_fft_bins & (num_fft_bins-1)) == 0) // Is a power of two + srfft_ = new SplitRadixRealFft(num_fft_bins); } Spectrogram::~Spectrogram() { @@ -45,7 +45,33 @@ void Spectrogram::Compute(const VectorBase &wave, // Get dimensions of output features int32 rows_out = NumFrames(wave.Dim(), opts_.frame_opts); - int32 cols_out = opts_.frame_opts.PaddedWindowSize()/2 +1; + + int32 num_fft_bins = opts_.frame_opts.NumFftBins(); + + BaseFloat sample_freq = opts_.frame_opts.samp_freq; + BaseFloat nyquist = 0.5 * sample_freq; + BaseFloat low_freq = opts_.low_freq, high_freq; + if (opts_.high_freq > 0.0) + high_freq = opts_.high_freq; + else + high_freq = nyquist + opts_.high_freq; + + if (low_freq < 0.0 || low_freq >= nyquist + || high_freq <= 0.0 || high_freq > nyquist + || high_freq <= low_freq) + KALDI_ERR << "Bad values in options: low-freq " << low_freq + << " and high-freq " << high_freq << " vs. nyquist " + << nyquist; + + int32 low_c = low_freq / sample_freq * num_fft_bins; + int32 high_c = high_freq / sample_freq * num_fft_bins; + + int32 cols_out = high_c - low_c + 1; + + if (opts_.use_energy && low_c != 0) { + cols_out++; + } + if (rows_out == 0) KALDI_ERR << "No frames fit in file (#samples is " << wave.Dim() << ")"; // Prepare the output buffer @@ -63,10 +89,13 @@ void Spectrogram::Compute(const VectorBase &wave, for (int32 r = 0; r < rows_out; r++) { // Cut the window, apply window function ExtractWindow(wave, r, opts_.frame_opts, feature_window_function_, - &window, (opts_.raw_energy ? &log_energy : NULL)); + &window, (opts_.use_energy && opts_.raw_energy ? &log_energy : NULL)); + + KALDI_ASSERT(window.Dim() <= num_fft_bins); + window.Resize(num_fft_bins, kCopyData); // Compute energy after window function (not the raw one) - if (!opts_.raw_energy) + if (opts_.use_energy && !opts_.raw_energy) log_energy = Log(std::max(VecVec(window, window), std::numeric_limits::min())); @@ -83,12 +112,16 @@ void Spectrogram::Compute(const VectorBase &wave, power_spectrum.ApplyLog(); // Output buffers - SubVector this_output(output->Row(r)); - this_output.CopyFromVec(power_spectrum); - if (opts_.energy_floor > 0.0 && log_energy < log_energy_floor_) { + SubVector this_output( + (output->Row(r)).Range((opts_.use_energy && low_c != 0) ? 1 : 0, high_c - low_c + 1)); + SubVector this_power_spectrum(power_spectrum, + low_c, high_c - low_c + 1); + this_output.CopyFromVec(this_power_spectrum); + if (opts_.use_energy && opts_.energy_floor > 0.0 && log_energy < log_energy_floor_) { log_energy = log_energy_floor_; } - this_output(0) = log_energy; + if (opts_.use_energy) + this_output(0) = log_energy; } } diff --git a/src/feat/feature-spectrogram.h b/src/feat/feature-spectrogram.h index 500e3f4a588..4e508ec6cb2 100644 --- a/src/feat/feature-spectrogram.h +++ b/src/feat/feature-spectrogram.h @@ -39,10 +39,13 @@ struct SpectrogramOptions { FrameExtractionOptions frame_opts; BaseFloat energy_floor; bool raw_energy; // If true, compute energy before preemphasis and windowing + bool use_energy; // append an extra dimension with energy to the filter banks + BaseFloat low_freq; // e.g. 20; lower frequency cutoff + BaseFloat high_freq; // an upper frequency cutoff; 0 -> no cutoff, negative SpectrogramOptions() : energy_floor(0.0), // not in log scale: a small value e.g. 1.0e-10 - raw_energy(true) {} + raw_energy(true), use_energy(true), low_freq(0), high_freq(0) {} void Register(OptionsItf *opts) { frame_opts.Register(opts); @@ -50,6 +53,12 @@ struct SpectrogramOptions { "Floor on energy (absolute, not relative) in Spectrogram computation"); opts->Register("raw-energy", &raw_energy, "If true, compute energy before preemphasis and windowing"); + opts->Register("use-energy", &use_energy, + "Add an extra dimension with energy to the spectrogram output."); + opts->Register("low-freq", &low_freq, + "Low cutoff frequency for mel bins"); + opts->Register("high-freq", &high_freq, + "High cutoff frequency for mel bins (if < 0, offset from Nyquist)"); } }; diff --git a/src/feat/mel-computations.cc b/src/feat/mel-computations.cc index 9949a468d4c..ac7a3c07d9e 100644 --- a/src/feat/mel-computations.cc +++ b/src/feat/mel-computations.cc @@ -36,13 +36,8 @@ MelBanks::MelBanks(const MelBanksOptions &opts, int32 num_bins = opts.num_bins; if (num_bins < 3) KALDI_ERR << "Must have at least 3 mel bins"; BaseFloat sample_freq = frame_opts.samp_freq; - int32 window_length = static_cast(frame_opts.samp_freq*0.001*frame_opts.frame_length_ms); - int32 window_length_padded = - (frame_opts.round_to_power_of_two ? - RoundUpToNearestPowerOfTwo(window_length) : - window_length); - KALDI_ASSERT(window_length_padded % 2 == 0); - int32 num_fft_bins = window_length_padded/2; + int32 num_fft_bins = frame_opts.NumFftBins(); + BaseFloat nyquist = 0.5 * sample_freq; BaseFloat low_freq = opts.low_freq, high_freq; @@ -58,8 +53,7 @@ MelBanks::MelBanks(const MelBanksOptions &opts, << " and high-freq " << high_freq << " vs. nyquist " << nyquist; - BaseFloat fft_bin_width = sample_freq / window_length_padded; - // fft-bin width [think of it as Nyquist-freq / half-window-length] + BaseFloat fft_bin_width = sample_freq / num_fft_bins; BaseFloat mel_low_freq = MelScale(low_freq); BaseFloat mel_high_freq = MelScale(high_freq); @@ -103,9 +97,9 @@ MelBanks::MelBanks(const MelBanksOptions &opts, center_freqs_(bin) = InverseMelScale(center_mel); // this_bin will be a vector of coefficients that is only // nonzero where this mel bin is active. - Vector this_bin(num_fft_bins); + Vector this_bin(num_fft_bins / 2); int32 first_index = -1, last_index = -1; - for (int32 i = 0; i < num_fft_bins; i++) { + for (int32 i = 0; i < num_fft_bins / 2; i++) { BaseFloat freq = (fft_bin_width * i); // center freq of this fft bin. BaseFloat mel = MelScale(freq); if (mel > left_mel && mel < right_mel) { diff --git a/src/feat/pitch-functions.cc b/src/feat/pitch-functions.cc index 4f717129994..732764f24b4 100644 --- a/src/feat/pitch-functions.cc +++ b/src/feat/pitch-functions.cc @@ -1357,7 +1357,8 @@ OnlineProcessPitch::OnlineProcessPitch( dim_ ((opts.add_pov_feature ? 1 : 0) + (opts.add_normalized_log_pitch ? 1 : 0) + (opts.add_delta_pitch ? 1 : 0) - + (opts.add_raw_log_pitch ? 1 : 0)) { + + (opts.add_raw_log_pitch ? 1 : 0) + + (opts.add_raw_pov ? 1 : 0)) { KALDI_ASSERT(dim_ > 0 && " At least one of the pitch features should be chosen. " "Check your post-process-pitch options."); @@ -1380,6 +1381,8 @@ void OnlineProcessPitch::GetFrame(int32 frame, (*feat)(index++) = GetDeltaPitchFeature(frame_delayed); if (opts_.add_raw_log_pitch) (*feat)(index++) = GetRawLogPitchFeature(frame_delayed); + if (opts_.add_raw_pov) + (*feat)(index++) = GetRawPov(frame_delayed); KALDI_ASSERT(index == dim_); } @@ -1437,6 +1440,12 @@ BaseFloat OnlineProcessPitch::GetNormalizedLogPitchFeature(int32 frame) { return normalized_log_pitch * opts_.pitch_scale; } +BaseFloat OnlineProcessPitch::GetRawPov(int32 frame) const { + Vector tmp(kRawFeatureDim); + src_->GetFrame(frame, &tmp); // (NCCF, pitch) from pitch extractor + BaseFloat nccf = tmp(0); + return NccfToPov(nccf); +} // inline void OnlineProcessPitch::GetNormalizationWindow(int32 t, diff --git a/src/feat/pitch-functions.h b/src/feat/pitch-functions.h index ee3c293ef81..a4be23d8be1 100644 --- a/src/feat/pitch-functions.h +++ b/src/feat/pitch-functions.h @@ -229,6 +229,7 @@ struct ProcessPitchOptions { bool add_normalized_log_pitch; bool add_delta_pitch; bool add_raw_log_pitch; + bool add_raw_pov; ProcessPitchOptions() : pitch_scale(2.0), @@ -243,7 +244,7 @@ struct ProcessPitchOptions { add_pov_feature(true), add_normalized_log_pitch(true), add_delta_pitch(true), - add_raw_log_pitch(false) { } + add_raw_log_pitch(false), add_raw_pov(false) { } void Register(ParseOptions *opts) { @@ -283,6 +284,8 @@ struct ProcessPitchOptions { "features"); opts->Register("add-raw-log-pitch", &add_raw_log_pitch, "If true, log(pitch) is added to output features"); + opts->Register("add-raw-pov", &add_raw_pov, + "If true, add NCCF converted to POV"); } }; @@ -388,6 +391,10 @@ class OnlineProcessPitch: public OnlineFeatureInterface { /// Called from GetFrame(). inline BaseFloat GetNormalizedLogPitchFeature(int32 frame); + /// Computes and retures the raw POV for this frames. + /// Called from GetFrames(). + inline BaseFloat GetRawPov(int32 frame) const; + /// Computes the normalization window sizes. inline void GetNormalizationWindow(int32 frame, int32 src_frames_ready, diff --git a/src/featbin/Makefile b/src/featbin/Makefile index 9843e7bbd4b..12c96c3240e 100644 --- a/src/featbin/Makefile +++ b/src/featbin/Makefile @@ -15,7 +15,11 @@ BINFILES = compute-mfcc-feats compute-plp-feats compute-fbank-feats \ process-kaldi-pitch-feats compare-feats wav-to-duration add-deltas-sdc \ compute-and-process-kaldi-pitch-feats modify-cmvn-stats wav-copy \ wav-reverberate append-vector-to-feats detect-sinusoids shift-feats \ - concat-feats + concat-feats \ + extract-column compute-zero-crossings \ + extract-vector-segments combine-vector-segments \ + compute-snr-targets compute-frame-snrs vector-to-feat \ + corrupt-wav OBJFILES = @@ -23,7 +27,7 @@ TESTFILES = ADDLIBS = ../feat/kaldi-feat.a ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../thread/kaldi-thread.a ../tree/kaldi-tree.a ../matrix/kaldi-matrix.a \ - ../util/kaldi-util.a ../base/kaldi-base.a + ../util/kaldi-util.a ../base/kaldi-base.a ../segmenter/kaldi-segmenter.a include ../makefiles/default_rules.mk diff --git a/src/featbin/apply-cmvn-sliding.cc b/src/featbin/apply-cmvn-sliding.cc index 4a6d02d16cd..5d6c7532cd0 100644 --- a/src/featbin/apply-cmvn-sliding.cc +++ b/src/featbin/apply-cmvn-sliding.cc @@ -36,9 +36,12 @@ int main(int argc, char *argv[]) { "\n" "Usage: apply-cmvn-sliding [options] \n"; + std::string skip_dims_str; ParseOptions po(usage); SlidingWindowCmnOptions opts; opts.Register(&po); + po.Register("skip-dims", &skip_dims_str, "Dimensions for which to skip " + "normalization: colon-separated list of integers, e.g. 13:14:15)"); po.Read(argc, argv); @@ -46,6 +49,15 @@ int main(int argc, char *argv[]) { po.PrintUsage(); exit(1); } + + std::vector skip_dims; // optionally use "fake" + // (zero-mean/unit-variance) stats for some + // dims to disable normalization. + if (!SplitStringToIntegers(skip_dims_str, ":", false, &skip_dims)) { + KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " + << "integers)"; + } + int32 num_done = 0, num_err = 0; diff --git a/src/featbin/combine-vector-segments.cc b/src/featbin/combine-vector-segments.cc new file mode 100644 index 00000000000..e1baed827d2 --- /dev/null +++ b/src/featbin/combine-vector-segments.cc @@ -0,0 +1,178 @@ +// featbin/combine-vector-segments.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "base/kaldi-extra-types.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + try { + const char *usage = + "Combine vectors corresponding to segments to whole vectors. " + "Does the reverse operation of extract-vector-segments." + "Usage: combine-vector-segments [options...] \n"; + + ParseOptions po(usage); + + BaseFloat min_segment_length = 0.1, // Minimum segment length in seconds. + max_overshoot = 0.0; // max time by which last segment can overshoot + BaseFloat frame_shift = 0.01; + BaseFloat default_weight = 0; + int32 overlap = 0; + + po.Register("min-segment-length", &min_segment_length, + "Minimum segment length in seconds (reject shorter segments)"); + po.Register("frame-shift", &frame_shift, + "Frame shift in second"); + po.Register("max-overshoot", &max_overshoot, + "End segments overshooting by less (in seconds) are truncated," + " else rejected."); + po.Register("default-weight", &default_weight, "Fill any extra " + "length with this weight"); + po.Register("overlap", &overlap, "Overlap in segments"); + + po.Read(argc, argv); + + if (po.NumArgs() != 5) { + po.PrintUsage(); + exit(1); + } + + std::string vecs_rspecifier = po.GetArg(1); // input vector archive + std::string reco2utt_rspecifier = po.GetArg(2); + std::string segments_rspecifier = po.GetArg(3); + std::string lengths_rspecifier = po.GetArg(4); // lengths archive + std::string vecs_wspecifier = po.GetArg(5); // output archive + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessUtteranceSegmentReader segment_reader(segments_rspecifier); + RandomAccessBaseFloatVectorReader vector_reader(vecs_rspecifier); + RandomAccessInt32Reader length_reader(lengths_rspecifier); + BaseFloatVectorWriter vector_writer(vecs_wspecifier); + + int32 num_reco = 0, num_success = 0, num_missing = 0, num_err = 0; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next(), num_reco++) { + std::string reco = reco2utt_reader.Key(); + const std::vector &uttlist = reco2utt_reader.Value(); + + if (!length_reader.HasKey(reco)) { + KALDI_WARN << "Could not find length for recording " + << reco; + num_missing++; + } + int32 file_length = length_reader.Value(reco); + + Vector out_vector(file_length); + segmenter::Segmentation seg; + + for (std::vector::const_iterator it = uttlist.begin(); + it != uttlist.end(); ++it) { + + if (!segment_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in segments " + << "file " << segments_rspecifier; + num_err++; + continue; + } + if (!vector_reader.HasKey(*it)) { + KALDI_WARN << "Could not find vector for utterance " << *it; + num_err++; + continue; + } + + const UtteranceSegment &segment = segment_reader.Value(*it); + const Vector &vector = vector_reader.Value(*it); + seg.Emplace(std::round(segment.start_time / frame_shift), + std::round(segment.end_time / frame_shift), 1, + vector); + } + + seg.Sort(); + + size_t i = 0; + for (segmenter::SegmentList::iterator it = seg.Begin(); + it != seg.End(); ++it, i++) { + if (i != 0) { + it->start_frame += overlap / 2; + } + if (i != seg.Dim()) { + it->end_frame -= overlap / 2; + } + + if (it->start_frame < 0 || it->start_frame >= file_length) { + KALDI_WARN << "start frame out of range " << it->start_frame << " [length:] " + << file_length << ", skipping segment "; + num_err++; + continue; + } + + /* end frame must be less than total number samples + * otherwise skip the segment + */ + if (it->end_frame > file_length) { + if (it->end_frame > + file_length + static_cast(max_overshoot / frame_shift)) { + KALDI_WARN << "end frame too far out of range " << it->end_frame + << " [overshooted length:] " << file_length + static_cast(max_overshoot / frame_shift) << ", skipping segment"; + num_err++; + continue; + } + it->end_frame = file_length; // for small differences, just truncate. + } + + KALDI_ASSERT(it->end_frame <= out_vector.Dim()); + SubVector this_out_vec(out_vector, + it->start_frame, it->end_frame - it->start_frame); + + const Vector &vector = it->VectorValue(); + KALDI_ASSERT(vector.Dim() >= it->end_frame - it->start_frame); + + KALDI_ASSERT(overlap / 2 + it->end_frame - it->start_frame <= vector.Dim()); + + SubVector in_vec(vector, overlap / 2, it->end_frame - it->start_frame); + + KALDI_ASSERT(in_vec.Dim() == this_out_vec.Dim()); + this_out_vec.CopyFromVec(in_vec); + } + + vector_writer.Write(reco, out_vector); + + num_success++; + } + + KALDI_LOG << "Read " << num_reco << " recordings and succeeded on " + << num_success << " recordings; " << num_missing + << " recording missing; " << num_err << " utterances " + << "skipped"; + + return (num_success > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/featbin/compute-frame-snrs.cc b/src/featbin/compute-frame-snrs.cc new file mode 100644 index 00000000000..b789e058b7a --- /dev/null +++ b/src/featbin/compute-frame-snrs.cc @@ -0,0 +1,246 @@ +// featbin/compute-frame-snrs.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +void ComputeFrameSnrsUsingCorruptedFbank(const Matrix &clean_fbank, + const Matrix &fbank, + Vector *frame_snrs, + BaseFloat ceiling = 100) { + int32 min_len = frame_snrs->Dim(); + + for (size_t t = 0; t < min_len; t++) { + Vector clean_fbank_t(clean_fbank.Row(t)); + Vector fbank_t(fbank.Row(t)); + + BaseFloat clean_energy_t = clean_fbank_t.LogSumExp(); + BaseFloat total_energy_t = fbank_t.LogSumExp(); + + if (kaldi::ApproxEqual(total_energy_t, clean_energy_t, 1e-10)) { + (*frame_snrs)(t) = ceiling; + } else { + BaseFloat noise_energy_t = (total_energy_t > clean_energy_t ? + LogSub(total_energy_t, clean_energy_t) : + LogSub(clean_energy_t, total_energy_t) ); + + (*frame_snrs)(t) = clean_energy_t - noise_energy_t; + } + } +} + +void ComputeFrameSnrsUsingNoiseFbank(const Matrix &clean_fbank, + const Matrix &noise_fbank, + Vector *frame_snrs) { + int32 min_len = frame_snrs->Dim(); + + for (size_t t = 0; t < min_len; t++) { + Vector clean_fbank_t(clean_fbank.Row(t)); + Vector noise_fbank_t(noise_fbank.Row(t)); + clean_fbank_t.Scale(2.0); + noise_fbank_t.Scale(2.0); + + BaseFloat noise_energy = noise_fbank_t.LogSumExp(); + clean_fbank_t.Add(-noise_energy); + (*frame_snrs)(t) = clean_fbank_t.LogSumExp(); + } +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace std; + + const char *usage = + "Compute frame-level log-SNRs from time-frequency bin predictions and " + "the corrupted fbank features. \n" + "Optionally write clean feats as output.\n" + "Usage: compute-frame-snrs [] []\n" + " e.g.: compute-frame-snrs scp:data/train_fbank/feats.scp \"ark:nnet3-compute exp/nnet3/final.raw scp:data/train_hires/feats.scp ark:- |\" ark:-\n"; + + int32 length_tolerance = 0; + std::string prediction_type = "FbankMask"; + BaseFloat ceiling = 100; + + ParseOptions po(usage); + + po.Register("length-tolerance", &length_tolerance, + "If length is different, trim as shortest up to a frame " + " difference of length-tolerance, otherwise exclude segment."); + po.Register("prediction-type", &prediction_type, + "Prediction type can be FbankMask or IRM"); + po.Register("ceiling", &ceiling, + "Maximum log-frame-SNR allowed"); + + po.Read(argc, argv); + + if (po.NumArgs() < 3 || po.NumArgs() > 5) { + po.PrintUsage(); + exit(1); + } + + std::string fbank_rspecifier = po.GetArg(1), + mask_rspecifier = po.GetArg(2), + frame_snr_wspecifier = po.GetArg(3), + clean_fbank_wspecifier, out_snr_wspecifier; + if (po.NumArgs() >= 4) + clean_fbank_wspecifier = po.GetArg(4); + if (po.NumArgs() == 5) + out_snr_wspecifier = po.GetArg(5); + + SequentialBaseFloatMatrixReader fbank_reader(fbank_rspecifier); + RandomAccessBaseFloatMatrixReader mask_reader(mask_rspecifier); + BaseFloatVectorWriter frame_snr_writer(frame_snr_wspecifier); + BaseFloatMatrixWriter clean_fbank_writer(clean_fbank_wspecifier); + BaseFloatMatrixWriter out_snr_writer(out_snr_wspecifier); + + int32 num_done = 0, num_fail = 0; + + for (; !fbank_reader.Done(); fbank_reader.Next()) { + const Matrix &fbank = fbank_reader.Value(); + const std::string &utt = fbank_reader.Key(); + + if (!mask_reader.HasKey(utt)) { + KALDI_WARN << "No mask features for utt " << utt; + num_fail++; + continue; + } + + Matrix mask(mask_reader.Value(utt)); + + if (mask.NumCols() != fbank.NumCols()) { + KALDI_ERR << "Dimension mismatch between fbank and mask; " + << fbank.NumCols() << " vs " << mask.NumCols() + << " for utt " << utt; + } + + int32 min_len = 0, max_len = 0; + if (mask.NumRows() < fbank.NumRows()) { + min_len = mask.NumRows(); + max_len = fbank.NumRows(); + } else { + min_len = fbank.NumRows(); + max_len = mask.NumRows(); + } + + if (max_len - min_len > length_tolerance || min_len == 0) { + KALDI_WARN << "Length mismatch " << max_len << " vs. " << min_len + << (utt.empty() ? "" : " for utt ") << utt + << " exceeds tolerance " << length_tolerance; + num_fail++; + continue; + } + + if (max_len - min_len > 0) { + KALDI_VLOG(2) << "Length mismatch " << max_len << " vs. " << min_len + << (utt.empty() ? "" : " for utt ") << utt + << " exceeds tolerance " << length_tolerance; + } + + // TODO: Support correction of length mismatch + KALDI_ASSERT(max_len == min_len); + + Vector frame_snrs(min_len); + + Matrix &clean_fbank = mask; + // clean_fbank temporarily stores the mask + + Matrix out_snr; + + if (prediction_type == "Irm") { + clean_fbank.ApplyCeiling(0.0); // S / (N + S) + clean_fbank.AddMat(1.0, fbank, kNoTrans); // F * S / (N + S) + // clean_fbank has been computed + ComputeFrameSnrsUsingCorruptedFbank(clean_fbank, fbank, &frame_snrs, ceiling); + if (!out_snr_wspecifier.empty()) { + out_snr.Resize(fbank.NumRows(), fbank.NumCols()); + // First compute noise + out_snr.CopyFromMat(fbank); + out_snr.LogAddExpMat(-1.0, clean_fbank, kNoTrans); // Noise computed + out_snr.AddMat(-1.0, clean_fbank); // N / S + out_snr.Scale(-1.0); // S / N + out_snr.ApplyCeiling(ceiling); + } + } else if (prediction_type == "FbankMask") { + clean_fbank.ApplyCeiling(0.0); // S / T + clean_fbank.AddMat(1.0, fbank, kNoTrans); // F * S / T + // clean_fbank has been computed + ComputeFrameSnrsUsingCorruptedFbank(clean_fbank, fbank, &frame_snrs, ceiling); + if (!out_snr_wspecifier.empty()) { + out_snr.Resize(fbank.NumRows(), fbank.NumCols()); + out_snr.CopyFromMat(fbank); + out_snr.LogAddExpMat(-1.0, clean_fbank, kNoTrans); // Noise computed + out_snr.AddMat(-1.0, clean_fbank); // N / S + out_snr.Scale(-1.0); // S / N + out_snr.ApplyCeiling(ceiling); + } + } else if (prediction_type == "Snr") { + mask.ApplyCeiling(ceiling); + if (!out_snr_wspecifier.empty()) { + out_snr.Resize(mask.NumRows(), mask.NumCols()); + out_snr.CopyFromMat(mask); + out_snr.ApplyCeiling(ceiling); + } + + Matrix noise_fbank(mask); + + Matrix zeros(mask.NumRows(), mask.NumCols()); + + clean_fbank.Scale(-1.0); // N/S + clean_fbank.LogAddExpMat(1.0, zeros); // 1 + N / S + clean_fbank.Scale(-1.0); // irm has been computed // (1+N/S)^-1.0 + clean_fbank.AddMat(1.0, fbank, kNoTrans); // F * (S / (S + N)) + // clean_fbank has been computed + + noise_fbank.LogAddExpMat(1.0, zeros); // S/N + noise_fbank.Scale(-1.0); // ~irm has been computed // (1+S/N)^-1.0 + noise_fbank.AddMat(1.0, fbank, kNoTrans); // F * (N / (S + N)) + // noise_fbank has been computed + ComputeFrameSnrsUsingNoiseFbank(clean_fbank, noise_fbank, &frame_snrs); + } else { + KALDI_ERR << "Unknown prediction-type" << prediction_type; + } + + if (clean_fbank_wspecifier != "") { + clean_fbank_writer.Write(utt, clean_fbank); + } + + if (!out_snr_wspecifier.empty()) { + out_snr_writer.Write(utt, out_snr); + } + + frame_snr_writer.Write(utt, frame_snrs); + num_done++; + } + + KALDI_LOG << "Computed frame snr for " << num_done << " utterances; " + << "failed for " << num_fail; + + return (num_done == 0 ? 1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/featbin/compute-snr-targets.cc b/src/featbin/compute-snr-targets.cc new file mode 100644 index 00000000000..a971b04ed2e --- /dev/null +++ b/src/featbin/compute-snr-targets.cc @@ -0,0 +1,273 @@ +// featbin/compute-snr-targets.cc + +// Copyright 2015-2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Compute snr targets using clean and noisy speech features.\n" + "The targets can be of 3 types -- \n" + "Irm (Ideal Ratio Mask) = Clean fbank / (Clean fbank + Noise fbank)\n" + "FbankMask = Clean fbank / Noisy fbank\n" + "Snr (Signal To Noise Ratio) = Clean fbank / Noise fbank\n" + "Both input and output features are assumed to be in log domain.\n" + "ali-rspecifier and silence-phones are used to identify whether " + "a particular frame is \"clean\" or not. Silence frames in " + "\"clean\" fbank are treated as \"noise\" and hence the SNR for those " + "frames are -inf in log scale.\n" + "Usage: compute-snr-targets [options] \n" + " or compute-snr-targets [options] --binary-targets \n" + "e.g.: compute-snr-targets scp:clean.scp scp:noisy.scp ark:targets.ark\n"; + + std::string target_type = "Irm"; + std::string ali_rspecifier; + std::string silence_phones_str; + std::string floor_str = "-inf", ceiling_str = "inf"; + int32 length_tolerance = 0; + bool binary_targets = false; + int32 target_dim = -1; + + ParseOptions po(usage); + po.Register("target_type", &target_type, "Target type can be FbankMask or IRM"); + po.Register("ali-rspecifier", &ali_rspecifier, "If provided, all the " + "energy in the silence region of clean file is considered noise"); + po.Register("silence-phones", &silence_phones_str, "Comma-separated list of " + "silence phones"); + po.Register("floor", &floor_str, "If specified, the target is floored at " + "this value. You may want to do this if you are using targets " + "in original log form as is usual in the case of Snr, but may " + "not if you are applying Exp() as is usual in the case of Irm"); + po.Register("ceiling", &ceiling_str, "If specified, the target is ceiled " + "at this value. You may want to do this if you expect " + "infinities or very large values, particularly for Snr targets."); + po.Register("length-tolerance", &length_tolerance, "Tolerate differences " + "in utterance lengths of these many frames"); + po.Register("binary-targets", &binary_targets, "If specified, then the " + "targets are created considering each frame to be either " + "completely signal or completely noise as decided by the " + "ali-rspecifier option. When ali-rspecifier is not specified, " + "then the entire utterance is considered to be just signal." + "If this option is specified, then only a single argument " + "-- the clean features -- is must be specified."); + po.Register("target-dim", &target_dim, "Overrides the target dimension. " + "Applicable only with --binary-targets is specified"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3 && po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::vector silence_phones; + if (!silence_phones_str.empty()) { + if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones)) { + KALDI_ERR << "Invalid silence-phones string " << silence_phones_str; + } + std::sort(silence_phones.begin(), silence_phones.end()); + } + + double floor = kLogZeroDouble, ceiling = -kLogZeroDouble; + + if (floor_str != "-inf") + if (!ConvertStringToReal(floor_str, &floor)) { + KALDI_ERR << "Invalid --floor value " << floor_str; + } + + if (ceiling_str != "inf") + if (!ConvertStringToReal(ceiling_str, &ceiling)) { + KALDI_ERR << "Invalid --ceiling value " << ceiling_str; + } + + int32 num_done = 0, num_err = 0, num_success = 0; + int64 num_sil_frames = 0; + int64 num_speech_frames = 0; + + if (!binary_targets) { + // This is the 'normal' case, where we have both clean and + // noise/corrupted input features. + // The word 'noisy' in the variable names is used to mean 'corrupted'. + std::string clean_rspecifier = po.GetArg(1), + noisy_rspecifier = po.GetArg(2), + targets_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader noisy_reader(noisy_rspecifier); + RandomAccessBaseFloatMatrixReader clean_reader(clean_rspecifier); + BaseFloatMatrixWriter kaldi_writer(targets_wspecifier); + + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !noisy_reader.Done(); noisy_reader.Next(), num_done++) { + const std::string &key = noisy_reader.Key(); + Matrix total_energy(noisy_reader.Value()); + // Although this is called 'energy', it is actually log filterbank + // features of noise or corrupted files + // Actually noise feats in the case of Irm and Snr + + // TODO: Support multiple corrupted version for a particular clean file + std::string uniq_key = key; + if (!clean_reader.HasKey(uniq_key)) { + KALDI_WARN << "Could not find uniq key " << uniq_key << " " + << "in clean feats " << clean_rspecifier; + num_err++; + continue; + } + + Matrix clean_energy(clean_reader.Value(uniq_key)); + + if (target_type == "Irm") { + total_energy.LogAddExpMat(1.0, clean_energy, kNoTrans); + } + + if (!ali_rspecifier.empty()) { + if (!alignment_reader.HasKey(uniq_key)) { + KALDI_WARN << "Could not find uniq key " << uniq_key + << "in alignment " << ali_rspecifier; + num_err++; + continue; + } + const std::vector &ali = alignment_reader.Value(key); + + if (std::abs(static_cast (ali.size()) - clean_energy.NumRows()) > length_tolerance) { + KALDI_WARN << "Mismatch in number of frames in alignment " + << "and feats; " << static_cast(ali.size()) + << " vs " << clean_energy.NumRows(); + num_err++; + continue; + } + + int32 length = std::min(static_cast(ali.size()), clean_energy.NumRows()); + if (ali.size() < length) + // TODO: Support this case + KALDI_ERR << "This code currently does not support the case " + << "where alignment smaller than features because " + << "it is not expected to happen"; + + KALDI_ASSERT(clean_energy.NumRows() == length); + KALDI_ASSERT(total_energy.NumRows() == length); + + if (clean_energy.NumRows() < length) clean_energy.Resize(length, clean_energy.NumCols(), kCopyData); + if (total_energy.NumRows() < length) total_energy.Resize(length, total_energy.NumCols(), kCopyData); + + for (int32 i = 0; i < clean_energy.NumRows(); i++) { + if (std::binary_search(silence_phones.begin(), silence_phones.end(), ali[i])) { + clean_energy.Row(i).Set(kLogZeroDouble); + num_sil_frames++; + } else num_speech_frames++; + } + } + + clean_energy.AddMat(-1.0, total_energy); + if (ceiling_str != "inf") { + clean_energy.ApplyCeiling(ceiling); + } + + if (floor_str != "-inf") { + clean_energy.ApplyFloor(floor); + } + + kaldi_writer.Write(key, Matrix(clean_energy)); + num_success++; + } + } else { + // Copying tables of features. + std::string feats_rspecifier = po.GetArg(1), + targets_wspecifier = po.GetArg(2); + + SequentialBaseFloatMatrixReader feats_reader(feats_rspecifier); + BaseFloatMatrixWriter kaldi_writer(targets_wspecifier); + + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + int64 num_sil_frames = 0; + int64 num_speech_frames = 0; + + for (; !feats_reader.Done(); feats_reader.Next(), num_done++) { + const std::string &key = feats_reader.Key(); + const Matrix &feats = feats_reader.Value(); + + Matrix targets; + + if (target_dim < 0) + targets.Resize(feats.NumRows(), feats.NumCols()); + else + targets.Resize(feats.NumRows(), target_dim); + + if (target_type == "Snr") + targets.Set(-kLogZeroDouble); + + if (!ali_rspecifier.empty()) { + if (!alignment_reader.HasKey(key)) { + KALDI_WARN << "Could not find uniq key " << key + << " in alignment " << ali_rspecifier; + num_err++; + continue; + } + + const std::vector &ali = alignment_reader.Value(key); + + if (std::abs(static_cast (ali.size()) - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Mismatch in number of frames in alignment " + << "and feats; " << static_cast(ali.size()) + << " vs " << feats.NumRows(); + num_err++; + continue; + } + + int32 length = std::min(static_cast(ali.size()), feats.NumRows()); + KALDI_ASSERT(ali.size() >= length); + + for (int32 i = 0; i < feats.NumRows(); i++) { + if (std::binary_search(silence_phones.begin(), silence_phones.end(), ali[i])) { + targets.Row(i).Set(kLogZeroDouble); + num_sil_frames++; + } else { + num_speech_frames++; + } + } + + if (ceiling_str != "inf") { + targets.ApplyCeiling(ceiling); + } + + if (floor_str != "-inf") { + targets.ApplyFloor(floor); + } + + kaldi_writer.Write(key, targets); + } + } + } + + KALDI_LOG << "Computed SNR targets for " << num_success + << " out of " << num_done << " utterances; failed for " + << num_err; + KALDI_LOG << "Got [ " << num_speech_frames << "," + << num_sil_frames << "] frames of silence and speech"; + return (num_success > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/featbin/compute-zero-crossings.cc b/src/featbin/compute-zero-crossings.cc new file mode 100644 index 00000000000..b4ad0dbe0f6 --- /dev/null +++ b/src/featbin/compute-zero-crossings.cc @@ -0,0 +1,139 @@ +// featbin/compute-zero-crossings.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "feat/feature-functions.h" +#include "feat/wave-reader.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + const char *usage = + "Create zero-crossing features\n" + "Usage: compute-zero-crossings [options...] \n"; + + // construct all the global objects + ParseOptions po(usage); + FrameExtractionOptions opts; + + int32 channel = -1; + BaseFloat min_duration = 0.0, zero_crossing_threshold = 0.0; + bool write_as_vector = false; + + // Register the option struct + opts.Register(&po); + // Register the options + po.Register("channel", &channel, "Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right)"); + po.Register("min-duration", &min_duration, "Minimum duration of segments to process (in seconds)."); + po.Register("zero-crossing-threshold", &zero_crossing_threshold, + "Take any value within this threshold as zero " + "for zero crossing computation"); + po.Register("write-as-vector", &write_as_vector, "Write as a vector " + "to interpret the output as weights instead of " + "a column of feature matrix"); + + // OPTION PARSING .......................................................... + // + + // parse options (+filling the registered variables) + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string wav_rspecifier = po.GetArg(1); + std::string output_wspecifier = po.GetArg(2); + + SequentialTableReader reader(wav_rspecifier); + BaseFloatMatrixWriter *matrix_writer = NULL; + BaseFloatVectorWriter *vector_writer = NULL; + + if (write_as_vector) + vector_writer = new BaseFloatVectorWriter(output_wspecifier); + else + matrix_writer = new BaseFloatMatrixWriter(output_wspecifier); + + int32 num_utts = 0, num_success = 0; + for (; !reader.Done(); reader.Next()) { + num_utts++; + std::string utt = reader.Key(); + const WaveData &wave_data = reader.Value(); + if (wave_data.Duration() < min_duration) { + KALDI_WARN << "File: " << utt << " is too short (" + << wave_data.Duration() << " sec): producing no output."; + continue; + } + int32 num_chan = wave_data.Data().NumRows(), this_chan = channel; + { // This block works out the channel (0=left, 1=right...) + KALDI_ASSERT(num_chan > 0); // should have been caught in + // reading code if no channels. + if (channel == -1) { + this_chan = 0; + if (num_chan != 1) + KALDI_WARN << "Channel not specified but you have data with " + << num_chan << " channels; defaulting to zero"; + } else { + if (this_chan >= num_chan) { + KALDI_WARN << "File with id " << utt << " has " + << num_chan << " channels but you specified channel " + << channel << ", producing no output."; + continue; + } + } + } + + if (opts.samp_freq != wave_data.SampFreq()) + KALDI_ERR << "Sample frequency mismatch: you specified " + << opts.samp_freq << " but data has " + << wave_data.SampFreq() << " (use --sample-frequency " + << "option). Utterance is " << utt; + + SubVector waveform(wave_data.Data(), this_chan); + Vector zero_crossings; + ComputeZeroCrossings(waveform, opts, zero_crossing_threshold, &zero_crossings, NULL); + + if (write_as_vector) { + vector_writer->Write(utt, zero_crossings); + } else { + Matrix mat(zero_crossings.Dim(), 1); + mat.CopyColFromVec(zero_crossings, 0); + matrix_writer->Write(utt, mat); + } + + if(num_utts % 10 == 0) + KALDI_LOG << "Processed " << num_utts << " utterances"; + KALDI_VLOG(2) << "Processed features for key " << utt; + num_success++; + } + KALDI_LOG << " Done " << num_success << " out of " << num_utts + << " utterances."; + + if (vector_writer != NULL) delete vector_writer; + if (matrix_writer != NULL) delete matrix_writer; + return (num_success != 0 ? 0 : 1); + } catch(const std::exception& e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/featbin/corrupt-wav.cc b/src/featbin/corrupt-wav.cc new file mode 100644 index 00000000000..88b9f3dfd1f --- /dev/null +++ b/src/featbin/corrupt-wav.cc @@ -0,0 +1,544 @@ +// featbin/corrupt-wav.cc + +// Copyright 2015 Tom Ko +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "feat/wave-reader.h" +#include "feat/signal.h" + +namespace kaldi { + +/* + This function is to repeatedly concatenate signal1 by itself + to match the length of signal2 and add the two signals together. +*/ +void AddVectorsOfUnequalLength(const VectorBase &signal1, + VectorBase *signal2, + VectorBase *signal1_added) { + if (signal1_added) + KALDI_ASSERT(signal2->Dim() == signal1_added->Dim()); + for (int32 po = 0; po < signal2->Dim(); po += signal1.Dim()) { + int32 block_length = signal1.Dim(); + if (signal2->Dim() - po < block_length) block_length = signal2->Dim() - po; + signal2->Range(po, block_length).AddVec(1.0, signal1.Range(0, block_length)); + if (signal1_added) + signal1_added->Range(po, block_length).CopyFromVec( + signal1.Range(0, block_length)); + } +} + +inline BaseFloat MaxAbsolute( + const VectorBase &vector) { + return std::max(std::abs(vector.Max()), std::abs(vector.Min())); +} + +inline BaseFloat ComputeEnergy( + const VectorBase &vec) { + return VecVec(vec, vec) / vec.Dim(); +} + +inline BaseFloat DbToValue(const BaseFloat &db) { + return Exp(db * Log(10.0) / 10.0); +} + +/* + Early reverberation component of the signal is composed of reflections + within 0.05 seconds of the direct path signal (assumed to be the peak of + the room impulse response). This function returns the energy in + this early reverberation component of the signal. + The input parameters to this function are the room impulse response, the signal + and their sampling frequency respectively. +*/ +BaseFloat ComputeEarlyReverbEnergy(const Vector &rir, const Vector &signal, + BaseFloat samp_freq) { + int32 peak_index = 0; + rir.Max(&peak_index); + KALDI_VLOG(1) << "peak index is " << peak_index; + + const float sec_before_peak = 0.001; + const float sec_after_peak = 0.05; + int32 early_rir_start_index = peak_index - sec_before_peak * samp_freq; + int32 early_rir_end_index = peak_index + sec_after_peak * samp_freq; + if (early_rir_start_index < 0) early_rir_start_index = 0; + if (early_rir_end_index > rir.Dim()) early_rir_end_index = rir.Dim(); + + int32 duration = early_rir_end_index - early_rir_start_index; + Vector early_rir(rir.Range(early_rir_start_index, duration)); + Vector early_reverb(signal); + FFTbasedBlockConvolveSignals(early_rir, &early_reverb); + + // compute the energy + return ComputeEnergy(early_reverb); +} + +/* + This is the core function to do reverberation and noise addition + on the given signal. The noise will be scaled before the addition + to match the given signal-to-noise ratio (SNR) and it will also concatenate + itself repeatedly to match the length of the signal. + The input parameters to this function are the room impulse response, + the sampling frequency, the SNR(dB), the noise and the signal respectively. +*/ +void DoCorruption(BaseFloat samp_freq, const Vector &rir, + Vector *noise, BaseFloat background_snr_db, + const std::vector > &foreground_noises, + int32 channel, BaseFloat foreground_snr_db, + Vector *signal, + Vector *out_clean = NULL, + Vector *out_noise = NULL, + BaseFloat min_duration = 0.1, BaseFloat search_fraction = 0.1) { + BaseFloat input_power = 0; + + if (rir.Dim() > 0) { + FFTbasedBlockConvolveSignals(rir, signal); + input_power = ComputeEarlyReverbEnergy(rir, *signal, samp_freq); + } else { + input_power = ComputeEnergy(*signal); + } + + if (out_clean) + out_clean->CopyFromVec(*signal); + + if (noise->Dim() > 0) { + BaseFloat noise_power = ComputeEnergy(*noise); + BaseFloat scale_factor = sqrt(DbToValue(-background_snr_db) + * input_power / noise_power); + noise->Scale(scale_factor); + KALDI_VLOG(1) << "Noise signal is being scaled with " << scale_factor + << " to generate output with SNR " << background_snr_db << "db\n"; + AddVectorsOfUnequalLength(*noise, signal, out_noise); + } + + KALDI_ASSERT(search_fraction <= 1.0); + if (foreground_noises.size() > 0) { + int32 t = 0; + while (t < signal->Dim()) { + // Start position to add foreground noise must be beyond the current 't' + // but not more than search_fraction * signal->Dim() + int32 start_t = t + search_fraction * signal->Dim() * RandUniform(); + int32 max_duration_possible = signal->Dim() - start_t; + + // Check if the max duration possible is less than a minimum duration. + // This is to avoid adding very short duration of noise, say 1 frame. + if (max_duration_possible < min_duration * samp_freq) break; + + int32 i = RandInt(0, foreground_noises.size() - 1); + KALDI_ASSERT(channel < foreground_noises[i].NumRows()); + Vector foreground_noise(foreground_noises[i].Row(channel)); + if (max_duration_possible < foreground_noise.Dim()) { + SubVector this_foreground_noise(foreground_noise, + 0, max_duration_possible); + SubVector this_signal(*signal, + start_t, max_duration_possible); + + BaseFloat noise_power = ComputeEnergy(this_foreground_noise); + BaseFloat signal_power = ComputeEnergy(this_signal); + + BaseFloat scale_factor = sqrt(DbToValue(-foreground_snr_db) + * signal_power / noise_power); + this_signal.AddVec(scale_factor, this_foreground_noise); + + if (out_noise) { + SubVector this_out_noise(*out_noise, + start_t, max_duration_possible); + this_out_noise.AddVec(scale_factor, this_foreground_noise); + } + + break; + } else { + SubVector this_signal(*signal, + start_t, foreground_noise.Dim()); + + BaseFloat noise_power = ComputeEnergy(foreground_noise); + BaseFloat signal_power = ComputeEnergy(this_signal); + + BaseFloat scale_factor = sqrt(DbToValue(-foreground_snr_db) + * signal_power / noise_power); + this_signal.AddVec(scale_factor, foreground_noise); + + if (out_noise) { + SubVector this_out_noise(*out_noise, + start_t, foreground_noise.Dim()); + this_out_noise.AddVec(scale_factor, foreground_noise); + } + + t += foreground_noise.Dim(); + } + } + } +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Corrupts the wave files supplied via input pipe with the specified\n" + "room-impulse response (rir_matrix) and additive noise distortions\n" + "(specified by corresponding files).\n" + "Usage: wav-reverberate [options] \n" + " e.g.: wav-reverberate --rir-file=large_roon_rir.wav clean.wav corrupted.wav\n"; + + ParseOptions po(usage); + + std::string rir_file; + std::string background_noise_file; + std::string foreground_noise_files_str; + std::string out_clean_file; + std::string out_noise_file; + + BaseFloat background_snr_db = 20; + BaseFloat foreground_snr_db = 20; + bool multi_channel_output = false; + int32 input_channel = 0; + int32 rir_channel = 0; + int32 noise_channel = 0; + bool normalize_output = true; + BaseFloat volume = 0; + BaseFloat rms_amplitude = 0.1; + bool normalize_by_amplitude = false, normalize_by_power = false; + int32 srand_seed = 0; + BaseFloat min_duration = 0.1, search_fraction = 0.1; + + po.Register("multi-channel-output", &multi_channel_output, + "Specifies if the output should be multi-channel or not"); + po.Register("input-wave-channel", &input_channel, + "Specifies the channel to be used from input as only a " + "single channel will be used to generate reverberated output"); + po.Register("rir-channel", &rir_channel, + "Specifies the channel of the room impulse response, " + "it will only be used when multi-channel-output is false"); + po.Register("noise-channel", &noise_channel, + "Specifies the channel of the noise file, " + "it will only be used when multi-channel-output is false"); + po.Register("background-snr-db", &background_snr_db, + "Desired SNR(dB) of background noise wrt clean signal"); + po.Register("foreground-snr-db", &foreground_snr_db, + "Desired SNR(db) of foreground noise wrt of background corrupted signal"); + po.Register("normalize-output", &normalize_output, + "If true, then after reverberating and " + "possibly adding noise, scale so that the signal " + "energy is the same as the original input signal."); + po.Register("volume", &volume, + "If nonzero, a scaling factor on the signal that is applied " + "after reverberating and possibly adding noise. " + "If you set this option to a nonzero value, it will be as" + "if you had also specified --normalize-output=false. " + "If you set this option to a negative value, it will be " + "ignored and instead the --signal-db option would be used."); + po.Register("rms-amplitude", &rms_amplitude, + "Desired rms after corruption. This will be used " + "only if volume is less than 0"); + po.Register("normalize-by-amplitude", &normalize_by_amplitude, + "Make the maximum amplitude in the output signal to be 95% of " + "the amplitude range possible in wave output"); + po.Register("normalize-by-power", &normalize_by_power, + "Make the amplitude such that the RMS energy of the signal " + "is rms-amplitude"); + po.Register("output-noise-file", &out_noise_file, + "Wave file to write the output noise file just before " + "adding it to the reverberated signal"); + po.Register("output-clean-file", &out_clean_file, + "Wave file to write the output clean file just before " + "adding additive noise. It may have reverberation"); + po.Register("rir-file", &rir_file, + "File with room impulse response"); + po.Register("background-noise-file", &background_noise_file, + "File with additive background noise"); + po.Register("foreground-noise-files", &foreground_noise_files_str, + "Colon separated list of foreground noise files"); + po.Register("min-duration", &min_duration, + "If the duration of signal in which we can add foreground " + "noise is smaller than this min-duration, then the noise " + "would not be added."); + po.Register("search-fraction", &search_fraction, + "The maximum separation between two foreground noise additions " + "specified as a fraction of the length of the file"); + po.Register("srand", &srand_seed, "Seed for random number generator"); + + po.Read(argc, argv); + + srand(srand_seed); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + if (multi_channel_output) { + if (rir_channel != 0 || noise_channel != 0) + KALDI_WARN << "options for --rir-channel and --noise-channel" + "are ignored as --multi-channel-output is true."; + } + + std::string input_wave_file = po.GetArg(1); + std::string output_wave_file = po.GetArg(2); + + KALDI_VLOG(1) << "input-wav-file: " << input_wave_file; + KALDI_VLOG(1) << "output-wav-file: " << output_wave_file; + KALDI_VLOG(1) << "rir-file: " << (!rir_file.empty() ? rir_file : "None"); + KALDI_VLOG(1) << "background-noise-file: " + << (!background_noise_file.empty() ? background_noise_file : "None"); + KALDI_VLOG(1) << "foreground-noise-files-str: " + << (!foreground_noise_files_str.empty() ? foreground_noise_files_str : "None"); + + /************************************************************************** + * Read input wave + **************************************************************************/ + + WaveData input_wave; + { + Input ki(input_wave_file); + input_wave.Read(ki.Stream()); + } + + const Matrix &input_matrix = input_wave.Data(); + BaseFloat samp_freq_input = input_wave.SampFreq(); + int32 num_samp_input = input_matrix.NumCols(), // #samples in the input + num_input_channel = input_matrix.NumRows(); // #channels in the input + KALDI_VLOG(1) << "sampling frequency of input: " << samp_freq_input + << " #samples: " << num_samp_input + << " #channel: " << num_input_channel; + KALDI_ASSERT(input_channel < num_input_channel); + + /************************************************************************** + * Read room impulse response if it exists + **************************************************************************/ + + const Matrix *rir_matrix = NULL; + + BaseFloat samp_freq_rir = samp_freq_input; + int32 num_samp_rir = 0, + num_rir_channel = 1; + + WaveData rir_wave; + if (!rir_file.empty()) { + { + Input ki(rir_file); + rir_wave.Read(ki.Stream()); + } + rir_matrix = &rir_wave.Data(); + + samp_freq_rir = rir_wave.SampFreq(); + KALDI_ASSERT(samp_freq_input == samp_freq_rir); + num_samp_rir = rir_matrix->NumCols(); + num_rir_channel = rir_matrix->NumRows(); + KALDI_VLOG(1) << "sampling frequency of rir: " << samp_freq_rir + << " #samples: " << num_samp_rir + << " #channel: " << num_rir_channel; + if (!multi_channel_output) { + KALDI_ASSERT(rir_channel < num_rir_channel); + } + } else { + rir_channel = 0; + // Cannot create multichannel output without an rir-file + KALDI_ASSERT(!multi_channel_output); + } + + /************************************************************************** + * Read background noise if it is provided + **************************************************************************/ + + const Matrix *background_noise_matrix = NULL; + WaveData noise_wave; + if (!background_noise_file.empty()) { + { + Input ki(background_noise_file); + noise_wave.Read(ki.Stream()); + } + background_noise_matrix = &noise_wave.Data(); + BaseFloat samp_freq_noise = noise_wave.SampFreq(); + KALDI_ASSERT(samp_freq_input == samp_freq_noise); + int32 num_samp_noise = background_noise_matrix->NumCols(), + num_noise_channel = background_noise_matrix->NumRows(); + KALDI_VLOG(1) << "sampling frequency of noise: " << samp_freq_noise + << " #samples: " << num_samp_noise + << " #channel: " << num_noise_channel; + if (multi_channel_output) { + KALDI_ASSERT(num_rir_channel == num_noise_channel); + } else { + KALDI_ASSERT(noise_channel < num_noise_channel); + } + } + + /************************************************************************** + * Read foreground noises if it is provided + **************************************************************************/ + + std::vector > foreground_noise_matrices; + std::vector foreground_noise_files; + + if (!foreground_noise_files_str.empty()) { + SplitStringToVector(foreground_noise_files_str, ":", + true, &foreground_noise_files); + + foreground_noise_matrices.resize(foreground_noise_files.size()); + for (size_t i = 0; i < foreground_noise_files.size(); i++) { + const std::string &noise_file = foreground_noise_files[i]; + WaveData noise_wave; + { + Input ki(noise_file); + noise_wave.Read(ki.Stream()); + } + + Matrix &noise_matrix = foreground_noise_matrices[i]; + noise_matrix.Resize(noise_wave.Data().NumRows(), + noise_wave.Data().NumCols()); + + noise_matrix.CopyFromMat(noise_wave.Data()); + + BaseFloat samp_freq_noise = noise_wave.SampFreq(); + KALDI_ASSERT(samp_freq_input == samp_freq_noise); + int32 num_samp_noise = noise_matrix.NumCols(), + num_noise_channel = noise_matrix.NumRows(); + KALDI_VLOG(1) << "sampling frequency of noise: " << samp_freq_noise + << " #samples: " << num_samp_noise + << " #channel: " << num_noise_channel; + if (multi_channel_output) { + KALDI_ASSERT(num_rir_channel == num_noise_channel); + } else { + KALDI_ASSERT(noise_channel < num_noise_channel); + } + } + } + + /************************************************************************** + * Prepare output wave matrix along with the output clean and noise + * matrices which need to be written optionally. + **************************************************************************/ + + int32 num_output_channels = (multi_channel_output ? num_rir_channel : 1); + Matrix out_matrix(num_output_channels, num_samp_input); + + Matrix out_clean_matrix; + Matrix out_noise_matrix; + + for (int32 output_channel = 0; output_channel < num_output_channels; + output_channel++) { + Vector input(num_samp_input); + input.CopyRowFromMat(input_matrix, input_channel); + float power_before_corruption = VecVec(input, input) / input.Dim(); + + int32 this_rir_channel = (multi_channel_output ? + output_channel : rir_channel); + Vector rir(num_samp_rir); + + if (!rir_file.empty()) { + // Read a particular channel of room impulse response and convert it + // to a floating point number + rir.CopyRowFromMat(*rir_matrix, this_rir_channel); + rir.Scale(1.0 / (1 << 15)); + } + + Vector background_noise; + + if (!background_noise_file.empty()) { + background_noise.Resize(background_noise_matrix->NumCols()); + int32 this_noise_channel = (multi_channel_output ? + output_channel : noise_channel); + background_noise.CopyRowFromMat(*background_noise_matrix, + this_noise_channel); + } + + Vector clean_signal(input.Dim()); + Vector noise_signal(input.Dim()); + + DoCorruption(samp_freq_input, rir, &background_noise, background_snr_db, + foreground_noise_matrices, + multi_channel_output ? output_channel : noise_channel, + foreground_snr_db, &input, + (!out_clean_file.empty() ? &clean_signal : NULL), + (!out_noise_file.empty() ? &noise_signal : NULL), + min_duration, search_fraction); + + BaseFloat power_after_corruption = ComputeEnergy(input); + + if (volume > 0) { + input.Scale(volume); + if (!out_clean_file.empty()) + clean_signal.Scale(volume); + if (!background_noise_file.empty()) + background_noise.Scale(volume); + } else if (volume < 0) { + BaseFloat scale; + + if (normalize_by_amplitude) { + BaseFloat max = MaxAbsolute(input); + + scale = Exp( Log(rms_amplitude) // signal_db to amplitude + - Log(max) // actual max amplitude + + 15.0 * Log(2.0) // * 2^15 + + Log(0.95) ); // Allow only 0.95 of max amplitude possible + } else if (normalize_by_power) { + scale = Exp( Log(rms_amplitude) // rms amplitude + - 0.5 * Log(power_before_corruption) // clean rms amplitude + + 15.0 * Log(2.0)); // * 2^15 + } + + input.Scale(scale); + + if (!out_clean_file.empty()) + clean_signal.Scale(scale); + if (!background_noise_file.empty()) + noise_signal.Scale(scale); + } else if (normalize_output) + input.Scale(sqrt(power_before_corruption / power_after_corruption)); + + out_matrix.CopyRowFromVec(input, output_channel); + + if (!out_clean_file.empty()) { + if (output_channel == 0) + out_clean_matrix.Resize(out_matrix.NumRows(), out_matrix.NumCols()); + out_clean_matrix.CopyRowFromVec(clean_signal, output_channel); + } + + if (!out_noise_file.empty()) { + if (output_channel == 0) + out_noise_matrix.Resize(out_matrix.NumRows(), out_matrix.NumCols()); + out_noise_matrix.CopyRowFromVec(noise_signal, output_channel); + } + } + + WaveData out_wave(samp_freq_input, out_matrix); + Output ko(output_wave_file, false); + out_wave.Write(ko.Stream()); + + if (!out_clean_file.empty()) { + WaveData out_clean_wave(samp_freq_input, out_clean_matrix); + Output ko(out_clean_file, false); + out_clean_wave.Write(ko.Stream()); + } + + if (!out_noise_file.empty()) { + WaveData out_noise_wave(samp_freq_input, out_noise_matrix); + Output ko(out_noise_file, false); + out_noise_wave.Write(ko.Stream()); + } + + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/featbin/extract-column.cc b/src/featbin/extract-column.cc new file mode 100644 index 00000000000..2bbf6b17235 --- /dev/null +++ b/src/featbin/extract-column.cc @@ -0,0 +1,82 @@ +// featbin/extract-column.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace std; + + const char *usage = + "Extract a column out of a matrix. \n" + "This is most useful to extract log-energies \n" + "from feature files\n" + "\n" + "Usage: extract-column [options] --column-index= \n" + " e.g. extract-column ark:feats-in.ark ark:energies.ark\n" + "See also: select-feats, subset-feats, subsample-feats, extract-rows\n"; + + ParseOptions po(usage); + + int32 column_index = 0; + + po.Register("column-index", &column_index, + "Index of column to extract"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + string feat_rspecifier = po.GetArg(1); + string vector_wspecifier = po.GetArg(2); + + SequentialBaseFloatMatrixReader reader(feat_rspecifier); + BaseFloatVectorWriter writer(vector_wspecifier); + + int32 num_done = 0, num_err = 0; + + string line; + + for (; !reader.Done(); reader.Next(), num_done++) { + const Matrix& feats(reader.Value()); + Vector col(feats.NumRows()); + if (column_index >= feats.NumCols()) { + KALDI_ERR << "Column index " << column_index << " is " + << "not less than number of columns " << feats.NumCols(); + } + col.CopyColFromMat(feats, column_index); + writer.Write(reader.Key(), col); + } + + KALDI_LOG << "Processed " << num_done << " segments successfully; " + << "errors on " << num_err; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/featbin/extract-feature-segments.cc b/src/featbin/extract-feature-segments.cc index 93f599feb3a..15ce87255b0 100644 --- a/src/featbin/extract-feature-segments.cc +++ b/src/featbin/extract-feature-segments.cc @@ -3,6 +3,7 @@ // Copyright 2009-2011 Microsoft Corporation; Govivace Inc. // 2012-2013 Mirko Hannemann; Arnab Ghoshal // 2015 Tanel Alumae +// 2015 Vimal Manohar // See ../../COPYING for clarification regarding multiple authors // @@ -47,19 +48,19 @@ int main(int argc, char *argv[]) { ParseOptions po(usage); BaseFloat min_segment_length = 0.1, // Minimum segment length in seconds. - max_overshoot = 0.0; // max time by which last segment can overshoot + max_overshoot = 0.0; // max time by which last segment can overshoot int32 frame_shift = 10; int32 frame_length = 25; bool snip_edges = true; // Register the options po.Register("min-segment-length", &min_segment_length, - "Minimum segment length in seconds (reject shorter segments)"); + "Minimum segment length in seconds (reject shorter segments)"); po.Register("frame-length", &frame_length, "Frame length in milliseconds"); po.Register("frame-shift", &frame_shift, "Frame shift in milliseconds"); po.Register("max-overshoot", &max_overshoot, - "End segments overshooting by less (in seconds) are truncated," - " else rejected."); + "End segments overshooting by less (in seconds) are truncated," + " else rejected."); po.Register("snip-edges", &snip_edges, "If true, n_frames frames will be snipped from the end of each " "extracted feature matrix, " @@ -72,24 +73,21 @@ int main(int argc, char *argv[]) { // OPTION PARSING ... // parse options (+filling the registered variables) po.Read(argc, argv); - // number of arguments should be 3 - // (scriptfile, segments file and outputwav write mode) if (po.NumArgs() != 3) { po.PrintUsage(); exit(1); } std::string rspecifier = po.GetArg(1); // get script file/feature archive - std::string segments_rxfilename = po.GetArg(2); // get segment file + std::string segments_rspecifier = po.GetArg(2); // get segment file std::string wspecifier = po.GetArg(3); // get written archive name BaseFloatMatrixWriter feat_writer(wspecifier); + SequentialUtteranceSegmentReader segments_reader(segments_rspecifier); RandomAccessBaseFloatMatrixReader feat_reader(rspecifier); - Input ki(segments_rxfilename); // no binary argment: never binary. - - int32 num_lines = 0, num_success = 0; + int32 num_err = 0, num_done = 0; int32 snip_length = 0; if (snip_edges) { @@ -97,70 +95,24 @@ int main(int argc, char *argv[]) { 1.0 * (frame_length - frame_shift) / frame_shift)); } - std::string line; - /* read each line from segments file */ - while (std::getline(ki.Stream(), line)) { - num_lines++; - std::vector split_line; - // Split the line by space or tab and check the number of fields in each - // line. There must be 4 fields--segment name , reacording wav file name, - // start time, end time; 5th field (channel info) is optional. - SplitStringToVector(line, " \t\r", true, &split_line); - if (split_line.size() != 4 && split_line.size() != 5) { - KALDI_WARN << "Invalid line in segments file: " << line; - continue; - } - std::string segment = split_line[0], - utterance = split_line[1], - start_str = split_line[2], - end_str = split_line[3]; - - // Convert the start time and endtime to real from string. Segment is - // ignored if start or end time cannot be converted to real. - double start, end; - if (!ConvertStringToReal(start_str, &start)) { - KALDI_WARN << "Invalid line in segments file [bad start]: " << line; - continue; - } - if (!ConvertStringToReal(end_str, &end)) { - KALDI_WARN << "Invalid line in segments file [bad end]: " << line; - continue; - } + for (; !segments_reader.Done(); segments_reader.Next()) { + const std::string &seg_id = segments_reader.Key(); + const UtteranceSegment &segment = segments_reader.Value(); - // start time must not be negative; start time must not be greater than - // end time, except if end time is -1 - if (start < 0 || end <= 0 || start >= end) { - KALDI_WARN << "Invalid line in segments file " - "[empty or invalid segment]: " - << line; + if (!feat_reader.HasKey(segment.reco_id)) { + KALDI_WARN << "Did not find features for utterance " << segment.reco_id + << ", skipping segment " << seg_id; + num_err++; continue; } - int32 channel = -1; // means channel info is unspecified. - // if each line has 5 elements then 5th element must be channel identifier - if (split_line.size() == 5) { - if (!ConvertStringToInteger(split_line[4], &channel) || channel < 0) { - KALDI_WARN<< "Invalid line in segments file [bad channel]: " << line; - continue; - } - } - - /* check whether a segment start time and end time exists in utterance - * if fails , skips the segment. - */ - if (!feat_reader.HasKey(utterance)) { - KALDI_WARN << "Did not find features for utterance " << utterance - << ", skipping segment " << segment; - continue; - } - const Matrix &feats = feat_reader.Value(utterance); - // total number of samples present in wav data + const Matrix &feats = feat_reader.Value(segment.reco_id); + // total number of samples present in features int32 num_samp = feats.NumRows(); - // total number of channels present in wav file - int32 num_chan = feats.NumCols(); + int32 dim= feats.NumCols(); // Convert start & end times of the segment to corresponding sample number - int32 start_samp = static_cast(round( - (start * 1000.0 / frame_shift))); - int32 end_samp = static_cast(round(end * 1000.0 / frame_shift)); + int32 start_samp = static_cast(( + (segment.start_time * 1000.0 / frame_shift))); + int32 end_samp = static_cast((segment.end_time * 1000.0 / frame_shift + 0.0495)); if (snip_edges) { // snip the edge at the end of the segment (usually 2 frames), @@ -172,11 +124,18 @@ int main(int argc, char *argv[]) { */ if (start_samp < 0 || start_samp >= num_samp) { KALDI_WARN << "Start sample out of range " << start_samp - << " [length:] " << num_samp << "x" << num_chan - << ", skipping segment " << segment; + << " [length:] " << num_samp << "x" << dim + << ", skipping segment " << seg_id; + num_err++; continue; } + if (end_samp < start_samp) { + KALDI_WARN << "End sample out of range " << end_samp + << " < start sample " << start_samp + << "; skipping segment " << seg_id; + } + /* end sample must be less than total number samples * otherwise skip the segment */ @@ -185,9 +144,9 @@ int main(int argc, char *argv[]) { + static_cast( round(max_overshoot * 1000.0 / frame_shift))) { KALDI_WARN<< "End sample too far out of range " << end_samp - << " [length:] " << num_samp << "x" << num_chan - << ", skipping segment " - << segment; + << " [length:] " << num_samp << "x" << dim + << ", skipping segment " << seg_id; + num_err++; continue; } end_samp = num_samp; // for small differences, just truncate. @@ -200,21 +159,22 @@ int main(int argc, char *argv[]) { <= start_samp + static_cast(round( (min_segment_length * 1000.0 / frame_shift)))) { - KALDI_WARN<< "Segment " << segment << " too short, skipping it."; + KALDI_WARN<< "Segment " << seg_id << " too short, skipping it."; + num_err++; continue; } SubMatrix segment_matrix(feats, start_samp, - end_samp-start_samp, 0, num_chan); + end_samp-start_samp, 0, dim); Matrix outmatrix(segment_matrix); // write segment in feature archive. - feat_writer.Write(segment, outmatrix); - num_success++; + feat_writer.Write(seg_id, outmatrix); + num_done++; } - KALDI_LOG << "Successfully processed " << num_success << " lines out of " - << num_lines << " in the segments file. "; + KALDI_LOG << "Successfully processed " << num_done << " segments; failed " + << num_err << " segments."; /* prints number of segments processed */ - if (num_success == 0) return -1; + if (num_done == 0) return -1; return 0; } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/featbin/extract-vector-segments.cc b/src/featbin/extract-vector-segments.cc new file mode 100644 index 00000000000..0ee6374b6f2 --- /dev/null +++ b/src/featbin/extract-vector-segments.cc @@ -0,0 +1,165 @@ +// featbin/extract-vector-segments.cc + +// Copyright 2009-2011 Microsoft Corporation; Govivace Inc. +// 2012-2013 Mirko Hannemann; Arnab Ghoshal +// 2015 Tanel Alumae +// 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Extract vectors corresponding to segments from whole vectors\n" + "Usage: extract-vector-segments [options...] \n"; + + // construct all the global objects + ParseOptions po(usage); + + BaseFloat min_segment_length = 0.1, // Minimum segment length in seconds. + max_overshoot = 0.0; // max time by which last segment can overshoot + int32 frame_shift = 10; + int32 frame_length = 25; + bool snip_edges = true; + + // Register the options + po.Register("min-segment-length", &min_segment_length, + "Minimum segment length in seconds (reject shorter segments)"); + po.Register("frame-length", &frame_length, "Frame length in milliseconds"); + po.Register("frame-shift", &frame_shift, "Frame shift in milliseconds"); + po.Register("max-overshoot", &max_overshoot, + "End segments overshooting by less (in seconds) are truncated," + " else rejected."); + po.Register("snip-edges", &snip_edges, + "If true, n_frames frames will be snipped from the beginning of each " + "extracted feature matrix, " + "where n_frames = ceil((frame_length - frame_shift) / frame_shift), " + "except for the segments at the beginning of a file, where " + "the snipping is done from the end. " + "This ensures that only the feature vectors that " + "completely fit in the segment are extracted. " + "This makes the extracted segment lengths match the lengths of the " + "features that have been extracted from already segmented audio."); + + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string rspecifier = po.GetArg(1); // get script file/vector archive + std::string segments_rspecifier = po.GetArg(2);// get segment file + std::string wspecifier = po.GetArg(3); // get written archive name + + BaseFloatVectorWriter vector_writer(wspecifier); + + SequentialUtteranceSegmentReader segments_reader(segments_rspecifier); + RandomAccessBaseFloatVectorReader vector_reader(rspecifier); + + int32 num_done = 0, num_err = 0; + + int32 snip_length = 0; + if (snip_edges) { + snip_length = static_cast(ceil( + 1.0 * (frame_length - frame_shift) / frame_shift)); + } + + for (; !segments_reader.Done(); segments_reader.Next()) { + const std::string &seg_id = segments_reader.Key(); + const UtteranceSegment &segment = segments_reader.Value(); + + if (!vector_reader.HasKey(segment.reco_id)) { + KALDI_WARN << "Did not find vector for utterance " << segment.reco_id + << ", skipping segment " << segment.reco_id; + continue; + } + const Vector &vector = vector_reader.Value(segment.reco_id); + + // total number of samples present in features + int32 num_samp = vector.Dim(); + // Convert start & end times of the segment to corresponding sample number + int32 start_samp = static_cast(( + (segment.start_time * 1000.0 / frame_shift))); + int32 end_samp = static_cast((segment.end_time * 1000.0 / frame_shift + 0.0495)); + + if (snip_edges) { + // snip the edge at the end of the segment (usually 2 frames), + end_samp -= snip_length; + } + + /* start sample must be less than total number of samples + * otherwise skip the segment + */ + if (start_samp < 0 || start_samp >= num_samp) { + KALDI_WARN << "Start sample out of range " << start_samp << " [length:] " + << num_samp << ", skipping segment " << seg_id; + num_err++; + continue; + } + + /* end sample must be less than total number samples + * otherwise skip the segment + */ + if (end_samp > num_samp) { + if (end_samp > + num_samp + static_cast(max_overshoot / frame_shift)) { + KALDI_WARN << "End sample too far out of range " << end_samp + << " [overshooted length:] " + << num_samp + static_cast(max_overshoot / frame_shift) + << ", skipping segment " << seg_id; + num_err++; + continue; + } + end_samp = num_samp; // for small differences, just truncate. + } + + /* check whether the segment size is less than minimum segment length(default 0.1 sec) + * if yes, skip the segment + */ + if (end_samp + <= start_samp + + static_cast(round( + (min_segment_length * 1000.0 / frame_shift)))) { + KALDI_WARN<< "Segment " << seg_id << " too short, skipping it."; + num_err++; + continue; + } + + SubVector segment_vector(vector, start_samp, + end_samp-start_samp); + Vector out_vector(segment_vector); + vector_writer.Write(seg_id, out_vector); // write segment in feature archive. + num_done++; + } + KALDI_LOG << "Successfully processed " << num_done << " segments; failed " + << num_err << " segments."; + /* prints number of segments processed */ + if (num_done == 0) return -1; + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/featbin/vector-to-feat.cc b/src/featbin/vector-to-feat.cc new file mode 100644 index 00000000000..5e98cf95a1c --- /dev/null +++ b/src/featbin/vector-to-feat.cc @@ -0,0 +1,99 @@ +// featbin/vector-to-feat.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Convert a vector into a single feature so that it can be appended \n" + "to other feature matrices\n" + "Usage: vector-to-feats \n" + "or: vector-to-feats \n" + "e.g.: vector-to-feats scp:weights.scp ark:weight_feats.ark\n" + " or: vector-to-feats weight_vec feat_mat\n" + "See also: copy-feats, copy-matrix, paste-feats, \n" + "subsample-feats, splice-feats\n"; + + ParseOptions po(usage); + bool compress = false, binary = true; + + po.Register("binary", &binary, "Binary-mode output (not relevant if writing " + "to archive)"); + po.Register("compress", &compress, "If true, write output in compressed form" + "(only currently supported for wxfilename, i.e. archive/script," + "output)"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + int32 num_done = 0; + + if (ClassifyRspecifier(po.GetArg(1), NULL, NULL) != kNoRspecifier) { + std::string vector_rspecifier = po.GetArg(1); + std::string feature_wspecifier = po.GetArg(2); + + SequentialBaseFloatVectorReader vector_reader(vector_rspecifier); + BaseFloatMatrixWriter feat_writer(feature_wspecifier); + CompressedMatrixWriter compressed_feat_writer(feature_wspecifier); + + for (; !vector_reader.Done(); vector_reader.Next(), ++num_done) { + const Vector &vec = vector_reader.Value(); + Matrix feat(vec.Dim(), 1); + feat.CopyColFromVec(vec, 0); + + if (!compress) + feat_writer.Write(vector_reader.Key(), feat); + else + compressed_feat_writer.Write(vector_reader.Key(), CompressedMatrix(feat)); + } + KALDI_LOG << "Converted " << num_done << " vectors into features"; + return (num_done != 0 ? 0 : 1); + } + + KALDI_ASSERT(!compress && "Compression not yet supported for single files"); + + std::string vector_rxfilename = po.GetArg(1), + feature_wxfilename = po.GetArg(2); + + Vector vec; + ReadKaldiObject(vector_rxfilename, &vec); + + Matrix feat(vec.Dim(), 1); + feat.CopyColFromVec(vec, 0); + + WriteKaldiObject(feat, feature_wxfilename, binary); + + KALDI_LOG << "Converted vector " << PrintableRxfilename(vector_rxfilename) + << " to " << PrintableWxfilename(feature_wxfilename); + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/gmmbin/Makefile b/src/gmmbin/Makefile index e5504891aab..7d46cb8284d 100644 --- a/src/gmmbin/Makefile +++ b/src/gmmbin/Makefile @@ -28,7 +28,8 @@ BINFILES = gmm-init-mono gmm-est gmm-acc-stats-ali gmm-align \ gmm-est-fmllr-raw gmm-est-fmllr-raw-gpost gmm-global-init-from-feats \ gmm-global-info gmm-latgen-faster-regtree-fmllr gmm-est-fmllr-global \ gmm-acc-mllt-global gmm-transform-means-global gmm-global-get-post \ - gmm-global-gselect-to-post gmm-global-est-lvtln-trans + gmm-global-gselect-to-post gmm-global-est-lvtln-trans \ + gmm-init-pdf-from-global gmm-extract-pdf OBJFILES = diff --git a/src/gmmbin/gmm-copy.cc b/src/gmmbin/gmm-copy.cc index 0b33bc6d67f..10a52b57f3e 100644 --- a/src/gmmbin/gmm-copy.cc +++ b/src/gmmbin/gmm-copy.cc @@ -36,12 +36,14 @@ int main(int argc, char *argv[]) { bool binary_write = true, copy_am = true, - copy_tm = true; + copy_tm = true, + write_tm = true; ParseOptions po(usage); po.Register("binary", &binary_write, "Write output in binary mode"); po.Register("copy-am", ©_am, "Copy the acoustic model (AmDiagGmm object)"); po.Register("copy-tm", ©_tm, "Copy the transition model"); + po.Register("write-tm", &write_tm, "Write the transition model"); po.Read(argc, argv); @@ -66,7 +68,7 @@ int main(int argc, char *argv[]) { { Output ko(model_out_filename, binary_write); - if (copy_tm) + if (copy_tm && write_tm) trans_model.Write(ko.Stream(), binary_write); if (copy_am) am_gmm.Write(ko.Stream(), binary_write); diff --git a/src/gmmbin/gmm-extract-pdf.cc b/src/gmmbin/gmm-extract-pdf.cc new file mode 100644 index 00000000000..284de11d6bb --- /dev/null +++ b/src/gmmbin/gmm-extract-pdf.cc @@ -0,0 +1,85 @@ +// gmmbin/gmm-extract-pdf.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + + const char *usage = + "Extract a pdf from a GMM based model \n" + "Usage: gmm-extract-pdf [options] \n" + "e.g.:\n" + " gmm-extract-pdf 1.mdl 0 0.gmm \n"; + + bool binary_write = true; + + ParseOptions po(usage); + po.Register("binary", &binary_write, "Write output in binary mode"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + gmm_out_filename = po.GetArg(3); + + int32 pdf_id; + if (!ConvertStringToInteger(po.GetArg(2), &pdf_id)) { + KALDI_ERR << "Unable to convert argument 2 (" << po.GetArg(2) + << ") to integer"; + } + + AmDiagGmm am_gmm; + { + bool binary_read; + Input ki(model_in_filename, &binary_read); + TransitionModel trans_model; + trans_model.Read(ki.Stream(), binary_read); + am_gmm.Read(ki.Stream(), binary_read); + } + + if (pdf_id >= am_gmm.NumPdfs() || pdf_id < 0) { + KALDI_ERR << "pdf-id " << pdf_id << " is not in the " + << "expected range [0-" << am_gmm.NumPdfs() - 1 << "]"; + } + + { + Output ko(gmm_out_filename, binary_write); + const DiagGmm &gmm = am_gmm.GetPdf(pdf_id); + gmm.Write(ko.Stream(), binary_write); + } + + KALDI_LOG << "Written gmm to " << gmm_out_filename; + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + + + diff --git a/src/gmmbin/gmm-global-init-from-feats.cc b/src/gmmbin/gmm-global-init-from-feats.cc index e83486dc3fb..2b781da31fe 100644 --- a/src/gmmbin/gmm-global-init-from-feats.cc +++ b/src/gmmbin/gmm-global-init-from-feats.cc @@ -29,7 +29,7 @@ namespace kaldi { // We initialize the GMM parameters by setting the variance to the global // variance of the features, and the means to distinct randomly chosen frames. -void InitGmmFromRandomFrames(const Matrix &feats, DiagGmm *gmm) { +void InitGmmFromRandomFrames(const Matrix &feats, BaseFloat var_floor, DiagGmm *gmm) { int32 num_gauss = gmm->NumGauss(), num_frames = feats.NumRows(), dim = feats.NumCols(); KALDI_ASSERT(num_frames >= 10 * num_gauss && "Too few frames to train on"); @@ -39,8 +39,10 @@ void InitGmmFromRandomFrames(const Matrix &feats, DiagGmm *gmm) { var.AddVec2(1.0 / num_frames, feats.Row(i)); } var.AddVec2(-1.0, mean); - if (var.Max() <= 0.0) - KALDI_ERR << "Features do not have positive variance " << var; + var.ApplyFloor(var_floor); + + //if (var.Min() <= 0.0) + // KALDI_ERR << "Features do not have positive variance " << var; DiagGmmNormal gmm_normal(*gmm); @@ -183,7 +185,7 @@ int main(int argc, char *argv[]) { KALDI_LOG << "Initializing GMM means from random frames to " << num_gauss_init << " Gaussians."; - InitGmmFromRandomFrames(feats, &gmm); + InitGmmFromRandomFrames(feats, gmm_opts.min_variance, &gmm); // we'll increase the #Gaussians by splitting, // till halfway through training. diff --git a/src/gmmbin/gmm-init-pdf-from-global.cc b/src/gmmbin/gmm-init-pdf-from-global.cc new file mode 100644 index 00000000000..f7fb7190e33 --- /dev/null +++ b/src/gmmbin/gmm-init-pdf-from-global.cc @@ -0,0 +1,94 @@ +// gmmbin/gmm-init-pdf-from-global.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + + const char *usage = + "Copy GMM based model and replace a pdf with a new GMM\n" + "Usage: gmm-init-pdf-from-global [options] \n" + "e.g.:\n" + " gmm-init-pdf-from-global 1.mdl new.gmm 1 1.new.mdl \n"; + + + bool binary_write = true; + + ParseOptions po(usage); + po.Register("binary", &binary_write, "Write output in binary mode"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), + gmm_in_filename = po.GetArg(3), + model_out_filename = po.GetArg(4); + + int32 pdf_id; + if (!ConvertStringToInteger(po.GetArg(2), &pdf_id)) { + KALDI_ERR << "Unable to convert argument 2 (" << po.GetArg(2) + << ") to integer"; + } + + AmDiagGmm am_gmm; + TransitionModel trans_model; + { + bool binary_read; + Input ki(model_in_filename, &binary_read); + trans_model.Read(ki.Stream(), binary_read); + am_gmm.Read(ki.Stream(), binary_read); + } + + if (pdf_id >= am_gmm.NumPdfs() || pdf_id < 0) { + KALDI_ERR << "pdf-id " << pdf_id << " is not in the " + << "expected range [0-" << am_gmm.NumPdfs() - 1 << "]"; + } + + { + bool binary_read; + Input ki(gmm_in_filename, &binary_read); + DiagGmm gmm; + gmm.Read(ki.Stream(), binary_read); + am_gmm.GetPdf(pdf_id).CopyFromDiagGmm(gmm); + } + + { + Output ko(model_out_filename, binary_write); + trans_model.Write(ko.Stream(), binary_write); + am_gmm.Write(ko.Stream(), binary_write); + } + + KALDI_LOG << "Written model to " << model_out_filename; + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + + diff --git a/src/ivectorbin/Makefile b/src/ivectorbin/Makefile index d05efc3093a..a6043f2f969 100644 --- a/src/ivectorbin/Makefile +++ b/src/ivectorbin/Makefile @@ -15,7 +15,10 @@ BINFILES = ivector-extractor-init ivector-extractor-acc-stats \ ivector-subtract-global-mean ivector-plda-scoring \ logistic-regression-train logistic-regression-eval \ logistic-regression-copy create-split-from-vad \ - ivector-extract-online ivector-adapt-plda + ivector-extract-online ivector-adapt-plda \ + select-top-frames speaker-diarization ivector-combine-to-recording \ + ivector-split-to-segments \ + logistic-regression-train-on-feats logistic-regression-eval-on-feats OBJFILES = @@ -26,6 +29,6 @@ TESTFILES = ADDLIBS = ../ivector/kaldi-ivector.a ../hmm/kaldi-hmm.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../thread/kaldi-thread.a ../matrix/kaldi-matrix.a \ - ../util/kaldi-util.a ../base/kaldi-base.a + ../util/kaldi-util.a ../base/kaldi-base.a ../segmenter/kaldi-segmenter.a include ../makefiles/default_rules.mk diff --git a/src/ivectorbin/ivector-combine-to-recording.cc b/src/ivectorbin/ivector-combine-to-recording.cc new file mode 100644 index 00000000000..65e36f4f882 --- /dev/null +++ b/src/ivectorbin/ivector-combine-to-recording.cc @@ -0,0 +1,126 @@ +// ivectorbin/ivector-combine-to-recording.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "base/kaldi-extra-types.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + try { + const char *usage = + "Combine iVectors for utterances into iVector matrix for recording " + "along with a segmentation corresponding to the utterances\n" + "Usage: ivector-combine-to-recording \n" + "e.g.: \n" + " ivector-combine-to-recording data/dev_diarized/split10/1/segments ark:exp/ivectors_dev_reco/ivectors_utt.1.ark ark:exp/ivectors_dev_reco/reco_segmentation.1.ark ark:exp/ivectors_dev_reco/ivectors_seg.1.ark\n"; + + ParseOptions po(usage); + BaseFloat frame_shift = 0.01; + + po.Register("frame-shift", &frame_shift, + "Frame shift in second"); + + po.Read(argc, argv); + + if (po.NumArgs() != 5) { + po.PrintUsage(); + exit(1); + } + + std::string reco2utt_rxfilename = po.GetArg(1); + std::string segments_rspecifier = po.GetArg(2), + utt_ivectors_rspecifier = po.GetArg(3), + segmentation_wspecifier = po.GetArg(4), + ivectors_wspecifier = po.GetArg(5); + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rxfilename); + RandomAccessUtteranceSegmentReader segment_reader(segments_rspecifier); + RandomAccessBaseFloatVectorReader ivector_reader(utt_ivectors_rspecifier); + segmenter::SegmentationWriter seg_writer(segmentation_wspecifier); + BaseFloatMatrixWriter ivector_writer(ivectors_wspecifier); + + int32 num_reco = 0, num_success = 0, num_err = 0; + + BaseFloat ivector_dim = -1; + for (; !reco2utt_reader.Done(); reco2utt_reader.Next(), num_reco++) { + std::string reco = reco2utt_reader.Key(); + const std::vector &uttlist = reco2utt_reader.Value(); + + bool missing_utt = false; + segmenter::Segmentation seg; + + for (std::vector::const_iterator it = uttlist.begin(); + it != uttlist.end(); ++it) { + + if (!segment_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in segments " + << "file " << segments_rspecifier; + missing_utt = true; + } + if (!ivector_reader.HasKey(*it)) { + KALDI_WARN << "Could not find iVector for utterance " << *it; + missing_utt = true; + } + + if (missing_utt) { + num_err++; + break; + } + + const UtteranceSegment &segment = segment_reader.Value(*it); + const Vector &ivector = ivector_reader.Value(*it); + + if (ivector_dim == -1) + ivector_dim = ivector.Dim(); + + seg.Emplace(std::round(segment.start_time / frame_shift), + std::round(segment.end_time / frame_shift), 1, + ivector); + } + + if (missing_utt) continue; + seg.Sort(); + seg_writer.Write(reco, seg); + + Matrix ivector_out(seg.Dim(), ivector_dim); + + size_t i = 0; + for (segmenter::SegmentList::const_iterator it = seg.Begin(); + it != seg.End(); ++it, i++) { + ivector_out.CopyRowFromVec(it->VectorValue(), i); + } + ivector_writer.Write(reco, ivector_out); + num_success++; + } + + KALDI_LOG << "Combined iVectors for " << num_success + << " out of " << num_reco << " recordings; " + << " errors in " << num_err << " recordings"; + return (num_success > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/ivectorbin/ivector-extract.cc b/src/ivectorbin/ivector-extract.cc index 220677d9af0..aa191e24875 100644 --- a/src/ivectorbin/ivector-extract.cc +++ b/src/ivectorbin/ivector-extract.cc @@ -33,7 +33,7 @@ class IvectorExtractTask { public: IvectorExtractTask(const IvectorExtractor &extractor, std::string utt, - const Matrix &feats, + const MatrixBase &feats, const Posterior &posterior, BaseFloatVectorWriter *writer, double *tot_auxf_change): @@ -95,7 +95,8 @@ int32 RunPerSpeaker(const std::string &ivector_extractor_rxfilename, const std::string &spk2utt_rspecifier, const std::string &feature_rspecifier, const std::string &posterior_rspecifier, - const std::string &ivector_wspecifier) { + const std::string &ivector_wspecifier, + int32 length_tolerance = 0) { IvectorExtractor extractor; ReadKaldiObject(ivector_extractor_rxfilename, &extractor); SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); @@ -131,16 +132,24 @@ int32 RunPerSpeaker(const std::string &ivector_extractor_rxfilename, continue; } Posterior posterior = posterior_reader.Value(utt); - if (feats.NumRows() != posterior.size()) { + if (std::abs(feats.NumRows() - static_cast(posterior.size())) > length_tolerance) { KALDI_WARN << "Posterior has wrong size " << posterior.size() << " vs. feats " << feats.NumRows() << " for " << utt; num_utt_err++; continue; } + + for (int32 j = 0; j < static_cast(posterior.size()) - feats.NumRows(); j++) { + posterior.pop_back(); + } ScalePosterior(opts.acoustic_weight, &posterior); num_utt_done++; - utt_stats.AccStats(feats, posterior); + + if (posterior.size() == feats.NumRows()) + utt_stats.AccStats(feats, posterior); + else + utt_stats.AccStats(feats.Range(0, posterior.size(), 0, feats.NumCols()), posterior); } if (utt_stats.NumFrames() == 0.0) { @@ -225,6 +234,7 @@ int main(int argc, char *argv[]) { bool compute_objf_change = true; IvectorEstimationOptions opts; std::string spk2utt_rspecifier; + int32 length_tolerance = 0; TaskSequencerConfig sequencer_config; po.Register("compute-objf-change", &compute_objf_change, "If true, compute the change in objective function from using " @@ -236,6 +246,9 @@ int main(int argc, char *argv[]) { "is not the normal way iVectors are obtained for speaker-id. " "This option will cause the program to ignore the --num-threads " "option."); + po.Register("length-tolerance", &length_tolerance, + "Tolerance on difference in number of frames in posterior " + "and weights."); opts.Register(&po); sequencer_config.Register(&po); @@ -279,13 +292,17 @@ int main(int argc, char *argv[]) { const Matrix &mat = feature_reader.Value(); Posterior posterior = posterior_reader.Value(utt); - if (static_cast(posterior.size()) != mat.NumRows()) { + if (std::abs(mat.NumRows() - static_cast(posterior.size())) > length_tolerance) { KALDI_WARN << "Size mismatch between posterior " << posterior.size() << " and features " << mat.NumRows() << " for utterance " << utt; num_err++; continue; } + + for (int32 j = 0; j < static_cast(posterior.size()) - mat.NumRows(); j++) { + posterior.pop_back(); + } double *auxf_ptr = (compute_objf_change ? &tot_auxf_change : NULL ); @@ -302,7 +319,11 @@ int main(int argc, char *argv[]) { &posterior); // note: now, this_t == sum of posteriors. - sequencer.Run(new IvectorExtractTask(extractor, utt, mat, posterior, + if (posterior.size() == mat.NumRows()) + sequencer.Run(new IvectorExtractTask(extractor, utt, mat, posterior, + &ivector_writer, auxf_ptr)); + else + sequencer.Run(new IvectorExtractTask(extractor, utt, mat.Range(0, posterior.size(), 0, mat.NumCols()), posterior, &ivector_writer, auxf_ptr)); tot_t += this_t; @@ -328,7 +349,8 @@ int main(int argc, char *argv[]) { spk2utt_rspecifier, feature_rspecifier, posterior_rspecifier, - ivectors_wspecifier); + ivectors_wspecifier, + length_tolerance); } } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/ivectorbin/ivector-split-to-segments.cc b/src/ivectorbin/ivector-split-to-segments.cc new file mode 100644 index 00000000000..80a4750627e --- /dev/null +++ b/src/ivectorbin/ivector-split-to-segments.cc @@ -0,0 +1,172 @@ +// ivectorbin/ivector-split-to-segments.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "base/kaldi-extra-types.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + try { + const char *usage = + "Split iVectors for recording into new segments\n" + "Usage: ivector-split-to-segments [options] " + " \n" + "e.g.: \n" + " ivector-split-to-segments " + "ark:exp/ivectors_dev_reco/ivectors_reco.1.ark " + "ark:exp/ivectors_dev_reco/reco_segmentation.1.ark ark:data/dev_uniformsegmented_win10_over5/split10/1/reco2utt " + "ark:data/dev_uniformsegmented_win10_over5/split10/1/segments ark:- | " + "paste-feats ark:feats.1.ark ark:- " + "ark:exp/ivectors_dev_uniformsegmented_win10_over5/ivector_online.ark\n"; + + ParseOptions po(usage); + BaseFloat frame_shift = 0.01; + int32 offset_frames = 2; + + po.Register("frame-shift", &frame_shift, + "Frame shift in second"); + po.Register("offset-frames", &offset_frames, + "Number of frames to reduce output iVector size by, to " + "adjust boundary to match feature length"); + + po.Read(argc, argv); + + if (po.NumArgs() != 5) { + po.PrintUsage(); + exit(1); + } + + std::string ivectors_rspecifier = po.GetArg(1), + segmentations_rspecifier = po.GetArg(2), + reco2utt_rspecifier = po.GetArg(3), + segments_rspecifier = po.GetArg(4), + ivectors_wspecifier = po.GetArg(5); + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessUtteranceSegmentReader segment_reader(segments_rspecifier); + segmenter::RandomAccessSegmentationReader reco_seg_reader(segmentations_rspecifier); // Corresponding to the read reco iVectors + RandomAccessBaseFloatMatrixReader ivector_reader(ivectors_rspecifier); + BaseFloatMatrixWriter ivector_writer(ivectors_wspecifier); + + int32 num_reco = 0, num_success = 0, num_err = 0; + + int32 ivector_dim = -1; + for (; !reco2utt_reader.Done(); reco2utt_reader.Next(), num_reco++) { + std::string reco = reco2utt_reader.Key(); + const std::vector &uttlist = reco2utt_reader.Value(); + + if (!reco_seg_reader.HasKey(reco)) { + KALDI_WARN << "Could not read segmentation for recording " << reco; + continue; + } + + if (!ivector_reader.HasKey(reco)) { + KALDI_WARN << "Could not find iVector for recording " << reco; + continue; + } + + const Matrix& ivector_in = ivector_reader.Value(reco); + segmenter::Segmentation reco_seg(reco_seg_reader.Value(reco)); + + ivector_dim = ivector_in.NumCols(); + + KALDI_ASSERT(ivector_in.NumRows() == reco_seg.Dim()); + size_t i = 0; + for (segmenter::SegmentList::iterator it = reco_seg.Begin(); + it != reco_seg.End(); ++it, i++) { + it->SetVectorValue(SubVector(ivector_in, i)); + } + KALDI_ASSERT(i == ivector_in.NumRows()); + + // reco_seg is sorted because it is written that way. + // This can be checked if needed. + + // Convert the segments file to segmentation + segmenter::Segmentation segments_seg; + for (std::vector::const_iterator it = uttlist.begin(); + it != uttlist.end(); ++it) { + + if (!segment_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in segments " + << "file " << segments_rspecifier; + num_err++; + continue; + } + const UtteranceSegment &segment = segment_reader.Value(*it); + + + segments_seg.Emplace(std::round(segment.start_time / frame_shift), + std::round(segment.end_time / frame_shift), 1, *it); + } + + segmenter::Segmentation new_seg; + segments_seg.CreateSubSegments(reco_seg, 1, 1, &new_seg); + + segmenter::SegmentList::iterator new_it = new_seg.Begin(); + for (segmenter::SegmentList::iterator utt_it = segments_seg.Begin(); + utt_it != segments_seg.End(); ++utt_it) { + // Effectively doing "For each segment in segments file" + + // start_frame and end_frame are all reco-level for both utt_it and + // seg_it + + // Create iVector matrix for the segment in segments file. + // Offset frames is to correct for the fact that feats extracted for + // the segment will have about 2 frames less at the end. + Matrix ivector(utt_it->end_frame - utt_it->start_frame - offset_frames, ivector_dim); + + KALDI_ASSERT(new_it->StringValue() == utt_it->StringValue()); + // By the way the CreateSubSegments function is written, this + // must be true + + for (; new_it != new_seg.End() && new_it->StringValue() == utt_it->StringValue(); ++new_it) { + size_t num_frames = new_it->end_frame - new_it->start_frame; + if (new_it->end_frame > utt_it->end_frame - offset_frames) { + num_frames -= offset_frames; + } + // Copy iVector for the subsegment into the iVector matrix in + // the segments file + SubMatrix this_ivector(ivector, new_it->start_frame - utt_it->start_frame, num_frames, 0, ivector.NumCols()); + this_ivector.CopyRowsFromVec(new_it->VectorValue()); + KALDI_ASSERT(this_ivector(0,0) != 0); + + KALDI_VLOG(2) << utt_it->StringValue() << " " << new_it->start_frame << " " << new_it->end_frame << " " << new_it->VectorValue(); + } + + ivector_writer.Write(utt_it->StringValue(), ivector); + } + num_success++; + } + + KALDI_LOG << "Split iVectors for " << num_success + << " out of " << num_reco << " recordings; " + << " errors in " << num_err << " segments"; + return (num_success > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/ivectorbin/logistic-regression-eval-on-feats.cc b/src/ivectorbin/logistic-regression-eval-on-feats.cc new file mode 100644 index 00000000000..773c9a0a1d3 --- /dev/null +++ b/src/ivectorbin/logistic-regression-eval-on-feats.cc @@ -0,0 +1,88 @@ +// ivectorbin/logistic-regression-eval-on-feats.cc + +// Copyright 2014 David Snyder +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "ivector/logistic-regression.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Evaluates a model on input vectors and outputs either\n" + "log posterior probabilities or scores.\n" + "Usage1: logistic-regression-eval " + "\n"; + + ParseOptions po(usage); + + bool apply_log = true; + po.Register("apply-log", &apply_log, + "If false, apply Exp to the log posteriors output. This is " + "helpful when combining posteriors from multiple logistic " + "regression models."); + LogisticRegressionConfig config; + config.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model = po.GetArg(1), + feats_rspecifier = po.GetArg(2), + log_posteriors_wspecifier = po.GetArg(3); + + LogisticRegression classifier; + ReadKaldiObject(model, &classifier); + + Matrix feats; + + SequentialBaseFloatMatrixReader feats_reader(feats_rspecifier); + BaseFloatMatrixWriter log_probs_writer(log_posteriors_wspecifier); + + int32 num_utt_done = 0; + + for (; !feats_reader.Done(); feats_reader.Next()) { + const std::string &key = feats_reader.Key(); + const Matrix &feats = feats_reader.Value(); + + Matrix log_posteriors; + + classifier.GetLogPosteriors(feats, &log_posteriors); + if (!apply_log) + log_posteriors.ApplyExp(); + + log_probs_writer.Write(key, log_posteriors); + num_utt_done++; + } + + KALDI_LOG << "Calculated log posteriors for " << num_utt_done << " vectors."; + return (num_utt_done == 0 ? 1 : 0); + + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/ivectorbin/logistic-regression-train-on-feats.cc b/src/ivectorbin/logistic-regression-train-on-feats.cc new file mode 100644 index 00000000000..c3347088606 --- /dev/null +++ b/src/ivectorbin/logistic-regression-train-on-feats.cc @@ -0,0 +1,148 @@ +// segmenterbin/logistic-regression-train-on-feats.cc + +// Copyright 2014 David Snyder +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "ivector/logistic-regression.h" + + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Trains a model using Logistic Regression with L-BFGS from " + "a set of features, where each row corresponds to a frame.\n" + "The corresponding frame labels are in " + "which is a vector of integers with one label for each frame.\n" + "The number of targets is input by the user; the labels must be " + "between 0 and -1.\n" + "Usage: logistic-regression-train-on-feats \n" + " \n"; + + ParseOptions po(usage); + + bool binary = true; + int32 num_targets = 2; + int32 num_frames = 200000; + int32 srand_seed = 0; + std::string model_rxfilename; + + LogisticRegressionConfig config; + config.Register(&po); + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("num-targets", &num_targets, "Number of target labels"); + po.Register("num-frames", &num_frames, + "Number of feature vectors to store in " + "memory and train on (randomly chosen from the input features)"); + po.Register("srand", &srand_seed, "Seed for random number generator "); + po.Register("model-rxfilename", &model_rxfilename, "Initialize " + "logistic-regression model"); + + po.Read(argc, argv); + + srand(srand_seed); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string feats_rspecifier = po.GetArg(1), + labels_rspecifier = po.GetArg(2), + model_out = po.GetArg(3); + + SequentialBaseFloatMatrixReader feats_reader(feats_rspecifier); + RandomAccessInt32VectorReader labels_reader(labels_rspecifier); + + Matrix feats; + std::vector labels(num_frames, -1); + + KALDI_ASSERT(num_frames > 0); + KALDI_LOG << "Reading features (will keep " << num_frames << " frames.)"; + + int64 num_read = 0, dim = 0; + + for (; !feats_reader.Done(); feats_reader.Next()) { + const std::string &key = feats_reader.Key(); + const Matrix &this_feats = feats_reader.Value(); + + if (!labels_reader.HasKey(key)) { + KALDI_WARN << "No labels found for utterance " << key; + continue; + } + + std::vector this_labels = labels_reader.Value(key); + + for (int32 t = 0; t < this_feats.NumRows(); t++) { + num_read++; + if (dim == 0) { + dim = this_feats.NumCols(); + feats.Resize(num_frames, dim); + } else if (this_feats.NumCols() != dim) { + KALDI_ERR << "Features have inconsistent dims " + << this_feats.NumCols() << " vs. " << dim + << " (current utt is) " << feats_reader.Key(); + } + if (num_read <= num_frames) { + feats.Row(num_read - 1).CopyFromVec(this_feats.Row(t)); + labels[num_read - 1] = this_labels[t]; + } else { + BaseFloat keep_prob = num_frames / static_cast(num_read); + if (WithProb(keep_prob)) { // With probability "keep_prob" + int32 t1 = RandInt(0, num_frames - 1); + feats.Row(t1).CopyFromVec(this_feats.Row(t)); + if ( this_labels[t] < 0 || this_labels[t] >= num_targets ) { + KALDI_ERR << "Label must be between 0 and -1; " + << "; but found label " << this_labels[t]; + } + labels[t1] = this_labels[t]; + } + } + } + } + + if (num_read < num_frames) { + KALDI_WARN << "Number of frames read " << num_read << " was less than " + << "target number " << num_frames << ", using all we read."; + feats.Resize(num_read, dim, kCopyData); + } else { + BaseFloat percent = num_frames * 100.0 / num_read; + KALDI_LOG << "Kept " << num_frames << " out of " << num_read + << " input frames = " << percent << "%."; + } + + LogisticRegression classifier = LogisticRegression(); + + if (!model_rxfilename.empty()) { + ReadKaldiObject(model_rxfilename, &classifier); + } + + classifier.Train(feats, labels, config); + WriteKaldiObject(classifier, model_out, binary); + + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/ivectorbin/select-top-frames.cc b/src/ivectorbin/select-top-frames.cc new file mode 100644 index 00000000000..8b3bb109365 --- /dev/null +++ b/src/ivectorbin/select-top-frames.cc @@ -0,0 +1,378 @@ +// ivectorbin/select-top-chunks.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" +#include "feat/feature-functions.h" + +namespace kaldi { + + void SmoothVector(int32 window_size, Vector *weights) { + Vector weights_tmp(weights->Dim()); + weights_tmp.CopyFromVec(*weights); + + for (int32 i = 0; i < weights_tmp.Dim(); i++) { + int32 left_index = std::max(0, i - window_size); + int32 right_index = std::min(i + window_size, weights_tmp.Dim() - 1); + SubVector this_weights(weights_tmp, + left_index, right_index - left_index); + (*weights)(i) = this_weights.Sum() / this_weights.Dim(); + } + } + + void SmoothMask(int32 window_size, int32 select_class, BaseFloat threshold, + Vector *mask) { + Vector mask_tmp(mask->Dim()); + mask_tmp.CopyFromVec(*mask); + + for (int32 i = 0; i < mask_tmp.Dim(); i++) { + int32 left_index = std::max(0, i - window_size); + int32 right_index = std::min(i + window_size, mask_tmp.Dim() - 1); + + int32 mask_sum = 0; + for (int32 j = left_index; j <= right_index; j++) { + mask_sum += (mask_tmp(j) == select_class ? 1 : 0); + } + (*mask)(i) = (static_cast(mask_sum) / (right_index - left_index + 1) >= threshold ? select_class : -1.0); + } + } + + template + class OtherVectorComparator { + public: + OtherVectorComparator(const std::vector &vec, bool descending = true) + : vec_(vec), descending_(descending) { } + + bool operator() (int32 a, int32 b) { + if (descending_) return vec_[a] > vec_[b]; + else return vec_[a] < vec_[b]; + } + + inline void SetDescending() { descending_ = true; } + inline void SetAscending() { descending_ = false; } + + private: + const std::vector &vec_; + bool descending_; + }; + + template class OtherVectorComparator; +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using kaldi::int32; + + const char *usage = + "Select a subset of chunks of frames of the input files, based on the log-energy\n" + "of the frames\n" + "Usage: select-top-chunks [options] " + " []\n" + "e.g. : select-top-chunks --frames-proportion=0.1 --window-size=100" + "scp:feats.scp ark:-\n"; + + BaseFloat frames_proportion = 1.0; + int32 window_size = 100; // 100 frames = 1 second assuming a shift of 10ms + std::string mask_rspecifier, weights_rspecifier, weights_next_rspecifier; + int32 select_frames = -1, select_frames_next = -1; + int32 select_class = 1; + int32 dim_as_weight = -1; + bool select_bottom_frames = false, select_bottom_frames_next = false; + bool smooth_weights = false, smooth_mask = false; + int32 smoothing_window = 4; + BaseFloat selection_threshold = 0.5; + + ParseOptions po(usage); + po.Register("frames-proportion", &frames_proportion, + "Select only the top / bottom proportion frames by the feature value"); + po.Register("select-class", &select_class, + "Select frames of this class in the mask"); + po.Register("num-select-frames", &select_frames, + "Select these many frames, instead of looking at a proportion. " + "Overrides frame-proportion if provided and >= 0"); + po.Register("num-select-frames-next", &select_frames_next, + "Second level of selection of frames among frames selected " + "using num-frames"); + po.Register("window-size", &window_size, + "Size of window to consider at once"); + po.Register("weights", &weights_rspecifier, + "Read weights from an archive to do selection"); + po.Register("weights_next", &weights_next_rspecifier, + "Read weights from an archive to do a second level of selection"); + po.Register("selection-mask", &mask_rspecifier, + "Selection mask on the frames. These are chosen for the chunk " + "based on majority"); + po.Register("select-bottom-frames", &select_bottom_frames, + "Select the bottom frames instead of top frames"); + po.Register("select-bottom-frames-next", &select_bottom_frames_next, + "Select bottom frames for second level of selection"); + po.Register("use-dim-as-weight", &dim_as_weight, + "Use a particular dimension of feature (e.g. C0) as the weight " + "when --weights is not specified"); + po.Register("smoothing-window", &smoothing_window, + "Size of smoothing window. Applicable if --smooth-vectors=true"); + po.Register("smooth-weights", &smooth_weights, + "Smooth weights over a window"); + po.Register("smooth-mask", &smooth_mask, + "Smooth mask over a window"); + po.Register("selection-threshold", &selection_threshold, + "Select chunks that have this fraction of frames to be " + "the select-class"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2 && po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string feat_rspecifier = po.GetArg(1), + feat_wspecifier = po.GetArg(2), + mask_wspecifier; + + if (po.NumArgs() == 3) + mask_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader feat_reader(feat_rspecifier); + BaseFloatMatrixWriter feat_writer(feat_wspecifier); + + RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier); + RandomAccessBaseFloatVectorReader weights_next_reader(weights_next_rspecifier); + RandomAccessBaseFloatVectorReader mask_reader(mask_rspecifier); + + BaseFloatVectorWriter mask_writer(mask_wspecifier); + + int32 num_done = 0, num_err = 0; + long long num_select = 0, num_frames = 0, num_filtered = 0, + num_select_second_level = 0, num_filtered_second_level = 0, + num_masked = 0; + + for (; !feat_reader.Done(); feat_reader.Next()) { + std::string utt = feat_reader.Key(); + const Matrix &feats = feat_reader.Value(); + if (feats.NumRows() == 0) { + KALDI_WARN << "Empty feature matrix for utterance " << utt; + num_err++; + continue; + } + num_frames += feats.NumRows(); + + int32 num_chunks = ( feats.NumRows() + 0.5 * window_size ) / window_size; + if (num_chunks == 0) { + KALDI_WARN << "No chunks found for utterance " << utt; + num_err++; + continue; + } + + // Find chunk size to use + int32 chunk_size = feats.NumRows() / num_chunks; + + Vector weights; + Vector mask; + Vector weights_next; + + // Read weights if specified + if (weights_rspecifier != "") { + if (!weights_reader.HasKey(utt)) { + KALDI_WARN << "weights not found for utterance " << utt; + num_err++; + continue; + } + weights = (weights_reader.Value(utt)); + } + + if (dim_as_weight > 0) { + KALDI_ASSERT(dim_as_weight < feats.NumCols()); + weights.CopyColFromMat(feats, dim_as_weight); + } + + // Read mask if specified + if (mask_rspecifier != "") { + if (!mask_reader.HasKey(utt)) { + KALDI_WARN << "mask not found for utterance " << utt; + num_err++; + continue; + } + mask = (mask_reader.Value(utt)); + } + + // Read second level weights if specified + if (weights_next_rspecifier != "") { + if (!weights_next_reader.HasKey(utt)) { + KALDI_WARN << "second-level weights not found for utterance " << utt; + num_err++; + continue; + } + weights_next = (weights_next_reader.Value(utt)); + } + + std::vector chunk_weights; + std::vector chunk_weights_next; + std::vector chunk_mask; + + if (weights_rspecifier != "") + chunk_weights.resize(num_chunks, 0.0); + + if (weights_next_rspecifier != "") + chunk_weights_next.resize(num_chunks, 0.0); + + if (mask_rspecifier != "") + chunk_mask.resize(num_chunks, 0); + else + chunk_mask.resize(num_chunks, 1); + + if (smooth_weights) { + SmoothVector(smoothing_window, &weights); + SmoothVector(smoothing_window, &weights_next); + } + + if (smooth_mask) { + SmoothMask(smoothing_window, select_class, selection_threshold, &mask); + } + + // Find average weight for each chunk + for (int32 i = 0; i < num_chunks; i++) { + if (weights_rspecifier != "") { + SubVector this_chunk_weights(weights, i*chunk_size, chunk_size); + chunk_weights[i] = this_chunk_weights.Sum() / chunk_size; + } + if (weights_next_rspecifier != "") { + SubVector this_chunk_weights_next(weights_next, i*chunk_size, chunk_size); + chunk_weights_next[i] = this_chunk_weights_next.Sum() / chunk_size; + } + if (mask_rspecifier != "") { + SubVector this_chunk_mask(mask, i*chunk_size, chunk_size); + int32 mask_sum = 0; + for (int32 j = 0; j < this_chunk_mask.Dim(); j++) { + if (this_chunk_mask(j) == select_class) { + mask_sum++; + num_masked++; + } + } + chunk_mask[i] = (static_cast(mask_sum) / chunk_size >= selection_threshold ? 1 : 0); + } + } + + std::vector idx; + for (int32 i = 0; i < num_chunks; i++) { + if ( chunk_mask[i] == 1 ) + idx.push_back(i); + } + + int32 this_select = 0; + + if (select_frames < 0) { + if (frames_proportion == 1.0) + this_select = idx.size(); + else + this_select = frames_proportion * idx.size() + 0.5; + } else + this_select = (static_cast(select_frames) + 0.5) / chunk_size ; + + // No chunk selected. Just select one instead. + if (this_select == 0) this_select = 1; + + num_filtered += idx.size() * chunk_size; + + if (this_select < idx.size()) { + // Need to select frames because this_select is less than the + // number of chunks found + + if (chunk_weights.size() > 0) { + // Select only top frames according to chunk_weights + OtherVectorComparator comparator(chunk_weights); + if (select_bottom_frames) comparator.SetAscending(); + + sort(idx.begin(), idx.end(), comparator); + } + idx.resize(this_select); + } else { + this_select = idx.size(); + } + + num_select += this_select * chunk_size; + num_filtered_second_level += idx.size() * chunk_size; + + int32 this_select_next = idx.size(); + if (select_frames_next > 0) + this_select_next = (static_cast(select_frames_next) + 0.5) / chunk_size; + if (this_select_next == 0) this_select_next = 1; + + if (this_select_next < idx.size()) { + // Need to select frames at second level because this_select is + // less than the number of chunks retained after first level + // selection + + if (chunk_weights_next.size() > 0) { + // Select only top frames according to chunk_weights + OtherVectorComparator comparator(chunk_weights_next); + if (select_bottom_frames_next) comparator.SetAscending(); + + sort(idx.begin(), idx.end(), comparator); + } + idx.resize(this_select_next); + } else { + this_select_next = idx.size(); + } + + num_select_second_level += this_select_next * chunk_size; + + Matrix selected_feats(this_select_next * chunk_size, feats.NumCols()); + Vector output_mask(feats.NumRows()); + + int32 n = 0; + for (std::vector::const_iterator it = idx.begin(); + it != idx.end(); ++it, n++) { + KALDI_VLOG(2) << utt << " " << *it * chunk_size << " " << *it * chunk_size + chunk_size; + SubMatrix src_feats(feats, *it * chunk_size, chunk_size, 0, feats.NumCols()); + SubMatrix dst_feats(selected_feats, n*chunk_size, chunk_size, 0, feats.NumCols()); + dst_feats.CopyFromMat(src_feats); + + if (mask_wspecifier != "") { + SubVector this_mask(output_mask, *it * chunk_size, chunk_size); + this_mask.Set(1.0); + } + } + + feat_writer.Write(feat_reader.Key(), selected_feats); + if (mask_wspecifier != "") + mask_writer.Write(feat_reader.Key(), output_mask); + + num_done++; + } + + KALDI_LOG << "Done selecting " << num_select_second_level + << " top frames out of " + << num_filtered_second_level << " frames at second level; " + << num_select << " top frames were selected at first level " + << "out of " << num_filtered + << " frames filtered through the mask out of " + << num_frames << " frames ; " + << " unmasked " << num_masked << " frames; processed " + << num_done << " utterances, " + << num_err << " had errors."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/ivectorbin/select-voiced-frames.cc b/src/ivectorbin/select-voiced-frames.cc index beba1a4f068..cfad5c71156 100644 --- a/src/ivectorbin/select-voiced-frames.cc +++ b/src/ivectorbin/select-voiced-frames.cc @@ -37,8 +37,14 @@ int main(int argc, char *argv[]) { "Usage: select-voiced-frames [options] " " \n" "E.g.: select-voiced-frames [options] scp:feats.scp scp:vad.scp ark:-\n"; - + + bool select_unvoiced_frames = false; + ParseOptions po(usage); + po.Register("select-unvoiced-frames", &select_unvoiced_frames, + "Reverses the operation of this file and selects " + "unvoiced frames instead"); + po.Read(argc, argv); if (po.NumArgs() != 3) { @@ -86,15 +92,27 @@ int main(int argc, char *argv[]) { } int32 dim = 0; for (int32 i = 0; i < voiced.Dim(); i++) - if (voiced(i) != 0.0) - dim++; + if (!select_unvoiced_frames) { + if (voiced(i) != 0.0) + dim++; + } else { + if (voiced(i) == 0.0) + dim++; + } Matrix voiced_feat(dim, feat.NumCols()); int32 index = 0; for (int32 i = 0; i < feat.NumRows(); i++) { - if (voiced(i) != 0.0) { - KALDI_ASSERT(voiced(i) == 1.0); // should be zero or one. - voiced_feat.Row(index).CopyFromVec(feat.Row(i)); - index++; + if (!select_unvoiced_frames) { + if (voiced(i) != 0.0) { + KALDI_ASSERT(voiced(i) == 1.0); // should be zero or one. + voiced_feat.Row(index).CopyFromVec(feat.Row(i)); + index++; + } + } else { + if (voiced(i) == 0.0) { + voiced_feat.Row(index).CopyFromVec(feat.Row(i)); + index++; + } } } KALDI_ASSERT(index == dim); diff --git a/src/ivectorbin/speaker-diarization.cc b/src/ivectorbin/speaker-diarization.cc new file mode 100644 index 00000000000..17718ecd325 --- /dev/null +++ b/src/ivectorbin/speaker-diarization.cc @@ -0,0 +1,135 @@ +// Copyright 2014 David Snyder + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/cluster-utils.h" +#include "tree/cluster-utils.cc" +#include "tree/clusterable-classes.h" +#include "ivector/logistic-regression.h" +#include "ivector/plda.h" + +using namespace kaldi; + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = "Does speaker diarzation using k-means clustering of PLDA transformed i-vectors.\n" + "Usage: speaker-diarization \n" + " e.g.: speaker-diarization plda ark,t:data/dev/spk2utt scp:exp/ivectors_dev/ivector.scp ark,t:exp/diarization_dev/diarization.txt\n" + "\n"; + + + int32 num_speakers = 2; + ParseOptions po(usage); + + po.Register("num-speakers", &num_speakers, "Number of speakers to use in the k-means clustering algorithm"); + + ClusterKMeansOptions cfg; + + po.Read(argc, argv); + + if (po.NumArgs() != 4 && po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + int32 num_utt_err = 0; + int32 num_utt_done = 0; + + std::string plda_rxfilename, spk2utt_rspecifier, ivector_rspecifier, + diar_wspecifier; + + if (po.NumArgs() == 4) { + plda_rxfilename = po.GetArg(1); + spk2utt_rspecifier = po.GetArg(2); + ivector_rspecifier = po.GetArg(3); + diar_wspecifier = po.GetArg(4); + } else { + spk2utt_rspecifier = po.GetArg(1); + ivector_rspecifier = po.GetArg(2); + diar_wspecifier = po.GetArg(3); + } + + Plda plda; + PldaConfig plda_config; + int32 dim = 0; + + if (plda_rxfilename != "") { + ReadKaldiObject(plda_rxfilename, &plda); + dim = plda.Dim(); + } + + SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); + RandomAccessBaseFloatVectorReader ivector_reader(ivector_rspecifier); + Int32Writer diar_writer(diar_wspecifier); + for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { + std::string spk = spk2utt_reader.Key(); + const std::vector &uttlist = spk2utt_reader.Value(); + std::vector ivector_clusters; + ivector_clusters.reserve(uttlist.size()); + + for (size_t i = 0; i < uttlist.size(); i++) { + std::string utt = uttlist[i]; + if (!ivector_reader.HasKey(utt)) { + KALDI_WARN << "No iVector present in input for utterance " << utt; + num_utt_err++; + } else { + ivector_clusters.resize(ivector_clusters.size() + 1); + Vector ivector = ivector_reader.Value(utt); + Vector *transformed_ivector; + + if (plda_rxfilename != "") { + transformed_ivector = new Vector(dim); + plda.TransformIvector(plda_config, ivector, + transformed_ivector); + } else { + transformed_ivector = &ivector; + } + + Clusterable *cluster = new VectorClusterable(*transformed_ivector, 1.0); + ivector_clusters.back() = cluster; + num_utt_done++; + + if (plda_rxfilename != "") { + delete transformed_ivector; + } + } + } + std::vector ivector_clusters_out; + std::vector assignments_out; + //BaseFloat imprv = ClusterKMeansOnce(ivector_clusters, 2, &ivector_clusters_out, &assignments_out, cfg); + BaseFloat imprv = ClusterKMeans(ivector_clusters, num_speakers, &ivector_clusters_out, &assignments_out, cfg); + for (int32 i = 0; i < ivector_clusters.size(); i++) { + delete ivector_clusters[i]; + } + for (int32 i = 0; i < ivector_clusters_out.size(); i++) { + delete ivector_clusters_out[i]; + } + + for (size_t i = 0; i < uttlist.size(); i++) { + std::string utt = uttlist[i]; + diar_writer.Write(utt, assignments_out[i]); + } + KALDI_LOG << "Objf improvement is " << imprv << " for utt " << spk; + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/matrix/Makefile b/src/matrix/Makefile index 6cc9645f15b..e0b892cd54f 100644 --- a/src/matrix/Makefile +++ b/src/matrix/Makefile @@ -10,7 +10,7 @@ include ../kaldi.mk # you can uncomment matrix-lib-speed-test if you want to do the speed tests. -TESTFILES = matrix-lib-test kaldi-gpsr-test sparse-matrix-test #matrix-lib-speed-test +TESTFILES = matrix-lib-test kaldi-gpsr-test sparse-matrix-test matrix-lib-speed-test OBJFILES = kaldi-matrix.o kaldi-vector.o packed-matrix.o sp-matrix.o tp-matrix.o \ matrix-functions.o qr.o srfft.o kaldi-gpsr.o compressed-matrix.o \ diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index 198616ef3ed..4491d364911 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -393,6 +393,87 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, } } +template +void MatrixBase::LogAddExpMat(const Real alpha, const MatrixBase& A, + MatrixTransposeType transA) { + if (alpha == 0) return; + + if (&A == this) { + if (transA == kNoTrans) { + Add(alpha + 1.0); + } else { + KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self (transposed): not symmetric."); + Real *data = data_; + if (alpha == 1.0) { // common case-- handle separately. + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, + *upper = data + (col * stride_) + row; + Real sum = LogAdd(*lower, *upper); + *lower = *upper = sum; + } + *(data + (row * stride_) + row) += Log(2.0); // diagonal. + } + } else { + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, + *upper = data + (col * stride_) + row; + Real lower_tmp = *lower; + if (alpha > 0) { + *lower = LogAdd(*lower, Log(alpha) + *upper); + *upper = LogAdd(*upper, Log(alpha) + lower_tmp); + } else { + KALDI_ASSERT(alpha < 0); + *lower = LogSub(*lower, Log(-alpha) + *upper); + *upper = LogSub(*upper, Log(-alpha) + lower_tmp); + } + } + if (alpha > -1.0) + *(data + (row * stride_) + row) += Log(1.0 + alpha); // diagonal. + else + KALDI_ERR << "Cannot subtract log-matrices if the difference is " + << "negative"; + } + } + } + } else { + int aStride = (int) A.stride_; + Real *adata = A.data_, *data = data_; + if (transA == kNoTrans) { + KALDI_ASSERT(A.num_rows_ == num_rows_ && A.num_cols_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + Real *value = data + (row * stride_) + col, + *aValue = adata + (row * aStride) + col; + if (alpha > 0) + *value = LogAdd(*value, Log(alpha) + *aValue); + else { + KALDI_ASSERT(alpha < 0); + *value = LogSub(*value, Log(-alpha) + *aValue); + } + } + } + } else { + KALDI_ASSERT(A.num_cols_ == num_rows_ && A.num_rows_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + Real *value = data + (row * stride_) + col, + *aValue = adata + (col * aStride) + row; + if (alpha > 0) + *value = LogAdd(*value, Log(alpha) + *aValue); + else { + KALDI_ASSERT(alpha < 0); + *value = LogSub(*value, Log(-alpha) + *aValue); + } + } + } + } + } +} + template template void MatrixBase::AddSp(const Real alpha, const SpMatrix &S) { @@ -1969,6 +2050,17 @@ void MatrixBase::ApplyHeaviside() { } } +template +void MatrixBase::ApplySignum() { + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + for (MatrixIndexT i = 0; i < num_rows; i++) { + Real *data = this->RowData(i); + for (MatrixIndexT j = 0; j < num_cols; j++) { + if (data[j] > 0) data[j] = 1.0; + else if (data[j] < 0) data[j] = -1.0; + } + } +} template bool MatrixBase::Power(Real power) { diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index c16ffb22135..3323971c776 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -348,7 +348,11 @@ class MatrixBase { /// please leave it as it (i.e. returning zero) because it affects the /// RectifiedLinearComponent in the neural net code. void ApplyHeaviside(); - + + /// Applies the Signum function (1 if x > 0, 0 if x = 0 and -1 if x < 0) + /// to all matrix elements + void ApplySignum(); + /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output is /// slightly complicated, due to the need for P to be real. In the symmetric @@ -537,6 +541,10 @@ class MatrixBase { /// *this += alpha * M [or M^T] void AddMat(const Real alpha, const MatrixBase &M, MatrixTransposeType transA = kNoTrans); + + /// *this += alpha * M [or M^T] when the matrices are stored as log + void LogAddExpMat(const Real alpha, const MatrixBase &M, + MatrixTransposeType transA = kNoTrans); /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only /// updates the lower triangle of *this. It will leave the matrix asymmetric; diff --git a/src/matrix/kaldi-vector.cc b/src/matrix/kaldi-vector.cc index 41a25f598c2..4ee64e33091 100644 --- a/src/matrix/kaldi-vector.cc +++ b/src/matrix/kaldi-vector.cc @@ -1033,6 +1033,37 @@ void VectorBase::AddVec(const float alpha, const VectorBase &v); template void VectorBase::AddVec(const double alpha, const VectorBase &v); + +template +void VectorBase::LogAddExpVec(const Real alpha, const VectorBase &v) { + KALDI_ASSERT(dim_ == v.dim_); + if (alpha == 0) return; + + // remove __restrict__ if it causes compilation problems. + register Real *__restrict__ data = data_; + register Real *__restrict__ other_data = v.data_; + + MatrixIndexT dim = dim_; + if (alpha != 1.0) + for (MatrixIndexT i = 0; i < dim; i++) { + if (alpha > 0) { + data[i] = LogAdd(data[i], other_data[i] + Log(alpha)); + } else { + KALDI_ASSERT(alpha < 0); + data[i] = LogSub(data[i], other_data[i] + Log(-alpha)); + } + } + else + for (MatrixIndexT i = 0; i < dim; i++) + data[i] = LogAdd(data[i], other_data[i]); +} + +template +void VectorBase::LogAddExpVec(const float alpha, const VectorBase &v); +template +void VectorBase::LogAddExpVec(const double alpha, const VectorBase &v); + + template template void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { diff --git a/src/matrix/kaldi-vector.h b/src/matrix/kaldi-vector.h index 498ddda302d..b6a217efb9e 100644 --- a/src/matrix/kaldi-vector.h +++ b/src/matrix/kaldi-vector.h @@ -177,6 +177,10 @@ class VectorBase { template void AddVec(const Real alpha, const VectorBase &v); + /// Add vector : *this = *this + alpha * rv, where the vector and rv are + /// stored in log and the answer is returned in log + void LogAddExpVec(const Real alpha, const VectorBase &v); + /// Add vector : *this = *this + alpha * rv^2 [element-wise squaring]. void AddVec2(const Real alpha, const VectorBase &v); diff --git a/src/matrix/matrix-lib-speed-test.cc b/src/matrix/matrix-lib-speed-test.cc index cd5405b8c3d..00115a4377b 100644 --- a/src/matrix/matrix-lib-speed-test.cc +++ b/src/matrix/matrix-lib-speed-test.cc @@ -268,6 +268,45 @@ static void UnitTestAddVecToColsSpeed() { KALDI_LOG << __func__ << " finished in " << t.Elapsed() << " seconds."; } +template static void UnitTestLogAddExpMat() { + std::vector sizes; + sizes.push_back(512); + sizes.push_back(1024); + + Matrix alphas_mat(1,5); + alphas_mat.SetRandUniform(); + Vector alphas(alphas_mat.Row(0)); + //alphas.Add(-0.5); + //alphas.Scale(2.0): + + for (size_t i = 0; i < sizes.size(); i++) { + MatrixIndexT size = sizes[i]; + { + for (int32 j=0; j<5; j++) { + Matrix A(size,size), B(size, size); + A.SetRandn(); B.SetRandn(); + A.ApplyPowAbs(1.0); + B.ApplyPowAbs(1.0); + Matrix logA(A); + logA.ApplyLog(); + Matrix logB(B); + logB.ApplyLog(); + + Real alpha = alphas(j); + Matrix sum1(A); + sum1.AddMat(alpha, B, kNoTrans); + + if (alpha > 0) { + Matrix sum2(logA); + sum2.LogAddExpMat(alpha, logB, kNoTrans); + sum2.ApplyExp(); + KALDI_ASSERT(sum1.ApproxEqual(sum2)); + } + } + } + } +} + template static void MatrixUnitSpeedTest() { UnitTestRealFftSpeed(); UnitTestSplitRadixRealFftSpeed(); @@ -277,6 +316,7 @@ template static void MatrixUnitSpeedTest() { UnitTestAddColSumMatSpeed(); UnitTestAddVecToRowsSpeed(); UnitTestAddVecToColsSpeed(); + UnitTestLogAddExpMat(); } } // namespace kaldi diff --git a/src/nnet3/nnet-chain-example.cc b/src/nnet3/nnet-chain-example.cc index 4e9438ee378..76bc8bbc66b 100644 --- a/src/nnet3/nnet-chain-example.cc +++ b/src/nnet3/nnet-chain-example.cc @@ -25,49 +25,6 @@ namespace kaldi { namespace nnet3 { -// writes compressed as unsigned char a vector 'vec' that is required to have -// values between 0 and 1. -static inline void WriteVectorAsChar(std::ostream &os, - bool binary, - const VectorBase &vec) { - if (binary) { - int32 dim = vec.Dim(); - std::vector char_vec(dim); - const BaseFloat *data = vec.Data(); - for (int32 i = 0; i < dim; i++) { - BaseFloat value = data[i]; - KALDI_ASSERT(value >= 0.0 && value <= 1.0); - // below, the adding 0.5 is done so that we round to the closest integer - // rather than rounding down (since static_cast will round down). - char_vec[i] = static_cast(255.0 * value + 0.5); - } - WriteIntegerVector(os, binary, char_vec); - } else { - // the regular floating-point format will be more readable for text mode. - vec.Write(os, binary); - } -} - -// reads data written by WriteVectorAsChar. -static inline void ReadVectorAsChar(std::istream &is, - bool binary, - Vector *vec) { - if (binary) { - BaseFloat scale = 1.0 / 255.0; - std::vector char_vec; - ReadIntegerVector(is, binary, &char_vec); - int32 dim = char_vec.size(); - vec->Resize(dim, kUndefined); - BaseFloat *data = vec->Data(); - for (int32 i = 0; i < dim; i++) - data[i] = scale * char_vec[i]; - } else { - vec->Read(is, binary); - } -} - - - void NnetChainSupervision::Write(std::ostream &os, bool binary) const { CheckDim(); WriteToken(os, binary, ""); diff --git a/src/nnet3/nnet-component-itf.cc b/src/nnet3/nnet-component-itf.cc index 151433f2c62..73267049e1c 100644 --- a/src/nnet3/nnet-component-itf.cc +++ b/src/nnet3/nnet-component-itf.cc @@ -58,6 +58,10 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new SoftmaxComponent(); } else if (component_type == "LogSoftmaxComponent") { ans = new LogSoftmaxComponent(); + } else if (component_type == "LogComponent") { + ans = new LogComponent(); + } else if (component_type == "ExpComponent") { + ans = new ExpComponent(); } else if (component_type == "RectifiedLinearComponent") { ans = new RectifiedLinearComponent(); } else if (component_type == "NormalizeComponent") { @@ -70,6 +74,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new AffineComponent(); } else if (component_type == "NaturalGradientAffineComponent") { ans = new NaturalGradientAffineComponent(); + } else if (component_type == "NaturalGradientPositiveAffineComponent") { + ans = new NaturalGradientPositiveAffineComponent(); } else if (component_type == "PerElementScaleComponent") { ans = new PerElementScaleComponent(); } else if (component_type == "NaturalGradientPerElementScaleComponent") { diff --git a/src/nnet3/nnet-component-itf.h b/src/nnet3/nnet-component-itf.h index 93cc0769bf6..b18ffeaa1b4 100644 --- a/src/nnet3/nnet-component-itf.h +++ b/src/nnet3/nnet-component-itf.h @@ -74,10 +74,12 @@ enum ComponentProperties { // forward-pass output (e.g. true for Sigmoid). kBackpropInPlace = 0x400, // true if we can do the backprop operation in-place // (input and output matrices may be the same). - kStoresStats = 0x800 // true if the StoreStats operation stores + kStoresStats = 0x800, // true if the StoreStats operation stores // statistics e.g. on average node activations and // derivatives of the nonlinearity, (as it does for // Tanh, Sigmoid, ReLU and Softmax). + kSparsityPrior = 0x1000, + kPositiveLinearParameters = 0x2000 }; diff --git a/src/nnet3/nnet-diagnostics.cc b/src/nnet3/nnet-diagnostics.cc index 7f7d485ffe0..f690d534431 100644 --- a/src/nnet3/nnet-diagnostics.cc +++ b/src/nnet3/nnet-diagnostics.cc @@ -95,9 +95,17 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg, { BaseFloat tot_weight, tot_objf; bool supply_deriv = config_.compute_deriv; + + CuMatrix nnet_output_deriv(output.NumRows(), output.NumCols()); + ComputeObjectiveFunction(io.features, obj_type, io.name, - supply_deriv, computer, - &tot_weight, &tot_objf); + output, + &tot_weight, &tot_objf, + supply_deriv ? &nnet_output_deriv : NULL); + + if (supply_deriv) + computer->AcceptOutputDeriv(io.name, &nnet_output_deriv); + SimpleObjectiveInfo &totals = objf_info_[io.name]; totals.tot_weight += tot_weight; totals.tot_objective += tot_objf; diff --git a/src/nnet3/nnet-example-utils.cc b/src/nnet3/nnet-example-utils.cc index 99d41fb06c4..7411d59d224 100644 --- a/src/nnet3/nnet-example-utils.cc +++ b/src/nnet3/nnet-example-utils.cc @@ -219,5 +219,43 @@ void GetComputationRequest(const Nnet &nnet, KALDI_ERR << "No outputs in computation request."; } +void WriteVectorAsChar(std::ostream &os, + bool binary, + const VectorBase &vec) { + if (binary) { + int32 dim = vec.Dim(); + std::vector char_vec(dim); + const BaseFloat *data = vec.Data(); + for (int32 i = 0; i < dim; i++) { + BaseFloat value = data[i]; + KALDI_ASSERT(value >= 0.0 && value <= 1.0); + // below, the adding 0.5 is done so that we round to the closest integer + // rather than rounding down (since static_cast will round down). + char_vec[i] = static_cast(255.0 * value + 0.5); + } + WriteIntegerVector(os, binary, char_vec); + } else { + // the regular floating-point format will be more readable for text mode. + vec.Write(os, binary); + } +} + +void ReadVectorAsChar(std::istream &is, + bool binary, + Vector *vec) { + if (binary) { + BaseFloat scale = 1.0 / 255.0; + std::vector char_vec; + ReadIntegerVector(is, binary, &char_vec); + int32 dim = char_vec.size(); + vec->Resize(dim, kUndefined); + BaseFloat *data = vec->Data(); + for (int32 i = 0; i < dim; i++) + data[i] = scale * char_vec[i]; + } else { + vec->Read(is, binary); + } +} + } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-example-utils.h b/src/nnet3/nnet-example-utils.h index d54c3296dac..afa68ee10dc 100644 --- a/src/nnet3/nnet-example-utils.h +++ b/src/nnet3/nnet-example-utils.h @@ -62,6 +62,16 @@ void GetComputationRequest(const Nnet &nnet, bool store_component_stats, ComputationRequest *computation_request); +// writes compressed as unsigned char a vector 'vec' that is required to have +// values between 0 and 1. +void WriteVectorAsChar(std::ostream &os, + bool binary, + const VectorBase &vec); + +// reads data written by WriteVectorAsChar. +void ReadVectorAsChar(std::istream &is, + bool binary, + Vector *vec); } // namespace nnet3 diff --git a/src/nnet3/nnet-example.cc b/src/nnet3/nnet-example.cc index 9a34258e0ee..87221b3dfbc 100644 --- a/src/nnet3/nnet-example.cc +++ b/src/nnet3/nnet-example.cc @@ -19,6 +19,7 @@ // limitations under the License. #include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" #include "lat/lattice-functions.h" #include "hmm/posterior.h" @@ -31,6 +32,8 @@ void NnetIo::Write(std::ostream &os, bool binary) const { WriteToken(os, binary, name); WriteIndexVector(os, binary, indexes); features.Write(os, binary); + WriteToken(os, binary, ""); // for DerivWeights. Want to save space. + WriteVectorAsChar(os, binary, deriv_weights); WriteToken(os, binary, ""); KALDI_ASSERT(static_cast(features.NumRows()) == indexes.size()); } @@ -40,7 +43,14 @@ void NnetIo::Read(std::istream &is, bool binary) { ReadToken(is, binary, &name); ReadIndexVector(is, binary, &indexes); features.Read(is, binary); - ExpectToken(is, binary, ""); + std::string token; + ReadToken(is, binary, &token); + // in the future this back-compatibility code can be reworked. + if (token != "") { + KALDI_ASSERT(token == ""); + ReadVectorAsChar(is, binary, &deriv_weights); + ExpectToken(is, binary, ""); + } } bool NnetIo::operator == (const NnetIo &other) const { @@ -52,7 +62,8 @@ bool NnetIo::operator == (const NnetIo &other) const { Matrix this_mat, other_mat; features.GetMatrix(&this_mat); other.features.GetMatrix(&other_mat); - return ApproxEqual(this_mat, other_mat); + return ApproxEqual(this_mat, other_mat) && + deriv_weights.ApproxEqual(other.deriv_weights); } NnetIo::NnetIo(const std::string &name, @@ -65,10 +76,44 @@ NnetIo::NnetIo(const std::string &name, indexes[i].t = t_begin + i; } +NnetIo::NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, const MatrixBase &feats): + name(name), features(feats), deriv_weights(deriv_weights) { + int32 num_rows = feats.NumRows(); + KALDI_ASSERT(num_rows > 0); + indexes.resize(num_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_rows; i++) + indexes[i].t = t_begin + i; +} + void NnetIo::Swap(NnetIo *other) { name.swap(other->name); indexes.swap(other->indexes); features.Swap(&(other->features)); + deriv_weights.Swap(&(other->deriv_weights)); +} + +NnetIo::NnetIo(const std::string &name, + int32 t_begin, + const SparseMatrix &feats): + name(name), features(feats) { + int32 num_rows = feats.NumRows(); + KALDI_ASSERT(num_rows > 0); + indexes.resize(num_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_rows; i++) + indexes[i].t = t_begin + i; +} + +NnetIo::NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, const SparseMatrix &feats): + name(name), features(feats), deriv_weights(deriv_weights) { + int32 num_rows = feats.NumRows(); + KALDI_ASSERT(num_rows > 0); + indexes.resize(num_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_rows; i++) + indexes[i].t = t_begin + i; } NnetIo::NnetIo(const std::string &name, @@ -85,7 +130,20 @@ NnetIo::NnetIo(const std::string &name, indexes[i].t = t_begin + i; } - +NnetIo::NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 dim, + int32 t_begin, + const Posterior &labels): + name(name), deriv_weights(deriv_weights) { + int32 num_rows = labels.size(); + KALDI_ASSERT(num_rows > 0); + SparseMatrix sparse_feats(dim, labels); + features = sparse_feats; + indexes.resize(num_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_rows; i++) + indexes[i].t = t_begin + i; +} void NnetExample::Write(std::ostream &os, bool binary) const { // Note: weight, label, input_frames and spk_info are members. This is a diff --git a/src/nnet3/nnet-example.h b/src/nnet3/nnet-example.h index 1df7cd1e78e..eb57bdd6a11 100644 --- a/src/nnet3/nnet-example.h +++ b/src/nnet3/nnet-example.h @@ -44,6 +44,15 @@ struct NnetIo { /// The features or labels. GeneralMatrix may contain either a CompressedMatrix, /// a Matrix, or SparseMatrix (a SparseMatrix would be the natural format for posteriors). GeneralMatrix features; + + /// This is a vector of per-frame weights, required to be between 0 and 1, + /// that is applied to the derivative during training (but not during model + /// combination, where the derivatives need to agree with the computed objf + /// values for the optimization code to work). + /// If this vector is empty it means we're not applying per-frame weights, + /// so it's equivalent to a vector of all ones. This vector is written + /// to disk compactly as unsigned char. + Vector deriv_weights; /// This constructor creates NnetIo with name "name", indexes with n=0, x=0, /// and t values ranging from t_begin to t_begin + feats.NumRows() - 1, and @@ -51,6 +60,19 @@ struct NnetIo { /// represents. NnetIo(const std::string &name, int32 t_begin, const MatrixBase &feats); + + NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, const MatrixBase &feats); + + /// This constructor is similar to the above constructed, + /// but takes in sparse input features. + NnetIo(const std::string &name, + int32 t_begin, const SparseMatrix &feats); + + NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, const SparseMatrix &feats); /// This constructor sets "name" to the provided string, sets "indexes" with /// n=0, x=0, and t from t_begin to t_begin + labels.size() - 1, and the labels @@ -59,6 +81,12 @@ struct NnetIo { int32 dim, int32 t_begin, const Posterior &labels); + + NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 dim, + int32 t_begin, + const Posterior &labels); void Swap(NnetIo *other); @@ -103,7 +131,6 @@ struct NnetExample { bool operator == (const NnetExample &other) const { return io == other.io; } }; - typedef TableWriter > NnetExampleWriter; typedef SequentialTableReader > SequentialNnetExampleReader; typedef RandomAccessTableReader > RandomAccessNnetExampleReader; diff --git a/src/nnet3/nnet-nnet.cc b/src/nnet3/nnet-nnet.cc index 8dea02b8918..597755f7764 100644 --- a/src/nnet3/nnet-nnet.cc +++ b/src/nnet3/nnet-nnet.cc @@ -73,8 +73,14 @@ std::string Nnet::GetAsConfigLine(int32 node_index, bool include_dim) const { node.descriptor.WriteConfig(ans, node_names_); if (include_dim) ans << " dim=" << node.Dim(*this); - ans << " objective=" << (node.u.objective_type == kLinear ? "linear" : - "quadratic"); + + if (node.u.objective_type == kLinear) + ans << " objective=linear"; + else if (node.u.objective_type == kQuadratic) + ans << " objective=quadratic"; + else if (node.u.objective_type == kCrossEntropy) + ans << " objective=xent"; + break; case kComponent: ans << "component-node name=" << name << " component=" @@ -100,8 +106,7 @@ bool Nnet::IsOutputNode(int32 node) const { int32 size = nodes_.size(); KALDI_ASSERT(node >= 0 && node < size); return (nodes_[node].node_type == kDescriptor && - (node + 1 == size || - nodes_[node + 1].node_type != kComponent)); + (nodes_[node + 1].node_type != kComponent)); } bool Nnet::IsInputNode(int32 node) const { @@ -385,6 +390,8 @@ void Nnet::ProcessOutputNodeConfigLine( nodes_[node_index].u.objective_type = kLinear; } else if (objective_type == "quadratic") { nodes_[node_index].u.objective_type = kQuadratic; + } else if (objective_type == "xent") { + nodes_[node_index].u.objective_type = kCrossEntropy; } else { KALDI_ERR << "Invalid objective type: " << objective_type; } diff --git a/src/nnet3/nnet-nnet.h b/src/nnet3/nnet-nnet.h index a48fbb26f88..672a871c563 100644 --- a/src/nnet3/nnet-nnet.h +++ b/src/nnet3/nnet-nnet.h @@ -49,7 +49,12 @@ namespace nnet3 { /// - Objective type kQuadratic is used to mean the objective function /// f(x, y) = -0.5 (x-y).(x-y), which is to be maximized, as in the kLinear /// case. -enum ObjectiveType { kLinear, kQuadratic }; +/// - Objective type kCrossEntropy is the objective function that is used +/// to learn a set of bernoulli random variables. +/// f(x, y) = x * y + (1-x) * Log(1-Exp(y)), where +/// x is the true probability of class 1 and +/// y is the predicted log probability of class 1 +enum ObjectiveType { kLinear, kQuadratic, kCrossEntropy }; enum NodeType { kInput, kDescriptor, kComponent, kDimRange, kNone }; diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index 533f962a6db..ed01a9554b1 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -2020,7 +2020,59 @@ void PerElementOffsetComponent::UnVectorize( offsets_.CopyFromVec(params); } +const BaseFloat LogComponent::kLogFloor = 1.0e-10; +void LogComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + // Apllies log function (x >= epsi ? log(x) : log(epsi)). + out->CopyFromMat(in); + out->ApplyFloor(kLogFloor); + out->ApplyLog(); +} + +void LogComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const { + if (in_deriv != NULL) { + CuMatrix divided_in_value(in_value), floored_in_value(in_value); + divided_in_value.Set(1.0); + floored_in_value.CopyFromMat(in_value); + floored_in_value.ApplyFloor(kLogFloor); // (x > epsi ? x : epsi) + + divided_in_value.DivElements(floored_in_value); // (x > epsi ? 1/x : 1/epsi) + in_deriv->CopyFromMat(in_value); + in_deriv->Add(-1.0 * kLogFloor); // (x - epsi) + in_deriv->ApplyHeaviside(); // (x > epsi ? 1 : 0) + in_deriv->MulElements(divided_in_value); // (dy/dx: x > epsi ? 1/x : 0) + in_deriv->MulElements(out_deriv); // dF/dx = dF/dy * dy/dx + } +} + +void ExpComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + // Applied exp function + out->CopyFromMat(in); + out->ApplyExp(); +} + +void ExpComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &,//in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const { + if (in_deriv != NULL) { + in_deriv->CopyFromMat(out_value); + in_deriv->MulElements(out_deriv); + } +} NaturalGradientAffineComponent::NaturalGradientAffineComponent(): max_change_per_sample_(0.0), @@ -2072,10 +2124,15 @@ void NaturalGradientAffineComponent::Read(std::istream &is, bool binary) { ReadBasicType(is, binary, &max_change_scale_stats_); ReadToken(is, binary, &token); } - if (token != "" && - token != "") - KALDI_ERR << "Expected or " - << ", got " << token; + + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + + if (token != ostr_end.str() && + token != ostr_beg.str()) + KALDI_ERR << "Expected " << ostr_beg.str() << " or " + << ostr_end.str() << ", got " << token; SetNaturalGradientConfigs(); } @@ -2221,7 +2278,10 @@ void NaturalGradientAffineComponent::Write(std::ostream &os, WriteBasicType(os, binary, active_scaling_count_); WriteToken(os, binary, ""); WriteBasicType(os, binary, max_change_scale_stats_); - WriteToken(os, binary, ""); + + std::ostringstream ostr_end; + ostr_end << ""; // e.g. "" + WriteToken(os, binary, ostr_end.str()); } std::string NaturalGradientAffineComponent::Info() const { @@ -2264,7 +2324,7 @@ NaturalGradientAffineComponent::NaturalGradientAffineComponent( SetNaturalGradientConfigs(); } -void NaturalGradientAffineComponent::Update( +BaseFloat NaturalGradientAffineComponent::Update( const std::string &debug_info, const CuMatrixBase &in_value, const CuMatrixBase &out_deriv) { @@ -2315,6 +2375,7 @@ void NaturalGradientAffineComponent::Update( precon_ones, 1.0); linear_params_.AddMatMat(local_lrate, out_deriv_temp, kTrans, in_value_precon_part, kNoTrans, 1.0); + return local_lrate; } void NaturalGradientAffineComponent::ZeroStats() { @@ -2342,6 +2403,404 @@ void NaturalGradientAffineComponent::Add(BaseFloat alpha, const Component &other bias_params_.AddVec(alpha, other->bias_params_); } +NaturalGradientPositiveAffineComponent::NaturalGradientPositiveAffineComponent(): + NaturalGradientAffineComponent(), ensure_positive_linear_component_(false), + sparsity_constant_(0.0) { } + +void NaturalGradientPositiveAffineComponent::Read(std::istream &is, bool binary) { + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + // might not see the "" part because + // of how ReadNew() works. + std::string token; + ReadToken(is, binary, &token); + if (token == ostr_beg.str()) { + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &ensure_positive_linear_component_); + ReadToken(is, binary, &token); + } else if (token == "") { + ReadBasicType(is, binary, &ensure_positive_linear_component_); + ReadToken(is, binary, &token); + } // else KALDI_ASSERT(token == ""); + if (token == "") { + ReadBasicType(is, binary, &sparsity_constant_); + ReadToken(is, binary, &token); + } + if (token != "") { + KALDI_ERR << "Expecting token ; got " << token; + } + ReadBasicType(is, binary, &learning_rate_); + ExpectToken(is, binary, ""); + linear_params_.Read(is, binary); + ExpectToken(is, binary, ""); + bias_params_.Read(is, binary); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &rank_in_); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &rank_out_); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &update_period_); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &num_samples_history_); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &alpha_); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &max_change_per_sample_); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &is_gradient_); + ReadToken(is, binary, &token); + if (token == "") { + ReadBasicType(is, binary, &update_count_); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &active_scaling_count_); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &max_change_scale_stats_); + ExpectToken(is, binary, ostr_end.str()); + } else { + if (token != ostr_end.str()) + KALDI_ERR << "Expected " + << ostr_end.str() << ", got " << token; + } + SetNaturalGradientConfigs(); +} + +void NaturalGradientPositiveAffineComponent::InitFromConfig(ConfigLine *cfl) { + bool ok = true; + std::string matrix_filename; + BaseFloat learning_rate = learning_rate_; + BaseFloat num_samples_history = 2000.0, alpha = 4.0, + max_change_per_sample = 0.0; + int32 input_dim = -1, output_dim = -1, rank_in = 20, rank_out = 80, + update_period = 4; + bool ensure_positive_linear_component = false; + BaseFloat sparsity_constant = 0.0; + cfl->GetValue("learning-rate", &learning_rate); // optional. + cfl->GetValue("num-samples-history", &num_samples_history); + cfl->GetValue("alpha", &alpha); + cfl->GetValue("max-change-per-sample", &max_change_per_sample); + cfl->GetValue("rank-in", &rank_in); + cfl->GetValue("rank-out", &rank_out); + cfl->GetValue("update-period", &update_period); + cfl->GetValue("ensure-positive-linear-component", &ensure_positive_linear_component); + cfl->GetValue("sparsity-constant", &sparsity_constant); + + if (cfl->GetValue("matrix", &matrix_filename)) { + Init(rank_in, rank_out, update_period, + num_samples_history, alpha, max_change_per_sample, + ensure_positive_linear_component, sparsity_constant, + matrix_filename); + if (cfl->GetValue("input-dim", &input_dim)) + KALDI_ASSERT(input_dim == InputDim() && + "input-dim mismatch vs. matrix."); + if (cfl->GetValue("output-dim", &output_dim)) + KALDI_ASSERT(output_dim == OutputDim() && + "output-dim mismatch vs. matrix."); + } else { + ok = ok && cfl->GetValue("input-dim", &input_dim); + ok = ok && cfl->GetValue("output-dim", &output_dim); + BaseFloat param_stddev = 1.0 / std::sqrt(input_dim), + bias_stddev = 1.0, bias_mean = 0.0; + cfl->GetValue("param-stddev", ¶m_stddev); + cfl->GetValue("bias-stddev", &bias_stddev); + cfl->GetValue("bias-mean", &bias_mean); + Init(input_dim, output_dim, param_stddev, + bias_mean, bias_stddev, rank_in, rank_out, update_period, + num_samples_history, alpha, max_change_per_sample, + ensure_positive_linear_component, sparsity_constant); + } + if (cfl->HasUnusedValues()) + KALDI_ERR << "Could not process these elements in initializer: " + << cfl->UnusedValues(); + if (!ok) + KALDI_ERR << "Bad initializer " << cfl->WholeLine(); +} + +void NaturalGradientPositiveAffineComponent::Init( + int32 rank_in, int32 rank_out, + int32 update_period, BaseFloat num_samples_history, BaseFloat alpha, + BaseFloat max_change_per_sample, bool ensure_positive_linear_component, + BaseFloat sparsity_constant, std::string matrix_filename) { + NaturalGradientAffineComponent::Init(rank_in, rank_out, update_period, + num_samples_history, alpha, max_change_per_sample, matrix_filename); + ensure_positive_linear_component_ = ensure_positive_linear_component; + sparsity_constant_ = sparsity_constant; + SetPositive(NaturalGradientPositiveAffineComponent::kFloor); +} + +void NaturalGradientPositiveAffineComponent::Init( + int32 input_dim, int32 output_dim, + BaseFloat param_stddev, + BaseFloat bias_mean, BaseFloat bias_stddev, + int32 rank_in, int32 rank_out, int32 update_period, + BaseFloat num_samples_history, BaseFloat alpha, + BaseFloat max_change_per_sample, + bool ensure_positive_linear_component, BaseFloat sparsity_constant) { + NaturalGradientAffineComponent::Init(input_dim, output_dim, param_stddev, + bias_mean, bias_stddev, rank_in, rank_out, update_period, + num_samples_history, alpha, max_change_per_sample); + ensure_positive_linear_component_ = ensure_positive_linear_component; + sparsity_constant_ = sparsity_constant; + SetPositive(NaturalGradientPositiveAffineComponent::kAbsoluteValue); +} + +void NaturalGradientPositiveAffineComponent::Write(std::ostream &os, bool binary) const { + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + WriteToken(os, binary, ostr_beg.str()); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, ensure_positive_linear_component_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, sparsity_constant_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, learning_rate_); + WriteToken(os, binary, ""); + linear_params_.Write(os, binary); + WriteToken(os, binary, ""); + bias_params_.Write(os, binary); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, rank_in_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, rank_out_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, update_period_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, num_samples_history_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, alpha_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, max_change_per_sample_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, is_gradient_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, update_count_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, active_scaling_count_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, max_change_scale_stats_); + WriteToken(os, binary, ostr_end.str()); +} + +std::string NaturalGradientPositiveAffineComponent::Info() const { + std::stringstream stream; + stream << NaturalGradientAffineComponent::Info() + << ", ensure-positive-linear-component=" + << ensure_positive_linear_component_ + << ", sparsity=constant=" << sparsity_constant_; + return stream.str(); +} + +Component* NaturalGradientPositiveAffineComponent::Copy() const { + NaturalGradientPositiveAffineComponent *ans = new NaturalGradientPositiveAffineComponent(); + ans->learning_rate_ = learning_rate_; + ans->rank_in_ = rank_in_; + ans->rank_out_ = rank_out_; + ans->update_period_ = update_period_; + ans->num_samples_history_ = num_samples_history_; + ans->alpha_ = alpha_; + ans->linear_params_ = linear_params_; + ans->bias_params_ = bias_params_; + ans->preconditioner_in_ = preconditioner_in_; + ans->preconditioner_out_ = preconditioner_out_; + ans->max_change_per_sample_ = max_change_per_sample_; + ans->is_gradient_ = is_gradient_; + ans->update_count_ = update_count_; + ans->active_scaling_count_ = active_scaling_count_; + ans->max_change_scale_stats_ = max_change_scale_stats_; + ans->SetNaturalGradientConfigs(); + ans->ensure_positive_linear_component_ = ensure_positive_linear_component_; + ans->sparsity_constant_ = sparsity_constant_; + return ans; +} + +void NaturalGradientPositiveAffineComponent::Backprop( + const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &, // out_value + const CuMatrixBase &out_deriv, + Component *to_update_in, + CuMatrixBase *in_deriv) const { + NaturalGradientPositiveAffineComponent *to_update = + dynamic_cast(to_update_in); + + // Propagate the derivative back to the input. + // add with coefficient 1.0 since property kBackpropAdds is true. + // If we wanted to add with coefficient 0.0 we'd need to zero the + // in_deriv, in case of infinities. + if (in_deriv) + in_deriv->AddMatMat(1.0, out_deriv, kNoTrans, linear_params_, kNoTrans, + 1.0); + + if (to_update != NULL) { + to_update->Update(debug_info, in_value, out_deriv, linear_params_); + } +} + +BaseFloat NaturalGradientPositiveAffineComponent::Update( + const std::string &debug_info, + const CuMatrixBase &in_value, + const CuMatrixBase &out_deriv, + const CuMatrixBase &linear_params) { + BaseFloat local_lrate = NaturalGradientAffineComponent::Update(debug_info, in_value, out_deriv); + CuMatrix sign_linear_params(linear_params); + sign_linear_params.ApplySignum(); + if (sparsity_constant_ > 0.0) + linear_params_.AddMat(-sparsity_constant_ / in_value.NumRows() * local_lrate, sign_linear_params); + return local_lrate; +} + +void NaturalGradientPositiveAffineComponent::Scale(BaseFloat scale) { + if (ensure_positive_linear_component_ && scale < 0) + KALDI_ERR << "Scaling a positive linear component by a negative value!"; + NaturalGradientAffineComponent::Scale(scale); +} + +void NaturalGradientPositiveAffineComponent::Add(BaseFloat alpha, const Component &other_in) { + const NaturalGradientPositiveAffineComponent *other = + dynamic_cast(&other_in); + KALDI_ASSERT(other); + if (ensure_positive_linear_component_ && !other->PositiveLinearComponentEnsured()) { + KALDI_ERR << "Trying to add a positive linear component with a non-positive one"; + } + NaturalGradientAffineComponent::Add(alpha, other_in); +} + +void NaturalGradientPositiveAffineComponent::SetPositive(NaturalGradientPositiveAffineComponent::PositivityMethod method) { + if (ensure_positive_linear_component_) { + if (method == kFloor) linear_params_.ApplyFloor(0.0); + else if (method == kAbsoluteValue) linear_params_.ApplyPowAbs(1.0); + } +} + +/* +void NaturalGradientLogExpAffineComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + + CuVector in_max(in.NumRows()); + in_max.ComputeMaxPerRow(in); + + CuMatrix in_temp(in.NumRows(), in.NumCols() + 1); + CuSubMatrix in_linear(in_temp, 0, in.NumRows(), 0, in.NumCols()); + in_linear.CopyFromMat(in); + in_linear.AddVecToCols(-1.0, in_max); + + CuMatrix params(linear_params_.NumRows(), linear_params_.NumCols() + 1); + params.Range(0, linear_params_.NumRows(), 0, linear_params_.NumCols()).CopyFromMat(linear_params_); + params.CopyColFromVec(linear_params_.NumCols(), bias_params_); + + // Think of bias_params_ as being stored in log-domain + // No need for asserts as they'll happen within the matrix operations. + out->CopyRowsFromVec(bias_params_); // copies bias_params_ to each row + // of *out. + + out->LogExpAddMatMat(1.0, in, kNoTrans, kLogUnit, linear_params_, kTrans, kNoLogUnit, 1.0, NaturalGradientPositiveAffineComponent::kFloor); +} + +void NaturalGradientLogExpAffineComponent::Backprop( + const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update_in, + CuMatrixBase *in_deriv) const { + NaturalGradientLogExpAffineComponent *to_update = + dynamic_cast(to_update_in); + + CuVector in_max; + in_max.ComputeRowMax(in_value); + + CuMatrix in(in_value); + in.AddVecToCols(-1.0, in_max, 0.0); + in.ApplyExp(); + + CuMatrix out(out_value); + out.AddVecToCols(-1.0, in_max, 0.0); + + out.ApplyExp(); + out.MulElements(1.0, out_deriv); + + if (in_deriv) { + CuMatrix this_in_deriv(in_deriv->NumRows(), in_deriv->NumCols()); + this_in_deriv.AddMatMat(-1.0, out, kNoTrans, linear_params_, kNoTrans, 0.0); + this_in_deriv.MulElements(in); + in_deriv->AddMat(-1.0, this_in_deriv, kNoTrans); + } + + if (to_update) { + to_update->Update(debug_info, in, out, linear_params_); + } +} + +BaseFloat NaturalGradientLogExpAffineComponent::Update( + const std::string &debug_info, + const CuMatrixBase &in, + const CuMatrixBase &out, + const CuMatrixBase &linear_params) { + + linear_params_.AddMatMat(1.0, in, kTrans, out, kNoTrans); + + CuMatrix in_value_temp; + + in_value_temp.Resize(in_value.NumRows(), + in_value.NumCols() + 1, kUndefined); + in_value_temp.Range(0, in_value.NumRows(), + 0, in_value.NumCols()).CopyFromMat(in_value); + + // Add the 1.0 at the end of each row "in_value_temp" + in_value_temp.Range(0, in_value.NumRows(), + in_value.NumCols(), 1).Set(1.0); + + CuMatrix out_deriv_temp(out_deriv); + + CuMatrix row_products(2, + in_value.NumRows()); + CuSubVector in_row_products(row_products, 0), + out_row_products(row_products, 1); + + // These "scale" values get will get multiplied into the learning rate (faster + // than having the matrices scaled inside the preconditioning code). + BaseFloat in_scale, out_scale; + + preconditioner_in_.PreconditionDirections(&in_value_temp, &in_row_products, + &in_scale); + preconditioner_out_.PreconditionDirections(&out_deriv_temp, &out_row_products, + &out_scale); + + // "scale" is a scaling factor coming from the PreconditionDirections calls + // (it's faster to have them output a scaling factor than to have them scale + // their outputs). + BaseFloat scale = in_scale * out_scale; + + CuSubMatrix in_value_precon_part(in_value_temp, + 0, in_value_temp.NumRows(), + 0, in_value_temp.NumCols() - 1); + // this "precon_ones" is what happens to the vector of 1's representing + // offsets, after multiplication by the preconditioner. + CuVector precon_ones(in_value_temp.NumRows()); + + precon_ones.CopyColFromMat(in_value_temp, in_value_temp.NumCols() - 1); + + BaseFloat local_lrate = scale * learning_rate_; + update_count_ += 1.0; + bias_params_.AddMatVec(local_lrate, out_deriv_temp, kTrans, + precon_ones, 1.0); + linear_params_.AddMatMat(local_lrate, out_deriv_temp, kTrans, + in_value_precon_part, kNoTrans, 1.0); + return local_lrate; + + BaseFloat local_lrate = NaturalGradientAffineComponent::Update(debug_info, in_value, out_deriv); + CuMatrix sign_linear_params(linear_params); + sign_linear_params.ApplySignum(); + if (sparsity_constant_ > 0.0) + linear_params_.AddMat(-sparsity_constant_ / in_value.NumRows() * local_lrate, sign_linear_params); + return local_lrate; +} +*/ + std::string FixedAffineComponent::Info() const { std::ostringstream stream; stream << Component::Info(); @@ -2651,7 +3110,7 @@ void FixedScaleComponent::Init(const CuVectorBase &scales) { void FixedScaleComponent::InitFromConfig(ConfigLine *cfl) { - std::string filename; + std::string filename; BaseFloat scale; // Accepts "scales" config (for filename) or "dim" -> random init, for testing. if (cfl->GetValue("scales", &filename)) { if (cfl->HasUnusedValues()) @@ -2660,6 +3119,15 @@ void FixedScaleComponent::InitFromConfig(ConfigLine *cfl) { CuVector vec; ReadKaldiObject(filename, &vec); Init(vec); + } else if (cfl->GetValue("scale", &scale)) { + int32 dim; + if (!cfl->GetValue("dim", &dim) || cfl->HasUnusedValues()) + KALDI_ERR << "Invalid initializer for layer of type " + << Type() << ": \"" << cfl->WholeLine() << "\""; + KALDI_ASSERT(dim > 0); + CuVector vec(dim); + vec.Add(scale); + Init(vec); } else { int32 dim; if (!cfl->GetValue("dim", &dim) || cfl->HasUnusedValues()) diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index 4ba4a7d1c0b..47bb3e527a8 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -380,11 +380,12 @@ class AffineComponent: public UpdatableComponent { friend class NaturalGradientAffineComponent; // This function Update() is for extensibility; child classes may override // this, e.g. for natural gradient update. - virtual void Update( + virtual BaseFloat Update( const std::string &debug_info, const CuMatrixBase &in_value, const CuMatrixBase &out_deriv) { UpdateSimple(in_value, out_deriv); + return 0.0; // child classes may return local learning rate } // UpdateSimple is used when *this is a gradient. Child classes may override // this if needed, but typically won't need to. @@ -626,6 +627,61 @@ class LogSoftmaxComponent: public NonlinearComponent { LogSoftmaxComponent &operator = (const LogSoftmaxComponent &other); // Disallow. }; +// The LogComponent outputs the log of input values as y = Log(max(x, epsi)) +class LogComponent: public NonlinearComponent { + public: + explicit LogComponent(int32 dim): NonlinearComponent(dim) { } + explicit LogComponent(const LogComponent &other): + NonlinearComponent(other) { } + LogComponent() { } + virtual std::string Type() const { return "LogComponent"; } + virtual int32 Properties() const { + return kSimpleComponent|kBackpropNeedsInput|kStoresStats; + } + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const; + + virtual Component* Copy() const { return new LogComponent(*this); } + private: + LogComponent &operator = (const LogComponent &other); // Disallow. + static const BaseFloat kLogFloor; +}; + +// The ExpComponent outputs the exp of input values as y = Exp(x) +class ExpComponent: public NonlinearComponent { + public: + explicit ExpComponent(int32 dim): NonlinearComponent(dim) { } + explicit ExpComponent(const ExpComponent &other): + NonlinearComponent(other) { } + ExpComponent() { } + virtual std::string Type() const { return "ExpComponent"; } + virtual int32 Properties() const { + return kSimpleComponent|kBackpropNeedsOutput|kStoresStats; + } + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, + const CuMatrixBase &out_value, + const CuMatrixBase &, + Component *to_update, + CuMatrixBase *in_deriv) const; + + virtual Component* Copy() const { return new ExpComponent(*this); } + private: + ExpComponent &operator = (const ExpComponent &other); // Disallow. +}; + /// Keywords: natural gradient descent, NG-SGD, naturalgradient. For /// the top-level of the natural gradient code look here, and also in /// nnet-precondition-online.h. @@ -662,6 +718,9 @@ class NaturalGradientAffineComponent: public AffineComponent { const NaturalGradientAffineComponent &other); virtual void ZeroStats(); + protected: + friend class NaturalGradientPositiveAffineComponent; + private: // disallow assignment operator. NaturalGradientAffineComponent &operator= ( @@ -706,12 +765,113 @@ class NaturalGradientAffineComponent: public AffineComponent { // from the class variables. void SetNaturalGradientConfigs(); - virtual void Update( + // returns the local learning rate used + virtual BaseFloat Update( const std::string &debug_info, const CuMatrixBase &in_value, const CuMatrixBase &out_deriv); }; +class NaturalGradientPositiveAffineComponent: public NaturalGradientAffineComponent { + public: + enum PositivityMethod { kFloor, kAbsoluteValue }; + virtual std::string Type() const { return "NaturalGradientPositiveAffineComponent"; } + virtual void Read(std::istream &is, bool binary); + virtual void Write(std::ostream &os, bool binary) const; + + void Init(int32 input_dim, int32 output_dim, + BaseFloat param_stddev, + BaseFloat bias_mean, BaseFloat bias_stddev, + int32 rank_in, int32 rank_out, int32 update_period, + BaseFloat num_samples_history, BaseFloat alpha, + BaseFloat max_change_per_sample, + bool ensure_positive_linear_component, + BaseFloat sparsity_constant = 0.0); + + void Init(int32 rank_in, + int32 rank_out, int32 update_period, + BaseFloat num_samples_history, + BaseFloat alpha, BaseFloat max_change_per_sample, + bool ensure_positive_linear_component, + BaseFloat sparsity_constant, + std::string matrix_filename); + + virtual void InitFromConfig(ConfigLine *cfl); + virtual std::string Info() const; + virtual Component* Copy() const; + virtual void Scale(BaseFloat scale); + virtual void Add(BaseFloat alpha, const Component &other); + NaturalGradientPositiveAffineComponent(); + + void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &, // out_value + const CuMatrixBase &out_deriv, + Component *to_update_in, + CuMatrixBase *in_deriv) const; + + // Make the linear_params_ to be positive using the method defined by + // PositivityMethod. + // The normal way this is done is to apply flooring at 0.0 (kFloor). + // But during initialization, the parameters are made positive by taking + // the absolute value (kAbsoluteValue). + void SetPositive(PositivityMethod method = kFloor); + void SetSparsityConstant(BaseFloat sparsity_constant) { + sparsity_constant_ = sparsity_constant; } + bool PositiveLinearComponentEnsured() const { + return ensure_positive_linear_component_; } + + int32 Properties() const { + return kSimpleComponent|kUpdatableComponent|kLinearInParameters| + kBackpropNeedsInput|kBackpropAdds| + kSparsityPrior|kPositiveLinearParameters; + } + + protected: + KALDI_DISALLOW_COPY_AND_ASSIGN(NaturalGradientPositiveAffineComponent); + + bool ensure_positive_linear_component_; + BaseFloat sparsity_constant_; + + virtual BaseFloat Update( + const std::string &debug_info, + const CuMatrixBase &in_value, + const CuMatrixBase &out_deriv, + const CuMatrixBase &linear_params); + + // // Add L1 regularization penalty term. This is assumed to be called by the + // // Update method after the normal update without considering the L1 penalty is + // // done. + // void AddPenalty(BaseFloat sparsity_constant, BaseFloat local_lrate); + friend class NaturalGradientLogExpAffineComponent; +}; + +class NaturalGradientLogExpAffineComponent: public NaturalGradientPositiveAffineComponent { + public: + virtual std::string Type() const { return "NaturalGradientLogExpAffineComponent"; } + + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &, // out_value + const CuMatrixBase &out_deriv, + Component *to_update_in, + CuMatrixBase *in_deriv) const; + + int32 Properties() const { + return kSimpleComponent|kUpdatableComponent| + kBackpropNeedsInput|kBackpropAdds; + } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(NaturalGradientLogExpAffineComponent); + +}; /// FixedAffineComponent is an affine transform that is supplied /// at network initialization time and is not trainable. diff --git a/src/nnet3/nnet-test-utils.cc b/src/nnet3/nnet-test-utils.cc index 98ba77822d2..86b37c3a7a5 100644 --- a/src/nnet3/nnet-test-utils.cc +++ b/src/nnet3/nnet-test-utils.cc @@ -870,7 +870,7 @@ void ComputeExampleComputationRequestSimple( static void GenerateRandomComponentConfig(std::string *component_type, std::string *config) { - int32 n = RandInt(0, 26); + int32 n = RandInt(0, 28); BaseFloat learning_rate = 0.001 * RandInt(1, 3); std::ostringstream os; @@ -1136,6 +1136,27 @@ static void GenerateRandomComponentConfig(std::string *component_type, << " pool-z-step=" << pool_z_step; break; } + case 27: { + *component_type = "ExpComponent"; + os << "dim=" << RandInt(1, 50); + break; + } + case 28: { + *component_type = "LogComponent"; + os << "dim=" << RandInt(1, 50); + } + //case 28: { + // break; + // *component_type = "NaturalGradientPositiveAffineComponent"; + // int32 input_dim = RandInt(1, 50), output_dim = RandInt(1, 50); + // bool ensure_positive_linear_component = true; + // BaseFloat sparsity_constant = std::abs(RandGauss()); + // os << "input-dim=" << input_dim << " output-dim=" << output_dim + // << " learning-rate=" << learning_rate + // << " sparsity-constant=" << sparsity_constant + // << " ensure-positive-linear-component=" << (ensure_positive_linear_component ? "true" : "false"); + // break; + //} default: KALDI_ERR << "Error generating random component"; } diff --git a/src/nnet3/nnet-training.cc b/src/nnet3/nnet-training.cc index 359254f4794..f3ed861271f 100644 --- a/src/nnet3/nnet-training.cc +++ b/src/nnet3/nnet-training.cc @@ -96,9 +96,25 @@ void NnetTrainer::ProcessOutputs(const NnetExample &eg, ObjectiveType obj_type = nnet_->GetNode(node_index).u.objective_type; BaseFloat tot_weight, tot_objf; bool supply_deriv = true; - ComputeObjectiveFunction(io.features, obj_type, io.name, - supply_deriv, computer, - &tot_weight, &tot_objf); + + const CuMatrixBase &nnet_output = computer->GetOutput(io.name); + CuMatrix nnet_output_deriv(nnet_output.NumRows(), + nnet_output.NumCols(), + kUndefined); + + ComputeObjectiveFunction(io.features, obj_type, io.name, nnet_output, + &tot_weight, &tot_objf, + supply_deriv ? &nnet_output_deriv : NULL); + + if (supply_deriv) { + if (config_.apply_deriv_weights && io.deriv_weights.Dim() != 0) { + CuVector cu_deriv_weights(io.deriv_weights); + nnet_output_deriv.MulRowsVec(cu_deriv_weights); + } + + computer->AcceptOutputDeriv(io.name, &nnet_output_deriv); + } + objf_info_[io.name].UpdateStats(io.name, config_.print_interval, num_minibatches_processed_++, tot_weight, tot_objf); @@ -192,18 +208,52 @@ NnetTrainer::~NnetTrainer() { void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, - bool supply_deriv, - NnetComputer *computer, + const CuMatrixBase &output, BaseFloat *tot_weight, - BaseFloat *tot_objf) { - const CuMatrixBase &output = computer->GetOutput(output_name); - + BaseFloat *tot_objf, + CuMatrixBase *output_deriv) { if (output.NumCols() != supervision.NumCols()) KALDI_ERR << "Nnet versus example output dimension (num-classes) " << "mismatch for '" << output_name << "': " << output.NumCols() << " (nnet) vs. " << supervision.NumCols() << " (egs)\n"; switch (objective_type) { + case kCrossEntropy: { + // objective is x * log(y) + (1-x) * log(1-y) + CuMatrix cu_post(supervision.NumRows(), supervision.NumCols(), + kUndefined); // x + cu_post.CopyFromGeneralMat(supervision); + + CuMatrix n_cu_post(cu_post.NumRows(), cu_post.NumCols()); + n_cu_post.Set(1.0); + n_cu_post.AddMat(-1.0, cu_post); // 1-x + + CuMatrix log_prob(output); // y + log_prob.ApplyLog(); // log(y) + + CuMatrix n_output(output.NumRows(), output.NumCols(), kSetZero); + n_output.Set(1.0); + n_output.AddMat(-1.0, output); // 1-y + n_output.ApplyLog(); // log(1-y) + + *tot_weight = cu_post.NumRows() * cu_post.NumCols(); + *tot_objf = TraceMatMat(log_prob, cu_post, kTrans) + + TraceMatMat(n_output, n_cu_post, kTrans); + + if (output_deriv) { + // deriv is x / y - (1-x) / (1-y) + n_output.ApplyExp(); // 1-y + n_cu_post.DivElements(n_output); // 1-x / (1-y) + + log_prob.ApplyExp(); // y + cu_post.DivElements(log_prob); // x / y + + output_deriv->CopyFromMat(cu_post); // x / y + output_deriv->AddMat(-1.0, n_cu_post); // x / y - (1-x) / (1-y) + } + + break; + } case kLinear: { // objective is x * y. switch (supervision.Type()) { @@ -215,33 +265,40 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, // of log-likelihoods that are normalized to sum to one. *tot_weight = cu_post.Sum(); *tot_objf = TraceMatSmat(output, cu_post, kTrans); - if (supply_deriv) { - CuMatrix output_deriv(output.NumRows(), output.NumCols(), - kUndefined); - cu_post.CopyToMat(&output_deriv); - computer->AcceptOutputDeriv(output_name, &output_deriv); + if (output_deriv) { + cu_post.CopyToMat(output_deriv); } break; } case kFullMatrix: { // there is a redundant matrix copy in here if we're not using a GPU // but we don't anticipate this code branch being used in many cases. - CuMatrix cu_post(supervision.GetFullMatrix()); - *tot_weight = cu_post.Sum(); - *tot_objf = TraceMatMat(output, cu_post, kTrans); - if (supply_deriv) - computer->AcceptOutputDeriv(output_name, &cu_post); + if (output_deriv) { + supervision.CopyToMat(output_deriv); + CuMatrixBase &cu_post = *output_deriv; + *tot_weight = cu_post.Sum(); + *tot_objf = TraceMatMat(output, cu_post, kTrans); + } else { + CuMatrix cu_post(supervision.GetFullMatrix()); + *tot_weight = cu_post.Sum(); + *tot_objf = TraceMatMat(output, cu_post, kTrans); + } break; } case kCompressedMatrix: { Matrix post; supervision.GetMatrix(&post); - CuMatrix cu_post; - cu_post.Swap(&post); - *tot_weight = cu_post.Sum(); - *tot_objf = TraceMatMat(output, cu_post, kTrans); - if (supply_deriv) - computer->AcceptOutputDeriv(output_name, &cu_post); + if (output_deriv) { + output_deriv->CopyFromMat(post); + CuMatrixBase &cu_post = *output_deriv; + *tot_weight = cu_post.Sum(); + *tot_objf = TraceMatMat(output, cu_post, kTrans); + } else { + CuMatrix cu_post; + cu_post.Swap(&post); + *tot_weight = cu_post.Sum(); + *tot_objf = TraceMatMat(output, cu_post, kTrans); + } break; } } @@ -249,15 +306,21 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, } case kQuadratic: { // objective is -0.5 (x - y)^2 - CuMatrix diff(supervision.NumRows(), - supervision.NumCols(), - kUndefined); - diff.CopyFromGeneralMat(supervision); - diff.AddMat(-1.0, output); - *tot_weight = diff.NumRows(); - *tot_objf = -0.5 * TraceMatMat(diff, diff, kTrans); - if (supply_deriv) - computer->AcceptOutputDeriv(output_name, &diff); + if (output_deriv) { + CuMatrixBase &diff = *output_deriv; + diff.CopyFromGeneralMat(supervision); + diff.AddMat(-1.0, output); + *tot_weight = diff.NumRows(); + *tot_objf = -0.5 * TraceMatMat(diff, diff, kTrans); + } else { + CuMatrix diff(supervision.NumRows(), + supervision.NumCols(), + kUndefined); + diff.CopyFromGeneralMat(supervision); + diff.AddMat(-1.0, output); + *tot_weight = diff.NumRows(); + *tot_objf = -0.5 * TraceMatMat(diff, diff, kTrans); + } break; } default: @@ -266,7 +329,5 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, } } - - } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-training.h b/src/nnet3/nnet-training.h index 7ad964084a7..d6a55286a10 100644 --- a/src/nnet3/nnet-training.h +++ b/src/nnet3/nnet-training.h @@ -38,13 +38,15 @@ struct NnetTrainerOptions { BaseFloat max_param_change; NnetOptimizeOptions optimize_config; NnetComputeOptions compute_config; + bool apply_deriv_weights; NnetTrainerOptions(): zero_component_stats(true), store_component_stats(true), print_interval(100), debug_computation(false), momentum(0.0), - max_param_change(2.0) { } + max_param_change(2.0), + apply_deriv_weights(true) { } void Register(OptionsItf *opts) { opts->Register("store-component-stats", &store_component_stats, "If true, store activations and derivatives for nonlinear " @@ -64,6 +66,9 @@ struct NnetTrainerOptions { "so that the 'effective' learning rate is the same as " "before (because momentum would normally increase the " "effective learning rate by 1/(1-momentum))"); + opts->Register("apply-deriv-weights", &apply_deriv_weights, + "If true, apply the per-frame derivative weights stored with " + "the example"); // register the optimization options with the prefix "optimization". ParseOptions optimization_opts("optimization", opts); @@ -198,10 +203,10 @@ class NnetTrainer { void ComputeObjectiveFunction(const GeneralMatrix &supervision, ObjectiveType objective_type, const std::string &output_name, - bool supply_deriv, - NnetComputer *computer, + const CuMatrixBase &output, BaseFloat *tot_weight, - BaseFloat *tot_objf); + BaseFloat *tot_objf, + CuMatrixBase *output_deriv); diff --git a/src/nnet3/nnet-utils.cc b/src/nnet3/nnet-utils.cc index 803ca98abed..08d7a1e8be0 100644 --- a/src/nnet3/nnet-utils.cc +++ b/src/nnet3/nnet-utils.cc @@ -142,7 +142,7 @@ void ComputeSimpleNnetContext(const Nnet &nnet, // This will crash if the total context (left + right) is greater // than window_size. - int32 window_size = 100; + int32 window_size = 150; // by going "<= modulus" instead of "< modulus" we do one more computation // than we really need; it becomes a sanity check. for (int32 input_start = 0; input_start <= modulus; input_start++) @@ -421,6 +421,5 @@ std::string NnetInfo(const Nnet &nnet) { return ostr.str(); } - } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-utils.h b/src/nnet3/nnet-utils.h index 149c0e08485..3a24441fa66 100644 --- a/src/nnet3/nnet-utils.h +++ b/src/nnet3/nnet-utils.h @@ -139,6 +139,8 @@ void UnVectorizeNnet(const VectorBase ¶ms, /// Returns the number of updatable components in the nnet. int32 NumUpdatableComponents(const Nnet &dest); +void EffectPositivity(Nnet *nnet); + /// Convert all components of type RepeatedAffineComponent or /// NaturalGradientRepeatedAffineComponent to BlockAffineComponent in nnet. void ConvertRepeatedToBlockAffine(Nnet *nnet); diff --git a/src/nnet3bin/Makefile b/src/nnet3bin/Makefile index 0a57c17fad0..0fba034a3c1 100644 --- a/src/nnet3bin/Makefile +++ b/src/nnet3bin/Makefile @@ -12,7 +12,8 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \ nnet3-am-adjust-priors nnet3-am-copy nnet3-compute-prob \ nnet3-average nnet3-am-info nnet3-combine nnet3-latgen-faster \ nnet3-copy nnet3-show-progress nnet3-align-compiled \ - nnet3-get-egs-dense-targets nnet3-compute + nnet3-get-egs-dense-targets nnet3-compute nnet3-get-egs-sparse-input \ + nnet3-compute-from-sparse-input OBJFILES = diff --git a/src/nnet3bin/nnet3-acc-lda-stats.cc b/src/nnet3bin/nnet3-acc-lda-stats.cc index b59f467c7da..f5e5e41c69b 100644 --- a/src/nnet3bin/nnet3-acc-lda-stats.cc +++ b/src/nnet3bin/nnet3-acc-lda-stats.cc @@ -87,6 +87,11 @@ class NnetLdaStatsAccumulator { // "row" is actually just a redudant copy, since we're likely on CPU, // but we're about to do an outer product, so this doesn't dominate. Vector row(cu_row); + + BaseFloat deriv_weight = 0.0; + if (output_supervision->deriv_weights.Dim() > 0 && r < output_supervision->deriv_weights.Dim()) { + deriv_weight = output_supervision->deriv_weights(r); + } const SparseVector &post(smat.Row(r)); const std::pair *post_data = post.Data(), @@ -94,7 +99,7 @@ class NnetLdaStatsAccumulator { for (; post_data != post_end; ++post_data) { MatrixIndexT pdf = post_data->first; BaseFloat weight = post_data->second; - BaseFloat pruned_weight = RandPrune(weight, rand_prune); + BaseFloat pruned_weight = RandPrune(weight, rand_prune) * deriv_weight; if (pruned_weight != 0.0) lda_stats_.Accumulate(row, pdf, pruned_weight); } diff --git a/src/nnet3bin/nnet3-compute-from-sparse-input.cc b/src/nnet3bin/nnet3-compute-from-sparse-input.cc new file mode 100644 index 00000000000..cc5cbda1950 --- /dev/null +++ b/src/nnet3bin/nnet3-compute-from-sparse-input.cc @@ -0,0 +1,181 @@ +// nnet3bin/nnet3-compute-from-sparse-input.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "nnet3/nnet-am-decodable-simple.h" +#include "base/timer.h" +#include "nnet3/nnet-utils.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + + const char *usage = + "Propagate the features through raw neural network model " + "and write the output.\n" + "If --apply-exp=true, apply the Exp() function to the output " + "before writing it out.\n" + "\n" + "Usage: nnet3-compute [options] \n" + " e.g.: nnet3-compute final.raw scp:feats.scp ark:nnet_prediction.ark\n" + "See also: nnet3-compute-from-egs\n"; + + ParseOptions po(usage); + Timer timer; + + NnetSimpleComputationOptions opts; + + bool apply_exp = false; + std::string use_gpu = "yes"; + + int32 sparse_input_dim = -1; + std::string word_syms_filename; + std::string ivector_rspecifier, + online_ivector_rspecifier, + utt2spk_rspecifier; + int32 online_ivector_period = 0; + opts.Register(&po); + + po.Register("ivectors", &ivector_rspecifier, "Rspecifier for " + "iVectors as vectors (i.e. not estimated online); per utterance " + "by default, or per speaker if you provide the --utt2spk option."); + po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for " + "utt2spk option used to get ivectors per speaker"); + po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for " + "iVectors estimated online, as matrices. If you supply this," + " you must set the --online-ivector-period option."); + po.Register("online-ivector-period", &online_ivector_period, "Number of frames " + "between iVectors in matrices supplied to the --online-ivectors " + "option"); + po.Register("apply-exp", &apply_exp, "If true, apply exp function to " + "output"); + po.Register("sparse-input-dim", &sparse_input_dim, "Sparse dim"); + po.Register("use-gpu", &use_gpu, + "yes|no|optional|wait, only has effect if compiled with CUDA"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + KALDI_ASSERT(sparse_input_dim > 0); + +#if HAVE_CUDA==1 + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + std::string nnet_rxfilename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + matrix_wspecifier = po.GetArg(3); + + Nnet nnet; + ReadKaldiObject(nnet_rxfilename, &nnet); + + RandomAccessBaseFloatMatrixReader online_ivector_reader( + online_ivector_rspecifier); + RandomAccessBaseFloatVectorReaderMapped ivector_reader( + ivector_rspecifier, utt2spk_rspecifier); + + BaseFloatMatrixWriter matrix_writer(matrix_wspecifier); + + int32 num_success = 0, num_fail = 0; + int64 frame_count = 0; + + SequentialPosteriorReader feature_reader(feature_rspecifier); + + int32 left_context = 0, right_context = 0; + ComputeSimpleNnetContext(nnet, &left_context, &right_context); + + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + SparseMatrix features(sparse_input_dim, feature_reader.Value()); + if (features.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + const Matrix *online_ivectors = NULL; + const Vector *ivector = NULL; + if (!ivector_rspecifier.empty()) { + if (!ivector_reader.HasKey(utt)) { + KALDI_WARN << "No iVector available for utterance " << utt; + num_fail++; + continue; + } else { + ivector = &ivector_reader.Value(utt); + } + } + if (!online_ivector_rspecifier.empty()) { + if (!online_ivector_reader.HasKey(utt)) { + KALDI_WARN << "No online iVector available for utterance " << utt; + num_fail++; + continue; + } else { + online_ivectors = &online_ivector_reader.Value(utt); + } + } + + Matrix mat(features.NumRows(), features.NumCols()); + features.CopyToMat(&mat); + + Vector priors; + NnetDecodableBase nnet_computer( + opts, nnet, priors, mat, + ivector, online_ivectors, + online_ivector_period); + + Matrix matrix(nnet_computer.NumFrames(), + nnet_computer.OutputDim()); + for (int32 t = 0; t < nnet_computer.NumFrames(); t++) { + SubVector row(matrix, t); + nnet_computer.GetOutputForFrame(t, &row); + } + + if (apply_exp) + matrix.ApplyExp(); + + matrix_writer.Write(utt, matrix); + + frame_count += features.NumRows(); + num_success++; + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken "<< elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed*100.0/frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + + if (num_success != 0) return 0; + else return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/nnet3bin/nnet3-compute.cc b/src/nnet3bin/nnet3-compute.cc index 9db3569500e..34a3f6330d1 100644 --- a/src/nnet3bin/nnet3-compute.cc +++ b/src/nnet3bin/nnet3-compute.cc @@ -157,6 +157,9 @@ int main(int argc, char *argv[]) { num_success++; } +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif double elapsed = timer.Elapsed(); KALDI_LOG << "Time taken "<< elapsed << "s: real-time factor assuming 100 frames/sec is " diff --git a/src/nnet3bin/nnet3-copy-egs.cc b/src/nnet3bin/nnet3-copy-egs.cc index efb51f51910..8e94bcf8119 100644 --- a/src/nnet3bin/nnet3-copy-egs.cc +++ b/src/nnet3bin/nnet3-copy-egs.cc @@ -101,6 +101,215 @@ bool ContainsSingleExample(const NnetExample &eg, return true; } +struct QuantizationOptions { + void Register(OptionsItf *opts) { + opts->Register("bin-boundaries", &bin_boundaries_str, "Bin boundaries"); + } + + std::string bin_boundaries_str; +}; + +class ExampleSelector { + public: + bool SelectFromExample(const NnetExample &eg, + NnetExample *eg_out) const; + + void QuantizeExample(const NnetExample &eg, + NnetExample *eg_out) const; + + ExampleSelector(std::string frame_str, + int32 left_context, int32 right_context, + int32 frame_shift, + const QuantizationOptions &quantization_opts, + bool quantize_input): frame_str_(frame_str), + left_context_(left_context), right_context_(right_context), + frame_shift_(frame_shift), quantize_input_(quantize_input), + quantization_opts_(quantization_opts) { + if (!SplitStringToFloats(quantization_opts_.bin_boundaries_str, ":", false, + &bin_boundaries_)) { + KALDI_ERR << "Bad value for --bin-boundaries option: " + << quantization_opts_.bin_boundaries_str; + } + + if (quantize_input && NumBins() <= 1) { + KALDI_ERR << "Bad value for --bin-boundaries option: " + << quantization_opts_.bin_boundaries_str; + } + } + + private: + + void FilterExample(const NnetExample &eg, + int32 min_input_t, + int32 max_input_t, + int32 min_output_t, + int32 max_output_t, + NnetExample *eg_out) const; + + void FilterAndQuantizeGeneralMatrixRows(const GeneralMatrix &in, + const std::vector &keep_rows, + GeneralMatrix *out) const; + + void QuantizeGeneralMatrix(const GeneralMatrix &in, + GeneralMatrix *out) const; + + void QuantizeFeats(const MatrixBase &in, SparseMatrix *out) const; + void QuantizeFeats(SparseMatrix *out) const; + + int32 NumBins() const { return bin_boundaries_.size() + 1; } + + std::string frame_str_; + int32 left_context_; + int32 right_context_; + int32 frame_shift_; + bool quantize_input_; + + const QuantizationOptions &quantization_opts_; + + std::vector bin_boundaries_; +}; + +void ExampleSelector::QuantizeFeats(const MatrixBase &in, + SparseMatrix *out) const { + out->Resize(in.NumRows(), in.NumCols() * NumBins()); + for (size_t t = 0; t < in.NumRows(); t++) { + std::vector > bins(in.NumCols()); + for (size_t j = 0; j < in.NumCols(); j++) { + auto bin = std::lower_bound(bin_boundaries_.begin(), + bin_boundaries_.end(), in(t,j)); + size_t k; + if (bin != bin_boundaries_.end()) + k = static_cast(bin - bin_boundaries_.begin()); + else { + k = static_cast(NumBins() - 1); + KALDI_ASSERT(k == bin_boundaries_.end() - bin_boundaries_.begin()); + } + + KALDI_ASSERT(k >= 0 && k < NumBins()); + KALDI_ASSERT(j + NumBins() + k < out->NumCols()); + + bins[j] = std::make_pair(j * NumBins() + k, 1.0); + } + out->SetRow(t, SparseVector(out->NumCols(), bins)); + } +} + +void ExampleSelector::QuantizeFeats (SparseMatrix *out) const { + SparseVector *row = out->Data(); + for (size_t t = 0; t < out->NumRows(); t++, ++row) { + std::pair *pairs = row->Data(); + for (size_t j = 0; j < (out->Row(t)).NumElements(); j++, ++pairs) { + auto bin = std::lower_bound(bin_boundaries_.begin(), + bin_boundaries_.end(), pairs->second); + size_t k; + if (bin != bin_boundaries_.end()) + k = static_cast(bin - bin_boundaries_.begin()); + else + k = static_cast(NumBins()); + + pairs->first = pairs->first * NumBins() + k; + pairs->second = 1.0; + } + } +} + +void ExampleSelector::QuantizeGeneralMatrix (const GeneralMatrix &in, + GeneralMatrix *out) const { + out->Clear(); + switch (in.Type()) { + case kCompressedMatrix: { + const CompressedMatrix &cmat = in.GetCompressedMatrix(); + Matrix full_mat(cmat); + SparseMatrix smat; + QuantizeFeats(full_mat, &smat); + out->SwapSparseMatrix(&smat); + return; + } + case kSparseMatrix: { + SparseMatrix smat(in.GetSparseMatrix()); + QuantizeFeats(&smat); + out->SwapSparseMatrix(&smat); + return; + } + case kFullMatrix: { + const Matrix &full_mat = in.GetFullMatrix(); + SparseMatrix smat; + QuantizeFeats(full_mat, &smat); + out->SwapSparseMatrix(&smat); + return; + } + default: + KALDI_ERR << "Invalid general-matrix type."; + } +} + +void ExampleSelector::FilterAndQuantizeGeneralMatrixRows( + const GeneralMatrix &in, + const std::vector &keep_rows, + GeneralMatrix *out) const { + out->Clear(); + KALDI_ASSERT(keep_rows.size() == static_cast(in.NumRows())); + int32 num_kept_rows = 0; + std::vector::const_iterator iter = keep_rows.begin(), + end = keep_rows.end(); + for (; iter != end; ++iter) + if (*iter) + num_kept_rows++; + if (num_kept_rows == 0) + KALDI_ERR << "No kept rows"; + switch (in.Type()) { + case kCompressedMatrix: { + const CompressedMatrix &cmat = in.GetCompressedMatrix(); + Matrix full_mat_out; + FilterCompressedMatrixRows(cmat, keep_rows, &full_mat_out); + SparseMatrix smat_out; + QuantizeFeats(full_mat_out, &smat_out); + out->SwapSparseMatrix(&smat_out); + return; + } + case kSparseMatrix: { + const SparseMatrix &smat = in.GetSparseMatrix(); + SparseMatrix smat_out; + FilterSparseMatrixRows(smat, keep_rows, &smat_out); + QuantizeFeats(&smat_out); + return; + } + case kFullMatrix: { + const Matrix &full_mat = in.GetFullMatrix(); + Matrix full_mat_out; + FilterMatrixRows(full_mat, keep_rows, &full_mat_out); + SparseMatrix smat_out; + QuantizeFeats(full_mat_out, &smat_out); + out->SwapSparseMatrix(&smat_out); + return; + } + default: + KALDI_ERR << "Invalid general-matrix type."; + } +} + +void ExampleSelector::QuantizeExample(const NnetExample &eg, + NnetExample *eg_out) const { + eg_out->io.clear(); + eg_out->io.resize(eg.io.size()); + for (size_t i = 0; i < eg.io.size(); i++) { + bool is_input = false; + const NnetIo &io_in = eg.io[i]; + NnetIo &io_out = eg_out->io[i]; + const std::string &name = io_in.name; + io_out.name = name; + if (name == "input") { + is_input = true; + } + io_out.indexes = io_in.indexes; + if (!is_input || !quantize_input_) // Just copy everything. + io_out.features = io_in.features; + else { + QuantizeGeneralMatrix(io_in.features, &io_out.features); + } + } +} + /** This function filters the indexes (and associated feature rows) in a NnetExample, removing any index/row in an NnetIo named "input" with t < @@ -108,16 +317,16 @@ bool ContainsSingleExample(const NnetExample &eg, min_output_t or t > max_output_t. Will crash if filtering removes all Indexes of "input" or "output". */ -void FilterExample(const NnetExample &eg, - int32 min_input_t, - int32 max_input_t, - int32 min_output_t, - int32 max_output_t, - NnetExample *eg_out) { +void ExampleSelector::FilterExample(const NnetExample &eg, + int32 min_input_t, + int32 max_input_t, + int32 min_output_t, + int32 max_output_t, + NnetExample *eg_out) const { eg_out->io.clear(); eg_out->io.resize(eg.io.size()); for (size_t i = 0; i < eg.io.size(); i++) { - bool is_input_or_output; + bool is_input_or_output = false, is_input = false; int32 min_t, max_t; const NnetIo &io_in = eg.io[i]; NnetIo &io_out = eg_out->io[i]; @@ -127,6 +336,7 @@ void FilterExample(const NnetExample &eg, min_t = min_input_t; max_t = max_input_t; is_input_or_output = true; + is_input = true; } else if (name == "output") { min_t = min_output_t; max_t = max_output_t; @@ -160,8 +370,13 @@ void FilterExample(const NnetExample &eg, if (num_kept == 0) KALDI_ERR << "FilterExample removed all indexes for '" << name << "'"; - FilterGeneralMatrixRows(io_in.features, keep, - &io_out.features); + if (is_input && quantize_input_) + FilterAndQuantizeGeneralMatrixRows(io_in.features, keep, + &io_out.features); + else + FilterGeneralMatrixRows(io_in.features, keep, + &io_out.features); + KALDI_ASSERT(io_out.features.NumRows() == num_kept && indexes_out.size() == static_cast(num_kept)); } @@ -185,27 +400,23 @@ void FilterExample(const NnetExample &eg, the end of a file and has a smaller than normal number of supervised frames. */ -bool SelectFromExample(const NnetExample &eg, - std::string frame_str, - int32 left_context, - int32 right_context, - int32 frame_shift, - NnetExample *eg_out) { +bool ExampleSelector::SelectFromExample(const NnetExample &eg, + NnetExample *eg_out) const { int32 min_input_t, max_input_t, min_output_t, max_output_t; if (!ContainsSingleExample(eg, &min_input_t, &max_input_t, &min_output_t, &max_output_t)) KALDI_ERR << "Too late to perform frame selection/context reduction on " << "these examples (already merged?)"; - if (frame_str != "") { + if (frame_str_ != "") { // select one frame. - if (frame_str == "random") { + if (frame_str_ == "random") { min_output_t = max_output_t = RandInt(min_output_t, max_output_t); } else { int32 frame; - if (!ConvertStringToInteger(frame_str, &frame)) - KALDI_ERR << "Invalid option --frame='" << frame_str << "'"; + if (!ConvertStringToInteger(frame_str_, &frame)) + KALDI_ERR << "Invalid option --frame='" << frame_str_ << "'"; if (frame < min_output_t || frame > max_output_t) { // Frame is out of range. Should happen only rarely. Calling code // makes sure of this. @@ -217,28 +428,29 @@ bool SelectFromExample(const NnetExample &eg, // There may come a time when we want to remove or make it possible to disable // the error messages below. The std::max and std::min expressions may seem // unnecessary but are intended to make life easier if and when we do that. - if (left_context != -1) { - if (min_input_t > min_output_t - left_context) - KALDI_ERR << "You requested --left-context=" << left_context + if (left_context_ != -1) { + if (min_input_t > min_output_t - left_context_) + KALDI_ERR << "You requested --left-context=" << left_context_ << ", but example only has left-context of " << (min_output_t - min_input_t); - min_input_t = std::max(min_input_t, min_output_t - left_context); + min_input_t = std::max(min_input_t, min_output_t - left_context_); } - if (right_context != -1) { - if (max_input_t < max_output_t + right_context) - KALDI_ERR << "You requested --right-context=" << right_context + if (right_context_ != -1) { + if (max_input_t < max_output_t + right_context_) + KALDI_ERR << "You requested --right-context=" << right_context_ << ", but example only has right-context of " << (max_input_t - max_output_t); - max_input_t = std::min(max_input_t, max_output_t + right_context); + max_input_t = std::min(max_input_t, max_output_t + right_context_); } + FilterExample(eg, min_input_t, max_input_t, min_output_t, max_output_t, eg_out); - if (frame_shift != 0) { + if (frame_shift_ != 0) { std::vector exclude_names; // we can later make this exclude_names.push_back(std::string("ivector")); // configurable. - ShiftExampleTimes(frame_shift, exclude_names, eg_out); + ShiftExampleTimes(frame_shift_, exclude_names, eg_out); } return true; } @@ -279,8 +491,14 @@ int main(int argc, char *argv[]) { // you can set frame to a number to select a single frame with a particular // offset, or to 'random' to select a random single frame. std::string frame_str; + + bool quantize_input = false; + + QuantizationOptions quantization_opts; ParseOptions po(usage); + quantization_opts.Register(&po); + po.Register("random", &random, "If true, will write frames to output " "archives randomly, not round-robin."); po.Register("frame-shift", &frame_shift, "Allows you to shift time values " @@ -301,8 +519,8 @@ int main(int argc, char *argv[]) { "feature left-context that we output."); po.Register("right-context", &right_context, "Can be used to truncate the " "feature right-context that we output."); - - + po.Register("quantize-input", &quantize_input, "If true, quantize input"); + po.Read(argc, argv); srand(srand_seed); @@ -321,6 +539,8 @@ int main(int argc, char *argv[]) { for (int32 i = 0; i < num_outputs; i++) example_writers[i] = new NnetExampleWriter(po.GetArg(i+2)); + ExampleSelector selector(frame_str, left_context, right_context, frame_shift, + quantization_opts, quantize_input); int64 num_read = 0, num_written = 0; for (; !example_reader.Done(); example_reader.Next(), num_read++) { @@ -331,15 +551,20 @@ int main(int argc, char *argv[]) { for (int32 c = 0; c < count; c++) { int32 index = (random ? Rand() : num_written) % num_outputs; if (frame_str == "" && left_context == -1 && right_context == -1 && - frame_shift == 0) { + !quantize_input && frame_shift == 0) { example_writers[index]->Write(key, eg); num_written++; } else { // the --frame option or context options were set. NnetExample eg_modified; - if (SelectFromExample(eg, frame_str, left_context, right_context, - frame_shift, &eg_modified)) { - // this branch of the if statement will almost always be taken (should only - // not be taken for shorter-than-normal egs from the end of a file. + if (! ( frame_str.empty() && left_context == -1 && right_context == -1 ) ) { + if (selector.SelectFromExample(eg, &eg_modified)) { + // this branch of the if statement will almost always be taken (should only + // not be taken for shorter-than-normal egs from the end of a file. + example_writers[index]->Write(key, eg_modified); + num_written++; + } + } else { + selector.QuantizeExample(eg, &eg_modified); example_writers[index]->Write(key, eg_modified); num_written++; } diff --git a/src/nnet3bin/nnet3-copy.cc b/src/nnet3bin/nnet3-copy.cc index 8d171cfa121..a1bb2d5153c 100644 --- a/src/nnet3bin/nnet3-copy.cc +++ b/src/nnet3bin/nnet3-copy.cc @@ -24,6 +24,34 @@ #include "hmm/transition-model.h" #include "nnet3/am-nnet-simple.h" #include "nnet3/nnet-utils.h" +#include "nnet3/nnet-simple-component.h" + +namespace kaldi { +namespace nnet3 { + +void SetSparsityConstant(BaseFloat sparsity_constant, + Nnet *nnet) { + bool success = false; + for (int32 c = 0; c < nnet->NumComponents(); c++) { + Component *comp = nnet->GetComponent(c); + if ( (comp->Properties() & kUpdatableComponent) && + (comp->Properties() & kSparsityPrior) ) { + // For now all updatable components inherit from class UpdatableComponent. + // If that changes in future, we will change this code. + NaturalGradientPositiveAffineComponent *uc = + dynamic_cast(comp); + if (uc == NULL) + KALDI_ERR << "Updatable component does not inherit from class " + "UpdatableComponent; change this code."; + uc->SetSparsityConstant(sparsity_constant); + success = true; + } + } + KALDI_ASSERT(success); +} + +} +} int main(int argc, char *argv[]) { try { @@ -42,12 +70,19 @@ int main(int argc, char *argv[]) { bool binary_write = true; BaseFloat learning_rate = -1; - + BaseFloat sparsity_constant = -1; + BaseFloat scale = 1.0; + ParseOptions po(usage); po.Register("binary", &binary_write, "Write output in binary mode"); po.Register("learning-rate", &learning_rate, "If supplied, all the learning rates of updatable components" "are set to this value."); + po.Register("sparsity-constant", &sparsity_constant, + "If supplied, set the sparsity constant for regularization " + "of the components that support have a positive sparsity-constant"); + po.Register("scale", &scale, "The parameter matrices are scaled" + " by the specified value."); po.Read(argc, argv); @@ -64,6 +99,12 @@ int main(int argc, char *argv[]) { if (learning_rate >= 0) SetLearningRate(learning_rate, &nnet); + + if (scale != 1.0) + ScaleNnet(scale, &nnet); + + if (sparsity_constant >= 0.0) + SetSparsityConstant(sparsity_constant, &nnet); WriteKaldiObject(nnet, raw_nnet_wxfilename, binary_write); KALDI_LOG << "Copied raw neural net from " << raw_nnet_rxfilename diff --git a/src/nnet3bin/nnet3-get-egs-dense-targets.cc b/src/nnet3bin/nnet3-get-egs-dense-targets.cc index 23bf8922a5b..ac05be7dcc0 100644 --- a/src/nnet3bin/nnet3-get-egs-dense-targets.cc +++ b/src/nnet3bin/nnet3-get-egs-dense-targets.cc @@ -32,6 +32,7 @@ namespace nnet3 { static void ProcessFile(const MatrixBase &feats, const MatrixBase *ivector_feats, + const VectorBase *deriv_weights, const MatrixBase &targets, const std::string &utt_id, bool compress, @@ -57,7 +58,7 @@ static void ProcessFile(const MatrixBase &feats, int32 tot_frames = left_context + frames_per_eg + right_context; - Matrix input_frames(tot_frames, feats.NumCols()); + Matrix input_frames(tot_frames, feats.NumCols(), kUndefined); // Set up "input_frames". for (int32 j = -left_context; j < frames_per_eg + right_context; j++) { @@ -76,7 +77,7 @@ static void ProcessFile(const MatrixBase &feats, input_frames)); // if applicable, add the iVector feature. - if (ivector_feats != NULL) { + if (ivector_feats) { // try to get closest frame to middle of window to get // a representative iVector. int32 closest_frame = t + (actual_frames_per_eg / 2); @@ -108,8 +109,16 @@ static void ProcessFile(const MatrixBase &feats, this_target_dest.CopyFromVec(this_target_src); } - // push this created targets matrix into the eg - eg.io.push_back(NnetIo("output", 0, targets_dest)); + if (!deriv_weights) { + // push this created targets matrix into the eg + eg.io.push_back(NnetIo("output", 0, targets_dest)); + } else { + Vector this_deriv_weights(targets_dest.NumRows()); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, deriv_weights->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) continue; // Ignore frames that have frame weights 0 + eg.io.push_back(NnetIo("output", this_deriv_weights, 0, targets_dest)); + } if (compress) eg.Compress(); @@ -158,7 +167,7 @@ int main(int argc, char *argv[]) { int32 num_targets = -1, left_context = 0, right_context = 0, num_frames = 1, length_tolerance = 100; - std::string ivector_rspecifier; + std::string ivector_rspecifier, deriv_weights_rspecifier; ParseOptions po(usage); po.Register("compress", &compress, "If true, write egs in " @@ -174,6 +183,11 @@ int main(int argc, char *argv[]) { "features, as matrix."); po.Register("length-tolerance", &length_tolerance, "Tolerance for " "difference in num-frames between feat and ivector matrices"); + po.Register("deriv-weights-rspecifier", &deriv_weights_rspecifier, + "Per-frame weights (only binary - 0 or 1) that specifies " + "whether a frame's gradient must be backpropagated or not. " + "Not specifying this is equivalent to specifying a vector of " + "all 1s."); po.Read(argc, argv); @@ -194,6 +208,7 @@ int main(int argc, char *argv[]) { RandomAccessBaseFloatMatrixReader matrix_reader(matrix_rspecifier); NnetExampleWriter example_writer(examples_wspecifier); RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + RandomAccessBaseFloatVectorReader deriv_weights_reader(deriv_weights_rspecifier); int32 num_done = 0, num_err = 0; int64 num_frames_written = 0, num_egs_written = 0; @@ -226,7 +241,7 @@ int main(int argc, char *argv[]) { } } - if (ivector_feats != NULL && + if (ivector_feats && (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance || ivector_feats->NumRows() == 0)) { KALDI_WARN << "Length difference between feats " << feats.NumRows() @@ -235,8 +250,33 @@ int main(int argc, char *argv[]) { num_err++; continue; } + + const Vector *deriv_weights = NULL; + if (!deriv_weights_rspecifier.empty()) { + if (!deriv_weights_reader.HasKey(key)) { + KALDI_WARN << "No deriv weights for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights = &(deriv_weights_reader.Value(key)); + } + } + + if (deriv_weights && + (abs(feats.NumRows() - deriv_weights->Dim()) > length_tolerance + || deriv_weights->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + - ProcessFile(feats, ivector_feats, target_matrix, key, compress, + ProcessFile(feats, ivector_feats, deriv_weights, target_matrix, + key, compress, num_targets, left_context, right_context, num_frames, &num_frames_written, &num_egs_written, &example_writer); diff --git a/src/nnet3bin/nnet3-get-egs-sparse-input.cc b/src/nnet3bin/nnet3-get-egs-sparse-input.cc new file mode 100644 index 00000000000..78e7a2cbfaa --- /dev/null +++ b/src/nnet3bin/nnet3-get-egs-sparse-input.cc @@ -0,0 +1,244 @@ +// nnet3bin/nnet3-get-egs-sparse-input.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2014 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "hmm/posterior.h" +#include "nnet3/nnet-example.h" + +namespace kaldi { +namespace nnet3 { + + +static void ProcessFile(const SparseMatrix &feats, + const MatrixBase *ivector_feats, + const Posterior &pdf_post, + const std::string &utt_id, + bool compress, + int32 num_pdfs, + int32 left_context, + int32 right_context, + int32 frames_per_eg, + int64 *num_frames_written, + int64 *num_egs_written, + NnetExampleWriter *example_writer) { + KALDI_ASSERT(feats.NumRows() == static_cast(pdf_post.size())); + + for (int32 t = 0; t < feats.NumRows(); t += frames_per_eg) { + + // actual_frames_per_eg is the number of frames with nonzero + // posteriors. At the end of the file we pad with zero posteriors + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min(frames_per_eg, + feats.NumRows() - t); + + + int32 tot_frames = left_context + frames_per_eg + right_context; + + SparseMatrix input_frames(tot_frames, feats.NumCols()); + + // Set up "input_frames". + for (int32 j = -left_context; j < frames_per_eg + right_context; j++) { + int32 t2 = j + t; + if (t2 < 0) t2 = 0; + if (t2 >= feats.NumRows()) t2 = feats.NumRows() - 1; + input_frames.SetRow(j + left_context, feats.Row(t2)); + } + + NnetExample eg; + + // call the regular input "input". + eg.io.push_back(NnetIo("input", - left_context, + input_frames)); + + // if applicable, add the iVector feature. + if (ivector_feats != NULL) { + // try to get closest frame to middle of window to get + // a representative iVector. + int32 closest_frame = t + (actual_frames_per_eg / 2); + KALDI_ASSERT(ivector_feats->NumRows() > 0); + if (closest_frame >= ivector_feats->NumRows()) + closest_frame = ivector_feats->NumRows() - 1; + Matrix ivector(1, ivector_feats->NumCols()); + ivector.Row(0).CopyFromVec(ivector_feats->Row(closest_frame)); + eg.io.push_back(NnetIo("ivector", 0, ivector)); + } + + // add the labels. + Posterior labels(frames_per_eg); + for (int32 i = 0; i < actual_frames_per_eg; i++) + labels[i] = pdf_post[t + i]; + // remaining posteriors for frames are empty. + eg.io.push_back(NnetIo("output", num_pdfs, 0, labels)); + + if (compress) + eg.Compress(); + + std::ostringstream os; + os << utt_id << "-" << t; + + std::string key = os.str(); // key is - + + *num_frames_written += actual_frames_per_eg; + *num_egs_written += 1; + + example_writer->Write(key, eg); + } +} + + +} // namespace nnet2 +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Get frame-by-frame examples of data for nnet3 neural network training.\n" + "Essentially this is a format change from features and posteriors\n" + "into a special frame-by-frame format. This program handles the\n" + "common case where you have some input features, possibly some\n" + "iVectors, and one set of labels. If people in future want to\n" + "do different things they may have to extend this program or create\n" + "different versions of it for different tasks (the egs format is quite\n" + "general)\n" + "\n" + "Usage: nnet3-get-egs [options] " + " \n" + "\n" + "An example [where $feats expands to the actual features]:\n" + "nnet-get-egs --num-pdfs=2658 --left-context=12 --right-context=9 --num-frames=8 \"$feats\"\\\n" + "\"ark:gunzip -c exp/nnet/ali.1.gz | ali-to-pdf exp/nnet/1.nnet ark:- ark:- | ali-to-post ark:- ark:- |\" \\\n" + " ark:- \n"; + + + bool compress = true; + int32 num_pdfs = -1, left_context = 0, right_context = 0, + num_frames = 1, length_tolerance = 100, sparse_input_dim = -1; + + std::string ivector_rspecifier; + + ParseOptions po(usage); + po.Register("compress", &compress, "If true, write egs in " + "compressed format."); + po.Register("num-pdfs", &num_pdfs, "Number of pdfs in the acoustic " + "model"); + po.Register("sparse-input-dim", &sparse_input_dim, "Sparse input feature dimension"); + po.Register("left-context", &left_context, "Number of frames of left " + "context the neural net requires."); + po.Register("right-context", &right_context, "Number of frames of right " + "context the neural net requires."); + po.Register("num-frames", &num_frames, "Number of frames with labels " + "that each example contains."); + po.Register("ivectors", &ivector_rspecifier, "Rspecifier of ivector " + "features, as matrix."); + po.Register("length-tolerance", &length_tolerance, "Tolerance for " + "difference in num-frames between feat and ivector matrices"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + if (num_pdfs <= 0) + KALDI_ERR << "--num-pdfs option is required."; + + if (sparse_input_dim <= 0) + KALDI_ERR << "--sparse-input-dim option is required."; + + std::string feature_rspecifier = po.GetArg(1), + pdf_post_rspecifier = po.GetArg(2), + examples_wspecifier = po.GetArg(3); + + // Read in all the training files. + SequentialPosteriorReader feat_reader(feature_rspecifier); + RandomAccessPosteriorReader pdf_post_reader(pdf_post_rspecifier); + NnetExampleWriter example_writer(examples_wspecifier); + RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + + int32 num_done = 0, num_err = 0; + int64 num_frames_written = 0, num_egs_written = 0; + + for (; !feat_reader.Done(); feat_reader.Next()) { + std::string key = feat_reader.Key(); + SparseMatrix feats(sparse_input_dim, feat_reader.Value()); + if (!pdf_post_reader.HasKey(key)) { + KALDI_WARN << "No pdf-level posterior for key " << key; + num_err++; + } else { + const Posterior &pdf_post = pdf_post_reader.Value(key); + if (pdf_post.size() != feats.NumRows()) { + KALDI_WARN << "Posterior has wrong size " << pdf_post.size() + << " versus " << feats.NumRows(); + num_err++; + continue; + } + const Matrix *ivector_feats = NULL; + if (!ivector_rspecifier.empty()) { + if (!ivector_reader.HasKey(key)) { + KALDI_WARN << "No iVectors for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + ivector_feats = &(ivector_reader.Value(key)); + } + } + + if (ivector_feats != NULL && + (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance + || ivector_feats->NumRows() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and iVectors " << ivector_feats->NumRows() + << "exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + + ProcessFile(feats, ivector_feats, pdf_post, key, compress, + num_pdfs, left_context, right_context, num_frames, + &num_frames_written, &num_egs_written, + &example_writer); + num_done++; + } + } + + KALDI_LOG << "Finished generating examples, " + << "successfully processed " << num_done + << " feature files, wrote " << num_egs_written << " examples, " + << " with " << num_frames_written << " egs in total; " + << num_err << " files had errors."; + return (num_egs_written == 0 || num_err > num_done ? 1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} diff --git a/src/nnet3bin/nnet3-get-egs.cc b/src/nnet3bin/nnet3-get-egs.cc index 75f264f1ceb..38592709542 100644 --- a/src/nnet3bin/nnet3-get-egs.cc +++ b/src/nnet3bin/nnet3-get-egs.cc @@ -32,6 +32,7 @@ namespace nnet3 { static void ProcessFile(const MatrixBase &feats, const MatrixBase *ivector_feats, + const VectorBase *deriv_weights, const Posterior &pdf_post, const std::string &utt_id, bool compress, @@ -75,7 +76,7 @@ static void ProcessFile(const MatrixBase &feats, input_frames)); // if applicable, add the iVector feature. - if (ivector_feats != NULL) { + if (ivector_feats) { // try to get closest frame to middle of window to get // a representative iVector. int32 closest_frame = t + (actual_frames_per_eg / 2); @@ -92,7 +93,17 @@ static void ProcessFile(const MatrixBase &feats, for (int32 i = 0; i < actual_frames_per_eg; i++) labels[i] = pdf_post[t + i]; // remaining posteriors for frames are empty. - eg.io.push_back(NnetIo("output", num_pdfs, 0, labels)); + + if (!deriv_weights) { + eg.io.push_back(NnetIo("output", num_pdfs, 0, labels)); + } else { + Vector this_deriv_weights(frames_per_eg); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, deriv_weights->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) continue; // Ignore frames that have frame weights 0 + eg.io.push_back(NnetIo("output", this_deriv_weights, num_pdfs, 0, labels)); + } + if (compress) eg.Compress(); @@ -143,7 +154,7 @@ int main(int argc, char *argv[]) { int32 num_pdfs = -1, left_context = 0, right_context = 0, num_frames = 1, length_tolerance = 100; - std::string ivector_rspecifier; + std::string ivector_rspecifier, deriv_weights_rspecifier; ParseOptions po(usage); po.Register("compress", &compress, "If true, write egs in " @@ -160,6 +171,11 @@ int main(int argc, char *argv[]) { "features, as a matrix."); po.Register("length-tolerance", &length_tolerance, "Tolerance for " "difference in num-frames between feat and ivector matrices"); + po.Register("deriv-weights-rspecifier", &deriv_weights_rspecifier, + "Per-frame weights (only binary - 0 or 1) that specifies " + "whether a frame's gradient must be backpropagated or not. " + "Not specifying this is equivalent to specifying a vector of " + "all 1s."); po.Read(argc, argv); @@ -181,6 +197,7 @@ int main(int argc, char *argv[]) { RandomAccessPosteriorReader pdf_post_reader(pdf_post_rspecifier); NnetExampleWriter example_writer(examples_wspecifier); RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + RandomAccessBaseFloatVectorReader deriv_weights_reader(deriv_weights_rspecifier); int32 num_done = 0, num_err = 0; int64 num_frames_written = 0, num_egs_written = 0; @@ -192,13 +209,17 @@ int main(int argc, char *argv[]) { KALDI_WARN << "No pdf-level posterior for key " << key; num_err++; } else { - const Posterior &pdf_post = pdf_post_reader.Value(key); - if (pdf_post.size() != feats.NumRows()) { + Posterior pdf_post = pdf_post_reader.Value(key); + if (abs(static_cast(pdf_post.size()) - feats.NumRows()) > length_tolerance + || pdf_post.size() < feats.NumRows()) { KALDI_WARN << "Posterior has wrong size " << pdf_post.size() << " versus " << feats.NumRows(); num_err++; continue; } + while (static_cast(pdf_post.size()) > feats.NumRows()) { + pdf_post.pop_back(); + } const Matrix *ivector_feats = NULL; if (!ivector_rspecifier.empty()) { if (!ivector_reader.HasKey(key)) { @@ -212,7 +233,7 @@ int main(int argc, char *argv[]) { } } - if (ivector_feats != NULL && + if (ivector_feats && (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance || ivector_feats->NumRows() == 0)) { KALDI_WARN << "Length difference between feats " << feats.NumRows() @@ -221,8 +242,33 @@ int main(int argc, char *argv[]) { num_err++; continue; } + + const Vector *deriv_weights = NULL; + if (!deriv_weights_rspecifier.empty()) { + if (!deriv_weights_reader.HasKey(key)) { + KALDI_WARN << "No deriv weights for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights = &(deriv_weights_reader.Value(key)); + } + } + + if (deriv_weights && + (abs(feats.NumRows() - deriv_weights->Dim()) > length_tolerance + || deriv_weights->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + - ProcessFile(feats, ivector_feats, pdf_post, key, compress, + ProcessFile(feats, ivector_feats, deriv_weights, pdf_post, + key, compress, num_pdfs, left_context, right_context, num_frames, &num_frames_written, &num_egs_written, &example_writer); diff --git a/src/nnet3bin/nnet3-latgen-faster.cc b/src/nnet3bin/nnet3-latgen-faster.cc index 6728b6224fd..fa08f3ea7db 100644 --- a/src/nnet3bin/nnet3-latgen-faster.cc +++ b/src/nnet3bin/nnet3-latgen-faster.cc @@ -65,6 +65,8 @@ int main(int argc, char *argv[]) { po.Register("ivectors", &ivector_rspecifier, "Rspecifier for " "iVectors as vectors (i.e. not estimated online); per utterance " "by default, or per speaker if you provide the --utt2spk option."); + po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for " + "utt2spk option used to get ivectors per speaker"); po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for " "iVectors estimated online, as matrices. If you supply this," " you must set the --online-ivector-period option."); diff --git a/src/online2bin/ivector-extract-online2.cc b/src/online2bin/ivector-extract-online2.cc index 3251d93b5dd..6f2c73bbecf 100644 --- a/src/online2bin/ivector-extract-online2.cc +++ b/src/online2bin/ivector-extract-online2.cc @@ -55,6 +55,8 @@ int main(int argc, char *argv[]) { g_num_threads = 8; bool repeat = false; + int32 length_tolerance = 0; + std::string frame_weights_rspecifier; po.Register("num-threads", &g_num_threads, "Number of threads to use for computing derived variables " @@ -62,6 +64,12 @@ int main(int argc, char *argv[]) { po.Register("repeat", &repeat, "If true, output the same number of iVectors as input frames " "(including repeated data)."); + po.Register("frame-weights-rspecifier", &frame_weights_rspecifier, + "Archive of frame weights to scale stats"); + po.Register("length-tolerance", &length_tolerance, + "Tolerance on the difference in number of frames " + "for feats and weights"); + po.Read(argc, argv); if (po.NumArgs() != 3) { @@ -82,9 +90,9 @@ int main(int argc, char *argv[]) { SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); + RandomAccessBaseFloatVectorReader frame_weights_reader(frame_weights_rspecifier); BaseFloatMatrixWriter ivector_writer(ivectors_wspecifier); - for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { std::string spk = spk2utt_reader.Key(); const std::vector &uttlist = spk2utt_reader.Value(); @@ -105,6 +113,31 @@ int main(int argc, char *argv[]) { &matrix_feature); ivector_feature.SetAdaptationState(adaptation_state); + + if (!frame_weights_rspecifier.empty()) { + if (!frame_weights_reader.HasKey(utt)) { + KALDI_WARN << "Did not find weights for utterance " << utt; + num_err++; + continue; + } + const Vector &weights = frame_weights_reader.Value(utt); + + if (std::abs(weights.Dim() - feats.NumRows()) > length_tolerance) { + num_err++; + continue; + } + + std::vector > frame_weights; + for (int32 i = 0; i < feats.NumRows(); i++) { + if (i < weights.Dim()) + frame_weights.push_back(std::make_pair(i, weights(i))); + else + frame_weights.push_back(std::make_pair(i, 0.0)); + } + + + ivector_feature.UpdateFrameWeights(frame_weights); + } int32 T = feats.NumRows(), n = (repeat ? 1 : ivector_config.ivector_period), diff --git a/src/segmenter/Makefile b/src/segmenter/Makefile new file mode 100644 index 00000000000..d8690866c26 --- /dev/null +++ b/src/segmenter/Makefile @@ -0,0 +1,15 @@ +all: + +include ../kaldi.mk + +TESTFILES = segmentation-io-test segmentation-test + +OBJFILES = segmenter.o + +LIBNAME = kaldi-segmenter + +ADDLIBS = ../gmm/kaldi-gmm.a \ + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../thread/kaldi-thread.a + +include ../makefiles/default_rules.mk + diff --git a/src/segmenter/segmentation-io-test.cc b/src/segmenter/segmentation-io-test.cc new file mode 100644 index 00000000000..79f1907d22c --- /dev/null +++ b/src/segmenter/segmentation-io-test.cc @@ -0,0 +1,58 @@ +// segmenter/segmentation-io-test.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmenter.h" + +namespace kaldi { +namespace segmenter { + + +void UnitTestSegmentationIo() { + Segmentation seg; + int32 max_length = 100, num_classes = 3; + seg.GenRandomSegmentation(max_length, num_classes); + + bool binary = (rand() % 2 == 0); + std::ostringstream os; + + seg.Write(os, binary); + + Segmentation seg2; + std::istringstream is(os.str()); + seg2.Read(is, binary); + + std::ostringstream os2; + seg2.Write(os2, binary); + + KALDI_ASSERT(os2.str() == os.str()); +} + +} // namespace segmenter +} // namespace kaldi + +int main() { + using namespace kaldi; + using namespace kaldi::segmenter; + + for (int32 i = 0; i < 100; i++) + UnitTestSegmentationIo(); + return 0; +} + + diff --git a/src/segmenter/segmentation-test.cc b/src/segmenter/segmentation-test.cc new file mode 100644 index 00000000000..0de227e14c8 --- /dev/null +++ b/src/segmenter/segmentation-test.cc @@ -0,0 +1,207 @@ +// segmenter/segmentation-test.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmenter.h" + +namespace kaldi { +namespace segmenter { + + +int32 GenerateRandomAlignment(int32 max_length, int32 num_classes, + std::vector *ali) { + int32 N = RandInt(1, max_length); + int32 C = RandInt(1, num_classes); + + ali->clear(); + + int32 len = 0; + while (len < N) { + int32 c = RandInt(0, C-1); + int32 n = std::min(RandInt(1, N), N - len); + ali->insert(ali->begin() + len, n, c); + len += n; + } + KALDI_ASSERT(ali->size() == N && len == N); + + int32 state = -1, num_segments = 0; + for (std::vector::const_iterator it = ali->begin(); + it != ali->end(); ++it) { + if (*it != state) num_segments++; + state = *it; + } + + return num_segments; +} + +void TestConversionToAlignment() { + std::vector ali; + int32 max_length = 1000, num_classes = 3; + int32 num_segments = GenerateRandomAlignment(max_length, num_classes, &ali); + + Segmentation seg; + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, 0)); + + std::vector out_ali; + { + seg.ConvertToAlignment(&out_ali); + KALDI_ASSERT(ali == out_ali); + } + + { + seg.ConvertToAlignment(&out_ali, num_classes, max_length * 2); + std::vector tmp_ali(out_ali.begin(), out_ali.begin() + ali.size()); + KALDI_ASSERT(ali == tmp_ali); + for (std::vector::const_iterator it = out_ali.begin() + ali.size(); + it != out_ali.end(); ++it) { + KALDI_ASSERT(*it == num_classes); + } + } + + seg.Clear(); + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, max_length)); + { + seg.ConvertToAlignment(&out_ali, num_classes, max_length * 2); + + for (std::vector::const_iterator it = out_ali.begin(); + it != out_ali.begin() + max_length; ++it) { + KALDI_ASSERT(*it == num_classes); + } + std::vector tmp_ali(out_ali.begin() + max_length, out_ali.begin() + max_length + ali.size()); + KALDI_ASSERT(tmp_ali == ali); + + for (std::vector::const_iterator it = out_ali.begin() + max_length + ali.size(); + it != out_ali.end(); ++it) { + KALDI_ASSERT(*it == num_classes); + } + } +} + +void TestRemoveSegments() { + std::vector ali; + int32 max_length = 1000, num_classes = 10; + int32 num_segments = GenerateRandomAlignment(max_length, num_classes, &ali); + + Segmentation seg; + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, 0)); + + for (int32 i = 0; i < num_classes; i++) { + Segmentation out_seg(seg); + out_seg.RemoveSegments(i); + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali, i, ali.size()); + KALDI_ASSERT(ali == out_ali); + } + + { + std::vector classes; + for (int32 i = 0; i < 3; i++) + classes.push_back(RandInt(0, num_classes - 1)); + std::sort(classes.begin(), classes.end()); + + Segmentation out_seg1(seg); + out_seg1.RemoveSegments(classes); + + Segmentation out_seg2(seg); + for (std::vector::const_iterator it = classes.begin(); + it != classes.end(); ++it) + out_seg2.RemoveSegments(*it); + + std::vector out_ali1, out_ali2; + out_seg1.ConvertToAlignment(&out_ali1); + out_seg2.ConvertToAlignment(&out_ali2); + + KALDI_ASSERT(out_ali1 == out_ali2); + } +} + +void TestIntersectSegments() { + int32 max_length = 100, num_classes = 3; + + std::vector primary_ali; + GenerateRandomAlignment(max_length, num_classes, &primary_ali); + + std::vector secondary_ali; + GenerateRandomAlignment(max_length, num_classes, &secondary_ali); + + Segmentation primary_seg; + primary_seg.InsertFromAlignment(primary_ali); + + Segmentation secondary_seg; + secondary_seg.InsertFromAlignment(secondary_ali); + + { + Segmentation out_seg; + primary_seg.IntersectSegments(secondary_seg, &out_seg, num_classes); + + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali); + + std::vector oracle_ali(primary_ali.size()); + + for (size_t i = 0; i < oracle_ali.size(); i++) { + int32 p = (i < primary_ali.size()) ? primary_ali[i] : -1; + int32 s = (i < secondary_ali.size()) ? secondary_ali[i] : -2; + + oracle_ali[i] = (p == s) ? p : num_classes; + } + + KALDI_ASSERT(oracle_ali == out_ali); + } + + { + Segmentation out_seg; + primary_seg.IntersectSegments(secondary_seg, &out_seg); + + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali, num_classes); + + std::vector oracle_ali(out_ali.size()); + + for (size_t i = 0; i < oracle_ali.size(); i++) { + int32 p = (i < primary_ali.size()) ? primary_ali[i] : -1; + int32 s = (i < secondary_ali.size()) ? secondary_ali[i] : -2; + + oracle_ali[i] = (p == s) ? p : num_classes; + } + + KALDI_ASSERT(oracle_ali == out_ali); + } + +} + +void UnitTestSegmentation() { + TestConversionToAlignment(); + TestRemoveSegments(); + TestIntersectSegments(); +} + +} // namespace segmenter +} // namespace kaldi + +int main() { + using namespace kaldi; + using namespace kaldi::segmenter; + + for (int32 i = 0; i < 10; i++) + UnitTestSegmentation(); + return 0; +} + + + diff --git a/src/segmenter/segmenter.cc b/src/segmenter/segmenter.cc new file mode 100644 index 00000000000..b0d98a6bbba --- /dev/null +++ b/src/segmenter/segmenter.cc @@ -0,0 +1,1535 @@ +#include "segmenter/segmenter.h" +#include + +namespace kaldi { +namespace segmenter { + +void Segment::Write(std::ostream &os, bool binary) const { + if (binary) { + os.write(reinterpret_cast(&start_frame), sizeof(start_frame)); + os.write(reinterpret_cast(&end_frame), sizeof(start_frame)); + os.write(reinterpret_cast(&class_id), sizeof(class_id)); + } else { + WriteBasicType(os, binary, start_frame); + WriteBasicType(os, binary, end_frame); + WriteBasicType(os, binary, Label()); + } +} + +void Segment::Read(std::istream &is, bool binary) { + if (binary) { + is.read(reinterpret_cast(&start_frame), sizeof(start_frame)); + is.read(reinterpret_cast(&end_frame), sizeof(end_frame)); + is.read(reinterpret_cast(&class_id), sizeof(class_id)); + } else { + ReadBasicType(is, binary, &start_frame); + ReadBasicType(is, binary, &end_frame); + int32 label; + ReadBasicType(is, binary, &label); + SetLabel(label); + } +} + +void HistogramEncoder::Initialize(int32 num_bins, BaseFloat bin_w, BaseFloat min_s) { + bin_sizes.clear(); + bin_sizes.resize(num_bins, 0); + bin_width = bin_w; + min_score = min_s; +} + +void HistogramEncoder::Encode(BaseFloat x, int32 n) { + int32 i = (x - min_score ) / bin_width; + if (i < 0) i = 0; + if (i >= NumBins()) i = NumBins() - 1; + bin_sizes[i] += n; +} + +void Segmentation::GenRandomSegmentation(int32 max_length, int32 num_classes) { + Clear(); + int32 s = max_length; + int32 e = max_length; + + while (s >= 0) { + int32 chunk_size = rand() % (max_length / 10); + s = e - chunk_size + 1; + int32 k = rand() % num_classes; + + if (k != 0) { + Segment seg(s,e,k); + segments_.push_front(seg); + dim_++; + } + e = s - 1; + } + Check(); +} + +/** + * This function splits an input segmentation in_segmentation into pieces of + * approximately segment_length. Each piece is given the same class id as the + * original segment. + * The way this function is written is that it first figures out the number of + * pieces that the segment must be broken into. Then it creates that many pieces + * of equal size (actual_segment_length). +**/ +void Segmentation::SplitSegments( + const Segmentation &in_segmentation, + int32 segment_length) { + Clear(); + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + int32 length = it->Length(); + + // Adding 0.5 here makes num_chunks like the ceil of the number that it + // actually is. This results in all pieces to be smaller than + // segment_length rather than being larger. + int32 num_chunks = (static_cast(length)) / segment_length + 0.5; + int32 actual_segment_length = static_cast(length) / num_chunks + + 0.5; + + int32 start_frame = it->start_frame; + for (int32 j = 0; j < num_chunks; j++) { + int32 end_frame = std::min(start_frame + actual_segment_length - 1, + it->end_frame); + Emplace(start_frame, end_frame, it->class_id); + start_frame = end_frame + 1; + } + } + Check(); +} + +/** + * This function splits the current segmentation into pieces of segment_length. + * But if the last remaining piece is smaller than min_remainder, then the last + * piece is merged to the piece before it, resulting in a piece that is of + * length < segment_length + min_remainder. + * The way this function works it is it looks at the current segment length and + * checks if it is larger than segment_length + min_remainder. If it is larger, + * then it must be split. To do this, it first modifies the start_frame of + * the current frame to start_frame + segment_length - overlap. + * It then creates a new segment of length segment_length from the original + * start_frame to start_frame + segment_length - 1 and adds it just before the + * current segment. So in the next iteration, we would actually be back to the + * same segment, but whose start_frame had just been modified. +**/ +void Segmentation::SplitSegments(int32 segment_length, + int32 min_remainder, int32 overlap, + int32 label) { + KALDI_ASSERT(overlap < segment_length); + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end(); ++it) { + if (label != -1 && it->Label() != label) continue; + + int32 start_frame = it->start_frame; + int32 length = it->Length(); + + if (length > segment_length + min_remainder) { + // Split segment + // To show what this is doing, consider the following example, where it is + // currently pointing to B. + // A <--> B <--> C + + // Modify the start_frame of the current frame. This prepares the current + // segment to be used as the "next segment" when we move the iterator in + // the next statement. + // In the example, the start_frame for B has just been modified. + it->start_frame = start_frame + segment_length - overlap; + + // Create a new segment and add it to the where the current iterator is. + // The statement below results in this: + // A <--> B1 <--> B <--> C + // with the iterator it pointing at B1. So when the iterator is + // incremented in the for loop, it will point to B again, but whose + // start_frame had been modified. + it = segments_.emplace(it, start_frame, + start_frame + segment_length - 1, it->Label()); + + // Forward list code + // it->end_frame = start_frame + segment_length - 1; + // it = segments_.emplace(it+1, it->end_frame + 1, end_frame, it->Label()); + dim_++; + } + } + Check(); +} + +/** + * This function is very straight forward. It just merges the labels in + * merge_labels to the class_id dest_label. This means any segment that originally + * had the class_id as any of the labels in merge_labels would end up having the + * class_id dest_label. +**/ +void Segmentation::MergeLabels(const std::vector &merge_labels, + int32 dest_label) { + std::is_sorted(merge_labels.begin(), merge_labels.end()); + int32 size = 0; + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end(); ++it, size++) { + if (std::binary_search(merge_labels.begin(), merge_labels.end(), it->Label())) { + it->SetLabel(dest_label); + } + } + KALDI_ASSERT(size == Dim()); + Check(); +} + +/** + * This function is used to merge segments next to each other in the SegmentList + * and within a distance of max_intersegment_length frames from each other, + * provided the segments are of the same class_id. +**/ +void Segmentation::MergeAdjacentSegments(int32 max_intersegment_length) { + for (SegmentList::iterator it = segments_.begin(), prev_it = segments_.begin(); + it != segments_.end();) { + + if (it != segments_.begin() && + it->Label() == prev_it->Label() && + prev_it->end_frame + max_intersegment_length >= it->start_frame) { + // merge segments + if (prev_it->end_frame < it->end_frame) { + // This is to avoid cases with overlapping segments where the previous + // segment ends after the current segment ends. In that case, the + // current segment must be simply removed. + // Otherwise the previous segment must be modified to cover the current + // segment. + prev_it->end_frame = it->end_frame; + } + it = segments_.erase(it); + dim_--; + } else { + // no merging of segments + prev_it = it; + ++it; + } + } + Check(); +} + +/** + * Create a HistogramEncoder object based on this segmentation +**/ +void Segmentation::CreateHistogram( + int32 label, const Vector &scores, + const HistogramOptions &opts, HistogramEncoder *hist_encoder) { + if (Dim() == 0) + KALDI_ERR << "Segmentation must not be empty"; + + int32 num_bins = opts.num_bins; + BaseFloat min_score = std::numeric_limits::infinity(); + BaseFloat max_score = -std::numeric_limits::infinity(); + + mean_scores_.clear(); + mean_scores_.resize(Dim(), std::numeric_limits::quiet_NaN()); + + std::vector num_frames(Dim(), 0); + + int32 i = 0; + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end(); ++it, i++) { + if (it->Label() != label) continue; + SubVector this_segment_scores(scores, it->start_frame, it->end_frame - it->start_frame + 1); + BaseFloat mean_score = this_segment_scores.Sum() / this_segment_scores.Dim(); + + mean_scores_[i] = mean_score; + num_frames[i] = this_segment_scores.Dim(); + + if (mean_score > max_score) max_score = mean_score; + if (mean_score < min_score) min_score = mean_score; + } + KALDI_ASSERT(i == mean_scores_.size()); + + if (opts.select_above_mean) { + min_score = scores.Sum() / scores.Dim(); + } + + BaseFloat bin_width = (max_score - min_score) / num_bins; + hist_encoder->Initialize(num_bins, bin_width, min_score); + + hist_encoder->select_from_full_histogram = opts.select_from_full_histogram; + + i = 0; + for (SegmentList::const_iterator it = segments_.begin(); it != segments_.end(); ++it, i++) { + if (it->Label() != label) continue; + KALDI_ASSERT(!KALDI_ISNAN(mean_scores_[i])); + + if (opts.select_above_mean && mean_scores_[i] < min_score) continue; + KALDI_ASSERT(mean_scores_[i] >= min_score); + + hist_encoder->Encode(mean_scores_[i], num_frames[i]); + } + KALDI_ASSERT(i == mean_scores_.size()); + Check(); +} + +int32 Segmentation::SelectTopBins( + const HistogramEncoder &hist_encoder, + int32 src_label, int32 dst_label, int32 reject_label, + int32 num_frames_select, bool remove_rejected_frames) { + KALDI_ASSERT(mean_scores_.size() == Dim()); + KALDI_ASSERT(dst_label >=0 && reject_label >= 0); + KALDI_ASSERT(num_frames_select >= 0); + + BaseFloat min_score_for_selection = std::numeric_limits::infinity(); + int32 num_top_frames = 0, i = hist_encoder.NumBins() - 1; + while (i >= (hist_encoder.select_from_full_histogram ? 0 : (hist_encoder.NumBins() / 2))) { + num_top_frames += hist_encoder.BinSize(i); + if (num_top_frames >= num_frames_select) { + num_top_frames -= hist_encoder.BinSize(i); + if (num_top_frames == 0) { + num_top_frames += hist_encoder.BinSize(i); + i--; + } + break; + } + i--; + } + min_score_for_selection = hist_encoder.min_score + (i+1) * hist_encoder.bin_width; + + i = 0; + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end(); i++) { + if (it->Label() != src_label) { + ++it; + continue; + } + KALDI_ASSERT(!KALDI_ISNAN(mean_scores_[i])); + if (mean_scores_[i] >= min_score_for_selection) { + it->SetLabel(dst_label); + ++it; + } else { + if (remove_rejected_frames) { + it = segments_.erase(it); + dim_--; + } else { + it->SetLabel(reject_label); + ++it; + } + } + } + KALDI_ASSERT(i == mean_scores_.size()); + + if (remove_rejected_frames) mean_scores_.clear(); + + Check(); + return num_top_frames; +} + +int32 Segmentation::SelectBottomBins( + const HistogramEncoder &hist_encoder, + int32 src_label, int32 dst_label, int32 reject_label, + int32 num_frames_select, bool remove_rejected_frames) { + KALDI_ASSERT(mean_scores_.size() == Dim()); + KALDI_ASSERT(dst_label >=0 && reject_label >= 0); + KALDI_ASSERT(num_frames_select >= 0); + + BaseFloat max_score_for_selection = -std::numeric_limits::infinity(); + int32 num_bottom_frames = 0, i = 0; + while (i < (hist_encoder.select_from_full_histogram ? hist_encoder.NumBins() : (hist_encoder.NumBins() / 2))) { + num_bottom_frames += hist_encoder.BinSize(i); + if (num_bottom_frames >= num_frames_select) { + num_bottom_frames -= hist_encoder.BinSize(i); + if (num_bottom_frames == 0) { + num_bottom_frames += hist_encoder.BinSize(i); + i++; + } + break; + } + i++; + } + max_score_for_selection = hist_encoder.min_score + i * hist_encoder.bin_width; + + i = 0; + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end(); i++) { + if (it->Label() != src_label) { + ++it; + continue; + } + KALDI_ASSERT(!KALDI_ISNAN(mean_scores_[i])); + if (mean_scores_[i] < max_score_for_selection) { + it->SetLabel(dst_label); + ++it; + } else { + if (remove_rejected_frames) { + it = segments_.erase(it); + dim_--; + } else { + it->SetLabel(reject_label); + ++it; + } + } + } + KALDI_ASSERT(i == mean_scores_.size()); + + if (remove_rejected_frames) mean_scores_.clear(); + + Check(); + return num_bottom_frames; +} + +std::pair Segmentation::SelectTopAndBottomBins( + const HistogramEncoder &hist_encoder, + int32 src_label, int32 top_label, int32 num_frames_top, + int32 bottom_label, int32 num_frames_bottom, + int32 reject_label, bool remove_rejected_frames) { + KALDI_ASSERT(mean_scores_.size() == Dim()); + KALDI_ASSERT(top_label >= 0 && bottom_label >= 0 && reject_label >= 0); + KALDI_ASSERT(num_frames_top >= 0 && num_frames_bottom >= 0); + + BaseFloat min_score_for_selection = std::numeric_limits::infinity(); + int32 num_selected_top = 0, i = hist_encoder.NumBins() - 1; + while (i >= hist_encoder.NumBins() / 2) { + int32 this_selected = hist_encoder.BinSize(i); + num_selected_top += this_selected; + if (num_selected_top >= num_frames_top) { + num_selected_top -= this_selected; + if (num_selected_top == 0) { + num_selected_top += this_selected; + i--; + } + break; + } + i--; + } + min_score_for_selection = hist_encoder.min_score + (i+1) * hist_encoder.bin_width; + + BaseFloat max_score_for_selection = -std::numeric_limits::infinity(); + int32 num_selected_bottom= 0; + i = 0; + while (i < hist_encoder.NumBins() / 2) { + int32 this_selected = hist_encoder.BinSize(i); + num_selected_bottom += this_selected; + if (num_selected_bottom >= num_frames_bottom) { + num_selected_bottom -= this_selected; + if (num_selected_bottom == 0) { + num_selected_bottom += this_selected; + i++; + } + break; + } + i++; + } + max_score_for_selection = hist_encoder.min_score + i * hist_encoder.bin_width; + + i = 0; + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end(); i++) { + if (it->Label() != src_label) { + ++it; + continue; + } + KALDI_ASSERT(!KALDI_ISNAN(mean_scores_[i])); + if (mean_scores_[i] >= min_score_for_selection) { + it->SetLabel(top_label); + ++it; + } else if (mean_scores_[i] < max_score_for_selection) { + it->SetLabel(bottom_label); + ++it; + } else { + if (remove_rejected_frames) { + it = segments_.erase(it); + dim_--; + } else { + it->SetLabel(reject_label); + ++it; + } + } + } + KALDI_ASSERT(i == mean_scores_.size()); + + if (remove_rejected_frames) mean_scores_.clear(); + + Check(); + return std::make_pair(num_selected_top, num_selected_bottom); +} + +/** + * This function intersects the segmentation with the filter segmentation + * and includes only sub-segments where the filter segmentation has the label + * filter_label. + * If filter_label is -1, the filter_label would dynamically change to be the + * label of the primary segmentation. + * For e.g. if the segmentation is + * start_frame end_frame label + * 5 10 1 + * 8 12 2 + * and filter_segmentation is + * 0 7 1 + * 7 10 2 + * 10 13 1. + * And filter_label is 1. Then after intersection, this + * object would hold + * 5 7 1 + * 8 10 2 + * 10 12 2 +**/ +void Segmentation::IntersectSegments( + const Segmentation &secondary_segmentation, + Segmentation *out_seg, int32 mismatch_label) const { + KALDI_ASSERT(secondary_segmentation.Dim() > 0); + KALDI_ASSERT(out_seg); + out_seg->Clear(); + SegmentList::const_iterator s_it = secondary_segmentation.Begin(); + for (SegmentList::const_iterator p_it = Begin(); p_it != End(); ++p_it) { + if (s_it == secondary_segmentation.End()) + --s_it; + // This statement was necessary so that it would not crash at the next + // statement. + KALDI_ASSERT(s_it != secondary_segmentation.End()); + + // If the secondary segment start is beyond the start of the current + // segment, then move the secondary segment iterator back. + while (s_it != secondary_segmentation.Begin() && + s_it->start_frame > p_it->start_frame) + --s_it; + KALDI_ASSERT(s_it != secondary_segmentation.End()); + + // Now, we can move the secondary segment iterator until the end of the + // secondary segment is just at the current segment. + // There are two possibilities here: + // (i) The secondary segment ends are on either side of the primary's + // start_frame. + // i.e. s_it->start_frame < p_it->start_frame <= s_it->end_frame + // (ii) The secondary segment ends are after the primary's start_frame. + // i.e. p_it->start_frame <= s_it->start_frame <= s_it->end_frame + while (s_it != secondary_segmentation.End() && + s_it->end_frame < p_it->start_frame) ++s_it; + + // Actual intersection is done here. + int32 start_frame = p_it->start_frame; + for (; s_it != secondary_segmentation.End() && + s_it->start_frame <= p_it->end_frame; ++s_it) { + int32 new_label = p_it->Label(); + if (s_it->Label() != p_it->Label()) { + new_label = mismatch_label; + } + + if (start_frame < s_it->start_frame) { + // This is the first part of handling case (ii) + if (mismatch_label != -1) + out_seg->Emplace(start_frame, s_it->start_frame - 1, + mismatch_label); + start_frame = s_it->start_frame; + } + + KALDI_ASSERT(start_frame == std::max(p_it->start_frame, s_it->start_frame)); + // Once this is done, it reduces to case (i) + + if (s_it->end_frame < p_it->end_frame) { + // This is the case in which there is a part of primary segment that is + // not intersected by the secondary segment. We split the primary + // segment into two parts, the first of which is created using the + // emplace statement below and the second is created by modifying the + // current segment to the contain only the part after secondary + // segment's end_frame. + if (start_frame <= s_it->end_frame) { + if (new_label != -1) + out_seg->Emplace(start_frame, s_it->end_frame, new_label); + start_frame = s_it->end_frame + 1; + } + KALDI_ASSERT(start_frame <= p_it->end_frame); + } else { // if (s_it->end_frame > p_it->end_frame) + if (new_label != -1) + out_seg->Emplace(start_frame, p_it->end_frame, new_label); + start_frame = p_it->end_frame + 1; + } + } + + if (s_it == secondary_segmentation.End()) { + --s_it; + if (start_frame < p_it->end_frame) { + if (mismatch_label != -1) + out_seg->Emplace(start_frame, p_it->end_frame, mismatch_label); + start_frame = p_it->end_frame + 1; + } + ++s_it; + } + } +} + +/* +void Segmentation::IntersectSegments( + const Segmentation &filter_segmentation, + int32 filter_label) { + SegmentList::iterator it = segments_.begin(); + SegmentList::const_iterator filter_it = filter_segmentation.Begin(); + + int32 orig_filter_label = filter_label; + + while (it != segments_.end()) { + if (orig_filter_label == -1) filter_label = it->Label(); + + // Move the filter iterator up to the first segment where the end point of + // the filter segment is just after the start of the current segment. + while (filter_it != filter_segmentation.End() && + (filter_it->end_frame < it->start_frame || + filter_it->Label() != filter_label)) { + ++filter_it; + } + + // If the filter has reached the end, then we are done + if (filter_it == filter_segmentation.End()) { + while (it != segments_.end()) { + // There is no segment in the filter_segmentation beyond this. So the + // intersection is empty. Hence erase the remaining segments. + it = segments_.erase(it); + dim_--; + } + break; + } + + // If the segment in the filter is beyond the end of the current segment, + // then there is no intersection between this segment and the + // filter_segmentation. Hence remove this segment. + if (filter_it->start_frame > it->end_frame) { + it = segments_.erase(it); + dim_--; + continue; + } + + // Filter start_frame is after the start_frame of this segment. + // So throw away the initial part of this segment as it is not in the + // filter. i.e. Set the start of this segment to be the start of the filter + // segment. + if (filter_it->start_frame > it->start_frame) + it->start_frame = filter_it->start_frame; + + if (filter_it->end_frame < it->end_frame) { + // filter segment ends before the end of the current segment. Then end + // the current segment right at the end of the filter and leave the + // iterator at the remaining part. + int32 start_frame = it->start_frame; + it->start_frame = filter_it->end_frame + 1; + segments_.emplace(it, start_frame, filter_it->end_frame, it->Label()); + dim_++; + } else { + // filter segment ends after the end of this current segment. So + // we don't need to create any new segment. Just advance the iterator + // to the next segment. + ++it; + } + } + Check(); +} +*/ + +/** + * A very straight forward function to extend this segmentation by adding + * segments from another segmentation seg. If sort is called, the segments would + * be sorted after extension. This can be skipped if its known that the segments + * would be sorted. +**/ +void Segmentation::Extend(const Segmentation &seg, bool sort) { + for (SegmentList::const_iterator it = seg.Begin(); it != seg.End(); ++it) { + segments_.push_back(*it); + dim_++; + } + if (sort) Sort(); +} + +/** + * This function is a little complicated in what it does. But this is required + * for one of the applications. This function creates a new segmentation by + * sub-segmenting an overlapping primary segmentation and assign new class_id to + * the regions where the primary segmentation intersects the non-overlapping + * secondary segmentation segments with class_id secondary_label. + * This is similar to the function "IntersectSegments", but instead of keeping + * only the filtered subsegments, all the subsegments are kept, while only + * changing the class_id of the filtered sub-segments. + * For the sub-segments, where the secondary segment class_id is + * secondary_label, the created sub-segment is labeled "subsegment_label", + * provided it is non-negative. Otherwise, the created sub-segment is labeled + * the class_id of the secondary segment. + * For the other sub-segments, where the secondary segment + * class_id is not secondary_label, the created sub-segment retains the class_id + * of the parent segment. + * Additionally this program adds the secondary segmentation's vector_value + * along with this segmentation's string_value if they exist. +**/ + +void Segmentation::SubSegmentUsingNonOverlappingSegments( + const Segmentation &secondary_segmentation, int32 secondary_label, + int32 subsegment_label, Segmentation *out_seg) const { + KALDI_ASSERT(secondary_segmentation.Dim() > 0); + KALDI_ASSERT(secondary_segmentation.IsNonOverlapping()); + SegmentList::const_iterator s_it = secondary_segmentation.Begin(); + for (SegmentList::const_iterator p_it = Begin(); p_it != End(); ++p_it) { + if (s_it == secondary_segmentation.End()) + --s_it; + // This statement was necessary so that it would not crash at the next + // statement. + KALDI_ASSERT(s_it != secondary_segmentation.End()); + + // If the secondary segment start is beyond the start of the current + // segment, then move the secondary segment iterator back. + while (s_it != secondary_segmentation.Begin() && + s_it->start_frame > p_it->start_frame) + --s_it; + KALDI_ASSERT(s_it != secondary_segmentation.End()); + + // Now, we can move the secondary segment iterator until the end of the + // secondary segment is just at the current segment. + // There are two possibilities here: + // (i) The secondary segment ends are on either side of the primary's + // start_frame. + // i.e. s_it->start_frame < p_it->start_frame <= s_it->end_frame + // (ii) The secondary segment ends are after the primary's start_frame. + // i.e. p_it->start_frame <= s_it->start_frame <= s_it->end_frame + while (s_it != secondary_segmentation.End() && + s_it->end_frame < p_it->start_frame) ++s_it; + + // Actual intersection is done here. + int32 start_frame = p_it->start_frame; + for (; s_it != secondary_segmentation.End() && + s_it->start_frame <= p_it->end_frame; ++s_it) { + int32 new_label = p_it->Label(); + + if (s_it->Label() == secondary_label) { + new_label = (subsegment_label >= 0 ? subsegment_label : s_it->Label()); + } + + if (start_frame < s_it->start_frame) { + // This is the first part of handling case (ii) + out_seg->Emplace(start_frame, s_it->start_frame - 1, + p_it->Label()); + start_frame = s_it->start_frame; + } // Once this is done, it reduces to case (i) + KALDI_ASSERT(start_frame == std::max(p_it->start_frame, s_it->start_frame)); + + if (s_it->end_frame < p_it->end_frame) { + // This is the case in which there is a part of primary segment that is + // not intersected by the secondary segment. We split the primary + // segment into two parts, the first of which is created using the + // emplace statement below and the second is created by modifying the + // current segment to the contain only the part after secondary + // segment's end_frame. + out_seg->Emplace(start_frame, s_it->end_frame, new_label); + start_frame = s_it->end_frame + 1; + KALDI_ASSERT(start_frame <= p_it->end_frame); + } else { // if (s_it->end_frame > p_it->end_frame) + out_seg->Emplace(start_frame, p_it->end_frame, new_label); + } + } + } +} + +void Segmentation::SubSegmentUsingSmallOverlapSegments( + const Segmentation &secondary_segmentation, int32 secondary_label, + int32 subsegment_label, Segmentation *out_seg) const { + // TODO: When the secondary segmentation has overlap, it just considers the + // label of the latest segment. + KALDI_ASSERT(secondary_segmentation.Dim() > 0); + KALDI_ASSERT(secondary_segmentation.HasSmallOverlap()); + SegmentList::const_iterator s_it = secondary_segmentation.Begin(); + for (SegmentList::const_iterator p_it = Begin(); p_it != End(); ++p_it) { + if (s_it == secondary_segmentation.End()) + --s_it; + // This statement was necessary so that it would not crash at the next + // statement. + KALDI_ASSERT(s_it != secondary_segmentation.End()); + + // If the secondary segment start is beyond the start of the current + // segment, then move the secondary segment iterator back. + while (s_it != secondary_segmentation.Begin() && + s_it->start_frame > p_it->start_frame) + --s_it; + KALDI_ASSERT(s_it != secondary_segmentation.End()); + + // Now, we can move the secondary segment iterator until the end of the + // secondary segment is just at the current segment. + // There are two possibilities here: + // (i) The secondary segment ends are on either side of the primary's + // start_frame. + // i.e. s_it->start_frame < p_it->start_frame <= s_it->end_frame + // (ii) The secondary segment ends are after the primary's start_frame. + // i.e. p_it->start_frame <= s_it->start_frame <= s_it->end_frame + while (s_it != secondary_segmentation.End() && + s_it->end_frame < p_it->start_frame) ++s_it; + + // Actual intersection is done here. + int32 start_frame = p_it->start_frame; + for (; s_it != secondary_segmentation.End() && + s_it->start_frame <= p_it->end_frame; ++s_it) { + SegmentList::const_iterator s_it_next(s_it); + ++s_it_next; + + int32 end_frame = s_it->end_frame; + if (s_it_next != secondary_segmentation.End() && + s_it_next->start_frame <= s_it->end_frame) { + end_frame = s_it_next->start_frame - 1; + } + + int32 new_label = p_it->Label(); + + if (s_it->Label() == secondary_label) { + new_label = (subsegment_label >= 0 ? subsegment_label : s_it->Label()); + } + + if (start_frame < s_it->start_frame) { + // This is the first part of handling case (ii) + out_seg->Emplace(start_frame, s_it->start_frame - 1, + p_it->Label()); + start_frame = s_it->start_frame; + } // Once this is done, it reduces to case (i) + KALDI_ASSERT(start_frame == std::max(p_it->start_frame, s_it->start_frame)); + + if (end_frame < p_it->end_frame) { + // This is the case in which there is a part of primary segment that is + // not intersected by the secondary segment. We split the primary + // segment into two parts, the first of which is created using the + // emplace statement below and the second is created by modifying the + // current segment to the contain only the part after secondary + // segment's end_frame. + out_seg->Emplace(start_frame, end_frame, new_label); + start_frame = end_frame + 1; + KALDI_ASSERT(start_frame <= p_it->end_frame); + } else { // if (end_frame > p_it->end_frame) + out_seg->Emplace(start_frame, p_it->end_frame, new_label); + } + } + } +} + + +// const Segmentation &nonoverlapping_segmentation, +// int32 secondary_label, int32 subsegment_label, +// Segmentation *out_segmentation) const { +// out_segmentation->Clear(); +// SegmentList::const_iterator s_it = nonoverlapping_segmentation.Begin(); +// for (SegmentList::const_iterator p_it = Begin(); p_it != End(); ++p_it) { +// if (s_it == nonoverlapping_segmentation.End()) --s_it; +// // This statement was necessary so that it would not crash at the next +// // statement. +// +// +// // The following two statements may be a little inefficient and there might +// // be better way to do this. This is a TODO. +// +// // If the secondary segment start is beyond the start of the current +// // segment, then move the secondary segment iterator back. +// while (s_it->start_frame > p_it->start_frame) --s_it; +// // Now, we can move the secondary segment iterator until the end of the +// // secondary segment is just before the current segment. +// while (s_it != nonoverlapping_segmentation.End() && +// s_it->end_frame < p_it->start_frame) ++s_it; +// // This is so that state is equalized and we can be sure that always, the +// // secondary segment is just one segment before the current segment. +// +// // Actual intersection is done here. +// for (; s_it->start_frame <= p_it->end_frame && +// s_it != nonoverlapping_segmentation.End(); ++s_it) { +// int32 new_label = p_it->Label(); +// +// if (s_it->Label() == secondary_label) { +// new_label = (subsegment_label >= 0 ? +// subsegment_label : s_it->Label()); +// } +// +// out_segmentation->Emplace( +// std::max(s_it->start_frame, p_it->start_frame), +// std::min(s_it->end_frame, p_it->end_frame), new_label, +// s_it->VectorValue(), p_it->StringValue()); +// } +// } +//} + +/** +void Segmentation::CreateSubSegmentsOld( + const Segmentation &filter_segmentation, + int32 filter_label, + int32 subsegment_label) { + SegmentList::iterator it = segments_.begin(); + SegmentList::const_iterator filter_it = filter_segmentation.Begin(); + + while (it != segments_.end()) { + + // If the start of the segment in the filter is before the current + // segment then move the filter iterator up to the first segment where the + // end point of the filter segment is just after the start of the current + // segment + while (filter_it != filter_segmentation.End() && + (filter_it->end_frame < it->start_frame || + filter_it->Label() != filter_label)) { + ++filter_it; + } + + // If the filter has reached the end, then we are done + if (filter_it == filter_segmentation.End()) { + break; + } + + // If the segment in the filter is beyond the end of the current segment, + // then increment the iterator until the current segment end + // point is just after the start of the filter segment + if (filter_it->start_frame > it->end_frame) { + ++it; + continue; + } + + // filter start_frame is after the start_frame of this segment. + // So split the segment into two parts at filter_start. + // Create a new segment for the + // first part which retains the same label as before. + // For now, retain the same label for the second part. The + // label would change while processing the end of the subsegment. + if (filter_it->start_frame > it->start_frame) { + segments_.emplace(it, it->start_frame, filter_it->start_frame - 1, it->Label()); + dim_++; + it->start_frame = filter_it->start_frame; + } + + if (filter_it->end_frame < it->end_frame) { + // filter segment ends before the end of the current segment. Then end + // the current segment right at the end of the filter and leave the + // remaining part for the next segment + int32 start_frame = it->start_frame; + it->start_frame = filter_it->end_frame + 1; + segments_.emplace(it, start_frame, filter_it->end_frame, subsegment_label); + dim_++; + } else { + // filter segment ends after the end of this current segment. + // So change the label of this frame to + // subsegment_label + it->SetLabel(subsegment_label); + ++it; + } + } + Check(); +} +**/ + +/** + * This function is used to widen segments of class_id "label" by "length" frames + * on either side of the segment. This is useful to widen segments of speech. + * While widening, it also reduces the length of the segment adjacent to it. + * This may not be required in some applications, but it is ok for speech / + * silence. We are calling frames within a "length" number of frames near the + * speech segment as speech and hence we reduce the width of the silence segment + * before it. +**/ +void Segmentation::WidenSegments(int32 label, int32 length) { + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end(); ++it) { + if (it->Label() == label) { + if (it != segments_.begin()) { + // it is not the beginning of the segmentation, so we can widen it on + // the start_frame side + SegmentList::iterator prev_it = it; + --prev_it; + it->start_frame -= length; + if (prev_it->Label() == label && it->start_frame < prev_it->end_frame) { + // After widening this segment, it overlaps the previous segment that + // also has the same class_id. Then turn this segment into a composite + // one + it->start_frame = prev_it->start_frame; + // and remove the previous segment from the list. + Erase(prev_it); + } else if (prev_it->Label() != label && + it->start_frame < prev_it->end_frame) { + // Previous segment is not the same class_id, so we cannot turn this into + // a composite segment. + if (it->start_frame <= prev_it->start_frame) { + // The extended segment absorbs the previous segment into it + // So remove the previous segment + Erase(prev_it); + } else { + // The extended segment reduces the length of the previous + // segment. But does not completely overlap it. + prev_it->end_frame -= length; + if (prev_it->end_frame < prev_it->start_frame) Erase(prev_it); + } + } + if (it->start_frame < 0) it->start_frame = 0; + } else { + it->start_frame -= length; + if (it->start_frame < 0) it->start_frame = 0; + } + + SegmentList::iterator next_it = it; + ++next_it; + + if (next_it != segments_.end()) + // We do not know the length of the file. + //So we don't want to extend the last one. + it->end_frame += length; // Line (1) + } else { // if (it->Label() != label) + if (it != segments_.begin()) { + SegmentList::iterator prev_it = it; + --prev_it; + if (prev_it->end_frame >= it->end_frame) { + // The extended previous segment in Line (1) completely + // overlaps the current segment. So remove the current segment. + it = Erase(it); + --it; // So that we can increment in the for loop + } else if (prev_it->end_frame >= it->start_frame) { + // The extended previous segment in Line (1) reduces the length of + // this segment. + it->start_frame = prev_it->end_frame + 1; + } + } + } + } +} + +void Segmentation::ShrinkSegments(int32 label, int32 length) { + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end();) { + if (it->Label() == label) { + if (it->Length() <= 2 * length) { + it = segments_.erase(it); + dim_--; + } else { + it->start_frame += length; + it->end_frame -= length; + ++it; + } + } else + ++it; + } +} + +/** + * This function relabels segments of class_id "label" that are shorter than + * max_length frames, provided the segments before and after it are of the same + * class_id "other_label". Now all three segments have the same class_id + * "other_label" and hence can be merged into a composite segment. + * An example where this is useful is when there is a short segment of silence + * with speech segments on either sides. Then the short segment of silence is + * removed and called speech instead. The three continguous segments of speech + * are merged into a single composite segment. +**/ +void Segmentation::RelabelShortSegments(int32 label, int32 max_length) { + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end();) { + if (it == segments_.begin()) { + ++it; + continue; + } + + SegmentList::iterator next_it = it; + ++next_it; + if (next_it == segments_.end()) break; + + SegmentList::iterator prev_it = it; + --prev_it; + + if (next_it->Label() == prev_it->Label() && it->Label() == label + && it->Length() < max_length) { + prev_it->end_frame = next_it->end_frame; + segments_.erase(it); + it = segments_.erase(next_it); + dim_ -= 2; + } else + ++it; + } +} + +/** + * This is very straight forward. It removes all segments of class_id "label" +**/ +void Segmentation::RemoveSegments(int32 label) { + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end();) { + if (it->Label() == label) { + it = segments_.erase(it); + dim_--; + } else { + ++it; + } + } + Check(); +} + +/** + * This is very straight forward. It removes any segment whose class_id is + * contained in the vector "labels" +**/ +void Segmentation::RemoveSegments(const std::vector &labels) { + KALDI_ASSERT(std::is_sorted(labels.begin(), labels.end())); + for (SegmentList::iterator it = segments_.begin(); + it != segments_.end();) { + if (std::binary_search(labels.begin(), labels.end(), it->Label())) { + it = segments_.erase(it); + dim_--; + } else { + ++it; + } + } + Check(); +} + +void Segmentation::Clear() { + segments_.clear(); + dim_ = 0; + mean_scores_.clear(); +} + +void Segmentation::Read(std::istream &is, bool binary) { + Clear(); + + if (binary) { + int32 sz = is.peek(); + if (sz == Segment::SizeOf()) { + is.get(); + } else { + KALDI_ERR << "Segmentation::Read: expected to see Segment of size " + << Segment::SizeOf() << ", saw instead " << sz + << ", at file position " << is.tellg(); + } + + int32 segmentssz; + is.read(reinterpret_cast(&segmentssz), sizeof(segmentssz)); + if (is.fail() || segmentssz < 0) + KALDI_ERR << "Segmentation::Read: read failure at file position " + << is.tellg(); + + for (int32 i = 0; i < segmentssz; i++) { + Segment seg; + seg.Read(is, binary); + segments_.push_back(seg); + } + dim_ = segmentssz; + } else { + if (int c = is.peek() != static_cast('[')) { + KALDI_ERR << "Segmentation::Read: expected to see [, saw " + << static_cast(c) << ", at file position " << is.tellg(); + } + is.get(); // consume the '[' + while (is.peek() != static_cast(']')) { + KALDI_ASSERT(!is.eof()); + Segment seg; + seg.Read(is, binary); + segments_.push_back(seg); + dim_++; + is >> std::ws; + } + is.get(); + KALDI_ASSERT(!is.eof()); + } + Check(); +} + +void Segmentation::Write(std::ostream &os, bool binary) const { + SegmentList::const_iterator it = segments_.begin(); + if (binary) { + char sz = Segment::SizeOf(); + os.write(&sz, 1); + + int32 segmentssz = static_cast(Dim()); + KALDI_ASSERT((size_t)segmentssz == Dim()); + + os.write(reinterpret_cast(&segmentssz), sizeof(segmentssz)); + + for (; it != segments_.end(); ++it) { + it->Write(os, binary); + } + } else { + os << "[ "; + for (; it != segments_.end(); ++it) { + it->Write(os, binary); + os << std::endl; + } + os << "]" << std::endl; + } +} + +/** + * This function is used to write the segmentation in RTTM format. Each class is + * treated as a "SPEAKER". If map_to_speech_and_sil is true, then the class_id 0 + * is treated as SILENCE and every other class_id as SPEECH. The argument + * start_time is used to set what the time corresponding to the 0 frame in the + * segment. Each segment is converted into the following line, + * SPEAKER 1 + * ,where + * is the file_id supplied as an argument + * is the start time of the segment in seconds + * is the length of the segment in seconds + * is the class_id stored in the segment. If map_to_speech_and_sil is + * set true then is either SPEECH or SILENCE. + * The function retunns the largest class_id that it encounters. +**/ +int32 Segmentation::WriteRttm(std::ostream &os, const std::string &file_id, const std::string &channel, + BaseFloat frame_shift, BaseFloat start_time, + bool map_to_speech_and_sil) const { + SegmentList::const_iterator it = segments_.begin(); + int32 largest_class = 0; + for (; it != segments_.end(); ++it) { + os << "SPEAKER " << file_id << " " << channel << " " + << it->start_frame * frame_shift + start_time << " " + << (it->Length()) * frame_shift << " "; + if (map_to_speech_and_sil) { + switch (it->Label()) { + case 1: + os << "SPEECH "; + break; + default: + os << "SILENCE "; + break; + } + largest_class = 1; + } else { + if (it->Label() >= 0) { + os << it->Label() << " "; + if (it->Label() > largest_class) + largest_class = it->Label(); + } + } + os << "" << std::endl; + } + return largest_class; +} + +/** + * This function is used to convert the segmentation into frame-level alignment + * with the label for each frame begin the class_id of segment the frame belongs + * to. + * The arguments are used to provided extended functionality that are required + * for most cases. + * default_label : the label that is used as filler in regions where the frame + * is not in any of the segments. In most applications, certain + * segments are removed, such as the ones that are silence. Then + * the segments would not span the entire duration of the file. + * e.g. + * 10 35 1 + * 41 190 2 + * ... + * Here there is no segment from 36-40. These frames are + * filled with default_label. + * length : the number of frames required in the alignment. In most + * applications, the length of the alignment required is known. + * Usually it must match the length of the features (obtained + * using feat-to-len). Then the alignment is resized to this + * length and filled with default_label. The segments are then + * read and the frames corresponding to the segments are + * relabeled with the class_id of the respective segments. + * tolerance : the tolerance in number of frames that we allow for the + * frame index corresponding to the end_frame of the last + * segment. Since, we use 25 ms widows with 10 ms frame shift, + * it is possible that the features length is 2 frames less than + * the end of the last segment. So the user can set the + * tolerance to 2 in order to avoid returning with error in this + * function. +**/ +bool Segmentation::ConvertToAlignment(std::vector *alignment, + int32 default_label, int32 length, + int32 tolerance) const { + KALDI_ASSERT(alignment != NULL); + alignment->clear(); + + if (length != -1) { + KALDI_ASSERT(length >= 0); + alignment->resize(length, default_label); + } + + SegmentList::const_iterator it = segments_.begin(); + for (; it != segments_.end(); ++it) { + if (length != -1 && it->end_frame >= length + tolerance) { + KALDI_WARN << "End frame (" << it->end_frame << ") " + << ">= length + tolerance (" << length + tolerance << ")." + << "Conversion failed."; + return false; + } + + int32 end_frame = it->end_frame; + if (length == -1) { + alignment->resize(it->end_frame + 1, default_label); + } else { + if (it->end_frame >= length) + end_frame = length - 1; + } + + KALDI_ASSERT(end_frame < alignment->size()); + for (size_t i = it->start_frame; i <= end_frame; i++) { + (*alignment)[i] = it->Label(); + } + } + return true; +} + +int32 Segmentation::InsertFromAlignment( + const std::vector &alignment, + int32 start_time_offset, + std::vector *frame_counts_per_class) { + if (alignment.size() == 0) return 0; + + int32 num_segments = 0; + int32 state = -1, start_frame = -1; + for (int32 i = 0; i < alignment.size(); i++) { + if (alignment[i] != state) { + // Change of state i.e. a different class id. + // So the previous segment has ended. + if (state != -1) { + // state == -1 in the beginning of the alignment. That is just + // initialization step and hence no creation of segment. + Emplace(start_frame + start_time_offset, + i-1 + start_time_offset, state); + num_segments++; + + if (frame_counts_per_class) { + if (frame_counts_per_class->size() <= state) { + frame_counts_per_class->resize(state + 1, 0); + } + (*frame_counts_per_class)[state] += i - start_frame; + } + } + start_frame = i; + state = alignment[i]; + } + } + + KALDI_ASSERT(state >= 0 && start_frame < alignment.size()); + Emplace(start_frame + start_time_offset, + alignment.size()-1 + start_time_offset, state); + num_segments++; + if (frame_counts_per_class) { + if (frame_counts_per_class->size() <= state) { + frame_counts_per_class->resize(state + 1, 0); + } + (*frame_counts_per_class)[state] += alignment.size() - start_frame; + } + + return num_segments; +} + +int32 Segmentation::InsertFromSegmentation( + const Segmentation &seg, + int32 start_time_offset, + std::vector *frame_counts_per_class) { + if (seg.Dim() == 0) return 0; + + int32 num_segments = 0; + + for (SegmentList::const_iterator it = seg.Begin(); it != seg.End(); ++it) { + Emplace(it->start_frame + start_time_offset, + it->end_frame + start_time_offset, it->Label()); + num_segments++; + if (frame_counts_per_class) { + if (frame_counts_per_class->size() <= it->Label()) { + frame_counts_per_class->resize(it->Label() + 1, 0); + } + (*frame_counts_per_class)[it->Label()] += it->Length(); + } + } + + return num_segments; +} + +void Segmentation::Check() const { + int32 dim = 0; + for (SegmentList::const_iterator it = segments_.begin(); + it != segments_.end(); ++it, dim++) { + KALDI_ASSERT(it->Label() >= 0); + }; + KALDI_ASSERT(dim == dim_); + KALDI_ASSERT(mean_scores_.size() == 0 || mean_scores_.size() == dim_); +} + +bool Segmentation::IsNonOverlapping() const { + int32 end_frame = Begin()->end_frame; + int32 start_frame = Begin()->start_frame; + for (SegmentList::const_iterator it = Begin(); it != End(); ++it) { + if (it == Begin()) continue; + if (it->start_frame <= end_frame || it->start_frame < start_frame) + return false; + end_frame = it->end_frame; + start_frame = it->start_frame; + } + return true; +} + +bool Segmentation::HasSmallOverlap() const { + int32 end_frame = Begin()->end_frame; + int32 start_frame = Begin()->start_frame; + for (SegmentList::const_iterator it = Begin(); it != End(); ++it) { + if (it == Begin()) continue; + if (it->end_frame < end_frame || it->start_frame < start_frame) + return false; + SegmentList::const_iterator next_it = it; + ++next_it; + if (next_it != End() && next_it->start_frame <= end_frame) + return false; + end_frame = it->end_frame; + start_frame = it->start_frame; + } + return true; +} + +SegmentationPostProcessor::SegmentationPostProcessor( + const SegmentationPostProcessingOptions &opts) : opts_(opts) { + if (!opts_.filter_in_fn.empty()) { + if (ClassifyRspecifier(opts_.filter_in_fn, NULL, NULL) == + kNoRspecifier) { + bool binary_read; + Input ki(opts_.filter_in_fn, &binary_read); + filter_segmentation_.Read(ki.Stream(), binary_read); + } else { + filter_reader_.Open(opts_.filter_in_fn); + } + } + + if (!opts_.remove_labels_csl.empty()) { + if (!SplitStringToIntegers(opts_.remove_labels_csl, ":", + false, &remove_labels_)) { + KALDI_ERR << "Bad value for --remove-labels option: " + << opts_.remove_labels_csl; + } + std::sort(remove_labels_.begin(), remove_labels_.end()); + } + + if (!opts_.merge_labels_csl.empty()) { + if (!SplitStringToIntegers(opts_.merge_labels_csl, ":", + false, &merge_labels_)) { + KALDI_ERR << "Bad value for --merge-labels option: " + << opts_.merge_labels_csl; + } + std::sort(merge_labels_.begin(), merge_labels_.end()); + } + + Check(); +} + +void SegmentationPostProcessor::Check() const { + if ( IsFilteringToBeDone() && opts_.post_process_label < 0) { + KALDI_ERR << "Invalid value " << opts_.post_process_label << " for option " + << "--post-process-label. It must be non-negative."; + } + + if (IsWideningSegmentsToBeDone() && opts_.widen_label < 0) { + KALDI_ERR << "Invalid value " << opts_.widen_label << " for option " + << "--widen-label. It must be non-negative."; + } + + if (IsWideningSegmentsToBeDone() && opts_.widen_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.widen_length << " for option " + << "--widen-length. It must be positive."; + } + + if (IsShrinkingSegmentsToBeDone() && opts_.shrink_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.shrink_length << " for option " + << "--shrink-length. It must be positive."; + } + + if (IsRelabelingShortSegmentsToBeDone() && + opts_.relabel_short_segments_class < 0) { + KALDI_ERR << "Invalid value " << opts_.relabel_short_segments_class + << " for option " << "--relabel-short-segments-class. " + << "It must be non-negative."; + } + + if (IsRelabelingShortSegmentsToBeDone() && opts_.max_relabel_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.max_relabel_length << " for option " + << "--max-relabel-length. It must be positive."; + } + + if (IsRemovingSegmentsToBeDone() && remove_labels_[0] < 0) { + KALDI_ERR << "Invalid value " << opts_.remove_labels_csl + << " for option " << "--remove-labels. " + << "The labels must be non-negative."; + } + + if (IsMergingAdjacentSegmentsToBeDone() && + opts_.max_intersegment_length < 0) { + KALDI_ERR << "Invalid value " << opts_.max_intersegment_length + << " for option " + << "--max-intersegment-length. It must be non-negative."; + } + + if (IsSplittingSegmentsToBeDone() && opts_.max_segment_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.max_segment_length + << " for option " + << "--max-segment-length. It must be positive."; + } + + if (opts_.post_process_label != -1 && opts_.post_process_label < 0) { + KALDI_ERR << "Invalid value " << opts_.post_process_label << " for option " + << "--post-process-label. It must be non-negative."; + } +} + +bool SegmentationPostProcessor::FilterAndPostProcess(Segmentation *seg, const + std::string *key) { + if (!key) { + Filter(seg); + } else { + if (!Filter(*key, seg)) return false; + } + + return PostProcess(seg); +} + +bool SegmentationPostProcessor::PostProcess(Segmentation *seg) const { + MergeLabels(seg); + WidenSegments(seg); + ShrinkSegments(seg); + RelabelShortSegments(seg); + RemoveSegments(seg); + MergeAdjacentSegments(seg); + SplitSegments(seg); + + return true; +} + +void SegmentationPostProcessor::Filter(Segmentation *seg) const { + if (!IsFilteringToBeDone()) return; + KALDI_ASSERT(ClassifyRspecifier(opts_.filter_in_fn, NULL, NULL) == + kNoRspecifier); + Segmentation tmp_seg(*seg); + tmp_seg.IntersectSegments(filter_segmentation_, seg); +} + +bool SegmentationPostProcessor::Filter(const std::string &key, + Segmentation *seg) { + if (!IsFilteringToBeDone()) return true; + KALDI_ASSERT(ClassifyRspecifier(opts_.filter_in_fn, NULL, NULL) != + kNoRspecifier); + if (!filter_reader_.HasKey(key)) { + KALDI_WARN << "Could not find filter for utterance " << key; + if (!opts_.ignore_missing_filter_keys) return false; + return true; + } + + Segmentation tmp_seg(*seg); + tmp_seg.IntersectSegments(filter_reader_.Value(key), seg); + return true; +} + +void SegmentationPostProcessor::MergeLabels(Segmentation *seg) const { + if (!IsMergingLabelsToBeDone()) return; + seg->MergeLabels(merge_labels_, opts_.merge_dst_label); +} + +void SegmentationPostProcessor::WidenSegments(Segmentation *seg) const { + if (!IsWideningSegmentsToBeDone()) return; + seg->WidenSegments(opts_.widen_label, opts_.widen_length); +} + +void SegmentationPostProcessor::ShrinkSegments(Segmentation *seg) const { + if (!IsShrinkingSegmentsToBeDone()) return; + seg->ShrinkSegments(opts_.widen_label, opts_.shrink_length); +} + +void SegmentationPostProcessor::RelabelShortSegments(Segmentation *seg) const { + if (!IsRelabelingShortSegmentsToBeDone()) return; + seg->RelabelShortSegments(opts_.relabel_short_segments_class, + opts_.max_relabel_length); +} + +void SegmentationPostProcessor::RemoveSegments(Segmentation *seg) const { + if (!IsRemovingSegmentsToBeDone()) return; + seg->RemoveSegments(remove_labels_); +} + +void SegmentationPostProcessor::MergeAdjacentSegments(Segmentation *seg) const { + if (!IsMergingAdjacentSegmentsToBeDone()) return; + seg->MergeAdjacentSegments(opts_.max_intersegment_length); +} + +void SegmentationPostProcessor::SplitSegments(Segmentation *seg) const { + if (!IsSplittingSegmentsToBeDone()) return; + seg->SplitSegments(opts_.max_segment_length, opts_.max_segment_length / 2, + opts_.overlap_length, + opts_.post_process_label); +} + +} +} diff --git a/src/segmenter/segmenter.h b/src/segmenter/segmenter.h new file mode 100644 index 00000000000..70cd060ba73 --- /dev/null +++ b/src/segmenter/segmenter.h @@ -0,0 +1,633 @@ +#ifndef SEGMENTER_H +#define SEGMENTER_H + +#include +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" +#include "segmenter/segmenter.h" +#include "util/kaldi-table.h" +#include "itf/options-itf.h" + +namespace kaldi { +namespace segmenter { + +// ClassId is just an integer for now. We could change it +// later if needed. +typedef int32 ClassId; + +/** + * This structure defines a single segment. It consists of the following basic + * properties: + * 1) start_frame : This is the frame index of the first frame in the + * segment. + * 2) end_frame : This is the frame index of the last frame in the segment. + * Note that the end_frame is included in the segment. + * 3) class_id : This is the class corresponding to the segments. For e.g., + * could be 0, 1 or 2 depending on whether the segment is + * silence, speech or noise. In general, it can be any + * integer class label. + * Some other properties that a segment might hold temporarily are + * vector_value : This is some real valued vector such as average energy or + * ivector for the segment. + * string_value : Some string value such as segment_id that is characteristic + * of the segment. +**/ + +struct Segment { + int32 start_frame; + int32 end_frame; + ClassId class_id; + Vector vector_value; + std::string string_value; + + // Accessors for labels or class id. This is useful in the future when + // we might change the type of label. + inline int32 Label() const { return class_id; } + inline void SetLabel(int32 label) { class_id = label; } + inline int32 Length() const { return end_frame - start_frame + 1; } + + // This is the default constructor that sets everything to undefined values. + Segment() : start_frame(-1), end_frame(-1), class_id(-1) { } + + // This constructor initializes the segmented with the provided start and end + // frames and the segment label. This is the main constructor. + Segment(int32 start, int32 end, int32 label) : + start_frame(start), end_frame(end), class_id(label) { } + + + // This constructor is an extension to the above main constructor and also + // initializes the vector_value of the segment. + Segment(int32 start, int32 end, int32 label, const Vector& vec) : + Segment(start, end, label) { + vector_value.Resize(vec.Dim()); + vector_value.CopyFromVec(vec); + } + + // This constructor is an extension to the above constructor and + // initializes the string_value along with the vector_value + Segment(int32 start, int32 end, int32 label, + const Vector& vec, const std::string &str) : + Segment(start, end, label, vec) { + string_value = str; + } + + // This constructor is an extension to the main constructor and + // additionally initializes the string_value of the segment. + Segment(int32 start, int32 end, int32 label, const std::string &str) : + Segment(start, end, label) { + string_value = str; + } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + // This is a function that returns the size of the elements in the structure. + // It is used during I/O in binary mode, which checks for the total size + // required to store the segment. + static size_t SizeOf() { + return (sizeof(int32) + sizeof(int32) + sizeof(ClassId)); + } + + // Accessors to get vector and string values corresponding to the segment. + const Vector& VectorValue() const { return vector_value; } + const std::string& StringValue() const { return string_value; } + + // Accessor to set the vector value not during initialization. + void SetVectorValue(const VectorBase& vec) { + vector_value.Resize(vec.Dim()); + vector_value.CopyFromVec(vec); + } +}; + +/** This structure is used to encode some vector of real values into bins. This + * is mainly used in the classification of segments into speech, silence and + * noise depending on the vector of frame-level energy and/or zero-crossing + * of the frames in the segment. +**/ + +struct HistogramEncoder { + // Width of the bins in the histogram of real values + BaseFloat bin_width; + + // Minimum score corresponding to the lowest bin of the histogram + BaseFloat min_score; + + // This is a vector that stores the number of real values contained in the + // different bins. + std::vector bin_sizes; + + // A flag that is relevant only in a particular function. See the comments + // in Encode function for details. + bool select_from_full_histogram; + + // default constructor + HistogramEncoder(): bin_width(-1), + min_score(std::numeric_limits::infinity()), + select_from_full_histogram(false) {} + + // Accessors for different quantities + inline int32 NumBins() const { return bin_sizes.size(); } + inline int32 BinSize(int32 i) const { return bin_sizes[i]; } + + // Initialize the container to a specific number of bins and also size + // and the value each bin represents. + void Initialize(int32 num_bins, BaseFloat bin_w, BaseFloat min_s); + + // Insert the real value 'x' with a count of 'n' times into the appropriate + // bin in the histogram. + void Encode(BaseFloat x, int32 n); +}; + +/** + * Structure for some common options related to segmentation that would be used + * in multiple segmentation programs. Some of the operations include merging, + * filtering etc. +**/ + +struct SegmentationPostProcessingOptions { + std::string filter_in_fn; + int32 filter_label; + bool ignore_missing_filter_keys; + std::string merge_labels_csl; + int32 merge_dst_label; + int32 widen_label; + int32 widen_length; + int32 shrink_label; + int32 shrink_length; + int32 relabel_short_segments_class; + int32 max_relabel_length; + std::string remove_labels_csl; + bool merge_adjacent_segments; + int32 max_intersegment_length; + int32 max_segment_length; + int32 overlap_length; + int32 post_process_label; + + SegmentationPostProcessingOptions() : + filter_label(-1), ignore_missing_filter_keys(false), merge_dst_label(-1), + widen_label(-1), widen_length(-1), + shrink_label(-1), shrink_length(-1), + relabel_short_segments_class(-1), max_relabel_length(-1), + merge_adjacent_segments(false), max_intersegment_length(0), + max_segment_length(-1), overlap_length(0), post_process_label(-1) { } + + void Register(OptionsItf *opts) { + opts->Register("filter-in-fn", &filter_in_fn, + "The segmentation that is used as a filter for the " + "Intersection or Filtering post-processing operation. " + "Refer to the IntersectSegments() code for details. " + "Used in conjunction with the option --filter-label."); + //opts->Register("filter-label", &filter_label, "The label on which the " + // "Intersection or Filtering operation is done. " + // "Refer to the IntersectSegments() code for details. " + // "Used in conjunction with the options --filter-in-fn."); + opts->Register("ignore-missing-filter-keys", &ignore_missing_filter_keys, + "If this is true and a key could not be found in the " + "filter, the post-processing skips the Filtering operation. " + "Otherwise, it counts it as an error. Applicable only when " + "--filter-in-fn is an archive. " + "Used in conjunction with the option --filter-in-fn."); + opts->Register("merge-labels", &merge_labels_csl, "Merge labels into a " + "single label defined by merge-dst-label." + "The labels are specified as a colon-separated list. " + "Refer to the MergeLabels() code for details. " + "Used in conjunction with the option --merge-dst-label"); + opts->Register("merge-dst-label", &merge_dst_label, + "Merge labels into this label. " + "Refer to the MergeLabels() code for details. " + "Used in conjunction with the option --merge-labels."); + opts->Register("widen-label", &widen_label, + "Widen segments of this class_id " + "by shrinking the adjacent segments of other class_ids or " + "merging with adjacent segments of the same class_id. " + "Refer to the WidenSegments() code for details. " + "Used in conjunction with the option --widen-length."); + opts->Register("widen-length", &widen_length, "Widen segments by this many " + "frames on either side. " + "See option --widen-label for details. " + "Refer to the WidenSegments() code for details. " + "Used in conjunction with the option --widen-label."); + opts->Register("shrink-label", &shrink_label, + "Shrink segments of this class_id " + "by shrinking the adjacent segments of other class_ids or " + "merging with adjacent segments of the same class_id. " + "Refer to the ShrinkSegments() code for details. " + "Used in conjunction with the option --widen-length."); + opts->Register("shrink-length", &shrink_length, "Shrink segments by this many " + "frames on either side. " + "See option --shrink-label for details. " + "Refer to the ShrinkSegments() code for details. " + "Used in conjunction with the option --shrink-label."); + opts->Register("relabel-short-segments-class", &relabel_short_segments_class, + "The class_id for which the short segments are to be " + "relabeled as the class_id of the neighboring segments. " + "Refer to RelabelShortSegments() code for details. " + "Used in conjunction with the option --max-relabel-length."); + opts->Register("max-relabel-length", &max_relabel_length, + "The maximum length of segment in number of frames that " + "will be relabeled to the class-id of the adjacent " + "segments, provided the adjacent segments both have the " + "same class-id. " + "Refer to RelabelShortSegments() code for details. " + "Used in conjunction with the option " + "--relabel-short-segments-class"); + opts->Register("remove-labels", &remove_labels_csl, + "Remove any segment whose class_id is contained in " + "remove_labels_csl. " + "Refer to the RemoveLabels() code for details."); + opts->Register("merge-adjacent-segments", &merge_adjacent_segments, + "Merge adjacent segments of the same label if they are " + "within max-intersegment-length distance. " + "Refer to the MergeAdjacentSegments() code for details. " + "Used in conjunction with the option " + "--max-intersegment-length\n"); + opts->Register("max-intersegment-length", &max_intersegment_length, + "The maximum intersegment length that is allowed for " + "two adjacent segments to be merged. " + "Refer to the MergeAdjacentSegments() code for details. " + "Used in conjunction with the option " + "--merge-adjacent-segments\n"); + opts->Register("max-segment-length", &max_segment_length, + "If segment is longer than this length, split it into " + "pieces with less than these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --overlap-length."); + opts->Register("overlap-length", &overlap_length, + "When splitting segments longer than max-segment-length, " + "have the pieces overlap by these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --max-segment-length."); + opts->Register("post-process-label", &post_process_label, + "Do post processing only on this label. This option is " + "applicable to only a few operations including " + "SplitSegments"); + //opts->Register("mask-rspecifier", &mask_rspecifier, "Unselect " + // "those regions that have label mask-label in this " + // "mask-segmentation"); + //opts->Register("mask-label", &mask_label, "The label on which the " + // "masking is done"); + } +}; + +/** + * Structure for options for histogram encoding +**/ + +struct HistogramOptions { + int32 num_bins; + bool select_above_mean; + bool select_from_full_histogram; + + HistogramOptions() : num_bins(100), select_above_mean(false), select_from_full_histogram(false) {} + + void Register(OptionsItf *opts) { + opts->Register("num-bins", &num_bins, "Number of bins in the histogram " + "created using the scores. Use larger number of bins to " + "make a finer selection"); + opts->Register("select-above-mean", &select_above_mean, "If true, " + "use mean as the reference instead of min"); + opts->Register("select-from-full-histogram", &select_from_full_histogram, + "Do not restrict selection to one half"); + + } + +}; + +/** + * Comparator to order segments based on start frame +**/ + +class SegmentComparator { + public: + bool operator() (const Segment &lhs, const Segment &rhs) const { + return lhs.start_frame < rhs.start_frame; + } +}; + +// Segments are stored as a doubly-linked-list. This could be changed later +// if needed. Hence defining a typedef SegmentList. +typedef std::list SegmentList; + +/** + * The main class to store segmentation and do operations on it. The segments + * are stored in the structure SegmentList, which is currently a doubly-linked + * list. + * See the .cc file for details of implementation of the different functions. + * This file gives only a small description of the functions. +**/ + +class Segmentation { + public: + // Default constructor + Segmentation() { + Clear(); + } + + // Create random segmentation. Useful for debugging purposes. + void GenRandomSegmentation(int32 max_length, int32 num_classes); + + // Split the input segmentation into pieces of approximately + // segment_length and store it in this segmentation. + // Most probably, you want to use the split segments version that is below + // this one. + void SplitSegments(const Segmentation &in_segments, + int32 segment_length); + + // Split this segmentation into pieces of size + // segment_length such that the last remaining piece + // is not longer than min_remainder. + // Optionally create overlapping pieces with the number + // of overlapping frames specified by overlap. + // Typically used to create 1s windows from 10 minute long chunks + void SplitSegments(int32 segment_length, + int32 min_remainder, int32 overlap = 0, + int32 label = -1); + + // Modify this segmentation to merge labels in merge_labels vector into a + // single label dest_label. + // e.g Merge noise and silence into a single silence label + void MergeLabels(const std::vector &merge_labels, + int32 dest_label); + + // Merge adjacent segments of the same label. "Adjacent" is defined as being + // within max_intersegment_length of each other. i.e. start_frame of next + // segment must not be greater than max_intersegment_length away from + // end_frame of the current segment. + void MergeAdjacentSegments(int32 max_intersegment_length = 1); + + // Create a Histogram Encoder that can map a segment to + // a bin based on the average score + void CreateHistogram(int32 label, const Vector &score, + const HistogramOptions &opts, HistogramEncoder *hist); + + // Modify this segmentation to select the top bins in the + // histogram. Assumes that this segmentation also has the + // average scores. + int32 SelectTopBins(const HistogramEncoder &hist, + int32 src_label, int32 dst_label, int32 reject_label, + int32 num_frames_select, bool remove_rejected_frames); + + // Modify this segmentation to select the bottom bins in the histogram. + // Assumes that this segmentation also has the average scores. + int32 SelectBottomBins(const HistogramEncoder &hist, + int32 src_label, int32 dst_label, int32 reject_label, + int32 num_frames_select, bool remove_rejected_frames); + + // Modify this segmentation to select the top and bottom bins in the + // histogram. Assumes that this segmentation also has the average scores. + std::pair SelectTopAndBottomBins( + const HistogramEncoder &hist_encoder, + int32 src_label, int32 top_label, int32 num_frames_top, + int32 bottom_label, int32 num_frames_bottom, + int32 reject_label, bool remove_rejected_frames); + + // Initialize this segmentation from in_segmentation. + // But select subsegments of this segmentation by including + // only regions for which the "filter_segmentation" has + // the label "filter_label". + //void IntersectSegments(const Segmentation &in_segmentation, + // const Segmentation &filter_segmentation, + // int32 filter_label); + + // Select subsegments of this segmentation by including + // only regions for which the "filter_segmentation" has + // the label "filter_label". + // For e.g. if the segmentation is + // start_frame end_frame label + // 5 10 1 + // 8 12 2 + // and filter_segmentation is + // 0 7 1 + // 7 10 2 + // 10 13 1. + // And filter_label is 1. Then after intersection, this + // object would hold + // 5 7 1 + // 8 10 2 + // 10 12 2 + void IntersectSegments(const Segmentation &secondary_segmentation, + Segmentation *out_seg, + int32 mismatch_label = -1) const; + + // Extend a segmentation by adding another one. By default, the + // resultant segmentation would be sorted. If its known that the other + // segmentation must all be after this segmentation, then sort may be given + // false. + void Extend(const Segmentation &other_seg, bool sort = true); + + // Create new segmentation by sub-segmenting this segmentation and + // assign new labels to the filtered regions from secondary segmentation. + // This is similar to "IntersectSegments", but instead of keeping only + // the filtered subsegments, all the subsegments are kept, while only + // changing the labels of the filtered subsegment to "subsegment_label". + // Additionally this program adds the secondary segmentation's + // vector_value along with this segmentation's string_value + void CreateSubSegments(const Segmentation &secondary_segmentation, + int32 secondary_label, int32 subsegment_label, + Segmentation *out_seg) const { + SubSegmentUsingNonOverlappingSegments(secondary_segmentation, + secondary_label, subsegment_label, + out_seg); + } + + void SubSegmentUsingNonOverlappingSegments(const Segmentation &secondary_segmentation, + int32 secondary_label, int32 subsegment_label, + Segmentation *out_seg) const; + + void SubSegmentUsingSmallOverlapSegments(const Segmentation &secondary_segmentation, + int32 secondary_label, int32 subsegment_label, + Segmentation *out_seg) const; + + void CreateSubSegmentsOld(const Segmentation &filter_segmentation, + int32 filter_label, + int32 subsegment_label); + + // Widen segments of label "label" by "length" frames + // on either side. But don't increase the length beyond the + // neighboring segment. Also if the neighboring segment is + // of a different type than "label", that segment is + // shortened to fix the boundary betten the segment and the + // neighbor + void WidenSegments(int32 label, int32 length); + void ShrinkSegments(int32 label, int32 length); + + // Relabel segments of label "label" if they have a length + // less than "max_length", label "label" and the previous + // and next segments have the same label (not necessarily "label") + // The three contiguous segments have the same label and hence are merged + // together. + void RelabelShortSegments(int32 label, int32 max_length); + + // Remove segments of label "label" + void RemoveSegments(int32 label); + + // Remove segments of labels "labels" + void RemoveSegments(const std::vector &labels); + + // Reset segmentation i.e. clear all values + void Clear(); + + // Read segmentation object from input stream + void Read(std::istream &is, bool binary); + + // Write segmentation object to output stream + void Write(std::ostream &os, bool binary) const; + + // Write the segmentation in the form of an RTTM + int32 WriteRttm(std::ostream &os, const std::string &file_id, const std::string &channel, + BaseFloat frame_shift, BaseFloat start_time, + bool map_to_speech_and_sil) const; + + // Convert current segmentation to alignment + bool ConvertToAlignment(std::vector *alignment, + int32 default_label = 0, int32 length = -1, + int32 tolerance = 2) const; + + // Insert segments created from alignment whose 0th frame corresponds to + // start_time_offset + int32 InsertFromAlignment(const std::vector &alignment, + int32 start_time_offset = 0, + std::vector *frame_counts_per_class = NULL); + + int32 InsertFromSegmentation(const Segmentation &seg, + int32 start_time_offset = 0, + std::vector *frame_counts_per_class = NULL); + + // The following functions construct new segment in-place in the + // segmentation and increments the dim_ of the segmentation. There's one + // emplace for each constructor in Segment. + inline void Emplace(int32 start_frame, int32 end_frame, ClassId class_id) { + dim_++; + segments_.emplace_back(start_frame, end_frame, class_id); + } + + inline void Emplace(int32 start_frame, int32 end_frame, ClassId class_id, + const Vector &vec) { + dim_++; + segments_.emplace_back(start_frame, end_frame, class_id, vec); + } + + inline void Emplace(int32 start_frame, int32 end_frame, ClassId class_id, + const std::string &str) { + dim_++; + segments_.emplace_back(start_frame, end_frame, class_id, str); + } + + inline void Emplace(int32 start_frame, int32 end_frame, ClassId class_id, + const Vector &vec, const std::string &str) { + dim_++; + segments_.emplace_back(start_frame, end_frame, class_id, vec, str); + } + + // Call erase operation on the SegmentList and returns the iterator pointing + // to the next segment in the SegmentList and also decrements dim_. + inline SegmentList::iterator Erase(SegmentList::iterator it) { + dim_--; + return segments_.erase(it); + } + + // Check if all segments have class_id >=0 and if dim_ matches the number of + // segments. + void Check() const; + + // Check if segmentation is non-overlapping. + bool IsNonOverlapping() const; + + // Check if segmentation does not have large overlaps. + bool HasSmallOverlap() const; + + // Sort the segments on the start_frame + inline void Sort() { segments_.sort(SegmentComparator()); }; + + // Public accessors + inline int32 Dim() const { return dim_; } + SegmentList::iterator Begin() { return segments_.begin(); } + SegmentList::const_iterator Begin() const { return segments_.cbegin(); } + SegmentList::iterator End() { return segments_.end(); } + SegmentList::const_iterator End() const { return segments_.cend(); } + + const SegmentList* Data() const { return &segments_; } + + private: + // number of segments in the segmentation + int32 dim_; + + // list of segments in the segmentation + SegmentList segments_; + + // the score for each segment in the segmentation. If it has a non-zero + // size, then the size must equal dim_. + std::vector mean_scores_; + + friend class SegmentationPostProcessor; +}; + +typedef TableWriter > SegmentationWriter; +typedef SequentialTableReader > SequentialSegmentationReader; +typedef RandomAccessTableReader > RandomAccessSegmentationReader; +typedef RandomAccessTableReaderMapped > RandomAccessBaseFloatMatrixReaderMapped; + +class SegmentationPostProcessor { + public: + explicit SegmentationPostProcessor( + const SegmentationPostProcessingOptions &opts); + + bool FilterAndPostProcess(Segmentation *seg, const std::string *key = NULL); + bool PostProcess(Segmentation *seg) const; + + bool Filter(const std::string &key, Segmentation *seg); + void Filter(Segmentation *seg) const; + void MergeLabels(Segmentation *seg) const; + void WidenSegments(Segmentation *seg) const; + void ShrinkSegments(Segmentation *seg) const; + void RelabelShortSegments(Segmentation *seg) const; + void RemoveSegments(Segmentation *seg) const; + void MergeAdjacentSegments(Segmentation *seg) const; + void SplitSegments(Segmentation *seg) const; + + private: + const SegmentationPostProcessingOptions &opts_; + std::vector merge_labels_; + std::vector remove_labels_; + Segmentation filter_segmentation_; + RandomAccessSegmentationReader filter_reader_; + + inline bool IsFilteringToBeDone() const { + return (!opts_.filter_in_fn.empty()); + } + + inline bool IsMergingLabelsToBeDone() const { + return (!opts_.merge_labels_csl.empty() || opts_.merge_dst_label != -1); + } + + inline bool IsWideningSegmentsToBeDone() const { + return (opts_.widen_label != -1 || opts_.widen_length != -1); + } + + inline bool IsShrinkingSegmentsToBeDone() const { + return (opts_.shrink_label != -1 || opts_.shrink_length != -1); + } + + inline bool IsRelabelingShortSegmentsToBeDone() const { + return (opts_.relabel_short_segments_class != -1 || opts_.max_relabel_length != -1); + } + + inline bool IsRemovingSegmentsToBeDone() const { + return (!opts_.remove_labels_csl.empty()); + } + + inline bool IsMergingAdjacentSegmentsToBeDone() const { + return (opts_.merge_adjacent_segments); + } + + inline bool IsSplittingSegmentsToBeDone() const { + return (opts_.max_segment_length != -1); + } + + void Check() const; +}; + +} +} + +#endif // SEGMENTER_H diff --git a/src/segmenterbin/Makefile b/src/segmenterbin/Makefile new file mode 100644 index 00000000000..6547ac01656 --- /dev/null +++ b/src/segmenterbin/Makefile @@ -0,0 +1,28 @@ + +all: + +EXTRA_CXXFLAGS = -Wno-sign-compare +include ../kaldi.mk + +BINFILES = segmentation-copy segmentation-init-from-segments \ + segmentation-init-from-ali segmentation-select-top \ + gmm-acc-pdf-stats-segmentation select-feats-from-segmentation \ + gmm-est-segmentation segmentation-to-rttm segmentation-post-process \ + gmm-update-segmentation segmentation-to-ali \ + segmentation-create-subsegments segmentation-remove-segments \ + segmentation-init-from-diarization segmentation-to-segments \ + segmentation-compute-class-ctm-conf segmentation-filter-ctm \ + segmentation-init-from-lengths segmentation-merge \ + segmentation-intersect-segments segmentation-combine-segments + +OBJFILES = + + + +TESTFILES = + +ADDLIBS = ../hmm/kaldi-hmm.a ../gmm/kaldi-gmm.a ../segmenter/kaldi-segmenter.a ../tree/kaldi-tree.a \ + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../thread/kaldi-thread.a + +include ../makefiles/default_rules.mk + diff --git a/src/segmenterbin/gmm-acc-pdf-stats-segmentation.cc b/src/segmenterbin/gmm-acc-pdf-stats-segmentation.cc new file mode 100644 index 00000000000..600af3a8009 --- /dev/null +++ b/src/segmenterbin/gmm-acc-pdf-stats-segmentation.cc @@ -0,0 +1,170 @@ +// gmmbin/gmm-acc-pdf-stats-segmentation.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "gmm/mle-am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + using namespace segmenter; + + typedef kaldi::int32 int32; + try { + const char *usage = + "Accumulate pdf stats for GMM training from segmentation.\n" + "Usage: gmm-acc-pdf-stats-segmentation [options] " + " \n" + "e.g.:\n gmm-acc-stats-ali 1.mdl scp:train.scp ark:1.seg 1.acc\n"; + + ParseOptions po(usage); + bool binary = true; + std::string class2pdf_rxfilename, pdfs_str; + + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("class2pdf", &class2pdf_rxfilename, + "Map from class label to pdf id"); + po.Register("pdfs", &pdfs_str, + "Only accumulate stats for these pdfs"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string model_filename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + segmentation_rspecifier = po.GetArg(3), + accs_wxfilename = po.GetArg(4); + + unordered_map class2pdf; + if (class2pdf_rxfilename != "") { + Input ki; + if (!ki.OpenTextMode(class2pdf_rxfilename)) + KALDI_ERR << "Unable to open file " << class2pdf_rxfilename + << " for reading in text mode"; + std::istream &is = ki.Stream(); + std::string line; + while (std::getline(is, line)) { + std::vector v; + if (!SplitStringToIntegers(line, " \t\r", true, &v) || v.size() != 2) { + KALDI_ERR << "Unable to parse line " << line << " in " + << class2pdf_rxfilename; + } + class2pdf.insert(std::make_pair(v[0], v[1])); + } + + if (!is.eof()) { + KALDI_ERR << "Did not reach EOF. Could not read file " << class2pdf_rxfilename + << " successfully"; + } + } + + std::vector pdfs; + if (pdfs_str != "") { + if (!SplitStringToIntegers(pdfs_str, ":", true, &pdfs)) { + KALDI_ERR << "Unable to parse string " << pdfs_str; + } + } + + AmDiagGmm am_gmm; + { + bool binary; + Input ki(model_filename, &binary); + TransitionModel trans_model; + trans_model.Read(ki.Stream(), binary); + am_gmm.Read(ki.Stream(), binary); + } + + AccumAmDiagGmm gmm_accs; + gmm_accs.Init(am_gmm, kGmmMeans|kGmmVariances|kGmmWeights); + + double tot_like = 0.0; + kaldi::int64 tot_t = 0; + + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + RandomAccessSegmentationReader segmentation_reader(segmentation_rspecifier); + + int32 num_done = 0, num_err = 0; + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string key = feature_reader.Key(); + if (!segmentation_reader.HasKey(key)) { + KALDI_WARN << "No segmentation for utterance " << key; + num_err++; + continue; + } + const Matrix &mat = feature_reader.Value(); + const Segmentation &segmentation = segmentation_reader.Value(key); + + BaseFloat tot_like_this_file = 0.0; + BaseFloat tot_t_this_file = 0.0; + + for (SegmentList::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + int32 pdf_id; + if (class2pdf_rxfilename != "") + pdf_id = it->Label(); + else + pdf_id = class2pdf.at(it->Label()); + if ( (pdfs_str != "" && std::binary_search(pdfs.begin(), pdfs.end(), pdf_id)) + || (pdfs_str == "" && pdf_id < am_gmm.NumPdfs() && pdf_id >=0) ) { + KALDI_ASSERT(pdf_id >= 0 && pdf_id < am_gmm.NumPdfs()); + for (int32 i = it->start_frame; i <= it->end_frame; i++) + tot_like_this_file += gmm_accs.AccumulateForGmm(am_gmm, mat.Row(i), + pdf_id, 1.0); + tot_t_this_file = it->end_frame - it->start_frame + 1; + } + } + tot_like += tot_like_this_file; + tot_t += tot_t_this_file; + + if (num_done % 50 == 0) { + KALDI_LOG << "Processed " << num_done << " utterances; for utterance " + << key << " avg. like is " + << (tot_like/tot_t) + << " over " << tot_t <<" frames."; + } + num_done++; + } + KALDI_LOG << "Done " << num_done << " files, " << num_err + << " with errors."; + + KALDI_LOG << "Overall avg like per frame (Gaussian only) = " + << (tot_like/tot_t) << " over " << tot_t << " frames."; + + { + Output ko(accs_wxfilename, binary); + gmm_accs.Write(ko.Stream(), binary); + } + KALDI_LOG << "Written accs."; + if (num_done != 0) + return 0; + else + return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/gmm-est-segmentation.cc b/src/segmenterbin/gmm-est-segmentation.cc new file mode 100644 index 00000000000..d943129021a --- /dev/null +++ b/src/segmenterbin/gmm-est-segmentation.cc @@ -0,0 +1,378 @@ +// gmmbin/gmm-est-segmentation.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "gmm/mle-am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "segmenter/segmenter.h" + +namespace kaldi { +namespace segmenter { + +void MleAmDiagGmmUpdateSubsetPdfs (const MleDiagGmmOptions &config, + const AccumAmDiagGmm &am_diag_gmm_acc, + const std::vector *pdfs, + GmmFlagsType flags, + AmDiagGmm *am_gmm, + BaseFloat *obj_change_out, + BaseFloat *count_out) { + KALDI_ASSERT(am_diag_gmm_acc.Dim() == am_gmm->Dim()); + KALDI_ASSERT(am_gmm != NULL); + KALDI_ASSERT(am_diag_gmm_acc.NumAccs() == am_gmm->NumPdfs()); + if (obj_change_out != NULL) *obj_change_out = 0.0; + if (count_out != NULL) *count_out = 0.0; + + BaseFloat tot_obj_change = 0.0, tot_count = 0.0; + int32 tot_elems_floored = 0, tot_gauss_floored = 0, + tot_gauss_removed = 0; + + if (pdfs != NULL) { + for (std::vector::const_iterator it = pdfs->begin(); + it != pdfs->end(); ++it) { + BaseFloat obj_change, count; + int32 elems_floored, gauss_floored, gauss_removed; + MleDiagGmmUpdate(config, am_diag_gmm_acc.GetAcc(*it), flags, + &(am_gmm->GetPdf(*it)), + &obj_change, &count, &elems_floored, + &gauss_floored, &gauss_removed); + KALDI_LOG << "Count for pdf " << *it << " is " << count; + tot_obj_change += obj_change; + tot_count += count; + tot_elems_floored += elems_floored; + tot_gauss_floored += gauss_floored; + tot_gauss_removed += gauss_removed; + } + } else { + for (int32 i = 0; i < am_diag_gmm_acc.NumAccs(); i++) { + BaseFloat obj_change, count; + int32 elems_floored, gauss_floored, gauss_removed; + + MleDiagGmmUpdate(config, am_diag_gmm_acc.GetAcc(i), flags, + &(am_gmm->GetPdf(i)), + &obj_change, &count, &elems_floored, + &gauss_floored, &gauss_removed); + KALDI_LOG << "Count for pdf " << i << " is " << count; + + tot_obj_change += obj_change; + tot_count += count; + tot_elems_floored += elems_floored; + tot_gauss_floored += gauss_floored; + tot_gauss_removed += gauss_removed; + } + } + + if (obj_change_out != NULL) *obj_change_out = tot_obj_change; + if (count_out != NULL) *count_out = tot_count; + KALDI_LOG << tot_elems_floored << " variance elements floored in " + << tot_gauss_floored << " Gaussians, out of " + << am_gmm->NumGauss(); + if (config.remove_low_count_gaussians) { + KALDI_LOG << "Removed " << tot_gauss_removed + << " Gaussians due to counts < --min-gaussian-occupancy=" + << config.min_gaussian_occupancy + << " and --remove-low-count-gaussians=true"; + } +} + +} +} + +int main(int argc, char *argv[]) { + using namespace kaldi; + using namespace segmenter; + + typedef kaldi::int32 int32; + try { + const char *usage = + "Accumulate pdf stats for GMM training from segmentation " + "and update GMM\n" + "Usage: gmm-est-segmentation [options] " + " \n" + "e.g.:\n gmm-acc-stats-ali 1.mdl scp:train.scp ark:1.seg 2.mdl\n"; + + ParseOptions po(usage); + bool binary = true; + std::string class2pdf_rxfilename, pdfs_str; + MleDiagGmmOptions gmm_opts; + int32 mixup = 0; + std::string mixup_per_pdf_str, mixup_rxfilename; + int32 mixdown = 0; + BaseFloat perturb_factor = 0.01; + BaseFloat power = 0.2; + BaseFloat min_count = 20.0; + std::string update_flags_str = "mvw"; + std::string occs_out_filename; + int32 num_iters = 3; + + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("class2pdf", &class2pdf_rxfilename, + "Map from class label to pdf id"); + po.Register("pdfs", &pdfs_str, + "Only accumulate stats for these pdfs"); + po.Register("mix-up", &mixup, "Increase number of mixture components to " + "this overall target."); + po.Register("mix-up-per-pdf", &mixup_per_pdf_str, + "Mix-up per pdf specified as comma separated list"); + po.Register("mix-up-rxfilename", &mixup_rxfilename, + "Mix-up per pdf specified in a table"); + po.Register("min-count", &min_count, + "Minimum per-Gaussian count enforced while mixing up and down."); + po.Register("mix-down", &mixdown, "If nonzero, merge mixture components to this " + "target."); + po.Register("power", &power, "If mixing up, power to allocate Gaussians to" + " states."); + po.Register("update-flags", &update_flags_str, "Which GMM parameters to " + "update: subset of mvwt."); + po.Register("perturb-factor", &perturb_factor, "While mixing up, perturb " + "means by standard deviation times this factor."); + po.Register("write-occs", &occs_out_filename, "File to write pdf " + "occupation counts to."); + po.Register("num-iters", &num_iters, "Number of iterations of ML estimation"); + + gmm_opts.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + kaldi::GmmFlagsType update_flags = + StringToGmmFlags(update_flags_str); + + std::string model_in_filename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + segmentation_rspecifier = po.GetArg(3), + model_out_filename = po.GetArg(4); + + unordered_map class2pdf; + if (class2pdf_rxfilename != "") { + Input ki; + if (!ki.OpenTextMode(class2pdf_rxfilename)) + KALDI_ERR << "Unable to open file " << class2pdf_rxfilename + << " for reading in text mode"; + std::istream &is = ki.Stream(); + std::string line; + while (std::getline(is, line)) { + std::vector v; + if (!SplitStringToIntegers(line, " \t\r", true, &v) || v.size() != 2) { + KALDI_ERR << "Unable to parse line " << line << " in " + << class2pdf_rxfilename; + } + class2pdf.insert(std::make_pair(v[0], v[1])); + } + + if (!is.eof()) { + KALDI_ERR << "Did not reach EOF. Could not read file " << class2pdf_rxfilename + << " successfully"; + } + } + + std::vector pdfs; + if (pdfs_str != "") { + if (!SplitStringToIntegers(pdfs_str, ":", true, &pdfs)) { + KALDI_ERR << "Unable to parse string " << pdfs_str; + } + } + + AmDiagGmm am_gmm; + TransitionModel trans_model; + { + bool binary; + Input ki(model_in_filename, &binary); + trans_model.Read(ki.Stream(), binary); + am_gmm.Read(ki.Stream(), binary); + } + + std::vector components_per_pdf(am_gmm.NumPdfs()); + for (int32 i = 0; i < am_gmm.NumPdfs(); i++) { + components_per_pdf[i] = am_gmm.GetPdf(i).NumGauss(); + } + + std::vector target_components_per_pdf(am_gmm.NumPdfs(), -1); + if (mixup_per_pdf_str != "") { + std::vector mixup_per_pdf; + if (!SplitStringToIntegers(mixup_per_pdf_str, ":", true, &mixup_per_pdf) + && mixup_per_pdf.size() != am_gmm.NumPdfs()) { + KALDI_ERR << "Unable to parse string " << mixup_per_pdf_str + << " or it has wrong size (!= " << am_gmm.NumPdfs() << ")"; + } + for (int32 i = 0; i < am_gmm.NumPdfs(); i++) { + target_components_per_pdf[i] = mixup_per_pdf[i]; + } + } else if (mixup_rxfilename != "") { + Input ki(mixup_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 2) { + KALDI_ERR << "Invalid line in file: " << line; + } + + int32 pdf_id, num_mix; + if (!ConvertStringToInteger(split_line[0], &pdf_id)) { + KALDI_ERR << "Invalid line in file [bad pdf_id]: " << line; + } + if (!ConvertStringToInteger(split_line[1], &num_mix)) { + KALDI_ERR << "Invalid line in file [bad num_mix]: " << line; + } + target_components_per_pdf[pdf_id] = num_mix; + } + } + + std::vector components_incr_per_pdf(am_gmm.NumPdfs(), 0); + for (int32 i = 0 ; i < am_gmm.NumPdfs(); i++) { + components_incr_per_pdf[i] = std::ceil((target_components_per_pdf[i] - components_per_pdf[i]) / (num_iters / 2)); + } + + RandomAccessSegmentationReader segmentation_reader(segmentation_rspecifier); + + for (int32 n = 0; n < num_iters; n++) { + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + + AccumAmDiagGmm gmm_accs; + gmm_accs.Init(am_gmm, update_flags); + + double tot_like = 0.0; + kaldi::int64 tot_t = 0; + + int32 num_done = 0, num_err = 0; + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string key = feature_reader.Key(); + if (!segmentation_reader.HasKey(key)) { + KALDI_WARN << "No segmentation for utterance " << key; + num_err++; + continue; + } + const Matrix &mat = feature_reader.Value(); + const Segmentation &segmentation = segmentation_reader.Value(key); + + BaseFloat tot_like_this_file = 0.0; + BaseFloat tot_t_this_file = 0.0; + + for (SegmentList::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + int32 pdf_id; + if (class2pdf_rxfilename != "") + try { + pdf_id = class2pdf.at(it->Label()); + } catch (const std::out_of_range& oor) { + KALDI_VLOG(2) << "Out of Range error: " << oor.what() << '\n'; + continue; + } + else + pdf_id = it->Label(); + if ( (pdfs_str != "" && std::binary_search(pdfs.begin(), pdfs.end(), pdf_id)) + || (pdfs_str == "" && pdf_id < am_gmm.NumPdfs() && pdf_id >=0) ) { + KALDI_ASSERT(pdf_id >= 0 && pdf_id < am_gmm.NumPdfs()); + for (int32 i = it->start_frame; i <= it->end_frame; i++) + tot_like_this_file += gmm_accs.AccumulateForGmm(am_gmm, mat.Row(i), + pdf_id, 1.0); + tot_t_this_file = it->end_frame - it->start_frame + 1; + } + } + tot_like += tot_like_this_file; + tot_t += tot_t_this_file; + + num_done++; + } + + KALDI_LOG << "In iteration " << n << ", done " << num_done << " files, " << num_err + << " with errors."; + + KALDI_LOG << "In iteration " << n << ", overall avg like per frame (Gaussian only) = " + << (tot_like/tot_t) << " over " << tot_t << " frames."; + + KALDI_ASSERT(tot_t > 0); + + BaseFloat objf_impr, count; + MleAmDiagGmmUpdateSubsetPdfs(gmm_opts, gmm_accs, pdfs.size() > 0 ? &pdfs : NULL, update_flags, + &am_gmm, &objf_impr, &count); + + KALDI_LOG << "GMM update: In iteration " << n << ", overall " + << (objf_impr/count) + << " objective function improvement per frame over " + << count << " frames"; + + KALDI_ASSERT(count > 0); + + if (mixup != 0 || mixdown != 0 || + (n == num_iters - 1 && !occs_out_filename.empty()) ) { + // get pdf occupation counts + Vector pdf_occs; + pdf_occs.Resize(gmm_accs.NumAccs()); + for (int i = 0; i < gmm_accs.NumAccs(); i++) + pdf_occs(i) = gmm_accs.GetAcc(i).occupancy().Sum(); + + if (mixdown != 0) + am_gmm.MergeByCount(pdf_occs, mixdown, power, min_count); + + if (mixup != 0) + am_gmm.SplitByCount(pdf_occs, mixup, perturb_factor, + power, min_count); + + if (n == num_iters - 1 && !occs_out_filename.empty()) { + bool binary = false; + WriteKaldiObject(pdf_occs, occs_out_filename, binary); + } + } + + if (mixup_per_pdf_str != "" || mixup_rxfilename != "") { + if (pdfs_str != "") { + for (std::vector::const_iterator it = pdfs.begin(); + it != pdfs.end(); ++it) { + components_per_pdf[*it] += components_incr_per_pdf[*it]; + if (target_components_per_pdf[*it] > 0 && + components_per_pdf[*it] > target_components_per_pdf[*it]) + components_per_pdf[*it] = target_components_per_pdf[*it]; + am_gmm.GetPdf(*it).Split(components_per_pdf[*it], perturb_factor); + } + } else { + for (int32 i = 0; i < am_gmm.NumPdfs(); i++) { + components_per_pdf[i] += components_incr_per_pdf[i]; + if (target_components_per_pdf[i] > 0 && + components_per_pdf[i] > target_components_per_pdf[i]) + components_per_pdf[i] = target_components_per_pdf[i]; + am_gmm.GetPdf(i).Split(components_per_pdf[i], perturb_factor); + } + } + } + if (num_done == 0) return 1; + } + + { + Output ko(model_out_filename, binary); + trans_model.Write(ko.Stream(), binary); + am_gmm.Write(ko.Stream(), binary); + } + KALDI_LOG << "Written model to " << model_out_filename; + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/segmenterbin/gmm-global-init-from-segmentation.cc b/src/segmenterbin/gmm-global-init-from-segmentation.cc new file mode 100644 index 00000000000..cbcfb1ba008 --- /dev/null +++ b/src/segmenterbin/gmm-global-init-from-segmentation.cc @@ -0,0 +1,196 @@ +// gmmbin/gmm-acc-pdf-stats-segmentation.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "gmm/mle-am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + using namespace segmenter; + + typedef kaldi::int32 int32; + try { + const char *usage = + "Initialize an AmGmm or some pdfs of it from segmentation\n" + "Usage: gmm-init-from-segmentation [options] " + " \n" + "e.g.:\n gmm-init-from-segmentation --pdfs=0:2 1.mdl scp:train.scp ark:1.seg 2.mdl\n"; + + ParseOptions po(usage); + MleDiagGmmOptions gmm_opts; + + bool binary = true; + int32 num_gauss = 100; + int32 num_gauss_init = 0; + int32 num_iters = 50; + int32 num_frames = 200000; + int32 srand_seed = 0; + int32 num_threads = 4; + int32 label = -1; + + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("num-gauss", &num_gauss, "Number of Gaussians in the model"); + po.Register("num-gauss-init", &num_gauss_init, "Number of Gaussians in " + "the model initially (if nonzero and less than num_gauss, " + "we'll do mixture splitting)"); + po.Register("num-iters", &num_iters, "Number of iterations of training"); + po.Register("num-frames", &num_frames, "Number of feature vectors to store in " + "memory and train on (randomly chosen from the input features)"); + po.Register("srand", &srand_seed, "Seed for random number generator "); + po.Register("num-threads", &num_threads, "Number of threads used for " + "statistics accumulation"); + + gmm_opts.Register(&po); + + po.Read(argc, argv); + + srand(srand_seed); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string model_filename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + segmentation_rspecifier = po.GetArg(3), + model_wxfilename = po.GetArg(4); + + // Read class2pdf map + unordered_map class2pdf; + if (class2pdf_rxfilename != "") { + Input ki; + if (!ki.OpenTextMode(class2pdf_rxfilename)) + KALDI_ERR << "Unable to open file " << class2pdf_rxfilename + << " for reading in text mode"; + std::istream &is = ki.Stream(); + std::string line; + while (std::getline(is, line)) { + std::vector v; + if (!SplitStringToIntegers(line, " \t\r", true, &v) || v.size() != 2) { + KALDI_ERR << "Unable to parse line " << line << " in " + << class2pdf_rxfilename; + } + class2pdf.insert(std::make_pair(v[0], v[1])); + } + + if (!is.eof()) { + KALDI_ERR << "Did not reach EOF. Could not read file " << class2pdf_rxfilename + << " successfully"; + } + } + + // Seed AmDiagGmm + AmDiagGmm am_gmm; + { + bool binary; + Input ki(model_filename, &binary); + TransitionModel trans_model; + trans_model.Read(ki.Stream(), binary); + am_gmm.Read(ki.Stream(), binary); + } + MleAccumGmm + + double tot_like = 0.0; + kaldi::int64 tot_t = 0; + + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + RandomAccessSegmentationReader segmentation_reader(segmentation_rspecifier); + + int32 num_done = 0, num_err = 0; + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string key = feature_reader.Key(); + if (!segmentation_reader.HasKey(key)) { + KALDI_WARN << "No segmentation for utterance " << key; + num_err++; + continue; + } + const Matrix &mat = feature_reader.Value(); + const Segmentation &segmentation = segmentation_reader.Value(key); + + BaseFloat tot_like_this_file = 0.0; + BaseFloat tot_t_this_file = 0.0; + + for (std::forward_list::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + int32 pdf_id; + if (class2pdf_rxfilename != "") + pdf_id = it->Label(); + else + pdf_id = class2pdf.at(it->Label()); + + if ( (pdfs_str != "" && std::binary_search(pdfs.begin(), pdfs.end(), pdf_id)) + || (pdfs_str == "") ) { + // Pdf needs to be initialized + KALDI_ASSERT(pdf_id < NumPdfs() && pdf_id >= 0); + if (gauss_clusterable[pdf_id] == NULL) { + gauss_clusterable[pdf_id] = new GaussClusterable(mat.NumCols(), gmm_opts.min_variance); + } + for (int32 i = it->start_frame; i <= it->end_frame; i++) + gauss_clusterable[pdf_id]->AddStats(mat.Row(i), 1.0); + } + } + num_done++; + } + + if (pdfs_str != "") { + for (std::vector::const_iterator it = pdfs.begin(); + it != pdfs.end(); ++it) { + if (init_am_gmm) { + DiagGmm gmm(*gauss_clusterable[pdf_id], var_floor); + // Initialize am_gmm from scratch + am_gmm.Init( + + } + + if (*it < NumPdfs()) { + DiagGmm &gmm = am_gmm.GetPdf(*it); + } else { + DiagGmm gmm; + am_gmm.AddPdf( + (*it) + } + } + + KALDI_LOG << "Done " << num_done << " files, " << num_err + << " with errors."; + + KALDI_LOG << "Overall avg like per frame (Gaussian only) = " + << (tot_like/tot_t) << " over " << tot_t << " frames."; + + { + Output ko(accs_wxfilename, binary); + gmm_accs.Write(ko.Stream(), binary); + } + KALDI_LOG << "Written accs."; + if (num_done != 0) + return 0; + else + return 1; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/segmenterbin/gmm-update-segmentation.cc b/src/segmenterbin/gmm-update-segmentation.cc new file mode 100644 index 00000000000..93df529aaae --- /dev/null +++ b/src/segmenterbin/gmm-update-segmentation.cc @@ -0,0 +1,302 @@ +// gmmbin/gmm-update-segmentation.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "gmm/mle-am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "segmenter/segmenter.h" + +namespace kaldi { +namespace segmenter { + +void TrainGmm(const std::list > &feats_list, + const MleDiagGmmOptions &gmm_opts, + int32 target_num_components, + int32 max_iters, GmmFlagsType update_flags, + BaseFloat perturb_factor, + int32 pdf_id, AmDiagGmm *am_gmm, + double *tot_like, int64 *tot_t, + double *tot_objf_impr, int64 *tot_count) { + + DiagGmm *gmm = &am_gmm->GetPdf(pdf_id); + + int32 num_components = gmm->NumGauss(); + int32 num_iters = target_num_components - num_components + 1; + int32 components_incr = 1; + + if (num_iters > max_iters) { + components_incr = std::ceil( static_cast( (target_num_components - num_components) / (max_iters - 1)) ); + num_iters = max_iters; + } + + for (int32 iter = 0; iter < num_iters; iter++) { + AccumDiagGmm gmm_accs; + gmm_accs.Resize(*gmm, update_flags); + + double like = 0.0; + int64 time = 0; + + std::list >::const_iterator it = feats_list.begin(); + for (; it != feats_list.end(); ++it, time++) { + like += gmm_accs.AccumulateFromDiag(*gmm, *it, 1.0); + } + + KALDI_LOG << "For pdf " << pdf_id << " with " << num_components + << " gaussians, in iteration " + << iter << ", log-likelihood is " << (like / time) + << " over " << time << " frames"; + + BaseFloat objf_impr, count; + MleDiagGmmUpdate(gmm_opts, gmm_accs, update_flags, gmm, + &objf_impr, &count); + + KALDI_LOG << "For pdf " << pdf_id << " with " << num_components + << " gaussians, in iteration " + << iter << ", objf improvement is " << (objf_impr / count) + << " over " << count << " frames"; + + if (num_components < target_num_components) { + gmm->Split(num_components + components_incr, perturb_factor); + num_components = gmm->NumGauss(); + } + (*tot_objf_impr) += objf_impr; + (*tot_count) += count; + if (iter == num_iters - 1) { + (*tot_t) += time; + (*tot_like) += like; + } + } +} + +} +} + +int main(int argc, char *argv[]) { + using namespace kaldi; + using namespace segmenter; + + typedef kaldi::int32 int32; + try { + const char *usage = + "Accumulate pdf stats for GMM training from segmentation " + "and update GMM\n" + "Usage: gmm-update-segmentation [options] " + " \n" + "e.g.:\n gmm-update-segmentation 1.mdl scp:train.scp ark:1.seg 2.mdl\n"; + + ParseOptions po(usage); + bool binary = true; + std::string class2pdf_rxfilename, pdfs_str; + MleDiagGmmOptions gmm_opts; + int32 mixup = 0; + std::string mixup_per_pdf_str, mixup_rxfilename; + int32 mixdown = 0; + BaseFloat perturb_factor = 0.01; + BaseFloat power = 0.2; + BaseFloat min_count = 20.0; + std::string update_flags_str = "mvw"; + std::string occs_out_filename; + int32 num_iters = 3; + + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("class2pdf", &class2pdf_rxfilename, + "Map from class label to pdf id"); + po.Register("pdfs", &pdfs_str, + "Only accumulate stats for these pdfs"); + po.Register("mix-up", &mixup, "Increase number of mixture components to " + "this overall target."); + po.Register("mix-up-per-pdf", &mixup_per_pdf_str, + "Mix-up per pdf specified as comma separated list"); + po.Register("mix-up-rxfilename", &mixup_rxfilename, + "Mix-up per pdf specified in a table"); + po.Register("min-count", &min_count, + "Minimum per-Gaussian count enforced while mixing up and down."); + po.Register("mix-down", &mixdown, "If nonzero, merge mixture components to this " + "target."); + po.Register("power", &power, "If mixing up, power to allocate Gaussians to" + " states."); + po.Register("update-flags", &update_flags_str, "Which GMM parameters to " + "update: subset of mvwt."); + po.Register("perturb-factor", &perturb_factor, "While mixing up, perturb " + "means by standard deviation times this factor."); + po.Register("write-occs", &occs_out_filename, "File to write pdf " + "occupation counts to."); + po.Register("num-iters", &num_iters, "Number of iterations of ML estimation"); + + gmm_opts.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + kaldi::GmmFlagsType update_flags = + StringToGmmFlags(update_flags_str); + + std::string model_in_filename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + segmentation_rspecifier = po.GetArg(3), + model_out_filename = po.GetArg(4); + + unordered_map class2pdf; + if (class2pdf_rxfilename != "") { + Input ki; + if (!ki.OpenTextMode(class2pdf_rxfilename)) + KALDI_ERR << "Unable to open file " << class2pdf_rxfilename + << " for reading in text mode"; + std::istream &is = ki.Stream(); + std::string line; + while (std::getline(is, line)) { + std::vector v; + if (!SplitStringToIntegers(line, " \t\r", true, &v) || v.size() != 2) { + KALDI_ERR << "Unable to parse line " << line << " in " + << class2pdf_rxfilename; + } + class2pdf.insert(std::make_pair(v[0], v[1])); + } + + if (!is.eof()) { + KALDI_ERR << "Did not reach EOF. Could not read file " << class2pdf_rxfilename + << " successfully"; + } + } + + std::vector pdfs; + if (pdfs_str != "") { + if (!SplitStringToIntegers(pdfs_str, ":", true, &pdfs)) { + KALDI_ERR << "Unable to parse string " << pdfs_str; + } + } + + AmDiagGmm am_gmm; + TransitionModel trans_model; + { + bool binary; + Input ki(model_in_filename, &binary); + trans_model.Read(ki.Stream(), binary); + am_gmm.Read(ki.Stream(), binary); + } + + std::vector components_per_pdf(am_gmm.NumPdfs()); + for (int32 i = 0; i < am_gmm.NumPdfs(); i++) { + components_per_pdf[i] = am_gmm.GetPdf(i).NumGauss(); + } + + std::vector > > feats_per_pdf(am_gmm.NumPdfs()); + + std::vector target_components_per_pdf(am_gmm.NumPdfs(), -1); + if (mixup_per_pdf_str != "") { + std::vector mixup_per_pdf; + if (!SplitStringToIntegers(mixup_per_pdf_str, ":", true, &mixup_per_pdf) + && mixup_per_pdf.size() != am_gmm.NumPdfs()) { + KALDI_ERR << "Unable to parse string " << mixup_per_pdf_str + << " or it has wrong size (!= " << am_gmm.NumPdfs() << ")"; + } + for (int32 i = 0; i < am_gmm.NumPdfs(); i++) { + target_components_per_pdf[i] = mixup_per_pdf[i]; + } + } else if (mixup_rxfilename != "") { + Input ki(mixup_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 2) { + KALDI_ERR << "Invalid line in file: " << line; + } + + int32 pdf_id, num_mix; + if (!ConvertStringToInteger(split_line[0], &pdf_id)) { + KALDI_ERR << "Invalid line in file [bad pdf_id]: " << line; + } + if (!ConvertStringToInteger(split_line[1], &num_mix)) { + KALDI_ERR << "Invalid line in file [bad num_mix]: " << line; + } + target_components_per_pdf[pdf_id] = num_mix; + } + } + + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + RandomAccessSegmentationReader segmentation_reader(segmentation_rspecifier); + + int32 num_done = 0, num_err = 0; + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string key = feature_reader.Key(); + if (!segmentation_reader.HasKey(key)) { + KALDI_WARN << "No segmentation for utterance " << key; + num_err++; + continue; + } + const Matrix &mat = feature_reader.Value(); + const Segmentation &segmentation = segmentation_reader.Value(key); + + for (SegmentList::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + int32 pdf_id; + if (class2pdf_rxfilename != "") + try { + pdf_id = class2pdf.at(it->Label()); + } catch (const std::out_of_range& oor) { + KALDI_VLOG(2) << "Out of Range error: " << oor.what() << '\n'; + continue; + } + else + pdf_id = it->Label(); + if ( (pdfs_str != "" && std::binary_search(pdfs.begin(), pdfs.end(), pdf_id)) + || (pdfs_str == "" && pdf_id < am_gmm.NumPdfs() && pdf_id >=0) ) { + KALDI_ASSERT(pdf_id >= 0 && pdf_id < am_gmm.NumPdfs()); + for (int32 i = it->start_frame; i <= it->end_frame; i++) + feats_per_pdf[pdf_id].emplace_back(mat.Row(i)); + } + } + num_done++; + } + + std::vector components_incr_per_pdf(am_gmm.NumPdfs(), 0); + + double tot_like = 0.0, tot_objf_impr = 0.0; + int64 tot_t = 0, tot_count = 0; + + for (int32 i = 0 ; i < am_gmm.NumPdfs(); i++) { + TrainGmm(feats_per_pdf[i], gmm_opts, target_components_per_pdf[i], + num_iters, update_flags, perturb_factor, + i, &am_gmm, &tot_like, &tot_t, &tot_objf_impr, &tot_count); + } + + { + Output ko(model_out_filename, binary); + trans_model.Write(ko.Stream(), binary); + am_gmm.Write(ko.Stream(), binary); + } + KALDI_LOG << "Written model to " << model_out_filename; + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + diff --git a/src/segmenterbin/segmentation-combine-segments.cc b/src/segmenterbin/segmentation-combine-segments.cc new file mode 100644 index 00000000000..64781aff811 --- /dev/null +++ b/src/segmenterbin/segmentation-combine-segments.cc @@ -0,0 +1,110 @@ +// segmenterbin/segmentation-combine-segments.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Combine utterance-level segmentations in an archive to file-level " + "segmentations using the kaldi segments to map utterances to " + "file.\n" + "\n" + "Usage: segmentation-combine-segments [options] \n" + " e.g.: segmentation-combine-segments ark:utt.seg ark,t:data/dev/segments ark,t:data/dev/reco2utt ark:file.seg\n"; + + bool binary = true; + BaseFloat frame_shift = 0.01; + ParseOptions po(usage); + + po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + segments_rspecifier = po.GetArg(2), + reco2utt_rspecifier = po.GetArg(3), + segmentation_wspecifier = po.GetArg(4); + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessUtteranceSegmentReader segments_reader(segments_rspecifier); + RandomAccessSegmentationReader segmentation_reader(segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0; + int64 num_segments = 0; + int64 num_err = 0; + + std::vector frame_counts_per_class; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &utts = reco2utt_reader.Value(); + const std::string &reco_id = reco2utt_reader.Key(); + + Segmentation seg; + + for (std::vector::const_iterator it = utts.begin(); + it != utts.end(); ++it) { + if (!segments_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segments " << segments_rspecifier; + num_err++; + continue; + } + + if (!segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segmentation " << segmentation_rspecifier; + num_err++; + continue; + } + const UtteranceSegment &segment = segments_reader.Value(*it); + const Segmentation &utt_seg = segmentation_reader.Value(*it); + + num_segments += seg.InsertFromSegmentation(utt_seg, + segment.start_time / frame_shift, + &frame_counts_per_class); + num_done++; + } + seg.Sort(); + segmentation_writer.Write(reco_id, seg); + num_segmentations++; + } + + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + + + diff --git a/src/segmenterbin/segmentation-compute-class-ctm-conf.cc b/src/segmenterbin/segmentation-compute-class-ctm-conf.cc new file mode 100644 index 00000000000..3c6071edb95 --- /dev/null +++ b/src/segmenterbin/segmentation-compute-class-ctm-conf.cc @@ -0,0 +1,232 @@ +// segmenterbin/segmentation-compute-class-ctm-conf.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" +#include "hmm/posterior.h" + +namespace kaldi { + +struct StringPairHasher { + size_t operator() (const std::pair &str_pair) const { + return StringHasher()(str_pair.first) + kPrime * StringHasher()(str_pair.second); + } + + private: + static const int kPrime = 7853; +}; + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Computes per-class confidences for different segmentation " + "(usually diarization) classes using word-confidences from a CTM file" + "\n" + "Usage: segmentation-compute-class-ctm-conf [options] " + " e.g.: segmentation-compute-class-ctm-conf ark:exp/nnet2_multicondition/diarization_dev/diarziation_results.txt ctm reco2file_and_channel ark:exp/nnet2_multicondition/diarization_dev/post.ark\n"; + + bool binary = true; + BaseFloat frame_shift = 0.01; + + ParseOptions po(usage); + + po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); + po.Register("frame-shift", &frame_shift, "Frame shift"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + ctm_rxfilename = po.GetArg(2), + reco2file_and_channel_rxfilename = po.GetArg(3), + post_wspecifier = po.GetArg(4); + + RandomAccessSegmentationReader seg_reader(segmentation_rspecifier); + Input ki(ctm_rxfilename); + PosteriorWriter conf_writer(post_wspecifier); + + long num_lines = 0, num_success = 0; + int32 num_recos = 0; + + std::unordered_map,std::string, StringPairHasher> file_and_channel2reco_map; + { + Input ki(reco2file_and_channel_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 3 fields--recording name, file name, channel + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 3) { + KALDI_ERR << "Invalid line in reco2file_and_channel file: " << line; + } + std::string reco_id = split_line[0], + file_id = split_line[1], + channel = split_line[2]; + + file_and_channel2reco_map[std::make_pair(file_id, channel)] = reco_id; + } + } + + std::string line, recording, prev_recording; + segmenter::Segmentation seg; + segmenter::SegmentList::const_iterator seg_it; + + std::unordered_map >*, StringHasher> confidences; + std::set reco_list; + + while (std::getline(ki.Stream(), line)) { + num_lines++; + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 6 fields--file name, channel, + // start time, end time, word, confidence + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 6 && split_line.size() != 5) { + KALDI_ERR << "Invalid line in ctm file: " << line; + } + + if (split_line.size() == 5) { + KALDI_VLOG(1) << "Line '" << line << "' does not have confidence. Skipping,"; + continue; + } + + std::string start_str = split_line[2], + duration_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in ctm file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(duration_str, &end)) { + KALDI_WARN << "Invalid line in ctm file [bad end]: " << line; + continue; + } + + end += start; + + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || (end != -1.0 && end <= 0) || ((start >= end) && (end > 0))) { + KALDI_WARN << "Invalid line in ctm file [empty or invalid segment]: " + << line; + continue; + } + + double conf; + if (!ConvertStringToReal(split_line[5], &conf)) { + KALDI_ERR << "Invalid line in ctm file [bad conf]: " << line; + } + + std::string reco_id; + try { + reco_id = file_and_channel2reco_map.at(std::make_pair(split_line[0], split_line[1])); + } catch (std::out_of_range &oor) { + KALDI_ERR << "Out of range error: " << oor.what(); + } + + if (prev_recording == "" || prev_recording != reco_id) { + if (!seg_reader.HasKey(reco_id)) { + KALDI_ERR << "Could not find segmentation for recording " << reco_id; + } + seg = seg_reader.Value(reco_id); + seg_it = seg.Begin(); + + std::map > *conf_acc; + + if (confidences.count(reco_id) > 0) { + conf_acc = confidences[reco_id]; + } else { + conf_acc = new std::map >; + confidences.insert(std::make_pair(reco_id, conf_acc)); + reco_list.insert(reco_id); + num_recos++; + } + + while (seg_it != seg.End() && seg_it->end_frame * frame_shift < start) ++seg_it; + while (seg_it != seg.Begin() && seg_it->start_frame * frame_shift > start) --seg_it; + + BaseFloat this_word_occ = 0; + std::vector occs; + while (seg_it != seg.End() && seg_it->start_frame * frame_shift <= end) { + double fraction = (std::min(static_cast(seg_it->end_frame * frame_shift), end) - std::max(static_cast(seg_it->start_frame * frame_shift), start) + frame_shift) / (end-start+frame_shift); + std::map >::iterator it = conf_acc->find(seg_it->Label()); + if (it == conf_acc->end()) { + (*conf_acc)[seg_it->Label()] = std::make_pair(conf * fraction, fraction); + } else { + (it->second).first += conf * fraction; + (it->second).second += fraction; + } + this_word_occ += fraction; + occs.push_back(seg_it->start_frame * frame_shift); + occs.push_back(seg_it->end_frame * frame_shift); + occs.push_back(fraction); + ++seg_it; + } + if (!kaldi::ApproxEqual(this_word_occ, 1.0)) { + KALDI_WARN << "This word from " << start << " - " << end + << "; computed occupancy is " << this_word_occ << "; (seg_start,seg_end,frac) = " << SubVector(&occs[0], occs.size());; + } + } + + prev_recording = recording; + num_success++; + } + + long double occ = 0; + for (std::set::const_iterator it = reco_list.begin(); + it != reco_list.end(); ++it) { + const std::map > *conf_acc = confidences[*it]; + Posterior post(2); + for (std::map >::const_iterator c_it = conf_acc->begin(); + c_it != conf_acc->end(); ++c_it) { + post[0].push_back(std::make_pair(c_it->first, (c_it->second).first / (c_it->second).second)); + post[1].push_back(std::make_pair(c_it->first, (c_it->second).second)); + occ += (c_it->second).second; + } + conf_writer.Write(*it, post); + } + + KALDI_LOG << "Successfully processed " << num_success << " lines out of " + << num_lines << " in the ctm file; wrote " + << num_recos << " recordings; total word occupancy is " << occ; + + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + + + diff --git a/src/segmenterbin/segmentation-copy.cc b/src/segmenterbin/segmentation-copy.cc new file mode 100644 index 00000000000..b1654d364c0 --- /dev/null +++ b/src/segmenterbin/segmentation-copy.cc @@ -0,0 +1,93 @@ +// segmenterbin/segmentation-copy.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Copy segmentation or archives of segmentation\n" + "\n" + "Usage: segmentation-copy [options] (segmentation-in-rspecifier|segmentation-in-rxfilename) (segmentation-out-wspecifier|segmentation-out-wxfilename)\n" + " e.g.: segmentation-copy --binary=false foo -\n" + " segmentation-copy ark:1.ali ark,t:-\n"; + + bool binary = true; + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode (only relevant if output is a wxfilename)"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation seg; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + seg.Read(ki.Stream(), binary_in); + } + Output ko(segmentation_out_fn, binary); + seg.Write(ko.Stream(), binary); + KALDI_LOG << "Copied segmentation to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + for (; !reader.Done(); reader.Next(), num_done++) { + writer.Write(reader.Key(), reader.Value()); + } + + KALDI_LOG << "Copied " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-create-subsegments.cc b/src/segmenterbin/segmentation-create-subsegments.cc new file mode 100644 index 00000000000..618bcf39b62 --- /dev/null +++ b/src/segmenterbin/segmentation-create-subsegments.cc @@ -0,0 +1,140 @@ +// segmenterbin/segmentation-create-subsegments.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Create sub-segmentation of a segmentation by intersecting with " + "segments from another segmentation. The regions where the other " + "segmentation has class_id 'filter-label' is labeled " + "'subsegment-label'\n" + "\n" + "Usage: segmentation-create-subsegments --subsegment-label=1000 --filter-label=10 [options] (segmentation-in-rspecifier|segmentation-in-rxfilename) (filter-segmentation-in-rspecifier|filter-segmentation-out-rxfilename) (segmentation-out-wspecifier|segmentation-out-wxfilename)\n" + " e.g.: segmentation-copy --binary=false foo -\n" + " segmentation-copy ark:1.ali ark,t:-\n"; + + bool binary = true, ignore_missing = true; + int32 filter_label = -1, subsegment_label = -1; + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode (only relevant if output is a wxfilename)"); + po.Register("filter-label", &filter_label, "The class_id on which the " + "filtering is done."); + po.Register("subsegment-label", &subsegment_label, + "If non-negative, change the class_id of " + "the intersection of the two segmentations to this label."); + po.Register("ignore-missing", &ignore_missing, "Ignore missing " + "segmentations in filter. If this is set true, then the " + "segmentations with missing key in filter are written " + "without any modification."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + secondary_segmentation_in_fn = po.GetArg(2), + segmentation_out_fn = po.GetArg(3); + + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + filter_is_rspecifier = + (ClassifyRspecifier(secondary_segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier || in_is_rspecifier != filter_is_rspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation seg; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + seg.Read(ki.Stream(), binary_in); + } + Segmentation secondary_seg; + { + bool binary_in; + Input ki(secondary_segmentation_in_fn, &binary_in); + secondary_seg.Read(ki.Stream(), binary_in); + } + + Segmentation new_seg; + seg.SubSegmentUsingSmallOverlapSegments(secondary_seg, filter_label, + subsegment_label, &new_seg); + Output ko(segmentation_out_fn, binary); + new_seg.Write(ko.Stream(), binary); + KALDI_LOG << "Copied segmentation to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + RandomAccessSegmentationReader filter_reader(secondary_segmentation_in_fn); + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &seg = reader.Value(); + const std::string &key = reader.Key(); + + if (!filter_reader.HasKey(key)) { + KALDI_WARN << "Could not find filter for utterance " << key; + if (!ignore_missing) { + num_err++; + } else + writer.Write(key, seg); + continue; + } + const Segmentation &secondary_segmentation = filter_reader.Value(key); + + Segmentation new_seg; + seg.SubSegmentUsingSmallOverlapSegments(secondary_segmentation, filter_label, + subsegment_label, &new_seg); + + writer.Write(key, new_seg); + } + + KALDI_LOG << "Created subsegments for " << num_done << " segmentations; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + diff --git a/src/segmenterbin/segmentation-filter-ctm.cc b/src/segmenterbin/segmentation-filter-ctm.cc new file mode 100644 index 00000000000..ca99e41af2b --- /dev/null +++ b/src/segmenterbin/segmentation-filter-ctm.cc @@ -0,0 +1,214 @@ +// segmenterbin/segmentation-filter-ctm.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" +#include "hmm/posterior.h" + +#include + +namespace kaldi { + +struct StringPairHasher { + size_t operator() (const std::pair &str_pair) const { + return StringHasher()(str_pair.first) + kPrime * StringHasher()(str_pair.second); + } + + private: + static const int kPrime = 7853; +}; + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Filter CTM based on the segmentation" + "\n" + "Usage: segmentation-filter-ctm [options] " + " e.g.: segmentation-filter-ctm ark:exp/nnet2_multicondition/diarization_dev/diarziation_results.txt ctm reco2file_and_channel ctm_out\n"; + + bool binary = true; + BaseFloat frame_shift = 0.01; + BaseFloat keep_word_threshold = 0.5; + + ParseOptions po(usage); + + po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); + po.Register("frame-shift", &frame_shift, "Frame shift"); + po.Register("keep-word-threshold", &keep_word_threshold, + "Keep word in CTM if at least this fraction of word is inside segment"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + ctm_rxfilename = po.GetArg(2), + reco2file_and_channel_rxfilename = po.GetArg(3), + ctm_wxfilename = po.GetArg(4); + + RandomAccessSegmentationReader seg_reader(segmentation_rspecifier); + Input ki(ctm_rxfilename); + Output ko(ctm_wxfilename, false); + + long num_lines = 0, num_success = 0; + int32 num_recos = 0; + + std::unordered_map,std::string, StringPairHasher> file_and_channel2reco_map; + { + Input ki(reco2file_and_channel_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 3 fields--recording name, file name, channel + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 3) { + KALDI_ERR << "Invalid line in reco2file_and_channel file: " << line; + } + std::string reco_id = split_line[0], + file_id = split_line[1], + channel = split_line[2]; + + file_and_channel2reco_map[std::make_pair(file_id, channel)] = reco_id; + } + } + + std::string line, recording, prev_recording; + segmenter::Segmentation seg; + segmenter::SegmentList::const_iterator seg_it; + + std::unordered_map >*, StringHasher> confidences; + std::set reco_list; + + std::regex special_regex(".*(||)>*"); + + while (std::getline(ki.Stream(), line)) { + num_lines++; + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 6 fields--file name, channel, + // start time, end time, word, confidence + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 6 && split_line.size() != 5) { + KALDI_ERR << "Invalid line in ctm file: " << line; + } + + if (split_line.size() == 5 || std::regex_match(line, special_regex)) { + KALDI_VLOG(2) << "Seen line that matches special regex " << line; + ko.Stream() << line << "\n"; + continue; + } + + std::string start_str = split_line[2], + duration_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in ctm file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(duration_str, &end)) { + KALDI_WARN << "Invalid line in ctm file [bad end]: " << line; + continue; + } + + end += start; + + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || (end != -1.0 && end <= 0) || ((start >= end) && (end > 0))) { + KALDI_WARN << "Invalid line in ctm file [empty or invalid segment]: " + << line; + continue; + } + + double conf = -1.0; + if (split_line.size() == 6 && !ConvertStringToReal(split_line[5], &conf)) { + KALDI_ERR << "Invalid line in ctm file [bad conf]: " << line; + } + + std::string reco_id; + try { + reco_id = file_and_channel2reco_map.at(std::make_pair(split_line[0], split_line[1])); + } catch (std::out_of_range &oor) { + KALDI_ERR << "Out of range error: " << oor.what(); + } + + if (prev_recording == "" || prev_recording != reco_id) { + if (!seg_reader.HasKey(reco_id)) { + KALDI_ERR << "Could not find segmentation for recording " << reco_id; + } + seg = seg_reader.Value(reco_id); + seg_it = seg.Begin(); + + std::map > *conf_acc; + + if (confidences.count(reco_id) > 0) { + conf_acc = confidences[reco_id]; + } else { + conf_acc = new std::map >; + confidences.insert(std::make_pair(reco_id, conf_acc)); + reco_list.insert(reco_id); + num_recos++; + } + + while (seg_it != seg.End() && seg_it->end_frame * frame_shift < start) ++seg_it; + while (seg_it != seg.Begin() && seg_it->start_frame * frame_shift > start) --seg_it; + + BaseFloat this_word_occ = 0; + while (seg_it != seg.End() && seg_it->start_frame * frame_shift <= end) { + double fraction = (std::min(static_cast(seg_it->end_frame * frame_shift), end) - std::max(static_cast(seg_it->start_frame * frame_shift), start) + frame_shift) / (end-start+frame_shift); + this_word_occ += fraction; + ++seg_it; + } + + if (this_word_occ > keep_word_threshold) { + ko.Stream() << line << "\n"; + } + } + + prev_recording = recording; + num_success++; + } + + KALDI_LOG << "Successfully processed " << num_success << " lines out of " + << num_lines << " in the ctm file; wrote " + << num_recos; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + + + + diff --git a/src/segmenterbin/segmentation-init-from-ali.cc b/src/segmenterbin/segmentation-init-from-ali.cc new file mode 100644 index 00000000000..75887ccd587 --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-ali.cc @@ -0,0 +1,142 @@ +// segmenterbin/segmentation-init-from-ali.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize segmentations from alignments file\n" + "\n" + "Usage: segmentation-init-from-ali [options] \n" + " e.g.: segmentation-init-from-ali ark:1.ali ark:-\n"; + + std::string reco2utt_rspecifier; + std::string segments_rspecifier; + BaseFloat frame_shift = 0.01; + + ParseOptions po(usage); + + po.Register("reco2utt-rspecifier", &reco2utt_rspecifier, + "Use reco2utt and segments files to create file-level " + "segmentations instead of utterance-level segmentations. " + "Works in conjunction with --segments-rspecifier option."); + po.Register("segments-rspecifier", &segments_rspecifier, + "Use reco2utt and segments files to create file-level " + "segmentations instead of utterance-level segmentations. " + "Works in conjunction with --segments-rspecifier option."); + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string ali_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0; + int64 num_segments = 0; + int64 num_err = 0; + + std::vector frame_counts_per_class; + + if (reco2utt_rspecifier.empty() && segments_rspecifier.empty()) { + SequentialInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !alignment_reader.Done(); alignment_reader.Next()) { + std::string key = alignment_reader.Key(); + const std::vector &alignment = alignment_reader.Value(); + + Segmentation seg; + + num_segments += seg.InsertFromAlignment(alignment, 0, + &frame_counts_per_class); + + segmentation_writer.Write(key, seg); + num_done++; + num_segmentations++; + } + } else { + if (reco2utt_rspecifier.empty() || segments_rspecifier.empty()) { + KALDI_ERR << "Require both --reco2utt-rspecifier and " + << "--segments-rspecifier to be non-empty"; + } + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessUtteranceSegmentReader segments_reader(segments_rspecifier); + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &utts = reco2utt_reader.Value(); + const std::string &reco_id = reco2utt_reader.Key(); + + Segmentation seg; + for (std::vector::const_iterator it = utts.begin(); + it != utts.end(); ++it) { + if (!segments_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segments " << segments_rspecifier; + num_err++; + continue; + } + + if (!alignment_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "alignment " << ali_rspecifier; + num_err++; + continue; + } + + const UtteranceSegment &segment = segments_reader.Value(*it); + const std::vector &alignment = alignment_reader.Value(*it); + + num_segments += seg.InsertFromAlignment(alignment, + segment.start_time / frame_shift, + &frame_counts_per_class); + + num_done++; + } + segmentation_writer.Write(reco_id, seg); + num_segmentations++; + } + } + + KALDI_LOG << "Processed " << num_done << " utterances; failed with " + << num_err << " utterances; " + << "wrote " << num_segmentations << " segmentations " + << "with a total of " << num_segments << " segments."; + KALDI_LOG << "Number of frames for the different classes are : "; + WriteIntegerVector(KALDI_LOG, false, frame_counts_per_class); + + return (num_err < num_segmentations ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-diarization.cc b/src/segmenterbin/segmentation-init-from-diarization.cc new file mode 100644 index 00000000000..38065fc906c --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-diarization.cc @@ -0,0 +1,164 @@ +// segmenterbin/segmentation-init-from-diarization.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize segmentations from diarization file\n" + "\n" + "Usage: segmentation-init-from-diarization [options] diarization-rspecifier segments-rxfilename segmentation-out-wspecifier \n" + " e.g.: segmentation-init-from-diarization diarization segments ark:-\n"; + + bool binary = true, per_utt = false; + BaseFloat frame_shift = 0.01, overlap = 0.5; + + ParseOptions po(usage); + + po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("diarization-window-overlap", &overlap, "Overlap of diarization window in seconds"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string diar_rspecifier = po.GetArg(1), + segments_rxfilename = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + Input ki(segments_rxfilename); + RandomAccessInt32Reader diar_reader(diar_rspecifier); + SegmentationWriter writer(segmentation_wspecifier); + + int32 num_lines = 0, num_success = 0, num_segmentations = 0, num_err = 0; + + std::string line, prev_recording; + Segmentation seg; + + while (std::getline(ki.Stream(), line)) { + num_lines++; + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 4 && split_line.size() != 5) { + KALDI_WARN << "Invalid line in segments file: " << line; + num_err++; + continue; + } + std::string segment = split_line[0], + recording = split_line[1], + start_str = split_line[2], + end_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + num_err++; + continue; + } + if (!ConvertStringToReal(end_str, &end)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + num_err++; + continue; + } + + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || (end != -1.0 && end <= 0) || ((start >= end) && (end > 0))) { + KALDI_WARN << "Invalid line in segments file [empty or invalid segment]: " + << line; + num_err++; + continue; + } + + if (split_line.size() >= 5) + KALDI_ERR << "Not supporting channel in segments file"; + + if (!diar_reader.HasKey(segment)) { + KALDI_WARN << "Could not find diarization assignment for " + << "utterance " << segment; + num_err++; + continue; + } + int32 label = diar_reader.Value(segment); + + if (!per_utt) { + if (prev_recording != "" && prev_recording != recording) { + // Start of new recording + + // Fix the previous recording's last segment to remove any + // overlap-related adjustment + segmenter::SegmentList::iterator seg_it = seg.End(); + --seg_it; + seg_it->end_frame += std::round(overlap / 2 / frame_shift); + writer.Write(prev_recording, seg); + num_segmentations++; + seg.Clear(); + } + // Adjustment due to overlap with next segment + end -= overlap / 2; + + if (seg.Dim() != 0) + start += overlap / 2; + + seg.Emplace(std::round(start / frame_shift), + std::round(end / frame_shift) - 1, label); + } + + if (per_utt) { + seg.Emplace(0.0, std::round((end - start)/ frame_shift) - 1, label); + writer.Write(segment, seg); + num_segmentations++; + seg.Clear(); + } + + prev_recording = recording; + num_success++; + } + + if (!per_utt) { + writer.Write(prev_recording, seg); + num_segmentations++; + } + + KALDI_LOG << "Successfully processed " << num_success << " lines out of " + << num_lines << " in the segments file; wrote " + << num_segmentations << " segmentations; failed for " + << num_err << " segments"; + + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-lengths.cc b/src/segmenterbin/segmentation-init-from-lengths.cc new file mode 100644 index 00000000000..aba8eac4bc0 --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-lengths.cc @@ -0,0 +1,74 @@ +// segmenterbin/segmentation-init-from-lengths.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize segmentations from frame lengths file\n" + "\n" + "Usage: segmentation-init-from-lengths [options] \n" + " e.g.: segmentation-init-from-lengths \"ark:feat-to-len scp:feats.scp ark:- |\" ark:-\n"; + + int32 label = 1; + + ParseOptions po(usage); + + po.Register("label", &label, "Assign the segment this class_id"); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string lengths_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SequentialInt32Reader lengths_reader(lengths_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0; + + for (; !lengths_reader.Done(); lengths_reader.Next()) { + std::string key = lengths_reader.Key(); + const int32 &len = lengths_reader.Value(); + + Segmentation seg; + seg.Emplace(0, len - 1, label); + + segmentation_writer.Write(key, seg); + num_done++; + } + + KALDI_LOG << "Processed " << num_done << " utterances"; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-segments.cc b/src/segmenterbin/segmentation-init-from-segments.cc new file mode 100644 index 00000000000..e3d0a57fdf3 --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-segments.cc @@ -0,0 +1,141 @@ +// segmenterbin/segmentation-init-from-segments.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize segmentations from segments file\n" + "\n" + "Usage: segmentation-init-from-segments [options] segments-rxfilename segmentation-out-wspecifier \n" + " e.g.: segmentation-init-from-segments segments ark:-\n"; + + bool binary = true, per_utt = false; + int32 label = 1; + BaseFloat frame_shift = 0.01; + + ParseOptions po(usage); + + po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); + po.Register("label", &label, "Label for all the segments"); + po.Register("per-utt", &per_utt, "Get segmentation per utterance instead of " + "per file"); + po.Register("frame-shift", &frame_shift, "Frame shift"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segments_rxfilename = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + Input ki(segments_rxfilename); + SegmentationWriter writer(segmentation_wspecifier); + + int32 num_lines = 0, num_success = 0, num_segmentations = 0; + + std::string line, prev_recording; + Segmentation seg; + + while (std::getline(ki.Stream(), line)) { + num_lines++; + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 4 && split_line.size() != 5) { + KALDI_WARN << "Invalid line in segments file: " << line; + continue; + } + std::string segment = split_line[0], + recording = split_line[1], + start_str = split_line[2], + end_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(end_str, &end)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + continue; + } + + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || (end != -1.0 && end <= 0) || ((start >= end) && (end > 0))) { + KALDI_WARN << "Invalid line in segments file [empty or invalid segment]: " + << line; + continue; + } + + if (split_line.size() >= 5) + KALDI_ERR << "Not supporting channel in segments file"; + + if (!per_utt) { + if (prev_recording != "" && prev_recording != recording) { + writer.Write(prev_recording, seg); + num_segmentations++; + seg.Clear(); + } + seg.Emplace(std::round(start / frame_shift), + std::round(end / frame_shift) - 1, label); + } else { + seg.Emplace(0, std::round((end - start)/ frame_shift), label); + writer.Write(segment, seg); + num_segmentations++; + seg.Clear(); + } + + prev_recording = recording; + num_success++; + } + + if (!per_utt) { + seg.Sort(); + writer.Write(prev_recording, seg); + num_segmentations++; + } + + KALDI_LOG << "Successfully processed " << num_success << " lines out of " + << num_lines << " in the segments file; wrote " + << num_segmentations << " segmentations."; + + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + + diff --git a/src/segmenterbin/segmentation-intersect-segments.cc b/src/segmenterbin/segmentation-intersect-segments.cc new file mode 100644 index 00000000000..7458957943c --- /dev/null +++ b/src/segmenterbin/segmentation-intersect-segments.cc @@ -0,0 +1,131 @@ +// segmenterbin/segmentation-intersect-segments.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Intersect segments from two archives\n" + "\n" + "Usage: segmentation-intersect-segments [options] (segmentation-rpecifier1|segmentation-rxfilename1) (segmentation-rspecifier2|segmentation-rxfilename2) (segmentation-wspecifier|segmentation-wxfilename)\n" + " e.g.: segmentation-intersect-segments --binary=false foo bar -\n" + " segmentation-intersect-segments ark:foo.seg ark:bar.seg ark,t:-\n" + "See also: segmentation-merge, segmentation-copy, segmentation-post-process --merge-labels\n"; + + bool binary = true; + int32 mismatch_label = -1; + + ParseOptions po(usage); + + po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); + po.Register("mismatch-label", &mismatch_label, "Label to be added for the " + "mismatch segments"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + secondary_segmentation_in_fn = po.GetArg(2), + segmentation_out_fn = po.GetArg(3); + + + // all these "fn"'s are either rspecifiers or filenames. + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != (ClassifyRspecifier(secondary_segmentation_in_fn, NULL, NULL) != kNoRspecifier) || + in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation seg; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + seg.Read(ki.Stream(), binary_in); + } + + Segmentation secondary_seg; + { + bool binary_in; + Input ki(secondary_segmentation_in_fn, &binary_in); + secondary_seg.Read(ki.Stream(), binary_in); + } + + Segmentation out_seg; + seg.IntersectSegments(secondary_seg, &out_seg, mismatch_label); + + Output ko(segmentation_out_fn, binary); + out_seg.Write(ko.Stream(), binary); + KALDI_LOG << "Intersected segmentations " << segmentation_in_fn + << " and " << secondary_segmentation_in_fn << "; wrote " + << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader primary_reader(segmentation_in_fn); + RandomAccessSegmentationReader secondary_reader(secondary_segmentation_in_fn); + + for (; !primary_reader.Done(); primary_reader.Next()) { + const Segmentation &seg = primary_reader.Value(); + const std::string &key = primary_reader.Key(); + + if (!secondary_reader.HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << secondary_segmentation_in_fn; + num_err++; + continue; + } + const Segmentation &secondary_seg = secondary_reader.Value(key); + + Segmentation out_seg; + seg.IntersectSegments(secondary_seg, &out_seg, mismatch_label); + out_seg.Sort(); + + writer.Write(key, out_seg); + num_done++; + } + + KALDI_LOG << "Intersected " << num_done << " segmentations; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/segmenterbin/segmentation-merge.cc b/src/segmenterbin/segmentation-merge.cc new file mode 100644 index 00000000000..c6d564fabd7 --- /dev/null +++ b/src/segmenterbin/segmentation-merge.cc @@ -0,0 +1,124 @@ +// segmenterbin/segmentation-merge.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Merge corresponding segments from two archives\n" + "\n" + "Usage: segmentation-merge [options] (segmentation-rpecifier1|segmentation-rxfilename1) (segmentation-rspecifier2|segmentation-rxfilename2) ... (segmentation-rspecifier|segmentation-rxfilename) (segmentation-wspecifier|segmentation-wxfilename)\n" + " e.g.: segmentation-merge --binary=false foo bar -\n" + " segmentation-merge ark:foo.seg ark:bar.seg ark,t:-\n" + "See also: segmentation-copy, segmentation-post-process --merge-labels\n"; + + bool binary = true; + ParseOptions po(usage); + + po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); + + po.Read(argc, argv); + + if (po.NumArgs() <= 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(po.NumArgs()); + + + // all these "fn"'s are either rspecifiers or filenames. + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation seg; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + seg.Read(ki.Stream(), binary_in); + } + + for (int32 i = 2; i < po.NumArgs(); i++) { + bool binary_in; + Input ki(po.GetArg(i), &binary_in); + Segmentation other_seg; + other_seg.Read(ki.Stream(), binary_in); + seg.Extend(other_seg, false); + } + + Output ko(segmentation_out_fn, binary); + seg.Write(ko.Stream(), binary); + KALDI_LOG << "Copied segmentation to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + std::vector other_readers(po.NumArgs()-2, + static_cast(NULL)); + + for (size_t i = 0; i < po.NumArgs()-2; i++) { + other_readers[i] = new RandomAccessSegmentationReader(po.GetArg(i+2)); + } + + for (; !reader.Done(); reader.Next()) { + Segmentation seg(reader.Value()); + std::string key = reader.Key(); + + for (size_t i = 0; i < po.NumArgs()-2; i++) { + if (!other_readers[i]->HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << po.GetArg(i+2); + num_err++; + } + const Segmentation &other_seg = other_readers[i]->Value(key); + seg.Extend(other_seg, false); + } + seg.Sort(); + + writer.Write(key, seg); + num_done++; + } + + KALDI_LOG << "Merged " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-post-process.cc b/src/segmenterbin/segmentation-post-process.cc new file mode 100644 index 00000000000..67f65b1abee --- /dev/null +++ b/src/segmenterbin/segmentation-post-process.cc @@ -0,0 +1,128 @@ +// segmenterbin/segmentation-post-process.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Post processing of segmentation that does the following operations " + "in order: \n" + "1) Intersection or Filtering: Intersects the input segmentation with " + "segments from the segmentation in 'filter-rspecifier' and retains " + "only regions where the segment in the filter has the class-id " + "'filter-label'. See method IntersectSegments() for details.\n" + "2) Merge labels: Merge labels specified in 'merge-labels' into a " + "single label 'label'. Any segment that has class_id that is contained " + "in 'merge-labels' is assigned class_id 'label'. " + "See method MergeLabels() for details.\n" + "3) Widen segments: Widen segments of label 'widen-label' by " + "'widen-length' frames on either side of the segment. This process " + "also shrinks the adjacent segments so that it does not overlap with " + "the widened segment or merges the adjacent segment into a composite " + "segment if they both have the same class_id. " + "See method WidenSegment() for details.\n" + "4) with the \n" + "Usage: segmentation-post-process [options] (segmentation-in-rspecifier|segmentation-in-rxfilename) (segmentation-out-wspecifier|segmentation-out-wxfilename)\n" + " e.g.: segmentation-post-process --binary=false foo -\n" + " segmentation-post-process ark:1.ali ark,t:-\n" + "See also: segmentation-merge, segmentation-copy, segmentation-remove-segments\n"; + + bool binary = true; + + ParseOptions po(usage); + + SegmentationPostProcessingOptions opts; + + po.Register("binary", &binary, + "Write in binary mode (only relevant if output is a wxfilename)"); + + opts.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + SegmentationPostProcessor post_processor(opts); + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation seg; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + seg.Read(ki.Stream(), binary_in); + } + if (post_processor.PostProcess(&seg)) { + Output ko(segmentation_out_fn, binary); + seg.Write(ko.Stream(), binary); + KALDI_LOG << "Post-processed segmentation " << segmentation_in_fn + << " and wrote " << segmentation_out_fn; + return 0; + } + KALDI_LOG << "Failed post-processing segmentation " + << segmentation_in_fn ; + return 1; + } + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + for (; !reader.Done(); reader.Next()){ + Segmentation seg(reader.Value()); + std::string key = reader.Key(); + + if (!post_processor.FilterAndPostProcess(&seg, &key)) { + num_err++; + continue; + } + + writer.Write(key, seg); + num_done++; + } + + KALDI_LOG << "Successfully post-processed " << num_done + << " segmentations; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-remove-segments.cc b/src/segmenterbin/segmentation-remove-segments.cc new file mode 100644 index 00000000000..a8098dee9de --- /dev/null +++ b/src/segmenterbin/segmentation-remove-segments.cc @@ -0,0 +1,126 @@ +// segmenterbin/segmentation-remove-segments.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Remove segments of particular class_id (e.g silence or noise) " + "or a set of class_ids" + "\n" + "Usage: segmentation-remove-segments [options] (segmentation-in-rspecifier|segmentation-in-rxfilename) (segmentation-out-wspecifier|segmentation-out-wxfilename)\n" + " e.g.: segmentation-remove-segments --remove-label=0 ark:foo.ark ark:foo.speech.ark\n" + "See also: segmentation-post-process, segmentation-merge, segmentation-copy\n"; + + bool binary = true; + int32 remove_label = -1; + std::string remove_labels_rspecifier = ""; + ParseOptions po(usage); + + po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); + po.Register("remove-label", &remove_label, "Remove segments of this label"); + po.Register("remove-labels-rspecifier", &remove_labels_rspecifier, "Specify colon separated list of labels for each recording"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_missing = 0; + + if (!in_is_rspecifier) { + Segmentation seg; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + seg.Read(ki.Stream(), binary_in); + } + seg.RemoveSegments(remove_label); + Output ko(segmentation_out_fn, binary); + seg.Write(ko.Stream(), binary); + KALDI_LOG << "Copied segmentation to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + + RandomAccessTokenReader remove_labels_reader(remove_labels_rspecifier); + + for (; !reader.Done(); reader.Next(), num_done++) { + Segmentation seg(reader.Value()); + std::string key = reader.Key(); + + if (remove_labels_rspecifier != "") { + if (!remove_labels_reader.HasKey(key)) { + KALDI_WARN << "No remove-labels found for recording " << key; + num_missing++; + writer.Write(key, seg); + continue; + } + std::vector merge_labels; + const std::string& remove_labels_str = remove_labels_reader.Value(key); + + if (!SplitStringToIntegers(remove_labels_str, ":", false, + &merge_labels)) { + KALDI_ERR << "Bad CSL " << remove_labels_str; + } + + remove_label = merge_labels[0]; + seg.MergeLabels(merge_labels, remove_label); + } + + seg.RemoveSegments(remove_label); + + writer.Write(key, seg); + } + + KALDI_LOG << "Removed segments " + << "from " << num_done << " segmentations; " + << "remove-labels missing for " << num_missing; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-select-top.cc b/src/segmenterbin/segmentation-select-top.cc new file mode 100644 index 00000000000..26c08e3134e --- /dev/null +++ b/src/segmenterbin/segmentation-select-top.cc @@ -0,0 +1,145 @@ +// segmenterbin/segmentation-select-top.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Select top segments from the segmentations and write new segmentation\n" + "\n" + "Usage: segmentation-select-top [options] \n" + " e.g.: segmentation-select-top ark:1.seg ark:1.log_energies ark:-\n"; + + ParseOptions po(usage); + + int32 top_select_label = 3, bottom_select_label = 1; + int32 reject_label = 4; + int32 num_top_frames = 10000, num_bottom_frames = 2000; + int32 window_size = 100, min_remainder = 50; + bool remove_rejected_frames = false; + + SegmentationPostProcessingOptions opts; + HistogramOptions hist_opts; + + int32 &src_label = opts.merge_dst_label; + + po.Register("src-label", &src_label, "Select top segments of only this " + " class label"); + po.Register("num-top-frames", &num_top_frames, "Number of frames to " + "select from the top half"); + po.Register("num-bottom-frames", &num_bottom_frames, "Number of frames to " + "select from the bottom half"); + po.Register("top-select-label", &top_select_label, "The label to assign " + "for the selected top segments"); + po.Register("bottom-select-label", &bottom_select_label, "The label to " + "assign for the selected bottom segments"); + po.Register("reject-label", &reject_label, "The label assigned to " + "segments that are binned in histogram but do not make it to " + "the top or bottom"); + po.Register("window-size", &window_size, "Split segments into windows of " + "this size"); + po.Register("min-window-remainder", &min_remainder, "Do not split segment " + "if final piece is smaller than this size"); + po.Register("remove-rejected-frames", &remove_rejected_frames, "If true, " + "then remove the chunks that are not selected"); + opts.Register(&po); + hist_opts.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + SegmentationPostProcessor post_processor(opts); + + std::string segmentation_rspecifier = po.GetArg(1), + scores_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + SequentialSegmentationReader segmentation_reader(segmentation_rspecifier); + RandomAccessBaseFloatVectorReader scores_reader(scores_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int64 num_done = 0, num_err = 0, num_selected_top = 0, num_selected_bottom = 0; + + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + std::string key = segmentation_reader.Key(); + if (!scores_reader.HasKey(key)) { + KALDI_WARN << "Could not read scores for utterance " << key; + num_err++; + continue; + } + + const Segmentation &in_seg = segmentation_reader.Value(); + const Vector &scores = scores_reader.Value(key); + + Segmentation out_seg(in_seg); // Make a copy + + post_processor.Filter(key, &out_seg); + post_processor.MergeLabels(&out_seg); + + out_seg.SplitSegments(window_size, min_remainder); + + HistogramEncoder hist_encoder; + out_seg.CreateHistogram(src_label, scores, hist_opts, &hist_encoder); + + if (top_select_label == -1) + num_selected_bottom += out_seg.SelectBottomBins(hist_encoder, src_label, + bottom_select_label, + reject_label, num_bottom_frames, + remove_rejected_frames); + else if (bottom_select_label == -1) + num_selected_top += out_seg.SelectTopBins(hist_encoder, src_label, + top_select_label, reject_label, num_top_frames, + remove_rejected_frames); + else { + std::pair p = out_seg.SelectTopAndBottomBins(hist_encoder, src_label, + top_select_label, num_top_frames, + bottom_select_label, num_bottom_frames, reject_label, + remove_rejected_frames); + num_selected_top += p.first; + num_selected_bottom += p.second; + } + + segmentation_writer.Write(key, out_seg); + num_done++; + } + + KALDI_LOG << "Processed " << num_done << " segmentations; " + << "error in " << num_err << "; " + << "Selected " << num_selected_top << " and " + << num_selected_bottom << " top and bottom frames respectively"; + + return (num_done == 0 || num_err >= num_done ? 1 : 0); + + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/segmenterbin/segmentation-to-ali.cc b/src/segmenterbin/segmentation-to-ali.cc new file mode 100644 index 00000000000..ea13dc56775 --- /dev/null +++ b/src/segmenterbin/segmentation-to-ali.cc @@ -0,0 +1,100 @@ +// segmenterbin/segmentation-to-ali.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation to alignment\n" + "\n" + "Usage: segmentation-to-ali [options] segmentation-rspecifier ali-wspecifier\n" + " e.g.: segmentation-to-ali ark:1.seg ark:1.ali\n"; + + std::string lengths_rspecifier; + int32 default_label = 0, frame_tolerance = 2; + + ParseOptions po(usage); + + po.Register("lengths", &lengths_rspecifier, "Archive of frame lengths " + "of the utterances. Fills up any extra length with " + "the specified default-label"); + po.Register("default-label", &default_label, "Fill any extra length " + "with this label"); + po.Register("frame-tolerance", &frame_tolerance, "Tolerate shortage of " + "this many frames in the specified lengths file"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1); + std::string alignment_wspecifier = po.GetArg(2); + + RandomAccessInt32Reader lengths_reader(lengths_rspecifier); + + SequentialSegmentationReader segmentation_reader(segmentation_rspecifier); + Int32VectorWriter alignment_writer(alignment_wspecifier); + + int32 num_err = 0, num_done = 0; + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + Segmentation seg(segmentation_reader.Value()); + std::string key = segmentation_reader.Key(); + + int32 len = -1; + if (lengths_rspecifier != "") { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for utterance " << key; + num_err++; + continue; + } + len = lengths_reader.Value(key); + } + + std::vector ali; + if (!seg.ConvertToAlignment(&ali, default_label, len, frame_tolerance)) { + KALDI_WARN << "Conversion failed for utterance " << key; + num_err++; + continue; + } + alignment_writer.Write(key, ali); + num_done++; + } + + KALDI_LOG << "Converted " << num_done << " segmentation into alignments; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + + + diff --git a/src/segmenterbin/segmentation-to-rttm.cc b/src/segmenterbin/segmentation-to-rttm.cc new file mode 100644 index 00000000000..4e10ef43360 --- /dev/null +++ b/src/segmenterbin/segmentation-to-rttm.cc @@ -0,0 +1,203 @@ +// segmenterbin/segmentation-to-rttm.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation into RTTM\n" + "\n" + "Usage: segmentation-to-rttm [options] segmentation-in-rspecifier rttm-out-wxfilename\n" + " e.g.: segmentation-to-rttm foo -\n" + " segmentation-to-rttm ark:1.seg -\n"; + + bool binary = true; + bool map_to_speech_and_sil = true; + + BaseFloat frame_shift = 0.01; + std::string segments_rxfilename; + std::string reco2file_and_channel_rxfilename; + ParseOptions po(usage); + + po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("segments", &segments_rxfilename, "Segments file"); + po.Register("reco2file-and-channel", &reco2file_and_channel_rxfilename, "reco2file_and_channel file"); + po.Register("map-to-speech-and-sil", &map_to_speech_and_sil, "Map all classes to SPEECH and SILENCE"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + unordered_map utt2file; + unordered_map utt2start_time; + + if (segments_rxfilename != "") { + Input ki(segments_rxfilename); // no binary argment: never binary. + int32 i = 0; + std::string line; + /* read each line from segments file */ + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 4 && split_line.size() != 5) { + KALDI_WARN << "Invalid line in segments file: " << line; + continue; + } + std::string segment = split_line[0], + utterance = split_line[1], + start_str = split_line[2], + end_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(end_str, &end)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + continue; + } + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || end <= 0 || start >= end) { + KALDI_WARN << "Invalid line in segments file [empty or invalid segment]: " + << line; + continue; + } + int32 channel = -1; // means channel info is unspecified. + // if each line has 5 elements then 5th element must be channel identifier + if(split_line.size() == 5) { + if (!ConvertStringToInteger(split_line[4], &channel) || channel < 0) { + KALDI_WARN << "Invalid line in segments file [bad channel]: " << line; + continue; + } + } + + utt2file.insert(std::make_pair(segment, utterance)); + utt2start_time.insert(std::make_pair(segment, start)); + i++; + } + KALDI_LOG << "Read " << i << " lines from " << segments_rxfilename; + } + + std::unordered_map , StringHasher> reco2file_and_channel; + + if (!reco2file_and_channel_rxfilename.empty()) { + Input ki(reco2file_and_channel_rxfilename); // no binary argment: never binary. + + int32 i = 0; + std::string line; + /* read each line from reco2file_and_channel file */ + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 3) { + KALDI_WARN << "Invalid line in reco2file_and_channel file: " << line; + continue; + } + + const std::string &reco_id = split_line[0]; + const std::string &file_id = split_line[1]; + const std::string &channel = split_line[2]; + + reco2file_and_channel.insert(std::make_pair(reco_id, std::make_pair(file_id, channel))); + i++; + } + + KALDI_LOG << "Read " << i << " lines from " << reco2file_and_channel_rxfilename; + } + + std::unordered_set seen_files; + + std::string segmentation_rspecifier = po.GetArg(1), + rttm_out_wxfilename = po.GetArg(2); + + int64 num_done = 0, num_err = 0; + + Output ko(rttm_out_wxfilename, false); + SequentialSegmentationReader reader(segmentation_rspecifier); + for (; !reader.Done(); reader.Next(), num_done++) { + Segmentation seg(reader.Value()); + const std::string &key = reader.Key(); + + std::string reco_id = key; + BaseFloat start_time = 0.0; + if (!segments_rxfilename.empty()) { + if (utt2file.count(key) == 0 || utt2start_time.count(key) == 0) + KALDI_ERR << "Could not find key " << key << " in segments " + << segments_rxfilename; + KALDI_ASSERT(utt2file.count(key) > 0 && utt2start_time.count(key) > 0); + reco_id = utt2file[key]; + start_time = utt2start_time[key]; + } + + std::string file_id, channel; + if (!reco2file_and_channel_rxfilename.empty()) { + if (reco2file_and_channel.count(reco_id) == 0) + KALDI_ERR << "Could not find recording " << reco_id + << " in " << reco2file_and_channel_rxfilename; + file_id = reco2file_and_channel[reco_id].first; + channel = reco2file_and_channel[reco_id].second; + } else { + file_id = reco_id; + channel = "1"; + } + + int32 largest_class = seg.WriteRttm(ko.Stream(), file_id, channel, frame_shift, start_time, map_to_speech_and_sil); + + if (map_to_speech_and_sil) { + if (seen_files.count(reco_id) == 0) { + ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown SILENCE \n"; + ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown SPEECH \n"; + seen_files.insert(reco_id); + } + } else { + for (int32 i = 0; i < largest_class; i++) { + ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown " << i << " \n"; + } + } + } + + KALDI_LOG << "Copied " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + + diff --git a/src/segmenterbin/segmentation-to-segments.cc b/src/segmenterbin/segmentation-to-segments.cc new file mode 100644 index 00000000000..e40b9c7f774 --- /dev/null +++ b/src/segmenterbin/segmentation-to-segments.cc @@ -0,0 +1,122 @@ +// segmenterbin/segmentation-to-segments.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation to a segments file and utt2spk file." + "Assumes that the segmentations are indexed by file-id and " + "treats speakers from different files as distinct speakers." + "\n" + "Usage: segmentation-to-segments [options] \n" + " e.g.: segmentation-to-segments ark:foo ark,t:utt2spk ark,t:segments\n"; + + BaseFloat frame_shift = 0.01, frame_overlap = 0.015; + bool single_speaker = false; + ParseOptions po(usage); + + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("frame-overlap", &frame_overlap, "Frame overlap in seconds"); + po.Register("single-speaker", &single_speaker, "If this is set true, then " + "each file is considered to contain only a single speaker " + "and all the utterances are mapped to that in utt2spk file"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + if (frame_shift < 0.001 || frame_shift > 1) { + KALDI_ERR << "Invalid frame-shift " << frame_shift << "; must be in " + << "the range [0.001,1]"; + } + + if (frame_overlap < 0 || frame_overlap > 1) { + KALDI_ERR << "Invalid frame-overlap " << frame_overlap << "; must be in " + << "the range [0,1]"; + } + + std::string segmentation_rspecifier = po.GetArg(1), + utt2spk_wspecifier = po.GetArg(2), + segments_wspecifier = po.GetArg(3); + + SequentialSegmentationReader reader(segmentation_rspecifier); + TokenWriter utt2spk_writer(utt2spk_wspecifier); + UtteranceSegmentWriter segments_writer(segments_wspecifier); + + int32 num_done = 0; + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &seg = reader.Value(); + const std::string &key = reader.Key(); + + std::string file_id = key; + + int32 i = 0; + for (SegmentList::const_iterator it = seg.Begin(); + it != seg.End(); ++it, i++) { + UtteranceSegment segment; + + segment.reco_id = key; + segment.start_time = it->start_frame * frame_shift; + segment.end_time = (it->end_frame + 1) * frame_shift + frame_overlap; + + std::ostringstream oss; + + if (!single_speaker) { + oss << key << "-" << it->Label(); + } else { + oss << key; + } + + std::string spk = oss.str(); + + oss << "-"; + oss << std::setw(6) << std::setfill('0') << it->start_frame; + oss << std::setw(1) << "-"; + oss << std::setw(6) << std::setfill('0') + << it->end_frame + 1 + + static_cast(frame_overlap / frame_shift); + + std::string utt = oss.str(); + + utt2spk_writer.Write(utt, spk); + segments_writer.Write(utt, segment); + } + } + + KALDI_LOG << "Converted" << num_done << " segmentations to segments"; + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/select-feats-from-segmentation.cc b/src/segmenterbin/select-feats-from-segmentation.cc new file mode 100644 index 00000000000..69016d188a1 --- /dev/null +++ b/src/segmenterbin/select-feats-from-segmentation.cc @@ -0,0 +1,119 @@ +// segmenterbin/segmentation-select-top.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmenter.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Select top segments from the segmentations and write new segmentation\n" + "\n" + "Usage: select-feats-from-segmentation [options] \n" + " e.g.: select-feats-from-segmentation ark:1.feats ark:1.seg ark:-\n"; + + ParseOptions po(usage); + + SegmentationPostProcessingOptions opts; + int32 &select_label = opts.merge_dst_label; + int32 selection_padding = 0; + + po.Register("select-label", &select_label, "Select frames of only this " + "class label"); + po.Register("selection-padding", &selection_padding, "If this is > 0, then " + "this number of frames at the boundary are not selected." + "Similar to program select-interior-frames."); + + opts.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + SegmentationPostProcessor post_processor(opts); + + std::string feats_rspecifier = po.GetArg(1), + segmentation_rspecifier = po.GetArg(2), + feats_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader feats_reader(feats_rspecifier); + RandomAccessSegmentationReader segmentation_reader(segmentation_rspecifier); + BaseFloatMatrixWriter feats_writer(feats_wspecifier); + + int64 num_done = 0, num_err = 0, num_frames_selected = 0, num_frames = 0; + + for (; !feats_reader.Done(); feats_reader.Next()) { + std::string key = feats_reader.Key(); + if (!segmentation_reader.HasKey(key)) { + KALDI_WARN << "Could not read segmentation for utterance " << key; + num_err++; + continue; + } + + const Matrix &feats_in = feats_reader.Value(); + const Segmentation &in_seg = segmentation_reader.Value(key); + + Segmentation seg(in_seg); + post_processor.MergeLabels(&seg); + + Matrix feats_out(feats_in.NumRows(), feats_in.NumCols()); + int32 j = 0; + for (SegmentList::const_iterator it = seg.Begin(); + it != seg.End(); ++it) { + if (it->Label() != select_label || + it->end_frame - it->start_frame + 1 <= 2 * selection_padding) continue; + const SubMatrix this_feats_in(feats_in, + it->start_frame + selection_padding, + it->end_frame - it->start_frame + 1 - 2 * selection_padding, + 0, feats_in.NumCols()); + SubMatrix this_feats_out(feats_out, j, + it->end_frame - it->start_frame + 1 - 2 * selection_padding, + 0, feats_in.NumCols()); + this_feats_out.CopyFromMat(this_feats_in); + j += this_feats_in.NumRows(); + num_frames_selected += this_feats_in.NumRows(); + } + + num_frames += feats_in.NumRows(); + // If no frames are selected, then we don't write anything + if (j > 0) { + feats_out.Resize(j, feats_in.NumCols(), kCopyData); + feats_writer.Write(key, feats_out); + } + num_done++; + } + + KALDI_LOG << "Processed " << num_done << " segmentations; " + << "selected " << num_frames_selected << " out of " + << num_frames << " frames"; + + return (num_frames > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/util/kaldi-holder-inl.h b/src/util/kaldi-holder-inl.h index 8130961c0fa..5dffc49b49a 100644 --- a/src/util/kaldi-holder-inl.h +++ b/src/util/kaldi-holder-inl.h @@ -28,6 +28,7 @@ #include "util/kaldi-io.h" #include "util/text-utils.h" #include "matrix/kaldi-matrix.h" +#include "base/kaldi-extra-types.h" namespace kaldi { @@ -685,6 +686,82 @@ class TokenVectorHolder { T t_; }; +// Holder for segments +class UtteranceSegmentHolder { + public: + typedef UtteranceSegment T; + + UtteranceSegmentHolder() { } + + static bool Write(std::ostream &os, bool, const T &t) { // ignore binary-mode. + KALDI_ASSERT(IsToken(t.reco_id) && IsToken(t.channel_id) && + t.end_time > t.start_time); + os << t.reco_id << ' ' << t.start_time << ' ' + << t.end_time; + + if (t.channel_id != "-1") + os << ' ' << t.channel_id; + + os << '\n'; + return os.good(); + } + + void Clear() { + t_.Reset(); + } + + // Reads into the holder. + bool Read(std::istream &is) { + Clear(); + // there is no binary/non-binary mode. + + std::string line; + getline(is, line); // this will discard the \n, if present. + if (is.fail()) { + KALDI_WARN << "UtteranceSegmentHolder::Read, error reading line " << (is.eof() ? "[eof]" : ""); + return false; // probably eof. fail in any case. + } + const char *white_chars = " \t\n\r\f\v"; + std::vector split; + SplitStringToVector(line, white_chars, true, &split); // true== omit empty strings e.g. + // between spaces. + + KALDI_ASSERT(split.size() == 4 || split.size() == 3); + + t_.reco_id = split[0]; + + if (!ConvertStringToReal(split[1], &t_.start_time)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + return false; + } + if (!ConvertStringToReal(split[2], &t_.end_time)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + return false; + } + + if (t_.end_time < t_.start_time) { + KALDI_WARN << "Invalid start and end times in line: " << line; + return false; + } + + if (split.size() == 4) { + t_.channel_id = split[3]; + } else + t_.channel_id = "-1"; + + return true; + } + + // Read in text format since it's basically a text-mode thing.. doesn't really matter, + // it would work either way since we ignore the extra '\r'. + static bool IsReadInBinary() { return false; } + + const T &Value() const { return t_; } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(UtteranceSegmentHolder); + T t_; +}; class HtkMatrixHolder { public: diff --git a/src/util/table-types.h b/src/util/table-types.h index 819c98fdf82..daef5542355 100644 --- a/src/util/table-types.h +++ b/src/util/table-types.h @@ -169,6 +169,10 @@ typedef RandomAccessTableReader RandomAccessTokenVectorReader; +typedef TableWriter UtteranceSegmentWriter; +typedef SequentialTableReader SequentialUtteranceSegmentReader; +typedef RandomAccessTableReader RandomAccessUtteranceSegmentReader; + /// @} // Note: for FST reader/writer, see ../fstext/fstext-utils.h diff --git a/src/vadbin/Makefile b/src/vadbin/Makefile new file mode 100644 index 00000000000..0396509557d --- /dev/null +++ b/src/vadbin/Makefile @@ -0,0 +1,22 @@ + +all: + +EXTRA_CXXFLAGS = -Wno-sign-compare +include ../kaldi.mk + +BINFILES = compute-vad select-voiced-frames \ + create-split-from-vad \ + select-top-frames select-top-chunks \ + select-interior-frames vector-extract-dims + +OBJFILES = + + + +TESTFILES = + + +ADDLIBS = ../thread/kaldi-thread.a ../matrix/kaldi-matrix.a \ + ../util/kaldi-util.a ../base/kaldi-base.a + +include ../makefiles/default_rules.mk diff --git a/src/vadbin/select-interior-frames.cc b/src/vadbin/select-interior-frames.cc new file mode 100644 index 00000000000..9ff54346896 --- /dev/null +++ b/src/vadbin/select-interior-frames.cc @@ -0,0 +1,164 @@ +// ivectorbin/select-interior-frames.cc + +// Copyright 2013 Daniel Povey +// 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" +#include "feat/feature-functions.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using kaldi::int32; + + const char *usage = + "Select a subset of frames of the input files, based on the output of\n" + "compute-vad or a similar program (a vector of length num-frames,\n" + "containing 1.0 for voiced, 0.0 for unvoiced).\n" + "Usage: select-voiced-frames [options] " + " \n" + "E.g.: select-voiced-frames [options] scp:feats.scp scp:vad.scp ark:-\n"; + + bool select_unvoiced_frames = false; + int32 padding = 0; + + ParseOptions po(usage); + po.Register("select-unvoiced-frames", &select_unvoiced_frames, + "Reverses the operation of this file and selects " + "unvoiced frames instead"); + po.Register("padding", &padding, + "Ignore frames at a boundary of this many frames"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string feat_rspecifier = po.GetArg(1), + vad_rspecifier = po.GetArg(2), + feat_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader feat_reader(feat_rspecifier); + RandomAccessInt32VectorReader vad_reader(vad_rspecifier); + BaseFloatMatrixWriter feat_writer(feat_wspecifier); + + int32 num_done = 0, num_err = 0; + long long num_frames = 0, num_select = 0; + + for (;!feat_reader.Done(); feat_reader.Next()) { + std::string utt = feat_reader.Key(); + const Matrix &feat = feat_reader.Value(); + if (feat.NumRows() == 0) { + KALDI_WARN << "Empty feature matrix for utterance " << utt; + num_err++; + continue; + } + if (!vad_reader.HasKey(utt)) { + KALDI_WARN << "No VAD input found for utterance " << utt; + num_err++; + continue; + } + const std::vector &voiced = vad_reader.Value(utt); + + if (std::abs(static_cast(feat.NumRows()) - static_cast(voiced.size())) > 1) { + KALDI_WARN << "Mismatch in number for frames " << feat.NumRows() + << " for features and VAD " << voiced.size() + << ", for utterance " << utt; + num_err++; + continue; + } + int32 dim = 0; + for (std::vector::const_iterator it = voiced.begin(); + it != voiced.end(); ++it) { + if (!select_unvoiced_frames) { + if (*it != 0) + dim++; + } else { + if (*it == 0) + dim++; + } + } + + if (dim == 0) { + if (select_unvoiced_frames) { + KALDI_WARN << "No unvoiced frames found for utterance " << utt; + } else { + KALDI_WARN << "No voiced frames found in utterance " << utt; + } + num_err++; + continue; + } + Matrix voiced_feat(dim, feat.NumCols()); + int32 index = 0; + bool voiced_state = false; + int32 start_idx = 0, end_idx = 0; + for (int32 i = 0; i < std::min(static_cast(feat.NumRows()),static_cast(voiced.size())); i++) { + if ((!voiced_state && voiced[i] != 0) || (voiced_state && voiced[i] == 0)) { + // Reached voiced state from unvoiced state + // or unvoiced state from voiced state + end_idx = i; + if ((!voiced_state && select_unvoiced_frames) || (voiced_state && !select_unvoiced_frames)) { + if (end_idx - start_idx > 2 * padding && start_idx + padding < feat.NumRows()) { + KALDI_ASSERT(index < voiced_feat.NumRows() && index + end_idx - start_idx - 2 * padding <= voiced_feat.NumRows()); + SubMatrix src_feat(feat, start_idx + padding, end_idx - start_idx - 2 * padding, 0, feat.NumCols()); + SubMatrix dst_feat(voiced_feat, index, end_idx - start_idx - 2 * padding, 0, feat.NumCols()); + dst_feat.CopyFromMat(src_feat); + index += end_idx - start_idx - 2 * padding; + } + } + start_idx = i; + voiced_state = !voiced_state; + } + } + + if (!voiced_state && select_unvoiced_frames) { + end_idx = std::min(static_cast(feat.NumRows()),static_cast(voiced.size())); + if (end_idx - start_idx > 2 * padding && start_idx + padding < feat.NumRows()) { + KALDI_ASSERT(index < voiced_feat.NumRows() && index + end_idx - start_idx - 2 * padding <= voiced_feat.NumRows()); + SubMatrix src_feat(feat, start_idx + padding, end_idx - start_idx - 2 * padding, 0, feat.NumCols()); + SubMatrix dst_feat(voiced_feat, index, end_idx - start_idx - 2 * padding, 0, feat.NumCols()); + dst_feat.CopyFromMat(src_feat); + } + } + + feat_writer.Write(utt, voiced_feat); + num_select += voiced_feat.NumRows(); + num_frames += feat.NumRows(); + + num_done++; + } + + KALDI_LOG << "Done selecting " << num_select << " voiced frames" + << " out of " << num_frames << " frames; processed " + << num_done << " utterances, " + << num_err << " had errors."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/vadbin/vector-extract-dims.cc b/src/vadbin/vector-extract-dims.cc new file mode 100644 index 00000000000..5ba1038bc7c --- /dev/null +++ b/src/vadbin/vector-extract-dims.cc @@ -0,0 +1,136 @@ +// bin/vector-extract-dims.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-vector.h" +#include "transform/transform-common.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Extract only some dimensions of a vector\n" + "\n" + "Usage: vector-extract-dims [options] \n" + " e.g.: copy-vector ark:2.vec ark:2.mask ark,t:-\n" + "see also: extract-rows, select-voiced-frames\n"; + + bool select_unmasked_dims = false; + + ParseOptions po(usage); + po.Register("select-unmasked-dims", &select_unmasked_dims, + "Reverses the operation of this file and selects " + "dimensions that have 0 in the mask"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string vector_rspecifier = po.GetArg(1), + mask_rspecifier = po.GetArg(2), + vector_wspecifier = po.GetArg(3); + + SequentialBaseFloatVectorReader vector_reader(vector_rspecifier); + RandomAccessBaseFloatVectorReader mask_reader(mask_rspecifier); + BaseFloatVectorWriter vector_writer(vector_wspecifier); + + int32 num_done = 0, num_err = 0; + long long num_dims = 0, num_select = 0; + + for (;!vector_reader.Done(); vector_reader.Next()) { + std::string utt = vector_reader.Key(); + const Vector &vec = vector_reader.Value(); + if (vec.Dim() == 0) { + KALDI_WARN << "Empty feature matrix for utterance " << utt; + num_err++; + continue; + } + if (!mask_reader.HasKey(utt)) { + KALDI_WARN << "No VAD input found for utterance " << utt; + num_err++; + continue; + } + const Vector &mask = mask_reader.Value(utt); + + if (vec.Dim() != mask.Dim()) { + KALDI_WARN << "Mismatch in number for dimensions " << vec.Dim() + << " for vector and mask " << mask.Dim() + << ", for utterance " << utt; + num_err++; + continue; + } + + int32 dim = 0; + for (int32 i = 0; i < mask.Dim(); i++) + if (!select_unmasked_dims) { + if (mask(i) != 0.0) + dim++; + } else { + if (mask(i) == 0.0) + dim++; + } + + if (dim == 0) { + KALDI_WARN << "No dimensions were selected for utterance " + << utt; + num_err++; + continue; + } + + Vector masked_vec(dim); + + int32 index = 0; + for (int32 i = 0; i < vec.Dim(); i++) { + if (!select_unmasked_dims) { + if (mask(i) != 0.0) { + KALDI_ASSERT(mask(i) == 1.0); // should be zero or one. + masked_vec(index) = vec(i); + index++; + } + } else { + if (mask(i) == 0.0) { + masked_vec(index) = vec(i); + index++; + } + } + } + KALDI_ASSERT(index == dim); + vector_writer.Write(utt, masked_vec); + num_done++; + num_select += dim; + num_dims += vec.Dim(); + } + + KALDI_LOG << "Done selecting " << num_select << " unmasked dimensions " + << "out of " << num_dims << " dims total dimensions ; processed " + << num_done << " utterances, " + << num_err << " had errors."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/tools/Makefile b/tools/Makefile index 22e91e8677b..36c33e16e71 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -3,9 +3,9 @@ CXX = g++ # CXX = clang++ # Uncomment this line to build with Clang. -OPENFST_VERSION = 1.3.4 +# OPENFST_VERSION = 1.3.4 # Uncomment the next line to build with OpenFst-1.4.1. -# OPENFST_VERSION = 1.4.1 +OPENFST_VERSION = 1.4.1 # Note: OpenFst >= 1.4 requires C++11 support, hence you will need to use a # relatively recent C++ compiler, e.g. gcc >= 4.6, clang >= 3.0.