diff --git a/.gitignore b/.gitignore index 5764bfe22..ed66c79d1 100644 --- a/.gitignore +++ b/.gitignore @@ -77,6 +77,8 @@ GSYMS /egs/*/*/plp /egs/*/*/exp /egs/*/*/data +/egs/*/*/wav +/egs/*/*/enhan # /tools/ /tools/pocolm/ @@ -149,3 +151,8 @@ GSYMS /tools/cub-1.8.0/ /tools/cub /tools/python/ + +# These CMakeLists.txt files are all genareted on the fly at the moment. +# They are added here to avoid accidently checkin. +/src/**/CMakeLists.txt +/build* diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..748d88a35 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,195 @@ +cmake_minimum_required(VERSION 3.5) +project(kaldi) + +set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}") +include(GNUInstallDirs) +include(Utils) +include(third_party/get_third_party) + +message(STATUS "Running gen_cmake_skeleton.py") +execute_process(COMMAND python + "${CMAKE_CURRENT_SOURCE_DIR}/cmake/gen_cmake_skeleton.py" + "${CMAKE_CURRENT_SOURCE_DIR}/src" + "--quiet" +) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_INSTALL_MESSAGE LAZY) # hide "-- Up-to-date: ..." +if(BUILD_SHARED_LIBS) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + if(WIN32) + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) + message(FATAL_ERROR "DLL is not supported currently") + elseif(APPLE) + set(CMAKE_INSTALL_RPATH "@loader_path") + else() + set(CMAKE_INSTALL_RPATH "$ORIGIN;$ORIGIN/../lib") + endif() +endif() + +set(MATHLIB "OpenBLAS" CACHE STRING "OpenBLAS|MKL|Accelerate") +option(KALDI_BUILD_EXE "If disabled, will make add_kaldi_executable a no-op" ON) +option(KALDI_BUILD_TEST "If disabled, will make add_kaldi_test_executable a no-op" ON) +option(KALDI_USE_PATCH_NUMBER "Use MAJOR.MINOR.PATCH format, otherwise MAJOR.MINOR" OFF) + +link_libraries(${CMAKE_DL_LIBS}) + +find_package(Threads) +link_libraries(Threads::Threads) + +if(MATHLIB STREQUAL "OpenBLAS") + set(BLA_VENDOR "OpenBLAS") + find_package(LAPACK REQUIRED) + add_definitions(-DHAVE_CLAPACK=1) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/tools/CLAPACK) + link_libraries(${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}) +elseif(MATHLIB STREQUAL "MKL") + set(BLA_VENDOR "Intel10_64lp") + # find_package(BLAS REQUIRED) + normalize_env_path(ENV{MKLROOT}) + find_package(LAPACK REQUIRED) + add_definitions(-DHAVE_MKL=1) + include_directories($ENV{MKLROOT}/include) # TODO: maybe not use env, idk, find_package doesnt handle includes... + link_libraries(${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}) +elseif(MATHLIB STREQUAL "Accelerate") + set(BLA_VENDOR "Apple") + find_package(BLAS REQUIRED) + find_package(LAPACK REQUIRED) + add_definitions(-DHAVE_CLAPACK=1) + link_libraries(${BLAS_LIBRARIES} ${LAPACK_LIBRARIES}) +else() + message(FATAL_ERROR "${MATHLIB} is not tested and supported, you are on your own now.") +endif() + +if(MSVC) + # Added in source, but we actually should do it in build script, whatever... + # add_definitions(-DWIN32_LEAN_AND_MEAN=1) + + add_compile_options(/permissive- /FS /wd4819 /EHsc /bigobj) + + # some warnings related with fst + add_compile_options(/wd4018 /wd4244 /wd4267 /wd4291 /wd4305) + + set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE INTERNAL "") + if(NOT DEFINED ENV{CUDAHOSTCXX}) + set(ENV{CUDAHOSTCXX} ${CMAKE_CXX_COMPILER}) + endif() + if(NOT DEFINED CUDA_HOST_COMPILER) + set(CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) + endif() +endif() + +find_package(CUDA) +if(CUDA_FOUND) + set(CUB_ROOT_DIR "${PROJECT_SOURCE_DIR}/tools/cub") + + set(CUDA_PROPAGATE_HOST_FLAGS ON) + set(KALDI_CUDA_NVCC_FLAGS "--default-stream=per-thread;-std=c++${CMAKE_CXX_STANDARD}") + if(MSVC) + list(APPEND KALDI_CUDA_NVCC_FLAGS "-Xcompiler /permissive-,/FS,/wd4819,/EHsc,/bigobj") + list(APPEND KALDI_CUDA_NVCC_FLAGS "-Xcompiler /wd4018,/wd4244,/wd4267,/wd4291,/wd4305") + if(BUILD_SHARED_LIBS) + list(APPEND CUDA_NVCC_FLAGS_RELEASE -Xcompiler /MD) + list(APPEND CUDA_NVCC_FLAGS_DEBUG -Xcompiler /MDd) + endif() + else() + # list(APPEND KALDI_CUDA_NVCC_FLAGS "-Xcompiler -std=c++${CMAKE_CXX_STANDARD}") + list(APPEND KALDI_CUDA_NVCC_FLAGS "-Xcompiler -fPIC") + endif() + set(CUDA_NVCC_FLAGS ${KALDI_CUDA_NVCC_FLAGS} ${CUDA_NVCC_FLAGS}) + + add_definitions(-DHAVE_CUDA=1) + add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM=1) + include_directories(${CUDA_INCLUDE_DIRS}) + link_libraries( + ${CUDA_LIBRARIES} + ${CUDA_CUDA_LIBRARY} + ${CUDA_CUBLAS_LIBRARIES} + ${CUDA_CUFFT_LIBRARIES} + ${CUDA_curand_LIBRARY} + ${CUDA_cusolver_LIBRARY} + ${CUDA_cusparse_LIBRARY}) + + find_package(NvToolExt REQUIRED) + include_directories(${NvToolExt_INCLUDE_DIR}) + link_libraries(${NvToolExt_LIBRARIES}) + + find_package(CUB REQUIRED) + include_directories(${CUB_INCLUDE_DIR}) +endif() + +add_definitions(-DKALDI_NO_PORTAUDIO=1) + +include(VersionHelper) +get_version() # this will set KALDI_VERSION and KALDI_PATCH_NUMBER +if(${KALDI_USE_PATCH_NUMBER}) + set(KALDI_VERSION "${KALDI_VERSION}.${KALDI_PATCH_NUMBER}") +endif() + +get_third_party(openfst) +set(OPENFST_ROOT_DIR ${CMAKE_CURRENT_BINARY_DIR}/openfst) +include(third_party/openfst_lib_target) +link_libraries(fst) + +# add all native libraries +add_subdirectory(src/base) # NOTE, we need to patch the target with version from outside +set_property(TARGET kaldi-base PROPERTY COMPILE_DEFINITIONS "KALDI_VERSION=\"${KALDI_VERSION}\"") +add_subdirectory(src/matrix) +add_subdirectory(src/cudamatrix) +add_subdirectory(src/util) +add_subdirectory(src/feat) +add_subdirectory(src/tree) +add_subdirectory(src/gmm) +add_subdirectory(src/transform) +add_subdirectory(src/sgmm2) +add_subdirectory(src/fstext) +add_subdirectory(src/hmm) +add_subdirectory(src/lm) +add_subdirectory(src/decoder) +add_subdirectory(src/lat) +add_subdirectory(src/nnet) +add_subdirectory(src/nnet2) +add_subdirectory(src/nnet3) +add_subdirectory(src/rnnlm) +add_subdirectory(src/chain) +add_subdirectory(src/ivector) +add_subdirectory(src/online) +add_subdirectory(src/online2) +add_subdirectory(src/kws) + +add_subdirectory(src/itf) + +# add all cuda libraries +if(CUDA_FOUND) + add_subdirectory(src/cudafeat) + add_subdirectory(src/cudadecoder) +endif() + +# add all native executables +add_subdirectory(src/gmmbin) +add_subdirectory(src/featbin) +add_subdirectory(src/onlinebin) + +# add all cuda executables +if(CUDA_FOUND) + add_subdirectory(src/cudafeatbin) + add_subdirectory(src/cudadecoderbin) +endif() + +include(CMakePackageConfigHelpers) +# maybe we should put this into subfolder? +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/kaldi-config.cmake.in + ${CMAKE_BINARY_DIR}/cmake/kaldi-config.cmake + INSTALL_DESTINATION lib/cmake/kaldi +) +write_basic_package_version_file( + ${CMAKE_BINARY_DIR}/cmake/kaldi-config-version.cmake + VERSION ${KALDI_VERSION} + COMPATIBILITY AnyNewerVersion +) +install(FILES ${CMAKE_BINARY_DIR}/cmake/kaldi-config.cmake ${CMAKE_BINARY_DIR}/cmake/kaldi-config-version.cmake + DESTINATION lib/cmake/kaldi +) +install(EXPORT kaldi-targets DESTINATION ${CMAKE_INSTALL_PREFIX}/lib/cmake/kaldi) diff --git a/INSTALL b/INSTALL index 2dbf31811..7beb79a73 100644 --- a/INSTALL +++ b/INSTALL @@ -1,9 +1,16 @@ This is the official Kaldi INSTALL. Look also at INSTALL.md for the git mirror installation. -[for native Windows install, see windows/INSTALL] +[Option 1 in the following does not apply to native Windows install, see windows/INSTALL or following Option 2] -(1) -go to tools/ and follow INSTALL instructions there. +Option 1 (bash + makefile): -(2) -go to src/ and follow INSTALL instructions there. + Steps: + (1) + go to tools/ and follow INSTALL instructions there. + (2) + go to src/ and follow INSTALL instructions there. + +Option 2 (cmake): + + Go to cmake/ and follow INSTALL.md instructions there. + Note, it may not be well tested and some features are missing currently. diff --git a/cmake/FindBLAS.cmake b/cmake/FindBLAS.cmake new file mode 100644 index 000000000..67676110c --- /dev/null +++ b/cmake/FindBLAS.cmake @@ -0,0 +1,816 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +FindBLAS +-------- + +Find Basic Linear Algebra Subprograms (BLAS) library + +This module finds an installed Fortran library that implements the +BLAS linear-algebra interface (see http://www.netlib.org/blas/). The +list of libraries searched for is taken from the ``autoconf`` macro file, +``acx_blas.m4`` (distributed at +http://ac-archive.sourceforge.net/ac-archive/acx_blas.html). + +Input Variables +^^^^^^^^^^^^^^^ + +The following variables may be set to influence this module's behavior: + +``BLA_STATIC`` + if ``ON`` use static linkage + +``BLA_VENDOR`` + If set, checks only the specified vendor, if not set checks all the + possibilities. List of vendors valid in this module: + + * Goto + * OpenBLAS + * FLAME + * ATLAS PhiPACK + * CXML + * DXML + * SunPerf + * SCSL + * SGIMATH + * IBMESSL + * Intel10_32 (intel mkl v10 32 bit) + * Intel10_64lp (intel mkl v10+ 64 bit, threaded code, lp64 model) + * Intel10_64lp_seq (intel mkl v10+ 64 bit, sequential code, lp64 model) + * Intel10_64ilp (intel mkl v10+ 64 bit, threaded code, ilp64 model) + * Intel10_64ilp_seq (intel mkl v10+ 64 bit, sequential code, ilp64 model) + * Intel (obsolete versions of mkl 32 and 64 bit) + * ACML + * ACML_MP + * ACML_GPU + * Apple + * NAS + * Generic + +``BLA_F95`` + if ``ON`` tries to find the BLAS95 interfaces + +``BLA_PREFER_PKGCONFIG`` + if set ``pkg-config`` will be used to search for a BLAS library first + and if one is found that is preferred + +Result Variables +^^^^^^^^^^^^^^^^ + +This module defines the following variables: + +``BLAS_FOUND`` + library implementing the BLAS interface is found +``BLAS_LINKER_FLAGS`` + uncached list of required linker flags (excluding ``-l`` and ``-L``). +``BLAS_LIBRARIES`` + uncached list of libraries (using full path name) to link against + to use BLAS (may be empty if compiler implicitly links BLAS) +``BLAS95_LIBRARIES`` + uncached list of libraries (using full path name) to link against + to use BLAS95 interface +``BLAS95_FOUND`` + library implementing the BLAS95 interface is found + +.. note:: + + C or CXX must be enabled to use Intel Math Kernel Library (MKL) + + For example, to use Intel MKL libraries and/or Intel compiler: + + .. code-block:: cmake + + set(BLA_VENDOR Intel10_64lp) + find_package(BLAS) + +Hints +^^^^^ + +Set ``MKLROOT`` environment variable to a directory that contains an MKL +installation. + +#]=======================================================================] + +include(CheckFunctionExists) +include(CheckFortranFunctionExists) +include(CMakePushCheckState) +include(FindPackageHandleStandardArgs) +cmake_push_check_state() +set(CMAKE_REQUIRED_QUIET ${BLAS_FIND_QUIETLY}) + +set(_blas_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES}) + +# Check the language being used +if( NOT (CMAKE_C_COMPILER_LOADED OR CMAKE_CXX_COMPILER_LOADED OR CMAKE_Fortran_COMPILER_LOADED) ) + if(BLAS_FIND_REQUIRED) + message(FATAL_ERROR "FindBLAS requires Fortran, C, or C++ to be enabled.") + else() + message(STATUS "Looking for BLAS... - NOT found (Unsupported languages)") + return() + endif() +endif() + +if(BLA_PREFER_PKGCONFIG) + find_package(PkgConfig) + pkg_check_modules(PKGC_BLAS blas) + if(PKGC_BLAS_FOUND) + set(BLAS_FOUND ${PKGC_BLAS_FOUND}) + set(BLAS_LIBRARIES "${PKGC_BLAS_LINK_LIBRARIES}") + return() + endif() +endif() + +macro(Check_Fortran_Libraries LIBRARIES _prefix _name _flags _list _thread) + # This macro checks for the existence of the combination of fortran libraries + # given by _list. If the combination is found, this macro checks (using the + # Check_Fortran_Function_Exists macro) whether can link against that library + # combination using the name of a routine given by _name using the linker + # flags given by _flags. If the combination of libraries is found and passes + # the link test, LIBRARIES is set to the list of complete library paths that + # have been found. Otherwise, LIBRARIES is set to FALSE. + + # N.B. _prefix is the prefix applied to the names of all cached variables that + # are generated internally and marked advanced by this macro. + + set(_libdir ${ARGN}) + + set(_libraries_work TRUE) + set(${LIBRARIES}) + set(_combined_name) + if (NOT _libdir) + if (WIN32) + set(_libdir ENV LIB) + elseif (APPLE) + set(_libdir ENV DYLD_LIBRARY_PATH) + else () + set(_libdir ENV LD_LIBRARY_PATH) + endif () + endif () + + list(APPEND _libdir "${CMAKE_C_IMPLICIT_LINK_DIRECTORIES}") + + foreach(_library ${_list}) + set(_combined_name ${_combined_name}_${_library}) + if(NOT "${_thread}" STREQUAL "") + set(_combined_name ${_combined_name}_thread) + endif() + if(_libraries_work) + if (BLA_STATIC) + if (WIN32) + set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif () + if (APPLE) + set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES}) + else () + set(CMAKE_FIND_LIBRARY_SUFFIXES .a ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif () + else () + if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + # for ubuntu's libblas3gf and liblapack3gf packages + set(CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES} .so.3gf) + endif () + endif () + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS ${_libdir} + ) + mark_as_advanced(${_prefix}_${_library}_LIBRARY) + set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) + set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) + endif() + endforeach() + if(_libraries_work) + # Test this combination of libraries. + set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_thread}) + # message("DEBUG: CMAKE_REQUIRED_LIBRARIES = ${CMAKE_REQUIRED_LIBRARIES}") + if (CMAKE_Fortran_COMPILER_LOADED) + check_fortran_function_exists("${_name}" ${_prefix}${_combined_name}_WORKS) + else() + check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) + endif() + set(CMAKE_REQUIRED_LIBRARIES) + set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) + endif() + if(_libraries_work) + if("${_list}" STREQUAL "") + set(${LIBRARIES} "${LIBRARIES}-PLACEHOLDER-FOR-EMPTY-LIBRARIES") + else() + set(${LIBRARIES} ${${LIBRARIES}} ${_thread}) # for static link + endif() + else() + set(${LIBRARIES} FALSE) + endif() + #message("DEBUG: ${LIBRARIES} = ${${LIBRARIES}}") +endmacro() + +set(BLAS_LINKER_FLAGS) +set(BLAS_LIBRARIES) +set(BLAS95_LIBRARIES) +if (NOT $ENV{BLA_VENDOR} STREQUAL "") + set(BLA_VENDOR $ENV{BLA_VENDOR}) +else () + if(NOT BLA_VENDOR) + set(BLA_VENDOR "All") + endif() +endif () + +if (BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + # Implicitly linked BLAS libraries + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "" + "" + ) + endif() +endif () + +#BLAS in intel mkl 10+ library? (em64t 64bit) +if (BLA_VENDOR MATCHES "Intel" OR BLA_VENDOR STREQUAL "All") + if (NOT BLAS_LIBRARIES) + + # System-specific settings + if (WIN32) + if (BLA_STATIC) + set(BLAS_mkl_DLL_SUFFIX "") + else() + set(BLAS_mkl_DLL_SUFFIX "_dll") + endif() + else() + # Switch to GNU Fortran support layer if needed (but not on Apple, where MKL does not provide it) + if(CMAKE_Fortran_COMPILER_LOADED AND CMAKE_Fortran_COMPILER_ID STREQUAL "GNU" AND NOT APPLE) + set(BLAS_mkl_INTFACE "gf") + set(BLAS_mkl_THREADING "gnu") + set(BLAS_mkl_OMP "gomp") + else() + set(BLAS_mkl_INTFACE "intel") + set(BLAS_mkl_THREADING "intel") + set(BLAS_mkl_OMP "iomp5") + endif() + set(BLAS_mkl_LM "-lm") + set(BLAS_mkl_LDL "-ldl") + endif() + + if (BLA_VENDOR MATCHES "_64ilp") + set(BLAS_mkl_ILP_MODE "ilp64") + else () + set(BLAS_mkl_ILP_MODE "lp64") + endif () + + if (CMAKE_C_COMPILER_LOADED OR CMAKE_CXX_COMPILER_LOADED) + if(BLAS_FIND_QUIETLY OR NOT BLAS_FIND_REQUIRED) + find_package(Threads) + else() + find_package(Threads REQUIRED) + endif() + + set(BLAS_SEARCH_LIBS "") + + if(BLA_F95) + set(BLAS_mkl_SEARCH_SYMBOL sgemm_f95) + set(_LIBRARIES BLAS95_LIBRARIES) + if (WIN32) + # Find the main file (32-bit or 64-bit) + set(BLAS_SEARCH_LIBS_WIN_MAIN "") + if (BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN + "mkl_blas95${BLAS_mkl_DLL_SUFFIX} mkl_intel_c${BLAS_mkl_DLL_SUFFIX}") + endif() + if (BLA_VENDOR MATCHES "^Intel10_64i?lp" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN + "mkl_blas95_${BLAS_mkl_ILP_MODE}${BLAS_mkl_DLL_SUFFIX} mkl_intel_${BLAS_mkl_ILP_MODE}${BLAS_mkl_DLL_SUFFIX}") + endif () + + # Add threading/sequential libs + set(BLAS_SEARCH_LIBS_WIN_THREAD "") + if (BLA_VENDOR MATCHES "_seq$" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD + "mkl_sequential${BLAS_mkl_DLL_SUFFIX}") + endif() + if (NOT BLA_VENDOR MATCHES "_seq$" OR BLA_VENDOR STREQUAL "All") + # old version + list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD + "libguide40 mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") + # mkl >= 10.3 + list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD + "libiomp5md mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") + endif() + + # Cartesian product of the above + foreach (MAIN ${BLAS_SEARCH_LIBS_WIN_MAIN}) + foreach (THREAD ${BLAS_SEARCH_LIBS_WIN_THREAD}) + list(APPEND BLAS_SEARCH_LIBS + "${MAIN} ${THREAD} mkl_core${BLAS_mkl_DLL_SUFFIX}") + endforeach() + endforeach() + else () + if (BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") + # old version + list(APPEND BLAS_SEARCH_LIBS + "mkl_blas95 mkl_${BLAS_mkl_INTFACE} mkl_${BLAS_mkl_THREADING}_thread mkl_core guide") + + # mkl >= 10.3 + list(APPEND BLAS_SEARCH_LIBS + "mkl_blas95 mkl_${BLAS_mkl_INTFACE} mkl_${BLAS_mkl_THREADING}_thread mkl_core ${BLAS_mkl_OMP}") + endif () + if (BLA_VENDOR MATCHES "^Intel10_64i?lp$" OR BLA_VENDOR STREQUAL "All") + # old version + list(APPEND BLAS_SEARCH_LIBS + "mkl_blas95 mkl_${BLAS_mkl_INTFACE}_${BLAS_mkl_ILP_MODE} mkl_${BLAS_mkl_THREADING}_thread mkl_core guide") + + # mkl >= 10.3 + list(APPEND BLAS_SEARCH_LIBS + "mkl_blas95_${BLAS_mkl_ILP_MODE} mkl_${BLAS_mkl_INTFACE}_${BLAS_mkl_ILP_MODE} mkl_${BLAS_mkl_THREADING}_thread mkl_core ${BLAS_mkl_OMP}") + endif () + if (BLA_VENDOR MATCHES "^Intel10_64i?lp_seq$" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS + "mkl_blas95_${BLAS_mkl_ILP_MODE} mkl_${BLAS_mkl_INTFACE}_${BLAS_mkl_ILP_MODE} mkl_sequential mkl_core") + endif () + endif () + else () + set(BLAS_mkl_SEARCH_SYMBOL sgemm) + set(_LIBRARIES BLAS_LIBRARIES) + if (WIN32) + # Find the main file (32-bit or 64-bit) + set(BLAS_SEARCH_LIBS_WIN_MAIN "") + if (BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN + "mkl_intel_c${BLAS_mkl_DLL_SUFFIX}") + endif() + if (BLA_VENDOR MATCHES "^Intel10_64i?lp" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN + "mkl_intel_${BLAS_mkl_ILP_MODE}${BLAS_mkl_DLL_SUFFIX}") + endif () + + # Add threading/sequential libs + set(BLAS_SEARCH_LIBS_WIN_THREAD "") + if (NOT BLA_VENDOR MATCHES "_seq$" OR BLA_VENDOR STREQUAL "All") + # old version + list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD + "libguide40 mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") + # mkl >= 10.3 + list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD + "libiomp5md mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") + endif() + if (BLA_VENDOR MATCHES "_seq$" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD + "mkl_sequential${BLAS_mkl_DLL_SUFFIX}") + endif() + + # Cartesian product of the above + foreach (MAIN ${BLAS_SEARCH_LIBS_WIN_MAIN}) + foreach (THREAD ${BLAS_SEARCH_LIBS_WIN_THREAD}) + list(APPEND BLAS_SEARCH_LIBS + "${MAIN} ${THREAD} mkl_core${BLAS_mkl_DLL_SUFFIX}") + endforeach() + endforeach() + else () + if (BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") + # old version + list(APPEND BLAS_SEARCH_LIBS + "mkl_${BLAS_mkl_INTFACE} mkl_${BLAS_mkl_THREADING}_thread mkl_core guide") + + # mkl >= 10.3 + list(APPEND BLAS_SEARCH_LIBS + "mkl_${BLAS_mkl_INTFACE} mkl_${BLAS_mkl_THREADING}_thread mkl_core ${BLAS_mkl_OMP}") + endif () + if (BLA_VENDOR MATCHES "^Intel10_64i?lp$" OR BLA_VENDOR STREQUAL "All") + # old version + list(APPEND BLAS_SEARCH_LIBS + "mkl_${BLAS_mkl_INTFACE}_${BLAS_mkl_ILP_MODE} mkl_${BLAS_mkl_THREADING}_thread mkl_core guide") + + # mkl >= 10.3 + list(APPEND BLAS_SEARCH_LIBS + "mkl_${BLAS_mkl_INTFACE}_${BLAS_mkl_ILP_MODE} mkl_${BLAS_mkl_THREADING}_thread mkl_core ${BLAS_mkl_OMP}") + endif () + if (BLA_VENDOR MATCHES "^Intel10_64i?lp_seq$" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS + "mkl_${BLAS_mkl_INTFACE}_${BLAS_mkl_ILP_MODE} mkl_sequential mkl_core") + endif () + + #older vesions of intel mkl libs + if (BLA_VENDOR STREQUAL "Intel" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS + "mkl") + list(APPEND BLAS_SEARCH_LIBS + "mkl_ia32") + list(APPEND BLAS_SEARCH_LIBS + "mkl_em64t") + endif () + endif () + endif () + + if (DEFINED ENV{MKLROOT}) + if (BLA_VENDOR STREQUAL "Intel10_32") + set(_BLAS_MKLROOT_LIB_DIR "$ENV{MKLROOT}/lib/ia32") + elseif (BLA_VENDOR MATCHES "^Intel10_64i?lp$" OR BLA_VENDOR MATCHES "^Intel10_64i?lp_seq$") + set(_BLAS_MKLROOT_LIB_DIR "$ENV{MKLROOT}/lib/intel64") + endif () + endif () + if (_BLAS_MKLROOT_LIB_DIR) + if (WIN32) + string(APPEND _BLAS_MKLROOT_LIB_DIR "_win") + elseif (APPLE) + string(APPEND _BLAS_MKLROOT_LIB_DIR "_mac") + else () + string(APPEND _BLAS_MKLROOT_LIB_DIR "_lin") + endif () + endif () + + foreach (IT ${BLAS_SEARCH_LIBS}) + string(REPLACE " " ";" SEARCH_LIBS ${IT}) + if (NOT ${_LIBRARIES}) + check_fortran_libraries( + ${_LIBRARIES} + BLAS + ${BLAS_mkl_SEARCH_SYMBOL} + "" + "${SEARCH_LIBS}" + "${CMAKE_THREAD_LIBS_INIT};${BLAS_mkl_LM};${BLAS_mkl_LDL}" + "${_BLAS_MKLROOT_LIB_DIR}" + ) + endif () + endforeach () + + endif () + unset(BLAS_mkl_ILP_MODE) + unset(BLAS_mkl_INTFACE) + unset(BLAS_mkl_THREADING) + unset(BLAS_mkl_OMP) + unset(BLAS_mkl_DLL_SUFFIX) + unset(BLAS_mkl_LM) + unset(BLAS_mkl_LDL) + endif () +endif () + +if(BLA_F95) + find_package_handle_standard_args(BLAS REQUIRED_VARS BLAS95_LIBRARIES) + set(BLAS95_FOUND ${BLAS_FOUND}) + if(BLAS_FOUND) + set(BLAS_LIBRARIES "${BLAS95_LIBRARIES}") + endif() +endif() + +if (BLA_VENDOR STREQUAL "Goto" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + # gotoblas (http://www.tacc.utexas.edu/tacc-projects/gotoblas2) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "goto2" + "" + ) + endif() +endif () + +if (BLA_VENDOR STREQUAL "OpenBLAS" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + # OpenBLAS (http://www.openblas.net) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "openblas" + "" + ) + endif() + if(NOT BLAS_LIBRARIES) + find_package(Threads) + # OpenBLAS (http://www.openblas.net) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "openblas" + "${CMAKE_THREAD_LIBS_INIT}" + ) + endif() +endif () + +if (BLA_VENDOR STREQUAL "FLAME" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + # FLAME's blis library (https://github.com/flame/blis) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "blis" + "" + ) + endif() +endif () + +if (BLA_VENDOR STREQUAL "ATLAS" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + # BLAS in ATLAS library? (http://math-atlas.sourceforge.net/) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + dgemm + "" + "f77blas;atlas" + "" + ) + endif() +endif () + +# BLAS in PhiPACK libraries? (requires generic BLAS lib, too) +if (BLA_VENDOR STREQUAL "PhiPACK" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "sgemm;dgemm;blas" + "" + ) + endif() +endif () + +# BLAS in Alpha CXML library? +if (BLA_VENDOR STREQUAL "CXML" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "cxml" + "" + ) + endif() +endif () + +# BLAS in Alpha DXML library? (now called CXML, see above) +if (BLA_VENDOR STREQUAL "DXML" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "dxml" + "" + ) + endif() +endif () + +# BLAS in Sun Performance library? +if (BLA_VENDOR STREQUAL "SunPerf" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "-xlic_lib=sunperf" + "sunperf;sunmath" + "" + ) + if(BLAS_LIBRARIES) + set(BLAS_LINKER_FLAGS "-xlic_lib=sunperf") + endif() + endif() +endif () + +# BLAS in SCSL library? (SGI/Cray Scientific Library) +if (BLA_VENDOR STREQUAL "SCSL" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "scsl" + "" + ) + endif() +endif () + +# BLAS in SGIMATH library? +if (BLA_VENDOR STREQUAL "SGIMATH" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "complib.sgimath" + "" + ) + endif() +endif () + +# BLAS in IBM ESSL library? (requires generic BLAS lib, too) +if (BLA_VENDOR STREQUAL "IBMESSL" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "essl;blas" + "" + ) + endif() +endif () + +#BLAS in acml library? +if (BLA_VENDOR MATCHES "ACML" OR BLA_VENDOR STREQUAL "All") + if( ((BLA_VENDOR STREQUAL "ACML") AND (NOT BLAS_ACML_LIB_DIRS)) OR + ((BLA_VENDOR STREQUAL "ACML_MP") AND (NOT BLAS_ACML_MP_LIB_DIRS)) OR + ((BLA_VENDOR STREQUAL "ACML_GPU") AND (NOT BLAS_ACML_GPU_LIB_DIRS)) + ) + # try to find acml in "standard" paths + if( WIN32 ) + file( GLOB _ACML_ROOT "C:/AMD/acml*/ACML-EULA.txt" ) + else() + file( GLOB _ACML_ROOT "/opt/acml*/ACML-EULA.txt" ) + endif() + if( WIN32 ) + file( GLOB _ACML_GPU_ROOT "C:/AMD/acml*/GPGPUexamples" ) + else() + file( GLOB _ACML_GPU_ROOT "/opt/acml*/GPGPUexamples" ) + endif() + list(GET _ACML_ROOT 0 _ACML_ROOT) + list(GET _ACML_GPU_ROOT 0 _ACML_GPU_ROOT) + if( _ACML_ROOT ) + get_filename_component( _ACML_ROOT ${_ACML_ROOT} PATH ) + if( SIZEOF_INTEGER EQUAL 8 ) + set( _ACML_PATH_SUFFIX "_int64" ) + else() + set( _ACML_PATH_SUFFIX "" ) + endif() + if( CMAKE_Fortran_COMPILER_ID STREQUAL "Intel" ) + set( _ACML_COMPILER32 "ifort32" ) + set( _ACML_COMPILER64 "ifort64" ) + elseif( CMAKE_Fortran_COMPILER_ID STREQUAL "SunPro" ) + set( _ACML_COMPILER32 "sun32" ) + set( _ACML_COMPILER64 "sun64" ) + elseif( CMAKE_Fortran_COMPILER_ID STREQUAL "PGI" ) + set( _ACML_COMPILER32 "pgi32" ) + if( WIN32 ) + set( _ACML_COMPILER64 "win64" ) + else() + set( _ACML_COMPILER64 "pgi64" ) + endif() + elseif( CMAKE_Fortran_COMPILER_ID STREQUAL "Open64" ) + # 32 bit builds not supported on Open64 but for code simplicity + # We'll just use the same directory twice + set( _ACML_COMPILER32 "open64_64" ) + set( _ACML_COMPILER64 "open64_64" ) + elseif( CMAKE_Fortran_COMPILER_ID STREQUAL "NAG" ) + set( _ACML_COMPILER32 "nag32" ) + set( _ACML_COMPILER64 "nag64" ) + else() + set( _ACML_COMPILER32 "gfortran32" ) + set( _ACML_COMPILER64 "gfortran64" ) + endif() + + if( BLA_VENDOR STREQUAL "ACML_MP" ) + set(_ACML_MP_LIB_DIRS + "${_ACML_ROOT}/${_ACML_COMPILER32}_mp${_ACML_PATH_SUFFIX}/lib" + "${_ACML_ROOT}/${_ACML_COMPILER64}_mp${_ACML_PATH_SUFFIX}/lib" ) + else() + set(_ACML_LIB_DIRS + "${_ACML_ROOT}/${_ACML_COMPILER32}${_ACML_PATH_SUFFIX}/lib" + "${_ACML_ROOT}/${_ACML_COMPILER64}${_ACML_PATH_SUFFIX}/lib" ) + endif() + endif() +elseif(BLAS_${BLA_VENDOR}_LIB_DIRS) + set(_${BLA_VENDOR}_LIB_DIRS ${BLAS_${BLA_VENDOR}_LIB_DIRS}) +endif() + +if( BLA_VENDOR STREQUAL "ACML_MP" ) + foreach( BLAS_ACML_MP_LIB_DIRS ${_ACML_MP_LIB_DIRS}) + check_fortran_libraries ( + BLAS_LIBRARIES + BLAS + sgemm + "" "acml_mp;acml_mv" "" ${BLAS_ACML_MP_LIB_DIRS} + ) + if( BLAS_LIBRARIES ) + break() + endif() + endforeach() +elseif( BLA_VENDOR STREQUAL "ACML_GPU" ) + foreach( BLAS_ACML_GPU_LIB_DIRS ${_ACML_GPU_LIB_DIRS}) + check_fortran_libraries ( + BLAS_LIBRARIES + BLAS + sgemm + "" "acml;acml_mv;CALBLAS" "" ${BLAS_ACML_GPU_LIB_DIRS} + ) + if( BLAS_LIBRARIES ) + break() + endif() + endforeach() +else() + foreach( BLAS_ACML_LIB_DIRS ${_ACML_LIB_DIRS} ) + check_fortran_libraries ( + BLAS_LIBRARIES + BLAS + sgemm + "" "acml;acml_mv" "" ${BLAS_ACML_LIB_DIRS} + ) + if( BLAS_LIBRARIES ) + break() + endif() + endforeach() +endif() + +# Either acml or acml_mp should be in LD_LIBRARY_PATH but not both +if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "acml;acml_mv" + "" + ) +endif() +if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "acml_mp;acml_mv" + "" + ) +endif() +if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "acml;acml_mv;CALBLAS" + "" + ) +endif() +endif () # ACML + +# Apple BLAS library? +if (BLA_VENDOR STREQUAL "Apple" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + dgemm + "" + "Accelerate" + "" + ) + endif() +endif () + +if (BLA_VENDOR STREQUAL "NAS" OR BLA_VENDOR STREQUAL "All") + if ( NOT BLAS_LIBRARIES ) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + dgemm + "" + "vecLib" + "" + ) + endif () +endif () + +# Generic BLAS library? +if (BLA_VENDOR STREQUAL "Generic" OR BLA_VENDOR STREQUAL "All") + if(NOT BLAS_LIBRARIES) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "blas" + "" + ) + endif() +endif () + +if(NOT BLA_F95) + find_package_handle_standard_args(BLAS REQUIRED_VARS BLAS_LIBRARIES) +endif() + +# On compilers that implicitly link BLAS (such as ftn, cc, and CC on Cray HPC machines) +# we used a placeholder for empty BLAS_LIBRARIES to get through our logic above. +if (BLAS_LIBRARIES STREQUAL "BLAS_LIBRARIES-PLACEHOLDER-FOR-EMPTY-LIBRARIES") + set(BLAS_LIBRARIES "") +endif() + +cmake_pop_check_state() +set(CMAKE_FIND_LIBRARY_SUFFIXES ${_blas_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES}) diff --git a/cmake/FindCUB.cmake b/cmake/FindCUB.cmake new file mode 100644 index 000000000..33c8a926f --- /dev/null +++ b/cmake/FindCUB.cmake @@ -0,0 +1,25 @@ +# Try to find the CUB library and headers. +# CUB_ROOT_DIR - where to find + +# CUB_FOUND - system has CUB +# CUB_INCLUDE_DIRS - the CUB include directory + + +find_path(CUB_INCLUDE_DIR + NAMES cub/cub.cuh + HINTS ${CUB_ROOT_DIR} + DOC "The directory where CUB includes reside" +) + +set(CUB_INCLUDE_DIRS ${CUB_INCLUDE_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(CUB + FOUND_VAR CUB_FOUND + REQUIRED_VARS CUB_INCLUDE_DIR +) + +mark_as_advanced(CUB_FOUND) + +add_library(CUB INTERFACE) +target_include_directories(CUB INTERFACE ${CUB_INCLUDE_DIR}) diff --git a/cmake/FindICU.cmake b/cmake/FindICU.cmake new file mode 100644 index 000000000..8c460082c --- /dev/null +++ b/cmake/FindICU.cmake @@ -0,0 +1,428 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +FindICU +------- + +Find the International Components for Unicode (ICU) libraries and +programs. + +This module supports multiple components. +Components can include any of: ``data``, ``i18n``, ``io``, ``le``, +``lx``, ``test``, ``tu`` and ``uc``. + +Note that on Windows ``data`` is named ``dt`` and ``i18n`` is named +``in``; any of the names may be used, and the appropriate +platform-specific library name will be automatically selected. + +This module reports information about the ICU installation in +several variables. General variables:: + + ICU_VERSION - ICU release version + ICU_FOUND - true if the main programs and libraries were found + ICU_LIBRARIES - component libraries to be linked + ICU_INCLUDE_DIRS - the directories containing the ICU headers + +Imported targets:: + + ICU:: + +Where ```` is the name of an ICU component, for example +``ICU::i18n``. + +ICU programs are reported in:: + + ICU_GENCNVAL_EXECUTABLE - path to gencnval executable + ICU_ICUINFO_EXECUTABLE - path to icuinfo executable + ICU_GENBRK_EXECUTABLE - path to genbrk executable + ICU_ICU-CONFIG_EXECUTABLE - path to icu-config executable + ICU_GENRB_EXECUTABLE - path to genrb executable + ICU_GENDICT_EXECUTABLE - path to gendict executable + ICU_DERB_EXECUTABLE - path to derb executable + ICU_PKGDATA_EXECUTABLE - path to pkgdata executable + ICU_UCONV_EXECUTABLE - path to uconv executable + ICU_GENCFU_EXECUTABLE - path to gencfu executable + ICU_MAKECONV_EXECUTABLE - path to makeconv executable + ICU_GENNORM2_EXECUTABLE - path to gennorm2 executable + ICU_GENCCODE_EXECUTABLE - path to genccode executable + ICU_GENSPREP_EXECUTABLE - path to gensprep executable + ICU_ICUPKG_EXECUTABLE - path to icupkg executable + ICU_GENCMN_EXECUTABLE - path to gencmn executable + +ICU component libraries are reported in:: + + ICU__FOUND - ON if component was found + ICU__LIBRARIES - libraries for component + +ICU datafiles are reported in:: + + ICU_MAKEFILE_INC - Makefile.inc + ICU_PKGDATA_INC - pkgdata.inc + +Note that ```` is the uppercased name of the component. + +This module reads hints about search results from:: + + ICU_ROOT - the root of the ICU installation + +The environment variable ``ICU_ROOT`` may also be used; the +ICU_ROOT variable takes precedence. + +The following cache variables may also be set:: + + ICU_

_EXECUTABLE - the path to executable

+ ICU_INCLUDE_DIR - the directory containing the ICU headers + ICU__LIBRARY - the library for component + +.. note:: + + In most cases none of the above variables will require setting, + unless multiple ICU versions are available and a specific version + is required. + +Other variables one may set to control this module are:: + + ICU_DEBUG - Set to ON to enable debug output from FindICU. +#]=======================================================================] + +# Written by Roger Leigh + +set(icu_programs + gencnval + icuinfo + genbrk + icu-config + genrb + gendict + derb + pkgdata + uconv + gencfu + makeconv + gennorm2 + genccode + gensprep + icupkg + gencmn) + +set(icu_data + Makefile.inc + pkgdata.inc) + +# The ICU checks are contained in a function due to the large number +# of temporary variables needed. +function(_ICU_FIND) + # Set up search paths, taking compiler into account. Search ICU_ROOT, + # with ICU_ROOT in the environment as a fallback if unset. + if(ICU_ROOT) + list(APPEND icu_roots "${ICU_ROOT}") + else() + if(NOT "$ENV{ICU_ROOT}" STREQUAL "") + file(TO_CMAKE_PATH "$ENV{ICU_ROOT}" NATIVE_PATH) + list(APPEND icu_roots "${NATIVE_PATH}") + set(ICU_ROOT "${NATIVE_PATH}" + CACHE PATH "Location of the ICU installation" FORCE) + endif() + endif() + + # Find include directory + list(APPEND icu_include_suffixes "include") + find_path(ICU_INCLUDE_DIR + NAMES "unicode/utypes.h" + HINTS ${icu_roots} + PATH_SUFFIXES ${icu_include_suffixes} + DOC "ICU include directory") + set(ICU_INCLUDE_DIR "${ICU_INCLUDE_DIR}" PARENT_SCOPE) + + # Get version + if(ICU_INCLUDE_DIR AND EXISTS "${ICU_INCLUDE_DIR}/unicode/uvernum.h") + file(STRINGS "${ICU_INCLUDE_DIR}/unicode/uvernum.h" icu_header_str + REGEX "^#define[\t ]+U_ICU_VERSION[\t ]+\".*\".*") + + string(REGEX REPLACE "^#define[\t ]+U_ICU_VERSION[\t ]+\"([^ \\n]*)\".*" + "\\1" icu_version_string "${icu_header_str}") + set(ICU_VERSION "${icu_version_string}") + set(ICU_VERSION "${icu_version_string}" PARENT_SCOPE) + unset(icu_header_str) + unset(icu_version_string) + endif() + + if(CMAKE_SIZEOF_VOID_P EQUAL 8) + # 64-bit binary directory + set(_bin64 "bin64") + # 64-bit library directory + set(_lib64 "lib64") + endif() + + + # Find all ICU programs + list(APPEND icu_binary_suffixes "${_bin64}" "bin" "sbin") + foreach(program ${icu_programs}) + string(TOUPPER "${program}" program_upcase) + set(cache_var "ICU_${program_upcase}_EXECUTABLE") + set(program_var "ICU_${program_upcase}_EXECUTABLE") + find_program("${cache_var}" + NAMES "${program}" + HINTS ${icu_roots} + PATH_SUFFIXES ${icu_binary_suffixes} + DOC "ICU ${program} executable" + NO_PACKAGE_ROOT_PATH + ) + mark_as_advanced(cache_var) + set("${program_var}" "${${cache_var}}" PARENT_SCOPE) + endforeach() + + # Find all ICU libraries + list(APPEND icu_library_suffixes "${_lib64}" "lib") + set(ICU_REQUIRED_LIBS_FOUND ON) + set(static_prefix ) + # static icu libraries compiled with MSVC have the prefix 's' + if(MSVC) + set(static_prefix "s") + endif() + foreach(component ${ICU_FIND_COMPONENTS}) + string(TOUPPER "${component}" component_upcase) + set(component_cache "ICU_${component_upcase}_LIBRARY") + set(component_cache_release "${component_cache}_RELEASE") + set(component_cache_debug "${component_cache}_DEBUG") + set(component_found "${component_upcase}_FOUND") + set(component_libnames "icu${component}") + set(component_debug_libnames "icu${component}d") + + # Special case deliberate library naming mismatches between Unix + # and Windows builds + unset(component_libnames) + unset(component_debug_libnames) + list(APPEND component_libnames "icu${component}") + list(APPEND component_debug_libnames "icu${component}d") + if(component STREQUAL "data") + list(APPEND component_libnames "icudt") + # Note there is no debug variant at present + list(APPEND component_debug_libnames "icudtd") + endif() + if(component STREQUAL "dt") + list(APPEND component_libnames "icudata") + # Note there is no debug variant at present + list(APPEND component_debug_libnames "icudatad") + endif() + if(component STREQUAL "i18n") + list(APPEND component_libnames "icuin") + list(APPEND component_debug_libnames "icuind") + endif() + if(component STREQUAL "in") + list(APPEND component_libnames "icui18n") + list(APPEND component_debug_libnames "icui18nd") + endif() + + if(static_prefix) + unset(static_component_libnames) + unset(static_component_debug_libnames) + foreach(component_libname ${component_libnames}) + list(APPEND static_component_libnames + ${static_prefix}${component_libname}) + endforeach() + foreach(component_libname ${component_debug_libnames}) + list(APPEND static_component_debug_libnames + ${static_prefix}${component_libname}) + endforeach() + list(APPEND component_libnames ${static_component_libnames}) + list(APPEND component_debug_libnames ${static_component_debug_libnames}) + endif() + find_library("${component_cache_release}" + NAMES ${component_libnames} + HINTS ${icu_roots} + PATH_SUFFIXES ${icu_library_suffixes} + DOC "ICU ${component} library (release)" + NO_PACKAGE_ROOT_PATH + ) + find_library("${component_cache_debug}" + NAMES ${component_debug_libnames} + HINTS ${icu_roots} + PATH_SUFFIXES ${icu_library_suffixes} + DOC "ICU ${component} library (debug)" + NO_PACKAGE_ROOT_PATH + ) + include(SelectLibraryConfigurations) + select_library_configurations(ICU_${component_upcase}) + mark_as_advanced("${component_cache_release}" "${component_cache_debug}") + if(${component_cache}) + set("${component_found}" ON) + list(APPEND ICU_LIBRARY "${${component_cache}}") + endif() + mark_as_advanced("${component_found}") + set("${component_cache}" "${${component_cache}}" PARENT_SCOPE) + set("${component_found}" "${${component_found}}" PARENT_SCOPE) + if(${component_found}) + if (ICU_FIND_REQUIRED_${component}) + list(APPEND ICU_LIBS_FOUND "${component} (required)") + else() + list(APPEND ICU_LIBS_FOUND "${component} (optional)") + endif() + else() + if (ICU_FIND_REQUIRED_${component}) + set(ICU_REQUIRED_LIBS_FOUND OFF) + list(APPEND ICU_LIBS_NOTFOUND "${component} (required)") + else() + list(APPEND ICU_LIBS_NOTFOUND "${component} (optional)") + endif() + endif() + endforeach() + set(_ICU_REQUIRED_LIBS_FOUND "${ICU_REQUIRED_LIBS_FOUND}" PARENT_SCOPE) + set(ICU_LIBRARY "${ICU_LIBRARY}" PARENT_SCOPE) + + # Find all ICU data files + if(CMAKE_LIBRARY_ARCHITECTURE) + list(APPEND icu_data_suffixes + "${_lib64}/${CMAKE_LIBRARY_ARCHITECTURE}/icu/${ICU_VERSION}" + "lib/${CMAKE_LIBRARY_ARCHITECTURE}/icu/${ICU_VERSION}" + "${_lib64}/${CMAKE_LIBRARY_ARCHITECTURE}/icu" + "lib/${CMAKE_LIBRARY_ARCHITECTURE}/icu") + endif() + list(APPEND icu_data_suffixes + "${_lib64}/icu/${ICU_VERSION}" + "lib/icu/${ICU_VERSION}" + "${_lib64}/icu" + "lib/icu") + foreach(data ${icu_data}) + string(TOUPPER "${data}" data_upcase) + string(REPLACE "." "_" data_upcase "${data_upcase}") + set(cache_var "ICU_${data_upcase}") + set(data_var "ICU_${data_upcase}") + find_file("${cache_var}" + NAMES "${data}" + HINTS ${icu_roots} + PATH_SUFFIXES ${icu_data_suffixes} + DOC "ICU ${data} data file") + mark_as_advanced(cache_var) + set("${data_var}" "${${cache_var}}" PARENT_SCOPE) + endforeach() + + if(NOT ICU_FIND_QUIETLY) + if(ICU_LIBS_FOUND) + message(STATUS "Found the following ICU libraries:") + foreach(found ${ICU_LIBS_FOUND}) + message(STATUS " ${found}") + endforeach() + endif() + if(ICU_LIBS_NOTFOUND) + message(STATUS "The following ICU libraries were not found:") + foreach(notfound ${ICU_LIBS_NOTFOUND}) + message(STATUS " ${notfound}") + endforeach() + endif() + endif() + + if(ICU_DEBUG) + message(STATUS "--------FindICU.cmake search debug--------") + message(STATUS "ICU binary path search order: ${icu_roots}") + message(STATUS "ICU include path search order: ${icu_roots}") + message(STATUS "ICU library path search order: ${icu_roots}") + message(STATUS "----------------") + endif() +endfunction() + +_ICU_FIND() + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS(ICU + FOUND_VAR ICU_FOUND + REQUIRED_VARS ICU_INCLUDE_DIR + ICU_LIBRARY + _ICU_REQUIRED_LIBS_FOUND + VERSION_VAR ICU_VERSION + FAIL_MESSAGE "Failed to find all ICU components") + +unset(_ICU_REQUIRED_LIBS_FOUND) + +if(ICU_FOUND) + set(ICU_INCLUDE_DIRS "${ICU_INCLUDE_DIR}") + set(ICU_LIBRARIES "${ICU_LIBRARY}") + foreach(_ICU_component ${ICU_FIND_COMPONENTS}) + string(TOUPPER "${_ICU_component}" _ICU_component_upcase) + set(_ICU_component_cache "ICU_${_ICU_component_upcase}_LIBRARY") + set(_ICU_component_cache_release "ICU_${_ICU_component_upcase}_LIBRARY_RELEASE") + set(_ICU_component_cache_debug "ICU_${_ICU_component_upcase}_LIBRARY_DEBUG") + set(_ICU_component_lib "ICU_${_ICU_component_upcase}_LIBRARIES") + set(_ICU_component_found "${_ICU_component_upcase}_FOUND") + set(_ICU_imported_target "ICU::${_ICU_component}") + if(${_ICU_component_found}) + set("${_ICU_component_lib}" "${${_ICU_component_cache}}") + if(NOT TARGET ${_ICU_imported_target}) + add_library(${_ICU_imported_target} UNKNOWN IMPORTED) + if(ICU_INCLUDE_DIR) + set_target_properties(${_ICU_imported_target} PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${ICU_INCLUDE_DIR}") + endif() + if(EXISTS "${${_ICU_component_cache}}") + set_target_properties(${_ICU_imported_target} PROPERTIES + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${${_ICU_component_cache}}") + endif() + if(EXISTS "${${_ICU_component_cache_release}}") + set_property(TARGET ${_ICU_imported_target} APPEND PROPERTY + IMPORTED_CONFIGURATIONS RELEASE) + set_target_properties(${_ICU_imported_target} PROPERTIES + IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "CXX" + IMPORTED_LOCATION_RELEASE "${${_ICU_component_cache_release}}") + endif() + if(EXISTS "${${_ICU_component_cache_debug}}") + set_property(TARGET ${_ICU_imported_target} APPEND PROPERTY + IMPORTED_CONFIGURATIONS DEBUG) + set_target_properties(${_ICU_imported_target} PROPERTIES + IMPORTED_LINK_INTERFACE_LANGUAGES_DEBUG "CXX" + IMPORTED_LOCATION_DEBUG "${${_ICU_component_cache_debug}}") + endif() + if(CMAKE_DL_LIBS AND _ICU_component STREQUAL "uc") + set_target_properties(${_ICU_imported_target} PROPERTIES + INTERFACE_LINK_LIBRARIES "${CMAKE_DL_LIBS}") + endif() + endif() + endif() + unset(_ICU_component_upcase) + unset(_ICU_component_cache) + unset(_ICU_component_lib) + unset(_ICU_component_found) + unset(_ICU_imported_target) + endforeach() +endif() + +if(ICU_DEBUG) + message(STATUS "--------FindICU.cmake results debug--------") + message(STATUS "ICU found: ${ICU_FOUND}") + message(STATUS "ICU_VERSION number: ${ICU_VERSION}") + message(STATUS "ICU_ROOT directory: ${ICU_ROOT}") + message(STATUS "ICU_INCLUDE_DIR directory: ${ICU_INCLUDE_DIR}") + message(STATUS "ICU_LIBRARIES: ${ICU_LIBRARIES}") + + foreach(program IN LISTS icu_programs) + string(TOUPPER "${program}" program_upcase) + set(program_lib "ICU_${program_upcase}_EXECUTABLE") + message(STATUS "${program} program: ${${program_lib}}") + unset(program_upcase) + unset(program_lib) + endforeach() + + foreach(data IN LISTS icu_data) + string(TOUPPER "${data}" data_upcase) + string(REPLACE "." "_" data_upcase "${data_upcase}") + set(data_lib "ICU_${data_upcase}") + message(STATUS "${data} data: ${${data_lib}}") + unset(data_upcase) + unset(data_lib) + endforeach() + + foreach(component IN LISTS ICU_FIND_COMPONENTS) + string(TOUPPER "${component}" component_upcase) + set(component_lib "ICU_${component_upcase}_LIBRARIES") + set(component_found "${component_upcase}_FOUND") + message(STATUS "${component} library found: ${${component_found}}") + message(STATUS "${component} library: ${${component_lib}}") + unset(component_upcase) + unset(component_lib) + unset(component_found) + endforeach() + message(STATUS "----------------") +endif() + +unset(icu_programs) diff --git a/cmake/FindLAPACK.cmake b/cmake/FindLAPACK.cmake new file mode 100644 index 000000000..60fbf0726 --- /dev/null +++ b/cmake/FindLAPACK.cmake @@ -0,0 +1,430 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +FindLAPACK +---------- + +Find Linear Algebra PACKage (LAPACK) library + +This module finds an installed fortran library that implements the +LAPACK linear-algebra interface (see http://www.netlib.org/lapack/). + +The approach follows that taken for the autoconf macro file, +``acx_lapack.m4`` (distributed at +http://ac-archive.sourceforge.net/ac-archive/acx_lapack.html). + +Input Variables +^^^^^^^^^^^^^^^ + +The following variables may be set to influence this module's behavior: + +``BLA_STATIC`` + if ``ON`` use static linkage + +``BLA_VENDOR`` + If set, checks only the specified vendor, if not set checks all the + possibilities. List of vendors valid in this module: + + * ``Intel10_32`` (intel mkl v10 32 bit) + * ``Intel10_64lp`` (intel mkl v10+ 64 bit, threaded code, lp64 model) + * ``Intel10_64lp_seq`` (intel mkl v10+ 64 bit, sequential code, lp64 model) + * ``Intel10_64ilp`` (intel mkl v10+ 64 bit, threaded code, ilp64 model) + * ``Intel10_64ilp_seq`` (intel mkl v10+ 64 bit, sequential code, ilp64 model) + * ``Intel`` (obsolete versions of mkl 32 and 64 bit) + * ``OpenBLAS`` + * ``FLAME`` + * ``ACML`` + * ``Apple`` + * ``NAS`` + * ``Generic`` + +``BLA_F95`` + if ``ON`` tries to find BLAS95/LAPACK95 + +Result Variables +^^^^^^^^^^^^^^^^ + +This module defines the following variables: + +``LAPACK_FOUND`` + library implementing the LAPACK interface is found +``LAPACK_LINKER_FLAGS`` + uncached list of required linker flags (excluding -l and -L). +``LAPACK_LIBRARIES`` + uncached list of libraries (using full path name) to link against + to use LAPACK +``LAPACK95_LIBRARIES`` + uncached list of libraries (using full path name) to link against + to use LAPACK95 +``LAPACK95_FOUND`` + library implementing the LAPACK95 interface is found + +.. note:: + + C or CXX must be enabled to use Intel MKL + + For example, to use Intel MKL libraries and/or Intel compiler: + + .. code-block:: cmake + + set(BLA_VENDOR Intel10_64lp) + find_package(LAPACK) +#]=======================================================================] + +set(_lapack_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES}) + +# Check the language being used +if( NOT (CMAKE_C_COMPILER_LOADED OR CMAKE_CXX_COMPILER_LOADED OR CMAKE_Fortran_COMPILER_LOADED) ) + if(LAPACK_FIND_REQUIRED) + message(FATAL_ERROR "FindLAPACK requires Fortran, C, or C++ to be enabled.") + else() + message(STATUS "Looking for LAPACK... - NOT found (Unsupported languages)") + return() + endif() +endif() + +if (CMAKE_Fortran_COMPILER_LOADED) +include(CheckFortranFunctionExists) +else () +include(CheckFunctionExists) +endif () +include(CMakePushCheckState) + +cmake_push_check_state() +set(CMAKE_REQUIRED_QUIET ${LAPACK_FIND_QUIETLY}) + +set(LAPACK_FOUND FALSE) +set(LAPACK95_FOUND FALSE) + +# TODO: move this stuff to separate module + +macro(Check_Lapack_Libraries LIBRARIES _prefix _name _flags _list _blas _threads) +# This macro checks for the existence of the combination of fortran libraries +# given by _list. If the combination is found, this macro checks (using the +# Check_Fortran_Function_Exists macro) whether can link against that library +# combination using the name of a routine given by _name using the linker +# flags given by _flags. If the combination of libraries is found and passes +# the link test, LIBRARIES is set to the list of complete library paths that +# have been found. Otherwise, LIBRARIES is set to FALSE. + +# N.B. _prefix is the prefix applied to the names of all cached variables that +# are generated internally and marked advanced by this macro. + +set(_libraries_work TRUE) +set(${LIBRARIES}) +set(_combined_name) +if (NOT _libdir) + if (WIN32) + set(_libdir ENV LIB) + elseif (APPLE) + set(_libdir ENV DYLD_LIBRARY_PATH) + else () + set(_libdir ENV LD_LIBRARY_PATH) + endif () +endif () + +list(APPEND _libdir "${CMAKE_C_IMPLICIT_LINK_DIRECTORIES}") + +foreach(_library ${_list}) + set(_combined_name ${_combined_name}_${_library}) + + if(_libraries_work) + if (BLA_STATIC) + if (WIN32) + set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif () + if (APPLE) + set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES}) + else () + set(CMAKE_FIND_LIBRARY_SUFFIXES .a ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif () + else () + if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + # for ubuntu's libblas3gf and liblapack3gf packages + set(CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES} .so.3gf) + endif () + endif () + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS ${_libdir} + ) + mark_as_advanced(${_prefix}_${_library}_LIBRARY) + set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) + set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) + endif() +endforeach() + +if(_libraries_work) + # Test this combination of libraries. + if(UNIX AND BLA_STATIC) + set(CMAKE_REQUIRED_LIBRARIES ${_flags} "-Wl,--start-group" ${${LIBRARIES}} ${_blas} "-Wl,--end-group" ${_threads}) + else() + set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_blas} ${_threads}) + endif() +# message("DEBUG: CMAKE_REQUIRED_LIBRARIES = ${CMAKE_REQUIRED_LIBRARIES}") + if (NOT CMAKE_Fortran_COMPILER_LOADED) + check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) + else () + check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) + endif () + set(CMAKE_REQUIRED_LIBRARIES) + set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) + #message("DEBUG: ${LIBRARIES} = ${${LIBRARIES}}") +endif() + +if(_libraries_work) + set(${LIBRARIES} ${${LIBRARIES}} ${_blas} ${_threads}) +else() + set(${LIBRARIES} FALSE) +endif() + +endmacro() + + +set(LAPACK_LINKER_FLAGS) +set(LAPACK_LIBRARIES) +set(LAPACK95_LIBRARIES) + + +if(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + find_package(BLAS) +else() + find_package(BLAS REQUIRED) +endif() + + +if(BLAS_FOUND) + set(LAPACK_LINKER_FLAGS ${BLAS_LINKER_FLAGS}) + if (NOT $ENV{BLA_VENDOR} STREQUAL "") + set(BLA_VENDOR $ENV{BLA_VENDOR}) + else () + if(NOT BLA_VENDOR) + set(BLA_VENDOR "All") + endif() + endif () + +#intel lapack +if (BLA_VENDOR MATCHES "Intel" OR BLA_VENDOR STREQUAL "All") + if (NOT WIN32) + set(LAPACK_mkl_LM "-lm") + set(LAPACK_mkl_LDL "-ldl") + endif () + if (CMAKE_C_COMPILER_LOADED OR CMAKE_CXX_COMPILER_LOADED) + if(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + find_PACKAGE(Threads) + else() + find_package(Threads REQUIRED) + endif() + + if (BLA_VENDOR MATCHES "_64ilp") + set(LAPACK_mkl_ILP_MODE "ilp64") + else () + set(LAPACK_mkl_ILP_MODE "lp64") + endif () + + set(LAPACK_SEARCH_LIBS "") + + if (BLA_F95) + set(LAPACK_mkl_SEARCH_SYMBOL "cheev_f95") + set(_LIBRARIES LAPACK95_LIBRARIES) + set(_BLAS_LIBRARIES ${BLAS95_LIBRARIES}) + + # old + list(APPEND LAPACK_SEARCH_LIBS + "mkl_lapack95") + # new >= 10.3 + list(APPEND LAPACK_SEARCH_LIBS + "mkl_intel_c") + list(APPEND LAPACK_SEARCH_LIBS + "mkl_lapack95_${LAPACK_mkl_ILP_MODE}") + else() + set(LAPACK_mkl_SEARCH_SYMBOL "cheev") + set(_LIBRARIES LAPACK_LIBRARIES) + set(_BLAS_LIBRARIES ${BLAS_LIBRARIES}) + + # old + list(APPEND LAPACK_SEARCH_LIBS + "mkl_lapack") + endif() + + # First try empty lapack libs + if (NOT ${_LIBRARIES}) + check_lapack_libraries( + ${_LIBRARIES} + LAPACK + ${LAPACK_mkl_SEARCH_SYMBOL} + "" + "" + "${_BLAS_LIBRARIES}" + "" + ) + endif () + # Then try the search libs + foreach (IT ${LAPACK_SEARCH_LIBS}) + if (NOT ${_LIBRARIES}) + check_lapack_libraries( + ${_LIBRARIES} + LAPACK + ${LAPACK_mkl_SEARCH_SYMBOL} + "" + "${IT}" + "${_BLAS_LIBRARIES}" + "${CMAKE_THREAD_LIBS_INIT};${LAPACK_mkl_LM};${LAPACK_mkl_LDL}" + ) + endif () + endforeach () + + unset(LAPACK_mkl_ILP_MODE) + unset(LAPACK_mkl_SEARCH_SYMBOL) + unset(LAPACK_mkl_LM) + unset(LAPACK_mkl_LDL) + endif () +endif() + +if (BLA_VENDOR STREQUAL "Goto" OR BLA_VENDOR STREQUAL "All") + if(NOT LAPACK_LIBRARIES) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "goto2" + "${BLAS_LIBRARIES}" + "" + ) + endif() +endif () + +if (BLA_VENDOR STREQUAL "OpenBLAS" OR BLA_VENDOR STREQUAL "All") + if(NOT LAPACK_LIBRARIES) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "openblas" + "${BLAS_LIBRARIES}" + "" + ) + endif() +endif () + +if (BLA_VENDOR STREQUAL "FLAME" OR BLA_VENDOR STREQUAL "All") + if(NOT LAPACK_LIBRARIES) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "flame" + "${BLAS_LIBRARIES}" + "" + ) + endif() +endif () + +#acml lapack +if (BLA_VENDOR MATCHES "ACML" OR BLA_VENDOR STREQUAL "All") + if (BLAS_LIBRARIES MATCHES ".+acml.+") + set (LAPACK_LIBRARIES ${BLAS_LIBRARIES}) + endif () +endif () + +# Apple LAPACK library? +if (BLA_VENDOR STREQUAL "Apple" OR BLA_VENDOR STREQUAL "All") + if(NOT LAPACK_LIBRARIES) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "Accelerate" + "${BLAS_LIBRARIES}" + "" + ) + endif() +endif () +if (BLA_VENDOR STREQUAL "NAS" OR BLA_VENDOR STREQUAL "All") + if ( NOT LAPACK_LIBRARIES ) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "vecLib" + "${BLAS_LIBRARIES}" + "" + ) + endif () +endif () +# Generic LAPACK library? +if (BLA_VENDOR STREQUAL "Generic" OR + BLA_VENDOR STREQUAL "ATLAS" OR + BLA_VENDOR STREQUAL "All") + if ( NOT LAPACK_LIBRARIES ) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "lapack" + "${BLAS_LIBRARIES}" + "" + ) + endif () +endif () + +else() + message(STATUS "LAPACK requires BLAS") +endif() + +if(BLA_F95) + if(LAPACK95_LIBRARIES) + set(LAPACK95_FOUND TRUE) + else() + set(LAPACK95_FOUND FALSE) + endif() + if(NOT LAPACK_FIND_QUIETLY) + if(LAPACK95_FOUND) + message(STATUS "A library with LAPACK95 API found.") + else() + if(LAPACK_FIND_REQUIRED) + message(FATAL_ERROR + "A required library with LAPACK95 API not found. Please specify library location." + ) + else() + message(STATUS + "A library with LAPACK95 API not found. Please specify library location." + ) + endif() + endif() + endif() + set(LAPACK_FOUND "${LAPACK95_FOUND}") + set(LAPACK_LIBRARIES "${LAPACK95_LIBRARIES}") +else() + if(LAPACK_LIBRARIES) + set(LAPACK_FOUND TRUE) + else() + set(LAPACK_FOUND FALSE) + endif() + + if(NOT LAPACK_FIND_QUIETLY) + if(LAPACK_FOUND) + message(STATUS "A library with LAPACK API found.") + else() + if(LAPACK_FIND_REQUIRED) + message(FATAL_ERROR + "A required library with LAPACK API not found. Please specify library location." + ) + else() + message(STATUS + "A library with LAPACK API not found. Please specify library location." + ) + endif() + endif() + endif() +endif() + +cmake_pop_check_state() +set(CMAKE_FIND_LIBRARY_SUFFIXES ${_lapack_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES}) diff --git a/cmake/FindNvToolExt.cmake b/cmake/FindNvToolExt.cmake new file mode 100644 index 000000000..5f2998e44 --- /dev/null +++ b/cmake/FindNvToolExt.cmake @@ -0,0 +1,35 @@ +# The following variables are optionally searched for defaults +# NvToolExt_ROOT_DIR: +# +# The following are set after configuration is done: +# NvToolExt_FOUND +# NvToolExt_INCLUDE_DIR +# NvToolExt_LIBRARIES +# NvToolExt_LIBRARY_DIR +# NvToolExt: a target + +include(FindPackageHandleStandardArgs) + +set(NvToolExt_SEARCH_DIRS ${CUDA_TOOLKIT_ROOT_DIR}) +if(WIN32) + list(APPEND NvToolExt_SEARCH_DIRS "C:/Program Files/NVIDIA Corporation/NvToolsExt") +endif() +set(NvToolExt_SEARCH_DIRS ${NvToolExt_ROOT_DIR} ${NvToolExt_SEARCH_DIRS}) + + +find_path(NvToolExt_INCLUDE_DIR nvToolsExt.h HINTS ${NvToolExt_SEARCH_DIRS} PATH_SUFFIXES include) + +# 32bit not considered +set(NvToolExt_LIBNAME nvToolsExt libnvToolsExt.so libnvToolsExt.a libnvToolsExt.so nvToolsExt64_1.lib) +find_library(NvToolExt_LIBRARIES NAMES ${NvToolExt_LIBNAME} HINTS ${NvToolExt_SEARCH_DIRS} + PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) + +find_package_handle_standard_args(NvToolExt REQUIRED_VARS NvToolExt_INCLUDE_DIR NvToolExt_LIBRARIES) + +add_library(NvToolExt INTERFACE) +target_include_directories(NvToolExt INTERFACE ${NvToolExt_INCLUDE_DIR}) +# target_link_directories(NvToolExt INTERFACE ${NvToolExt_INCLUDE_DIR}) +target_link_libraries(NvToolExt INTERFACE ${NvToolExt_LIBRARIES}) + +unset(NvToolExt_SEARCH_DIRS) +unset(NvToolExt_LIBNAME) diff --git a/cmake/INSTALL.md b/cmake/INSTALL.md new file mode 100644 index 000000000..0082212eb --- /dev/null +++ b/cmake/INSTALL.md @@ -0,0 +1,49 @@ +# Install Instruction + +Execute following commands in the repo root. + +## Build with Old Style Make Generator +```bash +mkdir -p build && cd build +cmake -DCMAKE_INSTALL_PREFIX=../dist .. # configure +cmake --build . --target install -- -j8 # build && install, substitude -j8 with /m:8 if you are on Windows +``` + +## Build with Ninja Generator +``` bash +mkdir -p build && cd build +cmake -GNinja -DCMAKE_INSTALL_PREFIX=../dist .. +cmake --build . --target install +``` + +After built, you can find all installed files in /dist + +# For Advance Configuration + +Follow options are currently available: + +| Variable | Available Options | Default | +| ---------------------- | ------------------------- | -------- | +| MATHLIB | OpenBLAS, MKL, Accelerate | OpenBLAS | +| KALDI_BUILD_EXE | ON,OFF | ON | +| KALDI_BUILD_TEST | ON,OFF | ON | +| KALDI_USE_PATCH_NUMBER | ON,OFF | OFF | +| BUILD_SHARED_LIBS | ON,OFF | OFF | + +Append `-D=` to the configure command to use it, e.g., +`-DKALDI_BUILD_TEST=OFF` will disable building of test executables. For more +information, please refers to +[CMake Documentation](https://cmake.org/cmake/help/latest/manual/cmake.1.html). +For quick learning CMake usage, LLVM's short introuction will do the trick: +[Basic CMake usage](https://llvm.org/docs/CMake.html#usage), +[Options and variables](https://llvm.org/docs/CMake.html#options-and-variables), +[Frequently-used CMake variables](https://llvm.org/docs/CMake.html#frequently-used-cmake-variables). + +NOTE 1: Currently, BUILD_SHARED_LIBS does not work on Windows due to some symbols + (variables) are not properly exported. + +NOTE 2: For scripts users, since you are doing an out of source build, and the + install destination is at your disposal, the `$PATH` is not configured + properly in this case. Scripts will not work out of box. See how `$PATH` + is modified in [path.sh](../egs/wsj/s5/path.sh). You should add + `/bin` to your `$PATH` before running any scripts. diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake new file mode 100644 index 000000000..88dbefdac --- /dev/null +++ b/cmake/Utils.cmake @@ -0,0 +1,46 @@ +if(NOT CMAKE_VERSION VERSION_LESS "3.10") + include_guard() +endif() + +# For Windows, some env or vars are using backward slash for pathes, convert +# them to forward slashes will fix some nasty problem in CMake. +macro(normalize_path in_path) + file(TO_CMAKE_PATH "${${in_path}}" normalize_path_out_path) + set(${in_path} "${normalize_path_out_path}") + unset(normalize_path_out_path) +endmacro() + +macro(normalize_env_path in_path) + file(TO_CMAKE_PATH "$${in_path}" normalize_env_path_out_path) + set(${in_path} "${normalize_env_path_out_path}") + unset(normalize_env_path_out_path) +endmacro() + + +macro(add_kaldi_executable) + if(${KALDI_BUILD_EXE}) + cmake_parse_arguments(kaldi_exe "" "NAME" "SOURCES;DEPENDS" ${ARGN}) + add_executable(${kaldi_exe_NAME} ${kaldi_exe_SOURCES}) + target_link_libraries(${kaldi_exe_NAME} PRIVATE ${kaldi_exe_DEPENDS}) + # list(APPEND KALDI_EXECUTABLES ${kaldi_exe_NAME}) + install(TARGETS ${kaldi_exe_NAME} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + + unset(kaldi_exe_NAME) + unset(kaldi_exe_SOURCES) + unset(kaldi_exe_DEPENDS) + endif() +endmacro() + +macro(add_kaldi_test_executable) + if(${KALDI_BUILD_TEST}) + cmake_parse_arguments(kaldi_test_exe "" "NAME" "SOURCES;DEPENDS" ${ARGN}) + add_executable(${kaldi_test_exe_NAME} ${kaldi_test_exe_SOURCES}) + target_link_libraries(${kaldi_test_exe_NAME} PRIVATE ${kaldi_test_exe_DEPENDS}) + # list(APPEND KALDI_TEST_EXECUTABLES ${kaldi_test_exe_NAME}) + install(TARGETS ${kaldi_test_exe_NAME} RUNTIME DESTINATION testbin) + + unset(kaldi_test_exe_NAME) + unset(kaldi_test_exe_SOURCES) + unset(kaldi_test_exe_DEPENDS) + endif() +endmacro() diff --git a/cmake/VersionHelper.cmake b/cmake/VersionHelper.cmake new file mode 100644 index 000000000..e494a2556 --- /dev/null +++ b/cmake/VersionHelper.cmake @@ -0,0 +1,14 @@ +function(get_version) + file(READ ${CMAKE_CURRENT_SOURCE_DIR}/src/.version version) + string(STRIP ${version} version) + execute_process(COMMAND git log -n1 --format=%H src/.version + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE version_commit + OUTPUT_STRIP_TRAILING_WHITESPACE) + execute_process(COMMAND git rev-list --count "${version_commit}..HEAD" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE patch_number) + + set(KALDI_VERSION ${version} PARENT_SCOPE) + set(KALDI_PATCH_NUMBER ${patch_number} PARENT_SCOPE) +endfunction() diff --git a/cmake/gen_cmake_skeleton.py b/cmake/gen_cmake_skeleton.py new file mode 100644 index 000000000..fa5069436 --- /dev/null +++ b/cmake/gen_cmake_skeleton.py @@ -0,0 +1,310 @@ +import os +import sys +import re +import argparse + +# earily parse, will refernece args globally +parser = argparse.ArgumentParser() +parser.add_argument("working_dir") +parser.add_argument("--quiet", default=False, action="store_true") +args = parser.parse_args() + +def print_wrapper(*args_, **kwargs): + if not args.quiet: + print(*args_, **kwargs) + +def get_subdirectories(d): + return [name for name in os.listdir(d) if os.path.isdir(os.path.join(d, name))] + +def is_bin_dir(d): + return d.endswith("bin") + +def get_files(d): + return [name for name in os.listdir(d) if os.path.isfile(os.path.join(d, name))] + +def is_header(f): + return f.endswith(".h") + +def is_cu_source(f): + return f.endswith(".cu") + +def is_test_source(f): + return f.endswith("-test.cc") + +def is_source(f): + return f.endswith(".cc") and not is_test_source(f) + +def dir_name_to_lib_target(dir_name): + return "kaldi-" + dir_name + +def wrap_notwin32_condition(should_wrap, lines): + if isinstance(lines, str): + lines = [lines] + if should_wrap: + return ["if(NOT WIN32)"] + list(map(lambda l: " " + l, lines)) + ["endif()"] + else: + return lines + + +def get_exe_additional_depends(t): + additional = { + "transform-feats" : ["transform"], + "interpolate-pitch" : ["transform"], + "post-to-feats" : ["hmm"], + "append-post-to-feats" : ["hmm"], + "gmm-est-fmllr-gpost": ["sgmm2", "hmm"], + "gmm-est-fmllr": ["hmm", "transform"], + "gmm-latgen-faster": ["decoder"], + "gmm-transform-means": ["hmm"], + "gmm-post-to-gpost": ["hmm"], + "gmm-init-lvtln": ["transform"], + "gmm-rescore-lattice": ["hmm", "lat"], + "gmm-est-fmllr-global": ["transform"], + "gmm-copy": ["hmm"], + "gmm-train-lvtln-special": ["transform", "hmm"], + "gmm-est-map": ["hmm"], + "gmm-acc-stats2": ["hmm"], + "gmm-decode-faster-regtree-mllr": ["decoder"], + "gmm-global-est-fmllr": ["transform"], + "gmm-est-basis-fmllr": ["hmm", "transform"], + "gmm-init-model": ["hmm"], + "gmm-est-weights-ebw": ["hmm"], + "gmm-init-biphone": ["hmm"], + "gmm-compute-likes": ["hmm"], + "gmm-est-fmllr-raw-gpost": ["hmm", "transform"], + # gmm-* is a bottom case, it will add link dependencies to all other + # target whose names start with gmm-, it is harmless, but will increase + # link time. Better to avoid it at best. + "gmm-*": ["hmm", "transform", "lat", "decoder"], + } + if t in additional: + return list(map(lambda name: dir_name_to_lib_target(name), additional[t])) + elif (t.split("-", 1)[0] + "-*") in additional: + wildcard = (t.split("-", 1)[0] + "-*") + return list(map(lambda name: dir_name_to_lib_target(name), additional[wildcard])) + else: + return [] + +def disable_for_win32(t): + disabled = [ + "online-audio-client", + "online-net-client", + "online2-tcp-nnet3-decode-faster", + "online-server-gmm-decode-faster", + "online-audio-server-decode-faster" + ] + return t in disabled + +class CMakeListsHeaderLibrary(object): + def __init__(self, dir_name): + self.dir_name = dir_name + self.target_name = dir_name_to_lib_target(self.dir_name) + self.header_list = [] + + def add_header(self, filename): + self.header_list.append(filename) + + def add_source(self, filename): + pass + + def add_cuda_source(self, filename): + pass + + def add_test_source(self, filename): + pass + + def gen_code(self): + ret = [] + if len(self.header_list) > 0: + ret.append("set(PUBLIC_HEADERS") + for f in self.header_list: + ret.append(" " + f) + ret.append(")\n") + + ret.append("add_library(" + self.target_name + " INTERFACE)") + ret.append("target_include_directories(" + self.target_name + " INTERFACE ") + ret.append(" $") + ret.append(" $") + ret.append(")\n") + + ret.append(""" +install(TARGETS {tgt} EXPORT kaldi-targets) + +install(FILES ${{PUBLIC_HEADERS}} DESTINATION include/kaldi/{dir}) +""".format(tgt=self.target_name, dir=self.dir_name)) + + return "\n".join(ret) + +class CMakeListsLibrary(object): + + def __init__(self, dir_name): + self.dir_name = dir_name + self.target_name = dir_name_to_lib_target(self.dir_name) + self.header_list = [] + self.source_list = [] + self.cuda_source_list = [] + self.test_source_list = [] + self.depends = [] + + def add_header(self, filename): + self.header_list.append(filename) + + def add_source(self, filename): + self.source_list.append(filename) + + def add_cuda_source(self, filename): + self.cuda_source_list.append(filename) + + def add_test_source(self, filename): + self.test_source_list.append(filename) + + def load_dependency_from_makefile(self, filename): + with open(filename) as f: + makefile = f.read() + if "ADDLIBS" not in makefile: + print_wrapper("WARNING: non-standard", filename) + return + libs = makefile.split("ADDLIBS")[-1].split("\n\n")[0] + libs = re.findall("[^\s\\\\=]+", libs) + for l in libs: + self.depends.append(os.path.splitext(os.path.basename(l))[0]) + + def gen_code(self): + ret = [] + + if len(self.header_list) > 0: + ret.append("set(PUBLIC_HEADERS") + for f in self.header_list: + ret.append(" " + f) + ret.append(")\n") + + if len(self.cuda_source_list) > 0: + self.source_list.append("${CUDA_OBJS}") + ret.append("cuda_include_directories(${CMAKE_CURRENT_SOURCE_DIR}/..)") + ret.append("cuda_compile(CUDA_OBJS") + for f in self.cuda_source_list: + ret.append(" " + f) + ret.append(")\n") + + ret.append("add_library(" + self.target_name) + for f in self.source_list: + ret.append(" " + f) + ret.append(")\n") + ret.append("target_include_directories(" + self.target_name + " PUBLIC ") + ret.append(" $") + ret.append(" $") + ret.append(")\n") + + if len(self.depends) > 0: + ret.append("target_link_libraries(" + self.target_name + " PUBLIC") + for d in self.depends: + ret.append(" " + d) + ret.append(")\n") + + def get_test_exe_name(filename): + exe_name = os.path.splitext(f)[0] + if self.dir_name.startswith("nnet") and exe_name.startswith("nnet"): + return self.dir_name + "-" + exe_name.split("-", 1)[1] + else: + return exe_name + + if len(self.test_source_list) > 0: + ret.append("if(KALDI_BUILD_TEST)") + for f in self.test_source_list: + exe_target = get_test_exe_name(f) + depends = (self.target_name + " " + " ".join(get_exe_additional_depends(exe_target))).strip() + ret.extend(wrap_notwin32_condition(disable_for_win32(self.target_name), + " add_kaldi_test_executable(NAME " + exe_target + " SOURCES " + f + " DEPENDS " + depends + ")")) + ret.append("endif()") + + ret.append(""" +install(TARGETS {tgt} + EXPORT kaldi-targets + ARCHIVE DESTINATION ${{CMAKE_INSTALL_LIBDIR}} + LIBRARY DESTINATION ${{CMAKE_INSTALL_LIBDIR}} + RUNTIME DESTINATION ${{CMAKE_INSTALL_BINDIR}} +) + +install(FILES ${{PUBLIC_HEADERS}} DESTINATION include/kaldi/{dir}) +""".format(tgt=self.target_name, dir=self.dir_name)) + + return "\n".join(ret) + + + +class CMakeListsExecutable(object): + + def __init__(self, dir_name, filename): + assert(dir_name.endswith("bin")) + self.list = [] + exe_name = os.path.splitext(os.path.basename(filename))[0] + file_name = filename + depend = dir_name_to_lib_target(dir_name[:-3]) + self.list.append((exe_name, file_name, depend)) + + def gen_code(self): + ret = [] + for exe_name, file_name, depend in self.list: + depends = (depend + " " + " ".join(get_exe_additional_depends(exe_name))).strip() + ret.extend(wrap_notwin32_condition(disable_for_win32(exe_name), + "add_kaldi_executable(NAME " + exe_name + " SOURCES " + file_name + " DEPENDS " + depends + ")")) + + return "\n".join(ret) + +class CMakeListsFile(object): + + GEN_CMAKE_HEADER = "# generated with cmake/gen_cmake_skeleton.py, DO NOT MODIFY.\n" + + def __init__(self, directory): + self.path = os.path.realpath(os.path.join(directory, "CMakeLists.txt")) + self.sections = [] + + def add_section(self, section): + self.sections.append(section) + + def write_file(self): + with open(self.path, "w", newline='\n') as f: # good luck for python2 + f.write(CMakeListsFile.GEN_CMAKE_HEADER) + for s in self.sections: + code = s.gen_code() + f.write(code) + f.write("\n") + print_wrapper(" Writed", self.path) + + +if __name__ == "__main__": + os.chdir(args.working_dir) + print_wrapper("Working in ", args.working_dir) + + subdirs = get_subdirectories(".") + for d in subdirs: + cmakelists = CMakeListsFile(d) + if is_bin_dir(d): + for f in get_files(d): + if is_source(f): + dir_name = os.path.basename(d) + filename = os.path.basename(f) + exe = CMakeListsExecutable(dir_name, filename) + cmakelists.add_section(exe) + else: + dir_name = os.path.basename(d) + lib = None + makefile = os.path.join(d, "Makefile") + if not os.path.exists(makefile): + lib = CMakeListsHeaderLibrary(dir_name) + else: + lib = CMakeListsLibrary(dir_name) + lib.load_dependency_from_makefile(makefile) + cmakelists.add_section(lib) + for f in sorted(get_files(d)): + filename = os.path.basename(f) + if is_source(filename): + lib.add_source(filename) + elif is_cu_source(filename): + lib.add_cuda_source(filename) + elif is_test_source(filename): + lib.add_test_source(filename) + elif is_header(filename): + lib.add_header(filename) + + cmakelists.write_file() diff --git a/cmake/kaldi-config.cmake.in b/cmake/kaldi-config.cmake.in new file mode 100644 index 000000000..123f58c56 --- /dev/null +++ b/cmake/kaldi-config.cmake.in @@ -0,0 +1,7 @@ +@PACKAGE_INIT@ + +find_package(Threads) + +if(NOT TARGET kaldi-base) + include(${CMAKE_CURRENT_LIST_DIR}/kaldi-targets.cmake) +endif() diff --git a/cmake/third_party/get_third_party.cmake b/cmake/third_party/get_third_party.cmake new file mode 100644 index 000000000..8e24dc9f6 --- /dev/null +++ b/cmake/third_party/get_third_party.cmake @@ -0,0 +1,20 @@ +# Download and unpack a third-party library at configure time +# The original code is at the README of google-test: +# https://github.com/google/googletest/tree/master/googletest +function(get_third_party name) + configure_file( + "${PROJECT_SOURCE_DIR}/cmake/third_party/${name}.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/${name}-download/CMakeLists.txt") + execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${name}-download") + if(result) + message(FATAL_ERROR "CMake step for ${name} failed: ${result}") + endif() + execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${name}-download") + if(result) + message(FATAL_ERROR "Build step for ${name} failed: ${result}") + endif() +endfunction() diff --git a/cmake/third_party/openfst.cmake b/cmake/third_party/openfst.cmake new file mode 100644 index 000000000..19a7f527f --- /dev/null +++ b/cmake/third_party/openfst.cmake @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 2.8.2) +project(openfst-download NONE) + +include(ExternalProject) +ExternalProject_Add(openfst + GIT_REPOSITORY https://github.com/kkm000/openfst + GIT_TAG 0bca6e76d24647427356dc242b0adbf3b5f1a8d9 # tag win/1.7.2.1 + SOURCE_DIR "${CMAKE_BINARY_DIR}/openfst" + BINARY_DIR "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/cmake/third_party/openfst_lib_target.cmake b/cmake/third_party/openfst_lib_target.cmake new file mode 100644 index 000000000..dde5efc40 --- /dev/null +++ b/cmake/third_party/openfst_lib_target.cmake @@ -0,0 +1,31 @@ +if(NOT OPENFST_ROOT_DIR) + message(FATAL_ERROR) +endif() + +set(fst_source_dir ${OPENFST_ROOT_DIR}/src/lib) +set(fst_include_dir ${OPENFST_ROOT_DIR}/src/include) + +include_directories(${fst_include_dir}) +file(GLOB fst_sources "${fst_source_dir}/*.cc") + +add_library(fst ${fst_sources}) +target_include_directories(fst PUBLIC + $ + $ +) + +install(TARGETS fst + EXPORT kaldi-targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +) + +install(DIRECTORY ${fst_include_dir}/fst + DESTINATION include/openfst + PATTERN "test/*.h" EXCLUDE +) + +unset(fst_source_dir) +unset(fst_include_dir) +unset(fst_sources) diff --git a/egs/babel/s5d/local/chain2/run_tdnn.sh b/egs/babel/s5d/local/chain2/run_tdnn.sh new file mode 100755 index 000000000..58d1e0cc0 --- /dev/null +++ b/egs/babel/s5d/local/chain2/run_tdnn.sh @@ -0,0 +1,331 @@ +#!/bin/bash + + +# by default, with cleanup +# please note that the language(s) was not selected for any particular reason (other to represent the various sizes of babel datasets) +# 304-lithuanian | %WER 42.6 | 20041 61492 | 60.3 29.6 10.1 2.9 42.6 29.2 | -0.226 | exp/chain_cleaned/tdnn_sp/decode_dev10h.pem/score_10/dev10h.pem.ctm.sys +# num-iters=48 nj=2..12 num-params=6.7M dim=43+100->3273 combine=-0.192->-0.179 +# xent:train/valid[31,47,final]=(-2.47,-2.34,-2.33/-2.66,-2.57,-2.57) +# logprob:train/valid[31,47,final]=(-0.191,-0.163,-0.162/-0.246,-0.242,-0.243) +# 206-zulu | %WER 54.6 | 22805 52162 | 49.1 39.7 11.2 3.7 54.6 31.1 | -0.567 | exp/chain_cleaned/tdnn_sp/decode_dev10h.pem/score_11/dev10h.pem.ctm.sys +# num-iters=66 nj=2..12 num-params=6.7M dim=43+100->3274 combine=-0.236->-0.227 +# xent:train/valid[43,65,final]=(-2.59,-2.46,-2.46/-2.73,-2.67,-2.66) +# logprob:train/valid[43,65,final]=(-0.236,-0.208,-0.206/-0.289,-0.287,-0.286) +# 104-pashto | %WER 42.7 | 21825 101803 | 61.1 27.5 11.4 3.8 42.7 30.4 | -0.345 | exp/chain_cleaned/tdnn_sp/decode_dev10h.pem/score_10/dev10h.pem.ctm.sys +# num-iters=85 nj=2..12 num-params=6.8M dim=43+100->3328 combine=-0.215->-0.211 +# xent:train/valid[55,84,final]=(-2.44,-2.32,-2.32/-2.63,-2.57,-2.56) +# logprob:train/valid[55,84,final]=(-0.214,-0.192,-0.191/-0.281,-0.276,-0.275) + + +set -e -o pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=-1 +nj=30 +train_set=train_cleaned +gmm=tri5_cleaned # the gmm for the target data +langdir=data/langp/tri5_ali +num_threads_ubm=12 +nnet3_affix=_cleaned # cleanup affix for nnet3 and chain dirs, e.g. _cleaned + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +train_stage=-10 +tree_affix= # affix for tree directory, e.g. "a" or "b", in case we change the configuration. +tdnn_affix= #affix for TDNN directory, e.g. "a" or "b", in case we change the configuration. +common_egs_dir= # you can set this to use previously dumped egs. +chunk_width=150,120,90,75 +langs=default # has multiple values for a multilingual system +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat <data/lang_chain/topo + fi +fi + +if [ $stage -le 15 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \ + $langdir $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 16 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --leftmost-questions-truncate -1 \ + --cmd "$train_cmd" 4000 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir +fi + +xent_regularize=0.1 +if [ $stage -le 17 ]; then + mkdir -p $dir + + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + [ -z $num_targets ] && { echo "$0: error getting num-targets"; exit 1; } + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=43 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=450 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=450 + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=450 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=450 + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=450 + relu-batchnorm-layer name=tdnn7 input=Append(-6,-3,0) dim=450 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn7 dim=450 target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + output-layer name=output-default input=prefinal-chain include-log-softmax=false dim=$num_targets max-change=1.5 + + ## adding the layers for chain branch + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn7 dim=450 target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + output-layer name=output-default-xent input=prefinal-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ + if [ ! -f $dir/init/default_trans.mdl ]; then # checking this because it may have been copied in a previous run of the same script + copy-transition-model $tree_dir/final.mdl $dir/init/default_trans.mdl || exit 1 & + else + echo "Keeping the old $dir/init/default_trans.mdl as it already exists." + fi + +fi + +init_info=$dir/init/info.txt +if [ $stage -le 18 ]; then + + if [ ! -f $dir/configs/ref.raw ]; then + echo "Expected $dir/configs/ref.raw to exist" + exit + fi + + mkdir -p $dir/init + nnet3-info $dir/configs/ref.raw > $dir/configs/temp.info + model_left_context=`fgrep 'left-context' $dir/configs/temp.info | awk '{print $2}'` + model_right_context=`fgrep 'right-context' $dir/configs/temp.info | awk '{print $2}'` + cat >$init_info <" + echo " " + echo "main options (for others, see top of script file)" + echo " --enhancement # enhancement type (gss or beamformit)" + exit 1; +fi + +jdir=$1 +puttdir=$2 +utt_loc_file=$3 + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +if [[ ${enhancement} == *gss* ]]; then + local/get_location.py $jdir > $utt_loc_file + local/replace_uttid.py $utt_loc_file $puttdir/per_utt > $puttdir/per_utt_loc +fi + +if [[ ${enhancement} == *beamformit* ]]; then + cat $puttdir/per_utt > $puttdir/per_utt_loc +fi diff --git a/egs/chime6/s5_track1/local/chain/compare_wer.sh b/egs/chime6/s5_track1/local/chain/compare_wer.sh new file mode 100755 index 000000000..cd6be14ed --- /dev/null +++ b/egs/chime6/s5_track1/local/chain/compare_wer.sh @@ -0,0 +1,131 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/tdnn_{c,d}_sp +# For use with discriminatively trained systems you specify the epochs after a colon: +# for instance, +# local/chain/compare_wer.sh exp/chain/tdnn_c_sp exp/chain/tdnn_c_sp_smbr:{1,2,3} + + +if [ $# == 0 ]; then + echo "Usage: $0: [--looped] [--online] [ ... ]" + echo "e.g.: $0 exp/chain/tdnn_{b,c}_sp" + echo "or (with epoch numbers for discriminative training):" + echo "$0 exp/chain/tdnn_b_sp_disc:{1,2,3}" + exit 1 +fi + +echo "# $0 $*" + +include_looped=false +if [ "$1" == "--looped" ]; then + include_looped=true + shift +fi +include_online=false +if [ "$1" == "--online" ]; then + include_online=true + shift +fi + + +used_epochs=false + +# this function set_names is used to separate the epoch-related parts of the name +# [for discriminative training] and the regular parts of the name. +# If called with a colon-free directory name, like: +# set_names exp/chain/tdnn_lstm1e_sp_bi_smbr +# it will set dir=exp/chain/tdnn_lstm1e_sp_bi_smbr and epoch_infix="" +# If called with something like: +# set_names exp/chain/tdnn_d_sp_smbr:3 +# it will set dir=exp/chain/tdnn_d_sp_smbr and epoch_infix="_epoch3" + + +set_names() { + if [ $# != 1 ]; then + echo "compare_wer_general.sh: internal error" + exit 1 # exit the program + fi + dirname=$(echo $1 | cut -d: -f1) + epoch=$(echo $1 | cut -s -d: -f2) + if [ -z $epoch ]; then + epoch_infix="" + else + used_epochs=true + epoch_infix=_epoch${epoch} + fi +} + + + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +strings=( + "#WER dev_clean_2 (tgsmall) " + "#WER dev_clean_2 (tglarge) ") + +for n in 0 1; do + echo -n "${strings[$n]}" + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + decode_names=(tgsmall_dev_clean_2 tglarge_dev_clean_2) + + wer=$(cat $dirname/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + if $include_looped; then + echo -n "# [looped:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat $dirname/decode_looped_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi + if $include_online; then + echo -n "# [online:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat ${dirname}_online/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi +done + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo diff --git a/egs/chime6/s5_track1/local/chain/run_tdnn.sh b/egs/chime6/s5_track1/local/chain/run_tdnn.sh new file mode 120000 index 000000000..61f8f4991 --- /dev/null +++ b/egs/chime6/s5_track1/local/chain/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1b.sh \ No newline at end of file diff --git a/egs/chime6/s5_track1/local/chain/tuning/run_tdnn_1a.sh b/egs/chime6/s5_track1/local/chain/tuning/run_tdnn_1a.sh new file mode 100755 index 000000000..daad37e2c --- /dev/null +++ b/egs/chime6/s5_track1/local/chain/tuning/run_tdnn_1a.sh @@ -0,0 +1,270 @@ +#!/bin/bash + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +nj=96 +train_set=train_worn_u100k +test_sets="dev_worn dev_beamformit_ref" +gmm=tri3 +nnet3_affix=_train_worn_u100k +lm_suffix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1a # affix for the TDNN directory name +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# training options +# training chunk-options +chunk_width=140,100,160 +common_egs_dir= +xent_regularize=0.1 + +# training options +srand=0 +remove_egs=true + +#decode options +test_online_decoding=false # if true, it will run the last decoding stage. + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj ${nj} --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +if [ $stage -le 13 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + opts="l2-regularize=0.05" + output_opts="l2-regularize=0.01 bottleneck-dim=320" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 $opts dim=512 + relu-batchnorm-layer name=tdnn2 $opts dim=512 input=Append(-1,0,1) + relu-batchnorm-layer name=tdnn3 $opts dim=512 + relu-batchnorm-layer name=tdnn4 $opts dim=512 input=Append(-1,0,1) + relu-batchnorm-layer name=tdnn5 $opts dim=512 + relu-batchnorm-layer name=tdnn6 $opts dim=512 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn7 $opts dim=512 input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn8 $opts dim=512 input=Append(-6,-3,0) + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain $opts dim=512 target-rms=0.5 + output-layer name=output include-log-softmax=false $output_opts dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn8 $opts dim=512 target-rms=0.5 + output-layer name=output-xent $output_opts dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 14 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/chime5-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.00005 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=10 \ + --trainer.frames-per-iter=3000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=4 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.num-chunk-per-minibatch=256,128,64 \ + --trainer.optimization.momentum=0.0 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 15 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang${lm_suffix}/ \ + $tree_dir $tree_dir/graph${lm_suffix} || exit 1; +fi + +if [ $stage -le 16 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + steps/nnet3/decode.sh \ + --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj 8 --cmd "$decode_cmd" --num-threads 4 \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${data}_hires \ + $tree_dir/graph${lm_suffix} data/${data}_hires ${dir}/decode${lm_suffix}_${data} || exit 1 + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 +fi + +# Not testing the 'looped' decoding separately, because for +# TDNN systems it would give exactly the same results as the +# normal decoding. + +if $test_online_decoding && [ $stage -le 17 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3${nnet3_affix}/extractor ${dir} ${dir}_online + + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l 2776 combine=-0.134->-0.133 (over 3) xent:train/valid[285,428,final]=(-2.37,-1.95,-1.95/-2.19,-1.90,-1.91) logprob:train/valid[285,428,final]=(-0.201,-0.125,-0.124/-0.198,-0.147,-0.148) + +set -e + +# configs for 'chain' +stage=0 +nj=96 +train_set=train_worn_u400k +test_sets="dev_worn dev_beamformit_ref" +gmm=tri3 +nnet3_affix=_train_worn_u400k +lm_suffix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1b # affix for the TDNN directory name +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +num_epochs=4 +common_egs_dir= +# training options +# training chunk-options +chunk_width=140,100,160 +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.5@0.50,0' + +# training options +srand=0 +remove_egs=true + +#decode options +test_online_decoding=false # if true, it will run the last decoding stage. +skip_decoding=true +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj ${nj} --cmd "$train_cmd" --generate-ali-from-lats true \ + ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $lat_dir $tree_dir +fi + +if [ $stage -le 13 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + affine_opts="l2-regularize=0.01 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" + tdnnf_opts="l2-regularize=0.01 dropout-proportion=0.0 bypass-scale=0.66" + linear_opts="l2-regularize=0.01 orthonormal-constraint=-1.0" + prefinal_opts="l2-regularize=0.01" + output_opts="l2-regularize=0.002" + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-dropout-layer name=tdnn1 $affine_opts dim=1536 + tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf3 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf4 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=1 + tdnnf-layer name=tdnnf5 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=0 + tdnnf-layer name=tdnnf6 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf7 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf8 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf9 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf10 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf11 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf12 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf13 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf14 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + tdnnf-layer name=tdnnf15 $tdnnf_opts dim=1536 bottleneck-dim=160 time-stride=3 + linear-component name=prefinal-l dim=256 $linear_opts + + prefinal-layer name=prefinal-chain input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + prefinal-layer name=prefinal-xent input=prefinal-l $prefinal_opts big-dim=1536 small-dim=256 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + +if [ $stage -le 14 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/chime5-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage $train_stage \ + --cmd "$train_cmd --mem 4G" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts "--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient 0.1 \ + --chain.l2-regularize 0.0 \ + --chain.apply-deriv-weights false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule "$dropout_schedule" \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --egs.dir "$common_egs_dir" \ + --egs.stage $get_egs_stage \ + --egs.opts "--frames-overlap-per-eg 0" \ + --egs.chunk-width $chunk_width \ + --trainer.num-chunk-per-minibatch 64 \ + --trainer.frames-per-iter 1500000 \ + --trainer.num-epochs $num_epochs \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.00025 \ + --trainer.optimization.final-effective-lrate 0.000025 \ + --trainer.max-param-change 2.0 \ + --cleanup.remove-egs $remove_egs \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir $dir || exit 1; + +fi + +if [ $stage -le 15 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang${lm_suffix}/ \ + $tree_dir $tree_dir/graph${lm_suffix} || exit 1; +fi + +if [ $stage -le 16 ] && [[ $skip_decoding == "false" ]]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + steps/nnet3/decode.sh \ + --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj 8 --cmd "$decode_cmd" --num-threads 4 \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${data}_hires \ + $tree_dir/graph${lm_suffix} data/${data}_hires ${dir}/decode${lm_suffix}_${data} || exit 1 + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 +fi + +exit 0; diff --git a/egs/chime6/s5_track1/local/check_tools.sh b/egs/chime6/s5_track1/local/check_tools.sh new file mode 100755 index 000000000..8e80e25ca --- /dev/null +++ b/egs/chime6/s5_track1/local/check_tools.sh @@ -0,0 +1,76 @@ +#!/bin/bash -u + +# Copyright 2015 (c) Johns Hopkins University (Jan Trmal ) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +[ -f ./path.sh ] && . ./path.sh + +command -v uconv &>/dev/null \ + || { echo >&2 "uconv not found on PATH. You will have to install ICU4C"; exit 1; } + +command -v ngram &>/dev/null \ + || { echo >&2 "srilm not found on PATH. Please use the script $KALDI_ROOT/tools/extras/install_srilm.sh to install it"; exit 1; } + +if [ -z ${LIBLBFGS} ]; then + echo >&2 "SRILM is not compiled with the support of MaxEnt models." + echo >&2 "You should use the script in \$KALDI_ROOT/tools/install_srilm.sh" + echo >&2 "which will take care of compiling the SRILM with MaxEnt support" + exit 1; +fi + +sox=`command -v sox 2>/dev/null` \ + || { echo >&2 "sox not found on PATH. Please install it manually (you will need version 14.4.0 and higher)."; exit 1; } + +# If sox is found on path, check if the version is correct +if [ ! -z "$sox" ]; then + sox_version=`$sox --version 2>&1| head -1 | sed -e 's?.*: ??' -e 's?.* ??'` + if [[ ! $sox_version =~ v14.4.* ]]; then + echo "Unsupported sox version $sox_version found on path. You will need version v14.4.0 and higher." + exit 1 + fi +fi + +command -v phonetisaurus-align &>/dev/null \ + || { echo >&2 "Phonetisaurus not found on PATH. Please use the script $KALDI_ROOT/tools/extras/install_phonetisaurus.sh to install it"; exit 1; } + +command -v BeamformIt &>/dev/null \ + || { echo >&2 "BeamformIt not found on PATH. Please use the script $KALDI_ROOT/tools/extras/install_beamformit.sh to install it"; exit 1; } + +miniconda_dir=$HOME/miniconda3/ +if [ ! -d $miniconda_dir ]; then + echo "$miniconda_dir does not exist. Please run '../../../tools/extras/install_miniconda.sh'" +fi + +# check if WPE is installed +result=`$miniconda_dir/bin/python -c "\ +try: + import nara_wpe + print('1') +except ImportError: + print('0')"` + +if [ "$result" != "1" ]; then + echo "WPE is not installed. Please run ../../../tools/extras/install_wpe.sh" + exit 1 +fi + +# this is used for the audio synchronization +sox_conda=`command -v ${miniconda_dir}/bin/sox 2>/dev/null` +if [ -z "${sox_conda}" ]; then + echo "install conda sox (v14.4.2)" + ${miniconda_dir}/bin/conda install -c conda-forge sox +fi + +exit 0 diff --git a/egs/chime6/s5_track1/local/copy_lat_dir_parallel.sh b/egs/chime6/s5_track1/local/copy_lat_dir_parallel.sh new file mode 100755 index 000000000..82839604c --- /dev/null +++ b/egs/chime6/s5_track1/local/copy_lat_dir_parallel.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +cmd=queue.pl +nj=40 +stage=0 +speed_perturb=true + +. ./path.sh +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + exit 1 +fi + +utt_map=$1 +data=$2 +srcdir=$3 +dir=$4 + +mkdir -p $dir + +cp $srcdir/{phones.txt,tree,final.mdl} $dir || exit 1 +cp $srcdir/{final.alimdl,final.occs,splice_opts,cmvn_opts,delta_opts,final.mat,full.mat} 2>/dev/null || true + +nj_src=$(cat $srcdir/num_jobs) || exit 1 + +if [ $stage -le 1 ]; then + $cmd JOB=1:$nj_src $dir/log/copy_lats_orig.JOB.log \ + lattice-copy "ark:gunzip -c $srcdir/lat.JOB.gz |" \ + ark,scp:$dir/lat_orig.JOB.ark,$dir/lat_orig.JOB.scp || exit 1 +fi + +for n in $(seq $nj_src); do + cat $dir/lat_orig.$n.scp +done > $dir/lat_orig.scp || exit 1 + +if $speed_perturb; then + for s in 0.9 1.1; do + awk -v s=$s '{print "sp"s"-"$1" sp"s"-"$2}' $utt_map + done | cat - $utt_map | sort -k1,1 > $dir/utt_map + utt_map=$dir/utt_map +fi + +if [ $stage -le 2 ]; then + utils/filter_scp.pl -f 2 $dir/lat_orig.scp < $utt_map | \ + utils/apply_map.pl -f 2 $dir/lat_orig.scp > \ + $dir/lat.scp || exit 1 + + if [ ! -s $dir/lat.scp ]; then + echo "$0: $dir/lat.scp is empty. Something went wrong!" + exit 1 + fi +fi + +utils/split_data.sh $data $nj + +if [ $stage -le 3 ]; then + $cmd JOB=1:$nj $dir/log/copy_lats.JOB.log \ + lattice-copy "scp:utils/filter_scp.pl $data/split$nj/JOB/utt2spk $dir/lat.scp |" \ + "ark:|gzip -c > $dir/lat.JOB.gz" || exit 1 +fi + +echo $nj > $dir/num_jobs + +if [ -f $srcdir/ali.1.gz ]; then + if [ $stage -le 4 ]; then + $cmd JOB=1:$nj_src $dir/log/copy_ali_orig.JOB.log \ + copy-int-vector "ark:gunzip -c $srcdir/ali.JOB.gz |" \ + ark,scp:$dir/ali_orig.JOB.ark,$dir/ali_orig.JOB.scp || exit 1 + fi + + for n in $(seq $nj_src); do + cat $dir/ali_orig.$n.scp + done > $dir/ali_orig.scp || exit 1 + + if [ $stage -le 5 ]; then + utils/filter_scp.pl -f 2 $dir/ali_orig.scp < $utt_map | \ + utils/apply_map.pl -f 2 $dir/ali_orig.scp > \ + $dir/ali.scp || exit 1 + + if [ ! -s $dir/ali.scp ]; then + echo "$0: $dir/ali.scp is empty. Something went wrong!" + exit 1 + fi + fi + + utils/split_data.sh $data $nj + + if [ $stage -le 6 ]; then + $cmd JOB=1:$nj $dir/log/copy_ali.JOB.log \ + copy-int-vector "scp:utils/filter_scp.pl $data/split$nj/JOB/utt2spk $dir/ali.scp |" \ + "ark:|gzip -c > $dir/ali.JOB.gz" || exit 1 + fi +fi + +rm $dir/lat_orig.*.{ark,scp} $dir/ali_orig.*.{ark,scp} 2>/dev/null || true diff --git a/egs/chime6/s5_track1/local/decode.sh b/egs/chime6/s5_track1/local/decode.sh new file mode 100755 index 000000000..b44716ba4 --- /dev/null +++ b/egs/chime6/s5_track1/local/decode.sh @@ -0,0 +1,253 @@ +#!/bin/bash +# +# Based mostly on the TED-LIUM and Switchboard recipe +# +# Copyright 2017 Johns Hopkins University (Author: Shinji Watanabe and Yenda Trmal) +# Apache 2.0 +# +# This is a subset of run.sh to only perform recognition experiments with evaluation data +# This script can be run from run.sh or standalone.  +# To run it standalone, you can download a pretrained chain ASR model using: +# wget http://kaldi-asr.org/models/12/0012_asr_v1.tar.gz +# Once it is downloaded, extract using: tar -xvzf 0012_asr_v1.tar.gz +# and copy the contents of the {data/ exp/} directory to your {data/ exp/} + +# Begin configuration section. +decode_nj=20 +gss_nj=50 +stage=0 +enhancement=gss # for a new enhancement method, + # change this variable and stage 4 + +# training data +train_set=train_worn_simu_u400k +# End configuration section +. ./utils/parse_options.sh + +. ./cmd.sh +. ./path.sh + + +set -e # exit on error + +# chime5 main directory path +# please change the path accordingly +chime5_corpus=/export/corpora4/CHiME5 +# chime6 data directories, which are generated from ${chime5_corpus}, +# to synchronize audio files across arrays and modify the annotation (JSON) file accordingly +chime6_corpus=${PWD}/CHiME6 +json_dir=${chime6_corpus}/transcriptions +audio_dir=${chime6_corpus}/audio + +enhanced_dir=enhanced +if [[ ${enhancement} == *gss* ]]; then + enhanced_dir=${enhanced_dir}_multiarray + enhancement=${enhancement}_multiarray +fi + +if [[ ${enhancement} == *beamformit* ]]; then + enhanced_dir=${enhanced_dir} + enhancement=${enhancement} +fi + +enhanced_dir=$(utils/make_absolute.sh $enhanced_dir) || exit 1 +test_sets="dev_${enhancement} eval_${enhancement}" + +# This script also needs the phonetisaurus g2p, srilm, beamformit +./local/check_tools.sh || exit 1 + +########################################################################### +# We first generate the synchronized audio files across arrays and +# corresponding JSON files. Note that this requires sox v14.4.2, +# which is installed via miniconda in ./local/check_tools.sh +########################################################################### + +if [ $stage -le 0 ]; then + local/generate_chime6_data.sh \ + --cmd "$train_cmd" \ + ${chime5_corpus} \ + ${chime6_corpus} +fi + +######################################################################################### +# In stage 1, we perform GSS based enhancement or beamformit for the test sets. multiarray = true +#can take around 10hrs for dev and eval set. +######################################################################################### + +if [ $stage -le 1 ] && [[ ${enhancement} == *gss* ]]; then + echo "$0: enhance data..." + # Guided Source Separation (GSS) from Paderborn University + # http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_boeddecker.pdf + # @Article{PB2018CHiME5, + # author = {Boeddeker, Christoph and Heitkaemper, Jens and Schmalenstroeer, Joerg and Drude, Lukas and Heymann, Jahn and Haeb-Umbach, Reinhold}, + # title = {{Front-End Processing for the CHiME-5 Dinner Party Scenario}}, + # year = {2018}, + # booktitle = {CHiME5 Workshop}, + # } + + if [ ! -d pb_chime5/ ]; then + local/install_pb_chime5.sh + fi + + if [ ! -f pb_chime5/cache/chime6.json ]; then + ( + cd pb_chime5 + miniconda_dir=$HOME/miniconda3/ + export PATH=$miniconda_dir/bin:$PATH + export CHIME6_DIR=$chime6_corpus + make cache/chime6.json + ) + fi + + for dset in dev eval; do + local/run_gss.sh \ + --cmd "$train_cmd --max-jobs-run $gss_nj" --nj 160 \ + ${dset} \ + ${enhanced_dir} \ + ${enhanced_dir} || exit 1 + done + + for dset in dev eval; do + local/prepare_data.sh --mictype gss ${enhanced_dir}/audio/${dset} \ + ${json_dir}/${dset} data/${dset}_${enhancement} || exit 1 + done +fi + +####################################################################### +# Prepare the dev and eval data with dereverberation (WPE) and +# beamforming. +####################################################################### + +if [ $stage -le 1 ] && [[ ${enhancement} == *beamformit* ]]; then + # Beamforming using reference arrays + # enhanced WAV directory + enhanced_dir=enhan + dereverb_dir=${PWD}/wav/wpe/ + for dset in dev eval; do + for mictype in u01 u02 u03 u04 u05 u06; do + local/run_wpe.sh --nj 4 --cmd "$train_cmd --mem 20G" \ + ${audio_dir}/${dset} \ + ${dereverb_dir}/${dset} \ + ${mictype} + done + done + + for dset in dev eval; do + for mictype in u01 u02 u03 u04 u05 u06; do + local/run_beamformit.sh --cmd "$train_cmd" \ + ${dereverb_dir}/${dset} \ + ${enhanced_dir}/${dset}_${enhancement}_${mictype} \ + ${mictype} + done + done + + for dset in dev eval; do + local/prepare_data.sh --mictype ref "$PWD/${enhanced_dir}/${dset}_${enhancement}_u0*" \ + ${json_dir}/${dset} data/${dset}_${enhancement} + done +fi + +# In GSS enhancement, we do not have array information in utterance ID +if [ $stage -le 2 ] && [[ ${enhancement} == *gss* ]]; then + # Split speakers up into 3-minute chunks. This doesn't hurt adaptation, and + # lets us use more jobs for decoding etc. + for dset in ${test_sets}; do + utils/copy_data_dir.sh data/${dset} data/${dset}_orig + done + + for dset in ${test_sets}; do + utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${dset}_orig data/${dset} + done +fi + +if [ $stage -le 2 ] && [[ ${enhancement} == *beamformit* ]]; then + # fix speaker ID issue (thanks to Dr. Naoyuki Kanda) + # add array ID to the speaker ID to avoid the use of other array information to meet regulations + # Before this fix + # $ head -n 2 data/eval_beamformit_ref_nosplit/utt2spk + # P01_S01_U02_KITCHEN.ENH-0000192-0001278 P01 + # P01_S01_U02_KITCHEN.ENH-0001421-0001481 P01 + # After this fix + # $ head -n 2 data/eval_beamformit_ref_nosplit_fix/utt2spk + # P01_S01_U02_KITCHEN.ENH-0000192-0001278 P01_U02 + # P01_S01_U02_KITCHEN.ENH-0001421-0001481 P01_U02 + echo "$0: fix data..." + for dset in ${test_sets}; do + utils/copy_data_dir.sh data/${dset} data/${dset}_nosplit + mkdir -p data/${dset}_nosplit_fix + for f in segments text wav.scp; do + if [ -f data/${dset}_nosplit/$f ]; then + cp data/${dset}_nosplit/$f data/${dset}_nosplit_fix + fi + done + awk -F "_" '{print $0 "_" $3}' data/${dset}_nosplit/utt2spk > data/${dset}_nosplit_fix/utt2spk + utils/utt2spk_to_spk2utt.pl data/${dset}_nosplit_fix/utt2spk > data/${dset}_nosplit_fix/spk2utt + done + + # Split speakers up into 3-minute chunks. This doesn't hurt adaptation, and + # lets us use more jobs for decoding etc. + for dset in ${test_sets}; do + utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${dset}_nosplit_fix data/${dset} + done +fi + +########################################################################## +# DECODING: we perform 2 stage decoding. +########################################################################## + +nnet3_affix=_${train_set}_cleaned_rvb +lm_suffix= + +if [ $stage -le 3 ]; then + # First the options that are passed through to run_ivector_common.sh + # (some of which are also used in this script directly). + + # The rest are configs specific to this script. Most of the parameters + # are just hardcoded at this level, in the commands below. + echo "$0: decode data..." + affix=1b # affix for the TDNN directory name + tree_affix= + tree_dir=exp/chain${nnet3_affix}/tree_sp${tree_affix:+_$tree_affix} + dir=exp/chain${nnet3_affix}/tdnn${affix}_sp + + # training options + # training chunk-options + chunk_width=140,100,160 + # we don't need extra left/right context for TDNN systems. + chunk_left_context=0 + chunk_right_context=0 + + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang${lm_suffix}/ \ + $tree_dir $tree_dir/graph${lm_suffix} || exit 1; + + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + local/nnet3/decode.sh --affix 2stage --pass2-decode-opts "--min-active 1000" \ + --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk 150 --nj $decode_nj \ + --ivector-dir exp/nnet3${nnet3_affix} \ + data/${data} data/lang${lm_suffix} \ + $tree_dir/graph${lm_suffix} \ + exp/chain${nnet3_affix}/tdnn${affix}_sp + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 +fi + +########################################################################## +# Scoring: here we obtain wer per session per location and overall WER +########################################################################## + +if [ $stage -le 4 ]; then + # final scoring to get the official challenge result + # please specify both dev and eval set directories so that the search parameters + # (insertion penalty and language model weight) will be tuned using the dev set + local/score_for_submit.sh --enhancement $enhancement --json $json_dir \ + --dev exp/chain${nnet3_affix}/tdnn1b_sp/decode${lm_suffix}_dev_${enhancement}_2stage \ + --eval exp/chain${nnet3_affix}/tdnn1b_sp/decode${lm_suffix}_eval_${enhancement}_2stage +fi diff --git a/egs/chime6/s5_track1/local/distant_audio_list b/egs/chime6/s5_track1/local/distant_audio_list new file mode 100644 index 000000000..710945b01 --- /dev/null +++ b/egs/chime6/s5_track1/local/distant_audio_list @@ -0,0 +1,372 @@ +S03_U01.CH1 +S03_U01.CH2 +S03_U01.CH3 +S03_U01.CH4 +S03_U02.CH1 +S03_U02.CH2 +S03_U02.CH3 +S03_U02.CH4 +S03_U03.CH1 +S03_U03.CH2 +S03_U03.CH3 +S03_U03.CH4 +S03_U04.CH1 +S03_U04.CH2 +S03_U04.CH3 +S03_U04.CH4 +S03_U05.CH1 +S03_U05.CH2 +S03_U05.CH3 +S03_U05.CH4 +S03_U06.CH1 +S03_U06.CH2 +S03_U06.CH3 +S03_U06.CH4 +S04_U01.CH1 +S04_U01.CH2 +S04_U01.CH3 +S04_U01.CH4 +S04_U02.CH1 +S04_U02.CH2 +S04_U02.CH3 +S04_U02.CH4 +S04_U03.CH1 +S04_U03.CH2 +S04_U03.CH3 +S04_U03.CH4 +S04_U04.CH1 +S04_U04.CH2 +S04_U04.CH3 +S04_U04.CH4 +S04_U05.CH1 +S04_U05.CH2 +S04_U05.CH3 +S04_U05.CH4 +S04_U06.CH1 +S04_U06.CH2 +S04_U06.CH3 +S04_U06.CH4 +S05_U01.CH1 +S05_U01.CH2 +S05_U01.CH3 +S05_U01.CH4 +S05_U02.CH1 +S05_U02.CH2 +S05_U02.CH3 +S05_U02.CH4 +S05_U05.CH1 +S05_U05.CH2 +S05_U05.CH3 +S05_U05.CH4 +S05_U06.CH1 +S05_U06.CH2 +S05_U06.CH3 +S05_U06.CH4 +S06_U01.CH1 +S06_U01.CH2 +S06_U01.CH3 +S06_U01.CH4 +S06_U02.CH1 +S06_U02.CH2 +S06_U02.CH3 +S06_U02.CH4 +S06_U03.CH1 +S06_U03.CH2 +S06_U03.CH3 +S06_U03.CH4 +S06_U04.CH1 +S06_U04.CH2 +S06_U04.CH3 +S06_U04.CH4 +S06_U05.CH1 +S06_U05.CH2 +S06_U05.CH3 +S06_U05.CH4 +S06_U06.CH1 +S06_U06.CH2 +S06_U06.CH3 +S06_U06.CH4 +S07_U01.CH1 +S07_U01.CH2 +S07_U01.CH3 +S07_U01.CH4 +S07_U02.CH1 +S07_U02.CH2 +S07_U02.CH3 +S07_U02.CH4 +S07_U03.CH1 +S07_U03.CH2 +S07_U03.CH3 +S07_U03.CH4 +S07_U04.CH1 +S07_U04.CH2 +S07_U04.CH3 +S07_U04.CH4 +S07_U05.CH1 +S07_U05.CH2 +S07_U05.CH3 +S07_U05.CH4 +S07_U06.CH1 +S07_U06.CH2 +S07_U06.CH3 +S07_U06.CH4 +S08_U01.CH1 +S08_U01.CH2 +S08_U01.CH3 +S08_U01.CH4 +S08_U02.CH1 +S08_U02.CH2 +S08_U02.CH3 +S08_U02.CH4 +S08_U03.CH1 +S08_U03.CH2 +S08_U03.CH3 +S08_U03.CH4 +S08_U04.CH1 +S08_U04.CH2 +S08_U04.CH3 +S08_U04.CH4 +S08_U05.CH1 +S08_U05.CH2 +S08_U05.CH3 +S08_U05.CH4 +S08_U06.CH1 +S08_U06.CH2 +S08_U06.CH3 +S08_U06.CH4 +S12_U01.CH1 +S12_U01.CH2 +S12_U01.CH3 +S12_U01.CH4 +S12_U02.CH1 +S12_U02.CH2 +S12_U02.CH3 +S12_U02.CH4 +S12_U03.CH1 +S12_U03.CH2 +S12_U03.CH3 +S12_U03.CH4 +S12_U04.CH1 +S12_U04.CH2 +S12_U04.CH3 +S12_U04.CH4 +S12_U05.CH1 +S12_U05.CH2 +S12_U05.CH3 +S12_U05.CH4 +S12_U06.CH1 +S12_U06.CH2 +S12_U06.CH3 +S12_U06.CH4 +S13_U01.CH1 +S13_U01.CH2 +S13_U01.CH3 +S13_U01.CH4 +S13_U02.CH1 +S13_U02.CH2 +S13_U02.CH3 +S13_U02.CH4 +S13_U03.CH1 +S13_U03.CH2 +S13_U03.CH3 +S13_U03.CH4 +S13_U04.CH1 +S13_U04.CH2 +S13_U04.CH3 +S13_U04.CH4 +S13_U05.CH1 +S13_U05.CH2 +S13_U05.CH3 +S13_U05.CH4 +S13_U06.CH1 +S13_U06.CH2 +S13_U06.CH3 +S13_U06.CH4 +S16_U01.CH1 +S16_U01.CH2 +S16_U01.CH3 +S16_U01.CH4 +S16_U02.CH1 +S16_U02.CH2 +S16_U02.CH3 +S16_U02.CH4 +S16_U03.CH1 +S16_U03.CH2 +S16_U03.CH3 +S16_U03.CH4 +S16_U04.CH1 +S16_U04.CH2 +S16_U04.CH3 +S16_U04.CH4 +S16_U05.CH1 +S16_U05.CH2 +S16_U05.CH3 +S16_U05.CH4 +S16_U06.CH1 +S16_U06.CH2 +S16_U06.CH3 +S16_U06.CH4 +S17_U01.CH1 +S17_U01.CH2 +S17_U01.CH3 +S17_U01.CH4 +S17_U02.CH1 +S17_U02.CH2 +S17_U02.CH3 +S17_U02.CH4 +S17_U03.CH1 +S17_U03.CH2 +S17_U03.CH3 +S17_U03.CH4 +S17_U04.CH1 +S17_U04.CH2 +S17_U04.CH3 +S17_U04.CH4 +S17_U05.CH1 +S17_U05.CH2 +S17_U05.CH3 +S17_U05.CH4 +S17_U06.CH1 +S17_U06.CH2 +S17_U06.CH3 +S17_U06.CH4 +S18_U01.CH1 +S18_U01.CH2 +S18_U01.CH3 +S18_U01.CH4 +S18_U02.CH1 +S18_U02.CH2 +S18_U02.CH3 +S18_U02.CH4 +S18_U03.CH1 +S18_U03.CH2 +S18_U03.CH3 +S18_U03.CH4 +S18_U04.CH1 +S18_U04.CH2 +S18_U04.CH3 +S18_U04.CH4 +S18_U05.CH1 +S18_U05.CH2 +S18_U05.CH3 +S18_U05.CH4 +S18_U06.CH1 +S18_U06.CH2 +S18_U06.CH3 +S18_U06.CH4 +S19_U01.CH1 +S19_U01.CH2 +S19_U01.CH3 +S19_U01.CH4 +S19_U02.CH1 +S19_U02.CH2 +S19_U02.CH3 +S19_U02.CH4 +S19_U03.CH1 +S19_U03.CH2 +S19_U03.CH3 +S19_U03.CH4 +S19_U04.CH1 +S19_U04.CH2 +S19_U04.CH3 +S19_U04.CH4 +S19_U05.CH1 +S19_U05.CH2 +S19_U05.CH3 +S19_U05.CH4 +S19_U06.CH1 +S19_U06.CH2 +S19_U06.CH3 +S19_U06.CH4 +S20_U01.CH1 +S20_U01.CH2 +S20_U01.CH3 +S20_U01.CH4 +S20_U02.CH1 +S20_U02.CH2 +S20_U02.CH3 +S20_U02.CH4 +S20_U03.CH1 +S20_U03.CH2 +S20_U03.CH3 +S20_U03.CH4 +S20_U04.CH1 +S20_U04.CH2 +S20_U04.CH3 +S20_U04.CH4 +S20_U05.CH1 +S20_U05.CH2 +S20_U05.CH3 +S20_U05.CH4 +S20_U06.CH1 +S20_U06.CH2 +S20_U06.CH3 +S20_U06.CH4 +S22_U01.CH1 +S22_U01.CH2 +S22_U01.CH3 +S22_U01.CH4 +S22_U02.CH1 +S22_U02.CH2 +S22_U02.CH3 +S22_U02.CH4 +S22_U04.CH1 +S22_U04.CH2 +S22_U04.CH3 +S22_U04.CH4 +S22_U05.CH1 +S22_U05.CH2 +S22_U05.CH3 +S22_U05.CH4 +S22_U06.CH1 +S22_U06.CH2 +S22_U06.CH3 +S22_U06.CH4 +S23_U01.CH1 +S23_U01.CH2 +S23_U01.CH3 +S23_U01.CH4 +S23_U02.CH1 +S23_U02.CH2 +S23_U02.CH3 +S23_U02.CH4 +S23_U03.CH1 +S23_U03.CH2 +S23_U03.CH3 +S23_U03.CH4 +S23_U04.CH1 +S23_U04.CH2 +S23_U04.CH3 +S23_U04.CH4 +S23_U05.CH1 +S23_U05.CH2 +S23_U05.CH3 +S23_U05.CH4 +S23_U06.CH1 +S23_U06.CH2 +S23_U06.CH3 +S23_U06.CH4 +S24_U01.CH1 +S24_U01.CH2 +S24_U01.CH3 +S24_U01.CH4 +S24_U02.CH1 +S24_U02.CH2 +S24_U02.CH3 +S24_U02.CH4 +S24_U03.CH1 +S24_U03.CH2 +S24_U03.CH3 +S24_U03.CH4 +S24_U04.CH1 +S24_U04.CH2 +S24_U04.CH3 +S24_U04.CH4 +S24_U05.CH1 +S24_U05.CH2 +S24_U05.CH3 +S24_U05.CH4 +S24_U06.CH1 +S24_U06.CH2 +S24_U06.CH3 +S24_U06.CH4 diff --git a/egs/chime6/s5_track1/local/extract_noises.py b/egs/chime6/s5_track1/local/extract_noises.py new file mode 100755 index 000000000..8f617752f --- /dev/null +++ b/egs/chime6/s5_track1/local/extract_noises.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 + +import argparse +import json +import logging +import os +import sys +import scipy.io.wavfile as siw +import math +import numpy as np + + +def get_args(): + parser = argparse.ArgumentParser( + """Extract noises from the corpus based on the non-speech regions. + e.g. {} /export/corpora4/CHiME5/audio/train/ \\ + /export/corpora4/CHiME5/transcriptions/train/ \\ + /export/b05/zhiqiw/noise/""".format(sys.argv[0])) + + parser.add_argument("--segment-length", default=20) + parser.add_argument("audio_dir", help="""Location of the CHiME5 Audio files. e.g. /export/corpora4/CHiME5/audio/train/""") + parser.add_argument("trans_dir", help="""Location of the CHiME5 Transcriptions. e.g. /export/corpora4/CHiME5/transcriptions/train/""") + parser.add_argument("audio_list", help="""List of ids of the CHiME5 recordings from which noise is extracted. e.g. local/distant_audio_list""") + parser.add_argument("out_dir", help="Output directory to write noise files. e.g. /export/b05/zhiqiw/noise/") + + args = parser.parse_args() + return args + + +def Trans_time(time, fs): + units = time.split(':') + time_second = float(units[0]) * 3600 + float(units[1]) * 60 + float(units[2]) + return int(time_second*fs) + + +# remove mic dependency for CHiME-6 +def Get_time(conf, tag, fs): + for i in conf: + st = Trans_time(i['start_time'], fs) + ed = Trans_time(i['end_time'], fs) + tag[st:ed] = 0 + return tag + + +def write_noise(out_dir, seg, audio, sig, tag, fs, cnt): + sig_noise = sig[np.nonzero(tag)] + for i in range(math.floor(len(sig_noise)/(seg*fs))): + siw.write(out_dir +'/noise'+str(cnt)+'.wav', fs, sig_noise[i*seg*fs:(i+1)*seg*fs]) + cnt += 1 + return cnt + + +def main(): + args = get_args() + + if not os.path.exists(args.out_dir): + os.makedirs(args.out_dir) + + wav_list = open(args.audio_list).readlines() + + cnt = 1 + for i, audio in enumerate(wav_list): + parts = audio.strip().split('.') + if len(parts) == 2: + # Assuming distant mic with name like S03_U01.CH1 + session, mic = parts[0].split('_') + channel = parts[1] + base_name = session + "_" + mic + "." + channel + else: + # Assuming close talk mic with name like S03_P09 + session, mic = audio.strip().split('_') + base_name = session + "_" + mic + fs, sig = siw.read(args.audio_dir + "/" + base_name + '.wav') + tag = np.ones(len(sig)) + if i == 0 or session != session_p: + with open(args.trans_dir + "/" + session + '.json') as f: + conf = json.load(f) + tag = Get_time(conf, tag, fs) + cnt = write_noise(args.out_dir, args.segment_length, audio, sig, tag, fs, cnt) + session_p = session + + +if __name__ == '__main__': + main() diff --git a/egs/chime6/s5_track1/local/extract_vad_weights.sh b/egs/chime6/s5_track1/local/extract_vad_weights.sh new file mode 100755 index 000000000..250b021bd --- /dev/null +++ b/egs/chime6/s5_track1/local/extract_vad_weights.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +# Copyright 2016 Johns Hopkins University (Author: Daniel Povey, Vijayaditya Peddinti) +# 2019 Vimal Manohar +# Apache 2.0. + +# This script converts lattices available from a first pass decode into a per-frame weights file +# The ctms generated from the lattices are filtered. Silence frames are assigned a low weight (e.g.0.00001) +# and voiced frames have a weight of 1. + +set -e + +stage=1 +cmd=run.pl +silence_weight=0.00001 +#end configuration section. + +. ./cmd.sh + +[ -f ./path.sh ] && . ./path.sh +. utils/parse_options.sh || exit 1; +if [ $# -ne 4 ]; then + echo "Usage: $0 [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + exit 1; +fi + +data_dir=$1 +lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. +decode_dir=$3 +output_wts_file_gz=$4 + +if [ $stage -le 1 ]; then + echo "$0: generating CTM from input lattices" + steps/get_ctm_conf.sh --cmd "$cmd" \ + --use-segments false \ + $data_dir \ + $lang \ + $decode_dir +fi + +if [ $stage -le 2 ]; then + name=`basename $data_dir` + # we just take the ctm from LMWT 10, it doesn't seem to affect the results a lot + ctm=$decode_dir/score_10/$name.ctm + echo "$0: generating weights file from ctm $ctm" + + pad_frames=0 # this did not seem to be helpful but leaving it as an option. + feat-to-len scp:$data_dir/feats.scp ark,t:- >$decode_dir/utt.lengths + if [ ! -f $ctm ]; then echo "$0: expected ctm to exist: $ctm"; exit 1; fi + + cat $ctm | awk '$6 == 1.0 && $4 < 1.0' | \ + grep -v -w mm | grep -v -w mhm | grep -v -F '[noise]' | \ + grep -v -F '[laughter]' | grep -v -F '' | \ + perl -e ' $lengths=shift @ARGV; $pad_frames=shift @ARGV; $silence_weight=shift @ARGV; + $pad_frames >= 0 || die "bad pad-frames value $pad_frames"; + open(L, "<$lengths") || die "opening lengths file"; + @all_utts = (); + $utt2ref = { }; + while () { + ($utt, $len) = split(" ", $_); + push @all_utts, $utt; + $array_ref = [ ]; + for ($n = 0; $n < $len; $n++) { ${$array_ref}[$n] = $silence_weight; } + $utt2ref{$utt} = $array_ref; + } + while () { + @A = split(" ", $_); + @A == 6 || die "bad ctm line $_"; + $utt = $A[0]; $beg = $A[2]; $len = $A[3]; + $beg_int = int($beg * 100) - $pad_frames; + $len_int = int($len * 100) + 2*$pad_frames; + $array_ref = $utt2ref{$utt}; + !defined $array_ref && die "No length info for utterance $utt"; + for ($t = $beg_int; $t < $beg_int + $len_int; $t++) { + if ($t >= 0 && $t < @$array_ref) { + ${$array_ref}[$t] = 1; + } + } + } + foreach $utt (@all_utts) { $array_ref = $utt2ref{$utt}; + print $utt, " [ ", join(" ", @$array_ref), " ]\n"; + } ' $decode_dir/utt.lengths $pad_frames $silence_weight | \ + gzip -c > $output_wts_file_gz +fi diff --git a/egs/chime6/s5_track1/local/generate_chime6_data.sh b/egs/chime6/s5_track1/local/generate_chime6_data.sh new file mode 100755 index 000000000..93106cf60 --- /dev/null +++ b/egs/chime6/s5_track1/local/generate_chime6_data.sh @@ -0,0 +1,121 @@ +#!/bin/bash + +# Copyright 2019, Johns Hopkins University (Author: Shinji Watanabe) +# Apache 2.0 +# +# This script generates synchronized audio data across arrays by considering +# the frame dropping, clock drift etc. done by Prof. Jon Barker at University of +# Sheffield. This script first downloads the synchronization tool and generate +# the synchronized audios and corresponding JSON transcription files +# Note that +# 1) the JSON format is slightly changed from the original CHiME-5 one (simplified +# thanks to the synchronization) +# 2) it requires sox v.14.4.2 and Python 3.6.7 +# Unfortunately, the generated files would be different depending on the sox +# and Python versions and to generate the exactly same audio files, this script uses +# the fixed versions of sox and Python installed in the miniconda instead of system ones + +. ./cmd.sh +. ./path.sh + +# Config: +cmd=run.pl + +. utils/parse_options.sh || exit 1; + +if [ $# != 2 ]; then + echo "Wrong #arguments ($#, expected 2)" + echo "Usage: local/generate_chime6_data.sh [options] " + echo "main options (for others, see top of script file)" + echo " --cmd # Command to run in parallel with" + exit 1; +fi + +sdir=$1 +odir=$2 +expdir=${PWD}/exp/chime6_data + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +# get chime6-synchronisation tools +SYNC_PATH=${PWD}/chime6-synchronisation +if [ ! -d ${SYNC_PATH} ]; then + git clone https://github.com/chimechallenge/chime6-synchronisation.git +fi + +mkdir -p ${odir} +mkdir -p ${expdir}/log + +# split the session to avoid too much disk access +sessions1="S01 S02 S03 S04 S05 S06 S07" +sessions2="S08 S09 S12 S13 S16 S17 S18" +sessions3="S19 S20 S21 S22 S23 S24" + +CONDA_PATH=${HOME}/miniconda3/bin +IN_PATH=${sdir}/audio +OUT_PATH=${odir}/audio +TMP_PATH=${odir}/audio_tmp + +if [ ! -d "${IN_PATH}" ]; then + echo "please specify the CHiME-5 data path correctly" + exit 1 +fi +mkdir -p $OUT_PATH/train $OUT_PATH/eval $OUT_PATH/dev +mkdir -p $TMP_PATH/train $TMP_PATH/eval $TMP_PATH/dev + +if [ -f ${odir}/audio/dev/S02_P05.wav ]; then + echo "CHiME-6 date already exists" + exit 0 +fi + +pushd ${SYNC_PATH} +echo "Correct for frame dropping" +for session in ${sessions1}; do + $cmd ${expdir}/correct_signals_for_frame_drops.${session}.log \ + ${CONDA_PATH}/python correct_signals_for_frame_drops.py --session=${session} chime6_audio_edits.json $IN_PATH $TMP_PATH & +done +wait +for session in ${sessions2}; do + $cmd ${expdir}/correct_signals_for_frame_drops.${session}.log \ + ${CONDA_PATH}/python correct_signals_for_frame_drops.py --session=${session} chime6_audio_edits.json $IN_PATH $TMP_PATH & +done +wait +for session in ${sessions3}; do + $cmd ${expdir}/correct_signals_for_frame_drops.${session}.log \ + ${CONDA_PATH}/python correct_signals_for_frame_drops.py --session=${session} chime6_audio_edits.json $IN_PATH $TMP_PATH & +done +wait + +echo "Sox processing for correcting clock drift" +for session in ${sessions1}; do + $cmd ${expdir}/correct_signals_for_clock_drift.${session}.log \ + ${CONDA_PATH}/python correct_signals_for_clock_drift.py --session=${session} --sox_path $CONDA_PATH chime6_audio_edits.json $TMP_PATH $OUT_PATH & +done +wait +for session in ${sessions2}; do + $cmd ${expdir}/correct_signals_for_clock_drift.${session}.log \ + ${CONDA_PATH}/python correct_signals_for_clock_drift.py --session=${session} --sox_path $CONDA_PATH chime6_audio_edits.json $TMP_PATH $OUT_PATH & +done +wait +for session in ${sessions3}; do + $cmd ${expdir}/correct_signals_for_clock_drift.${session}.log \ + ${CONDA_PATH}/python correct_signals_for_clock_drift.py --session=${session} --sox_path $CONDA_PATH chime6_audio_edits.json $TMP_PATH $OUT_PATH & +done +wait + +echo "adjust the JSON files" +mkdir -p ${odir}/transcriptions/eval ${odir}/transcriptions/dev ${odir}/transcriptions/train +${CONDA_PATH}/python correct_transcript_for_clock_drift.py --clock_drift_data chime6_audio_edits.json ${sdir}/transcriptions ${odir}/transcriptions +popd + +# finally check md5sum +pushd ${odir} +echo "check MD5 hash value for generated audios" +md5sum -c ${SYNC_PATH}/audio_md5sums.txt || echo "check https://github.com/chimechallenge/chime6-synchronisation" +popd + +echo "`basename $0` Done." diff --git a/egs/chime6/s5_track1/local/get_location.py b/egs/chime6/s5_track1/local/get_location.py new file mode 100755 index 000000000..92351e72e --- /dev/null +++ b/egs/chime6/s5_track1/local/get_location.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright Ashish Arora +# Apache 2.0 +# This script create a utterance and location mapping file +# It is used in score_for_submit script to get locationwise WER. +# for GSS enhancement + +import json +from datetime import timedelta +from glob import glob +import sys, io +from decimal import Decimal + +SAMPLE_RATE = 16000 + +def to_samples(time: str): + "mapping time in string to int, as mapped in pb_chime5" + "see https://github.com/fgnt/pb_chime5/blob/master/pb_chime5/database/chime5/get_speaker_activity.py" + hours, minutes, seconds = [t for t in time.split(':')] + hours = int(hours) + minutes = int(minutes) + seconds = Decimal(seconds) + + seconds_samples = seconds * SAMPLE_RATE + samples = ( + hours * 3600 * SAMPLE_RATE + + minutes * 60 * SAMPLE_RATE + + seconds_samples + ) + return int(samples) + + +def main(): + output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + json_file_location= sys.argv[1] + '/*.json' + json_files = glob(json_file_location) + + json_file_location= sys.argv[1] + '/*.json' + json_files = glob(json_file_location) + location_dict = {} + json_file_location= sys.argv[1] + '/*.json' + json_files = glob(json_file_location) + location_dict = {} + for file in json_files: + with open(file, 'r') as f: + session_dict = json.load(f) + + for uttid in session_dict: + try: + ref=uttid['ref'] + speaker_id = uttid['speaker'] + location = uttid['location'] + location=location.upper() + session_id=uttid['session_id'] + words = uttid['words'] + end_sample=to_samples(str(uttid['end_time'])) + start_sample=to_samples(str(uttid['start_time'])) + start_sample_str = str(int(start_sample * 100 / SAMPLE_RATE)).zfill(7) + end_sample_str = str(int(end_sample * 100 / SAMPLE_RATE)).zfill(7) + utt = "{0}_{1}-{2}-{3}".format(speaker_id, session_id, start_sample_str, end_sample_str) + location_dict[utt]=(location) + except: + continue + + for key in sorted(location_dict.keys()): + utt= "{0} {1}".format(key, location_dict[key]) + output.write(utt+ '\n') + +if __name__ == '__main__': + main() diff --git a/egs/chime6/s5_track1/local/install_pb_chime5.sh b/egs/chime6/s5_track1/local/install_pb_chime5.sh new file mode 100755 index 000000000..a151dc60f --- /dev/null +++ b/egs/chime6/s5_track1/local/install_pb_chime5.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# Installs pb_chime5 +# miniconda should be installed in $HOME/miniconda3/ + +miniconda_dir=$HOME/miniconda3/ + +if [ ! -d $miniconda_dir ]; then + echo "$miniconda_dir does not exist. Please run 'tools/extras/install_miniconda.sh" && exit 1; +fi + +git clone https://github.com/fgnt/pb_chime5.git +cd pb_chime5 +# Download submodule dependencies # https://stackoverflow.com/a/3796947/5766934 +git submodule init +git submodule update + +$miniconda_dir/bin/python -m pip install cython +$miniconda_dir/bin/python -m pip install pymongo +$miniconda_dir/bin/python -m pip install fire +$miniconda_dir/bin/python -m pip install -e pb_bss/ +$miniconda_dir/bin/python -m pip install -e . diff --git a/egs/chime6/s5_track1/local/json2text.py b/egs/chime6/s5_track1/local/json2text.py new file mode 100755 index 000000000..34cf52f08 --- /dev/null +++ b/egs/chime6/s5_track1/local/json2text.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import json +import argparse +import logging +import sys + + +def hms_to_seconds(hms): + hour = hms.split(':')[0] + minute = hms.split(':')[1] + second = hms.split(':')[2].split('.')[0] + + # .xx (10 ms order) + ms10 = hms.split(':')[2].split('.')[1] + + # total seconds + seconds = int(hour) * 3600 + int(minute) * 60 + int(second) + + return '{:07d}'.format(int(str(seconds) + ms10)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('json', type=str, help='JSON transcription file') + parser.add_argument('--mictype', type=str, + choices=['ref', 'worn', 'gss', 'u01', 'u02', 'u03', 'u04', 'u05', 'u06'], + help='Type of microphones') + args = parser.parse_args() + + # logging info + log_format = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s:%(message)s" + logging.basicConfig(level=logging.INFO, format=log_format) + + logging.debug("reading %s", args.json) + with open(args.json, 'rt', encoding="utf-8") as f: + j = json.load(f) + + for x in j: + if '[redacted]' not in x['words']: + session_id = x['session_id'] + speaker_id = x['speaker'] + if args.mictype == 'ref': + mictype = x['ref'] + elif args.mictype == 'worn' or args.mictype == 'gss': + mictype = 'original' + else: + mictype = args.mictype.upper() # convert from u01 to U01 + + # add location tag for scoring (only for dev and eval sets) + if 'location' in x.keys(): + location = x['location'].upper() + else: + location = 'NOLOCATION' + + # remove mic dependency for CHiME-6 + start_time = x['start_time'] + end_time = x['end_time'] + + # remove meta chars and convert to lower + words = x['words'].replace('"', '')\ + .replace('.', '')\ + .replace('?', '')\ + .replace(',', '')\ + .replace(':', '')\ + .replace(';', '')\ + .replace('!', '').lower() + + # remove multiple spaces + words = " ".join(words.split()) + + # convert to seconds, e.g., 1:10:05.55 -> 3600 + 600 + 5.55 = 4205.55 + start_time = hms_to_seconds(start_time) + end_time = hms_to_seconds(end_time) + + uttid = speaker_id + '_' + session_id + if not args.mictype in ['worn', 'gss']: + uttid += '_' + mictype + + if args.mictype == 'gss': + uttid += '-' + start_time + '-' + end_time + else: + uttid += '_' + location + '-' + start_time + '-' + end_time + + # In several utterances, there are inconsistency in the time stamp + # (the end time is earlier than the start time) + # We just ignored such utterances. + if end_time > start_time: + sys.stdout.buffer.write((uttid + ' ' + words + '\n').encode("utf-8")) diff --git a/egs/chime6/s5_track1/local/make_noise_list.py b/egs/chime6/s5_track1/local/make_noise_list.py new file mode 100755 index 000000000..5aaf7fa40 --- /dev/null +++ b/egs/chime6/s5_track1/local/make_noise_list.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +import glob +import os +import sys + + +if len(sys.argv) != 2: + print ("Usage: {} ".format(sys.argv[0])) + raise SystemExit(1) + + +for line in glob.glob("{}/*.wav".format(sys.argv[1])): + fname = os.path.basename(line.strip()) + + print ("--noise-id {} --noise-type point-source " + "--bg-fg-type foreground {}".format(fname, line.strip())) diff --git a/egs/chime6/s5_track1/local/nnet3/compare_wer.sh b/egs/chime6/s5_track1/local/nnet3/compare_wer.sh new file mode 100755 index 000000000..095e85cc3 --- /dev/null +++ b/egs/chime6/s5_track1/local/nnet3/compare_wer.sh @@ -0,0 +1,132 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/tdnn_{c,d}_sp +# For use with discriminatively trained systems you specify the epochs after a colon: +# for instance, +# local/chain/compare_wer.sh exp/chain/tdnn_c_sp exp/chain/tdnn_c_sp_smbr:{1,2,3} + + +if [ $# == 0 ]; then + echo "Usage: $0: [--looped] [--online] [ ... ]" + echo "e.g.: $0 exp/chain/tdnn_{b,c}_sp" + echo "or (with epoch numbers for discriminative training):" + echo "$0 exp/chain/tdnn_b_sp_disc:{1,2,3}" + exit 1 +fi + +echo "# $0 $*" + +include_looped=false +if [ "$1" == "--looped" ]; then + include_looped=true + shift +fi +include_online=false +if [ "$1" == "--online" ]; then + include_online=true + shift +fi + + +used_epochs=false + +# this function set_names is used to separate the epoch-related parts of the name +# [for discriminative training] and the regular parts of the name. +# If called with a colon-free directory name, like: +# set_names exp/chain/tdnn_lstm1e_sp_bi_smbr +# it will set dir=exp/chain/tdnn_lstm1e_sp_bi_smbr and epoch_infix="" +# If called with something like: +# set_names exp/chain/tdnn_d_sp_smbr:3 +# it will set dir=exp/chain/tdnn_d_sp_smbr and epoch_infix="_epoch3" + + +set_names() { + if [ $# != 1 ]; then + echo "compare_wer_general.sh: internal error" + exit 1 # exit the program + fi + dirname=$(echo $1 | cut -d: -f1) + epoch=$(echo $1 | cut -s -d: -f2) + if [ -z $epoch ]; then + epoch_infix="" + else + used_epochs=true + epoch_infix=_epoch${epoch} + fi +} + + + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +strings=( + "#WER dev_clean_2 (tgsmall) " + "#WER dev_clean_2 (tglarge) ") + +for n in 0 1; do + echo -n "${strings[$n]}" + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + decode_names=(tgsmall_dev_clean_2 tglarge_dev_clean_2) + + wer=$(cat $dirname/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + if $include_looped; then + echo -n "# [looped:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat $dirname/decode_looped_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi + if $include_online; then + echo -n "# [online:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat ${dirname}_online/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi +done + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.{final,combined}.log 2>/dev/null | grep log-like | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.{final,combined}.log 2>/dev/null | grep log-like | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train acc " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.{final,combined}.log 2>/dev/null | grep accuracy | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid acc " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.{final,combined}.log 2>/dev/null | grep accuracy | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo diff --git a/egs/chime6/s5_track1/local/nnet3/decode.sh b/egs/chime6/s5_track1/local/nnet3/decode.sh new file mode 100755 index 000000000..8fa54e0d4 --- /dev/null +++ b/egs/chime6/s5_track1/local/nnet3/decode.sh @@ -0,0 +1,164 @@ +#!/bin/bash + +# Copyright 2016 Johns Hopkins University (Author: Daniel Povey, Vijayaditya Peddinti) +# 2019 Vimal Manohar +# Apache 2.0. + +# This script does 2-stage decoding where the first stage is used to get +# reliable frames for i-vector extraction. + +set -e + +# general opts +iter= +stage=0 +nj=30 +affix= # affix for decode directory + +# ivector opts +max_count=75 # parameter for extract_ivectors.sh +sub_speaker_frames=6000 +ivector_scale=0.75 +get_weights_from_ctm=true +weights_file= # use weights from this archive (must be compressed using gunzip) +silence_weight=0.00001 # apply this weight to silence frames during i-vector extraction +ivector_dir=exp/nnet3 + +# decode opts +pass2_decode_opts="--min-active 1000" +lattice_beam=8 +extra_left_context=0 # change for (B)LSTM +extra_right_context=0 # change for BLSTM +frames_per_chunk=50 # change for (B)LSTM +acwt=0.1 # important to change this when using chain models +post_decode_acwt=1.0 # important to change this when using chain models +extra_left_context_initial=0 +extra_right_context_final=0 + +graph_affix= + +score_opts="--min-lmwt 6 --max-lmwt 13" + +. ./cmd.sh +[ -f ./path.sh ] && . ./path.sh +. utils/parse_options.sh || exit 1; + +if [ $# -ne 4 ]; then + echo "Usage: $0 [options] " + echo " Options:" + echo " --stage (0|1|2) # start scoring script from part-way through." + echo "e.g.:" + echo "$0 data/dev data/lang exp/tri5a/graph_pp exp/nnet3/tdnn" + exit 1; +fi + +data=$1 # data directory +lang=$2 # data/lang +graph=$3 #exp/tri5a/graph_pp +dir=$4 # exp/nnet3/tdnn + +model_affix=`basename $dir` +ivector_affix=${affix:+_$affix}_chain_${model_affix}${iter:+_iter$iter} +affix=${affix:+_${affix}}${iter:+_iter${iter}} + +if [ $stage -le 1 ]; then + if [ ! -s ${data}_hires/feats.scp ]; then + utils/copy_data_dir.sh $data ${data}_hires + steps/make_mfcc.sh --mfcc-config conf/mfcc_hires.conf --nj $nj --cmd "$train_cmd" ${data}_hires + steps/compute_cmvn_stats.sh ${data}_hires + utils/fix_data_dir.sh ${data}_hires + fi +fi + +data_set=$(basename $data) +if [ $stage -le 2 ]; then + echo "Extracting i-vectors, stage 1" + steps/online/nnet2/extract_ivectors_online.sh --cmd "$train_cmd" --nj $nj \ + --max-count $max_count \ + ${data}_hires $ivector_dir/extractor \ + $ivector_dir/ivectors_${data_set}${ivector_affix}_stage1; + # float comparisons are hard in bash + if [ `bc <<< "$ivector_scale != 1"` -eq 1 ]; then + ivector_scale_affix=_scale$ivector_scale + else + ivector_scale_affix= + fi + + if [ ! -z "$ivector_scale_affix" ]; then + echo "$0: Scaling iVectors, stage 1" + srcdir=$ivector_dir/ivectors_${data_set}${ivector_affix}_stage1 + outdir=$ivector_dir/ivectors_${data_set}${ivector_affix}${ivector_scale_affix}_stage1 + mkdir -p $outdir + $train_cmd $outdir/log/scale_ivectors.log \ + copy-matrix --scale=$ivector_scale scp:$srcdir/ivector_online.scp ark:- \| \ + copy-feats --compress=true ark:- ark,scp:$outdir/ivector_online.ark,$outdir/ivector_online.scp; + cp $srcdir/ivector_period $outdir/ivector_period + fi +fi + +decode_dir=$dir/decode${graph_affix}_${data_set}${affix} +# generate the lattices +if [ $stage -le 3 ]; then + echo "Generating lattices, stage 1" + steps/nnet3/decode.sh --nj $nj --cmd "$decode_cmd" \ + --acwt $acwt --post-decode-acwt $post_decode_acwt \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial $extra_left_context_initial \ + --extra-right-context-final $extra_right_context_final \ + --frames-per-chunk "$frames_per_chunk" \ + --online-ivector-dir $ivector_dir/ivectors_${data_set}${ivector_affix}${ivector_scale_affix}_stage1 \ + --skip-scoring true ${iter:+--iter $iter} \ + $graph ${data}_hires ${decode_dir}_stage1; +fi + +if [ $stage -le 4 ]; then + if $get_weights_from_ctm; then + if [ ! -z $weights_file ]; then + echo "$0: Using provided vad weights file $weights_file" + ivector_extractor_weights=$weights_file + else + echo "$0 : Generating vad weights file" + ivector_extractor_weights=${decode_dir}_stage1/weights${affix}.gz + local/extract_vad_weights.sh --silence-weight $silence_weight \ + --cmd "$decode_cmd" ${iter:+--iter $iter} \ + ${data}_hires $lang \ + ${decode_dir}_stage1 $ivector_extractor_weights + fi + else + # get weights from best path decoding + ivector_extractor_weights=${decode_dir}_stage1 + fi +fi + +if [ $stage -le 5 ]; then + echo "Extracting i-vectors, stage 2 with weights from $ivector_extractor_weights" + # this does offline decoding, except we estimate the iVectors per + # speaker, excluding silence (based on alignments from a DNN decoding), with a + # different script. This is just to demonstrate that script. + # the --sub-speaker-frames is optional; if provided, it will divide each speaker + # up into "sub-speakers" of at least that many frames... can be useful if + # acoustic conditions drift over time within the speaker's data. + steps/online/nnet2/extract_ivectors.sh --cmd "$train_cmd" --nj $nj \ + --silence-weight $silence_weight \ + --sub-speaker-frames $sub_speaker_frames --max-count $max_count \ + ${data}_hires $lang $ivector_dir/extractor \ + $ivector_extractor_weights $ivector_dir/ivectors_${data_set}${ivector_affix}; +fi + +if [ $stage -le 6 ]; then + echo "Generating lattices, stage 2 with --acwt $acwt" + rm -f ${decode_dir}/.error + steps/nnet3/decode.sh --nj $nj --cmd "$decode_cmd" $pass2_decode_opts \ + --acwt $acwt --post-decode-acwt $post_decode_acwt \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial $extra_left_context_initial \ + --extra-right-context-final $extra_right_context_final \ + --frames-per-chunk "$frames_per_chunk" \ + --skip-scoring false ${iter:+--iter $iter} --lattice-beam $lattice_beam \ + --online-ivector-dir $ivector_dir/ivectors_${data_set}${ivector_affix} \ + $graph ${data}_hires ${decode_dir} || touch ${decode_dir}/.error + [ -f ${decode_dir}/.error ] && echo "$0: Error decoding" && exit 1; +fi +exit 0 diff --git a/egs/chime6/s5_track1/local/nnet3/run_ivector_common.sh b/egs/chime6/s5_track1/local/nnet3/run_ivector_common.sh new file mode 100755 index 000000000..3910e1812 --- /dev/null +++ b/egs/chime6/s5_track1/local/nnet3/run_ivector_common.sh @@ -0,0 +1,151 @@ +#!/bin/bash + +set -euo pipefail + +# This script is called from local/nnet3/run_tdnn.sh and +# local/chain/run_tdnn.sh (and may eventually be called by more +# scripts). It contains the common feature preparation and +# iVector-related parts of the script. See those scripts for examples +# of usage. + +stage=0 +train_set=train_worn_u100k +test_sets="dev_worn dev_beamformit_ref" +gmm=tri3 +nj=96 + +nnet3_affix=_train_worn_u100k + +. ./cmd.sh +. ./path.sh +. utils/parse_options.sh + +gmm_dir=exp/${gmm} +ali_dir=exp/${gmm}_ali_${train_set}_sp + +for f in ${gmm_dir}/final.mdl; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + +if [ $stage -le 1 ]; then + # Although the nnet will be trained by high resolution data, we still have to + # perturb the normal data to get the alignment _sp stands for speed-perturbed + echo "$0: preparing directory for low-resolution speed-perturbed data (for alignment)" + utils/data/perturb_data_dir_speed_3way.sh data/${train_set} data/${train_set}_sp + echo "$0: making MFCC features for low-resolution speed-perturbed data" + steps/make_mfcc.sh --cmd "$train_cmd" --nj 20 data/${train_set}_sp || exit 1; + steps/compute_cmvn_stats.sh data/${train_set}_sp || exit 1; + utils/fix_data_dir.sh data/${train_set}_sp +fi + +if [ $stage -le 2 ]; then + echo "$0: aligning with the perturbed low-resolution data" + steps/align_fmllr.sh --nj ${nj} --cmd "$train_cmd" \ + data/${train_set}_sp data/lang $gmm_dir $ali_dir || exit 1 +fi + +if [ $stage -le 3 ]; then + # Create high-resolution MFCC features (with 40 cepstra instead of 13). + # this shows how you can split across multiple file-systems. + echo "$0: creating high-resolution MFCC features" + mfccdir=data/${train_set}_sp_hires/data + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/b1{5,6,8,9}/$USER/kaldi-data/mfcc/chime5-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + for datadir in ${train_set}_sp ${test_sets}; do + utils/copy_data_dir.sh data/$datadir data/${datadir}_hires + done + + # do volume-perturbation on the training data prior to extracting hires + # features; this helps make trained nnets more invariant to test data volume. + utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires || exit 1; + + for datadir in ${train_set}_sp ${test_sets}; do + steps/make_mfcc.sh --nj 20 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${datadir}_hires || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_hires || exit 1; + utils/fix_data_dir.sh data/${datadir}_hires || exit 1; + done +fi + +if [ $stage -le 4 ]; then + echo "$0: computing a subset of data to train the diagonal UBM." + # We'll use about a quarter of the data. + mkdir -p exp/nnet3${nnet3_affix}/diag_ubm + temp_data_root=exp/nnet3${nnet3_affix}/diag_ubm + + num_utts_total=$(wc -l &2 "$0" "$@" +if [ $# -ne 3 ] ; then + echo >&2 "$0" "$@" + echo >&2 "$0: Error: wrong number of arguments" + echo -e >&2 "Usage:\n $0 [opts] " + echo -e >&2 "eg:\n $0 /corpora/chime5/audio/train /corpora/chime5/transcriptions/train data/train" + exit 1 +fi + +set -e -o pipefail + +adir=$(utils/make_absolute.sh $1) +jdir=$2 +dir=$3 + +json_count=$(find -L $jdir -name "*.json" | wc -l) +wav_count=$(find -L $adir -name "*.wav" | wc -l) + +if [ "$json_count" -eq 0 ]; then + echo >&2 "We expect that the directory $jdir will contain json files." + echo >&2 "That implies you have supplied a wrong path to the data." + exit 1 +fi +if [ "$wav_count" -eq 0 ]; then + echo >&2 "We expect that the directory $adir will contain wav files." + echo >&2 "That implies you have supplied a wrong path to the data." + exit 1 +fi + +echo "$0: Converting transcription to text" + +mkdir -p $dir + +for file in $jdir/*json; do + ./local/json2text.py --mictype $mictype $file +done | \ + sed -e "s/\[inaudible[- 0-9]*\]/[inaudible]/g" |\ + sed -e 's/ - / /g' |\ + sed -e 's/mm-/mm/g' > $dir/text.orig + +echo "$0: Creating datadir $dir for type=\"$mictype\"" + +if [ $mictype == "worn" ]; then + # convert the filenames to wav.scp format, use the basename of the file + # as a the wav.scp key, add .L and .R for left and right channel + # i.e. each file will have two entries (left and right channel) + find -L $adir -name "S[0-9]*_P[0-9]*.wav" | \ + perl -ne '{ + chomp; + $path = $_; + next unless $path; + @F = split "/", $path; + ($f = $F[@F-1]) =~ s/.wav//; + @F = split "_", $f; + print "${F[1]}_${F[0]}.L sox $path -t wav - remix 1 |\n"; + print "${F[1]}_${F[0]}.R sox $path -t wav - remix 2 |\n"; + }' | sort > $dir/wav.scp + + # generate the transcripts for both left and right channel + # from the original transcript in the form + # P09_S03-0006072-0006147 gimme the baker + # create left and right channel transcript + # P09_S03.L-0006072-0006147 gimme the baker + # P09_S03.R-0006072-0006147 gimme the baker + sed -n 's/ *$//; h; s/-/\.L-/p; g; s/-/\.R-/p' $dir/text.orig | sort > $dir/text +elif [ $mictype == "ref" ]; then + # fixed reference array + + # first get a text, which will be used to extract reference arrays + perl -ne 's/-/.ENH-/;print;' $dir/text.orig | sort > $dir/text + + find -L $adir | grep "\.wav" | sort > $dir/wav.flist + # following command provide the argument for grep to extract only reference arrays + grep `cut -f 1 -d"-" $dir/text | awk -F"_" '{print $2 "_" $3}' | sed -e "s/\.ENH//" | sort | uniq | sed -e "s/^/ -e /" | tr "\n" " "` $dir/wav.flist > $dir/wav.flist2 + paste -d" " \ + <(awk -F "/" '{print $NF}' $dir/wav.flist2 | sed -e "s/\.wav/.ENH/") \ + $dir/wav.flist2 | sort > $dir/wav.scp +elif [ $mictype == "gss" ]; then + find -L $adir -name "P[0-9]*_S[0-9]*.wav" | \ + perl -ne '{ + chomp; + $path = $_; + next unless $path; + @F = split "/", $path; + ($f = $F[@F-1]) =~ s/.wav//; + print "$f $path\n"; + }' | sort > $dir/wav.scp + + cat $dir/text.orig | sort > $dir/text +else + # array mic case + # convert the filenames to wav.scp format, use the basename of the file + # as a the wav.scp key + find -L $adir -name "*.wav" -ipath "*${mictype}*" |\ + perl -ne '$p=$_;chomp $_;@F=split "/";$F[$#F]=~s/\.wav//;print "$F[$#F] $p";' |\ + sort -u > $dir/wav.scp + + # convert the transcripts from + # P09_S03-0006072-0006147 gimme the baker + # to the per-channel transcripts + # P09_S03_U01_NOLOCATION.CH1-0006072-0006147 gimme the baker + # P09_S03_U01_NOLOCATION.CH2-0006072-0006147 gimme the baker + # P09_S03_U01_NOLOCATION.CH3-0006072-0006147 gimme the baker + # P09_S03_U01_NOLOCATION.CH4-0006072-0006147 gimme the baker + perl -ne '$l=$_; + for($i=1; $i<=4; $i++) { + ($x=$l)=~ s/-/.CH\Q$i\E-/; + print $x;}' $dir/text.orig | sort > $dir/text + +fi +$cleanup && rm -f $dir/text.* $dir/wav.scp.* $dir/wav.flist + +# Prepare 'segments', 'utt2spk', 'spk2utt' +if [ $mictype == "worn" ]; then + cut -d" " -f 1 $dir/text | \ + awk -F"-" '{printf("%s %s %08.2f %08.2f\n", $0, $1, $2/100.0, $3/100.0)}' |\ + sed -e "s/_[A-Z]*\././2" \ + > $dir/segments +elif [ $mictype == "ref" ]; then + cut -d" " -f 1 $dir/text | \ + awk -F"-" '{printf("%s %s %08.2f %08.2f\n", $0, $1, $2/100.0, $3/100.0)}' |\ + sed -e "s/_[A-Z]*\././2" |\ + sed -e "s/ P.._/ /" > $dir/segments +elif [ $mictype != "gss" ]; then + cut -d" " -f 1 $dir/text | \ + awk -F"-" '{printf("%s %s %08.2f %08.2f\n", $0, $1, $2/100.0, $3/100.0)}' |\ + sed -e "s/_[A-Z]*\././2" |\ + sed -e 's/ P.._/ /' > $dir/segments +fi + +cut -f 1 -d ' ' $dir/text | \ + perl -ne 'chomp;$utt=$_;s/_.*//;print "$utt $_\n";' > $dir/utt2spk + +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +# Check that data dirs are okay! +utils/validate_data_dir.sh --no-feats $dir || exit 1 diff --git a/egs/chime6/s5_track1/local/prepare_dict.sh b/egs/chime6/s5_track1/local/prepare_dict.sh new file mode 100755 index 000000000..09083d0e7 --- /dev/null +++ b/egs/chime6/s5_track1/local/prepare_dict.sh @@ -0,0 +1,124 @@ +#!/bin/bash +# Copyright (c) 2018, Johns Hopkins University (Jan "Yenda" Trmal) +# License: Apache 2.0 + +# Begin configuration section. +# End configuration section +. ./utils/parse_options.sh + +. ./path.sh + +set -e -o pipefail +set -o nounset # Treat unset variables as an error + + +# The parts of the output of this that will be needed are +# [in data/local/dict/ ] +# lexicon.txt +# extra_questions.txt +# nonsilence_phones.txt +# optional_silence.txt +# silence_phones.txt + + +# check existing directories +[ $# != 0 ] && echo "Usage: $0" && exit 1; + +dir=data/local/dict + +mkdir -p $dir +echo "$0: Getting CMU dictionary" +if [ ! -f $dir/cmudict.done ]; then + [ -d $dir/cmudict ] && rm -rf $dir/cmudict + svn co https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict $dir/cmudict + touch $dir/cmudict.done +fi + +# silence phones, one per line. +for w in sil spn inaudible laughs noise; do + echo $w; +done > $dir/silence_phones.txt +echo sil > $dir/optional_silence.txt + +# For this setup we're discarding stress. +cat $dir/cmudict/cmudict-0.7b.symbols | \ + perl -ne 's:[0-9]::g; s:\r::; print lc($_)' | \ + sort -u > $dir/nonsilence_phones.txt + +# An extra question will be added by including the silence phones in one class. +paste -d ' ' -s $dir/silence_phones.txt > $dir/extra_questions.txt + +grep -v ';;;' $dir/cmudict/cmudict-0.7b |\ + uconv -f latin1 -t utf-8 -x Any-Lower |\ + perl -ne 's:(\S+)\(\d+\) :$1 :; s: : :; print;' |\ + perl -ne '@F = split " ",$_,2; $F[1] =~ s/[0-9]//g; print "$F[0] $F[1]";' \ + > $dir/lexicon1_raw_nosil.txt || exit 1; + +# Add prons for laughter, noise, oov +for w in `grep -v sil $dir/silence_phones.txt`; do + echo "[$w] $w" +done | cat - $dir/lexicon1_raw_nosil.txt > $dir/lexicon2_raw.txt || exit 1; + +# we keep all words from the cmudict in the lexicon +# might reduce OOV rate on dev and eval +cat $dir/lexicon2_raw.txt \ + <( echo "mm m" + echo " spn" + echo "cuz k aa z" + echo "cuz k ah z" + echo "cuz k ao z" + echo "mmm m"; \ + echo "hmm hh m"; \ + ) | sort -u | sed 's/[\t ]/\t/' > $dir/iv_lexicon.txt + + +cat data/train*/text | \ + awk '{for (n=2;n<=NF;n++){ count[$n]++; } } END { for(n in count) { print count[n], n; }}' | \ + sort -nr > $dir/word_counts + +cat $dir/word_counts | awk '{print $2}' > $dir/word_list + +awk '{print $1}' $dir/iv_lexicon.txt | \ + perl -e '($word_counts)=@ARGV; + open(W, "<$word_counts")||die "opening word-counts $word_counts"; + while() { chop; $seen{$_}=1; } + while() { + ($c,$w) = split; + if (!defined $seen{$w}) { print; } + } ' $dir/word_counts > $dir/oov_counts.txt + +echo "*Highest-count OOVs (including fragments) are:" +head -n 10 $dir/oov_counts.txt +echo "*Highest-count OOVs (excluding fragments) are:" +grep -v -E '^-|-$' $dir/oov_counts.txt | head -n 10 || true + +echo "*Training a G2P and generating missing pronunciations" +mkdir -p $dir/g2p/ +phonetisaurus-align --input=$dir/iv_lexicon.txt --ofile=$dir/g2p/aligned_lexicon.corpus +ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount\ + -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \ + -text $dir/g2p/aligned_lexicon.corpus -lm $dir/g2p/aligned_lexicon.arpa +phonetisaurus-arpa2wfst --lm=$dir/g2p/aligned_lexicon.arpa --ofile=$dir/g2p/g2p.fst +awk '{print $2}' $dir/oov_counts.txt > $dir/oov_words.txt +phonetisaurus-apply --nbest 2 --model $dir/g2p/g2p.fst --thresh 5 --accumulate \ + --word_list $dir/oov_words.txt > $dir/oov_lexicon.txt + +## The next section is again just for debug purposes +## to show words for which the G2P failed +cat $dir/oov_lexicon.txt $dir/iv_lexicon.txt | sort -u > $dir/lexicon.txt +rm -f $dir/lexiconp.txt 2>/dev/null; # can confuse later script if this exists. +awk '{print $1}' $dir/lexicon.txt | \ + perl -e '($word_counts)=@ARGV; + open(W, "<$word_counts")||die "opening word-counts $word_counts"; + while() { chop; $seen{$_}=1; } + while() { + ($c,$w) = split; + if (!defined $seen{$w}) { print; } + } ' $dir/word_counts > $dir/oov_counts.g2p.txt + +echo "*Highest-count OOVs (including fragments) after G2P are:" +head -n 10 $dir/oov_counts.g2p.txt + +utils/validate_dict_dir.pl $dir +exit 0; + diff --git a/egs/chime6/s5_track1/local/replace_uttid.py b/egs/chime6/s5_track1/local/replace_uttid.py new file mode 100755 index 000000000..96c45b587 --- /dev/null +++ b/egs/chime6/s5_track1/local/replace_uttid.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# Copyright Ashish Arora +# Apache 2.0 +# This script is used in score_for_submit. It adds locationid to the utteranceid, +# using uttid_location file, for locationwise scoring. + +import sys, io +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + +def load_uttid_location(f): + locations = {} + for line in f: + parts=line.strip().split(' ') + uttid, loc = parts[0], parts[1] + locations[uttid] = loc + return locations + +locations = load_uttid_location(open(sys.argv[1],'r', encoding='utf8')) + +for line in open(sys.argv[2],'r', encoding='utf8'): + uttid, res = line.split(None, 1) + try: + location = locations[uttid] + location_uttid = location +'_'+ str(uttid) + output.write(location_uttid + ' ' + res) + except KeyError as e: + raise Exception("Could not find utteranceid in " + "uttid_location file" + "({0})\n".format(str(e))) diff --git a/egs/chime6/s5_track1/local/reverberate_lat_dir.sh b/egs/chime6/s5_track1/local/reverberate_lat_dir.sh new file mode 100755 index 000000000..f601a37c0 --- /dev/null +++ b/egs/chime6/s5_track1/local/reverberate_lat_dir.sh @@ -0,0 +1,93 @@ +#!/bin/bash + +# Copyright 2018 Vimal Manohar +# Apache 2.0 + +num_data_reps=1 +cmd=run.pl +nj=20 +include_clean=false + +. utils/parse_options.sh +. ./path.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + exit 1 +fi + +train_data_dir=$1 +noisy_latdir=$2 +clean_latdir=$3 +dir=$4 + +clean_nj=$(cat $clean_latdir/num_jobs) + +$cmd JOB=1:$clean_nj $dir/copy_clean_lattices.JOB.log \ + lattice-copy "ark:gunzip -c $clean_latdir/lat.JOB.gz |" \ + ark,scp:$dir/lats_clean.JOB.ark,$dir/lats_clean.JOB.scp || exit 1 + +for n in $(seq $clean_nj); do + cat $dir/lats_clean.$n.scp +done > $dir/lats_clean.scp + +for i in $(seq $num_data_reps); do + cat $dir/lats_clean.scp | awk -vi=$i '{print "rev"i"_"$0}' +done > $dir/lats_rvb.scp + +noisy_nj=$(cat $noisy_latdir/num_jobs) +$cmd JOB=1:$noisy_nj $dir/copy_noisy_lattices.JOB>log \ + lattice-copy "ark:gunzip -c $noisy_latdir/lat.JOB.gz |" \ + ark,scp:$dir/lats_noisy.JOB.ark,$dir/lats_noisy.JOB.scp || exit 1 + +optional_clean= +if $include_clean; then + optional_clean=$dir/lats_clean.scp +fi + +for n in $(seq $noisy_nj); do + cat $dir/lats_noisy.$n.scp +done | cat - $dir/lats_rvb.scp ${optional_clean} | sort -k1,1 > $dir/lats.scp + +utils/split_data.sh $train_data_dir $nj +$cmd JOB=1:$nj $dir/copy_lattices.JOB.log \ + lattice-copy "scp:utils/filter_scp.pl $train_data_dir/split$nj/JOB/utt2spk $dir/lats.scp |" \ + "ark:|gzip -c >$dir/lat.JOB.gz" || exit 1 + +echo $nj > $dir/num_jobs + +if [ -f $clean_latdir/ali.1.gz ]; then + $cmd JOB=1:$clean_nj $dir/copy_clean_alignments.JOB.log \ + copy-int-vector "ark:gunzip -c $clean_latdir/ali.JOB.gz |" \ + ark,scp:$dir/ali_clean.JOB.ark,$dir/ali_clean.JOB.scp + + for n in $(seq $clean_nj); do + cat $dir/ali_clean.$n.scp + done > $dir/ali_clean.scp + + for i in $(seq $num_data_reps); do + cat $dir/ali_clean.scp | awk -vi=$i '{print "rev"i"_"$0}' + done > $dir/ali_rvb.scp + + optional_clean= + if $include_clean; then + optional_clean=$dir/ali_clean.scp + fi + + $cmd JOB=1:$noisy_nj $dir/copy_noisy_alignments.JOB.log \ + copy-int-vector "ark:gunzip -c $noisy_latdir/ali.JOB.gz |" \ + ark,scp:$dir/ali_noisy.JOB.ark,$dir/ali_noisy.JOB.scp + + for n in $(seq $noisy_nj); do + cat $dir/ali_noisy.$n.scp + done | cat - $dir/ali_rvb.scp $optional_clean | sort -k1,1 > $dir/ali.scp + + utils/split_data.sh $train_data_dir $nj || exit 1 + $cmd JOB=1:$nj $dir/copy_rvb_alignments.JOB.log \ + copy-int-vector "scp:utils/filter_scp.pl $train_data_dir/split$nj/JOB/utt2spk $dir/ali.scp |" \ + "ark:|gzip -c >$dir/ali.JOB.gz" || exit 1 +fi + +cp $clean_latdir/{final.*,tree,*.mat,*opts,*.txt} $dir || true + +rm $dir/lats_{clean,noisy}.*.{ark,scp} $dir/ali_{clean,noisy}.*.{ark,scp} || true # save space diff --git a/egs/chime6/s5_track1/local/run_beamformit.sh b/egs/chime6/s5_track1/local/run_beamformit.sh new file mode 100755 index 000000000..aa3badd90 --- /dev/null +++ b/egs/chime6/s5_track1/local/run_beamformit.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# Copyright 2015, Mitsubishi Electric Research Laboratories, MERL (Author: Shinji Watanabe) + +. ./cmd.sh +. ./path.sh + +# Config: +cmd=run.pl +bmf="1 2 3 4" + +. utils/parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "Wrong #arguments ($#, expected 3)" + echo "Usage: local/run_beamformit.sh [options] " + echo "main options (for others, see top of script file)" + echo " --cmd # Command to run in parallel with" + echo " --bmf \"1 2 3 4\" # microphones used for beamforming" + exit 1; +fi + +sdir=$1 +odir=$2 +array=$3 +expdir=exp/enhan/`echo $odir | awk -F '/' '{print $NF}'`_`echo $bmf | tr ' ' '_'` + +if ! command -v BeamformIt &>/dev/null ; then + echo "Missing BeamformIt, run 'cd $KALDI_ROOT/tools/; ./extras/install_beamformit.sh; cd -;'" && exit 1 +fi + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +mkdir -p $odir +mkdir -p $expdir/log + +echo "Will use the following channels: $bmf" +# number of channels +numch=`echo $bmf | tr ' ' '\n' | wc -l` +echo "the number of channels: $numch" + +# wavfiles.list can be used as the name of the output files +output_wavfiles=$expdir/wavfiles.list +find -L ${sdir} | grep -i ${array} | awk -F "/" '{print $NF}' | sed -e "s/\.CH.\.wav//" | sort | uniq > $expdir/wavfiles.list + +# this is an input file list of the microphones +# format: 1st_wav 2nd_wav ... nth_wav +input_arrays=$expdir/channels_$numch +for x in `cat $output_wavfiles`; do + echo -n "$x" + for ch in $bmf; do + echo -n " $x.CH$ch.wav" + done + echo "" +done > $input_arrays + +# split the list for parallel processing +# number of jobs are set by the number of WAV files +nj=`wc -l $expdir/wavfiles.list | awk '{print $1}'` +split_wavfiles="" +for n in `seq $nj`; do + split_wavfiles="$split_wavfiles $output_wavfiles.$n" +done +utils/split_scp.pl $output_wavfiles $split_wavfiles || exit 1; + +echo -e "Beamforming\n" +# making a shell script for each job +for n in `seq $nj`; do +cat << EOF > $expdir/log/beamform.$n.sh +while read line; do + $BEAMFORMIT/BeamformIt -s \$line -c $input_arrays \ + --config_file `pwd`/conf/beamformit.cfg \ + --source_dir $sdir \ + --result_dir $odir +done < $output_wavfiles.$n +EOF +done + +chmod a+x $expdir/log/beamform.*.sh +$cmd JOB=1:$nj $expdir/log/beamform.JOB.log \ + $expdir/log/beamform.JOB.sh + +echo "`basename $0` Done." diff --git a/egs/chime6/s5_track1/local/run_gss.sh b/egs/chime6/s5_track1/local/run_gss.sh new file mode 100755 index 000000000..fbdc4af25 --- /dev/null +++ b/egs/chime6/s5_track1/local/run_gss.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Copyright 2015, Mitsubishi Electric Research Laboratories, MERL (Author: Shinji Watanabe) + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi + +# Config: +cmd=run.pl +nj=4 +multiarray=outer_array_mics +bss_iterations=5 +context_samples=160000 +. utils/parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "Wrong #arguments ($#, expected 3)" + echo "Usage: local/run_gss.sh [options] " + echo "main options (for others, see top of script file)" + echo " --cmd # Command to run in parallel with" + echo " --bss_iterations 5 # Number of EM iterations" + echo " --context_samples 160000 # Left-right context in number of samples" + echo " --multiarray # Multiarray configuration" + exit 1; +fi + +# setting multiarray as "true" uses all mics, we didn't see any performance +# gain from this we have chosen settings that makes the enhacement finish +# in around 1/3 of a day without significant change in performance. +# our result during the experiments are as follows: + +#MAF: multi array = False +#MAT: multi array = True +#Enhancement Iterations Num Microphones Context Computational time for GSS #cpus dev WER eval WER +#GSS(MAF) 10 24 17 hrs 30 62.3 57.98 +#GSS(MAT) 5 24 10s 26 hrs 50 53.15 53.77 +#GSS(MAT) 5 12 10s 9.5 hrs 50 53.09 53.75 + +session_id=$1 +log_dir=$2 +enhanced_dir=$3 +if [ ! -d pb_chime5/ ]; then + echo "Missing pb_chime5, run 'local/install_pb_chime5'" + exit 1 +fi + +miniconda_dir=$HOME/miniconda3/ +if [ ! -d $miniconda_dir/ ]; then + echo "$miniconda_dir/ does not exist. Please run '../../../tools/extras/install_miniconda.sh'" + exit 1 +fi + +enhanced_dir=$(utils/make_absolute.sh $enhanced_dir) || \ + { echo "Could not make absolute '$enhanced_dir'" && exit 1; } + +$cmd JOB=1:$nj $log_dir/log/enhance_${session_id}.JOB.log \ + cd pb_chime5/ '&&' \ + $miniconda_dir/bin/python -m pb_chime5.scripts.kaldi_run with \ + chime6=True \ + storage_dir=$enhanced_dir \ + session_id=$session_id \ + job_id=JOB number_of_jobs=$nj \ + bss_iterations=$bss_iterations \ + context_samples=$context_samples \ + multiarray=$multiarray || exit 1 diff --git a/egs/chime6/s5_track1/local/run_wpe.py b/egs/chime6/s5_track1/local/run_wpe.py new file mode 100755 index 000000000..fbb264f2f --- /dev/null +++ b/egs/chime6/s5_track1/local/run_wpe.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# Copyright 2018 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 +# Works with both python2 and python3 +# This script assumes that WPE (nara_wpe) is installed locally using miniconda. +# ../../../tools/extras/install_miniconda.sh and ../../../tools/extras/install_wpe.sh +# needs to be run and this script needs to be launched run with that version of +# python. +# See local/run_wpe.sh for example. + +import numpy as np +import soundfile as sf +import time +import os, errno +from tqdm import tqdm +import argparse + +# to avoid huge memory consumption we decided to use `wpe_v8` instead of the original wpe by +# following the advice from Christoph Boeddeker at Paderborn University +# https://github.com/chimechallenge/kaldi_chime6/commit/2ea6ac07ef66ad98602f073b24a233cb7f61605c#r36147334 +from nara_wpe.wpe import wpe_v8 as wpe +from nara_wpe.utils import stft, istft +from nara_wpe import project_root + +parser = argparse.ArgumentParser() +parser.add_argument('--files', '-f', nargs='+') +args = parser.parse_args() + +input_files = args.files[:len(args.files)//2] +output_files = args.files[len(args.files)//2:] +out_dir = os.path.dirname(output_files[0]) +try: + os.makedirs(out_dir) +except OSError as e: + if e.errno != errno.EEXIST: + raise + +stft_options = dict( + size=512, + shift=128, + window_length=None, + fading=True, + pad=True, + symmetric_window=False +) + +sampling_rate = 16000 +delay = 3 +iterations = 5 +taps = 10 + +signal_list = [ + sf.read(f)[0] + for f in input_files +] +y = np.stack(signal_list, axis=0) +Y = stft(y, **stft_options).transpose(2, 0, 1) +Z = wpe(Y, iterations=iterations, statistics_mode='full').transpose(1, 2, 0) +z = istft(Z, size=stft_options['size'], shift=stft_options['shift']) + +for d in range(len(signal_list)): + sf.write(output_files[d], z[d,:], sampling_rate) diff --git a/egs/chime6/s5_track1/local/run_wpe.sh b/egs/chime6/s5_track1/local/run_wpe.sh new file mode 100755 index 000000000..ed512e69a --- /dev/null +++ b/egs/chime6/s5_track1/local/run_wpe.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# Copyright 2018 Johns Hopkins University (Author: Aswin Shanmugam Subramanian) +# Apache 2.0 + +. ./cmd.sh +. ./path.sh + +# Config: +nj=4 +cmd=run.pl + +. utils/parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "Wrong #arguments ($#, expected 3)" + echo "Usage: local/run_wpe.sh [options] " + echo "main options (for others, see top of script file)" + echo " --cmd # Command to run in parallel with" + echo " --nj 50 # number of jobs for parallel processing" + exit 1; +fi + +sdir=$1 +odir=$2 +array=$3 +task=`basename $sdir` +expdir=exp/wpe/${task}_${array} +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +miniconda_dir=$HOME/miniconda3/ +if [ ! -d $miniconda_dir ]; then + echo "$miniconda_dir does not exist. Please run '$KALDI_ROOT/tools/extras/install_miniconda.sh'." + exit 1 +fi + +# check if WPE is installed +result=`$miniconda_dir/bin/python -c "\ +try: + import nara_wpe + print('1') +except ImportError: + print('0')"` + +if [ "$result" == "1" ]; then + echo "WPE is installed" +else + echo "WPE is not installed. Please run ../../../tools/extras/install_wpe.sh" + exit 1 +fi + +mkdir -p $odir +mkdir -p $expdir/log + +# wavfiles.list can be used as the name of the output files +output_wavfiles=$expdir/wavfiles.list +find -L ${sdir} | grep -i ${array} > $expdir/channels_input +cat $expdir/channels_input | awk -F '/' '{print $NF}' | sed "s@S@$odir\/S@g" > $expdir/channels_output +paste -d" " $expdir/channels_input $expdir/channels_output > $output_wavfiles + +# split the list for parallel processing +split_wavfiles="" +for n in `seq $nj`; do + split_wavfiles="$split_wavfiles $output_wavfiles.$n" +done +utils/split_scp.pl $output_wavfiles $split_wavfiles || exit 1; + +echo -e "Dereverberation - $task - $array\n" +# making a shell script for each job +for n in `seq $nj`; do +cat <<-EOF > $expdir/log/wpe.$n.sh +while read line; do + $miniconda_dir/bin/python local/run_wpe.py \ + --file \$line +done < $output_wavfiles.$n +EOF +done + +chmod a+x $expdir/log/wpe.*.sh +$cmd JOB=1:$nj $expdir/log/wpe.JOB.log \ + $expdir/log/wpe.JOB.sh + +echo "`basename $0` Done." diff --git a/egs/chime6/s5_track1/local/score.sh b/egs/chime6/s5_track1/local/score.sh new file mode 120000 index 000000000..6a200b42e --- /dev/null +++ b/egs/chime6/s5_track1/local/score.sh @@ -0,0 +1 @@ +../steps/scoring/score_kaldi_wer.sh \ No newline at end of file diff --git a/egs/chime6/s5_track1/local/score_for_submit.sh b/egs/chime6/s5_track1/local/score_for_submit.sh new file mode 100755 index 000000000..ba7d6cde5 --- /dev/null +++ b/egs/chime6/s5_track1/local/score_for_submit.sh @@ -0,0 +1,132 @@ +#!/bin/bash +# Copyright 2012-2014 Johns Hopkins University (Author: Daniel Povey, Yenda Trmal) +# Copyright 2019 Johns Hopkins University (Author: Shinji Watanabe) +# Apache 2.0 +# +# This script provides official CHiME-6 challenge track 1 submission scores per room and session. +# It first calculates the best search parameter configurations by using the dev set +# and also create the transcriptions for dev and eval sets to be submitted. +# The default setup does not calculate scores of the evaluation set since +# the evaluation transcription is not distributed (July 9 2018) + +cmd=run.pl +dev=exp/chain_train_worn_u100k_cleaned/tdnn1a_sp/decode_dev_beamformit_ref +eval=exp/chain_train_worn_u100k_cleaned/tdnn1a_sp/decode_eval_beamformit_ref +do_eval=true +enhancement=gss +json= + +echo "$0 $@" # Print the command line for logging +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 0 ]; then + echo "Usage: $0 [--cmd (run.pl|queue.pl...)]" + echo "This script provides official CHiME-6 challenge submission scores" + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --dev # dev set decoding directory" + echo " --eval # eval set decoding directory" + echo " --enhancement # enhancement type (gss or beamformit)" + echo " --json # directory containing CHiME-6 json files" + exit 1; +fi + +# get language model weight and word insertion penalty from the dev set +best_lmwt=`cat $dev/scoring_kaldi/wer_details/lmwt` +best_wip=`cat $dev/scoring_kaldi/wer_details/wip` + +echo "best LM weight: $best_lmwt" +echo "insertion penalty weight: $best_wip" + +echo "==== development set ====" +# development set +# get uttid location mapping +local/add_location_to_uttid.sh --enhancement $enhancement $json/dev \ + $dev/scoring_kaldi/wer_details/ $dev/scoring_kaldi/wer_details/uttid_location +# get the scoring result per utterance +score_result=$dev/scoring_kaldi/wer_details/per_utt_loc + +for session in S02 S09; do + for room in DINING KITCHEN LIVING; do + # get nerror + nerr=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$4+$5+$6} END {print sum}'` + # get nwords from references (NF-2 means to exclude utterance id and " ref ") + nwrd=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$3+$4+$6} END {print sum}'` + # compute wer with scale=2 + wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` + + # report the results + echo -n "session $session " + echo -n "room $room: " + echo -n "#words $nwrd, " + echo -n "#errors $nerr, " + echo "wer $wer %" + done +done +echo -n "overall: " +# get nerror +nerr=`grep "\#csid" $score_result | awk '{sum+=$4+$5+$6} END {print sum}'` +# get nwords from references (NF-2 means to exclude utterance id and " ref ") +nwrd=`grep "\#csid" $score_result | awk '{sum+=$3+$4+$6} END {print sum}'` +# compute wer with scale=2 +wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` +echo -n "#words $nwrd, " +echo -n "#errors $nerr, " +echo "wer $wer %" + +echo "==== evaluation set ====" +# evaluation set +# get the scoring result per utterance. Copied from local/score.sh +mkdir -p $eval/scoring_kaldi/wer_details_devbest +$cmd $eval/scoring_kaldi/log/stats1.log \ + cat $eval/scoring_kaldi/penalty_$best_wip/$best_lmwt.txt \| \ + align-text --special-symbol="'***'" ark:$eval/scoring_kaldi/test_filt.txt ark:- ark,t:- \| \ + utils/scoring/wer_per_utt_details.pl --special-symbol "'***'" \> $eval/scoring_kaldi/wer_details_devbest/per_utt + +local/add_location_to_uttid.sh --enhancement $enhancement $json/eval \ + $eval/scoring_kaldi/wer_details_devbest/ $eval/scoring_kaldi/wer_details_devbest/uttid_location + +score_result=$eval/scoring_kaldi/wer_details_devbest/per_utt_loc +for session in S01 S21; do + for room in DINING KITCHEN LIVING; do + if $do_eval; then + # get nerror + nerr=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$4+$5+$6} END {print sum}'` + # get nwords from references (NF-2 means to exclude utterance id and " ref ") + nwrd=`grep "\#csid" $score_result | grep $room | grep $session | awk '{sum+=$3+$4+$6} END {print sum}'` + # compute wer with scale=2 + wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` + + # report the results + echo -n "session $session " + echo -n "room $room: " + echo -n "#words $nwrd, " + echo -n "#errors $nerr, " + echo "wer $wer %" + fi + done +done +if $do_eval; then + # get nerror + nerr=`grep "\#csid" $score_result | awk '{sum+=$4+$5+$6} END {print sum}'` + # get nwords from references (NF-2 means to exclude utterance id and " ref ") + nwrd=`grep "\#csid" $score_result | awk '{sum+=$3+$4+$6} END {print sum}'` + # compute wer with scale=2 + wer=`echo "scale=2; 100 * $nerr / $nwrd" | bc` + echo -n "overall: " + echo -n "#words $nwrd, " + echo -n "#errors $nerr, " + echo "wer $wer %" +else + echo "skip evaluation scoring" + echo "" + echo "==== when you submit your result to the CHiME-6 challenge track 1 ====" + echo "Please rename your recognition results of " + echo "$dev/scoring_kaldi/penalty_$best_wip/$best_lmwt.txt" + echo "$eval/scoring_kaldi/penalty_$best_wip/$best_lmwt.txt" + echo "with {dev,eval}__.txt, e.g., dev_watanabe_jhu.txt and eval_watanabe_jhu.txt, " + echo "and submit both of them as your final challenge result" + echo "==================================================================" +fi + diff --git a/egs/chime6/s5_track1/local/train_lms_srilm.sh b/egs/chime6/s5_track1/local/train_lms_srilm.sh new file mode 100755 index 000000000..5a1d56d24 --- /dev/null +++ b/egs/chime6/s5_track1/local/train_lms_srilm.sh @@ -0,0 +1,261 @@ +#!/bin/bash +# Copyright (c) 2017 Johns Hopkins University (Author: Yenda Trmal, Shinji Watanabe) +# Apache 2.0 + +export LC_ALL=C + +# Begin configuration section. +words_file= +train_text= +dev_text= +oov_symbol="" +# End configuration section + +echo "$0 $@" + +[ -f path.sh ] && . ./path.sh +. ./utils/parse_options.sh || exit 1 + +echo "-------------------------------------" +echo "Building an SRILM language model " +echo "-------------------------------------" + +if [ $# -ne 2 ] ; then + echo "Incorrect number of parameters. " + echo "Script has to be called like this:" + echo " $0 [switches] " + echo "For example: " + echo " $0 data data/srilm" + echo "The allowed switches are: " + echo " words_file= word list file -- data/lang/words.txt by default" + echo " train_text= data/train/text is used in case when not specified" + echo " dev_text= last 10 % of the train text is used by default" + echo " oov_symbol=> symbol to use for oov modeling -- by default" + exit 1 +fi + +datadir=$1 +tgtdir=$2 + +##End of configuration +loc=`which ngram-count`; +if [ -z $loc ]; then + echo >&2 "You appear to not have SRILM tools installed, either on your path," + echo >&2 "Use the script \$KALDI_ROOT/tools/install_srilm.sh to install it." + exit 1 +fi + +# Prepare the destination directory +mkdir -p $tgtdir + +for f in $words_file $train_text $dev_text; do + [ ! -s $f ] && echo "No such file $f" && exit 1; +done + +[ -z $words_file ] && words_file=$datadir/lang/words.txt +if [ ! -z "$train_text" ] && [ -z "$dev_text" ] ; then + nr=`cat $train_text | wc -l` + nr_dev=$(($nr / 10 )) + nr_train=$(( $nr - $nr_dev )) + orig_train_text=$train_text + head -n $nr_train $train_text > $tgtdir/train_text + tail -n $nr_dev $train_text > $tgtdir/dev_text + + train_text=$tgtdir/train_text + dev_text=$tgtdir/dev_text + echo "Using words file: $words_file" + echo "Using train text: 9/10 of $orig_train_text" + echo "Using dev text : 1/10 of $orig_train_text" +elif [ ! -z "$train_text" ] && [ ! -z "$dev_text" ] ; then + echo "Using words file: $words_file" + echo "Using train text: $train_text" + echo "Using dev text : $dev_text" + train_text=$train_text + dev_text=$dev_text +else + train_text=$datadir/train/text + dev_text=$datadir/dev2h/text + echo "Using words file: $words_file" + echo "Using train text: $train_text" + echo "Using dev text : $dev_text" + +fi + +[ ! -f $words_file ] && echo >&2 "File $words_file must exist!" && exit 1 +[ ! -f $train_text ] && echo >&2 "File $train_text must exist!" && exit 1 +[ ! -f $dev_text ] && echo >&2 "File $dev_text must exist!" && exit 1 + + +# Extract the word list from the training dictionary; exclude special symbols +sort $words_file | awk '{print $1}' | grep -v '\#0' | grep -v '' | grep -v -F "$oov_symbol" > $tgtdir/vocab +if (($?)); then + echo "Failed to create vocab from $words_file" + exit 1 +else + # wc vocab # doesn't work due to some encoding issues + echo vocab contains `cat $tgtdir/vocab | perl -ne 'BEGIN{$l=$w=0;}{split; $w+=$#_; $w++; $l++;}END{print "$l lines, $w words\n";}'` +fi + +# Kaldi transcript files contain Utterance_ID as the first word; remove it +# We also have to avoid skewing the LM by incorporating the same sentences +# from different channels +sed -e "s/\.CH.//" -e "s/_.\-./_/" -e "s/NOLOCATION\(\.[LR]\)*-//" -e "s/U[0-9][0-9]_//" $train_text | sort -u | \ + perl -ane 'print join(" ", @F[1..$#F]) . "\n" if @F > 1' > $tgtdir/train.txt +if (($?)); then + echo "Failed to create $tgtdir/train.txt from $train_text" + exit 1 +else + echo "Removed first word (uid) from every line of $train_text" + # wc text.train train.txt # doesn't work due to some encoding issues + echo $train_text contains `cat $train_text | perl -ane 'BEGIN{$w=$s=0;}{$w+=@F; $w--; $s++;}END{print "$w words, $s sentences\n";}'` + echo train.txt contains `cat $tgtdir/train.txt | perl -ane 'BEGIN{$w=$s=0;}{$w+=@F; $s++;}END{print "$w words, $s sentences\n";}'` +fi + +# Kaldi transcript files contain Utterance_ID as the first word; remove it +sed -e "s/\.CH.//" -e "s/_.\-./_/" $dev_text | sort -u | \ + perl -ane 'print join(" ", @F[1..$#F]) . "\n" if @F > 1' > $tgtdir/dev.txt +if (($?)); then + echo "Failed to create $tgtdir/dev.txt from $dev_text" + exit 1 +else + echo "Removed first word (uid) from every line of $dev_text" + # wc text.train train.txt # doesn't work due to some encoding issues + echo $dev_text contains `cat $dev_text | perl -ane 'BEGIN{$w=$s=0;}{$w+=@F; $w--; $s++;}END{print "$w words, $s sentences\n";}'` + echo $tgtdir/dev.txt contains `cat $tgtdir/dev.txt | perl -ane 'BEGIN{$w=$s=0;}{$w+=@F; $s++;}END{print "$w words, $s sentences\n";}'` +fi + + +echo "-------------------" +echo "Good-Turing 3grams" +echo "-------------------" +ngram-count -lm $tgtdir/3gram.gt011.gz -gt1min 0 -gt2min 1 -gt3min 1 -order 3 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.gt012.gz -gt1min 0 -gt2min 1 -gt3min 2 -order 3 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.gt022.gz -gt1min 0 -gt2min 2 -gt3min 2 -order 3 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.gt023.gz -gt1min 0 -gt2min 2 -gt3min 3 -order 3 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" + +echo "-------------------" +echo "Kneser-Ney 3grams" +echo "-------------------" +ngram-count -lm $tgtdir/3gram.kn011.gz -kndiscount1 -gt1min 0 \ + -kndiscount2 -gt2min 1 -kndiscount3 -gt3min 1 -order 3 -interpolate \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.kn012.gz -kndiscount1 -gt1min 0 \ + -kndiscount2 -gt2min 1 -kndiscount3 -gt3min 2 -order 3 -interpolate \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.kn022.gz -kndiscount1 -gt1min 0 \ + -kndiscount2 -gt2min 2 -kndiscount3 -gt3min 2 -order 3 -interpolate \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.kn023.gz -kndiscount1 -gt1min 0 \ + -kndiscount2 -gt2min 2 -kndiscount3 -gt3min 3 -order 3 -interpolate \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.kn111.gz -kndiscount1 -gt1min 1 \ + -kndiscount2 -gt2min 1 -kndiscount3 -gt3min 1 -order 3 -interpolate \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.kn112.gz -kndiscount1 -gt1min 1 \ + -kndiscount2 -gt2min 1 -kndiscount3 -gt3min 2 -order 3 -interpolate \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.kn122.gz -kndiscount1 -gt1min 1 \ + -kndiscount2 -gt2min 2 -kndiscount3 -gt3min 2 -order 3 -interpolate \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/3gram.kn123.gz -kndiscount1 -gt1min 1 \ + -kndiscount2 -gt2min 2 -kndiscount3 -gt3min 3 -order 3 -interpolate \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" + + +echo "-------------------" +echo "Good-Turing 4grams" +echo "-------------------" +ngram-count -lm $tgtdir/4gram.gt0111.gz \ + -gt1min 0 -gt2min 1 -gt3min 1 -gt4min 1 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.gt0112.gz \ + -gt1min 0 -gt2min 1 -gt3min 1 -gt4min 2 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.gt0122.gz \ + -gt1min 0 -gt2min 1 -gt3min 2 -gt4min 2 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.gt0123.gz \ + -gt1min 0 -gt2min 1 -gt3min 2 -gt4min 3 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.gt0113.gz \ + -gt1min 0 -gt2min 1 -gt3min 1 -gt4min 3 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.gt0222.gz \ + -gt1min 0 -gt2min 2 -gt3min 2 -gt4min 2 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.gt0223.gz \ + -gt1min 0 -gt2min 2 -gt3min 2 -gt4min 3 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" + +echo "-------------------" +echo "Kneser-Ney 4grams" +echo "-------------------" +ngram-count -lm $tgtdir/4gram.kn0111.gz \ + -kndiscount1 -gt1min 0 -kndiscount2 -gt2min 1 -kndiscount3 -gt3min 1 -kndiscount4 -gt4min 1 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.kn0112.gz \ + -kndiscount1 -gt1min 0 -kndiscount2 -gt2min 1 -kndiscount3 -gt3min 1 -kndiscount4 -gt4min 2 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.kn0113.gz \ + -kndiscount1 -gt1min 0 -kndiscount2 -gt2min 1 -kndiscount3 -gt3min 1 -kndiscount4 -gt4min 3 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.kn0122.gz \ + -kndiscount1 -gt1min 0 -kndiscount2 -gt2min 1 -kndiscount3 -gt3min 2 -kndiscount4 -gt4min 2 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.kn0123.gz \ + -kndiscount1 -gt1min 0 -kndiscount2 -gt2min 1 -kndiscount3 -gt3min 2 -kndiscount4 -gt4min 3 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.kn0222.gz \ + -kndiscount1 -gt1min 0 -kndiscount2 -gt2min 2 -kndiscount3 -gt3min 2 -kndiscount4 -gt4min 2 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" +ngram-count -lm $tgtdir/4gram.kn0223.gz \ + -kndiscount1 -gt1min 0 -kndiscount2 -gt2min 2 -kndiscount3 -gt3min 2 -kndiscount4 -gt4min 3 -order 4 \ + -text $tgtdir/train.txt -vocab $tgtdir/vocab -unk -sort -map-unk "$oov_symbol" + +if [ ! -z ${LIBLBFGS} ]; then + #please note that if the switch -map-unk "$oov_symbol" is used with -maxent-convert-to-arpa, ngram-count will segfault + #instead of that, we simply output the model in the maxent format and convert it using the "ngram" + echo "-------------------" + echo "Maxent 3grams" + echo "-------------------" + sed 's/'${oov_symbol}'//g' $tgtdir/train.txt | \ + ngram-count -lm - -order 3 -text - -vocab $tgtdir/vocab -unk -sort -maxent -maxent-convert-to-arpa|\ + ngram -lm - -order 3 -unk -map-unk "$oov_symbol" -prune-lowprobs -write-lm - |\ + sed 's//'${oov_symbol}'/g' | gzip -c > $tgtdir/3gram.me.gz || exit 1 + + echo "-------------------" + echo "Maxent 4grams" + echo "-------------------" + sed 's/'${oov_symbol}'//g' $tgtdir/train.txt | \ + ngram-count -lm - -order 4 -text - -vocab $tgtdir/vocab -unk -sort -maxent -maxent-convert-to-arpa|\ + ngram -lm - -order 4 -unk -map-unk "$oov_symbol" -prune-lowprobs -write-lm - |\ + sed 's//'${oov_symbol}'/g' | gzip -c > $tgtdir/4gram.me.gz || exit 1 +else + echo >&2 "SRILM is not compiled with the support of MaxEnt models." + echo >&2 "You should use the script in \$KALDI_ROOT/tools/install_srilm.sh" + echo >&2 "which will take care of compiling the SRILM with MaxEnt support" + exit 1; +fi + + +echo "--------------------" +echo "Computing perplexity" +echo "--------------------" +( + for f in $tgtdir/3gram* ; do ( echo $f; ngram -order 3 -lm $f -unk -map-unk "$oov_symbol" -prune-lowprobs -ppl $tgtdir/dev.txt ) | paste -s -d ' ' ; done + for f in $tgtdir/4gram* ; do ( echo $f; ngram -order 4 -lm $f -unk -map-unk "$oov_symbol" -prune-lowprobs -ppl $tgtdir/dev.txt ) | paste -s -d ' ' ; done +) | sort -r -n -k 15,15g | column -t | tee $tgtdir/perplexities.txt + +echo "The perlexity scores report is stored in $tgtdir/perplexities.txt " +echo "" + +for best_ngram in {3,4}gram ; do + outlm=best_${best_ngram}.gz + lmfilename=$(grep "${best_ngram}" $tgtdir/perplexities.txt | head -n 1 | cut -f 1 -d ' ') + echo "$outlm -> $lmfilename" + (cd $tgtdir; rm -f $outlm; ln -sf $(basename $lmfilename) $outlm ) +done diff --git a/egs/chime6/s5_track1/local/wer_output_filter b/egs/chime6/s5_track1/local/wer_output_filter new file mode 100755 index 000000000..6f4b64007 --- /dev/null +++ b/egs/chime6/s5_track1/local/wer_output_filter @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2017 Johns Hopkins University (Author: Yenda Trmal ) +# Apache 2.0 + + +## Filter for scoring of the STT results. Convert everything to lowercase +## and add some ad-hoc fixes for the hesitations + +perl -e ' + while() { + @A = split(" ", $_); + $id = shift @A; print "$id "; + foreach $a (@A) { + print lc($a) . " " unless $a =~ /\[.*\]/; + } + print "\n"; + }' | \ +sed -e ' + s/\/hmm/g; + s/\/hmm/g; + s/\/hmm/g; +' + +#| uconv -f utf-8 -t utf-8 -x Latin-ASCII + diff --git a/egs/chime6/s5_track1/path.sh b/egs/chime6/s5_track1/path.sh new file mode 100644 index 000000000..fb1c04893 --- /dev/null +++ b/egs/chime6/s5_track1/path.sh @@ -0,0 +1,7 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C + diff --git a/egs/chime6/s5_track1/run.sh b/egs/chime6/s5_track1/run.sh new file mode 100755 index 000000000..0890a939f --- /dev/null +++ b/egs/chime6/s5_track1/run.sh @@ -0,0 +1,280 @@ +#!/bin/bash +# +# Based mostly on the TED-LIUM and Switchboard recipe +# +# Copyright 2017 Johns Hopkins University (Author: Shinji Watanabe and Yenda Trmal) +# Apache 2.0 +# + +# Begin configuration section. +nj=96 +decode_nj=20 +stage=0 +nnet_stage=-10 +decode_stage=1 +decode_only=false +num_data_reps=4 +foreground_snrs="20:10:15:5:0" +background_snrs="20:10:15:5:0" +enhancement=beamformit # gss or beamformit + +# End configuration section +. ./utils/parse_options.sh + +. ./cmd.sh +. ./path.sh + +if [ $decode_only == "true" ]; then + stage=16 +fi + +set -e # exit on error + +# chime5 main directory path +# please change the path accordingly +chime5_corpus=/export/corpora4/CHiME5 +# chime6 data directories, which are generated from ${chime5_corpus}, +# to synchronize audio files across arrays and modify the annotation (JSON) file accordingly +chime6_corpus=${PWD}/CHiME6 +json_dir=${chime6_corpus}/transcriptions +audio_dir=${chime6_corpus}/audio + +if [[ ${enhancement} == *gss* ]]; then + enhanced_dir=${enhanced_dir}_multiarray + enhancement=${enhancement}_multiarray +fi + +if [[ ${enhancement} == *beamformit* ]]; then + enhanced_dir=${enhanced_dir} + enhancement=${enhancement} +fi + +test_sets="dev_${enhancement} eval_${enhancement}" +train_set=train_worn_simu_u400k + +# This script also needs the phonetisaurus g2p, srilm, beamformit +./local/check_tools.sh || exit 1 + +########################################################################### +# We first generate the synchronized audio files across arrays and +# corresponding JSON files. Note that this requires sox v14.4.2, +# which is installed via miniconda in ./local/check_tools.sh +########################################################################### + +if [ $stage -le 0 ]; then + local/generate_chime6_data.sh \ + --cmd "$train_cmd" \ + ${chime5_corpus} \ + ${chime6_corpus} +fi + +########################################################################### +# We prepare dict and lang in stages 1 to 3. +########################################################################### + +if [ $stage -le 1 ]; then + echo "$0: prepare data..." + # skip u03 and u04 as they are missing + for mictype in worn u01 u02 u05 u06; do + local/prepare_data.sh --mictype ${mictype} \ + ${audio_dir}/train ${json_dir}/train data/train_${mictype} + done + for dataset in dev; do + for mictype in worn; do + local/prepare_data.sh --mictype ${mictype} \ + ${audio_dir}/${dataset} ${json_dir}/${dataset} \ + data/${dataset}_${mictype} + done + done +fi + +if [ $stage -le 2 ]; then + echo "$0: train lm ..." + local/prepare_dict.sh + + utils/prepare_lang.sh \ + data/local/dict "" data/local/lang data/lang + + local/train_lms_srilm.sh \ + --train-text data/train_worn/text --dev-text data/dev_worn/text \ + --oov-symbol "" --words-file data/lang/words.txt \ + data/ data/srilm +fi + +LM=data/srilm/best_3gram.gz +if [ $stage -le 3 ]; then + # Compiles G for chime5 trigram LM + echo "$0: prepare lang..." + utils/format_lm.sh \ + data/lang $LM data/local/dict/lexicon.txt data/lang + +fi + +######################################################################################### +# In stages 4 to 7, we augment and fix train data for our training purpose. point source +# noises are extracted from chime corpus. Here we use 400k utterances from array microphones, +# its augmentation and all the worn set utterances in train. +######################################################################################### + +if [ $stage -le 4 ]; then + # remove possibly bad sessions (P11_S03, P52_S19, P53_S24, P54_S24) + # see http://spandh.dcs.shef.ac.uk/chime_challenge/data.html for more details + utils/copy_data_dir.sh data/train_worn data/train_worn_org # back up + grep -v -e "^P11_S03" -e "^P52_S19" -e "^P53_S24" -e "^P54_S24" data/train_worn_org/text > data/train_worn/text + utils/fix_data_dir.sh data/train_worn +fi + +if [ $stage -le 5 ]; then + local/extract_noises.py $chime6_corpus/audio/train $chime6_corpus/transcriptions/train \ + local/distant_audio_list distant_noises + local/make_noise_list.py distant_noises > distant_noise_list + + noise_list=distant_noise_list + + if [ ! -d RIRS_NOISES/ ]; then + # Download the package that includes the real RIRs, simulated RIRs, isotropic noises and point-source noises + wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip + unzip rirs_noises.zip + fi + + # This is the config for the system using simulated RIRs and point-source noises + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") + rvb_opts+=(--noise-set-parameters $noise_list) + + steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix "rev" \ + --foreground-snrs $foreground_snrs \ + --background-snrs $background_snrs \ + --speech-rvb-probability 1 \ + --pointsource-noise-addition-probability 1 \ + --isotropic-noise-addition-probability 1 \ + --num-replications $num_data_reps \ + --max-noises-per-minute 1 \ + --source-sampling-rate 16000 \ + data/train_worn data/train_worn_rvb +fi + +if [ $stage -le 6 ]; then + # combine mix array and worn mics + # randomly extract first 400k utterances from all mics + # if you want to include more training data, you can increase the number of array mic utterances + utils/combine_data.sh data/train_uall data/train_u01 data/train_u02 data/train_u05 data/train_u06 + utils/subset_data_dir.sh data/train_uall 400000 data/train_u400k + utils/combine_data.sh data/${train_set} data/train_worn data/train_worn_rvb data/train_u400k + + # only use left channel for worn mic recognition + # you can use both left and right channels for training + for dset in train dev; do + utils/copy_data_dir.sh data/${dset}_worn data/${dset}_worn_stereo + grep "\.L-" data/${dset}_worn_stereo/text > data/${dset}_worn/text + utils/fix_data_dir.sh data/${dset}_worn + done +fi + +if [ $stage -le 7 ]; then + # Split speakers up into 3-minute chunks. This doesn't hurt adaptation, and + # lets us use more jobs for decoding etc. + for dset in ${train_set}; do + utils/copy_data_dir.sh data/${dset} data/${dset}_nosplit + utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${dset}_nosplit data/${dset} + done +fi + +################################################################################## +# Now make 13-dim MFCC features. We use 13-dim fetures for GMM-HMM systems. +################################################################################## + +if [ $stage -le 8 ]; then + # Now make MFCC features. + # mfccdir should be some place with a largish disk where you + # want to store MFCC features. + echo "$0: make features..." + mfccdir=mfcc + for x in ${train_set}; do + steps/make_mfcc.sh --nj 20 --cmd "$train_cmd" \ + data/$x exp/make_mfcc/$x $mfccdir + steps/compute_cmvn_stats.sh data/$x exp/make_mfcc/$x $mfccdir + utils/fix_data_dir.sh data/$x + done +fi + +################################################################################### +# Stages 9 to 13 train monophone and triphone models. They will be used for +# generating lattices for training the chain model +################################################################################### + +if [ $stage -le 9 ]; then + # make a subset for monophone training + utils/subset_data_dir.sh --shortest data/${train_set} 100000 data/${train_set}_100kshort + utils/subset_data_dir.sh data/${train_set}_100kshort 30000 data/${train_set}_30kshort +fi + +if [ $stage -le 10 ]; then + # Starting basic training on MFCC features + steps/train_mono.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set}_30kshort data/lang exp/mono +fi + +if [ $stage -le 11 ]; then + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/mono exp/mono_ali + + steps/train_deltas.sh --cmd "$train_cmd" \ + 2500 30000 data/${train_set} data/lang exp/mono_ali exp/tri1 +fi + +if [ $stage -le 12 ]; then + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/tri1 exp/tri1_ali + + steps/train_lda_mllt.sh --cmd "$train_cmd" \ + 4000 50000 data/${train_set} data/lang exp/tri1_ali exp/tri2 +fi + +if [ $stage -le 13 ]; then + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/tri2 exp/tri2_ali + + steps/train_sat.sh --cmd "$train_cmd" \ + 5000 100000 data/${train_set} data/lang exp/tri2_ali exp/tri3 +fi + +####################################################################### +# Perform data cleanup for training data. +####################################################################### + +if [ $stage -le 14 ]; then + # The following script cleans the data and produces cleaned data + steps/cleanup/clean_and_segment_data.sh --nj ${nj} --cmd "$train_cmd" \ + --segmentation-opts "--min-segment-length 0.3 --min-new-segment-length 0.6" \ + data/${train_set} data/lang exp/tri3 exp/tri3_cleaned data/${train_set}_cleaned +fi + +########################################################################## +# CHAIN MODEL TRAINING +# skipping decoding here and performing it in step 16 +########################################################################## + +if [ $stage -le 15 ]; then + # chain TDNN + local/chain/run_tdnn.sh --nj ${nj} \ + --stage $nnet_stage \ + --train-set ${train_set}_cleaned \ + --test-sets "$test_sets" \ + --gmm tri3_cleaned --nnet3-affix _${train_set}_cleaned_rvb +fi + +########################################################################## +# DECODING is done in the local/decode.sh script. This script performs +# enhancement, fixes test sets performs feature extraction and 2 stage decoding +########################################################################## + +if [ $stage -le 16 ]; then + local/decode.sh --stage $decode_stage \ + --enhancement $enhancement \ + --train_set "$train_set" +fi + +exit 0; diff --git a/egs/chime6/s5_track1/steps b/egs/chime6/s5_track1/steps new file mode 120000 index 000000000..1b186770d --- /dev/null +++ b/egs/chime6/s5_track1/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/chime6/s5_track1/utils b/egs/chime6/s5_track1/utils new file mode 120000 index 000000000..a3279dc86 --- /dev/null +++ b/egs/chime6/s5_track1/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/chime6/s5_track2/RESULTS b/egs/chime6/s5_track2/RESULTS new file mode 100644 index 000000000..eacee1965 --- /dev/null +++ b/egs/chime6/s5_track2/RESULTS @@ -0,0 +1,19 @@ +# Results for Chime-6 track 2 for dev and eval, using pretrained models +# available at http://kaldi-asr.org/models/m12. + +# Speech Activity Detection (SAD) + Missed speech False alarm Total error +Dev 4.3 2.1 6.4 +Eval 5.6 5.9 11.5 + +# The results for the remaining pipeline are only for array U06. + +# Diarization + DER JER +Dev 31.37 20.45 +Eval 30.67 18.97 + +# ASR nnet3 tdnn+chain +Dev: U06 58881 48061 81.62 +Eval: U06 55132 47184 85.58 + diff --git a/egs/chime6/s5_track2/cmd.sh b/egs/chime6/s5_track2/cmd.sh new file mode 100644 index 000000000..86514d94d --- /dev/null +++ b/egs/chime6/s5_track2/cmd.sh @@ -0,0 +1,14 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export train_cmd="retry.pl queue.pl --mem 2G" +export decode_cmd="queue.pl --mem 4G" diff --git a/egs/chime6/s5_track2/conf/beamformit.cfg b/egs/chime6/s5_track2/conf/beamformit.cfg new file mode 100755 index 000000000..70fdd8586 --- /dev/null +++ b/egs/chime6/s5_track2/conf/beamformit.cfg @@ -0,0 +1,50 @@ +#BeamformIt sample configuration file for AMI data (http://groups.inf.ed.ac.uk/ami/download/) + +# scrolling size to compute the delays +scroll_size = 250 + +# cross correlation computation window size +window_size = 500 + +#amount of maximum points for the xcorrelation taken into account +nbest_amount = 4 + +#flag wether to apply an automatic noise thresholding +do_noise_threshold = 1 + +#Percentage of frames with lower xcorr taken as noisy +noise_percent = 10 + +######## acoustic modelling parameters + +#transition probabilities weight for multichannel decoding +trans_weight_multi = 25 +trans_weight_nbest = 25 + +### + +#flag wether to print the feaures after setting them, or not +print_features = 1 + +#flag wether to use the bad frames in the sum process +do_avoid_bad_frames = 1 + +#flag to use the best channel (SNR) as a reference +#defined from command line +do_compute_reference = 1 + +#flag wether to use a uem file or not(process all the file) +do_use_uem_file = 0 + +#flag wether to use an adaptative weights scheme or fixed weights +do_adapt_weights = 1 + +#flag wether to output the sph files or just run the system to create the auxiliary files +do_write_sph_files = 1 + +####directories where to store/retrieve info#### +#channels_file = ./cfg-files/channels + +#show needs to be passed as argument normally, here a default one is given just in case +#show_id = Ttmp + diff --git a/egs/chime6/s5_track2/conf/mfcc.conf b/egs/chime6/s5_track2/conf/mfcc.conf new file mode 100644 index 000000000..32988403b --- /dev/null +++ b/egs/chime6/s5_track2/conf/mfcc.conf @@ -0,0 +1,2 @@ +--use-energy=false +--sample-frequency=16000 diff --git a/egs/chime6/s5_track2/conf/mfcc_hires.conf b/egs/chime6/s5_track2/conf/mfcc_hires.conf new file mode 100644 index 000000000..fd64b62eb --- /dev/null +++ b/egs/chime6/s5_track2/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training. +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--sample-frequency=16000 +--num-mel-bins=40 +--num-ceps=40 +--low-freq=40 +--high-freq=-400 diff --git a/egs/chime6/s5_track2/conf/online_cmvn.conf b/egs/chime6/s5_track2/conf/online_cmvn.conf new file mode 100644 index 000000000..7748a4a4d --- /dev/null +++ b/egs/chime6/s5_track2/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh diff --git a/egs/chime6/s5_track2/conf/sad.conf b/egs/chime6/s5_track2/conf/sad.conf new file mode 100644 index 000000000..752bb1cf6 --- /dev/null +++ b/egs/chime6/s5_track2/conf/sad.conf @@ -0,0 +1,2 @@ +affix=_1a +nnet_type=stats diff --git a/egs/chime6/s5_track2/diarization b/egs/chime6/s5_track2/diarization new file mode 120000 index 000000000..bad937c14 --- /dev/null +++ b/egs/chime6/s5_track2/diarization @@ -0,0 +1 @@ +../../callhome_diarization/v1/diarization \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/chain b/egs/chime6/s5_track2/local/chain new file mode 120000 index 000000000..dd7910711 --- /dev/null +++ b/egs/chime6/s5_track2/local/chain @@ -0,0 +1 @@ +../../s5_track1/local/chain/ \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/check_tools.sh b/egs/chime6/s5_track2/local/check_tools.sh new file mode 120000 index 000000000..4e835e887 --- /dev/null +++ b/egs/chime6/s5_track2/local/check_tools.sh @@ -0,0 +1 @@ +../../s5_track1/local/check_tools.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/convert_rttm_to_utt2spk_and_segments.py b/egs/chime6/s5_track2/local/convert_rttm_to_utt2spk_and_segments.py new file mode 100755 index 000000000..410dced19 --- /dev/null +++ b/egs/chime6/s5_track2/local/convert_rttm_to_utt2spk_and_segments.py @@ -0,0 +1,98 @@ +#! /usr/bin/env python +# Copyright 2019 Vimal Manohar +# Apache 2.0. + +"""This script converts an RTTM with +speaker info into kaldi utt2spk and segments""" + +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script converts an RTTM with + speaker info into kaldi utt2spk and segments""") + parser.add_argument("--use-reco-id-as-spkr", type=str, + choices=["true", "false"], default="false", + help="Use the recording ID based on RTTM and " + "reco2file_and_channel as the speaker") + parser.add_argument("--append-reco-id-to-spkr", type=str, + choices=["true", "false"], default="false", + help="Append recording ID to the speaker ID") + + parser.add_argument("rttm_file", type=str, + help="""Input RTTM file. + The format of the RTTM file is + """ + """ """) + parser.add_argument("reco2file_and_channel", type=str, + help="""Input reco2file_and_channel. + The format is .""") + parser.add_argument("utt2spk", type=str, + help="Output utt2spk file") + parser.add_argument("segments", type=str, + help="Output segments file") + + args = parser.parse_args() + + args.use_reco_id_as_spkr = bool(args.use_reco_id_as_spkr == "true") + args.append_reco_id_to_spkr = bool(args.append_reco_id_to_spkr == "true") + + if args.use_reco_id_as_spkr: + if args.append_reco_id_to_spkr: + raise Exception("Appending recording ID to speaker does not make sense when using --use-reco-id-as-spkr=true") + + return args + +def main(): + args = get_args() + + file_and_channel2reco = {} + utt2spk={} + segments={} + for line in open(args.reco2file_and_channel): + parts = line.strip().split() + file_and_channel2reco[(parts[1], parts[2])] = parts[0] + + utt2spk_writer = open(args.utt2spk, 'w') + segments_writer = open(args.segments, 'w') + for line in open(args.rttm_file): + parts = line.strip().split() + if parts[0] != "SPEAKER": + continue + + file_id = parts[1] + channel = parts[2] + + try: + reco = file_and_channel2reco[(file_id, channel)] + except KeyError as e: + raise Exception("Could not find recording with " + "(file_id, channel) " + "= ({0},{1}) in {2}: {3}\n".format( + file_id, channel, + args.reco2file_and_channel, str(e))) + + start_time = float(parts[3]) + end_time = start_time + float(parts[4]) + + if args.use_reco_id_as_spkr: + spkr = reco + else: + if args.append_reco_id_to_spkr: + spkr = reco + "-" + parts[7] + else: + spkr = parts[7] + + st = int(start_time * 100) + end = int(end_time * 100) + utt = "{0}-{1:06d}-{2:06d}".format(spkr, st, end) + utt2spk[utt]=spkr + segments[utt]=(reco, start_time, end_time) + + for uttid_id in sorted(utt2spk): + utt2spk_writer.write("{0} {1}\n".format(uttid_id, utt2spk[uttid_id])) + segments_writer.write("{0} {1} {2:7.2f} {3:7.2f}\n".format( + uttid_id, segments[uttid_id][0], segments[uttid_id][1], segments[uttid_id][2])) + +if __name__ == '__main__': + main() diff --git a/egs/chime6/s5_track2/local/copy_lat_dir_parallel.sh b/egs/chime6/s5_track2/local/copy_lat_dir_parallel.sh new file mode 120000 index 000000000..a168a917d --- /dev/null +++ b/egs/chime6/s5_track2/local/copy_lat_dir_parallel.sh @@ -0,0 +1 @@ +../../s5_track1/local/copy_lat_dir_parallel.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/decode.sh b/egs/chime6/s5_track2/local/decode.sh new file mode 100755 index 000000000..66a96fce3 --- /dev/null +++ b/egs/chime6/s5_track2/local/decode.sh @@ -0,0 +1,173 @@ +#!/bin/bash +# +# This script decodes raw utterances through the entire pipeline: +# Feature extraction -> SAD -> Diarization -> ASR +# +# Copyright 2017 Johns Hopkins University (Author: Shinji Watanabe and Yenda Trmal) +# 2019 Desh Raj, David Snyder, Ashish Arora +# Apache 2.0 + +# Begin configuration section. +nj=8 +decode_nj=10 +stage=0 +sad_stage=0 +diarizer_stage=0 +decode_diarize_stage=0 +score_stage=0 +enhancement=beamformit + +# chime5 main directory path +# please change the path accordingly +chime5_corpus=/export/corpora4/CHiME5 +# chime6 data directories, which are generated from ${chime5_corpus}, +# to synchronize audio files across arrays and modify the annotation (JSON) file accordingly +chime6_corpus=${PWD}/CHiME6 +json_dir=${chime6_corpus}/transcriptions +audio_dir=${chime6_corpus}/audio + +enhanced_dir=enhanced +enhanced_dir=$(utils/make_absolute.sh $enhanced_dir) || exit 1 + +# training data +train_set=train_worn_simu_u400k +test_sets="dev_${enhancement}_dereverb eval_${enhancement}_dereverb" + +. ./utils/parse_options.sh + +. ./cmd.sh +. ./path.sh +. ./conf/sad.conf + +# This script also needs the phonetisaurus g2p, srilm, beamformit +./local/check_tools.sh || exit 1 + +########################################################################### +# We first generate the synchronized audio files across arrays and +# corresponding JSON files. Note that this requires sox v14.4.2, +# which is installed via miniconda in ./local/check_tools.sh +########################################################################### + +if [ $stage -le 0 ]; then + local/generate_chime6_data.sh \ + --cmd "$train_cmd" \ + ${chime5_corpus} \ + ${chime6_corpus} +fi + +####################################################################### +# Prepare the dev and eval data with dereverberation (WPE) and +# beamforming. +####################################################################### +if [ $stage -le 1 ]; then + # Beamforming using reference arrays + # enhanced WAV directory + enhandir=enhan + dereverb_dir=${PWD}/wav/wpe/ + + for dset in dev eval; do + for mictype in u01 u02 u03 u04 u06; do + local/run_wpe.sh --nj 4 --cmd "$train_cmd --mem 20G" \ + ${audio_dir}/${dset} \ + ${dereverb_dir}/${dset} \ + ${mictype} + done + done + + for dset in dev eval; do + for mictype in u01 u02 u03 u04 u06; do + local/run_beamformit.sh --cmd "$train_cmd" \ + ${dereverb_dir}/${dset} \ + ${enhandir}/${dset}_${enhancement}_${mictype} \ + ${mictype} + done + done + + # Note that for the evaluation sets, we use the flag + # "--train false". This keeps the files segments, text, + # and utt2spk with .bak extensions, so that they can + # be used later for scoring if needed but are not used + # in the intermediate stages. + for dset in dev eval; do + local/prepare_data.sh --mictype ref --train false \ + "$PWD/${enhandir}/${dset}_${enhancement}_u0*" \ + ${json_dir}/${dset} data/${dset}_${enhancement}_dereverb + done +fi + +if [ $stage -le 2 ]; then + # mfccdir should be some place with a largish disk where you + # want to store MFCC features. + mfccdir=mfcc + for x in ${test_sets}; do + steps/make_mfcc.sh --nj $decode_nj --cmd "$train_cmd" \ + --mfcc-config conf/mfcc_hires.conf \ + data/$x exp/make_mfcc/$x $mfccdir + done +fi + +####################################################################### +# Perform SAD on the dev/eval data +####################################################################### +dir=exp/segmentation${affix} +sad_work_dir=exp/sad${affix}_${nnet_type}/ +sad_nnet_dir=$dir/tdnn_${nnet_type}_sad_1a + +if [ $stage -le 3 ]; then + for datadir in ${test_sets}; do + test_set=data/${datadir} + if [ ! -f ${test_set}/wav.scp ]; then + echo "$0: Not performing SAD on ${test_set}" + exit 0 + fi + # Perform segmentation + local/segmentation/detect_speech_activity.sh --nj $decode_nj --stage $sad_stage \ + $test_set $sad_nnet_dir mfcc $sad_work_dir \ + data/${datadir} || exit 1 + + mv data/${datadir}_seg data/${datadir}_${nnet_type}_seg + # Generate RTTM file from segmentation performed by SAD. This can + # be used to evaluate the performance of the SAD as an intermediate + # step. + steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ + data/${datadir}_${nnet_type}_seg/utt2spk data/${datadir}_${nnet_type}_seg/segments \ + data/${datadir}_${nnet_type}_seg/rttm + done +fi + +####################################################################### +# Perform diarization on the dev/eval data +####################################################################### +if [ $stage -le 4 ]; then + for datadir in ${test_sets}; do + local/diarize.sh --nj 10 --cmd "$train_cmd" --stage $diarizer_stage \ + exp/xvector_nnet_1a \ + data/${datadir}_${nnet_type}_seg \ + exp/${datadir}_${nnet_type}_seg_diarization + done +fi + +####################################################################### +# Decode diarized output using trained chain model +####################################################################### +if [ $stage -le 5 ]; then + for datadir in ${test_sets}; do + local/decode_diarized.sh --nj $nj --cmd "$decode_cmd" --stage $decode_diarize_stage \ + exp/${datadir}_${nnet_type}_seg_diarization data/$datadir data/lang_chain \ + exp/chain_${train_set}_cleaned_rvb exp/nnet3_${train_set}_cleaned_rvb \ + data/${datadir}_diarized + done +fi + +####################################################################### +# Score decoded dev/eval sets +####################################################################### +if [ $stage -le 6 ]; then + for datadir in ${test_sets}; do + local/multispeaker_score.sh --cmd "$train_cmd" --stage $score_stage \ + --datadir $datadir data/${datadir}_diarized_hires/text \ + exp/chain_${train_set}_cleaned_rvb/tdnn1b_sp/decode_${datadir}_diarized_2stage/scoring_kaldi/penalty_1.0/10.txt \ + exp/chain_${train_set}_cleaned_rvb/tdnn1b_sp/decode_${datadir}_diarized_2stage/scoring_kaldi_multispeaker + done +fi +exit 0; diff --git a/egs/chime6/s5_track2/local/decode_diarized.sh b/egs/chime6/s5_track2/local/decode_diarized.sh new file mode 100755 index 000000000..2d0ad6a3b --- /dev/null +++ b/egs/chime6/s5_track2/local/decode_diarized.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright 2019 Ashish Arora, Vimal Manohar +# Apache 2.0. +# This script takes an rttm file, and performs decoding on on a test directory. +# The output directory contains a text file which can be used for scoring. + + +stage=0 +nj=8 +cmd=queue.pl +echo "$0 $@" # Print the command line for logging +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; +if [ $# != 6 ]; then + echo "Usage: $0 " + echo "e.g.: $0 data/rttm data/dev data/lang_chain exp/chain_train_worn_simu_u400k_cleaned_rvb \ + exp/nnet3_train_worn_simu_u400k_cleaned_rvb data/dev_diarized" + echo "Options: " + echo " --nj # number of parallel jobs." + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +rttm_dir=$1 +data_in=$2 +lang_dir=$3 +asr_model_dir=$4 +ivector_extractor=$5 +out_dir=$6 + +for f in $rttm_dir/rttm $data_in/wav.scp $data_in/text.bak \ + $lang_dir/L.fst $asr_model_dir/tree_sp/graph/HCLG.fst \ + $asr_model_dir/tdnn1b_sp/final.mdl; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +if [ $stage -le 0 ]; then + echo "$0 copying data files in output directory" + cp $rttm_dir/rttm $rttm_dir/rttm_1 + sed -i 's/'.ENH'/''/g' $rttm_dir/rttm_1 + mkdir -p ${out_dir}_hires + cp ${data_in}/{wav.scp,utt2spk} ${out_dir}_hires + utils/data/get_reco2dur.sh ${out_dir}_hires +fi + +if [ $stage -le 1 ]; then + echo "$0 creating segments file from rttm and utt2spk, reco2file_and_channel " + local/convert_rttm_to_utt2spk_and_segments.py --append-reco-id-to-spkr=true $rttm_dir/rttm_1 \ + <(awk '{print $2".ENH "$2" "$3}' $rttm_dir/rttm_1 |sort -u) \ + ${out_dir}_hires/utt2spk ${out_dir}_hires/segments + + utils/utt2spk_to_spk2utt.pl ${out_dir}_hires/utt2spk > ${out_dir}_hires/spk2utt + + awk '{print $1" "$1" 1"}' ${out_dir}_hires/wav.scp > ${out_dir}_hires/reco2file_and_channel + utils/fix_data_dir.sh ${out_dir}_hires || exit 1; +fi + +if [ $stage -le 2 ]; then + echo "$0 extracting mfcc freatures using segments file" + steps/make_mfcc.sh --mfcc-config conf/mfcc_hires.conf --nj $nj --cmd queue.pl ${out_dir}_hires + steps/compute_cmvn_stats.sh ${out_dir}_hires + cp $data_in/text.bak ${out_dir}_hires/text +fi + +if [ $stage -le 3 ]; then + echo "$0 performing decoding on the extracted features" + local/nnet3/decode.sh --affix 2stage --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk 150 --nj $nj --ivector-dir $ivector_extractor \ + $out_dir $lang_dir $asr_model_dir/tree_sp/graph $asr_model_dir/tdnn1b_sp/ +fi + diff --git a/egs/chime6/s5_track2/local/diarize.sh b/egs/chime6/s5_track2/local/diarize.sh new file mode 100755 index 000000000..2ca95dc0f --- /dev/null +++ b/egs/chime6/s5_track2/local/diarize.sh @@ -0,0 +1,95 @@ +#!/bin/bash +# Copyright 2019 David Snder +# Apache 2.0. +# +# This script takes an input directory that has a segments file (and +# a feats.scp file), and performs diarization on it. The output directory +# contains an RTTM file which can be used to resegment the input data. + +stage=0 +nj=10 +cmd="run.pl" +ref_rttm= + +echo "$0 $@" # Print the command line for logging +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; +if [ $# != 3 ]; then + echo "Usage: $0 " + echo "e.g.: $0 exp/xvector_nnet_1a data/dev exp/dev_diarization" + echo "Options: " + echo " --nj # number of parallel jobs." + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --ref-rttm # if present, used to score output RTTM." + exit 1; +fi + +model_dir=$1 +data_in=$2 +out_dir=$3 + +name=`basename $data_in` + +for f in $data_in/feats.scp $data_in/segments $model_dir/plda \ + $model_dir/final.raw $model_dir/extract.config; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +if [ $stage -le 0 ]; then + echo "$0: keeping only data corresponding to array U06 " + echo "$0: we can skip this stage, to perform diarization on all arrays " + cp -r data/$name data/${name}.bak + mv data/$name/wav.scp data/$name/wav.scp.bak + grep 'U06' data/$name/wav.scp.bak > data/$name/wav.scp + utils/fix_data_dir.sh data/$name + nj=2 # since we have reduced number of "speakers" now +fi + +if [ $stage -le 1 ]; then + echo "$0: computing features for x-vector extractor" + utils/fix_data_dir.sh data/${name} + rm -rf data/${name}_cmn + local/nnet3/xvector/prepare_feats.sh --nj $nj --cmd "$cmd" \ + data/$name data/${name}_cmn exp/${name}_cmn + cp data/$name/segments exp/${name}_cmn/ + utils/fix_data_dir.sh data/${name}_cmn +fi + +if [ $stage -le 2 ]; then + echo "$0: extracting x-vectors for all segments" + diarization/nnet3/xvector/extract_xvectors.sh --cmd "$cmd" \ + --nj $nj --window 1.5 --period 0.75 --apply-cmn false \ + --min-segment 0.5 $model_dir \ + data/${name}_cmn $out_dir/xvectors_${name} +fi + +# Perform PLDA scoring +if [ $stage -le 3 ]; then + # Perform PLDA scoring on all pairs of segments for each recording. + echo "$0: performing PLDA scoring between all pairs of x-vectors" + diarization/nnet3/xvector/score_plda.sh --cmd "$cmd" \ + --target-energy 0.5 \ + --nj $nj $model_dir/ $out_dir/xvectors_${name} \ + $out_dir/xvectors_${name}/plda_scores +fi + +if [ $stage -le 4 ]; then + echo "$0: performing clustering using PLDA scores (we assume 4 speakers per recording)" + awk '{print $1, "4"}' data/$name/wav.scp > data/$name/reco2num_spk + diarization/cluster.sh --cmd "$cmd" --nj $nj \ + --reco2num-spk data/$name/reco2num_spk \ + --rttm-channel 1 \ + $out_dir/xvectors_${name}/plda_scores $out_dir + echo "$0: wrote RTTM to output directory ${out_dir}" +fi + +if [ $stage -le 5 ]; then + if [ -f $ref_rttm ]; then + echo "$0: computing diariztion error rate (DER) using reference ${ref_rttm}" + mkdir -p $out_dir/tuning/ + md-eval.pl -c 0.25 -1 -r $ref_rttm -s $out_dir/rttm 2> $out_dir/log/der.log > $out_dir/der + der=$(grep -oP 'DIARIZATION\ ERROR\ =\ \K[0-9]+([.][0-9]+)?' ${out_dir}/der) + echo "DER: $der%" + fi +fi + diff --git a/egs/chime6/s5_track2/local/distant_audio_list b/egs/chime6/s5_track2/local/distant_audio_list new file mode 120000 index 000000000..0455876cf --- /dev/null +++ b/egs/chime6/s5_track2/local/distant_audio_list @@ -0,0 +1 @@ +../../s5_track1/local/distant_audio_list \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/extract_noises.py b/egs/chime6/s5_track2/local/extract_noises.py new file mode 120000 index 000000000..04a638991 --- /dev/null +++ b/egs/chime6/s5_track2/local/extract_noises.py @@ -0,0 +1 @@ +../../s5_track1/local/extract_noises.py \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/extract_vad_weights.sh b/egs/chime6/s5_track2/local/extract_vad_weights.sh new file mode 120000 index 000000000..0db29cded --- /dev/null +++ b/egs/chime6/s5_track2/local/extract_vad_weights.sh @@ -0,0 +1 @@ +../../s5_track1/local/extract_vad_weights.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/gen_aligned_hyp.py b/egs/chime6/s5_track2/local/gen_aligned_hyp.py new file mode 100755 index 000000000..acaa3a13a --- /dev/null +++ b/egs/chime6/s5_track2/local/gen_aligned_hyp.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright 2019 Yusuke Fujita +# Apache 2.0. + +"""This script generates hypothesis utterances aligned with reference segments. + Usage: gen_align_hyp.py alignment.txt wc.txt > hyp.txt + alignment.txt is a session-level word alignment generated by align-text command. + wc.txt is a sequence of utt-id:reference_word_count generated by 'local/get_ref_perspeaker_persession_file.py'. +""" + +import sys, io +import string +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + +def load_align_text(f): + alignments = {} + for line in f: + recoid, res = line.split(None, 1) + alignments[recoid] = [] + toks = res.split(';') + for tok in toks: + ref, hyp = tok.split() + alignments[recoid].append((ref, hyp)) + return alignments + +alignments = load_align_text(open(sys.argv[1],'r', encoding='utf8')) + +for line in open(sys.argv[2],'r', encoding='utf8'): + recoid, res = line.split(None, 1) + ali = iter(alignments[recoid]) + toks = res.split() + for tok in toks: + uttid, count = tok.split(':') + count = int(count) + text = '' + for i in range(count): + while True: + ref, hyp = ali.__next__() + if hyp != '': + text += ' ' + hyp + if ref != '': + break + output.write(uttid + ' ' + text.strip() + '\n') diff --git a/egs/chime6/s5_track2/local/generate_chime6_data.sh b/egs/chime6/s5_track2/local/generate_chime6_data.sh new file mode 120000 index 000000000..62882cd62 --- /dev/null +++ b/egs/chime6/s5_track2/local/generate_chime6_data.sh @@ -0,0 +1 @@ +../../s5_track1/local/generate_chime6_data.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/get_best_error.py b/egs/chime6/s5_track2/local/get_best_error.py new file mode 100755 index 000000000..651a63229 --- /dev/null +++ b/egs/chime6/s5_track2/local/get_best_error.py @@ -0,0 +1,86 @@ +#! /usr/bin/env python3 +# Copyright 2019 Ashish Arora +# Apache 2.0. +"""This script finds best matching of reference and hypothesis speakers. + For the best matching speakers,it provides the WER for the reference session + (eg:S02) and hypothesis recording (eg: S02_U02)""" + +import itertools +import numpy as np +import argparse +from munkres import Munkres + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script finds best matching of reference and hypothesis speakers. + For the best matching it provides the WER""") + parser.add_argument("WER_dir", type=str, + help="path of WER files") + parser.add_argument("recording_id", type=str, + help="recording_id name") + args = parser.parse_args() + return args + + +def get_results(filename): + with open(filename) as f: + first_line = f.readline() + parts = first_line.strip().split(',') + total_words = parts[0].split()[-1] + ins = parts[1].split()[0] + deletions = parts[2].split()[0] + sub = parts[3].split()[0] + return total_words, ins, deletions, sub + + +def get_min_wer(recording_id, num_speakers, WER_dir): + best_wer_file = WER_dir + '/' + 'best_wer' + '_' + recording_id + best_wer_writer = open(best_wer_file, 'w') + m = Munkres() + total_error_mat = [0] * num_speakers + all_errors_mat = [0] * num_speakers + for i in range(num_speakers): + total_error_mat[i] = [0] * num_speakers + all_errors_mat[i] = [0] * num_speakers + for i in range(1, num_speakers+1): + for j in range(1, num_speakers+1): + filename = '/wer_' + recording_id + '_' + 'r' + str(i)+ 'h' + str(j) + filename = WER_dir + filename + total_words, ins, deletions, sub = get_results(filename) + ins = int(ins) + dele = int(deletions) + sub = int(sub) + total_error = ins + dele + sub + total_error_mat[i-1][j-1]=total_error + all_errors_mat[i-1][j-1]= (total_words, total_error, ins, dele, sub) + + indexes = m.compute(total_error_mat) + total_errors=total_words=total_ins=total_del=total_sub=0 + spk_order = '(' + for row, column in indexes: + words, errs, ins, dele, sub = all_errors_mat[row][column] + total_errors += int(errs) + total_words += int(words) + total_ins += int(ins) + total_del += int(deletions) + total_sub += int(sub) + spk_order = spk_order + str(column+1) + ', ' + spk_order = spk_order + ')' + text = "Best error: (#T #E #I #D #S) " + str(total_words)+ ', '+str(total_errors)+ ', '+str(total_ins)+ ', '+str(total_del)+ ', '+str(total_sub) + best_wer_writer.write(" recording_id: "+ recording_id + ' ') + best_wer_writer.write(' best hypothesis speaker order: ' + spk_order + ' ') + best_wer_writer.write(text+ '\n') + print("recording_id: "+ recording_id + ' ') + print('best hypothesis speaker order: ' + spk_order + ' ') + print(text) + best_wer_writer.close() + + +def main(): + args = get_args() + num_speakers = 4 + get_min_wer(args.recording_id, num_speakers, args.WER_dir) + + +if __name__ == '__main__': + main() diff --git a/egs/chime6/s5_track2/local/get_hyp_perspeaker_perarray_file.py b/egs/chime6/s5_track2/local/get_hyp_perspeaker_perarray_file.py new file mode 100755 index 000000000..7b3e14aaa --- /dev/null +++ b/egs/chime6/s5_track2/local/get_hyp_perspeaker_perarray_file.py @@ -0,0 +1,56 @@ +#! /usr/bin/env python +# Copyright 2019 Ashish Arora +# Apache 2.0. +"""This script splits a kaldi (text) file + into per_array per_session per_speaker hypothesis (text) files""" + +import argparse +def get_args(): + parser = argparse.ArgumentParser( + description="""This script splits a kaldi text file + into per_array per_session per_speaker text files""") + parser.add_argument("input_text_path", type=str, + help="path of text files") + parser.add_argument("output_dir_path", type=str, + help="Output path for per_array per_session per_speaker reference files") + args = parser.parse_args() + return args + + +def main(): + # S09_U06.ENH-4-704588-704738 + args = get_args() + sessionid_micid_speakerid_dict= {} + for line in open(args.input_text_path): + parts = line.strip().split() + uttid_id = parts[0] + temp = uttid_id.strip().split('.')[0] + micid = temp.strip().split('_')[1] + speakerid = uttid_id.strip().split('-')[1] + sessionid = uttid_id.strip().split('_')[0] + sessionid_micid_speakerid = sessionid + '_' + micid + '_' + speakerid + if sessionid_micid_speakerid not in sessionid_micid_speakerid_dict: + sessionid_micid_speakerid_dict[sessionid_micid_speakerid]=list() + sessionid_micid_speakerid_dict[sessionid_micid_speakerid].append(line) + + for sessionid_micid_speakerid in sorted(sessionid_micid_speakerid_dict): + hyp_file = args.output_dir_path + '/' + 'hyp' + '_' + sessionid_micid_speakerid + hyp_writer = open(hyp_file, 'w') + combined_hyp_file = args.output_dir_path + '/' + 'hyp' + '_' + sessionid_micid_speakerid + '_comb' + combined_hyp_writer = open(combined_hyp_file, 'w') + utterances = sessionid_micid_speakerid_dict[sessionid_micid_speakerid] + text = '' + for line in utterances: + parts = line.strip().split() + text = text + ' ' + ' '.join(parts[1:]) + hyp_writer.write(line) + combined_utterance = 'utt' + " " + text + combined_hyp_writer.write(combined_utterance) + combined_hyp_writer.write('\n') + combined_hyp_writer.close() + hyp_writer.close() + + +if __name__ == '__main__': + main() + diff --git a/egs/chime6/s5_track2/local/get_ref_perspeaker_persession_file.py b/egs/chime6/s5_track2/local/get_ref_perspeaker_persession_file.py new file mode 100755 index 000000000..6b00e29e6 --- /dev/null +++ b/egs/chime6/s5_track2/local/get_ref_perspeaker_persession_file.py @@ -0,0 +1,79 @@ +#! /usr/bin/env python +# Copyright 2019 Ashish Arora +# Apache 2.0. +"""This script splits a kaldi (text) file + into per_speaker per_session reference (text) file""" + +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script splits a kaldi text file + into per_speaker per_session text files""") + parser.add_argument("input_text_path", type=str, + help="path of text file") + parser.add_argument("output_dir_path", type=str, + help="Output path for per_session per_speaker reference files") + args = parser.parse_args() + return args + + +def main(): + args = get_args() + sessionid_speakerid_dict= {} + spkrid_mapping = {} + for line in open(args.input_text_path): + parts = line.strip().split() + uttid_id = parts[0] + speakerid = uttid_id.strip().split('_')[0] + sessionid = uttid_id.strip().split('_')[1] + sessionid_speakerid = sessionid + '_' + speakerid + if sessionid_speakerid not in sessionid_speakerid_dict: + sessionid_speakerid_dict[sessionid_speakerid]=list() + sessionid_speakerid_dict[sessionid_speakerid].append(line) + + spkr_num = 1 + prev_sessionid = '' + for sessionid_speakerid in sorted(sessionid_speakerid_dict): + spkr_id = sessionid_speakerid.strip().split('_')[1] + curr_sessionid = sessionid_speakerid.strip().split('_')[0] + if prev_sessionid != curr_sessionid: + prev_sessionid = curr_sessionid + spkr_num = 1 + if spkr_id not in spkrid_mapping: + spkrid_mapping[spkr_id] = spkr_num + spkr_num += 1 + + for sessionid_speakerid in sorted(sessionid_speakerid_dict): + ref_file = args.output_dir_path + '/ref_' + sessionid_speakerid.split('_')[0] + '_' + str( + spkrid_mapping[sessionid_speakerid.split('_')[1]]) + ref_writer = open(ref_file, 'w') + wc_file = args.output_dir_path + '/ref_wc_' + sessionid_speakerid.split('_')[0] + '_' + str( + spkrid_mapping[sessionid_speakerid.split('_')[1]]) + wc_writer = open(wc_file, 'w') + combined_ref_file = args.output_dir_path + '/ref_' + sessionid_speakerid.split('_')[0] + '_' + str( + spkrid_mapping[sessionid_speakerid.split('_')[1]]) + '_comb' + combined_ref_writer = open(combined_ref_file, 'w') + utterances = sessionid_speakerid_dict[sessionid_speakerid] + text = '' + uttid_wc = 'utt' + for line in utterances: + parts = line.strip().split() + uttid_id = parts[0] + utt_text = ' '.join(parts[1:]) + text = text + ' ' + ' '.join(parts[1:]) + ref_writer.write(line) + length = str(len(utt_text.split())) + uttid_id_len = uttid_id + ":" + length + uttid_wc = uttid_wc + ' ' + uttid_id_len + combined_utterance = 'utt' + " " + text + combined_ref_writer.write(combined_utterance) + combined_ref_writer.write('\n') + combined_ref_writer.close() + wc_writer.write(uttid_wc) + wc_writer.write('\n') + wc_writer.close() + ref_writer.close() + +if __name__ == '__main__': + main() diff --git a/egs/chime6/s5_track2/local/install_pb_chime5.sh b/egs/chime6/s5_track2/local/install_pb_chime5.sh new file mode 120000 index 000000000..ce5ea5f9f --- /dev/null +++ b/egs/chime6/s5_track2/local/install_pb_chime5.sh @@ -0,0 +1 @@ +../../s5_track1/local/install_pb_chime5.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/json2text.py b/egs/chime6/s5_track2/local/json2text.py new file mode 120000 index 000000000..2aa0a8dd1 --- /dev/null +++ b/egs/chime6/s5_track2/local/json2text.py @@ -0,0 +1 @@ +../../s5_track1/local/json2text.py \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/make_noise_list.py b/egs/chime6/s5_track2/local/make_noise_list.py new file mode 120000 index 000000000..d8dcc7822 --- /dev/null +++ b/egs/chime6/s5_track2/local/make_noise_list.py @@ -0,0 +1 @@ +../../s5_track1/local/make_noise_list.py \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/make_voxceleb1.pl b/egs/chime6/s5_track2/local/make_voxceleb1.pl new file mode 100755 index 000000000..2268c20ab --- /dev/null +++ b/egs/chime6/s5_track2/local/make_voxceleb1.pl @@ -0,0 +1,130 @@ +#!/usr/bin/perl +# +# Copyright 2018 Ewald Enzinger +# 2018 David Snyder +# +# Usage: make_voxceleb1.pl /export/voxceleb1 data/ + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/voxceleb1 data/\n"; + exit(1); +} + +($data_base, $out_dir) = @ARGV; +my $out_test_dir = "$out_dir/voxceleb1_test"; +my $out_train_dir = "$out_dir/voxceleb1_train"; + +if (system("mkdir -p $out_test_dir") != 0) { + die "Error making directory $out_test_dir"; +} + +if (system("mkdir -p $out_train_dir") != 0) { + die "Error making directory $out_train_dir"; +} + +opendir my $dh, "$data_base/voxceleb1_wav" or die "Cannot open directory: $!"; +my @spkr_dirs = grep {-d "$data_base/voxceleb1_wav/$_" && ! /^\.{1,2}$/} readdir($dh); +closedir $dh; + +if (! -e "$data_base/voxceleb1_test.txt") { + system("wget -O $data_base/voxceleb1_test.txt http://www.openslr.org/resources/49/voxceleb1_test.txt"); +} + +if (! -e "$data_base/vox1_meta.csv") { + system("wget -O $data_base/vox1_meta.csv http://www.openslr.org/resources/49/vox1_meta.csv"); +} + +open(TRIAL_IN, "<", "$data_base/voxceleb1_test.txt") or die "Could not open the verification trials file $data_base/voxceleb1_test.txt"; +open(META_IN, "<", "$data_base/vox1_meta.csv") or die "Could not open the meta data file $data_base/vox1_meta.csv"; +open(SPKR_TEST, ">", "$out_test_dir/utt2spk") or die "Could not open the output file $out_test_dir/utt2spk"; +open(WAV_TEST, ">", "$out_test_dir/wav.scp") or die "Could not open the output file $out_test_dir/wav.scp"; +open(SPKR_TRAIN, ">", "$out_train_dir/utt2spk") or die "Could not open the output file $out_train_dir/utt2spk"; +open(WAV_TRAIN, ">", "$out_train_dir/wav.scp") or die "Could not open the output file $out_train_dir/wav.scp"; +open(TRIAL_OUT, ">", "$out_test_dir/trials") or die "Could not open the output file $out_test_dir/trials"; + +my %id2spkr = (); +while () { + chomp; + my ($vox_id, $spkr_id, $gender, $nation, $set) = split; + $id2spkr{$vox_id} = $spkr_id; +} + +my $test_spkrs = (); +while () { + chomp; + my ($tar_or_non, $path1, $path2) = split; + + # Create entry for left-hand side of trial + my ($spkr_id, $filename) = split('/', $path1); + my $rec_id = substr($filename, 0, 11); + my $segment = substr($filename, 12, 7); + my $utt_id1 = "$spkr_id-$rec_id-$segment"; + $test_spkrs{$spkr_id} = (); + + # Create entry for right-hand side of trial + my ($spkr_id, $filename) = split('/', $path2); + my $rec_id = substr($filename, 0, 11); + my $segment = substr($filename, 12, 7); + my $utt_id2 = "$spkr_id-$rec_id-$segment"; + $test_spkrs{$spkr_id} = (); + + my $target = "nontarget"; + if ($tar_or_non eq "1") { + $target = "target"; + } + print TRIAL_OUT "$utt_id1 $utt_id2 $target\n"; +} + +foreach (@spkr_dirs) { + my $spkr_id = $_; + my $new_spkr_id = $spkr_id; + # If we're using a newer version of VoxCeleb1, we need to "deanonymize" + # the speaker labels. + if (exists $id2spkr{$spkr_id}) { + $new_spkr_id = $id2spkr{$spkr_id}; + } + opendir my $dh, "$data_base/voxceleb1_wav/$spkr_id/" or die "Cannot open directory: $!"; + my @files = map{s/\.[^.]+$//;$_}grep {/\.wav$/} readdir($dh); + closedir $dh; + foreach (@files) { + my $filename = $_; + my $rec_id = substr($filename, 0, 11); + my $segment = substr($filename, 12, 7); + my $wav = "$data_base/voxceleb1_wav/$spkr_id/$filename.wav"; + my $utt_id = "$new_spkr_id-$rec_id-$segment"; + if (exists $test_spkrs{$new_spkr_id}) { + print WAV_TEST "$utt_id", " $wav", "\n"; + print SPKR_TEST "$utt_id", " $new_spkr_id", "\n"; + } else { + print WAV_TRAIN "$utt_id", " $wav", "\n"; + print SPKR_TRAIN "$utt_id", " $new_spkr_id", "\n"; + } + } +} + +close(SPKR_TEST) or die; +close(WAV_TEST) or die; +close(SPKR_TRAIN) or die; +close(WAV_TRAIN) or die; +close(TRIAL_OUT) or die; +close(TRIAL_IN) or die; +close(META_IN) or die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_test_dir/utt2spk >$out_test_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_test_dir"; +} +system("env LC_COLLATE=C utils/fix_data_dir.sh $out_test_dir"); +if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_test_dir") != 0) { + die "Error validating directory $out_test_dir"; +} + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_train_dir/utt2spk >$out_train_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_train_dir"; +} +system("env LC_COLLATE=C utils/fix_data_dir.sh $out_train_dir"); +if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_train_dir") != 0) { + die "Error validating directory $out_train_dir"; +} diff --git a/egs/chime6/s5_track2/local/make_voxceleb2.pl b/egs/chime6/s5_track2/local/make_voxceleb2.pl new file mode 100755 index 000000000..34c1591eb --- /dev/null +++ b/egs/chime6/s5_track2/local/make_voxceleb2.pl @@ -0,0 +1,70 @@ +#!/usr/bin/perl +# +# Copyright 2018 Ewald Enzinger +# +# Usage: make_voxceleb2.pl /export/voxceleb2 dev data/dev +# +# Note: This script requires ffmpeg to be installed and its location included in $PATH. + +if (@ARGV != 3) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/voxceleb2 dev data/dev\n"; + exit(1); +} + +# Check that ffmpeg is installed. +if (`which ffmpeg` eq "") { + die "Error: this script requires that ffmpeg is installed."; +} + +($data_base, $dataset, $out_dir) = @ARGV; + +if ("$dataset" ne "dev" && "$dataset" ne "test") { + die "dataset parameter must be 'dev' or 'test'!"; +} + +opendir my $dh, "$data_base/$dataset/aac" or die "Cannot open directory: $!"; +my @spkr_dirs = grep {-d "$data_base/$dataset/aac/$_" && ! /^\.{1,2}$/} readdir($dh); +closedir $dh; + +if (system("mkdir -p $out_dir") != 0) { + die "Error making directory $out_dir"; +} + +open(SPKR, ">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk"; +open(WAV, ">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp"; + +foreach (@spkr_dirs) { + my $spkr_id = $_; + + opendir my $dh, "$data_base/$dataset/aac/$spkr_id/" or die "Cannot open directory: $!"; + my @rec_dirs = grep {-d "$data_base/$dataset/aac/$spkr_id/$_" && ! /^\.{1,2}$/} readdir($dh); + closedir $dh; + + foreach (@rec_dirs) { + my $rec_id = $_; + + opendir my $dh, "$data_base/$dataset/aac/$spkr_id/$rec_id/" or die "Cannot open directory: $!"; + my @files = map{s/\.[^.]+$//;$_}grep {/\.m4a$/} readdir($dh); + closedir $dh; + + foreach (@files) { + my $name = $_; + my $wav = "ffmpeg -v 8 -i $data_base/$dataset/aac/$spkr_id/$rec_id/$name.m4a -f wav -acodec pcm_s16le -|"; + my $utt_id = "$spkr_id-$rec_id-$name"; + print WAV "$utt_id", " $wav", "\n"; + print SPKR "$utt_id", " $spkr_id", "\n"; + } + } +} +close(SPKR) or die; +close(WAV) or die; + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} +system("env LC_COLLATE=C utils/fix_data_dir.sh $out_dir"); +if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} diff --git a/egs/chime6/s5_track2/local/multispeaker_score.sh b/egs/chime6/s5_track2/local/multispeaker_score.sh new file mode 100755 index 000000000..e632381ad --- /dev/null +++ b/egs/chime6/s5_track2/local/multispeaker_score.sh @@ -0,0 +1,130 @@ +#!/bin/bash +# Copyright 2019 Ashish Arora, Yusuke Fujita +# Apache 2.0. +# This script takes a reference and hypothesis text file, and performs +# multispeaker scoring. + +stage=0 +cmd=queue.pl +num_spkrs=4 +num_hyp_spk=4 +datadir=dev_beamformit_dereverb +declare -a recording_id_array=("S02_U06" "S09_U06") +echo "$0 $@" # Print the command line for logging +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "Usage: $0 " + echo "e.g.: $0 data/diarized/text data/dev \ + exp/chain_train_worn_simu_u400k_cleaned_rvb/tdnn1b_sp/decode_dev_xvector_sad/scoring_kaldi/penalty_1.0/10.txt \ + exp/chain_train_worn_simu_u400k_cleaned_rvb/tdnn1b_sp/decode_dev_xvector_sad/scoring_kaldi_multispeaker" + echo "Options: " + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +ref_file=$1 +hyp_file=$2 +out_dir=$3 + +output_dir=$out_dir/per_speaker_output +wer_dir=$out_dir/per_speaker_wer + +# For dev and evaluation set, we take corresopnding arrays +if [[ ${datadir} == *dev* ]]; then + recording_id_array=("S02_U06" "S09_U06") +fi + +if [[ ${datadir} == *eval* ]]; then + recording_id_array=("S01_U06" "S21_U06") +fi + +for f in $ref_file $hyp_file; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +if [ $stage -le 0 ]; then + echo "$0 generate per speaker per session file at paragraph level for the reference" + echo "and per speaker per array file at paraghaph level for the hypothesis" + mkdir -p $output_dir $wer_dir + local/wer_output_filter < $ref_file > $output_dir/ref_filt.txt + local/wer_output_filter < $hyp_file > $output_dir/hyp_filt.txt + local/get_ref_perspeaker_persession_file.py $output_dir/ref_filt.txt $output_dir + local/get_hyp_perspeaker_perarray_file.py $output_dir/hyp_filt.txt $output_dir +fi + +if [ $stage -le 1 ]; then + if [ $num_hyp_spk -le 3 ]; then + echo "$0 create dummy per speaker per array hypothesis files for if the" + echo " perdicted number of speakers by diarization is less than 4 " + for recording_id in "${recording_id_array[@]}"; do + for (( i=$num_hyp_spk+1; i<$num_spkrs+1; i++ )); do + echo 'utt ' > ${dir}/hyp_${recording_id}_${i}_comb + done + done + fi +fi + +if [ $stage -le 2 ]; then + echo "$0 calculate wer for each ref and hypothesis speaker" + for recording_id in "${recording_id_array[@]}"; do + for (( i=0; i<$((num_spkrs * num_spkrs)); i++ )); do + ind_r=$((i / num_spkrs + 1)) + ind_h=$((i % num_spkrs + 1)) + sessionid="$(echo $recording_id | cut -d'_' -f1)" + + # compute WER with combined texts + compute-wer --text --mode=present ark:${output_dir}/ref_${sessionid}_${ind_r}_comb \ + ark:${output_dir}/hyp_${recording_id}_${ind_h}_comb \ + > $wer_dir/wer_${recording_id}_r${ind_r}h${ind_h} 2>/dev/null + done + + local/get_best_error.py $wer_dir $recording_id + done +fi + +if [ $stage -le 3 ]; then + echo "$0 print best word error rate" + echo "$0 it will print best wer for each recording and each array" + cat $wer_dir/best_wer* > $wer_dir/all.txt + cat $wer_dir/all.txt | local/print_dset_error.py $output_dir/recordinid_spkorder +fi + +mkdir -p $wer_dir/wer_details $wer_dir/wer_details/log/ +if [ $stage -le 4 ]; then + echo "$0 generate per utterance wer details at utterance level" + while read -r line; + do + recording_id=$(echo "$line" | cut -f1 -d ":") + spkorder_str=$(echo "$line" | cut -f2 -d ":") + sessionid=$(echo "$line" | cut -f1 -d "_") + IFS='_' read -r -a spkorder_list <<< "$spkorder_str" + IFS=" " + ind_r=1 + for ind_h in "${spkorder_list[@]}"; do + + $cmd $wer_dir/wer_details/log/${recording_id}_r${ind_r}h${ind_h}_comb.log \ + align-text ark:${output_dir}/ref_${sessionid}_${ind_r}_comb ark:${output_dir}/hyp_${recording_id}_${ind_h}_comb ark:$output_dir/alignment_${sessionid}_r${ind_r}h${ind_h}.txt + + # split hypothesis texts along with reference utterances using word alignment of combined texts + local/gen_aligned_hyp.py $output_dir/alignment_${sessionid}_r${ind_r}h${ind_h}.txt ${output_dir}/ref_wc_${sessionid}_${ind_r} > ${output_dir}/hyp_${recording_id}_r${ind_r}h${ind_h}_ref_segmentation + + ## compute per utterance alignments + $cmd $wer_dir/wer_details/log/${recording_id}_r${ind_r}h${ind_h}_per_utt.log \ + cat ${output_dir}/hyp_${recording_id}_r${ind_r}h${ind_h}_ref_segmentation \| \ + align-text --special-symbol="'***'" ark:${output_dir}/ref_${sessionid}_${ind_r} ark:- ark,t:- \| \ + utils/scoring/wer_per_utt_details.pl --special-symbol "'***'" \| tee $wer_dir/wer_details/per_utt_${recording_id}_r${ind_r}h${ind_h} || exit 1 + + $cmd $wer_dir/wer_details/log/${recording_id}_r${ind_r}h${ind_h}_ops.log \ + cat $wer_dir/wer_details/per_utt_${recording_id}_r${ind_r}h${ind_h} \| \ + utils/scoring/wer_ops_details.pl --special-symbol "'***'" \| \ + sort -b -i -k 1,1 -k 4,4rn -k 2,2 -k 3,3 \> $wer_dir/wer_details/ops_${recording_id}_r${ind_r}h${ind_h} || exit 1; + + ind_r=$(( ind_r + 1 )) + done + done < $output_dir/recordinid_spkorder + echo "$0 done generating per utterance wer details" +fi + +echo "$0 done scoring" diff --git a/egs/chime6/s5_track2/local/nnet3/compare_wer.sh b/egs/chime6/s5_track2/local/nnet3/compare_wer.sh new file mode 120000 index 000000000..87041e833 --- /dev/null +++ b/egs/chime6/s5_track2/local/nnet3/compare_wer.sh @@ -0,0 +1 @@ +../../../s5_track1/local/nnet3/compare_wer.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/nnet3/decode.sh b/egs/chime6/s5_track2/local/nnet3/decode.sh new file mode 120000 index 000000000..32595cced --- /dev/null +++ b/egs/chime6/s5_track2/local/nnet3/decode.sh @@ -0,0 +1 @@ +../../../s5_track1/local/nnet3/decode.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/nnet3/run_ivector_common.sh b/egs/chime6/s5_track2/local/nnet3/run_ivector_common.sh new file mode 120000 index 000000000..4161993c2 --- /dev/null +++ b/egs/chime6/s5_track2/local/nnet3/run_ivector_common.sh @@ -0,0 +1 @@ +../../../s5_track1/local/nnet3/run_ivector_common.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/nnet3/xvector/prepare_feats.sh b/egs/chime6/s5_track2/local/nnet3/xvector/prepare_feats.sh new file mode 100755 index 000000000..cb8fe2e63 --- /dev/null +++ b/egs/chime6/s5_track2/local/nnet3/xvector/prepare_feats.sh @@ -0,0 +1,89 @@ +#!/bin/bash +# +# Apache 2.0. + +# This script applies sliding window CMVN and writes the features to disk. +# +# Although this kind of script isn't necessary in speaker recognition recipes, +# it can be helpful in the diarization recipes. The script +# diarization/nnet3/xvector/extract_xvectors.sh extracts x-vectors from very +# short (e.g., 1-2 seconds) segments. Therefore, in order to apply the sliding +# window CMVN in a meaningful way, it must be performed prior to performing +# the subsegmentation. + +nj=40 +cmd="run.pl" +stage=0 +norm_vars=false +center=true +compress=true +cmn_window=300 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; +if [ $# != 3 ]; then + echo "Usage: $0 " + echo "e.g.: $0 data/train data/train_no_sil exp/make_xvector_features" + echo "Options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --norm-vars # If true, normalize variances in the sliding window cmvn" + exit 1; +fi + +data_in=$1 +data_out=$2 +dir=$3 + +name=`basename $data_in` + +for f in $data_in/feats.scp ; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +# Set various variables. +mkdir -p $dir/log +mkdir -p $data_out +featdir=$(utils/make_absolute.sh $dir) + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $featdir/storage ]; then + utils/create_split_dir.pl \ + /export/b{14,15,16,17}/$USER/kaldi-data/egs/callhome_diarization/v2/xvector-$(date +'%m_%d_%H_%M')/xvector_cmvn_feats/storage $featdir/storage +fi + +for n in $(seq $nj); do + # the next command does nothing unless $featdir/storage/ exists, see + # utils/create_data_link.pl for more info. + utils/create_data_link.pl $featdir/xvector_cmvn_feats_${name}.${n}.ark +done + +cp $data_in/utt2spk $data_out/utt2spk +cp $data_in/spk2utt $data_out/spk2utt +cp $data_in/wav.scp $data_out/wav.scp +for f in $data_in/segments $data_in/segments/vad.scp ; do + [ -f $f ] && cp $f $data_out/`basename $f`; +done + +write_num_frames_opt="--write-num-frames=ark,t:$featdir/log/utt2num_frames.JOB" + +sdata_in=$data_in/split$nj; +utils/split_data.sh $data_in $nj || exit 1; + +$cmd JOB=1:$nj $dir/log/create_xvector_cmvn_feats_${name}.JOB.log \ + apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=$cmn_window \ + scp:${sdata_in}/JOB/feats.scp ark:- \| \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ + ark,scp:$featdir/xvector_cmvn_feats_${name}.JOB.ark,$featdir/xvector_cmvn_feats_${name}.JOB.scp || exit 1; + +for n in $(seq $nj); do + cat $featdir/xvector_cmvn_feats_${name}.$n.scp || exit 1; +done > ${data_out}/feats.scp || exit 1 + +for n in $(seq $nj); do + cat $featdir/log/utt2num_frames.$n || exit 1; +done > $data_out/utt2num_frames || exit 1 +rm $featdir/log/utt2num_frames.* + +echo "$0: Succeeded creating xvector features for $name" diff --git a/egs/chime6/s5_track2/local/nnet3/xvector/prepare_feats_for_egs.sh b/egs/chime6/s5_track2/local/nnet3/xvector/prepare_feats_for_egs.sh new file mode 100755 index 000000000..dcdbe1b15 --- /dev/null +++ b/egs/chime6/s5_track2/local/nnet3/xvector/prepare_feats_for_egs.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# +# Apache 2.0. + +# This script applies sliding window CMVN and removes silence frames. This +# is performed on the raw features prior to generating examples for training +# the x-vector system. Once the training examples are generated, the features +# created by this script can be removed. + +nj=40 +cmd="run.pl" +stage=0 +norm_vars=false +center=true +compress=true +cmn_window=300 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; +if [ $# != 3 ]; then + echo "Usage: $0 " + echo "e.g.: $0 data/train data/train_no_sil exp/make_xvector_features" + echo "Options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --norm-vars # If true, normalize variances in the sliding window cmvn" + exit 1; +fi + +data_in=$1 +data_out=$2 +dir=$3 + +name=`basename $data_in` + +for f in $data_in/feats.scp $data_in/vad.scp ; do + [ ! -f $f ] && echo "$0: No such file $f" && exit 1; +done + +# Set various variables. +mkdir -p $dir/log +mkdir -p $data_out +featdir=$(utils/make_absolute.sh $dir) + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $featdir/storage ]; then + utils/create_split_dir.pl \ + /export/b{14,15,16,17}/$USER/kaldi-data/egs/callhome_diarization/v2/xvector-$(date +'%m_%d_%H_%M')/xvector_feats/storage $featdir/storage +fi + +for n in $(seq $nj); do + # the next command does nothing unless $featdir/storage/ exists, see + # utils/create_data_link.pl for more info. + utils/create_data_link.pl $featdir/xvector_feats_${name}.${n}.ark +done + +cp $data_in/utt2spk $data_out/utt2spk +cp $data_in/spk2utt $data_out/spk2utt +cp $data_in/wav.scp $data_out/wav.scp + +write_num_frames_opt="--write-num-frames=ark,t:$featdir/log/utt2num_frames.JOB" + +sdata_in=$data_in/split$nj; +utils/split_data.sh $data_in $nj || exit 1; + +$cmd JOB=1:$nj $dir/log/create_xvector_feats_${name}.JOB.log \ + apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=$cmn_window \ + scp:${sdata_in}/JOB/feats.scp ark:- \| \ + select-voiced-frames ark:- scp,s,cs:${sdata_in}/JOB/vad.scp ark:- \| \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ + ark,scp:$featdir/xvector_feats_${name}.JOB.ark,$featdir/xvector_feats_${name}.JOB.scp || exit 1; + +for n in $(seq $nj); do + cat $featdir/xvector_feats_${name}.$n.scp || exit 1; +done > ${data_out}/feats.scp || exit 1 + +for n in $(seq $nj); do + cat $featdir/log/utt2num_frames.$n || exit 1; +done > $data_out/utt2num_frames || exit 1 +rm $featdir/log/utt2num_frames.* + +echo "$0: Succeeded creating xvector features for $name" diff --git a/egs/chime6/s5_track2/local/nnet3/xvector/run_xvector.sh b/egs/chime6/s5_track2/local/nnet3/xvector/run_xvector.sh new file mode 120000 index 000000000..585b63fd2 --- /dev/null +++ b/egs/chime6/s5_track2/local/nnet3/xvector/run_xvector.sh @@ -0,0 +1 @@ +tuning/run_xvector_1a.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/nnet3/xvector/tuning/run_xvector_1a.sh b/egs/chime6/s5_track2/local/nnet3/xvector/tuning/run_xvector_1a.sh new file mode 100755 index 000000000..94fc7e768 --- /dev/null +++ b/egs/chime6/s5_track2/local/nnet3/xvector/tuning/run_xvector_1a.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# Copyright 2018 David Snyder +# 2018 Johns Hopkins University (Author: Daniel Garcia-Romero) +# 2018 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0. + +# This script trains the x-vector DNN. The recipe is similar to the one +# described in "Diarization is Hard: Some Experiences and Lessons Learned +# for the JHU Team in the Inaugural DIHARD Challenge" by Sell et al. + +. ./cmd.sh +set -e + +stage=1 +train_stage=-1 +use_gpu=true +remove_egs=false + +data=data/train +nnet_dir=exp/xvector_nnet_1a/ +egs_dir=exp/xvector_nnet_1a/egs + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + +num_pdfs=$(awk '{print $2}' $data/utt2spk | sort | uniq -c | wc -l) + +# Now we create the nnet examples using sid/nnet3/xvector/get_egs.sh. +# The argument --num-repeats is related to the number of times a speaker +# repeats per archive. If it seems like you're getting too many archives +# (e.g., more than 200) try increasing the --frames-per-iter option. The +# arguments --min-frames-per-chunk and --max-frames-per-chunk specify the +# minimum and maximum length (in terms of number of frames) of the features +# in the examples. +# +# To make sense of the egs script, it may be necessary to put an "exit 1" +# command immediately after stage 3. Then, inspect +# exp//egs/temp/ranges.* . The ranges files specify the examples that +# will be created, and which archives they will be stored in. Each line of +# ranges.* has the following form: +# +# For example: +# 100304-f-sre2006-kacg-A 1 2 4079 881 23 + +# If you're satisfied with the number of archives (e.g., 50-150 archives is +# reasonable) and with the number of examples per speaker (e.g., 1000-5000 +# is reasonable) then you can let the script continue to the later stages. +# Otherwise, try increasing or decreasing the --num-repeats option. You might +# need to fiddle with --frames-per-iter. Increasing this value decreases the +# the number of archives and increases the number of examples per archive. +# Decreasing this value increases the number of archives, while decreasing the +# number of examples per archive. +if [ $stage -le 6 ]; then + echo "$0: Getting neural network training egs"; + # dump egs. + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $egs_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/callhome_diarization/v2/xvector-$(date +'%m_%d_%H_%M')/$egs_dir/storage $egs_dir/storage + fi + sid/nnet3/xvector/get_egs.sh --cmd "$train_cmd" \ + --nj 8 \ + --stage 0 \ + --frames-per-iter 1000000000 \ + --frames-per-iter-diagnostic 500000 \ + --min-frames-per-chunk 200 \ + --max-frames-per-chunk 400 \ + --num-diagnostic-archives 3 \ + --num-repeats 40 \ + "$data" $egs_dir +fi + +if [ $stage -le 7 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + num_targets=$(wc -w $egs_dir/pdf2num | awk '{print $1}') + feat_dim=$(cat $egs_dir/info/feat_dim) + + # This chunk-size corresponds to the maximum number of frames the + # stats layer is able to pool over. In this script, it corresponds + # to 4 seconds. If the input recording is greater than 4 seconds, + # we will compute multiple xvectors from the same recording and average + # to produce the final xvector. + max_chunk_size=400 + + # The smallest number of frames we're comfortable computing an xvector from. + # Note that the hard minimum is given by the left and right context of the + # frame-level layers. + min_chunk_size=20 + mkdir -p $nnet_dir/configs + cat < $nnet_dir/configs/network.xconfig + # please note that it is important to have input layer with the name=input + + # The frame-level layers + input dim=${feat_dim} name=input + relu-batchnorm-layer name=tdnn1 input=Append(-2,-1,0,1,2) dim=512 + relu-batchnorm-layer name=tdnn2 input=Append(-2,0,2) dim=512 + relu-batchnorm-layer name=tdnn3 input=Append(-3,0,3) dim=512 + relu-batchnorm-layer name=tdnn4 dim=512 + relu-batchnorm-layer name=tdnn5 dim=1500 + + # The stats pooling layer. Layers after this are segment-level. + # In the config below, the first and last argument (0, and ${max_chunk_size}) + # means that we pool over an input segment starting at frame 0 + # and ending at frame ${max_chunk_size} or earlier. The other arguments (1:1) + # mean that no subsampling is performed. + stats-layer name=stats config=mean+stddev(0:1:1:${max_chunk_size}) + + # This is where we usually extract the embedding (aka xvector) from. + relu-batchnorm-layer name=tdnn6 dim=128 input=stats + output-layer name=output include-log-softmax=true dim=${num_targets} +EOF + + steps/nnet3/xconfig_to_configs.py \ + --xconfig-file $nnet_dir/configs/network.xconfig \ + --config-dir $nnet_dir/configs/ + cp $nnet_dir/configs/final.config $nnet_dir/nnet.config + + # These three files will be used by sid/nnet3/xvector/extract_xvectors.sh + echo "output-node name=output input=tdnn6.affine" > $nnet_dir/extract.config + echo "$max_chunk_size" > $nnet_dir/max_chunk_size + echo "$min_chunk_size" > $nnet_dir/min_chunk_size +fi + +dropout_schedule='0,0@0.20,0.1@0.50,0' +srand=123 +if [ $stage -le 8 ]; then + steps/nnet3/train_raw_dnn.py --stage=$train_stage \ + --cmd="$train_cmd" \ + --trainer.optimization.proportional-shrink 10 \ + --trainer.optimization.momentum=0.5 \ + --trainer.optimization.num-jobs-initial=3 \ + --trainer.optimization.num-jobs-final=8 \ + --trainer.optimization.initial-effective-lrate=0.001 \ + --trainer.optimization.final-effective-lrate=0.0001 \ + --trainer.optimization.minibatch-size=64 \ + --trainer.srand=$srand \ + --trainer.max-param-change=2 \ + --trainer.num-epochs=3 \ + --trainer.dropout-schedule="$dropout_schedule" \ + --trainer.shuffle-buffer-size=1000 \ + --egs.frames-per-eg=1 \ + --egs.dir="$egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --dir=$nnet_dir || exit 1; +fi + +exit 0; diff --git a/egs/chime6/s5_track2/local/prepare_data.sh b/egs/chime6/s5_track2/local/prepare_data.sh new file mode 100755 index 000000000..c6b8121da --- /dev/null +++ b/egs/chime6/s5_track2/local/prepare_data.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# +# Copyright 2017 Johns Hopkins University (Author: Shinji Watanabe, Yenda Trmal) +# Apache 2.0 + +# Begin configuration section. +mictype=worn # worn, ref or others +cleanup=true +train=true + +# End configuration section +. ./utils/parse_options.sh # accept options.. you can run this run.sh with the + +. ./path.sh + +echo >&2 "$0" "$@" +if [ $# -ne 3 ] ; then + echo >&2 "$0" "$@" + echo >&2 "$0: Error: wrong number of arguments" + echo -e >&2 "Usage:\n $0 [opts] " + echo -e >&2 "eg:\n $0 /corpora/chime5/audio/train /corpora/chime5/transcriptions/train data/train" + exit 1 +fi + +set -e -o pipefail + +adir=$1 +jdir=$2 +dir=$3 + +json_count=$(find -L $jdir -name "*.json" | wc -l) +wav_count=$(find -L $adir -name "*.wav" | wc -l) + +if [ "$json_count" -eq 0 ]; then + echo >&2 "We expect that the directory $jdir will contain json files." + echo >&2 "That implies you have supplied a wrong path to the data." + exit 1 +fi +if [ "$wav_count" -eq 0 ]; then + echo >&2 "We expect that the directory $adir will contain wav files." + echo >&2 "That implies you have supplied a wrong path to the data." + exit 1 +fi + +echo "$0: Converting transcription to text" + +mkdir -p $dir +for file in $jdir/*json; do + ./local/json2text.py --mictype $mictype $file +done | \ + sed -e "s/\[inaudible[- 0-9]*\]/[inaudible]/g" |\ + sed -e 's/ - / /g' |\ + sed -e 's/mm-/mm/g' > $dir/text.orig + +echo "$0: Creating datadir $dir for type=\"$mictype\"" + +if [ $mictype == "worn" ]; then + # convert the filenames to wav.scp format, use the basename of the file + # as a the wav.scp key, add .L and .R for left and right channel + # i.e. each file will have two entries (left and right channel) + find -L $adir -name "S[0-9]*_P[0-9]*.wav" | \ + perl -ne '{ + chomp; + $path = $_; + next unless $path; + @F = split "/", $path; + ($f = $F[@F-1]) =~ s/.wav//; + @F = split "_", $f; + print "${F[1]}_${F[0]}.L sox $path -t wav - remix 1 |\n"; + print "${F[1]}_${F[0]}.R sox $path -t wav - remix 2 |\n"; + }' | sort > $dir/wav.scp + + # generate the transcripts for both left and right channel + # from the original transcript in the form + # P09_S03-0006072-0006147 gimme the baker + # create left and right channel transcript + # P09_S03.L-0006072-0006147 gimme the baker + # P09_S03.R-0006072-0006147 gimme the baker + sed -n 's/ *$//; h; s/-/\.L-/p; g; s/-/\.R-/p' $dir/text.orig | sort > $dir/text +elif [ $mictype == "ref" ]; then + # fixed reference array + + # first get a text, which will be used to extract reference arrays + perl -ne 's/-/.ENH-/;print;' $dir/text.orig | sort > $dir/text + + find -L $adir | grep "\.wav" | sort > $dir/wav.flist + # following command provide the argument for grep to extract only reference arrays + #grep `cut -f 1 -d"-" $dir/text | awk -F"_" '{print $2 "_" $3}' | sed -e "s/\.ENH//" | sort | uniq | sed -e "s/^/ -e /" | tr "\n" " "` $dir/wav.flist > $dir/wav.flist2 + paste -d" " \ + <(awk -F "/" '{print $NF}' $dir/wav.flist | sed -e "s/\.wav/.ENH/") \ + $dir/wav.flist | sort > $dir/wav.scp +else + # array mic case + # convert the filenames to wav.scp format, use the basename of the file + # as a the wav.scp key + find -L $adir -name "*.wav" -ipath "*${mictype}*" |\ + perl -ne '$p=$_;chomp $_;@F=split "/";$F[$#F]=~s/\.wav//;print "$F[$#F] $p";' |\ + sort -u > $dir/wav.scp + + # convert the transcripts from + # P09_S03-0006072-0006147 gimme the baker + # to the per-channel transcripts + # P09_S03_U01_NOLOCATION.CH1-0006072-0006147 gimme the baker + # P09_S03_U01_NOLOCATION.CH2-0006072-0006147 gimme the baker + # P09_S03_U01_NOLOCATION.CH3-0006072-0006147 gimme the baker + # P09_S03_U01_NOLOCATION.CH4-0006072-0006147 gimme the baker + perl -ne '$l=$_; + for($i=1; $i<=4; $i++) { + ($x=$l)=~ s/-/.CH\Q$i\E-/; + print $x;}' $dir/text.orig | sort > $dir/text + +fi +$cleanup && rm -f $dir/text.* $dir/wav.scp.* $dir/wav.flist + +# Prepare 'segments', 'utt2spk', 'spk2utt' +if [ $mictype == "worn" ]; then + cut -d" " -f 1 $dir/text | \ + awk -F"-" '{printf("%s %s %08.2f %08.2f\n", $0, $1, $2/100.0, $3/100.0)}' |\ + sed -e "s/_[A-Z]*\././2" \ + > $dir/segments +elif [ $mictype == "ref" ]; then + cut -d" " -f 1 $dir/text | \ + awk -F"-" '{printf("%s %s %08.2f %08.2f\n", $0, $1, $2/100.0, $3/100.0)}' |\ + sed -e "s/_[A-Z]*\././2" |\ + sed -e "s/ P.._/ /" > $dir/segments +else + cut -d" " -f 1 $dir/text | \ + awk -F"-" '{printf("%s %s %08.2f %08.2f\n", $0, $1, $2/100.0, $3/100.0)}' |\ + sed -e "s/_[A-Z]*\././2" |\ + sed -e 's/ P.._/ /' > $dir/segments +fi +cut -f 1 -d ' ' $dir/segments | \ + perl -ne 'chomp;$utt=$_;s/_.*//;print "$utt $_\n";' > $dir/utt2spk + +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +if [ $train != 'true' ]; then + # For scoring the final system, we need the original utt2spk + # and text file. So we keep them with the extension .bak here + # so that they don't affect the validate_data_dir steps in + # the intermediate steps. + for file in text utt2spk spk2utt segments; do + mv $dir/$file $dir/$file.bak + done + + # For dev and eval data, prepare pseudo utt2spk. + awk '{print $1, $1}' $dir/wav.scp > $dir/utt2spk + utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt +fi diff --git a/egs/chime6/s5_track2/local/prepare_dict.sh b/egs/chime6/s5_track2/local/prepare_dict.sh new file mode 120000 index 000000000..ada309474 --- /dev/null +++ b/egs/chime6/s5_track2/local/prepare_dict.sh @@ -0,0 +1 @@ +../../s5_track1/local/prepare_dict.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/print_dset_error.py b/egs/chime6/s5_track2/local/print_dset_error.py new file mode 100755 index 000000000..8d7988e27 --- /dev/null +++ b/egs/chime6/s5_track2/local/print_dset_error.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright 2019 Ashish Arora +# Apache 2.0. + +import sys, io +import string +infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') +output = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +spkorder_writer = open(sys.argv[1],'w', encoding='utf8') +total_words={} +total_errors={} +spk_order={} +total_errors_arrayid={} +total_words_arrayid={} + +output.write('WER for each recording: \n') +for line in infile: + toks = line.strip().split() + recordingid = toks[1] + total_words[recordingid] = toks[-5][:-1] + total_errors[recordingid] = toks[-4][:-1] + spk_order[recordingid] = toks[6][1] + '_' + toks[7][0] + '_' + toks[8][0] + '_' + toks[9][0] + arrayid=recordingid.strip().split('_')[1] + if arrayid not in total_errors_arrayid: + total_errors_arrayid[arrayid]=0 + total_words_arrayid[arrayid]=0 + total_errors_arrayid[arrayid]+=int(total_errors[recordingid]) + total_words_arrayid[arrayid]+=int(total_words[recordingid]) + wer = float(total_errors[recordingid])/float(total_words[recordingid])*100 + utt = "{0} {1} {2} {3} {4:5.2f}".format(recordingid, spk_order[recordingid], total_words[recordingid], total_errors[recordingid], wer) + output.write(utt + '\n') + spkorder_writer.write(recordingid + ':' + str(spk_order[recordingid]) + '\n') + + +output.write('WER for each array: \n') +for arrayid in sorted(total_errors_arrayid): + wer = float(total_errors_arrayid[arrayid])/float(total_words_arrayid[arrayid])*100 + utt = "{0} {1} {2} {3:5.2f}".format(arrayid, total_words_arrayid[arrayid], total_errors_arrayid[arrayid], wer) + output.write(utt + '\n') + diff --git a/egs/chime6/s5_track2/local/reverberate_lat_dir.sh b/egs/chime6/s5_track2/local/reverberate_lat_dir.sh new file mode 120000 index 000000000..57302268f --- /dev/null +++ b/egs/chime6/s5_track2/local/reverberate_lat_dir.sh @@ -0,0 +1 @@ +../../s5_track1/local/reverberate_lat_dir.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/run_beamformit.sh b/egs/chime6/s5_track2/local/run_beamformit.sh new file mode 120000 index 000000000..832a16e3b --- /dev/null +++ b/egs/chime6/s5_track2/local/run_beamformit.sh @@ -0,0 +1 @@ +../../s5_track1/local/run_beamformit.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/run_ivector_common.sh b/egs/chime6/s5_track2/local/run_ivector_common.sh new file mode 120000 index 000000000..df7fca843 --- /dev/null +++ b/egs/chime6/s5_track2/local/run_ivector_common.sh @@ -0,0 +1 @@ +../../s5_track1/local/nnet3/run_ivector_common.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/run_wpe.py b/egs/chime6/s5_track2/local/run_wpe.py new file mode 120000 index 000000000..6621607c9 --- /dev/null +++ b/egs/chime6/s5_track2/local/run_wpe.py @@ -0,0 +1 @@ +../../s5_track1/local/run_wpe.py \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/run_wpe.sh b/egs/chime6/s5_track2/local/run_wpe.sh new file mode 120000 index 000000000..187080e62 --- /dev/null +++ b/egs/chime6/s5_track2/local/run_wpe.sh @@ -0,0 +1 @@ +../../s5_track1/local/run_wpe.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/score.sh b/egs/chime6/s5_track2/local/score.sh new file mode 120000 index 000000000..6a200b42e --- /dev/null +++ b/egs/chime6/s5_track2/local/score.sh @@ -0,0 +1 @@ +../steps/scoring/score_kaldi_wer.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/segmentation/detect_speech_activity.sh b/egs/chime6/s5_track2/local/segmentation/detect_speech_activity.sh new file mode 100755 index 000000000..91d52b392 --- /dev/null +++ b/egs/chime6/s5_track2/local/segmentation/detect_speech_activity.sh @@ -0,0 +1,217 @@ +#!/bin/bash + +# Copyright 2016-17 Vimal Manohar +# 2017 Nagendra Kumar Goel +# Apache 2.0. + +# This script does nnet3-based speech activity detection given an input +# kaldi data directory and outputs a segmented kaldi data directory. +# This script can also do music detection and other similar segmentation +# using appropriate options such as --output-name output-music. + +set -e +set -o pipefail +set -u + +if [ -f ./path.sh ]; then . ./path.sh; fi + +affix= # Affix for the segmentation +nj=32 +cmd=queue.pl +stage=-1 + +# Feature options (Must match training) +mfcc_config=conf/mfcc_hires.conf +feat_affix= # Affix for the type of feature used + +output_name=output # The output node in the network +sad_name=sad # Base name for the directory storing the computed loglikes + # Can be music for music detection +segmentation_name=segmentation # Base name for the directory doing segmentation + # Can be segmentation_music for music detection + +# SAD network config +iter=final # Model iteration to use + +# Contexts must ideally match training for LSTM models, but +# may not necessarily for stats components +extra_left_context=0 # Set to some large value, typically 40 for LSTM (must match training) +extra_right_context=0 +extra_left_context_initial=-1 +extra_right_context_final=-1 +frames_per_chunk=150 + +# Decoding options +graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0" +acwt=0.3 + +# These _in__weight represent the fraction of probability +# to transfer to class. +# e.g. --speech-in-sil-weight=0.0 --garbage-in-sil-weight=0.0 --sil-in-speech-weight=0.0 --garbage-in-speech-weight=0.3 +transform_probs_opts="" + +# Postprocessing options +segment_padding=0.2 # Duration (in seconds) of padding added to segments +min_segment_dur=0 # Minimum duration (in seconds) required for a segment to be included + # This is before any padding. Segments shorter than this duration will be removed. + # This is an alternative to --min-speech-duration above. +merge_consecutive_max_dur=0 # Merge consecutive segments as long as the merged segment is no longer than this many + # seconds. The segments are only merged if their boundaries are touching. + # This is after padding by --segment-padding seconds. + # 0 means do not merge. Use 'inf' to not limit the duration. + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "This script does nnet3-based speech activity detection given an input kaldi " + echo "data directory and outputs an output kaldi data directory." + echo "See script for details of the options to be supplied." + echo "Usage: $0 " + echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev exp/nnet3_sad_snr/nnet_tdnn_j_n4 \\" + echo " mfcc_hires exp/segmentation_sad_snr/nnet_tdnn_j_n4 data/ami_sdm1_dev" + echo "" + echo "Options: " + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --nj # number of parallel jobs to run." + echo " --stage # stage to do partial re-run from." + echo " --convert-data-dir-to-whole # If true, the input data directory is " + echo " # first converted to whole data directory (i.e. whole recordings) " + echo " # and segmentation is done on that." + echo " # If false, then the original segments are " + echo " # retained and they are split into sub-segments." + echo " --output-name # The output node in the network" + echo " --extra-left-context # Set to some large value, typically 40 for LSTM (must match training)" + echo " --extra-right-context # For BLSTM or statistics pooling" + exit 1 +fi + +src_data_dir=$1 # The input data directory that needs to be segmented. + # If convert_data_dir_to_whole is true, any segments in that will be ignored. +sad_nnet_dir=$2 # The SAD neural network +mfcc_dir=$3 # The directory to store the features +dir=$4 # Work directory +data_dir=$5 # The output data directory will be ${data_dir}_seg + +affix=${affix:+_$affix} +feat_affix=${feat_affix:+_$feat_affix} + +data_id=`basename $data_dir` +sad_dir=${dir}/${sad_name}${affix}_${data_id}${feat_affix} +seg_dir=${dir}/${segmentation_name}${affix}_${data_id}${feat_affix} +test_data_dir=data/${data_id}${feat_affix} + +############################################################################### +## Forward pass through the network network and dump the log-likelihoods. +############################################################################### + +frame_subsampling_factor=1 +if [ -f $sad_nnet_dir/frame_subsampling_factor ]; then + frame_subsampling_factor=$(cat $sad_nnet_dir/frame_subsampling_factor) +fi + +mkdir -p $dir +if [ $stage -le 1 ]; then + if [ "$(readlink -f $sad_nnet_dir)" != "$(readlink -f $dir)" ]; then + cp $sad_nnet_dir/cmvn_opts $dir || exit 1 + fi + + ######################################################################## + ## Initialize neural network for decoding using the output $output_name + ######################################################################## + + if [ ! -z "$output_name" ] && [ "$output_name" != output ]; then + $cmd $dir/log/get_nnet_${output_name}.log \ + nnet3-copy --edits="rename-node old-name=$output_name new-name=output" \ + $sad_nnet_dir/$iter.raw $dir/${iter}_${output_name}.raw || exit 1 + iter=${iter}_${output_name} + else + if ! diff $sad_nnet_dir/$iter.raw $dir/$iter.raw; then + cp $sad_nnet_dir/$iter.raw $dir/ + fi + fi + + steps/nnet3/compute_output.sh --nj $nj --cmd "$cmd" \ + --iter ${iter} \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --extra-left-context-initial $extra_left_context_initial \ + --extra-right-context-final $extra_right_context_final \ + --frames-per-chunk $frames_per_chunk --apply-exp true \ + --frame-subsampling-factor $frame_subsampling_factor \ + ${test_data_dir} $dir $sad_dir || exit 1 +fi + +############################################################################### +## Prepare FST we search to make speech/silence decisions. +############################################################################### + +utils/data/get_utt2dur.sh --nj $nj --cmd "$cmd" $test_data_dir || exit 1 +frame_shift=$(utils/data/get_frame_shift.sh $test_data_dir) || exit 1 + +graph_dir=${dir}/graph_${output_name} +if [ $stage -le 2 ]; then + mkdir -p $graph_dir + + # 1 for silence and 2 for speech + cat < $graph_dir/words.txt + 0 +silence 1 +speech 2 +EOF + + $cmd $graph_dir/log/make_graph.log \ + steps/segmentation/internal/prepare_sad_graph.py $graph_opts \ + --frame-shift=$(perl -e "print $frame_shift * $frame_subsampling_factor") - \| \ + fstcompile --isymbols=$graph_dir/words.txt --osymbols=$graph_dir/words.txt '>' \ + $graph_dir/HCLG.fst +fi + +############################################################################### +## Do Viterbi decoding to create per-frame alignments. +############################################################################### + +post_vec=$sad_nnet_dir/post_${output_name}.vec +if [ ! -f $sad_nnet_dir/post_${output_name}.vec ]; then + if [ ! -f $sad_nnet_dir/post_${output_name}.txt ]; then + echo "$0: Could not find $sad_nnet_dir/post_${output_name}.vec. " + echo "Re-run the corresponding stage in the training script possibly " + echo "with --compute-average-posteriors=true or compute the priors " + echo "from the training labels" + exit 1 + else + post_vec=$sad_nnet_dir/post_${output_name}.txt + fi +fi + +mkdir -p $seg_dir +if [ $stage -le 3 ]; then + steps/segmentation/internal/get_transform_probs_mat.py \ + --priors="$post_vec" $transform_probs_opts > $seg_dir/transform_probs.mat + + steps/segmentation/decode_sad.sh --acwt $acwt --cmd "$cmd" \ + --nj $nj \ + --transform "$seg_dir/transform_probs.mat" \ + $graph_dir $sad_dir $seg_dir +fi + +############################################################################### +## Post-process segmentation to create kaldi data directory. +############################################################################### + +if [ $stage -le 4 ]; then + steps/segmentation/post_process_sad_to_segments.sh \ + --segment-padding $segment_padding --min-segment-dur $min_segment_dur \ + --merge-consecutive-max-dur $merge_consecutive_max_dur \ + --cmd "$cmd" --frame-shift $(perl -e "print $frame_subsampling_factor * $frame_shift") \ + ${test_data_dir} ${seg_dir} ${seg_dir} +fi + +if [ $stage -le 5 ]; then + utils/data/subsegment_data_dir.sh ${test_data_dir} ${seg_dir}/segments \ + ${data_dir}_seg +fi + +echo "$0: Created output segmented kaldi data directory in ${data_dir}_seg" +exit 0 diff --git a/egs/chime6/s5_track2/local/segmentation/tuning/train_lstm_sad_1a.sh b/egs/chime6/s5_track2/local/segmentation/tuning/train_lstm_sad_1a.sh new file mode 100755 index 000000000..570142486 --- /dev/null +++ b/egs/chime6/s5_track2/local/segmentation/tuning/train_lstm_sad_1a.sh @@ -0,0 +1,140 @@ +#!/bin/bash + +# Copyright 2017 Nagendra Kumar Goel +# 2018 Vimal Manohar +# Apache 2.0 + +# This is a script to train a TDNN for speech activity detection (SAD) +# using LSTM for long-context information. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= + +chunk_width=20 + +extra_left_context=60 +extra_right_context=10 +relu_dim=256 +cell_dim=256 +projection_dim=64 + +# training options +num_epochs=1 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=true +max_param_change=0.2 # Small max-param change for small network +dropout_schedule='0,0@0.20,0.1@0.50,0' + +egs_dir= +nj=40 + +dir= +affix=1a + +data_dir= +targets_dir= + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +set -o pipefail +set -u + +if [ -z "$dir" ]; then + dir=exp/segmentation_1a/tdnn_lstm_asr_sad +fi +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/cmvn_opts + +if [ $stage -le 5 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$data_dir/feats.scp -` name=input + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2) affine-transform-file=$dir/configs/lda.mat + + relu-renorm-layer name=tdnn1 input=lda dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + fast-lstmp-layer name=lstm1 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim decay-time=20 delay=-3 dropout-proportion=0.0 + relu-renorm-layer name=tdnn4 input=Append(-6,0,6,12) add-log-stddev=true dim=$relu_dim + fast-lstmp-layer name=lstm2 cell-dim=$cell_dim recurrent-projection-dim=$projection_dim non-recurrent-projection-dim=$projection_dim decay-time=20 delay=-3 dropout-proportion=0.0 + relu-renorm-layer name=tdnn5 input=Append(-12,0,12,24) dim=$relu_dim + + output-layer name=output include-log-softmax=true dim=3 learning-rate-factor=0.1 input=tdnn5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ + + cat <> $dir/configs/vars +num_targets=3 +EOF +fi + +if [ $stage -le 6 ]; then + num_utts=`cat $data_dir/utt2spk | wc -l` + # Set num_utts_subset for diagnostics to a reasonable value + # of max(min(0.005 * num_utts, 300), 12) + num_utts_subset=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 300 ? 300 : ($n < 12 ? 12 : $n))' $num_utts` + + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=0.99 \ + --trainer.dropout-schedule="$dropout_schedule" \ + --trainer.rnn.num-chunk-per-minibatch=128,64 \ + --trainer.optimization.momentum=0.5 \ + --trainer.deriv-truncate-margin=10 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj $nj \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=true \ + --feat-dir=$data_dir \ + --targets-scp="$targets_dir/targets.scp" \ + --egs.opts="--frame-subsampling-factor 3 --num-utts-subset $num_utts_subset" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 7 ]; then + # Use a subset to compute prior over the output targets + $train_cmd $dir/log/get_priors.log \ + matrix-sum-rows "scp:utils/subset_scp.pl --quiet 1000 $targets_dir/targets.scp |" \ + ark:- \| vector-sum --binary=false ark:- $dir/post_output.vec || exit 1 + + echo 3 > $dir/frame_subsampling_factor +fi diff --git a/egs/chime6/s5_track2/local/segmentation/tuning/train_stats_sad_1a.sh b/egs/chime6/s5_track2/local/segmentation/tuning/train_stats_sad_1a.sh new file mode 100755 index 000000000..bb985462f --- /dev/null +++ b/egs/chime6/s5_track2/local/segmentation/tuning/train_stats_sad_1a.sh @@ -0,0 +1,150 @@ +#!/bin/bash + +# Copyright 2017 Nagendra Kumar Goel +# 2018 Vimal Manohar +# Apache 2.0 + +# This is a script to train a TDNN for speech activity detection (SAD) +# using statistics pooling for long-context information. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= + +chunk_width=20 + +# The context is chosen to be around 1 second long. The context at test time +# is expected to be around the same. +extra_left_context=79 +extra_right_context=21 + +relu_dim=256 + +# training options +num_epochs=1 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=true +max_param_change=0.2 # Small max-param change for small network + +egs_dir= +nj=40 + +dir= +affix=1a + +data_dir= +targets_dir= + +. ./cmd.sh +if [ -f ./path.sh ]; then . ./path.sh; fi +. ./utils/parse_options.sh + +set -o pipefail +set -u + +if [ -z "$dir" ]; then + dir=exp/segmentation_1a/tdnn_stats_sad +fi +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/cmvn_opts + +if [ $stage -le 5 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$data_dir/feats.scp -` name=input + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2) affine-transform-file=$dir/configs/lda.mat + + relu-renorm-layer name=tdnn1 input=lda dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn2 input=Append(-1,0,1,2) dim=$relu_dim add-log-stddev=true + relu-renorm-layer name=tdnn3 input=Append(-3,0,3,6) dim=$relu_dim add-log-stddev=true + stats-layer name=tdnn3_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn4 input=Append(tdnn3@-6,tdnn3@0,tdnn3@6,tdnn3@12,tdnn3_stats) add-log-stddev=true dim=$relu_dim + stats-layer name=tdnn4_stats config=mean+count(-108:6:18:108) + relu-renorm-layer name=tdnn5 input=Append(tdnn4@-12,tdnn4@0,tdnn4@12,tdnn4@24,tdnn4_stats) dim=$relu_dim + + output-layer name=output include-log-softmax=true dim=3 learning-rate-factor=0.1 input=tdnn5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ + + cat <> $dir/configs/vars +num_targets=3 +EOF +fi + +if [ $stage -le 6 ]; then + num_utts=`cat $data_dir/utt2spk | wc -l` + # Set num_utts_subset for diagnostics to a reasonable value + # of max(min(0.005 * num_utts, 300), 12) + num_utts_subset=`perl -e '$n=int($ARGV[0] * 0.005); print ($n > 300 ? 300 : ($n < 12 ? 12 : $n))' $num_utts` + + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts=$cmvn_opts \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.rnn.num-chunk-per-minibatch=128,64 \ + --trainer.optimization.momentum=0.5 \ + --trainer.deriv-truncate-margin=10 \ + --trainer.max-param-change=$max_param_change \ + --trainer.compute-per-dim-accuracy=true \ + --cmd="$decode_cmd" --nj $nj \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=true \ + --feat-dir=$data_dir \ + --targets-scp="$targets_dir/targets.scp" \ + --egs.opts="--frame-subsampling-factor 3 --num-utts-subset $num_utts_subset" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 7 ]; then + # Use a subset to compute prior over the output targets + #$train_cmd $dir/log/get_priors.log \ + # matrix-sum-rows "scp:utils/subset_scp.pl --quiet 1000 $targets_dir/targets.scp |" \ + # ark:- \| vector-sum --binary=false ark:- $dir/post_output.vec || exit 1 + + # Since the train data is individual microphones, while the dev and + # eval are beamformed, it is likely that the train contains a much + # higher ratio of silences. So using priors computed from the train + # data may miss a lot of speech in the dev/eval sets. Hence we manually + # tune the prior on the dev set. + # With the following prior, the SAD system results are: + # Dev (using -c 0.25) + # MISSED SPEECH = 1188.59 secs ( 3.3 percent of scored time) + # FALARM SPEECH = 539.37 secs ( 1.5 percent of scored time) + echo "[ 30 2 1 ]" > $dir/post_output.vec || exit 1 + + echo 3 > $dir/frame_subsampling_factor +fi + diff --git a/egs/chime6/s5_track2/local/train_diarizer.sh b/egs/chime6/s5_track2/local/train_diarizer.sh new file mode 100755 index 000000000..71918e7ca --- /dev/null +++ b/egs/chime6/s5_track2/local/train_diarizer.sh @@ -0,0 +1,186 @@ +#!/bin/bash +# Copyright +# 2019 David Snyder +# Apache 2.0. +# +# This script is based on the run.sh script in the Voxceleb v2 recipe. +# It trains an x-vector DNN for diarization. + +mfccdir=`pwd`/mfcc +vaddir=`pwd`/mfcc + +voxceleb1_root=/export/corpora/VoxCeleb1 +voxceleb2_root=/export/corpora/VoxCeleb2 +data_dir=train_worn_simu_u400k +model_dir=exp/xvector_nnet_1a + +stage=0 +train_stage=-1 + +. ./cmd.sh + +if [ -f ./path.sh ]; then . ./path.sh; fi +set -e -u -o pipefail +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + exit 1 +fi + +if [ $stage -le 0 ]; then + echo "$0: preparing voxceleb 2 data" + local/make_voxceleb2.pl $voxceleb2_root dev data/voxceleb2_train + local/make_voxceleb2.pl $voxceleb2_root test data/voxceleb2_test + + echo "$0: preparing voxceleb 1 data (see comments if this step fails)" + # The format of the voxceleb 1 corpus has changed several times since it was + # released. Therefore, our dataprep scripts may or may not fail depending + # on the version of the corpus you obtained. + # If you downloaded the corpus soon after it was first released, this + # version of the dataprep script might work: + local/make_voxceleb1.pl $voxceleb1_root data/voxceleb1 + # However, if you've downloaded the corpus recently, you may need to use the + # the following scripts instead: + #local/make_voxceleb1_v2.pl $voxceleb1_root dev data/voxceleb1_train + #local/make_voxceleb1_v2.pl $voxceleb1_root test data/voxceleb1_test + + # We should now have about 7,351 speakers and 1,277,503 utterances. + utils/combine_data.sh data/voxceleb data/voxceleb2_train data/voxceleb2_test +fi + +if [ $stage -le 1 ]; then + echo "$0: preparing features for training data (voxceleb 1 + 2)" + steps/make_mfcc.sh --write-utt2num-frames true \ + --mfcc-config conf/mfcc_hires.conf --nj 40 --cmd "$train_cmd" \ + data/voxceleb exp/make_mfcc $mfccdir + utils/fix_data_dir.sh data/voxceleb + # Note that we apply CMN to the MFCCs and write these to the disk. These + # features will later be used to train the x-vector DNN. +fi + +# In this section, we augment the voxceleb data with reverberation. +# Note that we can probably improve the x-vector DNN if we include +# augmentations from the nonspeech regions of the Chime 6 training +# dataset. +if [ $stage -le 2 ]; then + echo "$0: applying augmentation to x-vector training data (just reverb for now)" + frame_shift=0.01 + awk -v frame_shift=$frame_shift '{print $1, $2*frame_shift;}' data/voxceleb/utt2num_frames > data/voxceleb/reco2dur + + if [ ! -d "RIRS_NOISES" ]; then + echo "$0: downloading simulated room impulse response dataset" + # Download the package that includes the real RIRs, simulated RIRs, isotropic noises and point-source noises + wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip + unzip rirs_noises.zip + fi + + # Make a version with reverberated speech + rvb_opts=() + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") + + # Make a reverberated version of the training data. Note that we don't add any + # additive noise here. + steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --speech-rvb-probability 1 \ + --pointsource-noise-addition-probability 0 \ + --isotropic-noise-addition-probability 0 \ + --num-replications 1 \ + --source-sampling-rate 16000 \ + data/voxceleb data/voxceleb_reverb + utils/copy_data_dir.sh --utt-suffix "-reverb" data/voxceleb_reverb data/voxceleb_reverb.new + rm -rf data/voxceleb_reverb + mv data/voxceleb_reverb.new data/voxceleb_reverb +fi + +if [ $stage -le 3 ]; then + echo "$0: making MFCCs for augmented training data" + # Make MFCCs for the augmented data. Note that we do not compute a new + # vad.scp file here. Instead, we use the vad.scp from the clean version of + # the list. + steps/make_mfcc.sh --mfcc-config conf/mfcc_hires.conf --nj 40 --cmd "$train_cmd" \ + data/voxceleb_reverb exp/make_mfcc $mfccdir + # Combine the clean and augmented training data. This is now roughly + # double the size of the original clean list. + utils/combine_data.sh data/voxceleb_combined data/voxceleb_reverb data/voxceleb +fi + +# Now we prepare the features to generate examples for xvector training. +if [ $stage -le 4 ]; then + # This script applies CMVN and removes nonspeech frames. Note that this is somewhat + # wasteful, as it roughly doubles the amount of training data on disk. After + # creating voxceleb examples, this can be removed. + echo "$0: preparing features to train x-vector DNN" + local/nnet3/xvector/prepare_feats.sh --nj 40 --cmd "$train_cmd" \ + data/voxceleb_combined data/voxceleb_combined_cmn exp/voxceleb_combined_cmn + utils/fix_data_dir.sh data/voxceleb_combined_cmn +fi + +if [ $stage -le 5 ]; then + # Now, we need to remove features that are too short after removing silence + # frames. We want at least 4s (400 frames) per utterance. + min_len=400 + mv data/voxceleb_combined_cmn/utt2num_frames data/voxceleb_combined_cmn/utt2num_frames.bak + awk -v min_len=${min_len} '$2 > min_len {print $1, $2}' data/voxceleb_combined_cmn/utt2num_frames.bak > data/voxceleb_combined_cmn/utt2num_frames + utils/filter_scp.pl data/voxceleb_combined_cmn/utt2num_frames data/voxceleb_combined_cmn/utt2spk > data/voxceleb_combined_cmn/utt2spk.new + mv data/voxceleb_combined_cmn/utt2spk.new data/voxceleb_combined_cmn/utt2spk + utils/fix_data_dir.sh data/voxceleb_combined_cmn + + # We also want several utterances per speaker. Now we'll throw out speakers + # with fewer than 8 utterances. + min_num_utts=8 + awk '{print $1, NF-1}' data/voxceleb_combined_cmn/spk2utt > data/voxceleb_combined_cmn/spk2num + awk -v min_num_utts=${min_num_utts} '$2 >= min_num_utts {print $1, $2}' data/voxceleb_combined_cmn/spk2num | utils/filter_scp.pl - data/voxceleb_combined_cmn/spk2utt > data/voxceleb_combined_cmn/spk2utt.new + mv data/voxceleb_combined_cmn/spk2utt.new data/voxceleb_combined_cmn/spk2utt + utils/spk2utt_to_utt2spk.pl data/voxceleb_combined_cmn/spk2utt > data/voxceleb_combined_cmn/utt2spk + + utils/filter_scp.pl data/voxceleb_combined_cmn/utt2spk data/voxceleb_combined_cmn/utt2num_frames > data/voxceleb_combined_cmn/utt2num_frames.new + mv data/voxceleb_combined_cmn/utt2num_frames.new data/voxceleb_combined_cmn/utt2num_frames + + utils/fix_data_dir.sh data/voxceleb_combined_cmn +fi + +# Stages 6 through 8 are handled in run_xvector.sh. +# This script trains the x-vector DNN on the augmented voxceleb data. +local/nnet3/xvector/run_xvector.sh --stage $stage --train-stage $train_stage \ + --data data/voxceleb_combined_cmn --nnet-dir $model_dir \ + --egs-dir $model_dir/egs + +if [ $stage -le 9 ]; then + echo "$0: preparing a subset of Chime 6 training data to train PLDA model" + utils/subset_data_dir.sh ${data_dir} 100000 data/plda_train + steps/make_mfcc.sh --write-utt2num-frames true \ + --mfcc-config conf/mfcc_hires.conf --nj 40 --cmd "$train_cmd" \ + data/plda_train exp/make_mfcc $mfccdir + utils/fix_data_dir.sh data/plda_train + local/nnet3/xvector/prepare_feats.sh --nj 40 --cmd "$train_cmd" \ + data/plda_train data/plda_train_cmn exp/plda_train_cmn + if [ -f data/plda_train/segments ]; then + cp data/plda_train/segments data/plda_train_cmn/ + fi +fi + +if [ $stage -le 10 ]; then + echo "$0: extracting x-vector for PLDA training data" + utils/fix_data_dir.sh data/plda_train_cmn + diarization/nnet3/xvector/extract_xvectors.sh --cmd "$train_cmd --mem 10G" \ + --nj 40 --window 3.0 --period 10.0 --min-segment 1.5 --apply-cmn false \ + --hard-min true $model_dir \ + data/plda_train_cmn $model_dir/xvectors_plda_train +fi + +# Train PLDA models +if [ $stage -le 11 ]; then + echo "$0: training PLDA model" + $train_cmd $model_dir/xvectors_plda_train/log/plda.log \ + ivector-compute-plda ark:$model_dir/xvectors_plda_train/spk2utt \ + "ark:ivector-subtract-global-mean \ + scp:$model_dir/xvectors_plda_train/xvector.scp ark:- \ + | transform-vec $model_dir/xvectors_plda_train/transform.mat ark:- ark:- \ + | ivector-normalize-length ark:- ark:- |" \ + $model_dir/xvectors_plda_train/plda || exit 1; + cp $model_dir/xvectors_plda_train/plda $model_dir/ + cp $model_dir/xvectors_plda_train/transform.mat $model_dir/ + cp $model_dir/xvectors_plda_train/mean.vec $model_dir/ +fi diff --git a/egs/chime6/s5_track2/local/train_lms_srilm.sh b/egs/chime6/s5_track2/local/train_lms_srilm.sh new file mode 120000 index 000000000..a7666f6cd --- /dev/null +++ b/egs/chime6/s5_track2/local/train_lms_srilm.sh @@ -0,0 +1 @@ +../../s5_track1/local/train_lms_srilm.sh \ No newline at end of file diff --git a/egs/chime6/s5_track2/local/train_sad.sh b/egs/chime6/s5_track2/local/train_sad.sh new file mode 100755 index 000000000..e12a0cad6 --- /dev/null +++ b/egs/chime6/s5_track2/local/train_sad.sh @@ -0,0 +1,155 @@ +#!/bin/bash + +# Copyright 2017 Nagendra Kumar Goel +# 2017 Vimal Manohar +# 2019 Desh Raj +# Apache 2.0 + +# This script is based on local/run_asr_segmentation.sh script in the +# Aspire recipe. It demonstrates nnet3-based speech activity detection for +# segmentation. +# This script: +# 1) Prepares targets (per-frame labels) for a subset of training data +# using GMM models +# 2) Trains TDNN+Stats or TDNN+LSTM neural network using the targets +# 3) Demonstrates using the SAD system to get segments of dev data + +lang=data/lang # Must match the one used to train the models +lang_test=data/lang_test # Lang directory for decoding. + +data_dir= +test_sets= +# Model directory used to align the $data_dir to get target labels for training +# SAD. This should typically be a speaker-adapted system. +sat_model_dir= +# Model direcotry used to decode the whole-recording version of the $data_dir to +# get target labels for training SAD. This should typically be a +# speaker-independent system like LDA+MLLT system. +model_dir= +graph_dir= # Graph for decoding whole-recording version of $data_dir. + # If not provided, a new one will be created using $lang_test + +# List of weights on labels obtained from alignment; +# labels obtained from decoding; and default labels in out-of-segment regions +merge_weights=1.0,0.1,0.5 + +prepare_targets_stage=-10 +nstage=-10 +train_stage=-10 +stage=0 +nj=50 +reco_nj=40 + +# test options +test_nj=10 + +. ./cmd.sh +. ./conf/sad.conf + +if [ -f ./path.sh ]; then . ./path.sh; fi + +set -e -u -o pipefail +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + exit 1 +fi + +dir=exp/segmentation${affix} +sad_work_dir=exp/sad${affix}_${nnet_type}/ +sad_nnet_dir=$dir/tdnn_${nnet_type}_sad_1a + +mkdir -p $dir +mkdir -p ${sad_work_dir} + +# See $lang/phones.txt and decide which should be garbage +garbage_phones="laughs inaudible" +silence_phones="sil spn noise" + +for p in $garbage_phones; do + for a in "" "_B" "_E" "_I" "_S"; do + echo "$p$a" + done +done > $dir/garbage_phones.txt + +for p in $silence_phones; do + for a in "" "_B" "_E" "_I" "_S"; do + echo "$p$a" + done +done > $dir/silence_phones.txt + +if ! cat $dir/garbage_phones.txt $dir/silence_phones.txt | \ + steps/segmentation/internal/verify_phones_list.py $lang/phones.txt; then + echo "$0: Invalid $dir/{silence,garbage}_phones.txt" + exit 1 +fi + +# The training data may already be segmented, so we first prepare +# a "whole" training data (not segmented) for training the SAD +# system. + +whole_data_dir=${data_dir}_whole +whole_data_id=$(basename $whole_data_dir) + +if [ $stage -le 0 ]; then + utils/data/convert_data_dir_to_whole.sh $data_dir $whole_data_dir +fi + +############################################################################### +# Extract features for the whole data directory. We extract 13-dim MFCCs to +# generate targets using the GMM system, and 40-dim MFCCs to train the NN-based +# SAD. +############################################################################### +if [ $stage -le 1 ]; then + steps/make_mfcc.sh --nj $reco_nj --cmd "$train_cmd" --write-utt2num-frames true \ + --mfcc-config conf/mfcc.conf \ + $whole_data_dir exp/make_mfcc/${whole_data_id} + steps/compute_cmvn_stats.sh $whole_data_dir exp/make_mfcc/${whole_data_id} + utils/fix_data_dir.sh $whole_data_dir + + utils/copy_data_dir.sh $whole_data_dir ${whole_data_dir}_hires + steps/make_mfcc.sh --nj $reco_nj --cmd "$train_cmd" --write-utt2num-frames true \ + --mfcc-config conf/mfcc_hires.conf \ + ${whole_data_dir}_hires exp/make_mfcc/${whole_data_id}_hires + steps/compute_cmvn_stats.sh ${whole_data_dir}_hires exp/make_mfcc/${whole_data_id}_hires + utils/fix_data_dir.sh ${whole_data_dir}_hires +fi + +############################################################################### +# Prepare SAD targets for recordings +############################################################################### +targets_dir=$dir/${whole_data_id}_combined_targets_sub3 +if [ $stage -le 2 ]; then + steps/segmentation/prepare_targets_gmm.sh --stage $prepare_targets_stage \ + --train-cmd "$train_cmd" --decode-cmd "$decode_cmd" \ + --nj $nj --reco-nj $reco_nj --lang-test $lang \ + --garbage-phones-list $dir/garbage_phones.txt \ + --silence-phones-list $dir/silence_phones.txt \ + --merge-weights "$merge_weights" \ + --remove-mismatch-frames false \ + --graph-dir "$graph_dir" \ + $lang $data_dir $whole_data_dir $sat_model_dir $model_dir $dir +fi + +############################################################################### +# Train a neural network for SAD +############################################################################### +if [ $stage -le 3 ]; then + if [ $nnet_type == "stats" ]; then + # Train a STATS-pooling network for SAD + local/segmentation/tuning/train_stats_sad_1a.sh \ + --stage $nstage --train-stage $train_stage \ + --targets-dir ${targets_dir} \ + --data-dir ${whole_data_dir}_hires --affix "1a" || exit 1 + + elif [ $nnet_type == "lstm" ]; then + # Train a TDNN+LSTM network for SAD + local/segmentation/tuning/train_lstm_sad_1a.sh \ + --stage $nstage --train-stage $train_stage \ + --targets-dir ${targets_dir} \ + --data-dir ${whole_data_dir}_hires --affix "1a" || exit 1 + + fi +fi + +exit 0; diff --git a/egs/chime6/s5_track2/local/wer_output_filter b/egs/chime6/s5_track2/local/wer_output_filter new file mode 120000 index 000000000..12a6c616d --- /dev/null +++ b/egs/chime6/s5_track2/local/wer_output_filter @@ -0,0 +1 @@ +../../s5_track1/local/wer_output_filter \ No newline at end of file diff --git a/egs/chime6/s5_track2/path.sh b/egs/chime6/s5_track2/path.sh new file mode 100644 index 000000000..c2526194b --- /dev/null +++ b/egs/chime6/s5_track2/path.sh @@ -0,0 +1,7 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C + diff --git a/egs/chime6/s5_track2/run.sh b/egs/chime6/s5_track2/run.sh new file mode 100755 index 000000000..1350b8e14 --- /dev/null +++ b/egs/chime6/s5_track2/run.sh @@ -0,0 +1,296 @@ +#!/bin/bash +# +# Chime-6 Track 2 baseline. Based mostly on the Chime-5 recipe, with the exception +# that we are required to perform speech activity detection and speaker +# diarization before ASR, since we do not have access to the oracle SAD and +# diarization labels. +# +# Copyright 2017 Johns Hopkins University (Author: Shinji Watanabe and Yenda Trmal) +# 2019 Desh Raj, David Snyder, Ashish Arora +# Apache 2.0 + +# Begin configuration section. +nj=50 +decode_nj=20 +stage=0 +nnet_stage=-10 +sad_stage=0 +diarizer_stage=0 +decode_stage=1 +enhancement=beamformit # for a new enhancement method, + # change this variable and decode stage +decode_only=false +num_data_reps=4 +snrs="20:10:15:5:0" +foreground_snrs="20:10:15:5:0" +background_snrs="20:10:15:5:0" +# End configuration section +. ./utils/parse_options.sh + +. ./cmd.sh +. ./path.sh + +if [ $decode_only == "true" ]; then + stage=18 +fi + +set -e # exit on error + +# chime5 main directory path +# please change the path accordingly +chime5_corpus=/export/corpora4/CHiME5 +# chime6 data directories, which are generated from ${chime5_corpus}, +# to synchronize audio files across arrays and modify the annotation (JSON) file accordingly +chime6_corpus=${PWD}/CHiME6 +json_dir=${chime6_corpus}/transcriptions +audio_dir=${chime6_corpus}/audio + +# training and test data +train_set=train_worn_simu_u400k +sad_train_set=train_worn_u400k +test_sets="dev_${enhancement}_dereverb eval_${enhancement}_dereverb" + +# This script also needs the phonetisaurus g2p, srilm, beamformit +./local/check_tools.sh || exit 1; + +########################################################################### +# We first generate the synchronized audio files across arrays and +# corresponding JSON files. Note that this requires sox v14.4.2, +# which is installed via miniconda in ./local/check_tools.sh +########################################################################### + +if [ $stage -le 0 ]; then + local/generate_chime6_data.sh \ + --cmd "$train_cmd" \ + ${chime5_corpus} \ + ${chime6_corpus} +fi + +########################################################################### +# We prepare dict and lang in stages 1 to 3. +########################################################################### + +if [ $stage -le 1 ]; then + # skip u03 and u04 as they are missing + for mictype in worn u01 u02 u05 u06; do + local/prepare_data.sh --mictype ${mictype} --train true \ + ${audio_dir}/train ${json_dir}/train data/train_${mictype} + done + for dataset in dev; do + for mictype in worn; do + local/prepare_data.sh --mictype ${mictype} --train true \ + ${audio_dir}/${dataset} ${json_dir}/${dataset} \ + data/${dataset}_${mictype} + done + done +fi + +if [ $stage -le 2 ]; then + local/prepare_dict.sh + + utils/prepare_lang.sh \ + data/local/dict "" data/local/lang data/lang + + local/train_lms_srilm.sh \ + --train-text data/train_worn/text --dev-text data/dev_worn/text \ + --oov-symbol "" --words-file data/lang/words.txt \ + data/ data/srilm +fi + +LM=data/srilm/best_3gram.gz +if [ $stage -le 3 ]; then + # Compiles G for chime5 trigram LM + utils/format_lm.sh \ + data/lang $LM data/local/dict/lexicon.txt data/lang + +fi + +if [ $stage -le 4 ]; then + # remove possibly bad sessions (P11_S03, P52_S19, P53_S24, P54_S24) + # see http://spandh.dcs.shef.ac.uk/chime_challenge/data.html for more details + utils/copy_data_dir.sh data/train_worn data/train_worn_org # back up + grep -v -e "^P11_S03" -e "^P52_S19" -e "^P53_S24" -e "^P54_S24" data/train_worn_org/text > data/train_worn/text + utils/fix_data_dir.sh data/train_worn +fi + + +######################################################################################### +# In stages 5 and 6, we augment and fix train data for our training purpose. point source +# noises are extracted from chime corpus. Here we use 400k utterances from array microphones, +# its augmentation and all the worn set utterances in train. +######################################################################################### + +if [ $stage -le 5 ]; then + echo "$0: Extracting noise list from training data" + local/extract_noises.py $chime6_corpus/audio/train $chime6_corpus/transcriptions/train \ + local/distant_audio_list distant_noises + local/make_noise_list.py distant_noises > distant_noise_list + + noise_list=distant_noise_list + + echo "$0: Preparing simulated RIRs for data augmentation" + if [ ! -d RIRS_NOISES/ ]; then + # Download the package that includes the real RIRs, simulated RIRs, isotropic noises and point-source noises + wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip + unzip rirs_noises.zip + fi + + # This is the config for the system using simulated RIRs and point-source noises + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") + rvb_opts+=(--noise-set-parameters $noise_list) + + steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix "rev" \ + --foreground-snrs $foreground_snrs \ + --background-snrs $background_snrs \ + --speech-rvb-probability 1 \ + --pointsource-noise-addition-probability 1 \ + --isotropic-noise-addition-probability 1 \ + --num-replications $num_data_reps \ + --max-noises-per-minute 1 \ + --source-sampling-rate 16000 \ + data/train_worn data/train_worn_rvb +fi + +if [ $stage -le 6 ]; then + # combine mix array and worn mics + # randomly extract first 400k utterances from all mics + # if you want to include more training data, you can increase the number of array mic utterances + utils/combine_data.sh data/train_uall data/train_u01 data/train_u02 data/train_u05 data/train_u06 + utils/subset_data_dir.sh data/train_uall 400000 data/train_u400k + utils/combine_data.sh data/${train_set} data/train_worn data/train_worn_rvb data/train_u400k + utils/combine_data.sh data/${sad_train_set} data/train_worn data/train_u400k +fi + +if [ $stage -le 7 ]; then + # Split speakers up into 3-minute chunks. This doesn't hurt adaptation, and + # lets us use more jobs for decoding etc. + utils/copy_data_dir.sh data/${train_set} data/${train_set}_nosplit + utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${train_set}_nosplit data/${train_set} +fi + +################################################################################## +# Now make MFCC features. We use 13-dim MFCCs to train the GMM-HMM models. +################################################################################## + +if [ $stage -le 8 ]; then + # Now make MFCC features. + # mfccdir should be some place with a largish disk where you + # want to store MFCC features. + echo "$0: make features..." + mfccdir=mfcc + steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" \ + --mfcc-config conf/mfcc.conf \ + data/${train_set} exp/make_mfcc/${train_set} $mfccdir + steps/compute_cmvn_stats.sh data/${train_set} exp/make_mfcc/${train_set} $mfccdir + utils/fix_data_dir.sh data/${train_set} +fi + +################################################################################### +# Stages 9 to 14 train monophone and triphone models. They will be used for +# generating lattices for training the chain model and for obtaining targets +# for training the SAD system. +################################################################################### + +if [ $stage -le 9 ]; then + # make a subset for monophone training + utils/subset_data_dir.sh --shortest data/${train_set} 100000 data/${train_set}_100kshort + utils/subset_data_dir.sh data/${train_set}_100kshort 30000 data/${train_set}_30kshort +fi + +if [ $stage -le 10 ]; then + # Starting basic training on MFCC features + steps/train_mono.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set}_30kshort data/lang exp/mono +fi + +if [ $stage -le 11 ]; then + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/mono exp/mono_ali + + steps/train_deltas.sh --cmd "$train_cmd" \ + 2500 30000 data/${train_set} data/lang exp/mono_ali exp/tri1 +fi + +if [ $stage -le 12 ]; then + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/tri1 exp/tri1_ali + + steps/train_lda_mllt.sh --cmd "$train_cmd" \ + 4000 50000 data/${train_set} data/lang exp/tri1_ali exp/tri2 +fi + +if [ $stage -le 13 ]; then + steps/align_si.sh --nj $nj --cmd "$train_cmd" \ + data/${train_set} data/lang exp/tri2 exp/tri2_ali + + steps/train_sat.sh --cmd "$train_cmd" \ + 5000 100000 data/${train_set} data/lang exp/tri2_ali exp/tri3 +fi + +if [ $stage -le 14 ]; then + # The following script cleans the data and produces cleaned data + steps/cleanup/clean_and_segment_data.sh --nj $nj --cmd "$train_cmd" \ + --segmentation-opts "--min-segment-length 0.3 --min-new-segment-length 0.6" \ + data/${train_set} data/lang exp/tri3 exp/tri3_cleaned data/${train_set}_cleaned +fi + +########################################################################## +# CHAIN MODEL TRAINING +# You can also download a pretrained chain ASR model using: +# wget http://kaldi-asr.org/models/12/0012_asr_v1.tar.gz +# Once it is downloaded, extract using: tar -xvzf 0012_asr_v1.tar.gz +# and copy the contents of the exp/ directory to your exp/ +########################################################################## +if [ $stage -le 15 ]; then + # chain TDNN + local/chain/run_tdnn.sh --nj $nj \ + --stage $nnet_stage \ + --train-set ${train_set}_cleaned \ + --test-sets "$test_sets" \ + --gmm tri3_cleaned --nnet3-affix _${train_set}_cleaned_rvb +fi + +########################################################################## +# SAD MODEL TRAINING +# You can also download a pretrained SAD model using: +# wget http://kaldi-asr.org/models/12/0012_sad_v1.tar.gz +# Once it is downloaded, extract using: tar -xvzf 0012_sad_v1.tar.gz +# and copy the contents of the exp/ directory to your exp/ +########################################################################## +if [ $stage -le 16 ]; then + local/train_sad.sh --stage $sad_stage --nj $nj \ + --data-dir data/${sad_train_set} --test-sets "${test_sets}" \ + --sat-model-dir exp/tri3_cleaned \ + --model-dir exp/tri2 +fi + +########################################################################## +# DIARIZATION MODEL TRAINING +# You can also download a pretrained diarization model using: +# wget http://kaldi-asr.org/models/12/0012_diarization_v1.tar.gz +# Once it is downloaded, extract using: tar -xvzf 0012_diarization_v1.tar.gz +# and copy the contents of the exp/ directory to your exp/ +########################################################################## +if [ $stage -le 17 ]; then + local/train_diarizer.sh --stage $diarizer_stage \ + --data-dir data/${train_set} \ + --model-dir exp/xvector_nnet_1a +fi + +########################################################################## +# DECODING: In track 2, we are given raw utterances without segment +# or speaker information, so we have to decode the whole pipeline, i.e., +# SAD -> Diarization -> ASR. This is done in the local/decode.sh +# script. +########################################################################## +if [ $stage -le 18 ]; then + local/decode.sh --stage $decode_stage \ + --enhancement $enhancement \ + --test-sets "$test_sets" +fi + +exit 0; + diff --git a/egs/chime6/s5_track2/sid b/egs/chime6/s5_track2/sid new file mode 120000 index 000000000..893a12f30 --- /dev/null +++ b/egs/chime6/s5_track2/sid @@ -0,0 +1 @@ +../../sre08/v1/sid \ No newline at end of file diff --git a/egs/chime6/s5_track2/steps b/egs/chime6/s5_track2/steps new file mode 120000 index 000000000..1b186770d --- /dev/null +++ b/egs/chime6/s5_track2/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/chime6/s5_track2/utils b/egs/chime6/s5_track2/utils new file mode 120000 index 000000000..a3279dc86 --- /dev/null +++ b/egs/chime6/s5_track2/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/cmu_cslu_kids/README b/egs/cmu_cslu_kids/README new file mode 100644 index 000000000..0b8512e24 --- /dev/null +++ b/egs/cmu_cslu_kids/README @@ -0,0 +1,21 @@ +This is an ASR recipe for children speech using cmu_kids and cslu_kids. +Both of the corpora can be found on LDC: + - cmu_kids : https://catalog.ldc.upenn.edu/LDC97S63 + - cslu_kids: https://catalog.ldc.upenn.edu/LDC2007S18 + +To run this recipe, you'll need a copy of both corpora: + ./run.sh --cmu_kids --cslu_kids + +By default, this recipe will download an LM pretrained on LibriSpeech from +lm_url=www.openslr.org/resources/11. If you already have a copy of this LM +and do not wish to redownload, you can specify the LM path using the --lm_src option: + ./run.sh --cmu_kids --cslu_kids \ + --lm_src + +This recipe will also download and clean CMU_Dict by default. If you have a clean copy +already, or wish to use your own dictionary, simply copy your version of the dict to + data/local/dict + +To run extra features for triphone models or VLTN, set the following options true: + ./run.sh --cmu_kids --cslu_kids \ + --vtln true --extra_features true diff --git a/egs/cmu_cslu_kids/s5/cmd.sh b/egs/cmu_cslu_kids/s5/cmd.sh new file mode 100644 index 000000000..179307556 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/cmd.sh @@ -0,0 +1,23 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export train_cmd=queue.pl +export decode_cmd="queue.pl --mem 2G" +# the use of cuda_cmd is deprecated, used only in 'nnet1', +export cuda_cmd="queue.pl --gpu 1" + +if [[ "$(hostname -f)" == "*.fit.vutbr.cz" ]]; then + queue_conf=$HOME/queue_conf/default.conf # see example /homes/kazi/iveselyk/queue_conf/default.conf, + export train_cmd="queue.pl --config $queue_conf --mem 2G --matylda 0.2" + export decode_cmd="queue.pl --config $queue_conf --mem 3G --matylda 0.1" + export cuda_cmd="queue.pl --config $queue_conf --gpu 1 --mem 10G --tmp 40G" +fi diff --git a/egs/cmu_cslu_kids/s5/conf/decode.config b/egs/cmu_cslu_kids/s5/conf/decode.config new file mode 100644 index 000000000..10b0eee90 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/conf/decode.config @@ -0,0 +1,4 @@ +# Use wider-than-normal decoding beams for RM. +first_beam=16.0 +beam=20.0 +lattice_beam=10.0 diff --git a/egs/cmu_cslu_kids/s5/conf/decode_dnn.config b/egs/cmu_cslu_kids/s5/conf/decode_dnn.config new file mode 100644 index 000000000..e7cfca747 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/conf/decode_dnn.config @@ -0,0 +1,8 @@ +# In RM, the optimal decode LMWT is in range 2..5, which is different from usual 10..15 +# (it is caused by using simple rule-based LM, instead of n-gram LM), +scoring_opts="--min-lmwt 2 --max-lmwt 10" +# Still, it is better to use --acwt 0.1, both for decoding and sMBR, +acwt=0.1 +# For this small task we can afford to have large beams, +beam=30.0 # beam for decoding. Was 13.0 in the scripts. +lattice_beam=18.0 # this has most effect on size of the lattices. diff --git a/egs/cmu_cslu_kids/s5/conf/mfcc.conf b/egs/cmu_cslu_kids/s5/conf/mfcc.conf new file mode 100644 index 000000000..6bbcb7631 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/conf/mfcc.conf @@ -0,0 +1,2 @@ +--use-energy=false # only non-default option. +--allow_downsample=true diff --git a/egs/cmu_cslu_kids/s5/conf/mfcc_hires.conf b/egs/cmu_cslu_kids/s5/conf/mfcc_hires.conf new file mode 100644 index 000000000..40f95e970 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/conf/mfcc_hires.conf @@ -0,0 +1,11 @@ +# config for high-resolution MFCC features, intended for neural network training +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=20 # low cutoff frequency for mel bins... this is high-bandwidth data, so + # there might be some information at the low end. +--high-freq=-400 # high cutoff frequently, relative to Nyquist of 8000 (=7600) +--allow-downsample=true diff --git a/egs/cmu_cslu_kids/s5/conf/online_cmvn.conf b/egs/cmu_cslu_kids/s5/conf/online_cmvn.conf new file mode 100644 index 000000000..7748a4a4d --- /dev/null +++ b/egs/cmu_cslu_kids/s5/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh diff --git a/egs/cmu_cslu_kids/s5/conf/plp.conf b/egs/cmu_cslu_kids/s5/conf/plp.conf new file mode 100644 index 000000000..e7e8a9e14 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/conf/plp.conf @@ -0,0 +1,2 @@ +# No non-default options for now. +--allow_downsample=true diff --git a/egs/cmu_cslu_kids/s5/local/chain/compare_wer.sh b/egs/cmu_cslu_kids/s5/local/chain/compare_wer.sh new file mode 100755 index 000000000..8ee5db232 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/chain/compare_wer.sh @@ -0,0 +1,137 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/tdnn_{c,d}_sp +# For use with discriminatively trained systems you specify the epochs after a colon: +# for instance, +# local/chain/compare_wer.sh exp/chain/tdnn_c_sp exp/chain/tdnn_c_sp_smbr:{1,2,3} + + +if [ $# == 0 ]; then + echo "Usage: $0: [--looped] [--online] [ ... ]" + echo "e.g.: $0 exp/chain/tdnn_{b,c}_sp" + echo "or (with epoch numbers for discriminative training):" + echo "$0 exp/chain/tdnn_b_sp_disc:{1,2,3}" + exit 1 +fi + +echo "# $0 $*" + +include_looped=false +if [ "$1" == "--looped" ]; then + include_looped=true + shift +fi +include_online=false +if [ "$1" == "--online" ]; then + include_online=true + shift +fi + + +used_epochs=false + +# this function set_names is used to separate the epoch-related parts of the name +# [for discriminative training] and the regular parts of the name. +# If called with a colon-free directory name, like: +# set_names exp/chain/tdnn_lstm1e_sp_bi_smbr +# it will set dir=exp/chain/tdnn_lstm1e_sp_bi_smbr and epoch_infix="" +# If called with something like: +# set_names exp/chain/tdnn_d_sp_smbr:3 +# it will set dir=exp/chain/tdnn_d_sp_smbr and epoch_infix="_epoch3" + + +set_names() { + if [ $# != 1 ]; then + echo "compare_wer_general.sh: internal error" + exit 1 # exit the program + fi + dirname=$(echo $1 | cut -d: -f1) + epoch=$(echo $1 | cut -s -d: -f2) + if [ -z $epoch ]; then + epoch_infix="" + else + used_epochs=true + epoch_infix=_epoch${epoch} + fi +} + + + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +strings=( + "#WER dev_clean_2 (tgsmall) " + "#WER dev_clean_2 (tglarge) ") + +for n in 0 1; do + echo -n "${strings[$n]}" + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + decode_names=(tgsmall_dev_clean_2 tglarge_dev_clean_2) + + wer=$(cat $dirname/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + if $include_looped; then + echo -n "# [looped:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat $dirname/decode_looped_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi + if $include_online; then + echo -n "# [online:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat ${dirname}_online/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi +done + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -v xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob (xent)" +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.final.log | grep -w xent | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Num-params " +for x in $*; do + printf "% 10s" $(grep num-parameters $x/log/progress.1.log | awk '{print $2}') +done +echo diff --git a/egs/cmu_cslu_kids/s5/local/chain/run_tdnnf.sh b/egs/cmu_cslu_kids/s5/local/chain/run_tdnnf.sh new file mode 120000 index 000000000..344993628 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/chain/run_tdnnf.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1a.sh \ No newline at end of file diff --git a/egs/cmu_cslu_kids/s5/local/chain/tdnnf_decode.sh b/egs/cmu_cslu_kids/s5/local/chain/tdnnf_decode.sh new file mode 100755 index 000000000..8d1241935 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/chain/tdnnf_decode.sh @@ -0,0 +1,82 @@ +#! /bin/bash + +# Copyright Johns Hopkins University +# 2019 Fei Wu + +# Decode on new data set using trained model. +# The data directory should be prepared in kaldi style. +# Usage: +# ./local/chain/tdnnF_decode.sh --data_src + +set -euo pipefail +echo "$0 $@" + +stage=0 +decode_nj=10 +data_src= +affix= +tree_affix= +nnet3_affix= + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat </dev/null || true + + ( + nspk=$(wc -l <$data_hires/spk2utt) + steps/nnet3/decode.sh \ + --acwt 1.0 --post-decode-acwt 10.0 \ + --frames-per-chunk $frames_per_chunk \ + --nj $nspk --cmd "$decode_cmd" --num-threads 4 \ + --online-ivector-dir $ivect_dir \ + $tree_dir/graph_tgsmall $data_hires ${dir}/decode_tgsmall_$data_name || exit 1 + + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_test_{tgsmall,tglarge} \ + $data_hires ${dir}/decode_{tgsmall,tglarge}_$data_name || exit 1 + ) || touch $dir/.error & + + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 +fi + diff --git a/egs/cmu_cslu_kids/s5/local/chain/tuning/run_tdnn_1a.sh b/egs/cmu_cslu_kids/s5/local/chain/tuning/run_tdnn_1a.sh new file mode 100755 index 000000000..51e0123d0 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/chain/tuning/run_tdnn_1a.sh @@ -0,0 +1,279 @@ +#!/bin/bash + +# Copyright 2017-2018 Johns Hopkins University (author: Daniel Povey) +# 2017-2018 Yiming Wang +# 2019 Fei Wu + +# Based on material recipe for low-resource languages +# Factored TDNN with skip connectiong and splicing (two bottle neck layers) + +# WER results on dev +# Model LM Corpus WER(%) +# tdnn_1a tg_large Combined 11.72 +# tdnn_1a tg_small Combined 13.61 +# tdnn_1a tg_large CMU_Kids 17.26 +# tdnn_1a tg_small CMU_Kids 26.43 +# tdnn_1a tg_large CSLU_Kids 10.80 +# tdnn_1a tg_small CSLU_Kids 12.50 + +# steps/info/chain_dir_info.pl exp/chain/tdnn1a_sp +# exp/chain/tdnn1a_sp/: num-iters=342 nj=2..5 num-params=17.9M dim=40+100->3192 combine=-0.042->-0.041 (over 8) xent:train/valid[227,341,final]=(-0.451,-0.363,-0.346/-0.524,-0.466,-0.434) logprob:train/valid[227,341,final]=(-0.047,-0.043,-0.042/-0.058,-0.056,-0.054) + +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +nj=10 +train_set=train +test_sets="test" +gmm=tri3 +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1a +tree_affix= +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# training chunk-options +chunk_width=140,100,160 +dropout_schedule='0,0@0.20,0.3@0.50,0' +common_egs_dir= +xent_regularize=0.1 + +# training options +srand=0 +remove_egs=true +reporting_email= + + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat <$lang/topo + fi +fi + +if [ $stage -le 8 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 75 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + + +if [ $stage -le 10 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor 3 \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + +if [ $stage -le 11 ]; then + mkdir -p $dir + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + opts="l2-regularize=0.004 dropout-proportion=0.0 dropout-per-dim=true dropout-per-dim-continuous=true" + linear_opts="orthonormal-constraint=-1.0 l2-regularize=0.004" + output_opts="l2-regularize=0.002" + + mkdir -p $dir/configs + + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-dropout-layer name=tdnn1 $opts dim=1024 + linear-component name=tdnn2l0 dim=256 $linear_opts input=Append(-1,0) + linear-component name=tdnn2l dim=256 $linear_opts input=Append(-1,0) + relu-batchnorm-dropout-layer name=tdnn2 $opts input=Append(0,1) dim=1024 + linear-component name=tdnn3l dim=256 $linear_opts input=Append(-1,0) + relu-batchnorm-dropout-layer name=tdnn3 $opts dim=1024 input=Append(0,1) + linear-component name=tdnn4l0 dim=256 $linear_opts input=Append(-1,0) + linear-component name=tdnn4l dim=256 $linear_opts input=Append(0,1) + relu-batchnorm-dropout-layer name=tdnn4 $opts input=Append(0,1) dim=1024 + linear-component name=tdnn5l dim=256 $linear_opts + relu-batchnorm-dropout-layer name=tdnn5 $opts dim=1024 input=Append(0, tdnn3l) + linear-component name=tdnn6l0 dim=256 $linear_opts input=Append(-3,0) + linear-component name=tdnn6l dim=256 $linear_opts input=Append(-3,0) + relu-batchnorm-dropout-layer name=tdnn6 $opts input=Append(0,3) dim=1280 + linear-component name=tdnn7l0 dim=256 $linear_opts input=Append(-3,0) + linear-component name=tdnn7l dim=256 $linear_opts input=Append(0,3) + relu-batchnorm-dropout-layer name=tdnn7 $opts input=Append(0,3,tdnn6l,tdnn4l,tdnn2l) dim=1024 + linear-component name=tdnn8l0 dim=256 $linear_opts input=Append(-3,0) + linear-component name=tdnn8l dim=256 $linear_opts input=Append(0,3) + relu-batchnorm-dropout-layer name=tdnn8 $opts input=Append(0,3) dim=1280 + linear-component name=tdnn9l0 dim=256 $linear_opts input=Append(-3,0) + linear-component name=tdnn9l dim=256 $linear_opts input=Append(-3,0) + relu-batchnorm-dropout-layer name=tdnn9 $opts input=Append(0,3,tdnn8l,tdnn6l,tdnn5l) dim=1024 + linear-component name=tdnn10l0 dim=256 $linear_opts input=Append(-3,0) + linear-component name=tdnn10l dim=256 $linear_opts input=Append(0,3) + relu-batchnorm-dropout-layer name=tdnn10 $opts input=Append(0,3) dim=1280 + linear-component name=tdnn11l0 dim=256 $linear_opts input=Append(-3,0) + linear-component name=tdnn11l dim=256 $linear_opts input=Append(-3,0) + relu-batchnorm-dropout-layer name=tdnn11 $opts input=Append(0,3,tdnn10l,tdnn9l,tdnn7l) dim=1024 + linear-component name=prefinal-l dim=256 $linear_opts + + relu-batchnorm-layer name=prefinal-chain input=prefinal-l $opts dim=1280 + linear-component name=prefinal-chain-l dim=256 $linear_opts + batchnorm-component name=prefinal-chain-batchnorm + output-layer name=output include-log-softmax=false dim=$num_targets $output_opts + + relu-batchnorm-layer name=prefinal-xent input=prefinal-l $opts dim=1280 + linear-component name=prefinal-xent-l dim=256 $linear_opts + batchnorm-component name=prefinal-xent-batchnorm + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor $output_opts + +EOF + + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ + +fi + + +if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/chain/train.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --chain.xent-regularize $xent_regularize \ + --chain.leaky-hmm-coefficient=0.1 \ + --chain.l2-regularize=0.0 \ + --chain.apply-deriv-weights=false \ + --chain.lm-opts="--num-extra-lm-states=2000" \ + --trainer.dropout-schedule $dropout_schedule \ + --trainer.add-option="--optimization.memory-compression-level=2" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=20 \ + --trainer.frames-per-iter=3000000 \ + --trainer.optimization.num-jobs-initial=2 \ + --trainer.optimization.num-jobs-final=5 \ + --trainer.optimization.initial-effective-lrate=0.002 \ + --trainer.optimization.final-effective-lrate=0.0002 \ + --trainer.num-chunk-per-minibatch=128,64 \ + --egs.chunk-width=$chunk_width \ + --egs.dir="$common_egs_dir" \ + --egs.opts="--frames-overlap-per-eg 0" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --tree-dir=$tree_dir \ + --lat-dir=$lat_dir \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 13 ]; then + # Note: it's not important to give mkgraph.sh the lang directory with the + # matched topology (since it gets the topology file from the model). + utils/mkgraph.sh \ + --self-loop-scale 1.0 data/lang_test_tgsmall \ + $tree_dir $tree_dir/graph_tgsmall || exit 1; +fi + +if [ $stage -le 14 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l tmp + cut -f 3- < tmp > out + + tr '[:lower:]' '[:upper:]' < out > tmp + tr -d '[:cntrl:]' < tmp > out + sent=$( out + tr '[:lower:]' '[:upper:]' < tmp > out + trans=$(> $data/$target/utt2spk + echo "$uttID $KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe -f wav -p -c 1 $utt|" >> $data/$target/wav.scp + echo "$spkID f" >> $data/$target/spk2gender + echo "$uttID $sent" >> $data/$target/text + fi + done + fi + fi +done + +for d in $data/train $data/test; do + utils/utt2spk_to_spk2utt.pl $d/utt2spk > $d/spk2utt + utils/fix_data_dir.sh $d +done + +printf "\t total: %s; train: %s; test: %s.\n" "$total_cnt" "$train_cnt" "$test_cnt" +rm -f out tmp + +# Optional +# Get data duration, just for book keeping +# for data in $data/train $data/test; do +# ./local/data_duration.sh $data +# done +# + diff --git a/egs/cmu_cslu_kids/s5/local/cslu_aud_prep.sh b/egs/cmu_cslu_kids/s5/local/cslu_aud_prep.sh new file mode 100755 index 000000000..735f87eca --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/cslu_aud_prep.sh @@ -0,0 +1,43 @@ +#/bin/bash + +# Copyright Johns Hopkins University +# 2019 Fei Wu + +# Called by local/cslu_DataPrep.shi + +Assignment() +{ + rnd=$((1+RANDOM % 100)) + if [ $rnd -le $test_percentage ]; then + target="test" + else + target="train" + fi +} +audio= +test_percentage=30 # Percent of data reserved as test set +debug=debug/cslu_dataprep_debug +data=data/data_cslu +. ./utils/parse_options.sh + +uttID=$(basename $audio) +uttID=${uttID%'.wav'} +sentID=${uttID: -3} +spkID=${uttID%$sentID} +sentID=${sentID%"0"} +sentID=$(echo "$sentID" | tr '[:lower:]' '[:upper:]' ) + +line=$(grep $sentID cslu/docs/all.map) + +if [ -z "$line" ]; then # Can't map utterance to transcript + echo $audio $sentID >> $debug +else + txt=$(echo $line | grep -oP '"\K.*?(?=")') + cap_txt=${txt^^} + Assignment + echo "$uttID $cap_txt" >> $data/$target/text + echo "$uttID $spkID" >> $data/$target/utt2spk + echo "$spkID f" >> $data/$target/spk2gender + echo "$uttID $audio" >> $data/$target/wav.scp +fi + diff --git a/egs/cmu_cslu_kids/s5/local/cslu_prepare_data.sh b/egs/cmu_cslu_kids/s5/local/cslu_prepare_data.sh new file mode 100755 index 000000000..621179079 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/cslu_prepare_data.sh @@ -0,0 +1,49 @@ +#! /bin/bash + +# Copyright Johns Hopkins University +# 2019 Fei Wu + +# Prepares cslu_kids +# Should be run from egs/cmu_csli_kids + +set -e +Looper() +{ + # echo "Looping through $1" + for f in $1/*; do + if [ -d $f ]; then + Looper $f + else + ./local/cslu_aud_prep.sh --data $data --audio $f + fi + done +} + +data=data/data_cslu +corpus=cslu +. ./utils/parse_options.sh + +rm -f debug/cslu_dataprep_debug +mkdir -p debug +# File check, remove previous data and features files +for d in $data/test $data/train; do + mkdir -p $d + ./local/file_check.sh $d +done + +echo "Preparing cslu_kids..." +Looper $corpus/speech/scripted + +for d in $data/test $data/train; do + ./utils/utt2spk_to_spk2utt.pl $d + ./utils/fix_data_dir.sh $d +done +if [ -f debug/cslu_dataprep_debug ]; then + echo "Missing transcripts for some utterances. See cslu_dataprep_debug" +fi + +# Optional +# Get data duration, just for book keeping +# for data in data/data_cslu/test data/data_cslu/train; do +# ./local/data_duration.sh $data +# done diff --git a/egs/cmu_cslu_kids/s5/local/data_duration.sh b/egs/cmu_cslu_kids/s5/local/data_duration.sh new file mode 100755 index 000000000..e838e365e --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/data_duration.sh @@ -0,0 +1,19 @@ +#! /bin/bash + +# Copyright Johns Hopkins University +# 2019 Fei Wu + +# Get duration of the utterance given data dir +set -eu +echo $0 $@ + +data_dir=$1 +mkdir -p duration + +./utils/data/get_utt2dur.sh $data_dir + +echo "$data_dir" +python local/sum_duration.py $data_dir/utt2dur +echo "" + + diff --git a/egs/cmu_cslu_kids/s5/local/download_cmu_dict.sh b/egs/cmu_cslu_kids/s5/local/download_cmu_dict.sh new file mode 100755 index 000000000..0248dd0ca --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/download_cmu_dict.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2019 Fei Wu +set -eu +# Adapted from the local/prepare_dict script in +# the librispeech recipe. Download and prepare CMU_dict. +# For childresn speech ASR tasks, since the vocabulary in cmu_kids and +# cslu_kids is relatively easy comparing to librispeech, we use only the +# CMU_dict, and do not handle OOV with G2P. +# Should be run from egs/cmu_cslu_kids. +# Usage: +# local/download_cmu_dict.sh --dict_dir + +dict_dir=data/local/dict +OOV="" + +. ./utils/parse_options.sh || exit 1; +. ./path.sh || exit 1 + +if [ ! -d $dict_dir ]; then + echo "Downloading and preparing CMU dict" + svn co -r 12440 https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict $dict_dir/raw_dict || exit 1; + + echo "Removing the pronunciation variant markers ..." + grep -v ';;;' $dict_dir/raw_dict/cmudict.0.7a | \ + perl -ane 'if(!m:^;;;:){ s:(\S+)\(\d+\) :$1 :; print; }' | \ + sort -u > $dict_dir/lexicon.txt || exit 1; + + tr -d '\r' < $dict_dir/raw_dict/cmudict.0.7a.symbols > $dict_dir/nonsilence_phones.txt + + echo "$OOV SIL" >> $dict_dir/lexicon.txt + + echo "SIL" > $dict_dir/silence_phones.txt + echo "SPN" >> $dict_dir/silence_phones.txt + echo "SIL" > $dict_dir/optional_silence.txt + + rm -rf $dict_dir/raw_dict +fi diff --git a/egs/cmu_cslu_kids/s5/local/download_lm.sh b/egs/cmu_cslu_kids/s5/local/download_lm.sh new file mode 100755 index 000000000..382f313df --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/download_lm.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# Copyright 2014 Vassil Panayotov +# Apache 2.0 + +if [ $# -ne "2" ]; then + echo "Usage: $0 " + echo "e.g.: $0 http://www.openslr.org/resources/11 data/local/lm" + exit 1 +fi + +base_url=$1 +dst_dir=$2 + +# given a filename returns the corresponding file size in bytes +# The switch cases below can be autogenerated by entering the data directory and running: +# for f in *; do echo "\"$f\") echo \"$(du -b $f | awk '{print $1}')\";;"; done +function filesize() { + case $1 in + "3-gram.arpa.gz") echo "759636181";; + "3-gram.pruned.1e-7.arpa.gz") echo "34094057";; + "3-gram.pruned.3e-7.arpa.gz") echo "13654242";; + "4-gram.arpa.gz") echo "1355172078";; + "g2p-model-5") echo "20098243";; + "librispeech-lexicon.txt") echo "5627653";; + "librispeech-lm-corpus.tgz") echo "1803499244";; + "librispeech-lm-norm.txt.gz") echo "1507274412";; + "librispeech-vocab.txt") echo "1737588";; + *) echo "";; + esac +} + +function check_and_download () { + [[ $# -eq 1 ]] || { echo "check_and_download() expects exactly one argument!"; return 1; } + fname=$1 + echo "Downloading file '$fname' into '$dst_dir'..." + expect_size="$(filesize $fname)" + [[ ! -z "$expect_size" ]] || { echo "Unknown file size for '$fname'"; return 1; } + if [[ -s $dst_dir/$fname ]]; then + # In the following statement, the first version works on linux, and the part + # after '||' works on Linux. + f=$dst_dir/$fname + fsize=$(set -o pipefail; du -b $f 2>/dev/null | awk '{print $1}' || stat '-f %z' $f) + if [[ "$fsize" -eq "$expect_size" ]]; then + echo "'$fname' already exists and appears to be complete" + return 0 + else + echo "WARNING: '$fname' exists, but the size is wrong - re-downloading ..." + fi + fi + wget --no-check-certificate -O $dst_dir/$fname $base_url/$fname || { + echo "Error while trying to download $fname!" + return 1 + } + f=$dst_dir/$fname + # In the following statement, the first version works on linux, and the part after '||' + # works on Linux. + fsize=$(set -o pipefail; du -b $f 2>/dev/null | awk '{print $1}' || stat '-f %z' $f) + [[ "$fsize" -eq "$expect_size" ]] || { echo "$fname: file size mismatch!"; return 1; } + return 0 +} + +mkdir -p $dst_dir + +for f in 3-gram.arpa.gz 3-gram.pruned.1e-7.arpa.gz 3-gram.pruned.3e-7.arpa.gz 4-gram.arpa.gz \ + g2p-model-5 librispeech-lm-corpus.tgz librispeech-vocab.txt librispeech-lexicon.txt; do + check_and_download $f || exit 1 +done + +cd $dst_dir +ln -sf 3-gram.pruned.1e-7.arpa.gz lm_tgmed.arpa.gz +ln -sf 3-gram.pruned.3e-7.arpa.gz lm_tgsmall.arpa.gz +ln -sf 3-gram.arpa.gz lm_tglarge.arpa.gz +ln -sf 4-gram.arpa.gz lm_fglarge.arpa.gz + +exit 0 diff --git a/egs/cmu_cslu_kids/s5/local/file_check.sh b/egs/cmu_cslu_kids/s5/local/file_check.sh new file mode 100755 index 000000000..859f22805 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/file_check.sh @@ -0,0 +1,17 @@ +#! /bin/bash + +# Copyright Johns Hopkins University +# 2019 Fei Wu + + +printf "\t File Check in folder: %s.\n" "$1" + +WavScp="$1/wav.scp" +Text="$1/text" +Utt2Spk="$1/utt2spk" +Gend="$1/utt2gender" +Spk2Utt="$1/spk2utt" +rm -f $WavScp $Text $Utt2Spk $Gend $Spk2Utt + + + diff --git a/egs/cmu_cslu_kids/s5/local/format_lms.sh b/egs/cmu_cslu_kids/s5/local/format_lms.sh new file mode 100755 index 000000000..b530f61d2 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/format_lms.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +# Copyright 2014 Vassil Panayotov +# Apache 2.0 + +# Prepares the test time language model(G) transducers +# (adapted from wsj/s5/local/wsj_format_data.sh) + +. ./path.sh || exit 1; + +# begin configuration section +src_dir=data/lang +# end configuration section + +. utils/parse_options.sh || exit 1; + +set -e + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/a15/vpanayotov/data/lm" + echo ", where:" + echo " is the directory in which the language model is stored/downloaded" + echo "Options:" + echo " --src-dir

# source lang directory, default data/lang" + exit 1 +fi + +lm_dir=$1 + +if [ ! -d $lm_dir ]; then + echo "$0: expected source LM directory $lm_dir to exist" + exit 1; +fi +if [ ! -f $src_dir/words.txt ]; then + echo "$0: expected $src_dir/words.txt to exist." + exit 1; +fi + + +tmpdir=data/local/lm_tmp.$$ +trap "rm -r $tmpdir" EXIT + +mkdir -p $tmpdir + +for lm_suffix in tgsmall tgmed; do + # tglarge is prepared by a separate command, called from run.sh; we don't + # want to compile G.fst for tglarge, as it takes a while. + test=${src_dir}_test_${lm_suffix} + mkdir -p $test + cp -r ${src_dir}/* $test + gunzip -c $lm_dir/lm_${lm_suffix}.arpa.gz | \ + arpa2fst --disambig-symbol=#0 \ + --read-symbol-table=$test/words.txt - $test/G.fst + utils/validate_lang.pl --skip-determinization-check $test || exit 1; +done + +echo "Succeeded in formatting data." + +exit 0 diff --git a/egs/cmu_cslu_kids/s5/local/make_lm.pl b/egs/cmu_cslu_kids/s5/local/make_lm.pl new file mode 100755 index 000000000..80eea5a61 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/make_lm.pl @@ -0,0 +1,119 @@ +#!/usr/bin/env perl + +# Copyright 2010-2011 Yanmin Qian Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# This file takes as input the file wp_gram.txt that comes with the RM +# distribution, and creates the language model as an acceptor in FST form. + +# make_rm_lm.pl wp_gram.txt > G.txt + +if (@ARGV != 1) { + print "usage: make_rm_lm.pl wp_gram.txt > G.txt\n"; + exit(0); +} +unless (open(IN_FILE, "@ARGV[0]")) { + die ("can't open @ARGV[0]"); +} + + +$flag = 0; +$count_wrd = 0; +$cnt_ends = 0; +$init = ""; + +while ($line = ) +{ + chop($line); # Return the last char + + $line =~ s/ //g; # Selete all spaces + + if(($line =~ /^>/)) # If line has ">" + { + if($flag == 0) # Flip flag + { + $flag = 1; + } + $line =~ s/>//g; # Delete ">" + $hashcnt{$init} = $i; + $init = $line; + $i = 0; + $count_wrd++; + @LineArray[$count_wrd - 1] = $init; + $hashwrd{$init} = 0; + } + elsif($flag != 0) + { + + $hash{$init}[$i] = $line; + $i++; + if($line =~ /SENTENCE-END/) + { + $cnt_ends++; + } + } + else + {} +} + +$hashcnt{$init} = $i; + +$num = 0; +$weight = 0; +$init_wrd = "SENTENCE-END"; +$hashwrd{$init_wrd} = @LineArray; +for($i = 0; $i < $hashcnt{$init_wrd}; $i++) +{ + $weight = -log(1/$hashcnt{$init_wrd}); + $hashwrd{$hash{$init_wrd}[$i]} = $i + 1; + print "0 $hashwrd{$hash{$init_wrd}[$i]} $hash{$init_wrd}[$i] $hash{$init_wrd}[$i] $weight\n"; +} +$num = $i; + +for($i = 0; $i < @LineArray; $i++) +{ + if(@LineArray[$i] eq 'SENTENCE-END') + {} + else + { + if($hashwrd{@LineArray[$i]} == 0) + { + $num++; + $hashwrd{@LineArray[$i]} = $num; + } + for($j = 0; $j < $hashcnt{@LineArray[$i]}; $j++) + { + $weight = -log(1/$hashcnt{@LineArray[$i]}); + if($hashwrd{$hash{@LineArray[$i]}[$j]} == 0) + { + $num++; + $hashwrd{$hash{@LineArray[$i]}[$j]} = $num; + } + if($hash{@LineArray[$i]}[$j] eq 'SENTENCE-END') + { + print "$hashwrd{@LineArray[$i]} $hashwrd{$hash{@LineArray[$i]}[$j]} $weight\n" + } + else + { + print "$hashwrd{@LineArray[$i]} $hashwrd{$hash{@LineArray[$i]}[$j]} $hash{@LineArray[$i]}[$j] $hash{@LineArray[$i]}[$j] $weight\n"; + } + } + } +} + +print "$hashwrd{$init_wrd} 0\n"; +close(IN_FILE); + + diff --git a/egs/cmu_cslu_kids/s5/local/nnet3/compare_wer.sh b/egs/cmu_cslu_kids/s5/local/nnet3/compare_wer.sh new file mode 100755 index 000000000..095e85cc3 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/nnet3/compare_wer.sh @@ -0,0 +1,132 @@ +#!/bin/bash + +# this script is used for comparing decoding results between systems. +# e.g. local/chain/compare_wer.sh exp/chain/tdnn_{c,d}_sp +# For use with discriminatively trained systems you specify the epochs after a colon: +# for instance, +# local/chain/compare_wer.sh exp/chain/tdnn_c_sp exp/chain/tdnn_c_sp_smbr:{1,2,3} + + +if [ $# == 0 ]; then + echo "Usage: $0: [--looped] [--online] [ ... ]" + echo "e.g.: $0 exp/chain/tdnn_{b,c}_sp" + echo "or (with epoch numbers for discriminative training):" + echo "$0 exp/chain/tdnn_b_sp_disc:{1,2,3}" + exit 1 +fi + +echo "# $0 $*" + +include_looped=false +if [ "$1" == "--looped" ]; then + include_looped=true + shift +fi +include_online=false +if [ "$1" == "--online" ]; then + include_online=true + shift +fi + + +used_epochs=false + +# this function set_names is used to separate the epoch-related parts of the name +# [for discriminative training] and the regular parts of the name. +# If called with a colon-free directory name, like: +# set_names exp/chain/tdnn_lstm1e_sp_bi_smbr +# it will set dir=exp/chain/tdnn_lstm1e_sp_bi_smbr and epoch_infix="" +# If called with something like: +# set_names exp/chain/tdnn_d_sp_smbr:3 +# it will set dir=exp/chain/tdnn_d_sp_smbr and epoch_infix="_epoch3" + + +set_names() { + if [ $# != 1 ]; then + echo "compare_wer_general.sh: internal error" + exit 1 # exit the program + fi + dirname=$(echo $1 | cut -d: -f1) + epoch=$(echo $1 | cut -s -d: -f2) + if [ -z $epoch ]; then + epoch_infix="" + else + used_epochs=true + epoch_infix=_epoch${epoch} + fi +} + + + +echo -n "# System " +for x in $*; do printf "% 10s" " $(basename $x)"; done +echo + +strings=( + "#WER dev_clean_2 (tgsmall) " + "#WER dev_clean_2 (tglarge) ") + +for n in 0 1; do + echo -n "${strings[$n]}" + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + decode_names=(tgsmall_dev_clean_2 tglarge_dev_clean_2) + + wer=$(cat $dirname/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + if $include_looped; then + echo -n "# [looped:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat $dirname/decode_looped_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi + if $include_online; then + echo -n "# [online:] " + for x in $*; do + set_names $x # sets $dirname and $epoch_infix + wer=$(cat ${dirname}_online/decode_${decode_names[$n]}/wer_* | utils/best_wer.sh | awk '{print $2}') + printf "% 10s" $wer + done + echo + fi +done + + +if $used_epochs; then + exit 0; # the diagnostics aren't comparable between regular and discriminatively trained systems. +fi + +echo -n "# Final train prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.{final,combined}.log 2>/dev/null | grep log-like | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid prob " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.{final,combined}.log 2>/dev/null | grep log-like | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final train acc " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_train.{final,combined}.log 2>/dev/null | grep accuracy | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo -n "# Final valid acc " +for x in $*; do + prob=$(grep Overall $x/log/compute_prob_valid.{final,combined}.log 2>/dev/null | grep accuracy | awk '{printf("%.4f", $8)}') + printf "% 10s" $prob +done +echo + +echo diff --git a/egs/cmu_cslu_kids/s5/local/nnet3/run_ivector_common.sh b/egs/cmu_cslu_kids/s5/local/nnet3/run_ivector_common.sh new file mode 100755 index 000000000..c695f2c9f --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/nnet3/run_ivector_common.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +set -euo pipefail + +# This script is called from local/nnet3/run_tdnn.sh and +# local/chain/run_tdnn.sh (and may eventually be called by more +# scripts). It contains the common feature preparation and +# iVector-related parts of the script. See those scripts for examples +# of usage. + +stage=0 +train_set=train +test_sets="test" +gmm=tri3b + +nnet3_affix= + +. ./cmd.sh +. ./path.sh +. utils/parse_options.sh + +gmm_dir=exp/${gmm} +ali_dir=exp/${gmm}_ali_${train_set}_sp + +for f in data/${train_set}/feats.scp ${gmm_dir}/final.mdl; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + +if [ $stage -le 1 ]; then + # Although the nnet will be trained by high resolution data, we still have to + # perturb the normal data to get the alignment _sp stands for speed-perturbed + echo "$0: preparing directory for low-resolution speed-perturbed data (for alignment)" + utils/data/perturb_data_dir_speed_3way.sh data/${train_set} data/${train_set}_sp + echo "$0: making MFCC features for low-resolution speed-perturbed data" + steps/make_mfcc.sh --cmd "$train_cmd" --nj 10 data/${train_set}_sp || exit 1; + steps/compute_cmvn_stats.sh data/${train_set}_sp || exit 1; + utils/fix_data_dir.sh data/${train_set}_sp +fi + +if [ $stage -le 2 ]; then + echo "$0: aligning with the perturbed low-resolution data" + steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/${train_set}_sp data/lang $gmm_dir $ali_dir || exit 1 +fi + +if [ $stage -le 3 ]; then + # Create high-resolution MFCC features (with 40 cepstra instead of 13). + # this shows how you can split across multiple file-systems. + echo "$0: creating high-resolution MFCC features" + mfccdir=data/${train_set}_sp_hires/data + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/fs0{1,2}/$USER/kaldi-data/mfcc/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + for datadir in ${train_set}_sp ${test_sets}; do + utils/copy_data_dir.sh data/$datadir data/${datadir}_hires + done + + # do volume-perturbation on the training data prior to extracting hires + # features; this helps make trained nnets more invariant to test data volume. + utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires || exit 1; + + for datadir in ${train_set}_sp ${test_sets}; do + steps/make_mfcc.sh --nj 10 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${datadir}_hires || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_hires || exit 1; + utils/fix_data_dir.sh data/${datadir}_hires || exit 1; + done +fi + +if [ $stage -le 4 ]; then + echo "$0: computing a subset of data to train the diagonal UBM." + # We'll use about a quarter of the data. + mkdir -p exp/nnet3${nnet3_affix}/diag_ubm + temp_data_root=exp/nnet3${nnet3_affix}/diag_ubm + + num_utts_total=$(wc -l 2041 combine=-0.47->-0.38 loglike:train/valid[20,31,combined]=(-0.62,-0.38,-0.37/-1.03,-1.03,-1.02) accuracy:train/valid[20,31,combined]=(0.79,0.87,0.87/0.70,0.72,0.72) + +# Below, comparing with the chain TDNN system. It's a little better with the +# small-vocab decoding. Both systems are probably super-badly tuned, and the +# chain system probably used too many jobs. +# +# local/nnet3/compare_wer.sh exp/chain/tdnn1a_sp exp/nnet3/tdnn_lstm1a_sp +# System tdnn1a_sp tdnn_lstm1a_sp +#WER dev_clean_2 (tgsmall) 18.43 17.37 +#WER dev_clean_2 (tglarge) 13.15 13.43 +# Final train prob -0.3933 +# Final valid prob -0.9662 +# Final train acc 0.8652 +# Final valid acc 0.7206 + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=10 +train_set=train_clean_5 +test_sets=dev_clean_2 +gmm=tri3b +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1a # affix for the TDNN directory name +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# training options +# training chunk-options +chunk_width=40,30,20 +chunk_left_context=40 +chunk_right_context=0 +common_egs_dir= +xent_regularize=0.1 + +# training options +srand=0 +remove_egs=true +reporting_email= + +#decode options +test_online_decoding=true # if true, it will run the last decoding stage. + + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda delay=$label_delay input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + relu-renorm-layer name=tdnn1 dim=520 + relu-renorm-layer name=tdnn2 dim=520 input=Append(-1,0,1) + fast-lstmp-layer name=lstm1 cell-dim=520 recurrent-projection-dim=130 non-recurrent-projection-dim=130 decay-time=20 delay=-3 + relu-renorm-layer name=tdnn3 dim=520 input=Append(-3,0,3) + relu-renorm-layer name=tdnn4 dim=520 input=Append(-3,0,3) + fast-lstmp-layer name=lstm2 cell-dim=520 recurrent-projection-dim=130 non-recurrent-projection-dim=130 decay-time=20 delay=-3 + relu-renorm-layer name=tdnn5 dim=520 input=Append(-3,0,3) + relu-renorm-layer name=tdnn6 dim=520 input=Append(-3,0,3) + fast-lstmp-layer name=lstm3 cell-dim=520 recurrent-projection-dim=130 non-recurrent-projection-dim=130 decay-time=20 delay=-3 + + output-layer name=output input=lstm3 output-delay=$label_delay dim=$num_targets max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 11 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/train_rnn.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=6 \ + --trainer.deriv-truncate-margin=10 \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=1 \ + --trainer.optimization.num-jobs-final=2 \ + --trainer.optimization.initial-effective-lrate=0.0003 \ + --trainer.optimization.final-effective-lrate=0.00003 \ + --trainer.optimization.shrink-value=0.99 \ + --trainer.rnn.num-chunk-per-minibatch=128,64 \ + --trainer.optimization.momentum=0.5 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --ali-dir=$ali_dir \ + --lang=$lang \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 12 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l /dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l 2041 combine=-0.71->-0.58 loglike:train/valid[20,31,combined]=(-2.78,-0.95,-0.57/-2.94,-1.31,-0.98) accuracy:train/valid[20,31,combined]=(0.48,0.75,0.81/0.45,0.67,0.71) + +# local/nnet3/compare_wer.sh --online exp/nnet3/tdnn_lstm1a_sp exp/nnet3/tdnn_lstm1b_sp +# System tdnn_lstm1a_sp tdnn_lstm1b_sp +#WER dev_clean_2 (tgsmall) 17.67 17.01 +# [online:] 18.06 17.26 +#WER dev_clean_2 (tglarge) 13.43 12.63 +# [online:] 13.73 12.94 +# Final train prob -0.3660 -0.5680 +# Final valid prob -1.0236 -0.9771 +# Final train acc 0.8737 0.8067 +# Final valid acc 0.7222 0.7144 + + + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=10 +train_set=train_clean_5 +test_sets=dev_clean_2 +gmm=tri3b +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1b # affix for the TDNN+LSTM directory name +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# training options +# training chunk-options +chunk_width=40,30,20 +chunk_left_context=40 +chunk_right_context=0 +common_egs_dir= +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.3@0.50,0' + +# training options +srand=0 +remove_egs=true +reporting_email= + +#decode options +test_online_decoding=true # if true, it will run the last decoding stage. + + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda delay=$label_delay input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + relu-renorm-layer name=tdnn1 dim=520 + relu-renorm-layer name=tdnn2 dim=520 input=Append(-1,0,1) + fast-lstmp-layer name=lstm1 cell-dim=520 recurrent-projection-dim=130 non-recurrent-projection-dim=130 $lstm_opts + relu-renorm-layer name=tdnn3 dim=520 input=Append(-3,0,3) + relu-renorm-layer name=tdnn4 dim=520 input=Append(-3,0,3) + fast-lstmp-layer name=lstm2 cell-dim=520 recurrent-projection-dim=130 non-recurrent-projection-dim=130 $lstm_opts + relu-renorm-layer name=tdnn5 dim=520 input=Append(-3,0,3) + relu-renorm-layer name=tdnn6 dim=520 input=Append(-3,0,3) + fast-lstmp-layer name=lstm3 cell-dim=520 recurrent-projection-dim=130 non-recurrent-projection-dim=130 $lstm_opts + + output-layer name=output input=lstm3 output-delay=$label_delay dim=$num_targets max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 11 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/train_rnn.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=6 \ + --trainer.deriv-truncate-margin=10 \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=1 \ + --trainer.optimization.num-jobs-final=2 \ + --trainer.optimization.initial-effective-lrate=0.0003 \ + --trainer.optimization.final-effective-lrate=0.00003 \ + --trainer.optimization.shrink-value=0.99 \ + --trainer.dropout-schedule="$dropout_schedule" \ + --trainer.rnn.num-chunk-per-minibatch=128,64 \ + --trainer.optimization.momentum=0.5 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --ali-dir=$ali_dir \ + --lang=$lang \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 12 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l /dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l 2041 combine=-0.99->-0.81 loglike:train/valid[20,31,combined]=(-1.22,-0.69,-0.61/-1.34,-1.02,-0.91) accuracy:train/valid[20,31,combined]=(0.68,0.779,0.800/0.64,0.70,0.724) + + + + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=10 +train_set=train_clean_5 +test_sets=dev_clean_2 +gmm=tri3b +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=1c # affix for the TDNN+LSTM directory name +train_stage=-10 +get_egs_stage=-10 +decode_iter= + +# training options +# training chunk-options +chunk_width=40,30,20 +chunk_left_context=40 +chunk_right_context=0 +common_egs_dir= +xent_regularize=0.1 +dropout_schedule='0,0@0.20,0.3@0.50,0' + +# training options +srand=0 +remove_egs=true +reporting_email= + +#decode options +test_online_decoding=true # if true, it will run the last decoding stage. + + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda delay=$label_delay input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + relu-batchnorm-layer name=tdnn1 dim=520 $tdnn_opts + relu-batchnorm-layer name=tdnn2 dim=520 $tdnn_opts input=Append(-1,0,1) + fast-lstmp-layer name=lstm1 cell-dim=520 recurrent-projection-dim=130 non-recurrent-projection-dim=130 $lstm_opts + relu-batchnorm-layer name=tdnn3 dim=520 $tdnn_opts input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn4 dim=520 $tdnn_opts input=Append(-3,0,3) + fast-lstmp-layer name=lstm2 cell-dim=520 recurrent-projection-dim=130 non-recurrent-projection-dim=130 $lstm_opts + relu-batchnorm-layer name=tdnn5 dim=520 $tdnn_opts input=Append(-3,0,3) + relu-batchnorm-layer name=tdnn6 dim=520 $tdnn_opts input=Append(-3,0,3) + fast-lstmp-layer name=lstm3 cell-dim=520 recurrent-projection-dim=130 non-recurrent-projection-dim=130 $lstm_opts + + output-layer name=output input=lstm3 $output_opts output-delay=$label_delay dim=$num_targets max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ +fi + + +if [ $stage -le 11 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/train_rnn.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir=$train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --trainer.srand=$srand \ + --trainer.max-param-change=2.0 \ + --trainer.num-epochs=6 \ + --trainer.deriv-truncate-margin=10 \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=1 \ + --trainer.optimization.num-jobs-final=2 \ + --trainer.optimization.initial-effective-lrate=0.0003 \ + --trainer.optimization.final-effective-lrate=0.00003 \ + --trainer.dropout-schedule="$dropout_schedule" \ + --trainer.rnn.num-chunk-per-minibatch=128,64 \ + --trainer.optimization.momentum=0.5 \ + --egs.chunk-width=$chunk_width \ + --egs.chunk-left-context=$chunk_left_context \ + --egs.chunk-right-context=$chunk_right_context \ + --egs.chunk-left-context-initial=0 \ + --egs.chunk-right-context-final=0 \ + --egs.dir="$common_egs_dir" \ + --cleanup.remove-egs=$remove_egs \ + --use-gpu=true \ + --reporting.email="$reporting_email" \ + --feat-dir=$train_data_dir \ + --ali-dir=$ali_dir \ + --lang=$lang \ + --dir=$dir || exit 1; +fi + +if [ $stage -le 12 ]; then + frames_per_chunk=$(echo $chunk_width | cut -d, -f1) + rm $dir/.error 2>/dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l /dev/null || true + + for data in $test_sets; do + ( + nspk=$(wc -l data/lang/G.fst || exit 1; + +# Checking that G is stochastic [note, it wouldn't be for an Arpa] +fstisstochastic data/lang/G.fst || echo Error: G is not stochastic + +# Checking that G.fst is determinizable. +fstdeterminize data/lang/G.fst /dev/null || echo Error determinizing G. + +# Checking that L_disambig.fst is determinizable. +fstdeterminize data/lang/L_disambig.fst /dev/null || echo Error determinizing L. + +# Checking that disambiguated lexicon times G is determinizable +fsttablecompose data/lang/L_disambig.fst data/lang/G.fst | \ + fstdeterminize >/dev/null || echo Error + +# Checking that LG is stochastic: +fsttablecompose data/lang/L.fst data/lang/G.fst | \ + fstisstochastic || echo Error: LG is not stochastic. + +# Checking that L_disambig.G is stochastic: +fsttablecompose data/lang/L_disambig.fst data/lang/G.fst | \ + fstisstochastic || echo Error: LG is not stochastic. + +echo "Succeeded preparing grammar for CMU_kids." diff --git a/egs/cmu_cslu_kids/s5/local/score.sh b/egs/cmu_cslu_kids/s5/local/score.sh new file mode 100755 index 000000000..c812199fc --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/score.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# 2014 Guoguo Chen +# Apache 2.0 + +[ -f ./path.sh ] && . ./path.sh + +# begin configuration section. +cmd=run.pl +stage=0 +decode_mbr=true +word_ins_penalty=0.0,0.5,1.0 +min_lmwt=7 +max_lmwt=17 +iter=final +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --decode_mbr (true/false) # maximum bayes risk decoding (confusion network)." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + exit 1; +fi + +data=$1 +lang_or_graph=$2 +dir=$3 + +symtab=$lang_or_graph/words.txt + +for f in $symtab $dir/lat.1.gz $data/text; do + [ ! -f $f ] && echo "score.sh: no such file $f" && exit 1; +done + +mkdir -p $dir/scoring/log + +cat $data/text | sed 's:::g' | sed 's:::g' > $dir/scoring/test_filt.txt + +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.$wip.log \ + lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ + lattice-best-path --word-symbol-table=$symtab \ + ark:- ark,t:$dir/scoring/LMWT.$wip.tra || exit 1; +done + +# Note: the double level of quoting for the sed command +for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.$wip.log \ + cat $dir/scoring/LMWT.$wip.tra \| \ + utils/int2sym.pl -f 2- $symtab \| sed 's:\::g' \| \ + compute-wer --text --mode=present \ + ark:$dir/scoring/test_filt.txt ark,p:- ">&" $dir/wer_LMWT_$wip || exit 1; +done + +exit 0; diff --git a/egs/cmu_cslu_kids/s5/local/sort_result.sh b/egs/cmu_cslu_kids/s5/local/sort_result.sh new file mode 100755 index 000000000..aedec9dc3 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/sort_result.sh @@ -0,0 +1,46 @@ +#! /bin/bash + +# Copyright Johns Hopkins University +# 2019 Fei Wu + +# Sorts and reports results in results/results.txt +# for all models in exp. Expects decode directories +# to be named as exp//decode* or exp/chain/tdnn*/decode* +# Should be run from egs/cmu_cslu_kids. + +res=${1:-"results/results.txt"} +exp=exp +mkdir -p results +rm -f $res + +echo "Sorting results in: " +echo "# ---------- GMM-HMM Models ----------" >> $res +for mdl in $exp/mono* $exp/tri*; do + echo " $mdl" + if [ -d $mdl ];then + for dec in $mdl/decode*;do + echo " $dec" + if [ -d $dec ];then + grep WER $dec/wer* | \ + sort -k2 -n > $dec/WERs + head -n 1 $dec/WERs >> $res + fi + done + fi +done + +echo "# ---------- DNN-HMM Models ----------" >> $res +# DNN results +for mdl in $exp/chain/tdnn*; do + echo " $mdl" + for dec in $mdl/decode*; do + if [ -d $dec ]; then + echo " $dec" + grep WER $dec/wer* | \ + sort -k2 -n > $dec/WERs + head -n 1 $dec/WERs >> $res + fi + done +done + +sed -i "s/:/ /g" $res diff --git a/egs/cmu_cslu_kids/s5/local/subset_dataset.sh b/egs/cmu_cslu_kids/s5/local/subset_dataset.sh new file mode 100755 index 000000000..050128247 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/subset_dataset.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright 2017 Luminar Technologies, Inc. (author: Daniel Galvez) +# Apache 2.0 + +# The following commands were used to generate the mini_librispeech dataset: +# +# Note that data generation is random. This could be fixed by +# providing a seed argument to the shuf program. + +if [ "$#" -ne 3 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/a05/dgalvez/LibriSpeech/train-clean-100 \\ + /export/a05/dgalvez/LibriSpeech/train-clean-5 5" + exit 1 +fi + +src_dir=$1 +dest_dir=$2 +dest_num_hours=$3 + +src=$(basename $src_dir) +dest=$(basename $dest_dir) +librispeech_dir=$(dirname $src_dir) + +# TODO: Possibly improve this to ensure gender balance and speaker +# balance. +# TODO: Use actual time values instead of assuming that to make sure we get $dest_num_hours of data +src_num_hours=$(grep "$src" $librispeech_dir/CHAPTERS.TXT | awk -F'|' '{ print $3 }' | \ +python -c ' +from __future__ import print_function +from sys import stdin +minutes_str = stdin.read().split() +print(int(round(sum([float(minutes) for minutes in minutes_str]) / 60.0)))') +src_num_chapters=$(grep "$src" $librispeech_dir/CHAPTERS.TXT | \ + awk -F'|' '{ print $1 }' | sort -u | wc -l) +mkdir -p data/subset_tmp +grep "$src" $librispeech_dir/CHAPTERS.TXT | \ + awk -F'|' '{ print $1 }' | \ + shuf -n $(((dest_num_hours * src_num_chapters) / src_num_hours)) > \ + data/subset_tmp/${dest}_chapter_id_list.txt + +while read -r chapter_id || [[ -n "$chapter_id" ]]; do + chapter_dir=$(find $src_dir/ -mindepth 2 -name "$chapter_id" -type d) + speaker_id=$(basename $(dirname $chapter_dir)) + mkdir -p $dest_dir/$speaker_id/ + cp -r $chapter_dir $dest_dir/$speaker_id/ +done < data/subset_tmp/${dest}_chapter_id_list.txt diff --git a/egs/cmu_cslu_kids/s5/local/sum_duration.py b/egs/cmu_cslu_kids/s5/local/sum_duration.py new file mode 100644 index 000000000..0af7ba621 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/sum_duration.py @@ -0,0 +1,15 @@ +# Sum duration obtained by using +# utils/data/get_utt2dur.sh + +import sys +file = sys.argv[1] +sum = 0 +with open(file, 'r') as fp: + line = fp.readline() + while(line): + toks = line.strip().split() + sum += float(toks[1]) + line = fp.readline() +fp.close() +h=sum/3600 +sys.stdout.write("%f hour data.\n"%h) diff --git a/egs/cmu_cslu_kids/s5/local/train_lms.sh b/egs/cmu_cslu_kids/s5/local/train_lms.sh new file mode 100755 index 000000000..0807210be --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/train_lms.sh @@ -0,0 +1,217 @@ +#!/bin/bash + +# This script trains LMs on the WSJ LM-training data. +# It requires that you have already run wsj_extend_dict.sh, +# to get the larger-size dictionary including all of CMUdict +# plus any OOVs and possible acronyms that we could easily +# derive pronunciations for. + +dict_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +dir=data/local/local_lm +srcdir=data/local/dict${dict_suffix}_larger +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/kaldi_lm:$PATH +( # First make sure the kaldi_lm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d kaldi_lm ]; then + echo Not installing the kaldi_lm toolkit since it is already there. + else + echo Downloading and installing the kaldi_lm tools + if [ ! -f kaldi_lm.tar.gz ]; then + wget http://www.danielpovey.com/files/kaldi/kaldi_lm.tar.gz || exit 1; + fi + tar -xvzf kaldi_lm.tar.gz || exit 1; + cd kaldi_lm + make || exit 1; + echo Done making the kaldi_lm tools + fi +) || exit 1; + + + +if [ ! -f $srcdir/cleaned.gz -o ! -f $srcdir/lexicon.txt ]; then + echo "Expecting files $srcdir/cleaned.gz and $srcdir/lexicon.txt to exist"; + echo "You need to run local/wsj_extend_dict.sh before running this script." + exit 1; +fi + +# Get a wordlist-- keep everything but silence, which should not appear in +# the LM. +awk '{print $1}' $srcdir/lexicon.txt | grep -v -w '!SIL' > $dir/wordlist.txt + +# Get training data with OOV words (w.r.t. our current vocab) replaced with . +echo "Getting training data with OOV words replaced with (train_nounk.gz)" +gunzip -c $srcdir/cleaned.gz | awk -v w=$dir/wordlist.txt \ + 'BEGIN{while((getline0) v[$1]=1;} + {for (i=1;i<=NF;i++) if ($i in v) printf $i" ";else printf " ";print ""}'|sed 's/ $//g' \ + | gzip -c > $dir/train_nounk.gz + +# Get unigram counts (without bos/eos, but this doens't matter here, it's +# only to get the word-map, which treats them specially & doesn't need their +# counts). +# Add a 1-count for each word in word-list by including that in the data, +# so all words appear. +gunzip -c $dir/train_nounk.gz | cat - $dir/wordlist.txt | \ + awk '{ for(x=1;x<=NF;x++) count[$x]++; } END{for(w in count){print count[w], w;}}' | \ + sort -nr > $dir/unigram.counts + +# Get "mapped" words-- a character encoding of the words that makes the common words very short. +cat $dir/unigram.counts | awk '{print $2}' | get_word_map.pl "" "" "" > $dir/word_map + +gunzip -c $dir/train_nounk.gz | awk -v wmap=$dir/word_map 'BEGIN{while((getline0)map[$1]=$2;} + { for(n=1;n<=NF;n++) { printf map[$n]; if(n$dir/train.gz + +# To save disk space, remove the un-mapped training data. We could +# easily generate it again if needed. +rm $dir/train_nounk.gz + +train_lm.sh --arpa --lmtype 3gram-mincount $dir +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 141.444826 +# 7.8 million N-grams. + +prune_lm.sh --arpa 6.0 $dir/3gram-mincount/ +# 1.45 million N-grams. +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 165.394139 + +train_lm.sh --arpa --lmtype 4gram-mincount $dir +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 126.734180 +# 10.3 million N-grams. + +prune_lm.sh --arpa 7.0 $dir/4gram-mincount +# 1.50 million N-grams +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 155.663757 + + +exit 0 + +### Below here, this script is showing various commands that +## were run during LM tuning. + +train_lm.sh --arpa --lmtype 3gram-mincount $dir +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 141.444826 +# 7.8 million N-grams. + +prune_lm.sh --arpa 3.0 $dir/3gram-mincount/ +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 156.408740 +# 2.5 million N-grams. + +prune_lm.sh --arpa 6.0 $dir/3gram-mincount/ +# 1.45 million N-grams. +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 165.394139 + +train_lm.sh --arpa --lmtype 4gram-mincount $dir +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 126.734180 +# 10.3 million N-grams. + +prune_lm.sh --arpa 3.0 $dir/4gram-mincount +#Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 143.206294 +# 2.6 million N-grams. + +prune_lm.sh --arpa 4.0 $dir/4gram-mincount +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 146.927717 +# 2.15 million N-grams. + +prune_lm.sh --arpa 5.0 $dir/4gram-mincount +# 1.86 million N-grams +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 150.162023 + +prune_lm.sh --arpa 7.0 $dir/4gram-mincount +# 1.50 million N-grams +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 155.663757 + +train_lm.sh --arpa --lmtype 3gram $dir +# Perplexity over 228518.000000 words (excluding 478.000000 OOVs) is 135.692866 +# 20.0 million N-grams + +! which ngram-count \ + && echo "SRILM tools not installed so not doing the comparison" && exit 1; + +################# +# You could finish the script here if you wanted. +# Below is to show how to do baselines with SRILM. +# You'd have to install the SRILM toolkit first. + +heldout_sent=10000 # Don't change this if you want result to be comparable with + # kaldi_lm results +sdir=$dir/srilm # in case we want to use SRILM to double-check perplexities. +mkdir -p $sdir +gunzip -c $srcdir/cleaned.gz | head -$heldout_sent > $sdir/cleaned.heldout +gunzip -c $srcdir/cleaned.gz | tail -n +$heldout_sent > $sdir/cleaned.train +(echo ""; echo "" ) | cat - $dir/wordlist.txt > $sdir/wordlist.final.s + +# 3-gram: +ngram-count -text $sdir/cleaned.train -order 3 -limit-vocab -vocab $sdir/wordlist.final.s -unk \ + -map-unk "" -kndiscount -interpolate -lm $sdir/srilm.o3g.kn.gz +ngram -lm $sdir/srilm.o3g.kn.gz -ppl $sdir/cleaned.heldout # consider -debug 2 +#file data/local/local_lm/srilm/cleaned.heldout: 10000 sentences, 218996 words, 478 OOVs +#0 zeroprobs, logprob= -491456 ppl= 141.457 ppl1= 177.437 + +# Trying 4-gram: +ngram-count -text $sdir/cleaned.train -order 4 -limit-vocab -vocab $sdir/wordlist.final.s -unk \ + -map-unk "" -kndiscount -interpolate -lm $sdir/srilm.o4g.kn.gz +ngram -order 4 -lm $sdir/srilm.o4g.kn.gz -ppl $sdir/cleaned.heldout +#file data/local/local_lm/srilm/cleaned.heldout: 10000 sentences, 218996 words, 478 OOVs +#0 zeroprobs, logprob= -480939 ppl= 127.233 ppl1= 158.822 + +#3-gram with pruning: +ngram-count -text $sdir/cleaned.train -order 3 -limit-vocab -vocab $sdir/wordlist.final.s -unk \ + -prune 0.0000001 -map-unk "" -kndiscount -interpolate -lm $sdir/srilm.o3g.pr7.kn.gz +ngram -lm $sdir/srilm.o3g.pr7.kn.gz -ppl $sdir/cleaned.heldout +#file data/local/local_lm/srilm/cleaned.heldout: 10000 sentences, 218996 words, 478 OOVs +#0 zeroprobs, logprob= -510828 ppl= 171.947 ppl1= 217.616 +# Around 2.25M N-grams. +# Note: this is closest to the experiment done with "prune_lm.sh --arpa 3.0 $dir/3gram-mincount/" +# above, which gave 2.5 million N-grams and a perplexity of 156. + +# Note: all SRILM experiments above fully discount all singleton 3 and 4-grams. +# You can use -gt3min=0 and -gt4min=0 to stop this (this will be comparable to +# the kaldi_lm experiments above without "-mincount". + +## From here is how to train with +# IRSTLM. This is not really working at the moment. + +if [ -z $IRSTLM ] ; then + export IRSTLM=$KALDI_ROOT/tools/irstlm/ +fi +export PATH=${PATH}:$IRSTLM/bin +if ! command -v prune-lm >/dev/null 2>&1 ; then + echo "$0: Error: the IRSTLM is not available or compiled" >&2 + echo "$0: Error: We used to install it by default, but." >&2 + echo "$0: Error: this is no longer the case." >&2 + echo "$0: Error: To install it, go to $KALDI_ROOT/tools" >&2 + echo "$0: Error: and run extras/install_irstlm.sh" >&2 + exit 1 +fi + +idir=$dir/irstlm +mkdir $idir +gunzip -c $srcdir/cleaned.gz | tail -n +$heldout_sent | add-start-end.sh | \ + gzip -c > $idir/train.gz + +dict -i=WSJ.cleaned.irstlm.txt -o=dico -f=y -sort=no + cat dico | gawk 'BEGIN{while (getline<"vocab.20k.nooov") v[$1]=1; print "DICTIONARY 0 "length(v);}FNR>1{if ($1 in v)\ +{print $0;}}' > vocab.irstlm.20k + + +build-lm.sh -i "gunzip -c $idir/train.gz" -o $idir/lm_3gram.gz -p yes \ + -n 3 -s improved-kneser-ney -b yes +# Testing perplexity with SRILM tools: +ngram -lm $idir/lm_3gram.gz -ppl $sdir/cleaned.heldout +#data/local/local_lm/irstlm/lm_3gram.gz: line 162049: warning: non-zero probability for in closed-vocabulary LM +#file data/local/local_lm/srilm/cleaned.heldout: 10000 sentences, 218996 words, 0 OOVs +#0 zeroprobs, logprob= -513670 ppl= 175.041 ppl1= 221.599 + +# Perplexity is very bad (should be ~141, since we used -p option, +# not 175), +# but adding -debug 3 to the command line shows that +# the IRSTLM LM does not seem to sum to one properly, so it seems that +# it produces an LM that isn't interpretable in the normal way as an ARPA +# LM. + + + diff --git a/egs/cmu_cslu_kids/s5/local/vtln.sh b/egs/cmu_cslu_kids/s5/local/vtln.sh new file mode 100755 index 000000000..0ca179ce8 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/local/vtln.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +# Copyright Johns Hopkins University +# 2019 Fei Wu + +# Run VTLN. This will be run if the vtln option +# is set to be true in run.sh. + +set -eu +stage=0 +featdir=mfcc/vtln +data=data +mdl=exp/tri3 +mdl_vtln=${mdl}_vtln +vtln_lda=exp/tri4 +vtln_sat=exp/tri5 + +. ./cmd.sh +. ./utils/parse_options.sh + +mkdir -p $featdir + +steps/train_lvtln.sh --cmd "$train_cmd" 1800 9000 $data/train $data/lang $mdl $mdl_vtln + +if [ $stage -le 0 ]; then + mkdir -p $data/train_vtln + cp $data/train/* $data/train_vtln || true + cp $mdl_vtln/final.warp $data/train_vtln/spk2warp + steps/make_mfcc.sh --nj 8 --cmd "$train_cmd" $data/train_vtln exp/make_mfcc/train_vtln $featdir + steps/compute_cmvn_stats.sh $data/train_vtln exp/make_mfcc/train_vtln $featdir +fi + +if [ $stage -le 1 ]; then + utils/mkgraph.sh $data/lang_test_tgmed $mdl_vtln $mdl_vtln/graph + steps/decode_lvtln.sh --config conf/decode.config --nj 20 --cmd "$decode_cmd" \ + $mdl_vtln/graph $data/test $mdl_vtln/decode +fi + +if [ $stage -le 2 ]; then + mkdir -p $data/test_vtln + cp $data/test/* $data/test_vtln || true + cp $mdl_vtln/decode/final.warp $data/test_vtln/spk2warp + steps/make_mfcc.sh --nj 8 --cmd "$train_cmd" $data/test_vtln exp/make_mfcc/test_vtln $featdir + steps/compute_cmvn_stats.sh $data/test_vtln exp/make_mfcc/test_vtln $featdir +fi + +if [ $stage -le 3 ]; then + steps/train_lda_mllt.sh --cmd "$train_cmd" --splice-opts "--left-context=3 --right-context=3" 1800 9000 \ + $data/train_vtln $data/lang $mdl_vtln $vtln_lda + utils/mkgraph.sh $data/lang_test_tgmed $vtln_lda $vtln_lda/graph + echo "$mdl_vtln + lda + mllt" > $vtln_lda/mcodel_discription + steps/decode.sh --config conf/decode.config --nj 20 --cmd "$decode_cmd" \ + $vtln_lda/graph $data/test_vtln $vtln_lda/decode +fi + +if [ $stage -le 4 ]; then + steps/train_sat.sh 1800 9000 $data/train_vtln $data/lang $vtln_lda $vtln_sat + utils/mkgraph.sh $data/lang_test_tgmed $vtln_sat $vtln_sat/graph + steps/decode_fmllr.sh --config conf/decode.config --nj 20 --cmd "$decode_cmd" $vtln_sat/graph $data/test_vtln $vtln_sat/decode + echo "$mdl_vtln + lda + mllt + SAT" > $vtln_sat/model_discription +fi diff --git a/egs/cmu_cslu_kids/s5/path.sh b/egs/cmu_cslu_kids/s5/path.sh new file mode 100755 index 000000000..2d17b17a8 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/path.sh @@ -0,0 +1,6 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C diff --git a/egs/cmu_cslu_kids/s5/run.sh b/egs/cmu_cslu_kids/s5/run.sh new file mode 100755 index 000000000..43ae1ea94 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/run.sh @@ -0,0 +1,177 @@ +#! /bin/bash + +# Copyright Johns Hopkins University +# 2019 Fei Wu + +set -eo + +stage=0 +cmu_kids= # path to cmu_kids corpus +cslu_kids= # path to cslu_kids corpus +lm_src= # path of existing librispeech lm +extra_features=false # Extra features for GMM model (MMI, boosting and MPE) +vtln=false # Optional, run VLTN on gmm and tdnnf models if set true +email= # Reporting email for tdnn-f training + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +lm_url=www.openslr.org/resources/11 +mkdir -p data +mkdir -p data/local + +# Prepare data +if [ $stage -le 0 ]; then + # Make soft link to the corpora + if [ ! -e cmu_kids ]; then + ln -sf $cmu_kids cmu_kids + fi + if [ ! -e cslu ]; then + ln -sf $cslu_kids cslu + fi + + # Make softlink to lm, if lm_src provided + if [ ! -z "$lm_src" ] && [ ! -e data/local/lm ] ; then + ln -sf $lm_src data/local/lm + fi + + # Remove old data dirs + rm -rf data/data_cmu + rm -rf data/data_cslu + + # Data Prep + ./local/cmu_prepare_data.sh --corpus cmu_kids/kids --data data/data_cmu + ./local/cslu_prepare_data.sh --corpus cslu --data data/data_cslu +fi + +# Combine data +if [ $stage -le 1 ]; then + mkdir -p data/train + mkdir -p data/test + rm -rf data/train/* + rm -rf data/test/* + ./utils/combine_data.sh data/train data/data_cmu/train data/data_cslu/train + ./utils/combine_data.sh data/test data/data_cmu/test data/data_cslu/test +fi + +# LM, WFST Preparation +if [ $stage -le 2 ]; then + if [ ! -d data/local/dict ]; then + ./local/download_cmu_dict.sh + fi + + if [ ! -e data/local/lm ]; then + echo "lm_src not provided. Downloading lm from openslr." + ./local/download_lm.sh $lm_url data/local/lm + fi + + utils/prepare_lang.sh data/local/dict "" data/local/lang data/lang + local/format_lms.sh --src_dir data/lang data/local/lm + + # Create ConstArpaLm format language model for full 3-gram and 4-gram LMs + utils/build_const_arpa_lm.sh data/local/lm/lm_tglarge.arpa.gz data/lang data/lang_test_tglarge + utils/build_const_arpa_lm.sh data/local/lm/lm_fglarge.arpa.gz data/lang data/lang_test_fglarge +fi + +# Make MFCC features +if [ $stage -le 3 ]; then + mkdir -p mfcc + mkdir -p exp + steps/make_mfcc.sh --nj 40 --cmd "$train_cmd" data/test exp/make_feat/test mfcc + steps/compute_cmvn_stats.sh data/test exp/make_feat/test mfcc + steps/make_mfcc.sh --nj 40 --cmd "$train_cmd" data/train exp/make_feat/train mfcc + steps/compute_cmvn_stats.sh data/train exp/make_feat/train mfcc +fi + +# Mono-phone +if [ $stage -le 4 ]; then + # Train + steps/train_mono.sh --nj 40 --cmd "$train_cmd" data/train data/lang exp/mono + #Decode + utils/mkgraph.sh data/lang_test_tgsmall exp/mono exp/mono/graph + steps/decode.sh --config conf/decode.config --nj 40 --cmd "$decode_cmd" exp/mono/graph data/test exp/mono/decode + #Align + steps/align_si.sh --nj 20 --cmd "$train_cmd" data/train data/lang exp/mono exp/mono_ali +fi + +# Tri1 [Vanilla tri phone model] +if [ $stage -le 5 ]; then + # Train + steps/train_deltas.sh --cmd "$train_cmd" 1800 9000 data/train data/lang exp/mono_ali exp/tri1 + # Decode + utils/mkgraph.sh data/lang_test_tgmed exp/tri1 exp/tri1/graph + steps/decode.sh --config conf/decode.config --nj 40 --cmd "$decode_cmd" exp/tri1/graph data/test exp/tri1/decode + # Align - make graph - decode again + steps/align_si.sh --nj 20 --cmd "queue.pl" --use-graphs true data/train data/lang_test_tgmed exp/tri1 exp/tri1_ali + utils/mkgraph.sh data/lang_test_tgmed exp/tri1_ali exp/tri1_ali/graph + steps/decode.sh --config conf/decode.config --nj 40 --cmd "$decode_cmd" exp/tri1_ali/graph data/test exp/tri1_ali/decode +fi + +# Add LDA and MLLT +if [ $stage -le 6 ]; then + # Train + steps/train_lda_mllt.sh --cmd "$train_cmd" --splice-opts "--left-context=3 --right-context=3" 1800 9000 data/train data/lang exp/tri1_ali exp/tri2 + utils/mkgraph.sh data/lang_test_tgmed exp/tri2 exp/tri2/graph + # Decode + steps/decode.sh --config conf/decode.config --nj 40 --cmd "$decode_cmd" exp/tri2/graph data/test exp/tri2/decode + # Align - make graph - dcode again + steps/align_si.sh --nj 20 --cmd "$train_cmd" --use-graphs true data/train data/lang_test_tgmed exp/tri2 exp/tri2_ali + utils/mkgraph.sh data/lang_test_tgmed exp/tri2_ali exp/tri2_ali/graph + steps/decode_fmllr.sh --config conf/decode.config --nj 40 --cmd "$decode_cmd" exp/tri2_ali/graph data/test exp/tri2_ali/decode +fi + +# Add other features +if [ $stage -le 7 ]; then + if [ $extra_features = true ]; then + # Add MMI + steps/make_denlats.sh --nj 20 --cmd "$train_cmd" data/train data/lang exp/tri2 exp/tri2_denlats + steps/train_mmi.sh data/train data/lang exp/tri2_ali exp/tri2_denlats exp/tri2_mmi + steps/decode.sh --config conf/decode.config --iter 4 --nj 20 --cmd "$decode_cmd" exp/tri2/graph data/test exp/tri2_mmi/decode_it4 + steps/decode.sh --config conf/decode.config --iter 3 --nj 20 --cmd "$decode_cmd" exp/tri2/graph data/test exp/tri2_mmi/decode_it3 + + # Add Boosting + steps/train_mmi.sh --boost 0.05 data/train data/lang exp/tri2_ali exp/tri2_denlats exp/tri2_mmi_b0.05 + steps/decode.sh --config conf/decode.config --iter 4 --nj 20 --cmd "$decode_cmd" exp/tri2/graph data/test exp/tri2_mmi_b0.05/decode_it4 + steps/decode.sh --config conf/decode.config --iter 3 --nj 20 --cmd "$decode_cmd" exp/tri2/graph data/test exp/tri2_mmi_b0.05/decode_it3 + + # Add MPE + steps/train_mpe.sh data/train data/lang exp/tri2_ali exp/tri2_denlats exp/tri2_mpe + steps/decode.sh --config conf/decode.config --iter 4 --nj 20 --cmd "$decode_cmd" exp/tri2/graph data/test exp/tri2_mpe/decode_it4 + steps/decode.sh --config conf/decode.config --iter 3 --nj 20 --cmd "$decode_cmd" exp/tri2/graph data/test exp/tri2_mpe/decode_it3 + fi +fi + +# Add SAT +if [ $stage -le 8 ]; then + # Do LDA+MLLT+SAT, and decode. + steps/train_sat.sh 1800 9000 data/train data/lang exp/tri2_ali exp/tri3 + utils/mkgraph.sh data/lang_test_tgmed exp/tri3 exp/tri3/graph + steps/decode_fmllr.sh --config conf/decode.config --nj 40 --cmd "$decode_cmd" exp/tri3/graph data/test exp/tri3/decode +fi + +if [ $stage -le 9 ]; then + # Align all data with LDA+MLLT+SAT system (tri3) + steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" --use-graphs true data/train data/lang_test_tgmed exp/tri3 exp/tri3_ali + utils/mkgraph.sh data/lang_test_tgmed exp/tri3_ali exp/tri3_ali/graph + steps/decode_fmllr.sh --config conf/decode.config --nj 40 --cmd "$decode_cmd" exp/tri3_ali/graph data/test exp/tri3_ali/decode +fi + +if [ $stage -le 10 ]; then + # Uncomment reporting email option to get training progress updates by email + ./local/chain/run_tdnnf.sh --train_set train \ + --test_sets test --gmm tri3 # --reporting_email $email +fi + + +# Optional VTLN. Run if vtln is set to true +if [ $stage -le 11 ]; then + if [ $vtln = true ]; then + ./local/vtln.sh + ./local/chain/run_tdnnf.sh --nnet3_affix vtln --train_set train_vtln \ + --test_sets test_vtln --gmm tri5 # --reporting_email $email + fi +fi + +# Collect and resport WER results for all models +./local/sort_result.sh diff --git a/egs/cmu_cslu_kids/s5/steps b/egs/cmu_cslu_kids/s5/steps new file mode 120000 index 000000000..1b186770d --- /dev/null +++ b/egs/cmu_cslu_kids/s5/steps @@ -0,0 +1 @@ +../../wsj/s5/steps/ \ No newline at end of file diff --git a/egs/cmu_cslu_kids/s5/utils b/egs/cmu_cslu_kids/s5/utils new file mode 120000 index 000000000..a3279dc86 --- /dev/null +++ b/egs/cmu_cslu_kids/s5/utils @@ -0,0 +1 @@ +../../wsj/s5/utils/ \ No newline at end of file diff --git a/egs/gop/README.md b/egs/gop/README.md new file mode 100644 index 000000000..d95f4e966 --- /dev/null +++ b/egs/gop/README.md @@ -0,0 +1,98 @@ +There is a copy of this document on Google Docs, which renders the equations better: +[link](https://docs.google.com/document/d/1pie-PU6u2NZZC_FzocBGGm6mpfBJMiCft9UoG0uA1kA/edit?usp=sharing) + +* * * + +# GOP on Kaldi + +The Goodness of Pronunciation (GOP) is a variation of the posterior probability, for phone level pronunciation scoring. +GOP is widely used in pronunciation evaluation and mispronunciation detection tasks. + +This implementation is mainly based on the following paper: + +Hu, W., Qian, Y., Soong, F. K., & Wang, Y. (2015). Improved mispronunciation detection with deep neural network trained acoustic models and transfer learning based logistic regression classifiers. Speech Communication, 67(January), 154-166. + +## GOP-GMM + +In the conventional GMM-HMM based system, GOP was first proposed in (Witt et al., 2000). It was defined as the duration normalised log of the posterior: + +$$ +GOP(p)=\frac{1}{t_e-t_s+1} \log p(p|\mathbf o) +$$ + +where $\mathbf o$ is the input observations, $p$ is the canonical phone, $t_s, t_e$ are the start and end frame indexes. + +Assuming $p(q_i)\approx p(q_j)$ for any $q_i, q_j$, we have: + +$$ +\log p(p|\mathbf o)=\frac{p(\mathbf o|p)p(p)}{\sum_{q\in Q} p(\mathbf o|q)p(q)} + \approx\frac{p(\mathbf o|p)}{\sum_{q\in Q} p(\mathbf o|q)} +$$ + +where $Q$ is the whole phone set. + +The numerator of the equation is calculated from forced alignment result and the denominator is calculated from an Viterbi decoding with a unconstrained phone loop. + +We do not implement GOP-GMM for Kaldi, as GOP-NN performs much better than GOP-GMM. + +## GOP-NN + +The definition of GOP-NN is a bit different from the GOP-GMM. GOP-NN was defined as the log phone posterior ratio between the canonical phone and the one with the highest score (Hu et al., 2015). + +Firstly we define Log Phone Posterior (LPP): + +$$ +LPP(p)=\log p(p|\mathbf o; t_s,t_e) +$$ + +Then we define the GOP-NN using LPP: + +$$ +GOP(p)=\log \frac{LPP(p)}{\max_{q\in Q} LPP(q)} +$$ + +LPP could be calculated as: + +$$ +LPP(p) \approx \frac{1}{t_e-t_s+1} \sum_{t=t_s}^{t_e}\log p(p|o_t) +$$ + +$$ +p(p|o_t) = \sum_{s \in p} p(s|o_t) +$$ + +where $s$ is the senone label, $\{s|s \in p\}$ is the states belonging to those triphones whose current phone is $p$. + +## Phone-level Feature + +Normally the classifier-based approach archives better performance than GOP-based approach. + +Different from GOP based method, an extra supervised training process is needed. The input features for supervised training are phone-level, segmental features. The phone-level feature is defined as: + +$$ +{[LPP(p_1),\cdots,LPP(p_M), LPR(p_1|p_i), \cdots, LPR(p_j|p_i),\cdots]}^T +$$ + +where the Log Posterior Ratio (LPR) between phone $p_j$ and $p_i$ is defined as: + +$$ +LPR(p_j|p_i) = \log p(p_j|\mathbf o; t_s, t_e) - \log p(p_i|\mathbf o; t_s, t_e) +$$ + +## Implementation + +This implementation consists of a executable binary `bin/compute-gop` and some scripts. + +`compute-gop` computes GOP and extracts phone-level features using nnet output probabilities. +The output probabilities are assumed to be from a log-softmax layer. + +The script `run.sh` shows a typical pipeline based on librispeech's model and data. + +In Hu's paper, GOP was computed using a feed-forward DNN. +We have tried to use the output-xent of a chain model to compute GOP, but the result was not good. +We guess the HMM topo of chain model may not fit for GOP. + +The nnet3's TDNN (no chain) model performs well in GOP computing, so this recipe uses it. + +## Acknowledgement +The author of this recipe would like to thank Xingyu Na for his works of model tuning and his helpful suggestions. diff --git a/egs/gop/s5/cmd.sh b/egs/gop/s5/cmd.sh new file mode 100644 index 000000000..9139633e5 --- /dev/null +++ b/egs/gop/s5/cmd.sh @@ -0,0 +1,13 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +export cmd="run.pl" diff --git a/egs/gop/s5/local/make_testcase.sh b/egs/gop/s5/local/make_testcase.sh new file mode 100755 index 000000000..884563066 --- /dev/null +++ b/egs/gop/s5/local/make_testcase.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +src=$1 +dst=$2 + +# Select a very small set for testing +utils/subset_data_dir.sh --shortest $src 10 $dst + +# make fake transcripts as negative examples +cp $dst/text $dst/text.ori +sed -i "s/ THERE / THOSE /" $dst/text +sed -i "s/ IN / ON /" $dst/text diff --git a/egs/gop/s5/local/remove_phone_markers.pl b/egs/gop/s5/local/remove_phone_markers.pl new file mode 100755 index 000000000..16236a749 --- /dev/null +++ b/egs/gop/s5/local/remove_phone_markers.pl @@ -0,0 +1,72 @@ +#!/usr/bin/env perl +# Copyright 2019 Junbo Zhang + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +use strict; +use warnings; + +my $Usage = <new phone mapping file, in which each line is: "old-integer-id new-integer-id. + +Usage: utils/remove_phone_markers.pl + e.g.: utils/remove_phone_markers.pl phones.txt phones-pure.txt phone-to-pure-phone.int +EOU + +if (@ARGV < 3) { + die $Usage; +} + +my $old_phone_symbols_filename = shift @ARGV; +my $new_phone_symbols_filename = shift @ARGV; +my $mapping_filename = shift @ARGV; + +my %id_of_old_phone; +open(IN, $old_phone_symbols_filename) or die "Can't open $old_phone_symbols_filename"; +while () { + chomp; + my ($phone, $id) = split; + next if $phone =~ /\#/; + $id_of_old_phone{$phone} = $id; +} +close IN; + +my $new_id = 0; +my %id_of_new_phone; +my %id_old_to_new; +foreach (sort { $id_of_old_phone{$a} <=> $id_of_old_phone{$b} } keys %id_of_old_phone) { + my $old_phone = $_; + s/_[BIES]//; + s/\d//; + my $new_phone = $_; + $id_of_new_phone{$new_phone} = $new_id++ if not exists $id_of_new_phone{$new_phone}; + $id_old_to_new{$id_of_old_phone{$old_phone}} = $id_of_new_phone{$new_phone}; +} + +# Write to file +open(OUT, ">$new_phone_symbols_filename") or die "Can\'t write to $new_phone_symbols_filename"; +foreach (sort { $id_of_new_phone{$a} <=> $id_of_new_phone{$b} } keys %id_of_new_phone) { + print OUT "$_\t$id_of_new_phone{$_}\n"; +} +close OUT; + +open(OUT, ">$mapping_filename") or die "Can\'t write to $mapping_filename"; +foreach (sort { $a <=> $b } keys %id_old_to_new) { + next if $_ == 0; + print OUT "$_ $id_old_to_new{$_}\n"; +} +close OUT; diff --git a/egs/gop/s5/path.sh b/egs/gop/s5/path.sh new file mode 100755 index 000000000..03df6dd9f --- /dev/null +++ b/egs/gop/s5/path.sh @@ -0,0 +1,27 @@ +export KALDI_ROOT=`pwd`/../../.. +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C + +# we use this both in the (optional) LM training and the G2P-related scripts +PYTHON='python2.7' + +### Below are the paths used by the optional parts of the recipe + +# We only need the Festival stuff below for the optional text normalization(for LM-training) step +FEST_ROOT=tools/festival +NSW_PATH=${FEST_ROOT}/festival/bin:${FEST_ROOT}/nsw/bin +export PATH=$PATH:$NSW_PATH + +# SRILM is needed for LM model building +SRILM_ROOT=$KALDI_ROOT/tools/srilm +SRILM_PATH=$SRILM_ROOT/bin:$SRILM_ROOT/bin/i686-m64 +export PATH=$PATH:$SRILM_PATH + +# Sequitur G2P executable +sequitur=$KALDI_ROOT/tools/sequitur/g2p.py +sequitur_path="$(dirname $sequitur)/lib/$PYTHON/site-packages" + +# Directory under which the LM training corpus should be extracted +LM_CORPUS_ROOT=./lm-corpus diff --git a/egs/gop/s5/run.sh b/egs/gop/s5/run.sh new file mode 100755 index 000000000..a731b9135 --- /dev/null +++ b/egs/gop/s5/run.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +# Copyright 2019 Junbo Zhang +# Apache 2.0 + +# This script shows how to calculate Goodness of Pronunciation (GOP) and +# extract phone-level pronunciation feature for mispronunciations detection +# tasks. Read ../README.md or the following paper for details: +# +# "Hu et al., Improved mispronunciation detection with deep neural network +# trained acoustic models and transfer learning based logistic regression +# classifiers, 2015." + +# You might not want to do this for interactive shells. +set -e + +# Before running this recipe, you have to run the librispeech recipe firstly. +# This script assumes the following paths exist. +librispeech_eg=../../librispeech/s5 +model=$librispeech_eg/exp/nnet3_cleaned/tdnn_sp +ivector=$librispeech_eg/exp/nnet3_cleaned/ivectors_test_clean_hires +lang=$librispeech_eg/data/lang +test_data=$librispeech_eg/data/test_clean_hires + +for d in $model $ivector $lang $test_data; do + [ ! -d $d ] && echo "$0: no such path $d" && exit 1; +done + +# Global configurations +stage=0 +nj=4 + +data=test_10short +dir=exp/gop_$data + +. ./cmd.sh +. ./path.sh +. parse_options.sh + +if [ $stage -le 0 ]; then + # Prepare test data + [ -d data ] || mkdir -p data/$data + local/make_testcase.sh $test_data data/$data +fi + +if [ $stage -le 1 ]; then + # Compute Log-likelihoods + steps/nnet3/compute_output.sh --cmd "$cmd" --nj $nj \ + --online-ivector-dir $ivector data/$data $model exp/probs_$data +fi + +if [ $stage -le 2 ]; then + steps/nnet3/align.sh --cmd "$cmd" --nj $nj --use_gpu false \ + --online_ivector_dir $ivector data/$data $lang $model $dir +fi + +if [ $stage -le 3 ]; then + # make a map which converts phones to "pure-phones" + # "pure-phone" means the phone whose stress and pos-in-word markers are ignored + # eg. AE1_B --> AE, EH2_S --> EH, SIL --> SIL + local/remove_phone_markers.pl $lang/phones.txt $dir/phones-pure.txt \ + $dir/phone-to-pure-phone.int + + # Convert transition-id to pure-phone id + $cmd JOB=1:$nj $dir/log/ali_to_phones.JOB.log \ + ali-to-phones --per-frame=true $model/final.mdl "ark,t:gunzip -c $dir/ali.JOB.gz|" \ + "ark,t:-" \| utils/apply_map.pl -f 2- $dir/phone-to-pure-phone.int \| \ + gzip -c \>$dir/ali-pure-phone.JOB.gz || exit 1; +fi + +if [ $stage -le 4 ]; then + # The outputs of the binary compute-gop are the GOPs and the phone-level features. + # + # An example of the GOP result (extracted from "ark,t:$dir/gop.3.txt"): + # 4446-2273-0031 [ 1 0 ] [ 12 0 ] [ 27 -5.382001 ] [ 40 -13.91807 ] [ 1 -0.2555897 ] \ + # [ 21 -0.2897284 ] [ 5 0 ] [ 31 0 ] [ 33 0 ] [ 3 -11.43557 ] [ 25 0 ] \ + # [ 16 0 ] [ 30 -0.03224623 ] [ 5 0 ] [ 25 0 ] [ 33 0 ] [ 1 0 ] + # It is in the posterior format, where each pair stands for [pure-phone-index gop-value]. + # For example, [ 27 -5.382001 ] means the GOP of the pure-phone 27 (it corresponds to the + # phone "OW", according to "$dir/phones-pure.txt") is -5.382001, indicating the audio + # segment of this phone should be a mispronunciation. + # + # The phone-level features are in matrix format: + # 4446-2273-0031 [ -0.2462088 -10.20292 -11.35369 ... + # -8.584108 -7.629755 -13.04877 ... + # ... + # ... ] + # The row number is the phone number of the utterance. In this case, it is 17. + # The column number is 2 * (pure-phone set size), as the feature is consist of LLR + LPR. + # The phone-level features can be used to train a classifier with human labels. See Hu's + # paper for detail. + $cmd JOB=1:$nj $dir/log/compute_gop.JOB.log \ + compute-gop --phone-map=$dir/phone-to-pure-phone.int $model/final.mdl \ + "ark,t:gunzip -c $dir/ali-pure-phone.JOB.gz|" \ + "ark:exp/probs_$data/output.JOB.ark" \ + "ark,t:$dir/gop.JOB.txt" "ark,t:$dir/phonefeat.JOB.txt" || exit 1; + echo "Done compute-gop, the results: \"$dir/gop..txt\" in posterior format." + + # We set -5 as a universal empirical threshold here. You can also determine multiple phone + # dependent thresholds based on the human-labeled mispronunciation data. + echo "The phones whose gop values less than -5 could be treated as mispronunciations." +fi diff --git a/egs/gop/s5/steps b/egs/gop/s5/steps new file mode 120000 index 000000000..6e99bf5b5 --- /dev/null +++ b/egs/gop/s5/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/egs/gop/s5/utils b/egs/gop/s5/utils new file mode 120000 index 000000000..b24088521 --- /dev/null +++ b/egs/gop/s5/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/egs/librispeech/s5/RESULTS b/egs/librispeech/s5/RESULTS index b45271765..dbf54b938 100644 --- a/egs/librispeech/s5/RESULTS +++ b/egs/librispeech/s5/RESULTS @@ -1,6 +1,6 @@ # In the results below, "tgsmall" is the pruned 3-gram LM, which is used for lattice generation. # The following language models are then used for rescoring: -# a) tgmed- slightly less pruned 3-gram LM +# a) tgmed- slightly less pruned 3-gram LM # b) tglarge- the full, non-pruned 3-gram LM # c) fglarge- non-pruned 4-gram LM # @@ -337,7 +337,7 @@ %WER 4.39 [ 2387 / 54402, 377 ins, 199 del, 1811 sub ] exp/nnet2_online/nnet_ms_a_smbr_0.000005/decode_epoch3_dev_clean_tglarge/wer_14 %WER 5.36 [ 2918 / 54402, 328 ins, 338 del, 2252 sub ] exp/nnet2_online/nnet_ms_a_smbr_0.000005/decode_epoch3_dev_clean_tgmed/wer_17 %WER 6.08 [ 3305 / 54402, 369 ins, 396 del, 2540 sub ] exp/nnet2_online/nnet_ms_a_smbr_0.000005/decode_epoch3_dev_clean_tgsmall/wer_15 -%WER 4.40 [ 2395 / 54402, 375 ins, 200 del, 1820 sub ] exp/nnet2_online/nnet_ms_a_smbr_0.000005/decode_epoch4_dev_clean_tglarge/wer_14 +%WER 4.40 [ 2395 / 54402, 375 ins, 200 del, 1820 sub ] exp/nnet2_online/nnet_ms_a_smbr_0.000005/decode_epoch4_dev_clean_tglarge/wer_14 %WER 5.35 [ 2909 / 54402, 328 ins, 339 del, 2242 sub ] exp/nnet2_online/nnet_ms_a_smbr_0.000005/decode_epoch4_dev_clean_tgmed/wer_17 %WER 6.05 [ 3291 / 54402, 384 ins, 381 del, 2526 sub ] exp/nnet2_online/nnet_ms_a_smbr_0.000005/decode_epoch4_dev_clean_tgsmall/wer_14 %WER 13.45 [ 6850 / 50948, 808 ins, 876 del, 5166 sub ] exp/nnet2_online/nnet_ms_a_smbr_0.000005/decode_epoch0_dev_other_tglarge/wer_15 @@ -423,7 +423,7 @@ %WER 17.64 [ 9231 / 52343, 764 ins, 1662 del, 6805 sub ] exp/nnet2_online/nnet_ms_a_online/decode_pp_test_other_tgsmall_utt_offline/wer_14 # Results with nnet3 tdnn -# local/nnet3/run_tdnn.sh +# local/nnet3/run_tdnn.sh (with old configs, now moved to local/nnet3/tuning/run_tdnn_1a.sh) # (4 epoch training on speed-perturbed data) # num_params=19.3M %WER 4.43 [ 2410 / 54402, 306 ins, 278 del, 1826 sub ] exp/nnet3/tdnn_sp/decode_dev_clean_fglarge/wer_13_1.0 @@ -444,7 +444,7 @@ %WER 16.29 [ 8528 / 52343, 828 ins, 1320 del, 6380 sub ] exp/nnet3/tdnn_sp/decode_test_other_tgsmall/wer_14_0.0 # Results with nnet3 tdnn -# local/nnet3/run_tdnn.sh +# local/nnet3/run_tdnn.sh (with old configs, now moved to local/nnet3/tuning/run_tdnn_1a.sh) # (4 epoch training on speed-perturbed and volumn-perturbed "cleaned" data) # num_params=19.3M, average training time=68.8s per job(on Tesla K80), real-time factor=1.23161 # for x in exp/nnet3_cleaned/tdnn_sp/decode_*; do grep WER $x/wer_* | utils/best_wer.sh ; done @@ -465,6 +465,24 @@ %WER 14.78 [ 7737 / 52343, 807 ins, 1115 del, 5815 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_other_tgmed/wer_15_0.0 %WER 16.28 [ 8521 / 52343, 843 ins, 1258 del, 6420 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_other_tgsmall/wer_14_0.0 +# Results with nnet3 tdnn with new configs, a.k.a. xconfig +# local/nnet3/run_tdnn.sh (linked to local/nnet3/tuning/run_tdnn_1b.sh) +%WER 4.60 [ 2502 / 54402, 324 ins, 286 del, 1892 sub ] exp/nnet3_cleaned/tdnn_sp/decode_dev_clean_fglarge/wer_13_1.0 +%WER 4.80 [ 2612 / 54402, 350 ins, 285 del, 1977 sub ] exp/nnet3_cleaned/tdnn_sp/decode_dev_clean_tglarge/wer_11_1.0 +%WER 5.97 [ 3248 / 54402, 460 ins, 310 del, 2478 sub ] exp/nnet3_cleaned/tdnn_sp/decode_dev_clean_tgmed/wer_11_0.0 +%WER 6.66 [ 3625 / 54402, 479 ins, 392 del, 2754 sub ] exp/nnet3_cleaned/tdnn_sp/decode_dev_clean_tgsmall/wer_11_0.0 +%WER 12.29 [ 6262 / 50948, 863 ins, 665 del, 4734 sub ] exp/nnet3_cleaned/tdnn_sp/decode_dev_other_fglarge/wer_15_0.0 +%WER 12.89 [ 6565 / 50948, 773 ins, 853 del, 4939 sub ] exp/nnet3_cleaned/tdnn_sp/decode_dev_other_tglarge/wer_14_0.5 +%WER 15.41 [ 7849 / 50948, 894 ins, 1083 del, 5872 sub ] exp/nnet3_cleaned/tdnn_sp/decode_dev_other_tgmed/wer_15_0.0 +%WER 16.81 [ 8562 / 50948, 896 ins, 1215 del, 6451 sub ] exp/nnet3_cleaned/tdnn_sp/decode_dev_other_tgsmall/wer_14_0.0 +%WER 4.99 [ 2624 / 52576, 393 ins, 253 del, 1978 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_clean_fglarge/wer_13_0.5 +%WER 5.16 [ 2715 / 52576, 359 ins, 319 del, 2037 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_clean_tglarge/wer_12_1.0 +%WER 6.29 [ 3307 / 52576, 471 ins, 341 del, 2495 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_clean_tgmed/wer_12_0.0 +%WER 7.13 [ 3750 / 52576, 473 ins, 452 del, 2825 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_clean_tgsmall/wer_13_0.0 +%WER 12.73 [ 6665 / 52343, 894 ins, 711 del, 5060 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_other_fglarge/wer_14_0.0 +%WER 13.33 [ 6979 / 52343, 920 ins, 796 del, 5263 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_other_tglarge/wer_14_0.0 +%WER 15.90 [ 8323 / 52343, 921 ins, 1126 del, 6276 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_other_tgmed/wer_13_0.0 +%WER 17.28 [ 9044 / 52343, 894 ins, 1372 del, 6778 sub ] exp/nnet3_cleaned/tdnn_sp/decode_test_other_tgsmall/wer_14_0.0 # Results with nnet3 tdnn+sMBR # local/nnet3/run_tdnn_discriminative.sh diff --git a/egs/librispeech/s5/local/nnet3/run_tdnn.sh b/egs/librispeech/s5/local/nnet3/run_tdnn.sh deleted file mode 100755 index 28ee2b920..000000000 --- a/egs/librispeech/s5/local/nnet3/run_tdnn.sh +++ /dev/null @@ -1,127 +0,0 @@ -#!/bin/bash - -# this is the standard "tdnn" system, built in nnet3; it's what we use to -# call multi-splice. - -# without cleanup: -# local/nnet3/run_tdnn.sh --train-set train960 --gmm tri6b --nnet3-affix "" & - - -# At this script level we don't support not running on GPU, as it would be painfully slow. -# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, -# --num-threads 16 and --minibatch-size 128. - -# First the options that are passed through to run_ivector_common.sh -# (some of which are also used in this script directly). -stage=0 -decode_nj=30 -train_set=train_960_cleaned -gmm=tri6b_cleaned # this is the source gmm-dir for the data-type of interest; it - # should have alignments for the specified training data. -nnet3_affix=_cleaned - -# Options which are not passed through to run_ivector_common.sh -affix= -train_stage=-10 -common_egs_dir= -reporting_email= -remove_egs=true - -. ./cmd.sh -. ./path.sh -. ./utils/parse_options.sh - - -if ! cuda-compiled; then - cat </dev/null || true - for test in test_clean test_other dev_clean dev_other; do - ( - steps/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" \ - --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${test}_hires \ - ${graph_dir} data/${test}_hires $dir/decode_${test}_tgsmall || exit 1 - steps/lmrescore.sh --cmd "$decode_cmd" data/lang_test_{tgsmall,tgmed} \ - data/${test}_hires $dir/decode_${test}_{tgsmall,tgmed} || exit 1 - steps/lmrescore_const_arpa.sh \ - --cmd "$decode_cmd" data/lang_test_{tgsmall,tglarge} \ - data/${test}_hires $dir/decode_${test}_{tgsmall,tglarge} || exit 1 - steps/lmrescore_const_arpa.sh \ - --cmd "$decode_cmd" data/lang_test_{tgsmall,fglarge} \ - data/${test}_hires $dir/decode_${test}_{tgsmall,fglarge} || exit 1 - ) || touch $dir/.error & - done - wait - [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 -fi - -exit 0; diff --git a/egs/librispeech/s5/local/nnet3/run_tdnn.sh b/egs/librispeech/s5/local/nnet3/run_tdnn.sh new file mode 120000 index 000000000..61f8f4991 --- /dev/null +++ b/egs/librispeech/s5/local/nnet3/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1b.sh \ No newline at end of file diff --git a/egs/librispeech/s5/local/nnet3/tuning/run_tdnn_1a.sh b/egs/librispeech/s5/local/nnet3/tuning/run_tdnn_1a.sh new file mode 100755 index 000000000..28ee2b920 --- /dev/null +++ b/egs/librispeech/s5/local/nnet3/tuning/run_tdnn_1a.sh @@ -0,0 +1,127 @@ +#!/bin/bash + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +# without cleanup: +# local/nnet3/run_tdnn.sh --train-set train960 --gmm tri6b --nnet3-affix "" & + + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=30 +train_set=train_960_cleaned +gmm=tri6b_cleaned # this is the source gmm-dir for the data-type of interest; it + # should have alignments for the specified training data. +nnet3_affix=_cleaned + +# Options which are not passed through to run_ivector_common.sh +affix= +train_stage=-10 +common_egs_dir= +reporting_email= +remove_egs=true + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat </dev/null || true + for test in test_clean test_other dev_clean dev_other; do + ( + steps/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${test}_hires \ + ${graph_dir} data/${test}_hires $dir/decode_${test}_tgsmall || exit 1 + steps/lmrescore.sh --cmd "$decode_cmd" data/lang_test_{tgsmall,tgmed} \ + data/${test}_hires $dir/decode_${test}_{tgsmall,tgmed} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,tglarge} \ + data/${test}_hires $dir/decode_${test}_{tgsmall,tglarge} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,fglarge} \ + data/${test}_hires $dir/decode_${test}_{tgsmall,fglarge} || exit 1 + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 +fi + +exit 0; diff --git a/egs/librispeech/s5/local/nnet3/tuning/run_tdnn_1b.sh b/egs/librispeech/s5/local/nnet3/tuning/run_tdnn_1b.sh new file mode 100755 index 000000000..a96a1b33e --- /dev/null +++ b/egs/librispeech/s5/local/nnet3/tuning/run_tdnn_1b.sh @@ -0,0 +1,135 @@ +#!/bin/bash + +# 1b is as 1a but uses xconfigs. + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +# without cleanup: +# local/nnet3/run_tdnn.sh --train-set train960 --gmm tri6b --nnet3-affix "" & + + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=30 +train_set=train_960_cleaned +gmm=tri6b_cleaned # this is the source gmm-dir for the data-type of interest; it + # should have alignments for the specified training data. +nnet3_affix=_cleaned + +# Options which are not passed through to run_ivector_common.sh +affix= +train_stage=-10 +common_egs_dir= +reporting_email= +remove_egs=true + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + relu-batchnorm-layer name=tdnn0 dim=1280 + relu-batchnorm-layer name=tdnn1 dim=1280 input=Append(-1,2) + relu-batchnorm-layer name=tdnn2 dim=1280 input=Append(-3,3) + relu-batchnorm-layer name=tdnn3 dim=1280 input=Append(-7,2) + relu-batchnorm-layer name=tdnn4 dim=1280 + output-layer name=output input=tdnn4 dim=$num_targets max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs || exit 1; +fi + +if [ $stage -le 12 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/librispeech-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + steps/nnet3/train_dnn.py --stage=$train_stage \ + --cmd="$decode_cmd" \ + --feat.online-ivector-dir $train_ivector_dir \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --trainer.num-epochs 4 \ + --trainer.optimization.num-jobs-initial 3 \ + --trainer.optimization.num-jobs-final 16 \ + --trainer.optimization.initial-effective-lrate 0.0017 \ + --trainer.optimization.final-effective-lrate 0.00017 \ + --egs.dir "$common_egs_dir" \ + --cleanup.remove-egs $remove_egs \ + --cleanup.preserve-model-interval 100 \ + --feat-dir=$train_data_dir \ + --ali-dir $ali_dir \ + --lang data/lang \ + --reporting.email="$reporting_email" \ + --dir=$dir || exit 1; + +fi + +if [ $stage -le 13 ]; then + # this does offline decoding that should give about the same results as the + # real online decoding (the one with --per-utt true) + rm $dir/.error 2>/dev/null || true + for test in test_clean test_other dev_clean dev_other; do + ( + steps/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" \ + --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${test}_hires \ + ${graph_dir} data/${test}_hires $dir/decode_${test}_tgsmall || exit 1 + steps/lmrescore.sh --cmd "$decode_cmd" data/lang_test_{tgsmall,tgmed} \ + data/${test}_hires $dir/decode_${test}_{tgsmall,tgmed} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,tglarge} \ + data/${test}_hires $dir/decode_${test}_{tgsmall,tglarge} || exit 1 + steps/lmrescore_const_arpa.sh \ + --cmd "$decode_cmd" data/lang_test_{tgsmall,fglarge} \ + data/${test}_hires $dir/decode_${test}_{tgsmall,fglarge} || exit 1 + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 +fi + +exit 0; diff --git a/egs/mini_librispeech/s5/RESULTS b/egs/mini_librispeech/s5/RESULTS index 0b7471204..089b7c918 100755 --- a/egs/mini_librispeech/s5/RESULTS +++ b/egs/mini_librispeech/s5/RESULTS @@ -20,3 +20,7 @@ exit 0 %WER 18.58 [ 3742 / 20138, 366 ins, 763 del, 2613 sub ] exp/chain/tdnn1a_sp/decode_tgsmall_dev_clean_2/wer_10_0.0 %WER 13.35 [ 2689 / 20138, 318 ins, 491 del, 1880 sub ] exp/chain/tdnn1a_sp/decode_tglarge_dev_clean_2/wer_9_0.5 + +# Results with new chain recipe (based on chaina branch). Results are w/o final model combination +%WER 15.64 [ 3150 / 20138, 395 ins, 584 del, 2171 sub ] exp/chaina/tdnn2c_v4_sp/decode_tglarge_dev_clean_2//wer_11_0.0 +%WER 21.38 [ 4305 / 20138, 449 ins, 740 del, 3116 sub ] exp/chaina/tdnn2c_v4_sp/decode_tgsmall_dev_clean_2//wer_10_0.0 diff --git a/egs/mini_librispeech/s5/local/chain2/data_prep_common.sh b/egs/mini_librispeech/s5/local/chain2/data_prep_common.sh new file mode 100755 index 000000000..21b36cce4 --- /dev/null +++ b/egs/mini_librispeech/s5/local/chain2/data_prep_common.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Copyright 2019 Daniel Povey +# 2019 Srikanth Madikeri (Idiap Research Institute) + +set -euo pipefail + +# This script is called from local/chain/tuning/run_tdnn_2a.sh and +# similar scripts. It contains the common feature preparation and +# lattice-alignment preparation parts of the chaina training. +# See those scripts for examples of usage. + +stage=0 +train_set=train_clean_5 +test_sets="dev_clean_2" +gmm=tri3b + +. ./cmd.sh +. ./path.sh +. utils/parse_options.sh + +gmm_dir=exp/${gmm} +ali_dir=exp/${gmm}_ali_${train_set}_sp + +for f in data/${train_set}/feats.scp ${gmm_dir}/final.mdl; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done + +# Our default data augmentation method is 3-way speed augmentation followed by +# volume perturbation. We are looking into better ways of doing this, +# e.g. involving noise and reverberation. + +if [ $stage -le 1 ]; then + # Although the nnet will be trained by high resolution data, we still have to + # perturb the normal data to get the alignment. _sp stands for speed-perturbed + echo "$0: preparing directory for low-resolution speed-perturbed data (for alignment)" + utils/data/perturb_data_dir_speed_3way.sh data/${train_set} data/${train_set}_sp + echo "$0: making MFCC features for low-resolution speed-perturbed data" + steps/make_mfcc.sh --cmd "$train_cmd" --nj 10 data/${train_set}_sp || exit 1; + steps/compute_cmvn_stats.sh data/${train_set}_sp || exit 1; + utils/fix_data_dir.sh data/${train_set}_sp +fi + +if [ $stage -le 2 ]; then + echo "$0: aligning with the perturbed low-resolution data" + steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/${train_set}_sp data/lang $gmm_dir $ali_dir || exit 1 +fi + +if [ $stage -le 3 ]; then + # Create high-resolution MFCC features (with 40 cepstra instead of 13). + # this shows how you can split across multiple file-systems. + echo "$0: creating high-resolution MFCC features" + mfccdir=data/${train_set}_sp_hires/data + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl /export/fs0{1,2}/$USER/kaldi-data/mfcc/mini_librispeech-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + + for datadir in ${train_set}_sp ${test_sets}; do + utils/copy_data_dir.sh data/$datadir data/${datadir}_hires + done + + # do volume-perturbation on the training data prior to extracting hires + # features; this helps make trained nnets more invariant to test data volume. + utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires || exit 1; + + for datadir in ${train_set}_sp ${test_sets}; do + steps/make_mfcc.sh --nj 10 --mfcc-config conf/mfcc_hires.conf \ + --cmd "$train_cmd" data/${datadir}_hires || exit 1; + steps/compute_cmvn_stats.sh data/${datadir}_hires || exit 1; + utils/fix_data_dir.sh data/${datadir}_hires || exit 1; + done +fi + + +exit 0 diff --git a/egs/mini_librispeech/s5/local/chain2/run_tdnn.sh b/egs/mini_librispeech/s5/local/chain2/run_tdnn.sh new file mode 120000 index 000000000..344993628 --- /dev/null +++ b/egs/mini_librispeech/s5/local/chain2/run_tdnn.sh @@ -0,0 +1 @@ +tuning/run_tdnn_1a.sh \ No newline at end of file diff --git a/egs/mini_librispeech/s5/local/chain2/tuning/run_tdnn_1a.sh b/egs/mini_librispeech/s5/local/chain2/tuning/run_tdnn_1a.sh new file mode 100755 index 000000000..55141424a --- /dev/null +++ b/egs/mini_librispeech/s5/local/chain2/tuning/run_tdnn_1a.sh @@ -0,0 +1,332 @@ +#!/bin/bash + +# Copyright 2019 Srikanth Madikeri (Idiap Research Institute) +# +# This script is a modification of local/chain/run_tdnn.sh adapted to the chain2 recipes. + +# Set -e here so that we catch if any executable fails immediately +set -euo pipefail + +# First the options that are passed through to run_ivector_common.sh +# (some of which are also used in this script directly). +stage=0 +decode_nj=10 +train_set=train_clean_5 +test_sets=dev_clean_2 +gmm=tri3b +srand=0 +nnet3_affix= + +# The rest are configs specific to this script. Most of the parameters +# are just hardcoded at this level, in the commands below. +affix=2c # affix for the TDNN directory name +tree_affix= +train_stage=-10 +get_egs_stage=-10 + + +# training chunk-options +chunk_width=140 +dropout_schedule='0,0@0.20,0.3@0.50,0' +xent_regularize=0.1 +bottom_subsampling_factor=1 # I'll set this to 3 later, 1 is for compatibility with a broken ru. +frame_subsampling_factor=3 +langs="default" # list of language names + +# The amount of extra left/right context we put in the egs. Note: this could +# easily be zero, since we're not using a recurrent topology, but we put in a +# little extra context so that we have more room to play with the configuration +# without re-dumping egs. +egs_extra_left_context=5 +egs_extra_right_context=5 + +# The number of chunks (of length: see $chunk_width above) that we group +# together for each "speaker" (actually: pseudo-speaker, since we may have +# to group multiple speaker together in some cases). +chunks_per_group=4 + + +# End configuration section. +echo "$0 $@" # Print the command line for logging + +. ./cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +# if ! cuda-compiled; then +# cat <$lang/topo + fi +fi + +if [ $stage -le 11 ]; then + # Get the alignments as lattices (gives the chain training more freedom). + # use the same num-jobs as the alignments + steps/align_fmllr_lats.sh --nj 75 --cmd "$train_cmd" ${lores_train_data_dir} \ + data/lang $gmm_dir $lat_dir + rm $lat_dir/fsts.*.gz # save space +fi + +if [ $stage -le 12 ]; then + # Build a tree using our new topology. We know we have alignments for the + # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use + # those. The num-leaves is always somewhat less than the num-leaves from + # the GMM baseline. + # This will be a two-level tree (with the smaller number of leaves specified + # by the '--num-clusters' option); this is needed by the adaptation framework + # search below for 'tree.map' + if [ -f $tree_dir/final.mdl ]; then + echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it." + exit 1; + fi + steps/nnet3/chain/build_tree.sh \ + --frame-subsampling-factor ${frame_subsampling_factor} \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 3500 ${lores_train_data_dir} \ + $lang $ali_dir $tree_dir +fi + + +# $dir/configs will contain xconfig and config files for the initial +# models. It's a scratch space used by this script but not by +# scripts called from here. +mkdir -p $dir/configs/ +# $dir/init will contain the initial models +mkdir -p $dir/init/ + +learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python) + +if [ $stage -le 14 ]; then + + # Note: we'll use --bottom-subsampling-factor=3, so all time-strides for the + # top network should be interpreted at the 30ms frame subsampling rate. + num_leaves=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}') + + echo "$0: creating top model" + cat < $dir/configs/default.xconfig + input name=input dim=40 + # the first splicing is moved before the lda layer, so no splicing here + fixed-affine-layer name=lda input=Append(-2,-1,0,1,2) affine-transform-file=$dir/configs/lda.mat + relu-renorm-layer name=tdnn1 dim=512 input=Append(-2,-1,0,1,2) + relu-renorm-layer name=tdnn2 dim=512 input=Append(-1,0,1) + relu-renorm-layer name=tdnn3 dim=512 input=Append(-1,0,1) + relu-renorm-layer name=tdnn4 dim=512 input=Append(-3,0,3) + relu-renorm-layer name=tdnn5 dim=512 input=Append(-3,0,3) + relu-renorm-layer name=tdnn6 dim=512 input=Append(-6,-3,0) + relu-renorm-layer name=prefinal-chain dim=512 target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_leaves max-change=1.5 + output-layer name=output-default input=prefinal-chain include-log-softmax=false dim=$num_leaves max-change=1.5 + relu-renorm-layer name=prefinal-xent input=tdnn6 dim=512 target-rms=0.5 + output-layer name=output-xent dim=$num_leaves learning-rate-factor=$learning_rate_factor max-change=1.5 + output-layer name=output-default-xent input=prefinal-xent dim=$num_leaves learning-rate-factor=$learning_rate_factor max-change=1.5 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/default.xconfig --config-dir $dir/configs/ + if [ $dir/init/default_trans.mdl ]; then # checking this because it may have been copied in a previous run of the same script + copy-transition-model $tree_dir/final.mdl $dir/init/default_trans.mdl || exit 1 & + else + echo "Keeping the old $dir/init/default_trans.mdl as it already exists." + fi +fi +wait; + +init_info=$dir/init/info.txt +if [ $stage -le 15 ]; then + + if [ ! -f $dir/configs/ref.raw ]; then + echo "Expected $dir/configs/ref.raw to exist" + exit + fi + + nnet3-info $dir/configs/ref.raw > $dir/configs/temp.info + model_left_context=`fgrep 'left-context' $dir/configs/temp.info | awk '{print $2}'` + model_right_context=`fgrep 'right-context' $dir/configs/temp.info | awk '{print $2}'` + cat >$init_info <$lang/topo +fi + +if [ $stage -le 11 ]; then + # Build a tree using our new topology. This is the critically different + # step compared with other recipes. + steps/nnet3/chain/build_tree.sh --frame-subsampling-factor $frame_subsampling_factor \ + --context-opts "--context-width=2 --central-position=1" \ + --cmd "$train_cmd" 7000 data/$train_set $lang $ali_dir $treedir +fi + +if [ $stage -le 12 ]; then + echo "$0: creating neural net configs using the xconfig parser"; + + num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}') + learning_rate_factor=$(echo "print (0.5/$xent_regularize)" | python) + + mkdir -p $dir/configs + cat < $dir/configs/network.xconfig + input dim=100 name=ivector + input dim=40 name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat + + # the first splicing is moved before the lda layer, so no splicing here + relu-batchnorm-layer name=tdnn1 dim=625 + relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625 + relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625 + relu-batchnorm-layer name=tdnn7 input=Append(-3,0,3) dim=625 + + ## adding the layers for chain branch + relu-batchnorm-layer name=prefinal-chain input=tdnn7 dim=625 target-rms=0.5 + output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5 + output-layer name=output-default input=prefinal-chain include-log-softmax=false dim=$num_targets max-change=1.5 + + # adding the layers for xent branch + # This block prints the configs for a separate output that will be + # trained with a cross-entropy objective in the 'chain' models... this + # has the effect of regularizing the hidden parts of the model. we use + # 0.5 / args.xent_regularize as the learning rate factor- the factor of + # 0.5 / args.xent_regularize is suitable as it means the xent + # final-layer learns at a rate independent of the regularization + # constant; and the 0.5 was tuned so as to make the relative progress + # similar in the xent and regular final layers. + relu-batchnorm-layer name=prefinal-xent input=tdnn7 dim=625 target-rms=0.5 + output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + output-layer name=output-default-xent input=prefinal-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5 + +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/ + if [ ! -f $dir/init/default_trans.mdl ]; then # checking this because it may have been copied in a previous run of the same script + copy-transition-model $treedir/final.mdl $dir/init/default_trans.mdl || exit 1 & + else + echo "Keeping the old $dir/init/default_trans.mdl as it already exists." + fi +fi + +init_info=$dir/init/info.txt +if [ $stage -le 13 ]; then + + if [ ! -f $dir/configs/ref.raw ]; then + echo "Expected $dir/configs/ref.raw to exist" + exit + fi + + mkdir -p $dir/init + nnet3-info $dir/configs/ref.raw > $dir/configs/temp.info + model_left_context=`fgrep 'left-context' $dir/configs/temp.info | awk '{print $2}'` + model_right_context=`fgrep 'right-context' $dir/configs/temp.info | awk '{print $2}'` + cat >$init_info </dev/null || true + for decode_set in eval2000; do + ( + decode_nj=`wc -l data/${decode_set}_hires/spk2utt | cut -d ' ' -f1` + steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \ + --nj $decode_nj --cmd "$decode_cmd" $iter_opts \ + --online-ivector-dir exp/nnet3/ivectors_${decode_set} \ + $graph_dir data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_iter:+_$decode_iter}${decode_suff}_sw1_tg || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + $dir/decode_${decode_set}${decode_iter:+_$decode_iter}${decode_suff}_sw1_{tg,fsh_fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi + +if $test_online_decoding && [ $stage -le 16 ]; then + # note: if the features change (e.g. you add pitch features), you will have to + # change the options of the following command line. + steps/online/nnet3/prepare_online_decoding.sh \ + --mfcc-config conf/mfcc_hires.conf \ + $lang exp/nnet3/extractor $dir ${dir}_online + + rm $dir/.error 2>/dev/null || true + for decode_set in train_dev eval2000; do + ( + # note: we just give it "$decode_set" as it only uses the wav.scp, the + # feature type does not matter. + + steps/online/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" \ + --acwt 1.0 --post-decode-acwt 10.0 \ + $graph_dir data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg || exit 1; + if $has_fisher; then + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \ + data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \ + ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg} || exit 1; + fi + ) || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: something went wrong in decoding" + exit 1 + fi +fi + + +exit 0; + diff --git a/egs/wsj/s5/steps/chain2 b/egs/wsj/s5/steps/chain2 new file mode 120000 index 000000000..cf32f4266 --- /dev/null +++ b/egs/wsj/s5/steps/chain2 @@ -0,0 +1 @@ +nnet3/chain2 \ No newline at end of file diff --git a/egs/wsj/s5/steps/combine_ali_dirs.sh b/egs/wsj/s5/steps/combine_ali_dirs.sh index b74b004ca..39f2ff2b3 100755 --- a/egs/wsj/s5/steps/combine_ali_dirs.sh +++ b/egs/wsj/s5/steps/combine_ali_dirs.sh @@ -166,10 +166,13 @@ do_combine() { # Merge (presumed already sorted) scp's into a single script. sort -m $temp_dir/$ark.*.scp > $temp_dir/$ark.scp || exit 1 + inputs=$(for n in `seq $nj`; do echo $temp_dir/$ark.$n.scp; done) + utils/split_scp.pl --utt2spk=$data/utt2spk $temp_dir/$ark.scp $inputs + echo "$0: Splitting combined $entities into $nj archives on speaker boundary." $cmd JOB=1:$nj $dest/log/chop_combined_$entities.JOB.log \ $copy_program \ - "scp:utils/split_scp.pl --utt2spk=$data/utt2spk --one-based -j $nj JOB $temp_dir/$ark.scp |" \ + "scp:$temp_dir/$ark.JOB.scp" \ "ark:| gzip -c > $dest/$ark.JOB.gz" || exit 1 # Get some interesting stats, and signal an error if error threshold exceeded. diff --git a/egs/wsj/s5/steps/combine_trans_dirs.sh b/egs/wsj/s5/steps/combine_trans_dirs.sh new file mode 100644 index 000000000..2edb9b9f6 --- /dev/null +++ b/egs/wsj/s5/steps/combine_trans_dirs.sh @@ -0,0 +1,133 @@ +#!/bin/bash +# Copyright 2016 Xiaohui Zhang Apache 2.0. +# Copyright 2019 SmartAction (kkm) +# Copyright 2019 manhong wang (marvin) + +# This script only combines transform file in the aligments dirs, egs: trans.1, and +# validates matching of the utterances and alignments after combining. you would need this fmllr trans +# files after you combine ali or lat dirs(combine_ali_dirs.sh or combine_lat_dis.sh). + +# Begin configuration section. +cmd=run.pl +tolerance=10 +# End configuration section. +echo "$0 $@" # Print the command line for logging. + +[[ -f path.sh ]] && . ./path.sh +. parse_options.sh || exit 1 + +export LC_ALL=C + +if [[ $# -lt 3 ]]; then + cat >&2 < ... + e.g.: $0 data/train exp/tri3_trans_combined exp/tri3_trans_1 exp_tri3_trans_2 +Options: + --tolerance # maximum percentage of missing trans + # w.r.t. total utterances in before error is + # reported [10] + +Note:we do not checks that certain important files are present and compatible in all +source directories (phones.txt, tree) here.Because you would run combine_trans_dirs.sh +or combine_lat_dis.sh first. + +EOF + exit 1; +fi + + +data=$1 +dest=$2 +shift 2 +first_src=$1 + +do_trans=true + + +# All checks passed, ok to prepare directory. but we do not Copy model and other files from +# the first source. + +for src in $@; do + if [[ "$(cd 2>/dev/null -P -- "$src" && pwd)" = \ + "$(cd 2>/dev/null -P -- "$dest" && pwd)" ]]; then + echo "$0: error: Source $src is same as target $dest." + exit 1 + fi + if $do_trans && [[ ! -f $src/trans.1 ]]; then + echo "$0: warning: transform (trans.*) are not present in $src, not" \ + "combining. please check you files" + exit 1 + fi +done + +if [ ! -f $dest/ali.1.gz ] && [ ! -f $dest/lat.1.gz ] ; then + echo "$0: warning: we assume you have combined the ali or lat dirs " \ + "please run combine_ali_dir.sh or combine_lat_dir.sh firstly" + exit 1 +fi + +nj=$(cat $dest/num_jobs) + +if [ -f $dest/trans.1 ] ; then rm $dest/trans.* ;fi #remove old trans.* + +# Make temporary directory, delete on signal, but not on 'exit 1'. +temp_dir=$(mktemp -d $dest/temp.XXXXXX) || exit 1 +cleanup() { rm -rf "$temp_dir"; } +trap cleanup HUP INT TERM +echo "$0: note: Temporary directory $temp_dir will not be deleted in case of" \ + "script failure, so you could examine it for troubleshooting." + +do_combine_trans() { + local ark=$1 entities=$2 copy_program=$3 + shift 3 + + echo "$0: Gathering $entities from each source directory." + # Assign all source gzipped archive names to an exported variable, one each + # per source directory, so that we can copy archives in a job per source. + src_id=0 + for src in $@; do + src_id=$((src_id + 1)) + nj_src=$(cat $src/num_jobs) || exit 1 + # Create and export variable src_arcs_${src_id} for the job runner. + # Each numbered variable will contain the list of archives, e. g.: + # src_arcs_1="exp/tri3_ali/trans.1 exp/tri3_ali/trans.1 ..." + # ('printf' repeats its format as long as there are more arguments). + printf "$src/$ark.%d " $(seq $nj_src) > $temp_dir/src_arks.${src_id} + done + + # Gather archives in parallel jobs. + $cmd JOB=1:$src_id $dest/log/gather_$entities.JOB.log \ + $copy_program \ + "ark:cat \$(cat $temp_dir/src_arks.JOB) |" \ + "ark,scp:$temp_dir/$ark.JOB,$temp_dir/$ark.JOB.scp" || exit 1 + + # Merge (presumed already sorted) scp's into a single script. + sort -m $temp_dir/$ark.*.scp > $temp_dir/$ark.scp || exit 1 + + echo "$0: Splitting combined $entities into $nj archives on speaker boundary." + $cmd JOB=1:$nj $dest/log/chop_combined_$entities.JOB.log \ + $copy_program \ + "scp:utils/split_scp.pl -j $nj JOB --one-based $temp_dir/$ark.scp |" \ + "ark:$dest/$ark.JOB" || exit 1 + + # Get some interesting stats. + n_utt=$(wc -l <$data/spk2utt) + n_trans=$(wc -l <$temp_dir/$ark.scp) + n_utt_no_trans_pct=$(perl -e "print int(($n_utt - $n_trans)/$n_utt * 100 + .5);") + echo "$0: Combined $n_trans $entities for $n_utt utterances." + + if (( $n_utt_no_trans_pct >= $tolerance )); then + echo "$0: error: Percentage of utterances missing $entities," \ + "${n_utt_no_trans_pct}%, is at or above error tolerance ${tolerance}%." + exit 1 + fi + + return 0 +} + +$do_trans && do_combine_trans trans 'transforms' copy-matrix "$@" + +cleanup # Delete the temporary directory on success. + +echo "$0: Stored combined fmllr trans in $dest" +exit 0 diff --git a/egs/wsj/s5/steps/copy_lat_dir.sh b/egs/wsj/s5/steps/copy_lat_dir.sh index dd1e10fb3..67b2a6638 100755 --- a/egs/wsj/s5/steps/copy_lat_dir.sh +++ b/egs/wsj/s5/steps/copy_lat_dir.sh @@ -69,6 +69,6 @@ rm $dir/lat_tmp.* echo $nj > $dir/num_jobs -for f in cmvn_opts splice_opts final.mdl splice_opts tree frame_subsampling_factor; do +for f in phones.txt cmvn_opts splice_opts final.mdl splice_opts tree frame_subsampling_factor; do if [ -f $src_dir/$f ]; then cp $src_dir/$f $dir/$f; fi done diff --git a/egs/wsj/s5/steps/copy_trans_dir.sh b/egs/wsj/s5/steps/copy_trans_dir.sh new file mode 100644 index 000000000..3344d749e --- /dev/null +++ b/egs/wsj/s5/steps/copy_trans_dir.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# Copyright 2019 Phani Sankar Nidadavolu +# Copyright 2019 manhong wang(marvin) +# Apache 2.0. + +#This script creates fmllr transform for the aug dirs by copying +#the trans of original train dir after you copy_ali_dirs.sh or copy_lat_dirs.sh +#Note : wo do not accept --nj here ,which shoud keep same as ali file +prefixes="reverb1 babble music noise" +include_original=true +cmd=run.pl +write_binary=true + +. ./path.sh +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo "This script creates fmllr transform for the aug dirs by copying " + echo " the trans of original train dir" + echo "While copying it adds prefix to the utterances specified by prefixes option" + echo "Note that the original train dir does not have any prefix" + echo "To include the original training directory in the copied " + echo "version set the --include-original option to true" + echo "main options (for others, see top of script file)" + echo " --prefixes # All the prefixes of aug data to be included" + echo " --include-original # If true, will copy the alignements of original dir" + exit 1 +fi + +data=$1 +src_dir=$2 +dir=$3 + +if [ ! -d $dir ]; then + echo "$0: warning : you may need combine ali or lat first !" && exit 1 +fi + +if [ ! -f $src_dir/trans.1 ] ; then + echo "$0: no trans exist in $src_dir dir" && exit 1 +fi + + +nj=$(cat $dir/num_jobs) +rm -f $dir/trans* 2>/dev/null + +# Copy the fmllr trans temporarily +echo "creating temporary trans in $dir" +$cmd JOB=1:$nj $dir/log/copy_trans_temp.JOB.log \ + copy-matrix --binary=$write_binary \ + "ark:cat $src_dir/trans.JOB |" \ + ark,scp:$dir/trans_tmp.JOB.ark,$dir/trans_tmp.JOB.scp || exit 1 + +# Make copies of utterances for perturbed data +for p in $prefixes; do + cat $dir/trans_tmp.*.scp | awk -v p=$p '{print p"-"$0}' +done | sort -k1,1 > $dir/trans_out.scp.aug + +if [ "$include_original" == "true" ]; then + cat $dir/trans_tmp.*.scp | awk '{print $0}' | sort -k1,1 > $dir/trans_out.scp.clean + cat $dir/trans_out.scp.clean $dir/trans_out.scp.aug | sort -k1,1 > $dir/trans_out.scp +else + cat $dir/trans_out.scp.aug | sort -k1,1 > $dir/trans_out.scp.old +fi + +utils/filter_scp.pl ${data}/spk2utt $dir/trans_out.scp.old > $dir/trans_out.scp +utils/split_data.sh ${data} $nj + +# Copy and dump the trans for perturbed data +echo Creating fmllr trans for augmented data by copying fmllr trans from clean data +$cmd JOB=1:$nj $dir/log/copy_out_trans.JOB.log \ + copy-matrix --binary=$write_binary \ + "scp:utils/split_scp.pl --one-based -j $nj JOB $dir/trans_out.scp |" \ + ark:$dir/trans.JOB || exit 1 + +n_aug_trans=`wc -l $data/spk2utt` +n_copy_trans=`wc -l $dir/trans_out.scp` +echo "copy $n_copy_trans speaker's fmllr trans of total $n_aug_trans" +rm $dir/trans_out.scp.aug $dir/trans_out.scp.old $dir/trans_out.scp $dir/trans_tmp.* +exit 0 diff --git a/egs/wsj/s5/steps/data/reverberate_data_dir.py b/egs/wsj/s5/steps/data/reverberate_data_dir.py index 4ea44aad9..ea504244d 100755 --- a/egs/wsj/s5/steps/data/reverberate_data_dir.py +++ b/egs/wsj/s5/steps/data/reverberate_data_dir.py @@ -136,7 +136,10 @@ def pick_item_with_probability(x): collection (list or dictionary) where the values contain a field called probability """ if isinstance(x, dict): - plist = list(set(x.values())) + keylist = list(x.keys()) + keylist.sort() + random.shuffle(keylist) + plist = [x[k] for k in keylist] else: plist = x total_p = sum(item.probability for item in plist) diff --git a/egs/wsj/s5/steps/diagnostic/analyze_lattice_depth_stats.py b/egs/wsj/s5/steps/diagnostic/analyze_lattice_depth_stats.py index 6ed2bf781..8ae5e1ef6 100755 --- a/egs/wsj/s5/steps/diagnostic/analyze_lattice_depth_stats.py +++ b/egs/wsj/s5/steps/diagnostic/analyze_lattice_depth_stats.py @@ -9,6 +9,15 @@ import argparse import sys, os from collections import defaultdict +from io import open +import codecs + +# reference: http://www.macfreek.nl/memory/Encoding_of_Python_stdout +if sys.version_info.major == 2: + sys.stdout = codecs.getwriter('utf-8')(sys.stdout, 'strict') +else: + assert sys.version_info.major == 3 + sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict') parser = argparse.ArgumentParser(description="This script reads stats created in analyze_lats.sh " @@ -29,13 +38,13 @@ # set up phone_int2text to map from phone to printed form. phone_int2text = {} try: - f = open(args.lang + "/phones.txt", "r"); + f = open(args.lang + "/phones.txt", "r", encoding='utf-8') for line in f.readlines(): [ word, number] = line.split() phone_int2text[int(number)] = word f.close() except: - sys.exit("analyze_lattice_depth_stats.py: error opening or reading {0}/phones.txt".format( + sys.exit(u"analyze_lattice_depth_stats.py: error opening or reading {0}/phones.txt".format( args.lang)) # this is a special case... for begin- and end-of-sentence stats, # we group all nonsilence phones together. @@ -49,14 +58,14 @@ # open lang/phones/silence.csl-- while there are many ways of obtaining the # silence/nonsilence phones, we read this because it's present in graph # directories as well as lang directories. - filename = "{0}/phones/silence.csl".format(args.lang) + filename = u"{0}/phones/silence.csl".format(args.lang) f = open(filename, "r") line = f.readline() for silence_phone in line.split(":"): nonsilence.remove(int(silence_phone)) f.close() except Exception as e: - sys.exit("analyze_lattice_depth_stats.py: error processing {0}/phones/silence.csl: {1}".format( + sys.exit(u"analyze_lattice_depth_stats.py: error processing {0}/phones/silence.csl: {1}".format( args.lang, str(e))) # phone_depth_counts is a dict of dicts. @@ -80,7 +89,7 @@ break a = line.split() if len(a) != 3: - sys.exit("analyze_lattice_depth_stats.py: reading stdin, could not interpret line: " + line) + sys.exit(u"analyze_lattice_depth_stats.py: reading stdin, could not interpret line: " + line) try: phone, depth, count = [ int(x) for x in a ] @@ -92,11 +101,11 @@ universal_phone = -1 phone_depth_counts[universal_phone][depth] += count except Exception as e: - sys.exit("analyze_lattice_depth_stats.py: unexpected phone {0} " - "seen (lang directory mismatch?): line is {1}, error is {2}".format(phone, line, str(e))) + sys.exit(u"analyze_lattice_depth_stats.py: unexpected phone {0} " + u"seen (lang directory mismatch?): line is {1}, error is {2}".format(phone, line, str(e))) if total_frames == 0: - sys.exit("analyze_lattice_depth_stats.py: read no input") + sys.exit(u"analyze_lattice_depth_stats.py: read no input") # If depth_to_count is a map from depth-in-frames to count, @@ -125,8 +134,8 @@ def GetMean(depth_to_count): return this_total_depth / this_total_frames -print("The total amount of data analyzed assuming 100 frames per second " - "is {0} hours".format("%.1f" % (total_frames / 360000.0))) +print(u"The total amount of data analyzed assuming 100 frames per second " + u"is {0} hours".format("%.1f" % (total_frames / 360000.0))) # the next block prints lines like (to give some examples): # Nonsilence phones as a group account for 74.4% of phone occurrences, with lattice depth (10,50,90-percentile)=(1,2,7) and mean=3.1 @@ -152,18 +161,18 @@ def GetMean(depth_to_count): try: phone_text = phone_int2text[phone] except: - sys.exit("analyze_lattice_depth_stats.py: phone {0} is not covered on phones.txt " - "(lang/alignment mismatch?)".format(phone)) - preamble = "Phone {phone_text} accounts for {percent}% of frames, with".format( + sys.exit(u"analyze_lattice_depth_stats.py: phone {0} is not covered on phones.txt " + u"(lang/alignment mismatch?)".format(phone)) + preamble = u"Phone {phone_text} accounts for {percent}% of frames, with".format( phone_text = phone_text, percent = "%.1f" % frequency_percentage) elif phone == 0: - preamble = "Nonsilence phones as a group account for {percent}% of frames, with".format( + preamble = u"Nonsilence phones as a group account for {percent}% of frames, with".format( percent = "%.1f" % frequency_percentage) else: assert phone == -1 preamble = "Overall,"; - print("{preamble} lattice depth (10,50,90-percentile)=({p10},{p50},{p90}) and mean={mean}".format( + print(u"{preamble} lattice depth (10,50,90-percentile)=({p10},{p50},{p90}) and mean={mean}".format( preamble = preamble, p10 = depth_percentile_10, p50 = depth_percentile_50, diff --git a/egs/wsj/s5/steps/diagnostic/analyze_phone_length_stats.py b/egs/wsj/s5/steps/diagnostic/analyze_phone_length_stats.py index 5ebd9e736..549c1875a 100755 --- a/egs/wsj/s5/steps/diagnostic/analyze_phone_length_stats.py +++ b/egs/wsj/s5/steps/diagnostic/analyze_phone_length_stats.py @@ -8,6 +8,15 @@ import argparse import sys, os from collections import defaultdict +from io import open +import codecs + +# reference: http://www.macfreek.nl/memory/Encoding_of_Python_stdout +if sys.version_info.major == 2: + sys.stdout = codecs.getwriter('utf-8')(sys.stdout, 'strict') +else: + assert sys.version_info.major == 3 + sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict') parser = argparse.ArgumentParser(description="This script reads stats created in analyze_alignments.sh " @@ -31,7 +40,7 @@ # set up phone_int2text to map from phone to printed form. phone_int2text = {} try: - f = open(args.lang + "/phones.txt", "r"); + f = open(args.lang + "/phones.txt", "r", encoding='utf-8') for line in f.readlines(): [ word, number] = line.split() phone_int2text[int(number)] = word @@ -112,8 +121,8 @@ optional_silence_phone_text = phone_int2text[optional_silence_phone] f.close() if optional_silence_phone in nonsilence: - print("analyze_phone_length_stats.py: was expecting the optional-silence phone to " - "be a member of the silence phones, it is not. This script won't work correctly.") + print(u"analyze_phone_length_stats.py: was expecting the optional-silence phone to " + u"be a member of the silence phones, it is not. This script won't work correctly.") except: largest_count = 0 optional_silence_phone = 1 @@ -124,8 +133,8 @@ largest_count = this_count optional_silence_phone = p optional_silence_phone_text = phone_int2text[optional_silence_phone] - print("analyze_phone_length_stats.py: could not get optional-silence phone from " - "{0}/phones/optional_silence.int, guessing that it's {1} from the stats. ".format( + print(u"analyze_phone_length_stats.py: could not get optional-silence phone from " + u"{0}/phones/optional_silence.int, guessing that it's {1} from the stats. ".format( args.lang, optional_silence_phone_text)) @@ -175,8 +184,8 @@ def GetMean(length_to_count): # maybe half a second. If your database is not like this, you should know; # you may want to mess with the segmentation to add more silence. if frequency_percentage < 80.0: - print("analyze_phone_length_stats.py: WARNING: optional-silence {0} is seen only {1}% " - "of the time at utterance {2}. This may not be optimal.".format( + print(u"analyze_phone_length_stats.py: WARNING: optional-silence {0} is seen only {1}% " + u"of the time at utterance {2}. This may not be optimal.".format( optional_silence_phone_text, frequency_percentage, boundary_type)) @@ -213,8 +222,8 @@ def GetMean(length_to_count): except: sys.exit("analyze_phone_length_stats.py: phone {0} is not covered on phones.txt " "(lang/alignment mismatch?)".format(phone)) - print("{text}, {phone_text} accounts for {percent}% of phone occurrences, with " - "duration (median, mean, 95-percentile) is ({median},{mean},{percentile95}) frames.".format( + print(u"{text}, {phone_text} accounts for {percent}% of phone occurrences, with " + u"duration (median, mean, 95-percentile) is ({median},{mean},{percentile95}) frames.".format( text = text, phone_text = phone_text, percent = "%.1f" % frequency_percentage, median = duration_median, mean = "%.1f" % duration_mean, @@ -245,16 +254,16 @@ def GetMean(length_to_count): opt_sil_total_frame_percent = total_optsil_frames * 100.0 / total_frames['all'] internal_frame_percent = total_frames['internal'] * 100.0 / total_frames['all'] - print("The optional-silence phone {0} occupies {1}% of frames overall ".format( + print(u"The optional-silence phone {0} occupies {1}% of frames overall ".format( optional_silence_phone_text, "%.1f" % opt_sil_total_frame_percent)) hours_total = total_frames['all'] / 360000.0; hours_nonsil = (total_frames['all'] - total_optsil_frames) / 360000.0 - print("Limiting the stats to the {0}% of frames not covered by an utterance-[begin/end] phone, " - "optional-silence {1} occupies {2}% of frames.".format("%.1f" % internal_frame_percent, + print(u"Limiting the stats to the {0}% of frames not covered by an utterance-[begin/end] phone, " + u"optional-silence {1} occupies {2}% of frames.".format("%.1f" % internal_frame_percent, optional_silence_phone_text, "%.1f" % opt_sil_internal_frame_percent)) - print("Assuming 100 frames per second, the alignments represent {0} hours of data, " - "or {1} hours if {2} frames are excluded.".format( + print(u"Assuming 100 frames per second, the alignments represent {0} hours of data, " + u"or {1} hours if {2} frames are excluded.".format( "%.1f" % hours_total, "%.1f" % hours_nonsil, optional_silence_phone_text)) opt_sil_internal_phone_percent = (sum(internal_opt_sil_phone_lengths.values()) * @@ -262,7 +271,7 @@ def GetMean(length_to_count): duration_median = GetPercentile(internal_opt_sil_phone_lengths, 0.5) duration_mean = GetMean(internal_opt_sil_phone_lengths) duration_percentile_95 = GetPercentile(internal_opt_sil_phone_lengths, 0.95) - print("Utterance-internal optional-silences {0} comprise {1}% of utterance-internal phones, with duration " - "(median, mean, 95-percentile) = ({2},{3},{4})".format( + print(u"Utterance-internal optional-silences {0} comprise {1}% of utterance-internal phones, with duration " + u"(median, mean, 95-percentile) = ({2},{3},{4})".format( optional_silence_phone_text, "%.1f" % opt_sil_internal_phone_percent, duration_median, "%0.1f" % duration_mean, duration_percentile_95)) diff --git a/egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py b/egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py index 0de907451..af641237f 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/dropout_schedule.py @@ -223,6 +223,45 @@ def _get_dropout_proportions(dropout_schedule, data_fraction): component_dropout_schedule, data_fraction))) return dropout_proportions +def get_dropout_edit_option(dropout_schedule, data_fraction, iter_): + """Return an option to be passed to nnet3-copy (or nnet3-am-copy) + that will set the appropriate dropout proportion. If no dropout + is being used (dropout_schedule is None), returns the empty + string, otherwise returns something like + "--edits='set-dropout-proportion name=* proportion=0.625'" + Arguments: + dropout_schedule: Value for the --trainer.dropout-schedule option. + See help for --trainer.dropout-schedule. + See _self_test() for examples. + data_fraction: real number in [0,1] that says how far along + in training we are. + iter_: iteration number (needed for debug printing only) + See ReadEditConfig() in nnet3/nnet-utils.h to see how + set-dropout-proportion directive works. + """ + + if dropout_schedule is None: + return "" + + dropout_proportions = _get_dropout_proportions( + dropout_schedule, data_fraction) + + edit_config_lines = [] + dropout_info = [] + + for component_name, dropout_proportion in dropout_proportions: + edit_config_lines.append( + "set-dropout-proportion name={0} proportion={1}".format( + component_name, dropout_proportion)) + dropout_info.append("pattern/dropout-proportion={0}/{1}".format( + component_name, dropout_proportion)) + + if _debug_dropout: + logger.info("On iteration %d, %s", iter_, ', '.join(dropout_info)) + + return "--edits='{0}'".format(";".join(edit_config_lines)) + + def get_dropout_edit_string(dropout_schedule, data_fraction, iter_): """Return an nnet3-copy --edits line to modify raw_model_string to diff --git a/egs/wsj/s5/steps/nnet3/chain/align_lats.sh b/egs/wsj/s5/steps/nnet3/chain/align_lats.sh new file mode 100644 index 000000000..a8c169429 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain/align_lats.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# Copyright 2012 Brno University of Technology (Author: Karel Vesely) +# 2013 Johns Hopkins University (Author: Daniel Povey) +# 2015 Vijayaditya Peddinti +# 2016 Vimal Manohar +# 2017 Pegah Ghahremani +# Apache 2.0 + +# Computes training alignments using nnet3 DNN, with output to lattices. + +# Begin configuration section. +nj=4 +cmd=run.pl +stage=-1 +# Begin configuration. +scale_opts="--transition-scale=1.0 --self-loop-scale=1.0" +acoustic_scale=1.0 +post_decode_acwt=10.0 +beam=20 +iter=final +frames_per_chunk=50 +extra_left_context=0 +extra_right_context=0 +extra_left_context_initial=-1 +extra_right_context_final=-1 +online_ivector_dir= +graphs_scp= +# End configuration options. + +echo "$0 $@" # Print the command line for logging + +[ -f path.sh ] && . ./path.sh # source the path. +. parse_options.sh || exit 1; + +if [ $# != 4 ]; then + echo "Usage: $0 " + echo "e.g.: $0 data/train data/lang exp/nnet4 exp/nnet4_ali" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +data=$1 +lang=$2 +srcdir=$3 +dir=$4 + +oov=`cat $lang/oov.int` || exit 1; +mkdir -p $dir/log +echo $nj > $dir/num_jobs +sdata=$data/split${nj} +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || \ + split_data.sh $data $nj || exit 1; + +extra_files= +if [ ! -z "$online_ivector_dir" ]; then + steps/nnet2/check_ivectors_compatible.sh $srcdir $online_ivector_dir || exit 1 + extra_files="$online_ivector_dir/ivector_online.scp $online_ivector_dir/ivector_period" +fi + +for f in $srcdir/tree $srcdir/${iter}.mdl $data/feats.scp $lang/L.fst $extra_files; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +cp $srcdir/{tree,${iter}.mdl} $dir || exit 1; + +utils/lang/check_phones_compatible.sh $lang/phones.txt $srcdir/phones.txt || exit 1; +cp $lang/phones.txt $dir || exit 1; +## Set up features. Note: these are different from the normal features +## because we have one rspecifier that has the features for the entire +## training set, not separate ones for each batch. +echo "$0: feature type is raw" + +cmvn_opts=`cat $srcdir/cmvn_opts 2>/dev/null` +cp $srcdir/cmvn_opts $dir 2>/dev/null + +feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |" + +ivector_opts= +if [ ! -z "$online_ivector_dir" ]; then + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" +fi + +echo "$0: aligning data in $data using model from $srcdir, putting alignments in $dir" + +frame_subsampling_opt= +if [ -f $srcdir/frame_subsampling_factor ]; then + # e.g. for 'chain' systems + frame_subsampling_factor=$(cat $srcdir/frame_subsampling_factor) + frame_subsampling_opt="--frame-subsampling-factor=$frame_subsampling_factor" + cp $srcdir/frame_subsampling_factor $dir + if [ "$frame_subsampling_factor" -gt 1 ] && \ + [ "$scale_opts" == "--transition-scale=1.0 --self-loop-scale=0.1" ]; then + echo "$0: frame-subsampling-factor is not 1 (so likely a chain system)," + echo "... but the scale opts are the defaults. You probably want" + echo "--scale-opts '--transition-scale=1.0 --self-loop-scale=1.0'" + sleep 1 + fi +fi + +if [ ! -z "$graphs_scp" ]; then + if [ ! -f $graphs_scp ]; then + echo "Could not find graphs $graphs_scp" && exit 1 + fi + tra="scp:utils/filter_scp.pl $sdata/JOB/utt2spk $graphs_scp |" + prog=compile-train-graphs-fsts +else + tra="ark:utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt $sdata/JOB/text|"; + prog=compile-train-graphs +fi + +if [ $stage -le 0 ]; then + ## because nnet3-latgen-faster doesn't support adding the transition-probs to the + ## graph itself, we need to bake them into the compiled graphs. This means we can't reuse previously compiled graphs, + ## because the other scripts write them without transition probs. + $cmd JOB=1:$nj $dir/log/compile_graphs.JOB.log \ + $prog --read-disambig-syms=$lang/phones/disambig.int \ + $scale_opts \ + $dir/tree $srcdir/${iter}.mdl $lang/L.fst "$tra" \ + "ark:|gzip -c >$dir/fsts.JOB.gz" || exit 1 +fi + +if [ $stage -le 1 ]; then + # Warning: nnet3-latgen-faster doesn't support a retry-beam so you may get more + # alignment errors (however, it does have a default min-active=200 so this + # will tend to reduce alignment errors). + # --allow_partial=false makes sure we reach the end of the decoding graph. + # --word-determinize=false makes sure we retain the alternative pronunciations of + # words (including alternatives regarding optional silences). + # --lattice-beam=$beam keeps all the alternatives that were within the beam, + # it means we do no pruning of the lattice (lattices from a training transcription + # will be small anyway). + $cmd JOB=1:$nj $dir/log/generate_lattices.JOB.log \ + nnet3-latgen-faster --acoustic-scale=$acoustic_scale $ivector_opts $frame_subsampling_opt \ + --frames-per-chunk=$frames_per_chunk \ + --extra-left-context=$extra_left_context \ + --extra-right-context=$extra_right_context \ + --extra-left-context-initial=$extra_left_context_initial \ + --extra-right-context-final=$extra_right_context_final \ + --beam=$beam --lattice-beam=$beam \ + --allow-partial=false --word-determinize=false \ + $srcdir/${iter}.mdl "ark:gunzip -c $dir/fsts.JOB.gz |" \ + "$feats" "ark:|lattice-copy --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz" || exit 1; +fi + +echo "$0: done generating lattices from training transcripts." \ No newline at end of file diff --git a/egs/wsj/s5/steps/nnet3/chain/get_model_context.sh b/egs/wsj/s5/steps/nnet3/chain/get_model_context.sh new file mode 100755 index 000000000..39b7bbab6 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain/get_model_context.sh @@ -0,0 +1,107 @@ +#!/bin/bash + +# Copyright 2019 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# 2019 Idiap Research Institute (Author: Srikanth Madikeri) +# +# This script computes the total left and right context needed for example (eg) +# creation from a set of 'chain' models. +# See the usage message for more information about input and output formats. + +# Begin configuration section. +frame_subsampling_factor=1 # The total frame subsampling factor of the bottom + # + top model, i.e. the relative difference in + # frame rate between the input of the bottom model + # and the output of the top model. Would normally + # be 3. + +langs=default # the list of languages. This script checks that + # in the dir (first arg to the script), each + # language exists as $lang.mdl, and it warns if + # any model files appear (which might indicate a + # script bug). +# End configuration section + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + + +if [ $# != 2 ]; then + cat 1>&2 < +This script works out some acoustic-context-related information, +and writes it, long with the options provided to the script, +to the provided. An example of what +output-info-file> might contain after this script is called, is: +langs default +frame_subsampling_factor 3 +bottom_subsampling_factor 3 +model_left_context 22 +model_right_context 22 + e.g.: $0 --frame-subsampling-factor 3 + --langs 'default' exp/chaina/tdnn1a_sp/0 exp/chaina/tdnn1a_sp/0/info.txt + Options: + --frame-subsampling-factor # (default: 1) Total frame subsampling factor of + # both models combined, i.e. ratio of + # frame rate of input features vs. + # alignments and decoding (e.g. 3). + --bottom-subsampling-factor # (default: 1) Controls the frequency at which + # the output of the bottom model is + # evaluated, and the interpretation of frame + # offsets in the top config file. Must be a + # divisor of --frame-subsampling-factor + --langs # The list of languages (must be in quotes, + # to be parsed as a single arg). May be + # 'default' or e.g. 'english french' +EOF + exit 1; +fi + + +dir=$1 +info_file=$2 + +# die on error or undefined variable. +set -e -u + +if [ ! -d $dir ]; then + echo 1>&2 "$0: expected directory $dir to exist" + exit 1 +fi + +if [ -z $langs ]; then + echo 1>&2 "$0: list of languages (--langs option) is empty" + exit 1 +fi + +if ! [ $frame_subsampling_factor -ge 1 ]; then + echo 1>&2 "$0: there was a problem with the options --frame-subsampling-factor=$frame_subsampling_factor" + exit 1 +fi + +mkdir -p $dir/temp + +for lang in $langs; do + if [ ! -s $dir/$lang.mdl ]; then + echo 1>&2 "$0: expected file $dir/$lang.mdl to exist and be nonempty (check --langs option)" + exit 1 + fi + nnet3-am-info $dir/$lang.mdl > $dir/temp/$lang.info + this_left_context=$(grep '^left-context:' $dir/temp/$lang.info | awk '{print $2}') + this_right_context=$(grep '^right-context:' $dir/temp/$lang.info | awk '{print $2}') +done + +left_context=$this_left_context +right_context=$this_right_context + + +cat >$info_file < " + echo "e.g. $0 exp/chain/tdnn1a_sp/configs/init.raw exp/chain/tdnn1a_sp/egs/ exp/chain/tdnn1a_sp" + echo "" + echo "This script computes pre-conditioning matrix given the model (usually init.raw file from the config folder)," + echo "egs-folder which has train.*.scp files to be used to train LDA, and" + echo "lda-output-folder that will contain lda.mat file." + echo "" + echo "Main options (for others, see top of script file)" + echo " --cmd (utils/run.pl;utils/queue.pl ) # how to run jobs." + echo " --nj # number of jobs. this is also the number of train.*.scp files in egs/" + echo " --lda-acc-opts # options to be passed to nnet3-chain-acc-lda-stats" + echo " --lda-sum-opts # options to be passed to sum-lda-accs" + echo " --lda-transform-opts # options to be passed to nnet-get-feature-transform" + exit 1; +fi + +model=$1 +egs=$2 +ldafolder=$3 + +if [ ! -d $ldafolder ]; then + echo "Creating $ldafolder" + mkdir -p $ldafolder || exit 1 +fi + + +if [ $stage -le 0 ]; then + if $use_scp; then + egs_rspecifier="ark:nnet3-chain-copy-egs $egs_opts scp:$egs/train.JOB.scp ark:- |" + else + egs_rspecifier="ark:nnet3-chain-copy-egs $egs_opts ark:$egs/train.JOB.ark ark:- |" + fi + echo "$0: Accumulating LDA stats" + $cmd JOB=1:$nj $ldafolder/log/acc.JOB.log \ + nnet3-chain-acc-lda-stats $lda_acc_opts --rand-prune=${rand_prune} \ + $model "${egs_rspecifier}" \ + $ldafolder/JOB.lda_stats || exit 1 +fi + +if [ $stage -le 1 ]; then + echo "$0: Summing LDA stats" + lda_stats_files= + for i in `seq 1 $nj`; do + lda_stats_files="$lda_stats_files $ldafolder/$i.lda_stats" + done + + $cmd $ldafolder/log/sum_transform_stats.log \ + sum-lda-accs $lda_sum_opts $ldafolder/lda_stats $lda_stats_files || exit 1 + rm $lda_stats_files +fi + +if [ $stage -le 2 ]; then + echo "$0: Computing LDA transform" + $cmd $ldafolder/log/get_transform.log \ + nnet-get-feature-transform $lda_transform_opts \ + $ldafolder/lda.mat $ldafolder/lda_stats || exit 1 + + rm $ldafolder/lda_stats + ln -rs $ldafolder/lda.mat $ldafolder/configs/lda.mat +fi + +echo "$0: Finished computing LDA transform" +exit 0; diff --git a/egs/wsj/s5/steps/nnet3/chain2/get_raw_egs.sh b/egs/wsj/s5/steps/nnet3/chain2/get_raw_egs.sh new file mode 100755 index 000000000..d1fae9754 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain2/get_raw_egs.sh @@ -0,0 +1,310 @@ +#!/bin/bash + +# Copyright 2019 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# Copyright 2019 Idiap Research Institute (Author: Srikanth Madikeri). Apache 2.0. +# +# This script dumps 'raw' egs for 'chain' training. What 'raw' means in this +# context is that they need to be further processed to merge egs of the same +# speaker, etc. So they won't be directly consumed by training, but by +# by the script process_egs.sh. + + + +# Begin configuration section. +cmd=run.pl +frames_per_chunk=150 # Number of frames (at feature frame rate) per example. You + # are allowed to make this a comma-separated list, + # e.g. 150,110,100, meaning that a range of eg widths are + # allowed (but this may not be as helpful when using our + # adaptation framework, since it will tend to split up + # utterances into separate minibatches. + +frame_subsampling_factor=3 # frames-per-second of features we train on divided + # by frames-per-second at output of chain model +alignment_subsampling_factor=3 # frames-per-second of input alignments divided + # by frames-per-second at output of chain model +constrained=true # 'constrained=true' is the traditional setup; 'constrained=false' + # gives you the 'unconstrained' egs creation in which the time + # boundaries are not enforced inside chunks. +left_context=0 # amount of left-context per eg (i.e. extra frames of input + # features not present in the output supervision). Would + # normally depend on the model context, plus desired 'extra' + # context (e.g. for LSTM). +right_context=0 # amount of right-context per eg. + +left_context_initial=-1 # if >=0, right-context for last chunk of an utterance. +right_context_final=-1 # if >=0, right-context for last chunk of an utterance. + +compress=true # set this to false to disable compression (e.g. if you want to + # see whether results are affected). Note: if the features on + # disk were originally compressed, nnet3-chain-get-egs will dump + # compressed features regardless (since there is no further loss + # in that case). + +lang=default # the language name. will usually be 'default' in single-language + # setups. Requires because it's part of the name of some of + # the input files. + +right_tolerance= # chain right tolerance == max label delay. Only relevant if + # constrained=true. At frame rate of alignments. Code + # default is 5. +left_tolerance= # chain left tolerance (versus alignments from lattices). + # Only relevant if constrained=true. At frame rate of + # alignments. Code default is 5. + +stage=0 +max_jobs_run=40 # This should be set to the maximum number of + # nnet3-chain-get-egs jobs you are comfortable to run in + # parallel; you can increase it if your disk speed is + # greater and you have more machines. + + +srand=0 # rand seed for nnet3-chain-get-egs, nnet3-chain-copy-egs and nnet3-chain-shuffle-egs +online_ivector_dir= # can be used if we are including speaker information as iVectors. +cmvn_opts= # can be used for specifying CMVN options, if feature type is not lda (if lda, + # it doesn't make sense to use different options than were used as input to the + # LDA transform). This is used to turn off CMVN in the online-nnet experiments. + +lattice_lm_scale= # If supplied, the graph/lm weight of the lattices will be + # used (with this scale) in generating supervisions + # This is 0 by default for conventional supervised training, + # but may be close to 1 for the unsupervised part of the data + # in semi-supervised training. The optimum is usually + # 0.5 for unsupervised data. +lattice_prune_beam= # If supplied, the lattices will be pruned to this beam, + # before being used to get supervisions. + +acwt=0.1 # For pruning. Should be, for instance, 1.0 for chain lattices. +deriv_weights_scp= + +# end configuration section + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + + +if [ $# != 4 ]; then + echo "Usage: $0 [opts] " + echo " e.g.: $0 data/train exp/chain/tdnn1a_sp exp/tri3_lats exp/chain/tdnn1a_sp/raw_egs" + echo "" + echo "From , 0/.mdl (for the transition-model), .tree (the tree), " + echo " den_fsts/.den.fst, and den_fsts/.normalization.fst (the normalization " + echo " FST, derived from the denominator FST echo are read (where is specified" + echo " by the --lang option (its default values is 'default')" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options (alternative to this" + echo " # command line)" + echo " --max-jobs-run # The maximum number of jobs you want to run in" + echo " # parallel (increase this only if you have good disk and" + echo " # network speed). default=6" + echo " --cmd (utils/run.pl;utils/queue.pl ) # how to run jobs." + echo " --frame-subsampling-factor # factor by which num-frames at nnet output is reduced " + echo " --lang # Name of the language, determines names of some inputs." + echo " --frames-per-chunk # number of supervised frames per chunk on disk" + echo " # ... may be a comma separated list, but we advise a single" + echo " # number in most cases, due to interaction with the need " + echo " # to group egs from the same speaker into groups." + echo " --left-context # Number of frames on left side to append for feature input" + echo " --right-context # Number of frames on right side to append for feature input" + echo " --left-context-initial # Left-context for first chunk of an utterance" + echo " --right-context-final # Right-context for last chunk of an utterance" + echo " --lattice-lm-scale # If supplied, the graph/lm weight of the lattices will be " + echo " # used (with this scale) in generating supervisions" + echo " --lattice-prune-beam # If supplied, the lattices will be pruned to this beam, " + echo " # before being used to get supervisions." + echo " --acwt # Acoustic scale -- should be acoustic scale at which the " + echo " # supervision lattices are to be interpreted. Affects pruning" + echo " --deriv-weights-scp # If supplied, adds per-frame weights to the supervision." + echo " # (e.g., might be relevant for unsupervised training)." + echo " --stage # Used to run this script from somewhere in" + echo " # the middle." + exit 1; +fi + +data=$1 +chaindir=$2 +latdir=$3 +dir=$4 + +tree=$chaindir/${lang}.tree +trans_mdl=$chaindir/init/${lang}.mdl # contains the transition model and a nnet, but + # we won't be making use of the nnet part. +normalization_fst=$chaindir/den_fsts/${lang}.normalization.fst +den_fst=$chaindir/den_fsts/${lang}.den.fst + +[ ! -z "$online_ivector_dir" ] && \ + extra_files="$online_ivector_dir/ivector_online.scp $online_ivector_dir/ivector_period" + +for f in $data/feats.scp $latdir/lat.1.gz $latdir/final.mdl \ + $tree $normalization_fst $den_fst $extra_files; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done +if [ ! -f $trans_mdl ]; then + trans_mdl=$chaindir/init/${lang}_trans.mdl + if [ ! -f $trans_mdl ]; then + echo "$0: cannot find transition model in $chaindir/init/${lang}_trans.mdl or $trans_mdl" + exit 1 + fi +fi + +nj=$(cat $latdir/num_jobs) || exit 1 +if [ -f $latdir/per_utt ]; then + sdata=$data/split${nj}utt + utils/split_data.sh --per-utt $data $nj +else + sdata=$data/split$nj + utils/split_data.sh $data $nj +fi + +mkdir -p $dir/log $dir/misc + +cp $tree $dir/misc/ +copy-transition-model $trans_mdl $dir/misc/${lang}.trans_mdl +cp $normalization_fst $den_fst $dir/misc/ +cp $data/utt2spk $dir/misc/ +if [ -f $data/utt2uniq ]; then + cp $data/utt2uniq $dir/misc/ +elif [ -f $dir/misc/utt2uniq ]; then + rm $dir/misc/utt2uniq +fi + +if [ -e $dir/storage ]; then + # Make soft links to storage directories, if distributing this way.. See + # utils/create_split_dir.pl. + echo "$0: creating data links" + utils/create_data_link.pl $(for x in $(seq $nj); do echo $dir/cegs.$x.ark; done) +fi + + +lats_rspecifier="ark:gunzip -c $latdir/lat.JOB.gz |" +if [ ! -z $lattice_prune_beam ]; then + if [ "$lattice_prune_beam" == "0" ] || [ "$lattice_prune_beam" == "0.0" ]; then + lats_rspecifier="$lats_rspecifier lattice-1best --acoustic-scale=$acwt ark:- ark:- |" + else + lats_rspecifier="$lats_rspecifier lattice-prune --acoustic-scale=$acwt --beam=$lattice_prune_beam ark:- ark:- |" + fi +fi + +egs_opts="--long-key=true --left-context=$left_context --right-context=$right_context --num-frames=$frames_per_chunk --frame-subsampling-factor=$frame_subsampling_factor --compress=$compress" +[ $left_context_initial -ge 0 ] && egs_opts="$egs_opts --left-context-initial=$left_context_initial" +[ $right_context_final -ge 0 ] && egs_opts="$egs_opts --right-context-final=$right_context_final" + +[ ! -z "$deriv_weights_scp" ] && egs_opts="$egs_opts --deriv-weights-rspecifier=scp:$deriv_weights_scp" + + +chain_supervision_all_opts="--lattice-input=true --frame-subsampling-factor=$alignment_subsampling_factor" +[ ! -z $right_tolerance ] && \ + chain_supervision_all_opts="$chain_supervision_all_opts --right-tolerance=$right_tolerance" + +[ ! -z $left_tolerance ] && \ + chain_supervision_all_opts="$chain_supervision_all_opts --left-tolerance=$left_tolerance" + +if ! $constrained; then + # e2e supervision + chain_supervision_all_opts="$chain_supervision_all_opts --convert-to-pdfs=false" + egs_opts="$egs_opts --transition-model=$chaindir/0.trans_mdl" +fi + +if [ ! -z "$lattice_lm_scale" ]; then + chain_supervision_all_opts="$chain_supervision_all_opts --lm-scale=$lattice_lm_scale" + + normalization_fst_scale=$(perl -e " + if ($lattice_lm_scale >= 1.0 || $lattice_lm_scale < 0) { + print STDERR \"Invalid --lattice-lm-scale $lattice_lm_scale\"; exit(1); + } + print (1.0 - $lattice_lm_scale);") || exit 1 + egs_opts="$egs_opts --normalization-fst-scale=$normalization_fst_scale" +fi + +if [ ! -z "$online_ivector_dir" ]; then + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" +else + ivector_opts="" +fi + +feats="scp:$sdata/JOB/feats.scp" +if [ ! -z $cmvn_opts ]; then + if [ ! -f $data/cmvn.scp ]; then + echo "Cannot find $data/cmvn.scp. But cmvn_opts=$cmvn_opts" + exit 1 + fi + if [ `echo $cmvn_opts | fgrep -c true` -eq 1 ]; then + feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |" + fi +fi + +if [ $stage -le 0 ]; then + $cmd --max-jobs-run $max_jobs_run JOB=1:$nj $dir/log/get_egs.JOB.log \ + lattice-align-phones --replace-output-symbols=true $latdir/final.mdl \ + "$lats_rspecifier" ark:- \| \ + chain-get-supervision $chain_supervision_all_opts \ + $dir/misc/${lang}.tree $dir/misc/${lang}.trans_mdl ark:- ark:- \| \ + nnet3-chain-get-egs $ivector_opts --srand=\$[JOB+$srand] $egs_opts \ + "$normalization_fst" "$feats" ark,s,cs:- \ + ark,scp:$dir/cegs.JOB.ark,$dir/cegs.JOB.scp || exit 1; +fi + + +if [ $stage -le 1 ]; then + num_input_frames=$(steps/nnet2/get_num_frames.sh $data) + frames_and_chunks=$(for n in $(seq $nj); do cat $dir/log/get_egs.$n.log; done | \ + perl -e '$nc=0; $nf=0; while() { + if (m/Split .+ into (\d+) chunks/) { $this_nc = $1; } + if (m/Average chunk length was (\d+) frames/) { $nf += $1 * $this_nc; $nc += $this_nc; } + } print "$nf $nc"; ') + echo $frames_and_chunks + num_chunks=$(echo $frames_and_chunks | awk '{print $2}') + frames_per_chunk_avg=$[num_input_frames/num_chunks] + feat_dim=$(feat-to-dim scp:$sdata/1/feats.scp -) + num_leaves=$(tree-info $tree | awk '/^num-pdfs/ {print $2}') + if [ $left_context_initial -lt 0 ]; then + left_context_initial=$left_context + fi + if [ $right_context_final -lt 0 ]; then + right_context_final=$right_context + fi + + cat >$dir/info.txt < $dir/info/ivector_dim + echo ivector_dim $ivector_dim >> $dir/info.txt + ivector_id=`steps/nnet2/get_ivector_id.sh $online_ivector_dir || exit 1` + echo ivector_id $ivector_id + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + echo ivector_period $ivector_period + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" + else + ivector_opts="" + fi + + if ! cat $dir/info.txt | awk '{if (NF == 1) exit(1);}'; then + echo "$0: we failed to obtain at least one of the fields in $dir/info.txt" + exit 1 + fi +fi + + +if [ $stage -le 2 ]; then + for n in $(seq $nj); do cat $dir/cegs.$n.scp; done > $dir/all.scp +fi + +echo "$0: Finished preparing raw egs" diff --git a/egs/wsj/s5/steps/nnet3/chain2/internal/get_best_model.sh b/egs/wsj/s5/steps/nnet3/chain2/internal/get_best_model.sh new file mode 100755 index 000000000..8cc46a006 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain2/internal/get_best_model.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Copyright 2019 Idiap Research Institute (Author: Srikanth Madikeri). Apache 2.0. +# This script is the equivalent of get_successful_models function in the python library. +# It takes a list of models and returns either the best model (the deafult) or a list of +# models to average. + +models_to_average=false +difference_threshold=1.0 +output=output + + +# echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# -le 1 ]; then + echo "Usage: $0: [options] .... " + echo "where is one of the n models to choose from." + echo "" + echo "--models-to-average: when true, returns the models to be averaged rather than the single best model" + echo "--difference-threshold: used to reject models. models with objf < max-value - difference_threshold are rejected" + echo "--output: the objf of the this output layer is used for model selection" + echo "" + exit 1; +fi + +if ! $models_to_average; then + if [ $# -eq 1 ]; then + basename $1 | tr '.' ' ' | awk '{ print $(NF-1) }' + exit 0; + fi + model_log_list=$(for arg in $*; do echo $arg; done) + first_log=$1 + log_line=`fgrep -m 1 "Overall average objective function for '$output' is" $first_log` + colno=`echo $log_line | cut -d '=' -f1 | wc -w` + ((colno+=2)) + filename=$(fgrep -m 1 "Overall average objective function for '$output' is" $model_log_list | \ + cut -d ' ' -f1,$colno | tr ':' ' ' | \ + awk '{print $1,$3}' | \ + sort -k2,2 -g | tail -1 | cut -d ' ' -f1) + basename $filename | tr '.' ' ' | awk '{ print $(NF-1) }' +fi diff --git a/egs/wsj/s5/steps/nnet3/chain2/internal/get_train_schedule.py b/egs/wsj/s5/steps/nnet3/chain2/internal/get_train_schedule.py new file mode 100755 index 000000000..a7bf72c2a --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain2/internal/get_train_schedule.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# Copyright 2019 Johns Hopkins University (author: Daniel Povey) +# Copyright Hossein Hadian +# Copyright 2019 Idiap Research Institute (Author: Srikanth Madikeri). + + +# Apache 2.0. + +""" This script outputs information about a neural net training schedule, + to be used by ../train.sh, in the form of lines that can be selected + and sourced by the shell. +""" + +import argparse +import sys + +sys.path.insert(0, 'steps') +import libs.nnet3.train.common as common_train_lib +import libs.common as common_lib + +def get_args(): + parser = argparse.ArgumentParser( + description="""Output training schedule information to be consumed by ../train.sh""", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--frame-subsampling-factor", type=int, default=3, + help="""Frame subsampling factor for the combined model + (bottom+top), will normally be 3. Required here in order + to deal with frame-shifted versions of the input.""") + parser.add_argument("--initial-effective-lrate", + type=float, + dest='initial_effective_lrate', default=0.001, + help="""Effective learning rate used on the first iteration, + determines schedule via geometric interpolation with + --final-effective-lrate. Actual learning rate is + this times the num-jobs on that iteration.""") + parser.add_argument("--final-effective-lrate", type=float, + dest='final_effective_lrate', default=0.0001, + help="""Learning rate used on the final iteration, see + --initial-effective-lrate for more documentation.""") + parser.add_argument("--num-jobs-initial", type=int, default=1, + help="""Number of parallel neural net jobs to use at + the start of training""") + parser.add_argument("--num-jobs-final", type=int, default=1, + help="""Number of parallel neural net jobs to use at + the end of training. Would normally + be >= --num-jobs-initial""") + parser.add_argument("--num-epochs", type=float, default=4.0, + help="""The number of epochs to train for. + Note: the 'real' number of times we see each + utterance is this number times --frame-subsampling-factor + (to cover frame-shifted copies of the data), times + the value of --num-repeats given to process_egs.sh, + times any factor arising from data augmentation.""") + parser.add_argument("--dropout-schedule", type=str, + help="""Use this to specify the dropout schedule (how the dropout probability varies + with time, 0 == no dropout). You specify a piecewise + linear function on the domain [0,1], where 0 is the + start and 1 is the end of training; the + function-argument (x) rises linearly with the amount of + data you have seen, not iteration number (this improves + invariance to num-jobs-{initial-final}). E.g. '0,0.2,0' + means 0 at the start; 0.2 after seeing half the data; + and 0 at the end. You may specify the x-value of + selected points, e.g. '0,0.2@0.25,0' means that the 0.2 + dropout-proportion is reached a quarter of the way + through the data. The start/end x-values are at + x=0/x=1, and other unspecified x-values are interpolated + between known x-values. You may specify different rules + for different component-name patterns using + 'pattern1=func1 pattern2=func2', e.g. 'relu*=0,0.1,0 + lstm*=0,0.2,0'. More general should precede less + general patterns, as they are applied sequentially.""") + + parser.add_argument("--num-scp-files", type=int, default=0, required=True, + help="""The number of .scp files in the egs dir.""") + parser.add_argument("--schedule-out", type=str, required=True, + help="""Output file containing the training schedule. The output + is lines, one per training iteration. + Each line (one per iteration) is a list of ;-separated commands setting shell + variables. Currently the following variables are set: + iter, num_jobs, inv_num_jobs, scp_indexes, frame_shifts, dropout_opt, lrate. + """) + + print(sys.argv, file=sys.stderr) + args = parser.parse_args() + + return args + +def get_schedules(args): + num_scp_files_expanded = args.num_scp_files * args.frame_subsampling_factor + num_scp_files_to_process = int(args.num_epochs * num_scp_files_expanded) + num_scp_files_processed = 0 + num_iters = ((num_scp_files_to_process * 2) + // (args.num_jobs_initial + args.num_jobs_final)) + + with open(args.schedule_out, 'w', encoding='latin-1') as ostream: + for iter in range(num_iters): + current_num_jobs = int(0.5 + args.num_jobs_initial + + (args.num_jobs_final - args.num_jobs_initial) + * float(iter) / num_iters) + # as a special case, for iteration zero we use just one job + # regardless of the --num-jobs-initial and --num-jobs-final. This + # is because the model averaging does not work reliably for a + # freshly initialized model. + # if iter == 0: + # current_num_jobs = 1 + + lrate = common_train_lib.get_learning_rate(iter, current_num_jobs, + num_iters, + num_scp_files_processed, + num_scp_files_to_process, + args.initial_effective_lrate, + args.final_effective_lrate) + + if args.dropout_schedule == "": + args.dropout_schedule = None + dropout_edit_option = common_train_lib.get_dropout_edit_option( + args.dropout_schedule, + float(num_scp_files_processed) / max(1, (num_scp_files_to_process - args.num_jobs_final)), + iter) + + frame_shifts = [] + egs = [] + for job in range(1, current_num_jobs + 1): + # k is a zero-based index that we will derive the other indexes from. + k = num_scp_files_processed + job - 1 + # work out the 1-based scp index. + scp_index = (k % args.num_scp_files) + 1 + # previous : frame_shift = (k/num_scp_files) % frame_subsampling_factor + frame_shift = ((scp_index + k // args.num_scp_files) + % args.frame_subsampling_factor) + + # Instead of frame shifts like [0, 1, 2], we make them more like + # [0, 1, -1]. This is clearer in intent, and keeps the + # supervision starting at frame zero, which IIRC is a + # requirement somewhere in the 'chaina' code. +# TODO: delete this section if no longer useful + # if frame_shift > (args.frame_subsampling_factor // 2): + # frame_shift = frame_shift - args.frame_subsampling_factor + + frame_shifts.append(str(frame_shift)) + egs.append(str(scp_index)) + + + print("""iter={iter}; num_jobs={nj}; inv_num_jobs={nj_inv}; scp_indexes=(pad {indexes}); frame_shifts=(pad {shifts}); dropout_opt="{opt}"; lrate={lrate}""".format( + iter=iter, nj=current_num_jobs, nj_inv=(1.0 / current_num_jobs), + indexes = ' '.join(egs), shifts=' '.join(frame_shifts), + opt=dropout_edit_option, lrate=lrate), file=ostream) + num_scp_files_processed = num_scp_files_processed + current_num_jobs + + +def main(): + args = get_args() + get_schedules(args) + +if __name__ == "__main__": + main() diff --git a/egs/wsj/s5/steps/nnet3/chain2/process_egs.sh b/egs/wsj/s5/steps/nnet3/chain2/process_egs.sh new file mode 100755 index 000000000..57894330f --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain2/process_egs.sh @@ -0,0 +1,171 @@ +#!/bin/bash + +# Copyright 2019 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# Copyright 2019 Idiap Research Institute (Author: Srikanth Madikeri). Apache 2.0. +# +# This script takes nnet examples dumped by steps/chain/get_raw_egs.sh and +# combines the chunks into groups by speaker (to the extent possible; it may +# need to combine speakers in some cases), locally randomizes the result, and +# dumps the resulting egs to disk. Chunks of these will later be globally +# randomized (at the scp level) by steps/chaina/randomize_egs.sh + + +# Begin configuration section. +cmd=run.pl +chunks_per_group=4 +num_repeats=2 # number of times we repeat the same chunks with different + # grouping. Recommend 1 or 2; must divide chunks_per_group +compress=true # set this to false to disable compression (e.g. if you want to see whether + # results are affected). + +num_utts_subset=300 # number of utterances in validation and training + # subsets used for shrinkage and diagnostics. + + +shuffle_buffer_size=5000 # Size of buffer (containing grouped egs) to use + # for random shuffle. + +stage=0 +nj=5 # the number of parallel jobs to run. +srand=0 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + + +if [ $# != 2 ]; then + echo "Usage: $0 [opts] " + echo " e.g.: $0 --chunks-per-group 4 exp/chaina/tdnn1a_sp/raw_egs exp/chaina/tdnn1a_sp/processed_egs" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options (alternative to this" + echo " # command line)" + echo " --cmd (utils/run.pl;utils/queue.pl ) # how to run jobs." + echo " --chunks-per-group # Number of chunks (preferentially, from a single speaker" + echo " # to combine into each example. This grouping of" + echo " # egs is part of the 'chaina' framework; the adaptation" + echo " # parameters will be estimated from these groups." + echo " --num-repeats # Number of times we group the same chunks into different" + echo " # groups. For now only the values 1 and 2 are" + echo " # recommended, due to the very simple way we choose" + echo " # the groups (it's consecutive)." + echo " --nj # Number of jobs to run in parallel. Usually quite a" + echo " # small number, as we'll be limited by disk access" + echo " # speed." + echo " --compress # True if you want the egs to be compressed" + echo " # (e.g. you may set to false for debugging purposes, to" + echo " # check that the compression is not hurting)." + echo " --num-heldout-egs # Number of egs to put in train_subset.scp and heldout_subset.scp." + echo " # These will be used for diagnostics. Note: this number is" + echo " # the number of grouped egs, after merging --chunks-per-group" + echo " # chunks into a single eg." + echo " # ... may be a comma separated list, but we advise a single" + echo " # number in most cases, due to interaction with the need " + echo " # to group egs from the same speaker into groups." + echo " --stage # Used to run this script from somewhere in" + echo " # the middle." + exit 1; +fi + +raw_egs_dir=$1 +dir=$2 + +# die on error or undefined variable. +set -e -u + +if ! steps/chain/validate_raw_egs.sh $raw_egs_dir; then + echo "$0: failed to validate input directory $raw_egs_dir" + exit 1 +fi + + +mkdir -p $dir/temp $dir/log + + +if [ $stage -le 0 ]; then + echo "$0: choosing heldout_subset and train_subset" + + utt2uniq_opt= + if [ -f $raw_egs_dir/misc/utt2uniq ]; then + utt2uniq_opt="--utt2uniq=$raw_egs_dir/misc/utt2uniq" + echo "$0: File $raw_egs_dir/misc/utt2uniq exists, so ensuring the hold-out set" \ + "includes all perturbed versions of the same source utterance." + utils/utt2spk_to_spk2utt.pl $raw_egs_dir/misc/utt2uniq 2>/dev/null | \ + utils/shuffle_list.pl 2>/dev/null | \ + awk -v max_utt=$num_utts_subset '{ + for (n=2;n<=NF;n++) print $n; + printed += NF-1; + if (printed >= max_utt) nextfile; }' \ + | fgrep -f - $raw_egs_dir/all.scp | sort -k1,1 > $dir/temp/heldout_subset.list + else + awk '{print $1}' $raw_egs_dir/misc/utt2spk | \ + utils/shuffle_list.pl 2>/dev/null | \ + head -$num_utts_subset | fgrep -f - $raw_egs_dir/all.scp | sort -k1,1 > $dir/temp/heldout_subset.list + fi + + awk '{print $1}' $raw_egs_dir/misc/utt2spk | \ + utils/filter_scp.pl --exclude $dir/temp/heldout_subset.list | \ + utils/shuffle_list.pl 2>/dev/null | \ + head -$num_utts_subset | fgrep -f - $raw_egs_dir/all.scp | sort -k1,1 > $dir/temp/train_subset.list + + awk '{print $1}' $raw_egs_dir/misc/utt2spk | \ + utils/filter_scp.pl --exclude $dir/temp/heldout_subset.list | fgrep -f - $raw_egs_dir/all.scp > $dir/temp/train.list + fi +len_valid_uttlist=$(wc -l < $dir/temp/heldout_subset.list) +len_trainsub_uttlist=$(wc -l <$dir/temp/train_subset.list) + +if [ $stage -le 1 ]; then + + for name in heldout_subset train_subset; do + echo "$0: merging and shuffling $name egs" + + cp $dir/temp/${name}.list $dir/temp/${name}.scp + + $cmd $dir/log/shuffle_${name}_egs.log \ + nnet3-chain-shuffle-egs --srand=$srand scp:$dir/temp/${name}.scp ark,scp:$dir/${name}.ark,$dir/${name}.scp + done + + # Split up the training list into multiple smaller lists, as it could be long. + utils/split_scp.pl $dir/temp/train.list $(for j in $(seq $nj); do echo $dir/temp/train.$j.list; done) + # Linearize these lists and add keys to make them in scp format; + # nnet3-chain-merge-egs will merge the right groups, it's deterministic + # and we specified --minibatch-size=$chunks_per_group. + for j in $(seq $nj); do + #awk '{for (n=1;n<=NF;n++) { count++; print count, $n; }}' <$dir/temp/train.$j.list >$dir/temp/train.$j.scp + cp $dir/temp/train.${j}.list $dir/temp/train.${j}.scp + done + + if [ -e $dir/storage ]; then + # Make soft links to storage directories, if distributing this way.. See + # utils/create_split_dir.pl. + echo "$0: creating data links" + utils/create_data_link.pl $(for j in $(seq $nj); do echo $dir/train.$j.ark; done) || true + fi + + $cmd JOB=1:$nj $dir/log/shuffle_train_egs.JOB.log \ + nnet3-chain-shuffle-egs --buffer-size=$shuffle_buffer_size \ + --srand=\$[JOB+$srand] scp:$dir/temp/train.JOB.scp ark,scp:$dir/train.JOB.ark,$dir/train.JOB.scp || exit 1; + cat $(for j in $(seq $nj); do echo $dir/train.$j.scp; done) > $dir/train.scp +fi + +cat $raw_egs_dir/info.txt | awk -v num_repeats=$num_repeats \ + ' + /^dir_type / { print "dir_type processed_chain_egs"; next; } + /^num_input_frames / { print "num_input_frames "$2 * num_repeats; next; } # approximate; ignores held-out egs. + /^num_chunks / { print "num_chunks " $2 * num_repeats; next; } + {print;} + END{print "num_repeats " num_repeats;}' >$dir/info.txt + + + +if ! cat $dir/info.txt | awk '{if (NF == 1) exit(1);}'; then + echo "$0: we failed to obtain at least one of the fields in $dir/info.txt" + exit 1 +fi + +cp -r $raw_egs_dir/misc/ $dir/ + + +echo "$0: Finished processing egs" diff --git a/egs/wsj/s5/steps/nnet3/chain2/randomize_egs.sh b/egs/wsj/s5/steps/nnet3/chain2/randomize_egs.sh new file mode 100755 index 000000000..99b0b237d --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain2/randomize_egs.sh @@ -0,0 +1,161 @@ +#!/bin/bash + +# Copyright 2019 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# Copyright 2019 Idiap Research Institute (Author: Srikanth Madikeri). Apache 2.0. +# +# This script takes nnet examples dumped by steps/chain/process_egs.sh, +# globally randomizes the egs, and divides into multiple .scp files. This is +# the form of egs which is consumed by the training script. All this is done +# only by manipulating the contents of .scp files. To keep locality of disk +# access, we only randomize blocks of egs (e.g. blocks containing 128 groups of +# sequences). This doesn't defeat randomization, because both process_egs.sh +# and the training script use nnet3-shuffle-egs to do more local randomization. + +# Later on, we'll have a multilingual/multi-input-dir version fo this script +# that combines egs from various data sources and possibly multiple languages. +# This version assumes there is just one language. + +# Begin configuration section. +cmd=run.pl + +groups_per_block=128 # The 'groups' are the egs in the scp file from + # process_egs.sh, containing '--chunks-per-group' sequences + # each. +num_blocks=256 + +frames_per_job=3000000 # The number of frames of data we want to process per + # training job (will determine how long each job takes, + # and the frequency of model averaging. This was + # previously called --frames-per-iter, but + # --frames-per-job is clearer as each job does this + # many. + +num_groups_combine=1000 # the number of groups from the training set that we + # randomly choose as input to nnet3-chain-combine; + # these will go to combine.scp. train_subset.scp and + # heldout_subset.scp are, for now, just copied over + # from the input. + +# Later we may provide a mechanism to change the language name; for now we +# just copy it from the input. + + +srand=0 +stage=0 + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + + +if [ $# != 2 ]; then + echo "Usage: $0 [opts] " + echo " e.g.: $0 --frames-per-job 2000000 exp/chain/tdnn1a_sp/processed_egs exp/chain/tdnn1a_sp/egs" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options (alternative to this" + echo " # command line)" + echo " --cmd (utils/run.pl;utils/queue.pl ) # how to run jobs." + echo " --groups-per-block # The number of groups (i.e. previously merged egs" + echo " # containing --chunks-per-group chunks) to to consider " + echo " # as one block, where whole blocks are randomized;" + echo " # smaller means more complete randomization but less" + echo " # local disk access." + echo " --frames-per-job # The number of input frames (not counting context)" + echo " # that we aim to have in each scp file after" + echo " # randomization and splitting." + echo " --num-groups-combine # The number of randomly chosen groups to" + echo " # put in the subset in 'combine.scp' which will" + echo " # be used in nnet3-chain-combine to decide which" + echo " # models to average over." + echo " --stage # Used to run this script from somewhere in" + echo " # the middle." + echo " --srand # Random seed, affects randomization." + exit 1; +fi + +processed_egs_dir=$1 +dir=$2 + +# die on error or undefined variable. +set -e -u + +if ! steps/chain/validate_processed_egs.sh $processed_egs_dir; then + echo "$0: could not validate input directory $processed_egs_dir" + exit 1 +fi + +# Work out how many groups per job and how many frames per job we'll have + +info_in=$processed_egs_dir/info.txt + +# num_scp_files is the number of archives +num_input_frames=$(awk '/^num_input_frames/ { nif=$2; print nif}' $info_in) +frames_per_chunk_avg=$(awk '/^frames_per_chunk_avg/ { fpc=$2; print fpc}' $info_in) +num_chunks=$(awk '/^num_chunks/ { nc=$2; print nc}' $info_in) +num_scp_files=$[(num_chunks * frames_per_chunk_avg)/frames_per_job +1] +[ $num_scp_files -eq 0 ] && num_scp_files=1 + +frames_per_scp_file=$[(num_chunks*frames_per_chunk_avg)/num_scp_files] # because it may be slightly different from frames_per_job + + +mkdir -p $dir/temp + +if [ -d $dir/misc ]; then + rm -r $dir/misc +fi + +mkdir -p $dir/misc +cp $processed_egs_dir/misc/* $dir/misc + +utils/shuffle_list.pl $processed_egs_dir/train.scp > $dir/temp/train.scp +utils/split_scp.pl $dir/temp/train.scp $(for i in $(seq $num_blocks); do echo $dir/temp/train.$i.scp; done) +for i in `seq $num_blocks`; do + utils/split_scp.pl <(utils/shuffle_list.pl $dir/temp/train.$i.scp) $(for j in $(seq $num_scp_files); do echo $dir/temp/train.$i.$j.scp; done) +done +for j in `seq $num_scp_files`; do + cat $dir/temp/train.*.$j.scp | utils/shuffle_list.pl > $dir/train.$j.scp +done +rm -rf $dir/temp & + +cp $processed_egs_dir/heldout_subset.scp $processed_egs_dir/train_subset.scp $dir/ + + +# note: there is only one language in $processed_egs_dir (any +# merging would be done at the randomization stage but that is not supported yet). + +lang=$(awk '/^lang / { print $2; }' <$processed_egs_dir/info.txt) + +# We'll store info files per language, containing the part of the information +# that is language-specific, plus a single global info.txt containing stuff that +# is not language specific. +# This will get more complicated once we actually support multiple languages, +# and when we allow multiple input processed egs dirs for the same language. + +grep -v -E '^dir_type|^lang|^feat_dim' <$processed_egs_dir/info.txt | \ + cat <(echo "dir_type randomized_chain_egs") - > $dir/info_$lang.txt + + +cat <$dir/info.txt +dir_type randomized_chain_egs +num_scp_files $num_scp_files +langs $lang +frames_per_scp_file $frames_per_scp_file +EOF +# frames_per_job, after rounding, becomes frames_per_scp_file. + +# note: frames_per_chunk_avg will be present in the info.txt file as well as +# the per-language files. +grep -E '^feat_dim|^frames_per_chunk_avg' <$processed_egs_dir/info.txt >>$dir/info.txt + + + +if ! cat $dir/info.txt | awk '{if (NF == 1) exit(1);}'; then + echo "$0: we failed to obtain at least one of the fields in $dir/info.txt" + exit 1 +fi + + +wait; +echo "$0: Finished randomizing egs" diff --git a/egs/wsj/s5/steps/nnet3/chain2/train.sh b/egs/wsj/s5/steps/nnet3/chain2/train.sh new file mode 100755 index 000000000..d2c653626 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain2/train.sh @@ -0,0 +1,263 @@ +#!/bin/bash + +# Copyright 2019 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# Copyright 2019 Idiap Research Institute (Author: Srikanth Madikeri). Apache 2.0. + + +# Begin configuration section +stage=-2 +cmd=run.pl +gpu_cmd_opt= +leaky_hmm_coefficient=0.1 +xent_regularize=0.1 +apply_deriv_weights=false # you might want to set this to true in unsupervised training + # scenarios. +memory_compression_level=2 # Enables us to use larger minibatch size than we + # otherwise could, but may not be optimal for speed + # (--> set to 0 if you have plenty of memory. +dropout_schedule= +srand=0 +max_param_change=1.0 # we use a smaller than normal default (it's normally + # 2.0), because there are two models (bottom and top). +use_gpu=yes # can be "yes", "no", "optional", "wait" +print_interval=10 +momentum=0.0 +parallel_train_opts= +verbose_opt= + +common_opts= # Options passed through to nnet3-chain-train and nnet3-chain-combine + +num_epochs=4.0 # Note: each epoch may actually contain multiple repetitions of + # the data, for various reasons: + # using the --num-repeats option in process_egs.sh + # data augmentation + # different data shifts (this includes 3 different shifts + # of the data if frame_subsampling_factor=3 (see $dir/init/info.txt) + +num_jobs_initial=1 +num_jobs_final=1 +initial_effective_lrate=0.001 +final_effective_lrate=0.0001 +groups_per_minibatch=32 # This is how you set the minibatch size. Note: if + # chunks_per_group=4, this would mean 128 chunks per + # minibatch. + +max_iters_combine=80 +max_models_combine=20 +diagnostic_period=5 # Get diagnostics every this-many iterations + +shuffle_buffer_size=1000 # This "buffer_size" variable controls randomization of the groups + # on each iter. + + +l2_regularize= + +# End configuration section + + + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + + +if [ $# != 2 ]; then + echo "Usage: $0 [options] " + echo " e.g.: $0 exp/chain/tdnn1a_sp/egs exp/chain/tdnn1a_sp" + echo "" + echo " TODO: more documentation" + exit 1 +fi + +egs_dir=$1 +dir=$2 + +set -e -u # die on failed command or undefined variable + +steps/chain2/validate_randomized_egs.sh $egs_dir + +for f in $dir/init/info.txt; do + if [ ! -f $f ]; then + echo "$0: expected file $f to exist" + exit 1 + fi +done +cat $egs_dir/info.txt >> $dir/init/info.txt + + +frame_subsampling_factor=$(awk '/^frame_subsampling_factor/ {print $2}' <$dir/init/info.txt) +num_scp_files=$(awk '/^num_scp_files/ {print $2}' <$egs_dir/info.txt) + +if [ $stage -le -2 ]; then + echo "$0: Generating training schedule" + steps/chain2/internal/get_train_schedule.py \ + --frame-subsampling-factor=$frame_subsampling_factor \ + --num-jobs-initial=$num_jobs_initial \ + --num-jobs-final=$num_jobs_final \ + --num-epochs=$num_epochs \ + --dropout-schedule="$dropout_schedule" \ + --num-scp-files=$num_scp_files \ + --frame-subsampling-factor=$frame_subsampling_factor \ + --initial-effective-lrate=$initial_effective_lrate \ + --final-effective-lrate=$final_effective_lrate \ + --schedule-out=$dir/schedule.txt +fi + + +# won't work at Idiap +#if [ "$use_gpu" != "no" ]; then gpu_cmd_opt="--gpu 1"; else gpu_cmd_opt=""; fi + +num_iters=$(wc -l <$dir/schedule.txt) + +echo "$0: will train for $num_epochs epochs = $num_iters iterations" + +# source the 1st line of schedule.txt in the shell; this sets +# lrate and dropout_opt, among other variables. +. <(head -n 1 $dir/schedule.txt) +langs=$(awk '/^langs/ { $1=""; print; }' <$dir/init/info.txt) + +mkdir -p $dir/log + +# Copy models with initial learning rate and dropout options from $dir/init to $dir/0 +#for lang in $langs; do +if [ $stage -le -1 ]; then + # run.pl $dir/log/init_model_default.log \ + # nnet3-am-copy --learning-rate=$lrate $dropout_opt $dir/init/default.mdl $dir/0.mdl + echo "$0: Copying transition model" + cp $dir/init/default.raw $dir/0.raw + cp $dir/init/default_trans.mdl $dir/0_trans.mdl + # nnet3-am-copy "--edits=rename-node old-name=output new-name=output-default; rename-node old-name=output-xent new-name=output-default-xent;" - $dir/0.mdl +#done +fi + + +l2_regularize_opt="" +if [ ! -z $l2_regularize ]; then + l2_regularize_opt="--l2-regularize=$l2_regularize" +fi + +x=0 +if [ $stage -gt $x ]; then x=$stage; fi + +[ $max_models_combine -gt $[num_iters/2] ] && max_models_combine=$[num_iters/2]; +combine_start_iter=$[num_iters+1-max_models_combine] + +while [ $x -lt $num_iters ]; do + # Source some variables fromm schedule.txt. The effect will be something + # like the following: + # iter=0; num_jobs=2; inv_num_jobs=0.5; scp_indexes=(pad 1 2); frame_shifts=(pad 1 2); dropout_opt="--edits='set-dropout-proportion name=* proportion=0.0'" lrate=0.002 + . <(grep "^iter=$x;" $dir/schedule.txt) + + echo "$0: training, iteration $x of $num_iters, num-jobs is $num_jobs" + + next_x=$[$x+1] + den_fst_dir=$egs_dir/misc + transform_dir=$dir/init + model_out_prefix=$dir/${next_x} + model_out=${model_out_prefix}.mdl + + + # for the first 4 iterations, plus every $diagnostic_period iterations, launch + # some diagnostic processes. We don't do this on iteration 0, because + # the batchnorm stats wouldn't be ready + if [ $x -gt 0 ] && [ $[x%diagnostic_period] -eq 0 -o $x -lt 5 ]; then + + [ -f $dir/.error_diagnostic ] && rm $dir/.error_diagnostic + for name in train heldout; do + $cmd $gpu_cmd_opt $dir/log/diagnostic_${name}.$x.log \ + nnet3-chain-train2 --use-gpu=$use_gpu \ + --leaky-hmm-coefficient=$leaky_hmm_coefficient \ + --xent-regularize=$xent_regularize \ + $l2_regularize_opt \ + --print-interval=10 \ + "nnet3-am-init $dir/0_trans.mdl $dir/${x}.raw - | nnet3-am-copy --learning-rate=$lrate - - |" $den_fst_dir \ + "ark:nnet3-chain-merge-egs --minibatch-size=$groups_per_minibatch scp:$egs_dir/${name}_subset.scp ark:-|" \ + $dir/${next_x}_${name}.mdl || touch $dir/.error_diagnostic & + done + fi + + cache_io_opt="--write-cache=$dir/cache.$next_x" + if [ $x -gt 1 -a -f $dir/cache.$x ]; then + cache_io_opt="$cache_io_opt --read-cache=$dir/cache.$x" + fi + for j in $(seq $num_jobs); do + scp_index=${scp_indexes[$j]} + frame_shift=${frame_shifts[$j]} + + # not implemented yet + $cmd $gpu_cmd_opt $dir/log/train.$x.$j.log \ + nnet3-chain-train2 \ + $parallel_train_opts $verbose_opt \ + $cache_io_opt \ + --use-gpu=$use_gpu --apply-deriv-weights=$apply_deriv_weights \ + --leaky-hmm-coefficient=$leaky_hmm_coefficient --xent-regularize=$xent_regularize \ + --print-interval=$print_interval --max-param-change=$max_param_change \ + --momentum=$momentum \ + --l2-regularize-factor=$inv_num_jobs \ + $l2_regularize_opt \ + --srand=$srand \ + "nnet3-am-init $dir/0_trans.mdl $dir/${x}.raw - | nnet3-am-copy --learning-rate=$lrate - - |" $den_fst_dir \ + "ark:nnet3-chain-copy-egs --frame-shift=$frame_shift scp:$egs_dir/train.$scp_index.scp ark:- | nnet3-chain-shuffle-egs --buffer-size=$shuffle_buffer_size --srand=$x ark:- ark:- | nnet3-chain-merge-egs --minibatch-size=$groups_per_minibatch ark:- ark:-|" \ + ${model_out_prefix}.$j.raw || touch $dir/.error & + done + wait + if [ -f $dir/.error ]; then + echo "$0: error detected training on iteration $x" + exit 1 + fi + if [ $x -ge 1 ]; then + models_to_average=$(for j in `seq $num_jobs`; do echo ${model_out_prefix}.$j.raw; done) + $cmd $dir/log/average.$x.log \ + nnet3-average $models_to_average $dir/$next_x.raw || exit 1; + rm $models_to_average + else + lang=$(echo $langs | awk '{print $1}') + model_index=`steps/nnet3/chain2/internal/get_best_model.sh --output output-${lang} $dir/log/train.$x.*.log` + cp ${model_out_prefix}.$model_index.raw $dir/$next_x.raw + rm ${model_out_prefix}.*.raw + fi + [ -f $dir/$x/.error_diagnostic ] && echo "$0: error getting diagnostics on iter $x" && exit 1; + + # TODO: cleanup + if [ -f $dir/cache.$x ]; then + rm $dir/cache.$x + fi + delete_iter=$[x-2] + if [ $delete_iter -lt $combine_start_iter ]; then + if [ -f $dir/$delete_iter.raw ]; then + rm $dir/$delete_iter.raw + fi + fi + rm $dir/${next_x}_{train,heldout}.mdl + x=$[x+1] +done + + + +if [ $stage -le $num_iters ]; then + echo "$0: doing model combination" + # nnet3-copy --edits="rename-node old-name=output new-name=output-dummy; rename-node old-name=output-default new-name=output" \ + # $dir/$num_iters.mdl $dir/final.raw + # nnet3-am-init $dir/0.mdl $dir/final.raw $dir/final.mdl + # exit 0 + den_fst_dir=$egs_dir/misc + input_models=$(for x in $(seq $combine_start_iter $num_iters); do echo $dir/${x}.raw; done) + output_model_dir=$dir/final + transform_dir=$dir/init + + $cmd $gpu_cmd_opt $dir/log/combine.log \ + nnet3-chain-combine2 --use-gpu=$use_gpu \ + --leaky-hmm-coefficient=$leaky_hmm_coefficient \ + --print-interval=10 \ + $den_fst_dir $input_models \ + "ark:nnet3-chain-merge-egs scp:$egs_dir/train_subset.scp ark:-|" \ + $dir/final.raw || exit 1; + nnet3-copy --edits="rename-node old-name=output new-name=output-dummy; rename-node old-name=output-default new-name=output" \ + $dir/final.raw - | \ + nnet3-am-init $dir/0_trans.mdl - $dir/final.mdl + +fi + +echo "$0: done" +exit 0 diff --git a/egs/wsj/s5/steps/nnet3/chain2/validate_processed_egs.sh b/egs/wsj/s5/steps/nnet3/chain2/validate_processed_egs.sh new file mode 100755 index 000000000..66067f7d9 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain2/validate_processed_egs.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# Copyright 2019 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# Copyright 2019 Idiap Research Institute (Author: Srikanth Madikeri). Apache 2.0. +# +# This script validates a directory containing 'processed' egs for 'chain' +# training, i.e. the output of process_egs.sh. It also helps to document the +# expectations on such a directory. + + +if [ -f path.sh ]; then . ./path.sh; fi + + +if [ $# != 1 ]; then + echo "Usage: $0 " + echo " e.g.: $0 exp/chain/tdnn1a_sp/processed_egs" + echo "" + echo "Validates that the processed-egs dir has the expected format" +fi + +dir=$1 + +# Note: the .ark files are not actually consumed directly downstream (only via +# the top-level .scp files), but we check them anyway for now. +for f in $dir/train.scp $dir/info.txt \ + $dir/heldout_subset.{ark,scp} $dir/train_subset.{ark,scp} \ + $dir/train.1.scp $dir/train.1.ark; do + if ! [ -f $f -a -s $f ]; then + echo "$0: expected file $f to exist and be nonempty." + exit 1 + fi +done + + +if [ $(awk '/^dir_type/ { print $2; }' <$dir/info.txt) != "processed_chain_egs" ]; then + grep dir_type $dir/info.txt + echo "$0: dir_type should be processed_chain_egs in $dir/info.txt" + exit 1 +fi + +lang=$(awk '/^lang / {print $2; }' <$dir/info.txt) + +for f in $dir/misc/$lang.{trans_mdl,normalization.fst,den.fst}; do + if ! [ -f $f -a -s $f ]; then + echo "$0: expected file $f to exist and be nonempty." + exit 1 + fi +done + +echo "$0: sucessfully validated processed egs in $dir" diff --git a/egs/wsj/s5/steps/nnet3/chain2/validate_randomized_egs.sh b/egs/wsj/s5/steps/nnet3/chain2/validate_randomized_egs.sh new file mode 100755 index 000000000..e16755fd2 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain2/validate_randomized_egs.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Copyright 2019 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# Copyright 2019 Idiap Research Institute (Author: Srikanth Madikeri). Apache 2.0. +# +# This script validates a directory containing 'randomized' egs for 'chain' +# training, i.e. the output of randomize_egs.sh (this is the final form of the +# egs which is consumed by the training script). It also helps to document the +# expectations on such a directory. + + +if [ -f path.sh ]; then . ./path.sh; fi + + +if [ $# != 1 ]; then + echo "Usage: $0 " + echo " e.g.: $0 exp/chain/tdnn1a_sp/egs" + echo "" + echo "Validates that the final (randomized) egs dir has the expected format" +fi + +dir=$1 + +# Note: the .ark files are not actually consumed directly downstream (only via +# the top-level .scp files), but we check them anyway for now. +for f in $dir/train.1.scp $dir/info.txt \ + $dir/heldout_subset.scp $dir/train_subset.scp; do + if ! [ -f $f -a -s $f ]; then + echo "$0: expected file $f to exist and be nonempty." + exit 1 + fi +done + + +if [ $(awk '/^dir_type/ { print $2; }' <$dir/info.txt) != "randomized_chain_egs" ]; then + grep dir_type $dir/info.txt + echo "$0: dir_type should be randomized_chaina_egs in $dir/info.txt" + exit 1 +fi + +langs=$(awk '/^langs / {$1 = ""; print; }' <$dir/info.txt) +num_scp_files=$(awk '/^num_scp_files / { print $2; }' <$dir/info.txt) + +if [ -z "$langs" ]; then + echo "$0: expecting the list of languages to be nonempty in $dir/info.txt" + exit 1 +fi + +for lang in $langs; do + for f in $dir/misc/$lang.{trans_mdl,normalization.fst,den.fst} $dir/info_${lang}.txt; do + if ! [ -f $f -a -s $f ]; then + echo "$0: expected file $f to exist and be nonempty." + exit 1 + fi + done +done + +for i in $(seq $num_scp_files); do + if ! [ -s $dir/train.$i.scp ]; then + echo "$0: expected file $dir/train.$i.scp to exist and be nonempty." + exit 1 + fi +done + + +echo "$0: sucessfully validated randomized egs in $dir" diff --git a/egs/wsj/s5/steps/nnet3/chain2/validate_raw_egs.sh b/egs/wsj/s5/steps/nnet3/chain2/validate_raw_egs.sh new file mode 100755 index 000000000..2c29693bb --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/chain2/validate_raw_egs.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Copyright 2019 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# Copyright 2019 Idiap Research Institute (Author: Srikanth Madikeri). Apache 2.0. +# +# This script validates a directory containing 'raw' egs for 'chain' training. +# It also helps to document the expectations on such a directory. + + + +if [ -f path.sh ]; then . ./path.sh; fi + + +if [ $# != 1 ]; then + echo "Usage: $0 " + echo " e.g.: $0 exp/chaina/tdnn1a_sp/raw_egs" + echo "" + echo "Validates that the raw-egs dir has the expected format" +fi + +dir=$1 + +for f in $dir/all.scp $dir/cegs.1.ark $dir/info.txt \ + $dir/misc/utt2spk; do + if ! [ -s $f ]; then + echo "$0: expected file $f to exist and be nonempty." + exit 1 + fi +done + + +if [ $(awk '/^dir_type/ { print $2; }' <$dir/info.txt) != "raw_chain_egs" ]; then + grep dir_type $dir/info.txt + echo "$0: dir_type should be raw_chain_egs in $dir/info.txt" + exit 1 +fi + +lang=$(awk '/^lang / {print $2; }' <$dir/info.txt) + +for f in $dir/misc/$lang.{trans_mdl,normalization.fst,den.fst}; do + if ! [ -s $f ]; then + echo "$0: expected file $f to exist and be nonempty." + exit 1 + fi +done + +echo "$0: sucessfully validated raw egs in $dir" diff --git a/egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py b/egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py index edc2f7e46..d08925899 100755 --- a/egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py +++ b/egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py @@ -100,7 +100,7 @@ class Nnet3Model(object): def __init__(self): self.input_dim = -1 self.output_dim = -1 - self.ivector_dim = -1 + self.ivector_dim = 0 self.counts = defaultdict(int) self.num_components = 0 self.components_read = 0 @@ -121,7 +121,7 @@ def add_component(self, component, pairs): if "" in pairs and self.input_dim == -1: self.input_dim = int(pairs[""]) - if "" in pairs and self.ivector_dim == -1: + if "" in pairs and self.ivector_dim == 0: self.ivector_dim = int(pairs[""]) # remove nnet2 specific tokens and catch descriptors @@ -163,7 +163,7 @@ def write_config(self, filename): config_string=config_string)) f.write("\n# Component nodes\n") - if self.ivector_dim != -1: + if self.ivector_dim != 0: f.write("input-node name=input dim={0}\n".format(self.input_dim-self.ivector_dim)) f.write("input-node name=ivector dim={0}\n".format(self.ivector_dim)) else: @@ -294,7 +294,7 @@ def parse_component(line, line_buffer): def parse_standard_component(component, line, line_buffer): # Ignores stats such as ValueSum and DerivSum line = consume_token(component, line) - pairs = re.findall("(<\w+>) ([\w.]+)", line) + pairs = re.findall("(<\w+>) ([\w.-]+)", line) return dict(pairs) @@ -364,7 +364,7 @@ def parse_end_of_component(component, line, line_buffer): def parse_affine_component(component, line, line_buffer): assert ("" in line) - pairs = dict(re.findall("(<\w+>) ([\w.]+)", line)) + pairs = dict(re.findall("(<\w+>) ([\w.-]+)", line)) # read the linear params and bias and convert it to a matrix weights = parse_weights(line_buffer) diff --git a/egs/wsj/s5/steps/nnet3/decode.sh b/egs/wsj/s5/steps/nnet3/decode.sh index bbd81a6db..c0b539a63 100755 --- a/egs/wsj/s5/steps/nnet3/decode.sh +++ b/egs/wsj/s5/steps/nnet3/decode.sh @@ -79,7 +79,11 @@ for f in $graphdir/HCLG.fst $data/feats.scp $model $extra_files; do done sdata=$data/split$nj; -cmvn_opts=`cat $srcdir/cmvn_opts` || exit 1; +if [ -f $srcdir/cmvn_opts ]; then + cmvn_opts=`cat $srcdir/cmvn_opts` +else + cmvn_opts="--norm-means=false --norm-vars=false" +fi thread_string= if $use_gpu; then if [ $num_threads -eq 1 ]; then @@ -122,6 +126,11 @@ frame_subsampling_opt= if [ -f $srcdir/frame_subsampling_factor ]; then # e.g. for 'chain' systems frame_subsampling_opt="--frame-subsampling-factor=$(cat $srcdir/frame_subsampling_factor)" +elif [ -f $srcdir/init/info.txt ]; then + frame_subsampling_factor=$(awk '/^frame_subsampling_factor/ {print $2}' <$srcdir/init/info.txt) + if [ ! -z $frame_subsampling_factor ]; then + frame_subsampling_opt="--frame-subsampling-factor=$frame_subsampling_factor" + fi fi if [ $stage -le 1 ]; then diff --git a/egs/wsj/s5/steps/segmentation/prepare_targets_gmm.sh b/egs/wsj/s5/steps/segmentation/prepare_targets_gmm.sh index 20bcfd96d..76025f4a3 100755 --- a/egs/wsj/s5/steps/segmentation/prepare_targets_gmm.sh +++ b/egs/wsj/s5/steps/segmentation/prepare_targets_gmm.sh @@ -46,6 +46,7 @@ overlap_duration=2.5 max_remaining_duration=5 # If the last remaining piece when splitting uniformly # is smaller than this duration, then the last piece # is merged with the previous. +remove_mismatch_frames=true # List of weights on labels obtained from alignment, # labels obtained from decoding and default labels in out-of-segment regions @@ -108,7 +109,7 @@ for f in $in_whole_data_dir/feats.scp $in_data_dir/segments \ fi done -utils/validate_data_dir.sh $in_data_dir || exit 1 +utils/validate_data_dir.sh --no-feats $in_data_dir || exit 1 utils/validate_data_dir.sh --no-text $in_whole_data_dir || exit 1 if ! cat $garbage_phones_list $silence_phones_list | \ @@ -159,7 +160,7 @@ whole_data_dir=$dir/$whole_data_id # Obtain supervision-constrained lattices ############################################################################### sup_lats_dir=$dir/`basename ${ali_model_dir}`_sup_lats_${data_id} -if [ $stage -le 2 ]; then +if [ $stage -le 3 ]; then steps/align_fmllr_lats.sh --nj $nj --cmd "$train_cmd" \ ${data_dir} ${lang} ${ali_model_dir} $sup_lats_dir || exit 1 fi @@ -170,7 +171,7 @@ fi uniform_seg_data_dir=$dir/${whole_data_id}_uniformseg_${max_segment_duration}sec uniform_seg_data_id=`basename $uniform_seg_data_dir` -if [ $stage -le 3 ]; then +if [ $stage -le 4 ]; then utils/data/get_segments_for_data.sh ${whole_data_dir} > \ ${whole_data_dir}/segments @@ -193,7 +194,7 @@ model_id=$(basename $model_dir) ############################################################################### if [ -z "$graph_dir" ]; then graph_dir=$dir/$model_id/graph - if [ $stage -le 4 ]; then + if [ $stage -le 5 ]; then if [ ! -f $graph_dir/HCLG.fst ]; then rm -r $dir/lang_test 2>/dev/null || true cp -r $lang_test/ $dir/lang_test @@ -207,7 +208,7 @@ fi ############################################################################### model_id=$(basename $model_dir) decode_dir=$dir/${model_id}/decode_${uniform_seg_data_id} -if [ $stage -le 5 ]; then +if [ $stage -le 6 ]; then mkdir -p $decode_dir cp $model_dir/{final.mdl,final.mat,*_opts,tree} $dir/${model_id} @@ -228,7 +229,7 @@ ali_model_id=`basename $ali_model_dir` # The target values are obtained by summing up posterior probabilites of # arcs from lattice-arc-post over silence, speech and garbage phones. ############################################################################### -if [ $stage -le 6 ]; then +if [ $stage -le 7 ]; then steps/segmentation/lats_to_targets.sh --cmd "$train_cmd" \ --silence-phones "$silence_phones_list" \ --garbage-phones "$garbage_phones_list" \ @@ -237,7 +238,7 @@ if [ $stage -le 6 ]; then $dir/${ali_model_id}_${data_id}_sup_targets fi -if [ $stage -le 7 ]; then +if [ $stage -le 8 ]; then steps/segmentation/lats_to_targets.sh --cmd "$train_cmd" \ --silence-phones "$silence_phones_list" \ --garbage-phones "$garbage_phones_list" \ @@ -253,7 +254,7 @@ fi # for the manual segments, these are converted to whole recording-levels # by inserting [ 0 0 0 ] for the out-of-manual segment regions. ############################################################################### -if [ $stage -le 8 ]; then +if [ $stage -le 9 ]; then steps/segmentation/convert_targets_dir_to_whole_recording.sh --cmd "$train_cmd" --nj $reco_nj \ $data_dir $whole_data_dir \ $dir/${ali_model_id}_${data_id}_sup_targets \ @@ -268,7 +269,7 @@ fi ############################################################################### # Convert the targets from decoding to whole recording. ############################################################################### -if [ $stage -le 9 ]; then +if [ $stage -le 10 ]; then steps/segmentation/convert_targets_dir_to_whole_recording.sh --cmd "$train_cmd" --nj $reco_nj \ $dir/${uniform_seg_data_id} $whole_data_dir \ $dir/${model_id}_${uniform_seg_data_id}_targets \ @@ -285,7 +286,7 @@ fi # We assume in this setup that this is silence i.e. [ 1 0 0 ]. ############################################################################### -if [ $stage -le 10 ]; then +if [ $stage -le 11 ]; then echo " [ 1 0 0 ]" > $dir/default_targets.vec steps/segmentation/get_targets_for_out_of_segments.sh --cmd "$train_cmd" \ --nj $reco_nj --frame-subsampling-factor 3 \ @@ -301,9 +302,9 @@ fi # disagree (more than 0.5 probability on different classes), then those frames # are removed by setting targets to [ 0 0 0 ]. ############################################################################### -if [ $stage -le 11 ]; then +if [ $stage -le 12 ]; then steps/segmentation/merge_targets_dirs.sh --cmd "$train_cmd" --nj $reco_nj \ - --weights $merge_weights --remove-mismatch-frames true \ + --weights $merge_weights --remove-mismatch-frames $remove_mismatch_frames \ $whole_data_dir \ $dir/${ali_model_id}_${whole_data_id}_sup_targets_sub3 \ $dir/${model_id}_${whole_data_id}_targets_sub3 \ diff --git a/egs/wsj/s5/utils/filter_scps.pl b/egs/wsj/s5/utils/filter_scps.pl index d701f5fd2..418f8f73e 100755 --- a/egs/wsj/s5/utils/filter_scps.pl +++ b/egs/wsj/s5/utils/filter_scps.pl @@ -90,6 +90,7 @@ # Some variables that we set to produce a warning. $warn_uncovered = 0; +$warn_multiply_covered = 0; for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) { $idlist_n = $idlist; @@ -132,6 +133,9 @@ $warn_uncovered = 1; } else { @jobs = @{$id2jobs{$id}}; # this dereferences the array reference. + if (@jobs > 1) { + $warn_multiply_covered = 1; + } foreach $job_id (@jobs) { if (!defined $job2output{$job_id}) { die "Likely code error"; @@ -160,3 +164,7 @@ if ($warn_uncovered && $print_warnings) { print STDERR "filter_scps.pl: warning: some input lines did not get output\n"; } +if ($warn_multiply_covered && $print_warnings) { + print STDERR "filter_scps.pl: warning: some input lines were output to multiple files [OK if splitting per utt] " . + join(" ", @ARGV) . "\n"; +} diff --git a/egs/wsj/s5/utils/parallel/slurm.pl b/egs/wsj/s5/utils/parallel/slurm.pl index cfa634aeb..4a2a3b7c4 100755 --- a/egs/wsj/s5/utils/parallel/slurm.pl +++ b/egs/wsj/s5/utils/parallel/slurm.pl @@ -180,9 +180,10 @@ sub exec_command { default gpu=0 option gpu=0 -p shared option gpu=* -p gpu --gres=gpu:$0 --time 4:0:0 # this has to be figured out +EOF + # note: the --max-jobs-run option is supported as a special case # by slurm.pl and you don't have to handle it in the config file. -EOF # Here the configuration options specified by the user on the command line # (e.g. --mem 2G) are converted to options to the qsub system as defined in diff --git a/egs/wsj/s5/utils/split_data.sh b/egs/wsj/s5/utils/split_data.sh index a3105351a..bc5894e75 100755 --- a/egs/wsj/s5/utils/split_data.sh +++ b/egs/wsj/s5/utils/split_data.sh @@ -67,11 +67,6 @@ if [ -f $data/text ] && [ $nu -ne $nt ]; then echo "** use utils/fix_data_dir.sh to fix this." fi -ns=`cat $data/spk2utt | wc -l` -if [ $numsplit -gt $ns ] && [ $split_per_spk = "true" ]; then - echo "You should reduce the number of jobs ($numsplit) as there are not enough speakers ($ns)." - exit 1 -fi if $split_per_spk; then utt2spk_opt="--utt2spk=$data/utt2spk" @@ -81,11 +76,6 @@ else utt="utt" fi -utt2dur_opt= -if [ -f $data/utt2dur ]; then - utt2dur_opt="--utt2dur=$data/utt2dur" -fi - s1=$data/split${numsplit}${utt}/1 if [ ! -d $s1 ]; then need_to_split=true @@ -118,7 +108,7 @@ fi which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM -utils/split_scp.pl $utt2spk_opt $utt2dur_opt $data/utt2spk $utt2spks || exit 1 +utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 for n in `seq $numsplit`; do dsn=$data/split${numsplit}${utt}/$n diff --git a/egs/wsj/s5/utils/split_scp.pl b/egs/wsj/s5/utils/split_scp.pl index 3ca14dbea..dc798282f 100755 --- a/egs/wsj/s5/utils/split_scp.pl +++ b/egs/wsj/s5/utils/split_scp.pl @@ -47,7 +47,6 @@ $num_jobs = 0; $job_id = 0; $utt2spk_file = ""; -$utt2dur_file = ""; $one_based = 0; for ($x = 1; $x <= 3 && @ARGV > 0; $x++) { @@ -60,12 +59,6 @@ $utt2spk_file=$1; shift; } - - if ($ARGV[0] =~ "--utt2dur=(.+)") { - $utt2dur_file=$1; - shift; - } - if ($ARGV[0] eq '--one-based') { $one_based = 1; shift @ARGV; @@ -76,7 +69,6 @@ $job_id - $one_based >= $num_jobs)) { die "$0: Invalid job number/index values for '-j $num_jobs $job_id" . ($one_based ? " --one-based" : "") . "'\n" - } $one_based @@ -84,8 +76,8 @@ if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) { die -"Usage: split_scp.pl [--utt2spk=] [--utt2dur=] in.scp out1.scp out2.scp ... - or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=] [--utt2dur=] in.scp [out.scp] +"Usage: split_scp.pl [--utt2spk=] in.scp out1.scp out2.scp ... + or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=] in.scp [out.scp] ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n"; } @@ -103,113 +95,8 @@ } } } -if ($utt2spk_file ne "" && $utt2dur_file ne "" ) { # --utt2spk and --utt2dur - open(U, "<$utt2spk_file") || die "Failed to open utt2spk file $utt2spk_file"; - while() { - @A = split; - @A == 2 || die "Bad line $_ in utt2spk file $utt2spk_file"; - ($u,$s) = @A; - $utt2spk{$u} = $s; - } - $dursum = 0.0; - open(U, "<$utt2dur_file") || die "Failed to open utt2dur file $utt2dur_file"; - while() { - @A = split; - @A == 2 || die "Bad line $_ in utt2spk file $utt2dur_file"; - ($u,$d) = @A; - $utt2dur{$u} = $d; - $dursum += $d; - } - open(I, "<$inscp") || die "Opening input scp file $inscp"; - @spkrs = (); - while() { - @A = split; - if(@A == 0) { die "Empty or space-only line in scp file $inscp"; } - $u = $A[0]; - $s = $utt2spk{$u}; - if(!defined $s) { die "No such utterance $u in utt2spk file $utt2spk_file"; } - if(!defined $spk_count{$s}) { - push @spkrs, $s; - $spk_count{$s} = 0; - $spk_data{$s} = []; # ref to new empty array. - } - if(!defined $spk2utt{$s}) { - $spk2utt{$s} = []; - } - $spk_count{$s}++; - push @{$spk_data{$s}}, $_; - push @{$spk2utt{$s}}, $u; - } - - $numspks = @spkrs; # number of speakers. - $numscps = @OUTPUTS; # number of output files. - if ($numspks < $numscps) { - die "Refusing to split data because number of speakers $numspks is less " . - "than the number of output .scp files $numscps"; - } - for($scpidx = 0; $scpidx < $numscps; $scpidx++) { - $scparray[$scpidx] = []; # [] is array reference. - } - $splitdur = $dursum / $numscps; - $dursum = 0.0; - $scpidx = 0; - for my $spk (sort (keys %spk2utt)) { - $scpcount[$scpidx] += $spk_count{$spk}; - push @{$scparray[$scpidx]}, $spk; - for my $utt (@{$spk2utt{$spk}}) { - $dur = $utt2dur{$utt}; - $dursum += $dur; - } - if ( $dursum >= $splitdur ) { - $scpidx += 1; - $dursum = 0.0; - } - } - - # Because scpidx might not have gone up to numscps (because all utts from one - # speaker go into one split means a major imbalance will mean not all splits - # are filled), move one speaker inside scparray to the indices which don't have - # any. - if ( $scpidx + 1 < $numscps || @{$scparray[$scpidx]} == 0 ) { - $scpdone = $scpidx; - if ( @{$scparray[$scpidx]} == 0 ) { - $scpdone -= 1; - } - for(; $scpidx < $numscps; $scpidx++) { - $i = 0; - for(; $i < $scpdone; $i++) { - $numspk = @{$scparray[$i]}; - if ($numspk > 1) { - last; - } - } - $spk = pop @{$scparray[$i]}; - $scpcount[$i] -= $spk_count{$spk}; - - push @{$scparray[$scpidx]}, $spk; - $scpcount[$scpidx] += $spk_count{$spk}; - } - } - - # Now print out the files... - for($scpidx = 0; $scpidx < $numscps; $scpidx++) { - $scpfn = $OUTPUTS[$scpidx]; - open(F, ">$scpfn") || die "Could not open scp file $scpfn for writing."; - $count = 0; - if(@{$scparray[$scpidx]} == 0) { - print STDERR "Error: split_scp.pl producing empty .scp file $scpfn (too many splits and too few speakers?)\n"; - $error = 1; - } else { - foreach $spk ( sort @{$scparray[$scpidx]} ) { - print F @{$spk_data{$spk}}; - $count += $spk_count{$spk}; - } - if($count != $scpcount[$scpidx]) { die "Count mismatch [code error]"; } - } - close(F); - } -} elsif ($utt2spk_file ne "") { # We have the --utt2spk option... +if ($utt2spk_file ne "") { # We have the --utt2spk option... open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n"; while(<$u_fh>) { @A = split; diff --git a/src/base/kaldi-error.cc b/src/base/kaldi-error.cc index 2dbc73182..12f972ee8 100644 --- a/src/base/kaldi-error.cc +++ b/src/base/kaldi-error.cc @@ -33,7 +33,11 @@ #include "base/kaldi-common.h" #include "base/kaldi-error.h" + +// KALDI_GIT_HEAD is useless currently in full repo +#if !defined(KALDI_VERSION) #include "base/version.h" +#endif namespace kaldi { diff --git a/src/bin/Makefile b/src/bin/Makefile index 7cb01b501..a04a84e21 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -22,7 +22,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-sum build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ transform-vec align-text matrix-dim post-to-smat compile-graph \ - compare-int-vector + compare-int-vector latgen-incremental-mapped compute-gop OBJFILES = diff --git a/src/bin/compute-gop.cc b/src/bin/compute-gop.cc new file mode 100644 index 000000000..63b42212e --- /dev/null +++ b/src/bin/compute-gop.cc @@ -0,0 +1,227 @@ +// bin/compute-gop.cc + +// Copyright 2019 Junbo Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +/** + This code computes Goodness of Pronunciation (GOP) and extracts phone-level + pronunciation feature for mispronunciations detection tasks, the reference: + + "Improved mispronunciation detection with deep neural network trained acoustic + models and transfer learning based logistic regression classifiers" + by Hu et al., Speech Comunication, 2015. + + GOP is widely used to detect mispronunciations. The DNN-based GOP was defined + as the log phone posterior ratio between the canonical phone and the one with + the highest score. + + To compute GOP, we need to compute Log Phone Posterior (LPP): + LPP(p) = \log p(p|\mathbf o; t_s,t_e) + where {\mathbf o} is the input observations, p is the canonical phone, + {t_s, t_e} are the start and end frame indexes. + + LPP could be calculated as the average of the frame-level LPP, i.e. p(p|o_t): + LPP(p) = \frac{1}{t_e-t_s+1} \sum_{t=t_s}^{t_e}\log p(p|o_t) + p(p|o_t) = \sum_{s \in p} p(s|o_t) + where s is the senone label, {s|s \in p} is the states belonging to those + triphones whose current phone is p. + + GOP is extracted from LPP: + GOP(p) = \log \frac{LPP(p)}{\max_{q\in Q} LPP(q)} + + An array of a phone-level feature for each phone is extracted as well, which + could be used to train a classifier to detect mispronunciations. Normally the + classifier-based approach archives better performance than the GOP-based approach. + + The phone-level feature is defined as: + {[LPP(p_1),\cdots,LPP(p_M), LPR(p_1|p_i), \cdots, LPR(p_j|p_i),\cdots]}^T + + where the Log Posterior Ratio (LPR) between phone p_j and p_i is defined as: + LPR(p_j|p_i) = \log p(p_j|\mathbf o; t_s, t_e) - \log p(p_i|\mathbf o; t_s, t_e) + */ + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "hmm/hmm-utils.h" +#include "hmm/tree-accu.h" +#include "hmm/posterior.h" + +namespace kaldi { + +/** FrameLevelLpp compute a log posterior for pure-phones by sum the posterior + of the states belonging to those triphones whose current phone is the canonical + phone: + + p(p|o_t) = \sum_{s \in p} p(s|o_t), + + where s is the senone label, {s|s \in p} is the states belonging to those + riphones whose current phone is the canonical phone p. + + */ +void FrameLevelLpp(const SubVector &prob_row, + const std::vector > &pdf2phones, + const std::vector *phone_map, + Vector *out_frame_level_lpp) { + for (int32 i = 0; i < prob_row.Dim(); i++) { + std::set dest_idxs; + for (int32 ph : pdf2phones.at(i)) { + dest_idxs.insert((phone_map != NULL) ? (*phone_map)[ph] - 1 : ph - 1); + } + + for (int32 idx : dest_idxs) { + KALDI_ASSERT(idx < out_frame_level_lpp->Dim()); + (*out_frame_level_lpp)(idx) += prob_row(i); + } + } + out_frame_level_lpp->ApplyLog(); +} + +} // namespace kaldi + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Compute Goodness Of Pronunciation (GOP) from a matrix of " + "probabilities (e.g. from nnet3-compute).\n" + "Usage: compute-gop [options] " + " " + "[]\n" + "e.g.:\n" + " nnet3-compute [args] | compute-gop 1.mdl ark:ali-phone.1 ark:-" + " ark:gop.1 ark:phone-feat.1\n"; + + ParseOptions po(usage); + + bool log_applied = true; + std::string phone_map_rxfilename; + + po.Register("log-applied", &log_applied, + "If true, assume the input probabilities have been applied log."); + po.Register("phone-map", &phone_map_rxfilename, + "File name containing old->new phone mapping (each line is: " + "old-integer-id new-integer-id)"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4 && po.NumArgs() != 5) { + po.PrintUsage(); + exit(1); + } + + std::string model_filename = po.GetArg(1), + alignments_rspecifier = po.GetArg(2), + prob_rspecifier = po.GetArg(3), + gop_wspecifier = po.GetArg(4), + feat_wspecifier = po.GetArg(5); + + TransitionModel trans_model; + { + bool binary; + Input ki(model_filename, &binary); + trans_model.Read(ki.Stream(), binary); + } + std::vector > pdf2phones; + GetPdfToPhonesMap(trans_model, &pdf2phones); + int32 phone_num = trans_model.NumPhones(); + + std::vector phone_map; + if (phone_map_rxfilename != "") { + ReadPhoneMap(phone_map_rxfilename, &phone_map); + phone_num = phone_map[phone_map.size() - 1]; + } + + RandomAccessInt32VectorReader alignment_reader(alignments_rspecifier); + SequentialBaseFloatMatrixReader prob_reader(prob_rspecifier); + PosteriorWriter gop_writer(gop_wspecifier); + BaseFloatMatrixWriter feat_writer(feat_wspecifier); + + int32 num_done = 0; + for (; !prob_reader.Done(); prob_reader.Next()) { + std::string key = prob_reader.Key(); + auto alignment = alignment_reader.Value(key); + Matrix &probs = prob_reader.Value(); + if (log_applied) probs.ApplyExp(); + + int32 frame_num = alignment.size(); + if (alignment.size() != probs.NumRows()) { + KALDI_WARN << "The frame numbers of alignment and prob are not equal."; + if (frame_num > probs.NumRows()) frame_num = probs.NumRows(); + } + + KALDI_ASSERT(frame_num > 0); + int32 cur_phone_id = alignment[0] - 1; // start by 0, skipping + int32 duration = 0; + Vector phone_level_feat(phone_num * 2); // LPPs and LPRs + SubVector lpp_part(phone_level_feat, 0, phone_num); + std::vector > phone_level_feat_stdvector; + Posterior posterior_gop; + for (int32 i = 0; i < frame_num; i++) { + // Calculate LPP and LPR for each pure-phone + Vector frame_level_lpp(phone_num); + FrameLevelLpp(probs.Row(i), pdf2phones, + (phone_map_rxfilename != "") ? &phone_map : NULL, + &frame_level_lpp); + + // LPP(p)=\frac{1}{t_e-t_s+1} \sum_{t=t_s}^{t_e}\log p(p|o_t) + lpp_part.AddVec(1, frame_level_lpp); + duration++; + + int32 next_phone_id = (i < frame_num - 1) ? alignment[i + 1] - 1: -1; + if (next_phone_id != cur_phone_id) { + // The current phone's feature have been ready + lpp_part.Scale(1.0 / duration); + + // LPR(p_j|p_i)=\log p(p_j|\mathbf o; t_s, t_e)-\log p(p_i|\mathbf o; t_s, t_e) + for (int k = 0; k < phone_num; k++) + phone_level_feat(phone_num + k) = lpp_part(cur_phone_id) - lpp_part(k); + phone_level_feat_stdvector.push_back(phone_level_feat); + + // Compute GOP from LPP + // GOP(p)=\log \frac{LPP(p)}{\max_{q\in Q} LPP(q)} + BaseFloat gop = lpp_part(cur_phone_id) - lpp_part.Max(); + std::vector > posterior_item; + posterior_item.push_back(std::make_pair(cur_phone_id + 1, gop)); + posterior_gop.push_back(posterior_item); + + // Reset + phone_level_feat.Set(0); + duration = 0; + } + cur_phone_id = next_phone_id; + } + + // Write GOPs and the phone-level features + Matrix feats(phone_level_feat_stdvector.size(), phone_num * 2); + for (int32 i = 0; i < phone_level_feat_stdvector.size(); i++) { + SubVector row(feats, i); + row.AddVec(1.0, phone_level_feat_stdvector[i]); + } + feat_writer.Write(key, feats); + gop_writer.Write(key, posterior_gop); + num_done++; + } + + KALDI_LOG << "Processed " << num_done << " prob matrices."; + return (num_done != 0 ? 0 : 1); + } catch (const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} diff --git a/src/bin/latgen-incremental-mapped.cc b/src/bin/latgen-incremental-mapped.cc new file mode 100644 index 000000000..80c65bfb5 --- /dev/null +++ b/src/bin/latgen-incremental-mapped.cc @@ -0,0 +1,183 @@ +// bin/latgen-incremental-mapped.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "tree/context-dep.h" +#include "hmm/transition-model.h" +#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "decoder/decodable-matrix.h" +#include "base/timer.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::Fst; + using fst::StdArc; + + const char *usage = + "Generate lattices, reading log-likelihoods as matrices\n" + " (model is needed only for the integer mappings in its transition-model)\n" + "The lattice determinization algorithm here can operate\n" + "incrementally.\n" + "Usage: latgen-incremental-mapped [options] trans-model-in " + "(fst-in|fsts-rspecifier) loglikes-rspecifier" + " lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n"; + ParseOptions po(usage); + Timer timer; + bool allow_partial = false; + BaseFloat acoustic_scale = 0.1; + LatticeIncrementalDecoderConfig config; + + std::string word_syms_filename; + config.Register(&po); + po.Register("acoustic-scale", &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + + po.Register("word-symbol-table", &word_syms_filename, + "Symbol table for words [for debug output]"); + po.Register("allow-partial", &allow_partial, + "If true, produce output even if end state was not reached."); + + po.Read(argc, argv); + + if (po.NumArgs() < 4 || po.NumArgs() > 6) { + po.PrintUsage(); + exit(1); + } + + std::string model_in_filename = po.GetArg(1), fst_in_str = po.GetArg(2), + feature_rspecifier = po.GetArg(3), lattice_wspecifier = po.GetArg(4), + words_wspecifier = po.GetOptArg(5), + alignment_wspecifier = po.GetOptArg(6); + + TransitionModel trans_model; + ReadKaldiObject(model_in_filename, &trans_model); + + bool determinize = true; + CompactLatticeWriter compact_lattice_writer; + LatticeWriter lattice_writer; + if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: " + << lattice_wspecifier; + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) + KALDI_ERR << "Could not read symbol table from file " << word_syms_filename; + + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + + if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { + SequentialBaseFloatMatrixReader loglike_reader(feature_rspecifier); + // Input FST is just one FST, not a table of FSTs. + Fst *decode_fst = fst::ReadFstKaldiGeneric(fst_in_str); + timer.Reset(); + + { + LatticeIncrementalDecoder decoder(*decode_fst, trans_model, config); + + for (; !loglike_reader.Done(); loglike_reader.Next()) { + std::string utt = loglike_reader.Key(); + Matrix loglikes(loglike_reader.Value()); + loglike_reader.FreeCurrent(); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + + DecodableMatrixScaledMapped decodable(trans_model, loglikes, + acoustic_scale); + + double like; + if (DecodeUtteranceLatticeIncremental( + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else { + num_fail++; + } + } + } + delete decode_fst; // delete this only after decoder goes out of scope. + } else { // We have different FSTs for different utterances. + SequentialTableReader fst_reader(fst_in_str); + RandomAccessBaseFloatMatrixReader loglike_reader(feature_rspecifier); + for (; !fst_reader.Done(); fst_reader.Next()) { + std::string utt = fst_reader.Key(); + if (!loglike_reader.HasKey(utt)) { + KALDI_WARN << "Not decoding utterance " << utt + << " because no loglikes available."; + num_fail++; + continue; + } + const Matrix &loglikes = loglike_reader.Value(utt); + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << utt; + num_fail++; + continue; + } + LatticeIncrementalDecoder decoder(fst_reader.Value(), trans_model, config); + DecodableMatrixScaledMapped decodable(trans_model, loglikes, acoustic_scale); + double like; + if (DecodeUtteranceLatticeIncremental( + decoder, decodable, trans_model, word_syms, utt, acoustic_scale, + determinize, allow_partial, &alignment_writer, &words_writer, + &compact_lattice_writer, &lattice_writer, &like)) { + tot_like += like; + frame_count += loglikes.NumRows(); + num_success++; + } else { + num_fail++; + } + } + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken " << elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed * 100.0 / frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like / frame_count) + << " over " << frame_count << " frames."; + + delete word_syms; + if (num_success != 0) + return 0; + else + return 1; + } catch (const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/chainbin/Makefile b/src/chainbin/Makefile index 41ac7342d..1438cfa71 100644 --- a/src/chainbin/Makefile +++ b/src/chainbin/Makefile @@ -11,7 +11,8 @@ BINFILES = chain-est-phone-lm chain-get-supervision chain-make-den-fst \ nnet3-chain-shuffle-egs nnet3-chain-subset-egs \ nnet3-chain-acc-lda-stats nnet3-chain-train nnet3-chain-compute-prob \ nnet3-chain-combine nnet3-chain-normalize-egs \ - nnet3-chain-e2e-get-egs nnet3-chain-compute-post + nnet3-chain-e2e-get-egs nnet3-chain-compute-post \ + nnet3-chain-train2 OBJFILES = diff --git a/src/chainbin/nnet3-chain-combine2.cc b/src/chainbin/nnet3-chain-combine2.cc new file mode 100644 index 000000000..58f3c2013 --- /dev/null +++ b/src/chainbin/nnet3-chain-combine2.cc @@ -0,0 +1,221 @@ +// chainbin/nnet3-chain-combine.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2017 Yiming Wang +// 2019 Srikanth Madikeri (Idiap Research Institute) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "nnet3/nnet-utils.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/nnet-chain-diagnostics.h" +#include "nnet3/nnet-chain-diagnostics2.h" + + +namespace kaldi { +namespace nnet3 { + +// Computes and returns the objective function for the examples in 'egs' given +// the model in 'nnet'. If either of batchnorm/dropout test modes is true, we +// make a copy of 'nnet', set test modes on that and evaluate its objective. +// Note: the object that prob_computer->nnet_ refers to should be 'nnet'. +double ComputeObjf(bool batchnorm_test_mode, bool dropout_test_mode, + std::vector &egs, const Nnet &nnet, + const chain::ChainTrainingOptions &chain_config, + NnetChainModel2 &model, + NnetChainComputeProb2 *prob_computer) { + if (batchnorm_test_mode || dropout_test_mode) { + Nnet nnet_copy(nnet); + if (batchnorm_test_mode) + SetBatchnormTestMode(true, &nnet_copy); + if (dropout_test_mode) + SetDropoutTestMode(true, &nnet_copy); + NnetComputeProbOptions compute_prob_opts; + NnetChainComputeProb2 prob_computer_test(compute_prob_opts, chain_config, + model, nnet_copy); + return ComputeObjf(false, false, egs, nnet_copy, + chain_config, model, &prob_computer_test); + } else { + prob_computer->Reset(); + std::vector::iterator iter = egs.begin(), + end = egs.end(); + for (; iter != end; ++iter) + prob_computer->Compute(*iter); + + double tot_weight = 0.0; + double tot_objf = prob_computer->GetTotalObjective(&tot_weight); + + KALDI_ASSERT(tot_weight > 0.0); + // inf/nan tot_objf->return -inf objective. + if (!(tot_objf == tot_objf && tot_objf - tot_objf == 0)) + return -std::numeric_limits::infinity(); + // we prefer to deal with normalized objective functions. + return tot_objf / tot_weight; + } +} + +// Updates moving average over num_models nnets, given the average over +// previous (num_models - 1) nnets, and the new nnet. +void UpdateNnetMovingAverage(int32 num_models, + const Nnet &nnet, Nnet *moving_average_nnet) { + KALDI_ASSERT(NumParameters(nnet) == NumParameters(*moving_average_nnet)); + ScaleNnet((num_models - 1.0) / num_models, moving_average_nnet); + AddNnet(nnet, 1.0 / num_models, moving_average_nnet); +} + +} +} + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Using a subset of training or held-out nnet3+chain examples, compute\n" + "the average over the first n nnet models where we maximize the\n" + "'chain' objective function for n. Note that the order of models has\n" + "been reversed before feeding into this binary. So we are actually\n" + "combining last n models.\n" + "Inputs and outputs are nnet3 raw nnets.\n" + "\n" + "Usage: nnet3-chain-combine [options] ... \n" + "\n" + "e.g.:\n" + " nnet3-combine den.fst 35.raw 36.raw 37.raw 38.raw ark:valid.cegs final.raw\n"; + + bool binary_write = true; + int32 max_objective_evaluations = 30; + bool batchnorm_test_mode = false, + dropout_test_mode = true; + std::string use_gpu = "yes"; + chain::ChainTrainingOptions chain_config; + NnetChainTraining2Options opts; + + ParseOptions po(usage); + po.Register("binary", &binary_write, "Write output in binary mode"); + po.Register("max-objective-evaluations", &max_objective_evaluations, "The " + "maximum number of objective evaluations in order to figure " + "out the best number of models to combine. It helps to speedup " + "if the number of models provided to this binary is quite " + "large (e.g. several hundred)."); + po.Register("use-gpu", &use_gpu, + "yes|no|optional|wait, only has effect if compiled with CUDA"); + po.Register("batchnorm-test-mode", &batchnorm_test_mode, + "If true, set test-mode to true on any BatchNormComponents " + "while evaluating objectives."); + po.Register("dropout-test-mode", &dropout_test_mode, + "If true, set test-mode to true on any DropoutComponents and " + "DropoutMaskComponents while evaluating objectives."); + + chain_config.Register(&po); + opts.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() < 4) { + po.PrintUsage(); + exit(1); + } + +#if HAVE_CUDA==1 + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + std::string + den_fst_dirname = po.GetArg(1), + raw_nnet_rxfilename = po.GetArg(2), + valid_examples_rspecifier = po.GetArg(po.NumArgs() - 1), + nnet_wxfilename = po.GetArg(po.NumArgs()); + + Nnet nnet; + ReadKaldiObject(raw_nnet_rxfilename, &nnet); + NnetChainModel2 model(opts, &nnet, den_fst_dirname); + Nnet moving_average_nnet(nnet), best_nnet(nnet); + NnetComputeProbOptions compute_prob_opts; + NnetChainComputeProb2 prob_computer(compute_prob_opts, chain_config, + model, moving_average_nnet); + + std::vector egs; + egs.reserve(10000); // reserve a lot of space to minimize the chance of + // reallocation. + + { // This block adds training examples to "egs". + SequentialNnetChainExampleReader example_reader( + valid_examples_rspecifier); + for (; !example_reader.Done(); example_reader.Next()) + egs.push_back(example_reader.Value()); + KALDI_LOG << "Read " << egs.size() << " examples."; + KALDI_ASSERT(!egs.empty()); + } + + // first evaluates the objective using the last model. + int32 best_num_to_combine = 1; + double + init_objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode, + egs, moving_average_nnet, chain_config, model, &prob_computer), + best_objf = init_objf; + KALDI_LOG << "objective function using the last model is " << init_objf; + + int32 num_nnets = po.NumArgs() - 3; + // then each time before we re-evaluate the objective function, we will add + // num_to_add models to the moving average. + int32 num_to_add = (num_nnets + max_objective_evaluations - 1) / + max_objective_evaluations; + for (int32 n = 1; n < num_nnets; n++) { + std::string this_nnet_rxfilename = po.GetArg(n + 2); + ReadKaldiObject(this_nnet_rxfilename, &nnet); + // updates the moving average + UpdateNnetMovingAverage(n + 1, nnet, &moving_average_nnet); + // evaluates the objective everytime after adding num_to_add model or + // all the models to the moving average. + if ((n - 1) % num_to_add == num_to_add - 1 || n == num_nnets - 1) { + double objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode, + egs, moving_average_nnet, chain_config, model, &prob_computer); + KALDI_LOG << "Combining last " << n + 1 + << " models, objective function is " << objf; + if (objf > best_objf) { + best_objf = objf; + best_nnet = moving_average_nnet; + best_num_to_combine = n + 1; + } + } + } + KALDI_LOG << "Combining " << best_num_to_combine + << " nnets, objective function changed from " << init_objf + << " to " << best_objf; + + if (HasBatchnorm(nnet)) + RecomputeStats2(egs, chain_config, model, &best_nnet); + +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif + + WriteKaldiObject(best_nnet, nnet_wxfilename, binary_write); + KALDI_LOG << "Finished combining neural nets, wrote model to " + << nnet_wxfilename; + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + diff --git a/src/chainbin/nnet3-chain-get-egs.cc b/src/chainbin/nnet3-chain-get-egs.cc index 1032b7e21..9a53ef8ed 100644 --- a/src/chainbin/nnet3-chain-get-egs.cc +++ b/src/chainbin/nnet3-chain-get-egs.cc @@ -95,7 +95,7 @@ static bool ProcessFile(const TransitionModel *trans_mdl, const VectorBase *deriv_weights, int32 supervision_length_tolerance, const std::string &utt_id, - bool compress, + bool compress, bool long_key, UtteranceSplitter *utt_splitter, NnetChainExampleWriter *example_writer) { KALDI_ASSERT(supervision.num_sequences == 1); @@ -228,9 +228,14 @@ static bool ProcessFile(const TransitionModel *trans_mdl, nnet_chain_eg.Compress(); std::ostringstream os; - os << utt_id << "-" << chunk.first_frame; + if (long_key) + os << utt_id + << "-" << chunk.first_frame << "-" << chunk.left_context + << "-" << chunk.num_frames << "-" << chunk.right_context << "-v1"; + else // key is - + os << utt_id << "-" << chunk.first_frame; - std::string key = os.str(); // key is - + std::string key = os.str(); example_writer->Write(key, nnet_chain_eg); } @@ -265,7 +270,7 @@ int main(int argc, char *argv[]) { "Note: the --frame-subsampling-factor option must be the same as given to\n" "chain-get-supervision.\n"; - bool compress = true; + bool compress = true, long_key = false; int32 length_tolerance = 100, online_ivector_period = 1, supervision_length_tolerance = 1; @@ -283,7 +288,7 @@ int main(int argc, char *argv[]) { "in compressed format (recommended). Update: this is now " "only relevant if the features being read are un-compressed; " "if already compressed, we keep we same compressed format when " - "dumping-egs."); + "dumping egs."); po.Register("ivectors", &online_ivector_rspecifier, "Alias for " "--online-ivectors option, for back compatibility"); po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier of " @@ -311,6 +316,8 @@ int main(int argc, char *argv[]) { "Filename of transition model to read; should only be supplied " "if you want 'unconstrained' egs, and if you supplied " "--convert-to-pdfs=false to chain-get-supervision."); + po.Register("long-key", &long_key, "If true, a long format will be used " + "for the key, which encodes context info, etc."); eg_config.Register(&po); @@ -426,7 +433,7 @@ int main(int argc, char *argv[]) { if (!ProcessFile(trans_mdl_ptr, normalization_fst, feats, online_ivector_feats, online_ivector_period, supervision, deriv_weights, supervision_length_tolerance, - key, compress, + key, compress, long_key, &utt_splitter, &example_writer)) num_err++; } diff --git a/src/chainbin/nnet3-chain-train2.cc b/src/chainbin/nnet3-chain-train2.cc new file mode 100644 index 000000000..083b2637c --- /dev/null +++ b/src/chainbin/nnet3-chain-train2.cc @@ -0,0 +1,105 @@ +// nnet3bin/nnet3-chain-train.cc + +// Copyright 2015 Johns Hopkins University (author: Daniel Povey) +// 2019 Idiap Research Institute (author: Srikanth Madikeri) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "nnet3/nnet-chain-training2.h" +#include "cudamatrix/cu-allocator.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + using namespace kaldi::chain; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Train nnet3+chain neural network parameters with backprop and stochastic\n" + "gradient descent. Minibatches are to be created by nnet3-chain-merge-egs in\n" + "the input pipeline. This training program is single-threaded (best to\n" + "use it with a GPU).\n" + "\n" + "Usage: nnet3-chain-train [options] \n" + "\n" + "nnet3-chain-train 1.raw den.fst 'ark:nnet3-merge-egs 1.cegs ark:-|' 2.raw\n"; + + int32 srand_seed = 0; + bool binary_write = true; + std::string use_gpu = "yes"; + NnetChainTraining2Options opts; + + ParseOptions po(usage); + po.Register("srand", &srand_seed, "Seed for random number generator "); + po.Register("binary", &binary_write, "Write output in binary mode"); + po.Register("use-gpu", &use_gpu, + "yes|no|optional|wait, only has effect if compiled with CUDA"); + + opts.Register(&po); + RegisterCuAllocatorOptions(&po); + + po.Read(argc, argv); + + srand(srand_seed); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + +#if HAVE_CUDA==1 + CuDevice::Instantiate().SelectGpuId(use_gpu); +#endif + + std::string nnet_rxfilename = po.GetArg(1), + den_fst_dirname = po.GetArg(2), + examples_rspecifier = po.GetArg(3), + nnet_wxfilename = po.GetArg(4); + + Nnet nnet; + ReadKaldiObject(nnet_rxfilename, &nnet); + + bool ok; + + { + NnetChainModel2 model(opts, &nnet, den_fst_dirname); + NnetChainTrainer2 trainer(opts, model, &nnet); + + SequentialNnetChainExampleReader example_reader(examples_rspecifier); + + for (; !example_reader.Done(); example_reader.Next()) + trainer.Train(example_reader.Key(), example_reader.Value()); + + ok = trainer.PrintTotalStats(); + } + +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif + WriteKaldiObject(nnet, nnet_wxfilename, binary_write); + KALDI_LOG << "Wrote raw model to " << nnet_wxfilename; + return (ok ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + diff --git a/src/configure b/src/configure index e6ffdf337..7015946a5 100755 --- a/src/configure +++ b/src/configure @@ -73,6 +73,7 @@ Configuration options: --cudatk-dir=DIR CUDA toolkit directory --cuda-arch=FLAGS Override the default CUDA_ARCH flags. See: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#nvcc-examples. + --debug-level=N Use assertion level 0 (disabled), 1, or 2 [default=1] --double-precision Build with BaseFloat set to double if yes [default=no], mostly useful for testing purposes. --static-fst Build with static OpenFst libraries [default=no] @@ -739,6 +740,7 @@ ENV_LDFLAGS=$LDFLAGS ENV_LDLIBS=$LDLIBS # Default configuration +debug_level=1 double_precision=false dynamic_kaldi=false use_cuda=true @@ -771,6 +773,9 @@ do static_math=false; static_fst=false; shift ;; + --debug-level=*) + GetSwitchValueOrDie debug_level "$1" + shift ;; --double-precision) double_precision=true; shift ;; @@ -901,6 +906,11 @@ do esac done +case "$debug_level" in + [012]) ;; + *) failure "Invalid value --debug-level=$debug_level. Supported values are 0, 1, and 2." ;; +esac + # The idea here is that if you change the configuration options from using # CUDA to not using it, or vice versa, we want to recompile all parts of the # code that may use a GPU. Touching this file is a way to force this. @@ -1033,6 +1043,7 @@ if $dynamic_kaldi ; then echo "KALDI_FLAVOR := dynamic" >> kaldi.mk echo "KALDILIBDIR := $KALDILIBDIR" >> kaldi.mk fi +echo "DEBUG_LEVEL = $debug_level" >> kaldi.mk if $double_precision; then echo "DOUBLE_PRECISION = 1" >> kaldi.mk else diff --git a/src/cudadecoder/cuda-decoder.cc b/src/cudadecoder/cuda-decoder.cc index 9d274eb80..850257fef 100644 --- a/src/cudadecoder/cuda-decoder.cc +++ b/src/cudadecoder/cuda-decoder.cc @@ -358,6 +358,7 @@ void CudaDecoder::ComputeInitialChannel() { CopyLaneCountersToHostSync(); PostProcessingMainQueue(); + ConcatenateData(); CopyLaneCountersToHostSync(); const int32 main_q_end = diff --git a/src/cudadecoder/thread-pool.h b/src/cudadecoder/thread-pool.h index 12cd27da4..920ea6d33 100644 --- a/src/cudadecoder/thread-pool.h +++ b/src/cudadecoder/thread-pool.h @@ -28,6 +28,7 @@ freely, subject to the following restrictions: #ifndef KALDI_CUDA_DECODER_THREAD_POOL_H_ #define KALDI_CUDA_DECODER_THREAD_POOL_H_ +#include #include #include #include diff --git a/src/cudafeat/feature-spectral-cuda.cu b/src/cudafeat/feature-spectral-cuda.cu index 18b3eed98..8d01a9ac7 100644 --- a/src/cudafeat/feature-spectral-cuda.cu +++ b/src/cudafeat/feature-spectral-cuda.cu @@ -276,7 +276,7 @@ __device__ inline int32 FirstSampleOfFrame(int32 frame, int32 frame_shift, __global__ void extract_window_kernel( int32 frame_shift, int32 frame_length, int32 frame_length_padded, int32 window_size, bool snip_edges, int32_t sample_offset, - const BaseFloat __restrict__ *wave, int32 wave_dim, + const BaseFloat * __restrict__ wave, int32 wave_dim, BaseFloat *__restrict__ windows, int32_t wlda) { int frame = blockIdx.x; int tidx = threadIdx.x; @@ -503,8 +503,8 @@ void CudaSpectralFeatures::ComputeFinalFeatures(int num_frames, BaseFloat vtln_w kTrans, 0.0); apply_lifter_and_floor_energy<<>>( - cu_features->NumRows(), cu_features->NumCols(), - mfcc_opts.cepstral_lifter, mfcc_opts.use_energy, + cu_features->NumRows(), cu_features->NumCols(), + mfcc_opts.cepstral_lifter, mfcc_opts.use_energy, mfcc_opts.energy_floor, cu_signal_log_energy->Data(), cu_lifter_coeffs_.Data(), cu_features->Data(), cu_features->Stride()); } else { diff --git a/src/cudamatrix/cu-array.h b/src/cudamatrix/cu-array.h index 82d07bdab..84f78f00a 100644 --- a/src/cudamatrix/cu-array.h +++ b/src/cudamatrix/cu-array.h @@ -189,8 +189,8 @@ class CuSubArray: public CuArrayBase { CuSubArray(const T* data, MatrixIndexT length) { // Yes, we're evading C's restrictions on const here, and yes, it can be used // to do wrong stuff; unfortunately the workaround would be very difficult. - CuArrayBase::data_ = const_cast(data); - CuArrayBase::dim_ = length; + this->data_ = const_cast(data); + this->dim_ = length; } }; diff --git a/src/cudamatrix/cu-vector.h b/src/cudamatrix/cu-vector.h index 9c532b52f..f1c327568 100644 --- a/src/cudamatrix/cu-vector.h +++ b/src/cudamatrix/cu-vector.h @@ -135,15 +135,15 @@ class CuVectorBase { void Floor(const CuVectorBase &src, Real floor_val, MatrixIndexT *floored_count = NULL); void Ceiling(const CuVectorBase &src, Real ceiling_val, MatrixIndexT *ceiled_count = NULL); void Pow(const CuVectorBase &src, Real power); - + inline void ApplyFloor(Real floor_val, MatrixIndexT *floored_count = NULL) { - this -> Floor(*this, floor_val, floored_count); + this -> Floor(*this, floor_val, floored_count); }; - + inline void ApplyCeiling(Real ceiling_val, MatrixIndexT *ceiled_count = NULL) { this -> Ceiling(*this, ceiling_val, ceiled_count); }; - + inline void ApplyPow(Real power) { this -> Pow(*this, power); }; @@ -329,27 +329,27 @@ class CuSubVector: public CuVectorBase { KALDI_ASSERT(static_cast(origin)+ static_cast(length) <= static_cast(t.Dim())); - CuVectorBase::data_ = const_cast(t.Data()+origin); - CuVectorBase::dim_ = length; + this->data_ = const_cast(t.Data()+origin); + this->dim_ = length; } /// Copy constructor /// this constructor needed for Range() to work in base class. CuSubVector(const CuSubVector &other) : CuVectorBase () { - CuVectorBase::data_ = other.data_; - CuVectorBase::dim_ = other.dim_; + this->data_ = other.data_; + this->dim_ = other.dim_; } CuSubVector(const Real* data, MatrixIndexT length) : CuVectorBase () { // Yes, we're evading C's restrictions on const here, and yes, it can be used // to do wrong stuff; unfortunately the workaround would be very difficult. - CuVectorBase::data_ = const_cast(data); - CuVectorBase::dim_ = length; + this->data_ = const_cast(data); + this->dim_ = length; } /// This operation does not preserve const-ness, so be careful. CuSubVector(const CuMatrixBase &matrix, MatrixIndexT row) { - CuVectorBase::data_ = const_cast(matrix.RowData(row)); - CuVectorBase::dim_ = matrix.NumCols(); + this->data_ = const_cast(matrix.RowData(row)); + this->dim_ = matrix.NumCols(); } diff --git a/src/decoder/Makefile b/src/decoder/Makefile index fbd8386f0..a814931f6 100644 --- a/src/decoder/Makefile +++ b/src/decoder/Makefile @@ -7,7 +7,8 @@ TESTFILES = OBJFILES = training-graph-compiler.o lattice-simple-decoder.o lattice-faster-decoder.o \ lattice-faster-online-decoder.o simple-decoder.o faster-decoder.o \ - decoder-wrappers.o grammar-fst.o decodable-matrix.o + decoder-wrappers.o grammar-fst.o decodable-matrix.o \ + lattice-incremental-decoder.o lattice-incremental-online-decoder.o LIBNAME = kaldi-decoder diff --git a/src/decoder/decoder-wrappers.cc b/src/decoder/decoder-wrappers.cc index 588274e11..f63b3caa7 100644 --- a/src/decoder/decoder-wrappers.cc +++ b/src/decoder/decoder-wrappers.cc @@ -68,7 +68,7 @@ void DecodeUtteranceLatticeFasterClass::operator () () { success_ = true; using fst::VectorFst; if (!decoder_->Decode(decodable_)) { - KALDI_WARN << "Failed to decode file " << utt_; + KALDI_WARN << "Failed to decode utterance with id " << utt_; success_ = false; } if (!decoder_->ReachedFinal()) { @@ -195,6 +195,92 @@ DecodeUtteranceLatticeFasterClass::~DecodeUtteranceLatticeFasterClass() { delete decodable_; } +template +bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr) { // puts utterance's like in like_ptr on success. + using fst::VectorFst; + if (!decoder.Decode(&decodable)) { + KALDI_WARN << "Failed to decode utterance with id " << utt; + return false; + } + if (!decoder.ReachedFinal()) { + if (allow_partial) { + KALDI_WARN << "Outputting partial output for utterance " << utt + << " since no final-state reached\n"; + } else { + KALDI_WARN << "Not producing output for utterance " << utt + << " since no final-state reached and " + << "--allow-partial=false.\n"; + return false; + } + } + + // Get lattice + CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(), true); + if (clat.NumStates() == 0) + KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt; + + double likelihood; + LatticeWeight weight; + int32 num_frames; + { // First do some stuff with word-level traceback... + CompactLattice decoded_clat; + CompactLatticeShortestPath(clat, &decoded_clat); + Lattice decoded; + fst::ConvertLattice(decoded_clat, &decoded); + + if (decoded.Start() == fst::kNoStateId) + // Shouldn't really reach this point as already checked success. + KALDI_ERR << "Failed to get traceback for utterance " << utt; + + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + num_frames = alignment.size(); + KALDI_ASSERT(num_frames == decoder.NumFramesDecoded()); + if (words_writer->IsOpen()) + words_writer->Write(utt, words); + if (alignment_writer->IsOpen()) + alignment_writer->Write(utt, alignment); + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << '\n'; + } + likelihood = -(weight.Value1() + weight.Value2()); + } + + // We'll write the lattice without acoustic scaling. + if (acoustic_scale != 0.0) + fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat); + Connect(&clat); + compact_lattice_writer->Write(utt, clat); + KALDI_LOG << "Log-like per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " + << num_frames << " frames."; + KALDI_VLOG(2) << "Cost for utterance " << utt << " is " + << weight.Value1() << " + " << weight.Value2(); + *like_ptr = likelihood; + return true; +} + // Takes care of output. Returns true on success. template @@ -215,7 +301,7 @@ bool DecodeUtteranceLatticeFaster( using fst::VectorFst; if (!decoder.Decode(&decodable)) { - KALDI_WARN << "Failed to decode file " << utt; + KALDI_WARN << "Failed to decode utterance with id " << utt; return false; } if (!decoder.ReachedFinal()) { @@ -296,6 +382,37 @@ bool DecodeUtteranceLatticeFaster( } // Instantiate the template above for the two required FST types. +template bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl > &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + +template bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, + DecodableInterface &decodable, + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignment_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); + + template bool DecodeUtteranceLatticeFaster( LatticeFasterDecoderTpl > &decoder, DecodableInterface &decodable, @@ -345,7 +462,7 @@ bool DecodeUtteranceLatticeSimple( using fst::VectorFst; if (!decoder.Decode(&decodable)) { - KALDI_WARN << "Failed to decode file " << utt; + KALDI_WARN << "Failed to decode utterance with id " << utt; return false; } if (!decoder.ReachedFinal()) { diff --git a/src/decoder/decoder-wrappers.h b/src/decoder/decoder-wrappers.h index 17592d028..085c8e94e 100644 --- a/src/decoder/decoder-wrappers.h +++ b/src/decoder/decoder-wrappers.h @@ -22,6 +22,7 @@ #include "itf/options-itf.h" #include "decoder/lattice-faster-decoder.h" +#include "decoder/lattice-incremental-decoder.h" #include "decoder/lattice-simple-decoder.h" // This header contains declarations from various convenience functions that are called @@ -88,6 +89,23 @@ void AlignUtteranceWrapper( void ModifyGraphForCarefulAlignment( fst::VectorFst *fst); +/// TODO +template +bool DecodeUtteranceLatticeIncremental( + LatticeIncrementalDecoderTpl &decoder, // not const but is really an input. + DecodableInterface &decodable, // not const but is really an input. + const TransitionModel &trans_model, + const fst::SymbolTable *word_syms, + std::string utt, + double acoustic_scale, + bool determinize, + bool allow_partial, + Int32VectorWriter *alignments_writer, + Int32VectorWriter *words_writer, + CompactLatticeWriter *compact_lattice_writer, + LatticeWriter *lattice_writer, + double *like_ptr); // puts utterance's likelihood in like_ptr on success. + /// This function DecodeUtteranceLatticeFaster is used in several decoders, and /// we have moved it here. Note: this is really "binary-level" code as it diff --git a/src/decoder/lattice-faster-decoder.cc b/src/decoder/lattice-faster-decoder.cc index 9106309eb..83c582d3b 100644 --- a/src/decoder/lattice-faster-decoder.cc +++ b/src/decoder/lattice-faster-decoder.cc @@ -229,24 +229,17 @@ void LatticeFasterDecoderTpl::PossiblyResizeHash(size_t num_toks) { extra_cost is used in pruning tokens, to save memory. - Define the 'forward cost' of a token as zero for any token on the frame - we're currently decoding; and for other frames, as the shortest-path cost - between that token and a token on the frame we're currently decoding. - (by "currently decoding" I mean the most recently processed frame). - - Then define the extra_cost of a token (always >= 0) as the forward-cost of - the token minus the smallest forward-cost of any token on the same frame. + extra_cost can be thought of as a beta (backward) cost assuming + we had set the betas on currently-active tokens to all be the negative + of the alphas for those tokens. (So all currently active tokens would + be on (tied) best paths). We can use the extra_cost to accurately prune away tokens that we know will never appear in the lattice. If the extra_cost is greater than the desired lattice beam, the token would provably never appear in the lattice, so we can prune away the token. - The advantage of storing the extra_cost rather than the forward-cost, is that - it is less costly to keep the extra_cost up-to-date when we process new frames. - When we process a new frame, *all* the previous frames' forward-costs would change; - but in general the extra_cost will change only for a finite number of frames. - (Actually we don't update all the extra_costs every time we update a frame; we + (Note: we don't update all the extra_costs every time we update a frame; we only do it every 'config_.prune_interval' frames). */ diff --git a/src/decoder/lattice-faster-decoder.h b/src/decoder/lattice-faster-decoder.h index e0cf7dea8..57cbe5fe1 100644 --- a/src/decoder/lattice-faster-decoder.h +++ b/src/decoder/lattice-faster-decoder.h @@ -43,11 +43,13 @@ struct LatticeFasterDecoderConfig { int32 prune_interval; bool determinize_lattice; // not inspected by this class... used in // command-line program. - BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat beam_delta; BaseFloat hash_ratio; - BaseFloat prune_scale; // Note: we don't make this configurable on the command line, - // it's not a very important parameter. It affects the - // algorithm that prunes the tokens as we go. + // Note: we don't make prune_scale configurable on the command line, it's not + // a very important parameter. It affects the algorithm that prunes the + // tokens as we go. + BaseFloat prune_scale; + // Most of the options inside det_opts are not actually queried by the // LatticeFasterDecoder class itself, but by the code that calls it, for // example in the function DecodeUtteranceLatticeFaster. @@ -316,15 +318,10 @@ class LatticeFasterDecoderTpl { /// This function may be optionally called after AdvanceDecoding(), when you /// do not plan to decode any further. It does an extra pruning step that /// will help to prune the lattices output by GetLattice and (particularly) - /// GetRawLattice more accurately, particularly toward the end of the - /// utterance. It does this by using the final-probs in pruning (if any - /// final-state survived); it also does a final pruning step that visits all - /// states (the pruning that is done during decoding may fail to prune states - /// that are within kPruningScale = 0.1 outside of the beam). If you call - /// this, you cannot call AdvanceDecoding again (it will fail), and you - /// cannot call GetLattice() and related functions with use_final_probs = - /// false. - /// Used to be called PruneActiveTokensFinal(). + /// GetRawLattice more completely, particularly toward the end of the + /// utterance. If you call this, you cannot call AdvanceDecoding again (it + /// will fail), and you cannot call GetLattice() and related functions with + /// use_final_probs = false. Used to be called PruneActiveTokensFinal(). void FinalizeDecoding(); /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives diff --git a/src/decoder/lattice-incremental-decoder.cc b/src/decoder/lattice-incremental-decoder.cc new file mode 100644 index 000000000..81e700833 --- /dev/null +++ b/src/decoder/lattice-incremental-decoder.cc @@ -0,0 +1,1720 @@ +// decoder/lattice-incremental-decoder.cc + +// Copyright 2019 Zhehuai Chen, Daniel Povey + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-incremental-decoder.h" +#include "lat/lattice-functions.h" +#include "base/timer.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( + const FST &fst, const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config) + : fst_(&fst), + delete_fst_(false), + num_toks_(0), + config_(config), + determinizer_(trans_model, config) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeIncrementalDecoderTpl::LatticeIncrementalDecoderTpl( + const LatticeIncrementalDecoderConfig &config, FST *fst, + const TransitionModel &trans_model) + : fst_(fst), + delete_fst_(true), + num_toks_(0), + config_(config), + determinizer_(trans_model, config) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeIncrementalDecoderTpl::~LatticeIncrementalDecoderTpl() { + DeleteElems(toks_.Clear()); + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeIncrementalDecoderTpl::InitDecoding() { + // clean up from last time: + DeleteElems(toks_.Clear()); + cost_offsets_.clear(); + ClearActiveTokens(); + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + toks_.Insert(start_state, start_tok); + num_toks_++; + + determinizer_.Init(); + num_frames_in_lattice_ = 0; + token2label_map_.clear(); + next_token_label_ = LatticeIncrementalDeterminizer::kTokenLabelOffset; + ProcessNonemitting(config_.beam); +} + +template +void LatticeIncrementalDecoderTpl::UpdateLatticeDeterminization() { + if (NumFramesDecoded() - num_frames_in_lattice_ < + config_.determinize_max_delay) + return; + + + /* Make sure the token-pruning is active. Note: PruneActiveTokens() has + internal logic that prevents it from doing unnecessary work if you + call it and then immediately call it again. */ + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + + int32 first = num_frames_in_lattice_ + config_.determinize_min_chunk_size, + last = NumFramesDecoded(), + fewest_tokens = std::numeric_limits::max(), + best_frame = -1; + for (int32 t = last; t >= first; t--) { + /* Make sure PruneActiveTokens() has computed num_toks for all these + frames... */ + KALDI_ASSERT(active_toks_[t].num_toks != -1); + if (active_toks_[t].num_toks < fewest_tokens) { + // <= because we want the latest one in case of ties. + fewest_tokens = active_toks_[t].num_toks; + best_frame = t; + } + } + /* OK, determinize the chunk that spans from num_frames_in_lattice_ to + best_frame. */ + bool use_final_probs = false; + GetLattice(best_frame, use_final_probs); + return; +} +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeIncrementalDecoderTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + + while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + UpdateLatticeDeterminization(); + + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } + Timer timer; + FinalizeDecoding(); + bool use_final_probs = true; + GetLattice(NumFramesDecoded(), use_final_probs); + KALDI_VLOG(2) << "Delay time during and after FinalizeDecoding()" + << "(secs): " << timer.Elapsed(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +template +void LatticeIncrementalDecoderTpl::PossiblyResizeHash(size_t num_toks) { + size_t new_sz = + static_cast(static_cast(num_toks) * config_.hash_ratio); + if (new_sz > toks_.Size()) { + toks_.SetSize(new_sz); + } +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + extra_cost can be thought of as a beta (backward) cost assuming + we had set the betas on currently-active tokens to all be the negative + of the alphas for those tokens. (So all currently active tokens would + be on (tied) best paths). + + + Define the 'forward cost' of a token as zero for any token on the frame + we're currently decoding; and for other frames, as the shortest-path cost + between that token and a token on the frame we're currently decoding. + (by "currently decoding" I mean the most recently processed frame). + + Then define the extra_cost of a token (always >= 0) as the forward-cost of + the token minus the smallest forward-cost of any token on the same frame. + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + The advantage of storing the extra_cost rather than the forward-cost, is that + it is less costly to keep the extra_cost up-to-date when we process new frames. + When we process a new frame, *all* the previous frames' forward-costs would change; + but in general the extra_cost will change only for a finite number of frames. + (Actually we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash of toks_, +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline Token *LatticeIncrementalDecoderTpl::FindOrAddToken( + StateId state, int32 frame_plus_one, BaseFloat tot_cost, Token *backpointer, + bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + Elem *e_found = toks_.Find(state); + if (e_found == NULL) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new Token(tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + toks_.Insert(state, new_tok); + if (changed) *changed = true; + return new_tok; + } else { + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return tok; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeIncrementalDecoderTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned, + BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; + tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL;) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) + prev_link->next = next_link; + else + tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeIncrementalDecoderTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + // We call DeleteElems() as a nicety, not because it's really necessary; + // otherwise there would be a time, after calling PruneTokensForFrame() on the + // final frame, when toks_.GetList() or toks_.Clear() would contain pointers + // to nonexistent tokens. + DeleteElems(toks_.Clear()); + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL; + tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this + // token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL;) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = + next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - + next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) + prev_link->next = next_link; + else + tok->links = next_link; + delete link; + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeIncrementalDecoderTpl::FinalRelativeCost() const { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; +} + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeIncrementalDecoderTpl::PruneTokensForFrame( + int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + int32 num_toks = 0; + for (tok = toks; tok != NULL; tok = next_tok, num_toks++) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) + prev_tok->next = tok->next; + else + toks = tok->next; + delete tok; + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } + active_toks_[frame_plus_one].num_toks = num_toks; +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeIncrementalDecoderTpl::PruneActiveTokens(BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + + if (active_toks_[cur_frame_plus_one].num_toks == -1){ + // The current frame's tokens don't get pruned so they don't get counted + // (the count is needed by the incremental determinization code). + // Fix this. + int this_frame_num_toks = 0; + for (Token *t = active_toks_[cur_frame_plus_one].toks; t != NULL; t = t->next) + this_frame_num_toks++; + active_toks_[cur_frame_plus_one].num_toks = this_frame_num_toks; + } + + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f - 1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f + 1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f + 1].must_prune_tokens) { + PruneTokensForFrame(f + 1); + active_toks_[f + 1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeIncrementalDecoderTpl::ComputeFinalCosts( + unordered_map *final_costs, BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + if (decoding_finalized_) { + // If we finalized decoding, the list toks_ will no longer exist, so return + // something we already computed. + if (final_costs) *final_costs = final_costs_; + if (final_relative_cost) *final_relative_cost = final_relative_cost_; + if (final_best_cost) *final_best_cost = final_best_cost_; + return; + } + if (final_costs != NULL) final_costs->clear(); + const Elem *final_toks = toks_.GetList(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, best_cost_with_final = infinity; + + while (final_toks != NULL) { + StateId state = final_toks->key; + Token *tok = final_toks->val; + const Elem *next = final_toks->tail; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + final_toks = next; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeIncrementalDecoderTpl::AdvanceDecoding( + DecodableInterface *decodable, int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeIncrementalDecoderTpl, Token> *this_cast = + reinterpret_cast< + LatticeIncrementalDecoderTpl, Token> *>( + this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeIncrementalDecoderTpl, Token> *this_cast = + reinterpret_cast< + LatticeIncrementalDecoderTpl, Token> *>( + this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = + std::min(target_frames_decoded, NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } + UpdateLatticeDeterminization(); +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeIncrementalDecoderTpl::FinalizeDecoding() { + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes the final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin << " to " << num_toks_; +} + +/// Gets the weight cutoff. Also counts the active tokens. +template +BaseFloat LatticeIncrementalDecoderTpl::GetCutoff( + Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem) { + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) + min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) + ? tmp_array_.begin() + config_.max_active + : tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +BaseFloat LatticeIncrementalDecoderTpl::ProcessEmitting( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_ + // in simple-decoder.h. Removes the Elems from + // being indexed in the hash in toks_. + Elem *best_elem = NULL; + BaseFloat adaptive_beam; + size_t tok_cnt; + BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. + + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // pruning "online" before having seen all tokens + + BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good + // dynamic range. + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + if (best_elem) { + StateId state = best_elem->key; + Token *tok = best_elem->val; + cost_offset = -tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + + tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // the tokens are now owned here, in final_toks, and the hash is empty. + // 'owned' is a complex thing here; the point is we need to call DeleteElem + // on each elem 'e' to let toks_ know we're done with them. + for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { + // loop this way because we delete "e" as we go. + StateId state = e->key; + Token *tok = e->val; + if (tok->tot_cost <= cur_cutoff) { + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat ac_cost = + cost_offset - decodable->LogLikelihood(frame, arc.ilabel), + graph_cost = arc.weight.Value(), cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost > next_cutoff) + continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // prune by best current token + // Note: the frame indexes into active_toks_ are one-based, + // hence the + 1. + Token *next_tok = + FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, NULL); + // NULL: no change indicator needed + + // Add ForwardLink from tok to next_tok (put on head of list tok->links) + tok->links = new ForwardLinkT(next_tok, arc.ilabel, arc.olabel, graph_cost, + ac_cost, tok->links); + } + } // for all arcs + } + e_tail = e->tail; + toks_.Delete(e); // delete Elem + } + return next_cutoff; +} + +// static inline +template +void LatticeIncrementalDecoderTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + delete l; + l = m; + } + tok->links = NULL; +} + +template +void LatticeIncrementalDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame = static_cast(active_toks_.size()) - 2; + // Note: "frame" is the time-index we just processed, or -1 if + // we are processing the nonemitting transitions before the + // first frame (called from InitDecoding()). + + // Processes nonemitting arcs for one frame. Propagates within toks_. + // Note-- this queue structure is is not very optimal as + // it may cause us to process states unnecessarily (e.g. more than once), + // but in the baseline code, turning this vector into a set to fix this + // problem did not improve overall speed. + + KALDI_ASSERT(queue_.empty()); + + if (toks_.GetList() == NULL) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens: frame is " << frame; + warned_ = true; + } + } + + for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { + StateId state = e->key; + if (fst_->NumInputEpsilons(state) != 0) queue_.push_back(state); + } + + while (!queue_.empty()) { + StateId state = queue_.back(); + queue_.pop_back(); + + Token *tok = + toks_.Find(state) + ->val; // would segfault if state not in toks_ but this can't happen. + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost > cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + // but since most states are emitting it's not a huge issue. + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel == 0) { // propagate nonemitting only... + BaseFloat graph_cost = arc.weight.Value(), tot_cost = cur_cost + graph_cost; + if (tot_cost < cutoff) { + bool changed; + + Token *new_tok = + FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, &changed); + + tok->links = + new ForwardLinkT(new_tok, 0, arc.olabel, graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new [if so, add into queue]. + if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0) + queue_.push_back(arc.nextstate); + } + } + } // for all arcs + } // while queue not empty +} + +template +void LatticeIncrementalDecoderTpl::DeleteElems(Elem *list) { + for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks_.Delete(e); + } +} + +template +void LatticeIncrementalDecoderTpl< + FST, Token>::ClearActiveTokens() { // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL;) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + delete tok; + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + + +template +const CompactLattice& LatticeIncrementalDecoderTpl::GetLattice( + int32 num_frames_to_include, + bool use_final_probs) { + KALDI_ASSERT(num_frames_to_include >= num_frames_in_lattice_ && + num_frames_to_include <= NumFramesDecoded()); + + if (decoding_finalized_ && !use_final_probs) { + // This is not supported + KALDI_ERR << "You cannot get the lattice without final-probs after " + "calling FinalizeDecoding()."; + } + if (use_final_probs && num_frames_to_include != NumFramesDecoded()) { + /* This is because we only remember the relation between HCLG states and + Tokens for the current frame; the Token does not have a `state` field. */ + KALDI_ERR << "use-final-probs may no be true if you are not " + "getting a lattice for all frames decoded so far."; + } + + + if (num_frames_to_include > num_frames_in_lattice_) { + /* Make sure the token-pruning is up to date. If we just pruned the tokens, + this will do very little work. */ + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + + Lattice chunk_lat; + + unordered_map token_label2state; + if (num_frames_in_lattice_ != 0) { + determinizer_.InitializeRawLatticeChunk(&chunk_lat, + &token_label2state); + } + + // tok_map will map from Token* to state-id in chunk_lat. + // The cur and prev versions alternate on different frames. + unordered_map &tok2state_map(temp_token_map_); + tok2state_map.clear(); + + unordered_map &next_token2label_map(token2label_map_temp_); + next_token2label_map.clear(); + + { // Deal with the last frame in the chunk, the one numbered `num_frames_to_include`. + // (Yes, this is backwards). We allocate token labels, and set tokens as + // final, but don't add any transitions. This may leave some states + // disconnected (e.g. due to chains of nonemitting arcs), but it's OK; we'll + // fix it when we generate the next chunk of lattice. + int32 frame = num_frames_to_include; + // Allocate state-ids for all tokens on this frame. + + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + /* If we included the final-costs at this stage, they will cause + non-final states to be pruned out from the end of the lattice. */ + BaseFloat final_cost; + { // This block computes final_cost + if (decoding_finalized_) { + if (final_costs_.empty()) { + final_cost = 0.0; /* No final-state survived, so treat all as final + * with probability One(). */ + } else { + auto iter = final_costs_.find(tok); + if (iter == final_costs_.end()) + final_cost = std::numeric_limits::infinity(); + else + final_cost = iter->second; + } + } else { + /* this is a `fake` final-cost used to guide pruning. It's as if we + set the betas (backward-probs) on the final frame to the + negatives of the corresponding alphas, so all tokens on the last + frae will be on a best path.. the extra_cost for each token + always corresponds to its alpha+beta on this assumption. We want + the final_cost here to correspond to the beta (backward-prob), so + we get that by final_cost = extra_cost - tot_cost. + [The tot_cost is the forward/alpha cost.] + */ + final_cost = tok->extra_cost - tok->tot_cost; + } + } + + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + if (final_cost < std::numeric_limits::infinity()) { + next_token2label_map[tok] = AllocateNewTokenLabel(); + StateId token_final_state = chunk_lat.AddState(); + LatticeArc::Label ilabel = 0, + olabel = (next_token2label_map[tok] = AllocateNewTokenLabel()); + chunk_lat.AddArc(state, + LatticeArc(ilabel, olabel, + LatticeWeight::One(), + token_final_state)); + chunk_lat.SetFinal(token_final_state, LatticeWeight(final_cost, 0.0)); + } + } + } + + // Go in reverse order over the remaining frames so we can create arcs as we + // go, and their destination-states will already be in the map. + for (int32 frame = num_frames_to_include; + frame >= num_frames_in_lattice_; frame--) { + // The conditional below is needed for the last frame of the utterance. + BaseFloat cost_offset = (frame < cost_offsets_.size() ? + cost_offsets_[frame] : 0.0); + + // For the first frame of the chunk, we need to make sure the states are + // the ones created by InitializeRawLatticeChunk() (where not pruned away). + if (frame == num_frames_in_lattice_ && num_frames_in_lattice_ != 0) { + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + auto iter = token2label_map_.find(tok); + KALDI_ASSERT(iter != token2label_map_.end()); + Label token_label = iter->second; + auto iter2 = token_label2state.find(token_label); + if (iter2 != token_label2state.end()) { + StateId state = iter2->second; + tok2state_map[tok] = state; + } else { + // Some states may have been pruned out, but we should still allocate + // them. They might have been part of chains of nonemitting arcs + // where the state became disconnected because the last chunk didn't + // include arcs starting at this frame. + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + } + } + } else if (frame != num_frames_to_include) { // We already created states + // for the last frame. + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + StateId state = chunk_lat.AddState(); + tok2state_map[tok] = state; + } + } + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + auto iter = tok2state_map.find(tok); + KALDI_ASSERT(iter != tok2state_map.end()); + StateId cur_state = iter->second; + for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) { + auto next_iter = tok2state_map.find(l->next_tok); + if (next_iter == tok2state_map.end()) { + // Emitting arcs from the last frame we're including -- ignore + // these. + KALDI_ASSERT(frame == num_frames_to_include); + continue; + } + StateId next_state = next_iter->second; + BaseFloat this_offset = (l->ilabel != 0 ? cost_offset : 0); + LatticeArc arc(l->ilabel, l->olabel, + LatticeWeight(l->graph_cost, l->acoustic_cost - this_offset), + next_state); + // Note: the epsilons get redundantly included at the end and beginning + // of successive chunks. These will get removed in the determinization. + chunk_lat.AddArc(cur_state, arc); + } + } + } + if (num_frames_in_lattice_ == 0) { + // This block locates the start token. NOTE: we use the fact that in the + // linked list of tokens, things are added at the head, so the start state + // must be at the tail. If this data structure is changed in future, we + // might need to explicitly store the start token as a class member. + Token *tok = active_toks_[0].toks; + if (tok == NULL) { + KALDI_WARN << "No tokens exist on start frame"; + return determinizer_.GetLattice(); // will be empty. + } + while (tok->next != NULL) + tok = tok->next; + Token *start_token = tok; + auto iter = tok2state_map.find(start_token); + KALDI_ASSERT(iter != tok2state_map.end()); + StateId start_state = iter->second; + chunk_lat.SetStart(start_state); + } + token2label_map_.swap(next_token2label_map); + + // bool finished_before_beam = + determinizer_.AcceptRawLatticeChunk(&chunk_lat); + // We are ignoring the return status, which say whether it finished before the beam. + + num_frames_in_lattice_ = num_frames_to_include; + } + + unordered_map token2final_cost; + unordered_map token_label2final_cost; + if (use_final_probs) { + ComputeFinalCosts(&token2final_cost, NULL, NULL); + for (const auto &p: token2final_cost) { + Token *tok = p.first; + BaseFloat cost = p.second; + auto iter = token2label_map_.find(tok); + if (iter != token2label_map_.end()) { + /* Some tokens may not have survived the pruned determinization. */ + Label token_label = iter->second; + bool ret = token_label2final_cost.insert({token_label, cost}).second; + KALDI_ASSERT(ret); /* Make sure it was inserted. */ + } + } + } + /* Note: these final-probs won't affect the next chunk, only the lattice + returned from GetLattice(). They are kind of temporaries. */ + determinizer_.SetFinalCosts(token_label2final_cost.empty() ? NULL : + &token_label2final_cost); + + return determinizer_.GetLattice(); +} + + +template +int32 LatticeIncrementalDecoderTpl::GetNumToksForFrame(int32 frame) { + int32 r = 0; + for (Token *tok = active_toks_[frame].toks; tok; tok = tok->next) r++; + return r; +} + + + +/* This utility function adds an arc to a Lattice, but where the source is a + CompactLatticeArc. If the CompactLatticeArc has a string with length greater + than 1, this will require adding extra states to `lat`. + */ +static void AddCompactLatticeArcToLattice( + const CompactLatticeArc &clat_arc, + LatticeArc::StateId src_state, + Lattice *lat) { + const std::vector &string = clat_arc.weight.String(); + size_t N = string.size(); + if (N == 0) { + LatticeArc arc; + arc.ilabel = 0; + arc.olabel = clat_arc.ilabel; + arc.nextstate = clat_arc.nextstate; + arc.weight = clat_arc.weight.Weight(); + lat->AddArc(src_state, arc); + } else { + LatticeArc::StateId cur_state = src_state; + for (size_t i = 0; i < N; i++) { + LatticeArc arc; + arc.ilabel = string[i]; + arc.olabel = (i == 0 ? clat_arc.ilabel : 0); + arc.nextstate = (i + 1 == N ? clat_arc.nextstate : lat->AddState()); + arc.weight = (i == 0 ? clat_arc.weight.Weight() : LatticeWeight::One()); + lat->AddArc(cur_state, arc); + cur_state = arc.nextstate; + } + } +} + + +void LatticeIncrementalDeterminizer::Init() { + non_final_redet_states_.clear(); + clat_.DeleteStates(); + final_arcs_.clear(); + forward_costs_.clear(); + arcs_in_.clear(); +} + +CompactLattice::StateId LatticeIncrementalDeterminizer::AddStateToClat() { + CompactLattice::StateId ans = clat_.AddState(); + forward_costs_.push_back(std::numeric_limits::infinity()); + KALDI_ASSERT(forward_costs_.size() == ans + 1); + arcs_in_.resize(ans + 1); + return ans; +} + +void LatticeIncrementalDeterminizer::AddArcToClat( + CompactLattice::StateId state, + const CompactLatticeArc &arc) { + BaseFloat forward_cost = forward_costs_[state] + + ConvertToCost(arc.weight); + if (forward_cost == std::numeric_limits::infinity()) + return; + int32 arc_idx = clat_.NumArcs(state); + clat_.AddArc(state, arc); + arcs_in_[arc.nextstate].push_back({state, arc_idx}); + if (forward_cost < forward_costs_[arc.nextstate]) + forward_costs_[arc.nextstate] = forward_cost; +} + +// See documentation in header +void LatticeIncrementalDeterminizer::IdentifyTokenFinalStates( + const CompactLattice &chunk_clat, + std::unordered_map *token_map) const { + token_map->clear(); + using StateId = CompactLattice::StateId; + using Label = CompactLatticeArc::Label; + + StateId num_states = chunk_clat.NumStates(); + for (StateId state = 0; state < num_states; state++) { + for (fst::ArcIterator aiter(chunk_clat, state); + !aiter.Done(); aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + if (arc.olabel >= kTokenLabelOffset && arc.olabel < kMaxTokenLabel) { + StateId nextstate = arc.nextstate; + auto r = token_map->insert({nextstate, arc.olabel}); + // Check consistency of labels on incoming arcs + KALDI_ASSERT(r.first->second == arc.olabel); + } + } + } +} + + + + +void LatticeIncrementalDeterminizer::GetNonFinalRedetStates() { + using StateId = CompactLattice::StateId; + non_final_redet_states_.clear(); + non_final_redet_states_.reserve(final_arcs_.size()); + + std::vector state_queue; + for (const CompactLatticeArc &arc: final_arcs_) { + // Note: we abuse the .nextstate field to store the state which is really + // the source of that arc. + StateId redet_state = arc.nextstate; + if (forward_costs_[redet_state] != std::numeric_limits::infinity()) { + // if it is accessible.. + if (non_final_redet_states_.insert(redet_state).second) { + // it was not already there + state_queue.push_back(redet_state); + } + } + } + // Add any states that are reachable from the states above. + while (!state_queue.empty()) { + StateId s = state_queue.back(); + state_queue.pop_back(); + for (fst::ArcIterator aiter(clat_, s); !aiter.Done(); + aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + StateId nextstate = arc.nextstate; + if (non_final_redet_states_.insert(nextstate).second) + state_queue.push_back(nextstate); // it was not already there + } + } +} + + +void LatticeIncrementalDeterminizer::InitializeRawLatticeChunk( + Lattice *olat, + unordered_map *token_label2state) { + using namespace fst; + + olat->DeleteStates(); + LatticeArc::StateId start_state = olat->AddState(); + olat->SetStart(start_state); + token_label2state->clear(); + + // redet_state_map maps from state-ids in clat_ to state-ids in olat. This + // will be the set of states from which the arcs to final-states in the + // canonical appended lattice leave (physically, these are in the .nextstate + // elements of arcs_, since we use that field for the source state), plus any + // states reachable from those states. + unordered_map redet_state_map; + + for (CompactLattice::StateId redet_state: non_final_redet_states_) + redet_state_map[redet_state] = olat->AddState(); + + // First, process any arcs leaving the non-final redeterminized states that + // are not to final-states. (What we mean by "not to final states" is, not to + // stats that are final in the `canonical appended lattice`.. they may + // actually be physically final in clat_, because we make clat_ what we want + // to return to the user. + for (CompactLattice::StateId redet_state: non_final_redet_states_) { + LatticeArc::StateId lat_state = redet_state_map[redet_state]; + + for (ArcIterator aiter(clat_, redet_state); + !aiter.Done(); aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + CompactLattice::StateId nextstate = arc.nextstate; + LatticeArc::StateId lat_nextstate = olat->NumStates(); + auto r = redet_state_map.insert({nextstate, lat_nextstate}); + if (r.second) { // Was inserted. + LatticeArc::StateId s = olat->AddState(); + KALDI_ASSERT(s == lat_nextstate); + } else { + // was not inserted -> was already there. + lat_nextstate = r.first->second; + } + CompactLatticeArc clat_arc(arc); + clat_arc.nextstate = lat_nextstate; + AddCompactLatticeArcToLattice(clat_arc, lat_state, olat); + } + clat_.DeleteArcs(redet_state); + clat_.SetFinal(redet_state, CompactLatticeWeight::Zero()); + } + + for (const CompactLatticeArc &arc: final_arcs_) { + // We abuse the `nextstate` field to store the source state. + CompactLattice::StateId src_state = arc.nextstate; + auto iter = redet_state_map.find(src_state); + if (forward_costs_[src_state] == std::numeric_limits::infinity()) + continue; /* Unreachable state */ + KALDI_ASSERT(iter != redet_state_map.end()); + LatticeArc::StateId src_lat_state = iter->second; + Label token_label = arc.ilabel; // will be == arc.olabel. + KALDI_ASSERT(token_label >= kTokenLabelOffset && + token_label < kMaxTokenLabel); + auto r = token_label2state->insert({token_label, + olat->NumStates()}); + LatticeArc::StateId dest_lat_state = r.first->second; + if (r.second) { // was inserted + LatticeArc::StateId new_state = olat->AddState(); + KALDI_ASSERT(new_state == dest_lat_state); + } + CompactLatticeArc new_arc; + new_arc.nextstate = dest_lat_state; + /* We convert the token-label to epsilon; it's not needed anymore. */ + new_arc.ilabel = new_arc.olabel = 0; + new_arc.weight = arc.weight; + AddCompactLatticeArcToLattice(new_arc, src_lat_state, olat); + } + + // Now deal with the initial-probs. Arcs from initial-states to + // redeterminized-states in the raw lattice have an olabel that identifies the + // id of that redeterminized-state in clat_, and a cost that is derived from + // its entry in forward_costs_. These forward-probs are used to get the + // pruned lattice determinization to behave correctly, and will be canceled + // out later on. + // + // In the paper this is the second-from-last bullet in Sec. 5.2. NOTE: in the + // paper we state that we only include such arcs for "each redeterminized + // state that is either initial in det(A) or that has an arc entering it from + // a state that is not a redeterminized state." In fact, we include these + // arcs for all redeterminized states. I realized that it won't make a + // difference to the outcome, and it's easier to do it this way. + for (CompactLattice::StateId state_id: non_final_redet_states_) { + BaseFloat forward_cost = forward_costs_[state_id]; + LatticeArc arc; + arc.ilabel = 0; + // The olabel (which appears where the word-id would) is what + // we call a 'state-label'. It identifies a state in clat_. + arc.olabel = state_id + kStateLabelOffset; + // It doesn't matter what field we put forward_cost in (or whether we + // divide it among them both; the effect on pruning is the same, and + // we will cancel it out later anyway. + arc.weight = LatticeWeight(forward_cost, 0); + auto iter = redet_state_map.find(state_id); + KALDI_ASSERT(iter != redet_state_map.end()); + arc.nextstate = iter->second; + olat->AddArc(start_state, arc); + } +} + +void LatticeIncrementalDeterminizer::GetRawLatticeFinalCosts( + const Lattice &raw_fst, + std::unordered_map *old_final_costs) { + LatticeArc::StateId raw_fst_num_states = raw_fst.NumStates(); + for (LatticeArc::StateId s = 0; s < raw_fst_num_states; s++) { + for (fst::ArcIterator aiter(raw_fst, s); !aiter.Done(); + aiter.Next()) { + const LatticeArc &value = aiter.Value(); + if (value.olabel >= (Label)kTokenLabelOffset && + value.olabel < (Label)kMaxTokenLabel) { + LatticeWeight final_weight = raw_fst.Final(value.nextstate); + if (final_weight == LatticeWeight::Zero() || + final_weight.Value2() != 0) { + KALDI_ERR << "Label " << value.olabel << " from state " << s + << " looks like a token-label but its next-state " + << value.nextstate << + " has unexpected final-weight " << final_weight.Value1() << ',' + << final_weight.Value2(); + } + auto r = old_final_costs->insert({value.olabel, + final_weight.Value1()}); + if (!r.second && r.first->second != final_weight.Value1()) { + // For any given token-label, all arcs in raw_fst with that + // olabel should go to the same state, so this should be + // impossible. + KALDI_ERR << "Unexpected mismatch in final-costs for tokens, " + << r.first->second << " vs " << final_weight.Value1(); + } + } + } + } +} + + +bool LatticeIncrementalDeterminizer::ProcessArcsFromChunkStartState( + const CompactLattice &chunk_clat, + std::unordered_map *state_map, + CompactLatticeWeight *extra_start_weight) { + using StateId = CompactLattice::StateId; + StateId clat_num_states = clat_.NumStates(); + + // Process arcs leaving the start state of chunk_clat. These arcs will have + // state-labels on them (unless this is the first chunk). + // For destination-states of those arcs, work out which states in + // clat_ they correspond to and update their forward_costs. + for (fst::ArcIterator aiter(chunk_clat, chunk_clat.Start()); + !aiter.Done(); aiter.Next()) { + const CompactLatticeArc &arc = aiter.Value(); + Label label = arc.ilabel; // ilabel == olabel; would be the olabel + // in a Lattice. + if (!(label >= kStateLabelOffset && + label - kStateLabelOffset < clat_num_states)) { + // The label was not a state-label. This should only be possible on the + // first chunk. + KALDI_ASSERT(state_map->empty()); + return true; // this is the first chunk. + } + StateId clat_state = label - kStateLabelOffset; + StateId chunk_state = arc.nextstate; + auto p = state_map->insert({chunk_state, clat_state}); + StateId dest_clat_state = p.first->second; + // We deleted all its arcs in InitializeRawLatticeChunk + KALDI_ASSERT(clat_.NumArcs(clat_state) == 0); + /* + In almost all cases, dest_clat_state and clat_state will be the same state; + but there may be situations where two arcs with different state-labels + left the start state and entered the same next-state in chunk_clat; and in + these cases, they will be different. + + We didn't address this issue in the paper (or actually realize it could be + a problem). What we do is pick one of the clat_states as the "canonical" + one, and redirect all incoming transitions of the others to enter the + "canonical" one. (Search below for new_in_arc.nextstate = + dest_clat_state). + */ + if (clat_state != dest_clat_state) { + // Check that the start state isn't getting merged with any other state. + // If this were possible, we'd need to deal with it specially, but it + // can't be, because to be merged, 2 states must have identical arcs + // leaving them with identical weights, so we'd need to have another state + // on frame 0 identical to the start state, which is not possible if the + // lattice is deterministic and epsilon-free. + KALDI_ASSERT(clat_state != 0 && dest_clat_state != 0); + } + + // in_weight is an extra weight that we'll include on arcs entering this + // state from the previous chunk. We need to cancel out + // `forward_costs[clat_state]`, which was included in the corresponding arc + // in the raw lattice for pruning purposes; and we need to include + // the weight from the start-state of `chunk_clat` to this state. + CompactLatticeWeight extra_weight_in = arc.weight; + extra_weight_in.SetWeight( + fst::Times(extra_weight_in.Weight(), + LatticeWeight(-forward_costs_[clat_state], 0.0))); + + if (clat_state == 0) { + // if clat_state is the star-state of clat_ (state 0), we can't modify + // incoming arcs; we need to modify outgoing arcs, but we'll do that + // later, after we add them. + *extra_start_weight = extra_weight_in; + forward_costs_[0] = forward_costs_[0] + ConvertToCost(extra_weight_in); + continue; + } + + // Note: 0 is the start state of clat_. This was checked. + forward_costs_[clat_state] = (clat_state == 0 ? 0 : + std::numeric_limits::infinity()); + std::vector > arcs_in; + arcs_in.swap(arcs_in_[clat_state]); + for (auto p: arcs_in) { + // Note: we'll be doing `continue` below if this input arc came from + // another redeterminized-state, because we did DeleteStates() for them in + // InitializeRawLatticeChunk(). Those arcs will be transferred + // from chunk_clat later on. + CompactLattice::StateId src_state = p.first; + int32 arc_pos = p.second; + if (arc_pos >= (int32)clat_.NumArcs(src_state)) + continue; + fst::MutableArcIterator aiter(&clat_, src_state); + aiter.Seek(arc_pos); + if (aiter.Value().nextstate != clat_state) + continue; // This arc record has become invalidated. + CompactLatticeArc new_in_arc(aiter.Value()); + // In most cases we will have dest_clat_state == clat_state, so the next + // line won't change the value of .nextstate + new_in_arc.nextstate = dest_clat_state; + new_in_arc.weight = fst::Times(new_in_arc.weight, extra_weight_in); + aiter.SetValue(new_in_arc); + + BaseFloat new_forward_cost = forward_costs_[src_state] + + ConvertToCost(new_in_arc.weight); + if (new_forward_cost < forward_costs_[dest_clat_state]) + forward_costs_[dest_clat_state] = new_forward_cost; + arcs_in_[dest_clat_state].push_back(p); + } + } + return false; // this is not the first chunk. +} + +void LatticeIncrementalDeterminizer::ReweightStartState( + CompactLatticeWeight &extra_start_weight) { + for (fst::MutableArcIterator aiter(&clat_, 0); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc arc(aiter.Value()); + arc.weight = fst::Times(extra_start_weight, arc.weight); + aiter.SetValue(arc); + } +} + +void LatticeIncrementalDeterminizer::TransferArcsToClat( + const CompactLattice &chunk_clat, + bool is_first_chunk, + const std::unordered_map &state_map, + const std::unordered_map &chunk_state_to_token, + const std::unordered_map &old_final_costs) { + using StateId = CompactLattice::StateId; + StateId chunk_num_states = chunk_clat.NumStates(); + + // Now transfer arcs from chunk_clat to clat_. + for (StateId chunk_state = (is_first_chunk ? 0 : 1); + chunk_state < chunk_num_states; chunk_state++) { + auto iter = state_map.find(chunk_state); + if (iter == state_map.end()) { + KALDI_ASSERT(chunk_state_to_token.count(chunk_state) != 0); + // Don't process token-final states. Anyway they have no arcs leaving + // them. + continue; + } + StateId clat_state = iter->second; + + // We know that this point that `clat_state` is not a token-final state + // (see glossary for definition) as if it were, we would have done + // `continue` above. + // + // Only in the last chunk of the lattice would be there be a final-prob on + // states that are not `token-final states`; these final-probs would + // normally all be Zero() at this point. So in almost all cases the following + // call will do nothing. + clat_.SetFinal(clat_state, chunk_clat.Final(chunk_state)); + + // Process arcs leaving this state. + for (fst::ArcIterator aiter(chunk_clat, chunk_state); + !aiter.Done(); aiter.Next()) { + CompactLatticeArc arc(aiter.Value()); + + auto next_iter = state_map.find(arc.nextstate); + if (next_iter != state_map.end()) { + // The normal case (when the .nextstate has a corresponding + // state in clat_) is very simple. Just copy the arc over. + arc.nextstate = next_iter->second; + KALDI_ASSERT(arc.ilabel < kTokenLabelOffset || + arc.ilabel > kMaxTokenLabel); + AddArcToClat(clat_state, arc); + } else { + // This is the case when the arc is to a `token-final` state (see + // glossary.) + + // TODO: remove the following slightly excessive assertion? + KALDI_ASSERT(chunk_clat.Final(arc.nextstate) != CompactLatticeWeight::Zero() && + arc.olabel >= (Label)kTokenLabelOffset && + arc.olabel < (Label)kMaxTokenLabel && + chunk_state_to_token.count(arc.nextstate) != 0 && + old_final_costs.count(arc.olabel) != 0); + + // Include the final-cost of the next state (which should be final) + // in arc.weight. + arc.weight = fst::Times(arc.weight, + chunk_clat.Final(arc.nextstate)); + + auto cost_iter = old_final_costs.find(arc.olabel); + KALDI_ASSERT(cost_iter != old_final_costs.end()); + BaseFloat old_final_cost = cost_iter->second; + + // `arc` is going to become an element of final_arcs_. These + // contain information about transitions from states in clat_ to + // `token-final` states (i.e. states that have a token-label on the arc + // to them and that are final in the canonical compact lattice). + // We subtract the old_final_cost as it was just a temporary cost + // introduced for pruning purposes. + arc.weight.SetWeight(fst::Times(arc.weight.Weight(), + LatticeWeight{-old_final_cost, 0.0})); + // In a slight abuse of the Arc data structure, the nextstate is set to + // the source state. The label (ilabel == olabel) indicates the + // token it is associated with. + arc.nextstate = clat_state; + final_arcs_.push_back(arc); + } + } + } + +} + +bool LatticeIncrementalDeterminizer::AcceptRawLatticeChunk( + Lattice *raw_fst) { + using Label = CompactLatticeArc::Label; + using StateId = CompactLattice::StateId; + + // old_final_costs is a map from a `token-label` (see glossary) to the + // associated final-prob in a final-state of `raw_fst`, that is associated + // with that Token. These are Tokens that were active at the end of the + // chunk. The final-probs may arise from beta (backward) costs, introduced + // for pruning purposes, and/or from final-probs in HCLG. Those costs will + // not be included in anything we store permamently in this class; they used + // only to guide pruned determinization, and we will use `old_final_costs` + // later to cancel them out. + std::unordered_map old_final_costs; + GetRawLatticeFinalCosts(*raw_fst, &old_final_costs); + + CompactLattice chunk_clat; + bool determinized_till_beam = DeterminizeLatticePhonePrunedWrapper( + trans_model_, raw_fst, config_.lattice_beam, &chunk_clat, + config_.det_opts); + + TopSortCompactLatticeIfNeeded(&chunk_clat); + + std::unordered_map chunk_state_to_token; + IdentifyTokenFinalStates(chunk_clat, + &chunk_state_to_token); + + StateId chunk_num_states = chunk_clat.NumStates(); + if (chunk_num_states == 0) { + // This will be an error but user-level calling code can detect it from the + // lattice being empty. + KALDI_WARN << "Empty lattice, something went wrong."; + clat_.DeleteStates(); + return false; + } + + StateId start_state = chunk_clat.Start(); // would be 0. + KALDI_ASSERT(start_state == 0); + + // Process arcs leaving the start state of chunk_clat. Unless this is the + // first chunk in the lattice, all arcs leaving the start state of chunk_clat + // will have `state labels` on them (identifying redeterminized-states in + // clat_), and will transition to a state in `chunk_clat` that we can identify + // with that redeterminized-state. + + // state_map maps from (non-initial, non-token-final state s in chunk_clat) to + // a state in clat_. + std::unordered_map state_map; + + + CompactLatticeWeight extra_start_weight = CompactLatticeWeight::One(); + bool is_first_chunk = ProcessArcsFromChunkStartState(chunk_clat, &state_map, + &extra_start_weight); + + // Remove any existing arcs in clat_ that leave redeterminized-states, and + // make those states non-final. Below, we'll add arcs leaving those states + // (and possibly new final-probs.) + for (StateId clat_state: non_final_redet_states_) { + clat_.DeleteArcs(clat_state); + clat_.SetFinal(clat_state, CompactLatticeWeight::Zero()); + } + + // The previous final-arc info is no longer relevant; we'll recreate it below. + final_arcs_.clear(); + + // assume chunk_lat.Start() == 0; we asserted it above. Allocate state-ids + // for all remaining states in chunk_clat, except for token-final states. + for (StateId state = (is_first_chunk ? 0 : 1); + state < chunk_num_states; state++) { + if (chunk_state_to_token.count(state) != 0) + continue; // these `token-final` states don't get a state allocated. + + StateId new_clat_state = clat_.NumStates(); + if (state_map.insert({state, new_clat_state}).second) { + // If it was inserted then we need to actually allocate that state + StateId s = AddStateToClat(); + KALDI_ASSERT(s == new_clat_state); + } // else do nothing; it would have been a redeterminized-state and no + } // allocation is needed since they already exist in clat_. and + // in state_map. + + if (is_first_chunk) { + auto iter = state_map.find(start_state); + KALDI_ASSERT(iter != state_map.end()); + CompactLattice::StateId clat_start_state = iter->second; + KALDI_ASSERT(clat_start_state == 0); // topological order. + clat_.SetStart(clat_start_state); + forward_costs_[clat_start_state] = 0.0; + } + + TransferArcsToClat(chunk_clat, is_first_chunk, + state_map, chunk_state_to_token, old_final_costs); + + if (extra_start_weight != CompactLatticeWeight::One()) + ReweightStartState(extra_start_weight); + + GetNonFinalRedetStates(); + + return determinized_till_beam; +} + + + +void LatticeIncrementalDeterminizer::SetFinalCosts( + const unordered_map *token_label2final_cost) { + if (final_arcs_.empty()) { + KALDI_WARN << "SetFinalCosts() called when final_arcs_.empty()... possibly " + "means you are calling this after Finalize()? Not allowed: could " + "indicate a code error. Or possibly decoding failed somehow."; + } + + /* + prefinal states a terminology that does not appear in the paper. What it + means is: the set of states that have an arc with a Token-label as the label + leaving them in the canonical appended lattice. + */ + std::unordered_set &prefinal_states(temp_); + prefinal_states.clear(); + for (const auto &arc: final_arcs_) { + /* Caution: `state` is actually the state the arc would + leave from in the canonical appended lattice; we just store + that in the .nextstate field. */ + CompactLattice::StateId state = arc.nextstate; + prefinal_states.insert(state); + } + + for (int32 state: prefinal_states) + clat_.SetFinal(state, CompactLatticeWeight::Zero()); + + + for (const CompactLatticeArc &arc: final_arcs_) { + Label token_label = arc.ilabel; + /* Note: we store the source state in the .nextstate field. */ + CompactLattice::StateId src_state = arc.nextstate; + BaseFloat graph_final_cost; + if (token_label2final_cost == NULL) { + graph_final_cost = 0.0; + } else { + auto iter = token_label2final_cost->find(token_label); + if (iter == token_label2final_cost->end()) + continue; + else + graph_final_cost = iter->second; + } + /* It might seem odd to set a final-prob on the src-state of the arc.. + the point is that the symbol on the arc is a token-label, which should not + appear in the lattice the user sees, so after that token-label is removed + the arc would just become a final-prob. + */ + clat_.SetFinal(src_state, + fst::Plus(clat_.Final(src_state), + fst::Times(arc.weight, + CompactLatticeWeight( + LatticeWeight(graph_final_cost, 0), {})))); + } +} + + + + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeIncrementalDecoderTpl, decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, + decoder::StdToken>; +template class LatticeIncrementalDecoderTpl, + decoder::StdToken>; +template class LatticeIncrementalDecoderTpl; + +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl, + decoder::BackpointerToken>; +template class LatticeIncrementalDecoderTpl; + +} // end namespace kaldi. diff --git a/src/decoder/lattice-incremental-decoder.h b/src/decoder/lattice-incremental-decoder.h new file mode 100644 index 000000000..7abc37017 --- /dev/null +++ b/src/decoder/lattice-incremental-decoder.h @@ -0,0 +1,752 @@ +// decoder/lattice-incremental-decoder.h + +// Copyright 2019 Zhehuai Chen, Daniel Povey + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ +#define KALDI_DECODER_LATTICE_INCREMENTAL_DECODER_H_ + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/grammar-fst.h" +#include "lattice-faster-decoder.h" + +namespace kaldi { +/** + The normal decoder, lattice-faster-decoder.h, sometimes has an issue when + doing real-time applications with long utterances, that each time you get the + lattice the lattice determinization can take a considerable amount of time; + this introduces latency. This version of the decoder spreads the work of + lattice determinization out throughout the decoding process. + + NOTE: + + Please see https://www.danielpovey.com/files/ *TBD* .pdf for a technical + explanation of what is going on here. + + GLOSSARY OF TERMS: + chunk: We do the determinization on chunks of frames; these + may coincide with the chunks on which the user calls + AdvanceDecoding(). The basic idea is to extract chunks + of the raw lattice and determinize them individually, but + it gets much more complicated than that. The chunks + should normally be at least as long as a word (let's say, + at least 20 frames), or the overhead of this algorithm + might become excessive and affect RTF. + + raw lattice chunk: A chunk of raw (i.e. undeterminized) lattice + that we will determinize. In the paper this corresponds + to the FST B that is described in Section 5.2. + + token_label, state_label / token-label, state-label: + + In the paper these are both referred to as `state labels` (these are + special, large integer id's that refer to states in the undeterminized + lattice and in the the determinized lattice); but we use two separate + terms here, for more clarity, when referring to the undeterminized + vs. determinized lattice. + + token_label conceptually refers to states in the + raw lattice, but we don't materialize the entire + raw lattice as a physical FST and and these tokens + are actually tokens (template type Token) held by + the decoder + + state_label when used in this code refers specifically + to labels that identify states in the determinized + lattice (i.e. state indexes in lat_). + + token-final state + A state in a raw lattice or in a determinized chunk that has an arc + entering it that has a `token-label` on it (as defined above). + These states will have nonzero final-probs. + + redeterminized-non-splice-state, aka ns_redet: + A redeterminized state which is not also a splice state; + refer to the paper for explanation. In the already-determinized + part this means a redeterminized state which is not final. + + canonical appended lattice: This is the appended compact lattice + that we conceptually have (i.e. what we described in the paper). + The difference from the "actual appended lattice" stored + in LatticeIncrementalDeterminizer::clat_ is that the + actual appended lattice has all its final-arcs replaced with + final-probs, and we keep the real final-arcs "on the side" in a + separate data structure. The final-probs in clat_ aren't + necessarily related to the costs on the final-arcs; instead + they can have arbitrary values passed in by the user (e.g. + if we want to include final-probs). This means that the + clat_ can be returned without modification to the user who wants + a partially determinized result. + + final-arc: An arc in the canonical appended CompactLattice which + goes to a final-state. These arcs will have `state-labels` as + their labels. + + */ +struct LatticeIncrementalDecoderConfig { + // All the configuration values until det_opts are the same as in + // LatticeFasterDecoder. For clarity we repeat them rather than inheriting. + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + BaseFloat beam_delta; // has nothing to do with beam_ratio + BaseFloat hash_ratio; + BaseFloat prune_scale; // Note: we don't make this configurable on the command line, + // it's not a very important parameter. It affects the + // algorithm that prunes the tokens as we go. + // Most of the options inside det_opts are not actually queried by the + // LatticeIncrementalDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeIncremental. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + // The configuration values from this point on are specific to the + // incremental determinization. See where they are registered for + // explanation. + // Caution: these are only inspected in UpdateLatticeDeterminization(). + // If you call + int32 determinize_max_delay; + int32 determinize_min_chunk_size; + + + LatticeIncrementalDecoderConfig() + : beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1), + determinize_max_delay(60), + determinize_min_chunk_size(20) { + det_opts.minimize = false; + } + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, + "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, + "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, + "Interval (in frames) at " + "which to prune tokens"); + opts->Register("beam-delta", &beam_delta, + "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, + "Setting used in decoder to " + "control hash behavior"); + opts->Register("determinize-max-delay", &determinize_max_delay, + "Maximum frames of delay between decoding a frame and " + "determinizing it"); + opts->Register("determinize-min-chunk-size", &determinize_min_chunk_size, + "Minimum chunk size used in determinization"); + + } + void Check() const { + if (!(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 && + min_active <= max_active && prune_interval > 0 && + beam_delta > 0.0 && hash_ratio >= 1.0 && + prune_scale > 0.0 && prune_scale < 1.0 && + determinize_max_delay > determinize_min_chunk_size && + determinize_min_chunk_size > 0)) + KALDI_ERR << "Invalid options given to decoder"; + /* Minimization of the chunks is not compatible withour algorithm (or at + least, would require additional complexity to implement.) */ + if (det_opts.minimize || !det_opts.word_determinize) + KALDI_ERR << "Invalid determinization options given to decoder."; + } +}; + + + +/** + This class is used inside LatticeIncrementalDecoderTpl; it handles + some of the details of incremental determinization. + https://www.danielpovey.com/files/ *TBD*.pdf for the paper. + +*/ +class LatticeIncrementalDeterminizer { + public: + using Label = typename LatticeArc::Label; /* Actualy the same labels appear + in both lattice and compact + lattice, so we don't use the + specific type all the time but + just say 'Label' */ + LatticeIncrementalDeterminizer( + const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config): + trans_model_(trans_model), config_(config) { } + + // Resets the lattice determinization data for new utterance + void Init(); + + // Returns the current determinized lattice. + const CompactLattice &GetDeterminizedLattice() const { return clat_; } + + /** + Starts the process of creating a raw lattice chunk. (Search the glossary + for "raw lattice chunk"). This just sets up the initial states and + redeterminized-states in the chunk. Relates to sec. 5.2 in the paper, + specifically the initial-state i and the redeterminized-states. + + After calling this, the caller would add the remaining arcs and states + to `olat` and then call AcceptRawLatticeChunk() with the result. + + @param [out] olat The lattice to be (partially) created + + @param [out] token_label2state This function outputs to here + a map from `token-label` to the state we created for + it in *olat. See glossary for `token-label`. + The keys actually correspond to the .nextstate fields + in the arcs in final_arcs_; values are states in `olat`. + See the last bullet point before Sec. 5.3 in the paper. + */ + void InitializeRawLatticeChunk( + Lattice *olat, + unordered_map *token_label2state); + + /** + This function accepts the raw FST (state-level lattice) corresponding to a + single chunk of the lattice, determinizes it and appends it to this->clat_. + Unless this was the + + Note: final-probs in `raw_fst` are treated specially: they are used to + guide the pruned determinization, but when you call GetLattice() it will be + -- except for pruning effects-- as if all nonzero final-probs in `raw_fst` + were: One() if final_costs == NULL; else the value present in `final_costs`. + + @param [in] raw_fst (Consumed destructively). The input + raw (state-level) lattice. Would correspond to the + FST A in the paper if first_frame == 0, and B + otherwise. + + @return returns false if determinization finished earlier than the beam + or the determinized lattice was empty; true otherwise. + + NOTE: if this is not the final chunk, you will probably want to call + SetFinalCosts() directly after calling this. + */ + bool AcceptRawLatticeChunk(Lattice *raw_fst); + + /* + Sets final-probs in `clat_`. Must only be called if the final chunk + has not been processed. (The final chunk is whenever GetLattice() is + called with finalize == true). + + The reason this is a separate function from AcceptRawLatticeChunk() is that + there may be situations where a user wants to get the latice with + final-probs in it, after previously getting it without final-probs; or + vice versa. By final-probs, we mean the Final() probabilities in the + HCLG (decoding graph; this->fst_). + + @param [in] token_label2final_cost A map from the token-label + corresponding to Tokens active on the final frame of the + lattice in the object, to the final-cost we want to use for + those tokens. If NULL, it means all Tokens should be treated + as final with probability One(). If non-NULL, and a particular + token-label is not a key of this map, it means that Token + corresponded to a state that was not final in HCLG; and + such tokens will be treated as non-final. However, + if this would result in no states in the lattice being final, + we will treat all Tokens as final with probability One(), + a warning will be printed (this should not happen.) + */ + void SetFinalCosts(const unordered_map *token_label2final_cost = NULL); + + const CompactLattice &GetLattice() { return clat_; } + + // kStateLabelOffset is what we add to state-ids in clat_ to produce labels + // to identify them in the raw lattice chunk + // kTokenLabelOffset is where we start allocating labels corresponding to Tokens + // (these correspond with raw lattice states); + enum { kStateLabelOffset = (int)1e8, kTokenLabelOffset = (int)2e8, kMaxTokenLabel = (int)3e8 }; + + private: + + // [called from AcceptRawLatticeChunk()] + // Gets the final costs from token-final states in the raw lattice (see + // glossary for definition). These final costs will be subtracted after + // determinization; in the normal case they are `temporaries` used to guide + // pruning. NOTE: the index of the array is not the FST state that is final, + // but the label on arcs entering it (these will be `token-labels`). Each + // token-final state will have the same label on all arcs entering it. + // + // `old_final_costs` is assumed to be empty at entry. + void GetRawLatticeFinalCosts(const Lattice &raw_fst, + std::unordered_map *old_final_costs); + + // Sets up non_final_redet_states_. See documentation for that variable. + void GetNonFinalRedetStates(); + + /** [called from AcceptRawLatticeChunk()] Processes arcs that leave the + start-state of `chunk_clat` (if this is not the first chunk); does nothing + if this is the first chunk. This includes using the `state-labels` to + work out which states in clat_ these states correspond to, and writing + that mapping to `state_map`. + + Also modifies forward_costs_, because it has to do a kind of reweighting + of the clat states that are the values it puts in `state_map`, to take + account of the probabilities on the arcs from the start state of + chunk_clat to the states corresponding to those redeterminized-states + (i.e. the states in clat corresponding to the values it puts in + `*state_map`). It also modifies arcs_in_, mostly because there + are rare cases when we end up `merging` sets of those redeterminized-states, + because the determinization process mapped them to a single state, + and that means we need to reroute the arcs into members of that + set into one single member (which will appear as a value in + `*state_map`). + + @param [in] chunk_clat The determinized chunk of lattice we are + processing + @param [out] state_map Mapping from states in chunk_clat to + the state in clat_ they correspond to. + @param [out] extra_start_weight If the start-state of + clat_ (its state 0) needs to be modified as + if its incoming arcs were multiplied by + `extra_start_weight`, this isn't possible + using the `in_arcs_` data-structure, + so we remember the extra weight and multiply + it in later, after processing arcs leaving + the start state of clat_. This is set + only if the start-state of clat_ is a + redeterminized state. + @return Returns true if this is the first chunk. + */ + bool ProcessArcsFromChunkStartState( + const CompactLattice &chunk_clat, + std::unordered_map *state_map, + CompactLatticeWeight *extra_start_weight); + + /** + This function, called from AcceptRawLatticeChunk(), takes care of an + unusual situation where we need to reweight the start state of clat_. This + `extra_start_weight` is to be thought of as an extra `incoming` weight, and + we need to left-multiply all the arcs leaving the start state, by it. + + This function does not need to modify forward_costs_; that will + already have been done by ProcessArcsFromChunkStartState(). + */ + void ReweightStartState(CompactLatticeWeight &extra_start_weight); + + + /** + This function, called from AcceptRawLatticeChunk(), transfers arcs from + `chunk_clat` to clat_. For those arcs that have `token-labels` on them, + they don't get written to clat_ but instead are stored in the arcs_ array. + + @param [in] chunk_clat The determinized lattice for the chunk + we are processing; this is the source of the arcs + we are moving. + @param [in] is_first_chunk True if this is the first chunk in the + utterance; it's needed because if it is, we + will also transfer arcs from the start state of + chunk_clat. + @param [in] state_map Map from state-ids in chunk_clat to state-ids + in clat_. + @param [in] chunk_state_to_token Map from `token-final states` + (see glossary) in chunk_clat, to the token-label + on arcs entering those states. + @param [in] old_final_costs Map from token-label to the + final-costs that were on the corresponding + token-final states in the undeterminized lattice; + these final-costs need to be removed when + we record the weights in final_arcs_, because + they were just temporary. + */ + void TransferArcsToClat( + const CompactLattice &chunk_clat, + bool is_first_chunk, + const std::unordered_map &state_map, + const std::unordered_map &chunk_state_to_token, + const std::unordered_map &old_final_costs); + + + + /** + Adds one arc to `clat_`. It's like clat_.AddArc(state, arc), except + it also modifies arcs_in_ and forward_costs_. + */ + void AddArcToClat(CompactLattice::StateId state, + const CompactLatticeArc &arc); + CompactLattice::StateId AddStateToClat(); + + + // Identifies token-final states in `chunk_clat`; see glossary above for + // definition of `token-final`. This function outputs a map from such states + // in chunk_clat, to the `token-label` on arcs entering them. (It is not + // possible that the same state would have multiple arcs entering it with + // different token-labels, or some arcs entering with one token-label and some + // another, or be both initial and have such arcs; this is true due to how we + // construct the raw lattice.) + void IdentifyTokenFinalStates( + const CompactLattice &chunk_clat, + std::unordered_map *token_map) const; + + // trans_model_ is needed by DeterminizeLatticePhonePrunedWrapper() which this + // class calls. + const TransitionModel &trans_model_; + // config_ is needed by DeterminizeLatticePhonePrunedWrapper() which this + // class calls. + const LatticeIncrementalDecoderConfig &config_; + + + // Contains the set of redeterminized-states which are not final in the + // canonical appended lattice. Since the final ones don't physically appear + // in clat_, this means the set of redeterminized-states which are physically + // in clat_. In code terms, this means set of .first elements in final_arcs, + // plus whatever other states in clat_ are reachable from such states. + std::unordered_set non_final_redet_states_; + + + // clat_ is the appended lattice (containing all chunks processed so + // far), except its `final-arcs` (i.e. arcs which in the canonical + // lattice would go to final-states) are not present (they are stored + // separately in final_arcs_) and states which in the canonical lattice + // should have final-arcs leaving them will instead have a final-prob. + CompactLattice clat_; + + + // arcs_in_ is indexed by (state-id in clat_), and is a list of + // arcs that come into this state, in the form (prev-state, + // arc-index). CAUTION: not all these input-arc records will always + // be valid (some may be out-of-date, and may refer to an out-of-range + // arc or an arc that does not point to this state). But all + // input arcs will always be listed. + std::vector > > arcs_in_; + + // final_arcs_ contains arcs which would appear in the canonical appended + // lattice but for implementation reasons are not physically present in clat_. + // These are arcs to final states in the canonical appended lattice. The + // .first elements are the source states in clat_ (these will all be elements + // of non_final_redet_states_); the .nextstate elements of the arcs does not + // contain a physical state, but contain state-labels allocated by + // AllocateNewStateLabel(). + std::vector final_arcs_; + + // forward_costs_, indexed by the state-id in clat_, stores the alpha + // (forward) costs, i.e. the minimum cost from the start state to each state + // in clat_. This is relevant for pruned determinization. The BaseFloat can + // be thought of as the sum of a Value1() + Value2() in a LatticeWeight. + std::vector forward_costs_; + + // temporary used in a function, kept here to avoid excessive reallocation. + std::unordered_set temp_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDeterminizer); +}; + + +/** This is an extention to the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The main difference is the incremental determinization which will be + discussed in the function GetLattice(). This means that the work of determinizatin + isn't done all at once at the end of the file, but incrementally while decoding. + See the comment at the top of this file for more explanation. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder with is expected to be of type + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeIncrementalDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeIncrementalDecoderTpl(const FST &fst, const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeIncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &config, + FST *fst, const TransitionModel &trans_model); + + void SetOptions(const LatticeIncrementalDecoderConfig &config) { config_ = config; } + + const LatticeIncrementalDecoderConfig &GetOptions() const { return config_; } + + ~LatticeIncrementalDecoderTpl(); + + /** + CAUTION: it's unlikely that you will ever want to call this function. In a + scenario where you have the entire file and just want to decode it, there + is no point using this decoder. + + An example of how to do decoding together with incremental + determinization. It decodes until there are no more frames left in the + "decodable" object. + + In this example, config_.determinize_delay, config_.determinize_period + and config_.determinize_max_active are used to determine the time to + call GetLattice(). + + Users will probably want to use appropriate combinations of + AdvanceDecoding() and GetLattice() to build their application; this just + gives you some idea how. + + The function returns true if any kind of traceback is available (not + necessarily from a final state). + */ + bool Decode(DecodableInterface *decodable); + + /// says whether a final-state was active on the last frame. If it was not, + /// the lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /** + This decoder has no GetBestPath() function. + If you need that functionality you should probably use lattice-incremental-online-decoder.h, + which makes it very efficient to obtain the best path. */ + + /** + This GetLattice() function returns the lattice containing + `num_frames_to_decode` frames; this will be all frames decoded so + far, if you let num_frames_to_decode == NumFramesDecoded(), + but it will generally be better to make it a few frames less than + that to avoid the lattice having too many active states at + the end. + + @param [in] num_frames_to_include The number of frames that you want + to be included in the lattice. Must be >= + NumFramesInLattice() and <= NumFramesDecoded(). + + @param [in] use_final_probs True if you want the final-probs + of HCLG to be included in the output lattice. Must not + be set to true if num_frames_to_include != + NumFramesDecoded(). Must be set to true if you have + previously called FinalizeDecoding(). + + (If no state was final on frame `num_frames_to_include`, the + final-probs won't be included regardless of + `use_final_probs`; you can test whether this + was the case by calling ReachedFinal(). + + @return clat The CompactLattice representing what has been decoded + up until `num_frames_to_include` (e.g., LatticeStateTimes() + on this lattice would return `num_frames_to_include`). + + See also UpdateLatticeDeterminizaton(). Caution: this const ref + is only valid until the next time you call AdvanceDecoding() or + GetLattice(). + + CAUTION: the lattice may contain disconnnected states; you should + call Connect() on the output before writing it out. + */ + const CompactLattice &GetLattice(int32 num_frames_to_include, + bool use_final_probs = false); + + /* + Returns the number of frames in the currently-determinized part of the + lattice which will be a number in [0, NumFramesDecoded()]. It will + be the largest number that GetLattice() was called with, but note + that GetLattice() may be called from UpdateLatticeDeterminization(). + + Made available in case the user wants to give that same number to + GetLattice(). + */ + int NumFramesInLattice() const { return num_frames_in_lattice_; } + + /** + InitDecoding initializes the decoding, and should only be used if you + intend to call AdvanceDecoding(). If you call Decode(), you don't need to + call this. You can also call InitDecoding if you have already decoded an + utterance and want to start with a new utterance. + */ + void InitDecoding(); + + /** + This will decode until there are no more frames ready in the decodable + object. You can keep calling it each time more frames become available + (this is the normal pattern in a real-time/online decoding scenario). + If max_num_frames is specified, it specifies the maximum number of frames + the function will decode before returning. + */ + void AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames = -1); + + + /** FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + more information. It returns the difference between the best (final-cost + plus cost) of any token on the final frame, and the best cost of any token + on the final frame. If it is infinity it means no final-states were + present on the final frame. It will usually be nonnegative. If it not + too positive (e.g. < 5 is my first guess, but this is not tested) you can + take it as a good indication that we reached the final-state with + reasonable likelihood. */ + BaseFloat FinalRelativeCost() const; + + /** Returns the number of frames decoded so far. */ + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + /** + Finalizes the decoding, doing an extra pruning step on the last frame + that uses the final-probs. May be called only once. + */ + void FinalizeDecoding(); + + protected: + /* Some protected things are needed in LatticeIncrementalOnlineDecoderTpl. */ + + /** NOTE: for parts the internal implementation that are shared with LatticeFasterDecoer, + we have removed the comments.*/ + inline static void DeleteForwardLinks(Token *tok); + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + int32 num_toks; /* Note: you can only trust `num_toks` if must_prune_tokens + * == false, because it is only set in + * PruneTokensForFrame(). */ + TokenList() + : toks(NULL), must_prune_forward_links(true), must_prune_tokens(true), + num_toks(-1) {} + }; + using Elem = typename HashList::Elem; + void PossiblyResizeHash(size_t num_toks); + inline Token *FindOrAddToken(StateId state, int32 frame_plus_one, + BaseFloat tot_cost, Token *backpointer, bool *changed); + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta); + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + void PruneForwardLinksFinal(); + void PruneTokensForFrame(int32 frame_plus_one); + void PruneActiveTokens(BaseFloat delta); + BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, + Elem **best_elem); + BaseFloat ProcessEmitting(DecodableInterface *decodable); + void ProcessNonemitting(BaseFloat cost_cutoff); + + HashList toks_; + std::vector active_toks_; // indexed by frame. + std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector tmp_array_; // used in GetCutoff. + const FST *fst_; + bool delete_fst_; + std::vector cost_offsets_; + int32 num_toks_; + bool warned_; + bool decoding_finalized_; + + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + /*********************** + Variables below this point relate to the incremental + determinization. + *********************/ + LatticeIncrementalDecoderConfig config_; + /** Much of the the incremental determinization algorithm is encapsulated in + the determinize_ object. */ + LatticeIncrementalDeterminizer determinizer_; + + + /* Just a temporary used in a function; stored here to avoid reallocation. */ + unordered_map temp_token_map_; + + /** num_frames_in_lattice_ is the highest `num_frames_to_include_` argument + for any prior call to GetLattice(). */ + int32 num_frames_in_lattice_; + + // A map from Token to its token_label. Will contain an entry for + // each Token in active_toks_[num_frames_in_lattice_]. + unordered_map token2label_map_; + + // A temporary used in a function, kept here to avoid reallocation. + unordered_map token2label_map_temp_; + + // we allocate a unique id for each Token + Label next_token_label_; + + inline Label AllocateNewTokenLabel() { return next_token_label_++; } + + + // There are various cleanup tasks... the the toks_ structure contains + // singly linked lists of Token pointers, where Elem is the list type. + // It also indexes them in a hash, indexed by state (this hash is only + // maintained for the most recent frame). toks_.Clear() + // deletes them from the hash and returns the list of Elems. The + // function DeleteElems calls toks_.Delete(elem) for each elem in + // the list, which returns ownership of the Elem to the toks_ structure + // for reuse, but does not delete the Token pointer. The Token pointers + // are reference-counted and are ultimately deleted in PruneTokensForFrame, + // but are also linked together on each frame by their own linked-list, + // using the "next" pointer. We delete them manually. + void DeleteElems(Elem *list); + + void ClearActiveTokens(); + + + // Returns the number of active tokens on frame `frame`. Can be used as part + // of a heuristic to decide which frame to determinize until, if you are not + // at the end of an utterance. + int32 GetNumToksForFrame(int32 frame); + + /** + UpdateLatticeDeterminization() ensures the work of determinization is kept + up to date so that when you do need the lattice you can get it fast. It + uses the configuration values `determinize_delay`, `determinize_max_delay` + and `determinize_min_chunk_size` to decide whether and when to call + GetLattice(). You can safely call this as often as you want (e.g. after + each time you call AdvanceDecoding(); it won't do subtantially more work if + it is called frequently. + */ + void UpdateLatticeDeterminization(); + + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalDecoderTpl); +}; + +typedef LatticeIncrementalDecoderTpl + LatticeIncrementalDecoder; + + +} // end namespace kaldi. + +#endif diff --git a/src/decoder/lattice-incremental-online-decoder.cc b/src/decoder/lattice-incremental-online-decoder.cc new file mode 100644 index 000000000..85f902bde --- /dev/null +++ b/src/decoder/lattice-incremental-online-decoder.cc @@ -0,0 +1,150 @@ +// decoder/lattice-incremental-online-decoder.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.cc, about how to maintain this +// file in sync with lattice-faster-decoder.cc + +#include "decoder/lattice-incremental-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" +#include "lat/lattice-functions.h" +#include "base/timer.h" + +namespace kaldi { + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeIncrementalOnlineDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { + olat->DeleteStates(); + BaseFloat final_graph_cost; + BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost); + if (iter.Done()) + return false; // would have printed warning. + StateId state = olat->AddState(); + olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0)); + while (!iter.Done()) { + LatticeArc arc; + iter = TraceBackBestPath(iter, &arc); + arc.nextstate = state; + StateId new_state = olat->AddState(); + olat->AddArc(new_state, arc); + state = new_state; + } + olat->SetStart(state); + return true; +} + +template +typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator LatticeIncrementalOnlineDecoderTpl::BestPathEnd( + bool use_final_probs, + BaseFloat *final_cost_out) const { + if (this->decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "BestPathEnd() with use_final_probs == false"; + KALDI_ASSERT(this->NumFramesDecoded() > 0 && + "You cannot call BestPathEnd if no frames were decoded."); + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (this->decoding_finalized_ ? this->final_costs_ :final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); + + // Singly linked list of tokens on last frame (access list through "next" + // pointer). + BaseFloat best_cost = std::numeric_limits::infinity(); + BaseFloat best_final_cost = 0; + Token *best_tok = NULL; + for (Token *tok = this->active_toks_.back().toks; + tok != NULL; tok = tok->next) { + BaseFloat cost = tok->tot_cost, final_cost = 0.0; + if (use_final_probs && !final_costs.empty()) { + // if we are instructed to use final-probs, and any final tokens were + // active on final frame, include the final-prob in the cost of the token. + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) { + final_cost = iter->second; + cost += final_cost; + } else { + cost = std::numeric_limits::infinity(); + } + } + if (cost < best_cost) { + best_cost = cost; + best_tok = tok; + best_final_cost = final_cost; + } + } + if (best_tok == NULL) { // this should not happen, and is likely a code error or + // caused by infinities in likelihoods, but I'm not making + // it a fatal error for now. + KALDI_WARN << "No final token found."; + } + if (final_cost_out == NULL) + *final_cost_out = best_final_cost; + return BestPathIterator(best_tok, this->NumFramesDecoded() - 1); +} + + +template +typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator LatticeIncrementalOnlineDecoderTpl::TraceBackBestPath( + BestPathIterator iter, LatticeArc *oarc) const { + KALDI_ASSERT(!iter.Done() && oarc != NULL); + Token *tok = static_cast(iter.tok); + int32 cur_t = iter.frame, ret_t = cur_t; + if (tok->backpointer != NULL) { + ForwardLinkT *link; + for (link = tok->backpointer->links; + link != NULL; link = link->next) { + if (link->next_tok == tok) { // this is the link to "tok" + oarc->ilabel = link->ilabel; + oarc->olabel = link->olabel; + BaseFloat graph_cost = link->graph_cost, + acoustic_cost = link->acoustic_cost; + if (link->ilabel != 0) { + KALDI_ASSERT(static_cast(cur_t) < this->cost_offsets_.size()); + acoustic_cost -= this->cost_offsets_[cur_t]; + ret_t--; + } + oarc->weight = LatticeWeight(graph_cost, acoustic_cost); + break; + } + } + if (link == NULL) { // Did not find correct link. + KALDI_ERR << "Error tracing best-path back (likely " + << "bug in token-pruning algorithm)"; + } + } else { + oarc->ilabel = 0; + oarc->olabel = 0; + oarc->weight = LatticeWeight::One(); // zero costs. + } + return BestPathIterator(tok->backpointer, ret_t); +} + +// Instantiate the template for the FST types that we'll need. +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl >; +template class LatticeIncrementalOnlineDecoderTpl; + + +} // end namespace kaldi. diff --git a/src/decoder/lattice-incremental-online-decoder.h b/src/decoder/lattice-incremental-online-decoder.h new file mode 100644 index 000000000..8bd41c851 --- /dev/null +++ b/src/decoder/lattice-incremental-online-decoder.h @@ -0,0 +1,132 @@ +// decoder/lattice-incremental-online-decoder.h + +// Copyright 2019 Zhehuai Chen +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.h, about how to maintain this +// file in sync with lattice-faster-decoder.h + + +#ifndef KALDI_DECODER_LATTICE_INCREMENTAL_ONLINE_DECODER_H_ +#define KALDI_DECODER_LATTICE_INCREMENTAL_ONLINE_DECODER_H_ + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/lattice-incremental-decoder.h" + +namespace kaldi { + + + +/** LatticeIncrementalOnlineDecoderTpl is as LatticeIncrementalDecoderTpl but also + supports an efficient way to get the best path (see the function + BestPathEnd()), which is useful in endpointing and in situations where you + might want to frequently access the best path. + + This is only templated on the FST type, since the Token type is required to + be BackpointerToken. Actually it only makes sense to instantiate + LatticeIncrementalDecoderTpl with Token == BackpointerToken if you do so indirectly via + this child class. + */ +template +class LatticeIncrementalOnlineDecoderTpl: + public LatticeIncrementalDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using Token = decoder::BackpointerToken; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeIncrementalOnlineDecoderTpl(const FST &fst, + const TransitionModel &trans_model, + const LatticeIncrementalDecoderConfig &config): + LatticeIncrementalDecoderTpl(fst, trans_model, config) { } + + // This version of the initializer takes ownership of 'fst', and will delete + // it when this object is destroyed. + LatticeIncrementalOnlineDecoderTpl(const LatticeIncrementalDecoderConfig &config, + FST *fst, + const TransitionModel &trans_model): + LatticeIncrementalDecoderTpl(config, fst, trans_model) { } + + + struct BestPathIterator { + void *tok; + int32 frame; + // note, "frame" is the frame-index of the frame you'll get the + // transition-id for next time, if you call TraceBackBestPath on this + // iterator (assuming it's not an epsilon transition). Note that this + // is one less than you might reasonably expect, e.g. it's -1 for + // the nonemitting transitions before the first frame. + BestPathIterator(void *t, int32 f): tok(t), frame(f) { } + bool Done() { return tok == NULL; } + }; + + + /// Outputs an FST corresponding to the single best path through the lattice. + /// This is quite efficient because it doesn't get the entire raw lattice and find + /// the best path through it; instead, it uses the BestPathEnd and BestPathIterator + /// so it basically traces it back through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true) const; + + + + /// This function returns an iterator that can be used to trace back + /// the best path. If use_final_probs == true and at least one final state + /// survived till the end, it will use the final-probs in working out the best + /// final Token, and will output the final cost to *final_cost (if non-NULL), + /// else it will use only the forward likelihood, and will put zero in + /// *final_cost (if non-NULL). + /// Requires that NumFramesDecoded() > 0. + BestPathIterator BestPathEnd(bool use_final_probs, + BaseFloat *final_cost = NULL) const; + + + /// This function can be used in conjunction with BestPathEnd() to trace back + /// the best path one link at a time (e.g. this can be useful in endpoint + /// detection). By "link" we mean a link in the graph; not all links cross + /// frame boundaries, but each time you see a nonzero ilabel you can interpret + /// that as a frame. The return value is the updated iterator. It outputs + /// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer, + /// while leaving its "nextstate" variable unchanged. + BestPathIterator TraceBackBestPath( + BestPathIterator iter, LatticeArc *arc) const; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeIncrementalOnlineDecoderTpl); +}; + +typedef LatticeIncrementalOnlineDecoderTpl LatticeIncrementalOnlineDecoder; + + +} // end namespace kaldi. + +#endif diff --git a/src/feat/feature-spectrogram.cc b/src/feat/feature-spectrogram.cc index 7eee2643c..1a9ea3a93 100644 --- a/src/feat/feature-spectrogram.cc +++ b/src/feat/feature-spectrogram.cc @@ -62,6 +62,11 @@ void SpectrogramComputer::Compute(BaseFloat signal_raw_log_energy, else // An alternative algorithm that works for non-powers-of-two RealFft(signal_frame, true); + if (opts_.return_raw_fft) { + feature->CopyFromVec(*signal_frame); + return; + } + // Convert the FFT into a power spectrum. ComputePowerSpectrum(signal_frame); SubVector power_spectrum(*signal_frame, diff --git a/src/feat/feature-spectrogram.h b/src/feat/feature-spectrogram.h index 132a6875e..0b1061ad9 100644 --- a/src/feat/feature-spectrogram.h +++ b/src/feat/feature-spectrogram.h @@ -39,10 +39,14 @@ struct SpectrogramOptions { FrameExtractionOptions frame_opts; BaseFloat energy_floor; bool raw_energy; // If true, compute energy before preemphasis and windowing + bool return_raw_fft; // If true, return the raw FFT spectrum + // Note that in that case the Dim() will return double + // the expected dimension (because of the complex domain of it) SpectrogramOptions() : energy_floor(0.0), - raw_energy(true) {} + raw_energy(true), + return_raw_fft(false) {} void Register(OptionsItf *opts) { frame_opts.Register(opts); @@ -54,6 +58,8 @@ struct SpectrogramOptions { "std::numeric_limits::epsilon()."); opts->Register("raw-energy", &raw_energy, "If true, compute energy before preemphasis and windowing"); + opts->Register("return-raw-fft", &return_raw_fft, + "If true, return raw FFT complex numbers instead of log magnitudes"); } }; @@ -68,7 +74,13 @@ class SpectrogramComputer { return opts_.frame_opts; } - int32 Dim() const { return opts_.frame_opts.PaddedWindowSize() / 2 + 1; } + int32 Dim() const { + if (opts_.return_raw_fft) { + return opts_.frame_opts.PaddedWindowSize(); + } else { + return opts_.frame_opts.PaddedWindowSize() / 2 + 1; + } + } bool NeedRawLogEnergy() const { return opts_.raw_energy; } diff --git a/src/feat/feature-window.cc b/src/feat/feature-window.cc index c5d4cc298..1dea03d67 100644 --- a/src/feat/feature-window.cc +++ b/src/feat/feature-window.cc @@ -115,6 +115,10 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) double i_fl = static_cast(i); if (opts.window_type == "hanning") { window(i) = 0.5 - 0.5*cos(a * i_fl); + } else if (opts.window_type == "sine") { + // when you are checking ws wikipedia, please + // note that 0.5 * a = M_PI/(frame_length-1) + window(i) = sin(0.5 * a * i_fl); } else if (opts.window_type == "hamming") { window(i) = 0.54 - 0.46*cos(a * i_fl); } else if (opts.window_type == "povey") { // like hamming but goes to zero at edges. diff --git a/src/feat/feature-window.h b/src/feat/feature-window.h index a7abba50e..e6d673937 100644 --- a/src/feat/feature-window.h +++ b/src/feat/feature-window.h @@ -40,7 +40,7 @@ struct FrameExtractionOptions { BaseFloat preemph_coeff; // Preemphasis coefficient. bool remove_dc_offset; // Subtract mean of wave before FFT. std::string window_type; // e.g. Hamming window - // May be "hamming", "rectangular", "povey", "hanning", "blackman" + // May be "hamming", "rectangular", "povey", "hanning", "sine", "blackman" // "povey" is a window I made to be similar to Hamming but to go to zero at the // edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) // I just don't think the Hamming window makes sense as a windowing function. @@ -81,7 +81,7 @@ struct FrameExtractionOptions { "option, e.g. to 1.0 or 0.1"); opts->Register("window-type", &window_type, "Type of window " "(\"hamming\"|\"hanning\"|\"povey\"|\"rectangular\"" - "|\"blackmann\")"); + "|\"sine\"|\"blackmann\")"); opts->Register("blackman-coeff", &blackman_coeff, "Constant coefficient for generalized Blackman window."); opts->Register("round-to-power-of-two", &round_to_power_of_two, diff --git a/src/hmm/hmm-utils.cc b/src/hmm/hmm-utils.cc index 06edf8d59..15a1edfd2 100644 --- a/src/hmm/hmm-utils.cc +++ b/src/hmm/hmm-utils.cc @@ -1289,5 +1289,16 @@ void ChangeReorderingOfAlignment(const TransitionModel &trans_model, } } +void GetPdfToPhonesMap(const TransitionModel &trans_model, + std::vector > *pdf2phones) { + pdf2phones->clear(); + pdf2phones->resize(trans_model.NumPdfs()); + for (int32 i = 0; i < trans_model.NumTransitionIds(); i++) { + int32 trans_id = i + 1; + int32 pdf_id = trans_model.TransitionIdToPdf(trans_id); + int32 phone = trans_model.TransitionIdToPhone(trans_id); + (*pdf2phones)[pdf_id].insert(phone); + } +} } // namespace kaldi diff --git a/src/hmm/hmm-utils.h b/src/hmm/hmm-utils.h index a8ad84694..4415927df 100644 --- a/src/hmm/hmm-utils.h +++ b/src/hmm/hmm-utils.h @@ -329,6 +329,12 @@ void GetRandomAlignmentForPhone(const ContextDependencyInterface &ctx_dep, void ChangeReorderingOfAlignment(const TransitionModel &trans_model, std::vector *alignment); + +// GetPdfToPhonesMap creates a map which maps each pdf-id into its +// corresponding monophones. +void GetPdfToPhonesMap(const TransitionModel &trans_model, + std::vector > *pdf2phones); + /// @} end "addtogroup hmm_group" } // end namespace kaldi diff --git a/src/lat/determinize-lattice-pruned.h b/src/lat/determinize-lattice-pruned.h index 8e1858aa2..353237098 100644 --- a/src/lat/determinize-lattice-pruned.h +++ b/src/lat/determinize-lattice-pruned.h @@ -105,8 +105,8 @@ namespace fst { representation" and hence the "minimal representation" will be the same. We can use this to reduce compute. Note that if two initial representations are different, this does not preclude the other representations from being the same. - -*/ + +*/ struct DeterminizeLatticePrunedOptions { @@ -190,7 +190,7 @@ template bool DeterminizeLatticePruned( const ExpandedFst > &ifst, double prune, - MutableFst > *ofst, + MutableFst > *ofst, DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions()); @@ -199,7 +199,7 @@ bool DeterminizeLatticePruned( (i.e. the sequences of output symbols are represented directly as strings The input FST must be topologically sorted in order for the algorithm to work. For efficiency it is recommended to sort the ilabel for the input FST as well. - Returns true on success, and false if it had to terminate the determinization + Returns true on normal success, and false if it had to terminate the determinization earlier than specified by the "prune" beam-- that is, if it terminated because of the max_mem, max_loop or max_arcs constraints in the options. CAUTION: if Lattice is the input, you need to Invert() before calling this, @@ -261,7 +261,7 @@ bool DeterminizeLatticePhonePruned( = DeterminizeLatticePhonePrunedOptions()); /** "Destructive" version of DeterminizeLatticePhonePruned() where the input - lattice might be changed. + lattice might be changed. */ template bool DeterminizeLatticePhonePruned( diff --git a/src/lat/lattice-functions.cc b/src/lat/lattice-functions.cc index 7f484f952..f4a184f3c 100644 --- a/src/lat/lattice-functions.cc +++ b/src/lat/lattice-functions.cc @@ -1107,7 +1107,6 @@ void CompactLatticeShortestPath(const CompactLattice &clat, // Now we can assume it's topologically sorted. shortest_path->DeleteStates(); if (clat.Start() == kNoStateId) return; - KALDI_ASSERT(clat.Start() == 0); // since top-sorted. typedef CompactLatticeArc Arc; typedef Arc::StateId StateId; typedef CompactLatticeWeight Weight; @@ -1117,7 +1116,7 @@ void CompactLatticeShortestPath(const CompactLattice &clat, best_cost_and_pred[s].first = std::numeric_limits::infinity(); best_cost_and_pred[s].second = fst::kNoStateId; } - best_cost_and_pred[0].first = 0; + best_cost_and_pred[clat.Start()].first = 0; for (StateId s = 0; s < clat.NumStates(); s++) { double my_cost = best_cost_and_pred[s].first; for (ArcIterator aiter(clat, s); @@ -1139,8 +1138,8 @@ void CompactLatticeShortestPath(const CompactLattice &clat, } } std::vector states; // states on best path. - StateId cur_state = superfinal; - while (cur_state != 0) { + StateId cur_state = superfinal, start_state = clat.Start(); + while (cur_state != start_state) { StateId prev_state = best_cost_and_pred[cur_state].second; if (prev_state == kNoStateId) { KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)"; diff --git a/src/makefiles/linux_atlas.mk b/src/makefiles/linux_atlas.mk index bcbd019c0..bd3086e0c 100644 --- a/src/makefiles/linux_atlas.mk +++ b/src/makefiles/linux_atlas.mk @@ -1,5 +1,8 @@ # ATLAS specific Linux configuration +ifndef DEBUG_LEVEL +$(error DEBUG_LEVEL not defined.) +endif ifndef DOUBLE_PRECISION $(error DOUBLE_PRECISION not defined.) endif @@ -22,12 +25,19 @@ CXXFLAGS = -std=c++11 -I.. -isystem $(OPENFSTINC) -O1 $(EXTRA_CXXFLAGS) \ -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_ATLAS -I$(ATLASINC) \ -msse -msse2 -pthread \ - -g # -O0 -DKALDI_PARANOID + -g ifeq ($(KALDI_FLAVOR), dynamic) CXXFLAGS += -fPIC endif +ifeq ($(DEBUG_LEVEL), 0) +CXXFLAGS += -DNDEBUG +endif +ifeq ($(DEBUG_LEVEL), 2) +CXXFLAGS += -O0 -DKALDI_PARANOID +endif + # Compiler specific flags COMPILER = $(shell $(CXX) -v 2>&1) ifeq ($(findstring clang,$(COMPILER)),clang) diff --git a/src/makefiles/linux_atlas_ppc64le.mk b/src/makefiles/linux_atlas_ppc64le.mk index 58b41c8e8..fdb2618c9 100644 --- a/src/makefiles/linux_atlas_ppc64le.mk +++ b/src/makefiles/linux_atlas_ppc64le.mk @@ -1,5 +1,8 @@ # ATLAS specific Linux ppc64le configuration +ifndef DEBUG_LEVEL +$(error DEBUG_LEVEL not defined.) +endif ifndef DOUBLE_PRECISION $(error DOUBLE_PRECISION not defined.) endif @@ -23,12 +26,19 @@ CXXFLAGS = -std=c++11 -I.. -isystem $(OPENFSTINC) -O1 $(EXTRA_CXXFLAGS) \ -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_ATLAS -I$(ATLASINC) \ -m64 -maltivec -mcpu=power8 -mtune=power8 -mpower8-vector -mvsx \ -pthread \ - -g # -O0 -DKALDI_PARANOID + -g ifeq ($(KALDI_FLAVOR), dynamic) CXXFLAGS += -fPIC endif +ifeq ($(DEBUG_LEVEL), 0) +CXXFLAGS += -DNDEBUG +endif +ifeq ($(DEBUG_LEVEL), 2) +CXXFLAGS += -O0 -DKALDI_PARANOID +endif + # Compiler specific flags COMPILER = $(shell $(CXX) -v 2>&1) ifeq ($(findstring clang,$(COMPILER)),clang) diff --git a/src/makefiles/linux_clapack.mk b/src/makefiles/linux_clapack.mk index 5c670bfb8..058c4eeab 100644 --- a/src/makefiles/linux_clapack.mk +++ b/src/makefiles/linux_clapack.mk @@ -1,5 +1,8 @@ # CLAPACK specific Linux configuration +ifndef DEBUG_LEVEL +$(error DEBUG_LEVEL not defined.) +endif ifndef DOUBLE_PRECISION $(error DOUBLE_PRECISION not defined.) endif @@ -16,12 +19,19 @@ CXXFLAGS = -std=c++11 -I.. -isystem $(OPENFSTINC) -O1 $(EXTRA_CXXFLAGS) \ -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_CLAPACK -I../../tools/CLAPACK \ -msse -msse2 -pthread \ - -g # -O0 -DKALDI_PARANOID + -g ifeq ($(KALDI_FLAVOR), dynamic) CXXFLAGS += -fPIC endif +ifeq ($(DEBUG_LEVEL), 0) +CXXFLAGS += -DNDEBUG +endif +ifeq ($(DEBUG_LEVEL), 2) +CXXFLAGS += -O0 -DKALDI_PARANOID +endif + # Compiler specific flags COMPILER = $(shell $(CXX) -v 2>&1) ifeq ($(findstring clang,$(COMPILER)),clang) diff --git a/src/makefiles/linux_clapack_arm.mk b/src/makefiles/linux_clapack_arm.mk index fb5a3821f..c80710bd0 100644 --- a/src/makefiles/linux_clapack_arm.mk +++ b/src/makefiles/linux_clapack_arm.mk @@ -1,5 +1,8 @@ # CLAPACK specific Linux ARM configuration +ifndef DEBUG_LEVEL +$(error DEBUG_LEVEL not defined.) +endif ifndef DOUBLE_PRECISION $(error DOUBLE_PRECISION not defined.) endif @@ -16,12 +19,19 @@ CXXFLAGS = -std=c++11 -I.. -isystem $(OPENFSTINC) -O1 $(EXTRA_CXXFLAGS) \ -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_CLAPACK -I../../tools/CLAPACK \ -ftree-vectorize -mfloat-abi=hard -mfpu=neon -pthread \ - -g # -O0 -DKALDI_PARANOID + -g ifeq ($(KALDI_FLAVOR), dynamic) CXXFLAGS += -fPIC endif +ifeq ($(DEBUG_LEVEL), 0) +CXXFLAGS += -DNDEBUG +endif +ifeq ($(DEBUG_LEVEL), 2) +CXXFLAGS += -O0 -DKALDI_PARANOID +endif + # Compiler specific flags COMPILER = $(shell $(CXX) -v 2>&1) ifeq ($(findstring clang,$(COMPILER)),clang) diff --git a/src/makefiles/linux_openblas.mk b/src/makefiles/linux_openblas.mk index 8135f1e91..f2bd7ec42 100644 --- a/src/makefiles/linux_openblas.mk +++ b/src/makefiles/linux_openblas.mk @@ -1,5 +1,8 @@ # OpenBLAS specific Linux configuration +ifndef DEBUG_LEVEL +$(error DEBUG_LEVEL not defined.) +endif ifndef DOUBLE_PRECISION $(error DOUBLE_PRECISION not defined.) endif @@ -22,12 +25,19 @@ CXXFLAGS = -std=c++11 -I.. -isystem $(OPENFSTINC) -O1 $(EXTRA_CXXFLAGS) \ -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_OPENBLAS -I$(OPENBLASINC) \ -msse -msse2 -pthread \ - -g # -O0 -DKALDI_PARANOID + -g ifeq ($(KALDI_FLAVOR), dynamic) CXXFLAGS += -fPIC endif +ifeq ($(DEBUG_LEVEL), 0) +CXXFLAGS += -DNDEBUG +endif +ifeq ($(DEBUG_LEVEL), 2) +CXXFLAGS += -O0 -DKALDI_PARANOID +endif + # Compiler specific flags COMPILER = $(shell $(CXX) -v 2>&1) ifeq ($(findstring clang,$(COMPILER)),clang) diff --git a/src/makefiles/linux_openblas_aarch64.mk b/src/makefiles/linux_openblas_aarch64.mk index 55287d344..7098f8b6a 100644 --- a/src/makefiles/linux_openblas_aarch64.mk +++ b/src/makefiles/linux_openblas_aarch64.mk @@ -1,5 +1,8 @@ # OpenBLAS specific Linux ARM configuration +ifndef DEBUG_LEVEL +$(error DEBUG_LEVEL not defined.) +endif ifndef DOUBLE_PRECISION $(error DOUBLE_PRECISION not defined.) endif @@ -22,12 +25,19 @@ CXXFLAGS = -std=c++11 -I.. -isystem $(OPENFSTINC) -O1 $(EXTRA_CXXFLAGS) \ -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_OPENBLAS -I$(OPENBLASINC) \ -ftree-vectorize -pthread \ - -g # -O0 -DKALDI_PARANOID + -g ifeq ($(KALDI_FLAVOR), dynamic) CXXFLAGS += -fPIC endif +ifeq ($(DEBUG_LEVEL), 0) +CXXFLAGS += -DNDEBUG +endif +ifeq ($(DEBUG_LEVEL), 2) +CXXFLAGS += -O0 -DKALDI_PARANOID +endif + # Compiler specific flags COMPILER = $(shell $(CXX) -v 2>&1) ifeq ($(findstring clang,$(COMPILER)),clang) diff --git a/src/makefiles/linux_openblas_arm.mk b/src/makefiles/linux_openblas_arm.mk index 30603c1b8..5a79d8244 100644 --- a/src/makefiles/linux_openblas_arm.mk +++ b/src/makefiles/linux_openblas_arm.mk @@ -1,5 +1,8 @@ # OpenBLAS specific Linux ARM configuration +ifndef DEBUG_LEVEL +$(error DEBUG_LEVEL not defined.) +endif ifndef DOUBLE_PRECISION $(error DOUBLE_PRECISION not defined.) endif @@ -22,12 +25,19 @@ CXXFLAGS = -std=c++11 -I.. -isystem $(OPENFSTINC) -O1 $(EXTRA_CXXFLAGS) \ -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_OPENBLAS -I$(OPENBLASINC) \ -ftree-vectorize -mfloat-abi=hard -mfpu=neon -pthread \ - -g # -O0 -DKALDI_PARANOID + -g ifeq ($(KALDI_FLAVOR), dynamic) CXXFLAGS += -fPIC endif +ifeq ($(DEBUG_LEVEL), 0) +CXXFLAGS += -DNDEBUG +endif +ifeq ($(DEBUG_LEVEL), 2) +CXXFLAGS += -O0 -DKALDI_PARANOID +endif + # Compiler specific flags COMPILER = $(shell $(CXX) -v 2>&1) ifeq ($(findstring clang,$(COMPILER)),clang) diff --git a/src/makefiles/linux_openblas_ppc64le.mk b/src/makefiles/linux_openblas_ppc64le.mk index 89e882cb2..4d3919e1f 100644 --- a/src/makefiles/linux_openblas_ppc64le.mk +++ b/src/makefiles/linux_openblas_ppc64le.mk @@ -1,5 +1,8 @@ # OpenBLAS specific Linux configuration +ifndef DEBUG_LEVEL +$(error DEBUG_LEVEL not defined.) +endif ifndef DOUBLE_PRECISION $(error DOUBLE_PRECISION not defined.) endif @@ -23,12 +26,19 @@ CXXFLAGS = -std=c++11 -I.. -isystem $(OPENFSTINC) -O1 $(EXTRA_CXXFLAGS) \ -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_OPENBLAS -I$(OPENBLASINC) \ -m64 -maltivec -mcpu=power8 -mtune=power8 -mpower8-vector -mvsx \ -pthread \ - -g # -O0 -DKALDI_PARANOID + -g ifeq ($(KALDI_FLAVOR), dynamic) CXXFLAGS += -fPIC endif +ifeq ($(DEBUG_LEVEL), 0) +CXXFLAGS += -DNDEBUG +endif +ifeq ($(DEBUG_LEVEL), 2) +CXXFLAGS += -O0 -DKALDI_PARANOID +endif + # Compiler specific flags COMPILER = $(shell $(CXX) -v 2>&1) ifeq ($(findstring clang,$(COMPILER)),clang) diff --git a/src/makefiles/linux_x86_64_mkl.mk b/src/makefiles/linux_x86_64_mkl.mk index d1c399d97..dc1fa7a73 100644 --- a/src/makefiles/linux_x86_64_mkl.mk +++ b/src/makefiles/linux_x86_64_mkl.mk @@ -9,6 +9,9 @@ # Use the options obtained from this website to manually configure for other # platforms using MKL. +ifndef DEBUG_LEVEL +$(error DEBUG_LEVEL not defined.) +endif ifndef DOUBLE_PRECISION $(error DOUBLE_PRECISION not defined.) endif @@ -30,12 +33,19 @@ CXXFLAGS = -std=c++11 -I.. -isystem $(OPENFSTINC) -O1 $(EXTRA_CXXFLAGS) \ -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -DHAVE_MKL -I$(MKLROOT)/include \ -m64 -msse -msse2 -pthread \ - -g # -O0 -DKALDI_PARANOID + -g ifeq ($(KALDI_FLAVOR), dynamic) CXXFLAGS += -fPIC endif +ifeq ($(DEBUG_LEVEL), 0) +CXXFLAGS += -DNDEBUG +endif +ifeq ($(DEBUG_LEVEL), 2) +CXXFLAGS += -O0 -DKALDI_PARANOID +endif + # Compiler specific flags COMPILER = $(shell $(CXX) -v 2>&1) ifeq ($(findstring clang,$(COMPILER)),clang) diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index 4387538c4..bf634b0ec 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -574,6 +574,11 @@ class MatrixBase { void SymPosSemiDefEig(VectorBase *s, MatrixBase *P, Real check_thresh = 0.001); + // There are some weird issue with template friend function in a class + // template in Windows version of nvcc. This is simple an ugly walkaround. +#if defined(__NVCC__) && defined(_MSC_VER) + template +#endif friend Real kaldi::TraceMatMat(const MatrixBase &A, const MatrixBase &B, MatrixTransposeType trans); // tr (A B) diff --git a/src/matrix/kaldi-vector.h b/src/matrix/kaldi-vector.h index 2a032354b..a5baa3c2d 100644 --- a/src/matrix/kaldi-vector.h +++ b/src/matrix/kaldi-vector.h @@ -510,36 +510,36 @@ class SubVector : public VectorBase { KALDI_ASSERT(static_cast(origin)+ static_cast(length) <= static_cast(t.Dim())); - VectorBase::data_ = const_cast (t.Data()+origin); - VectorBase::dim_ = length; + this->data_ = const_cast (t.Data()+origin); + this->dim_ = length; } /// This constructor initializes the vector to point at the contents /// of this packed matrix (SpMatrix or TpMatrix). SubVector(const PackedMatrix &M) { - VectorBase::data_ = const_cast (M.Data()); - VectorBase::dim_ = (M.NumRows()*(M.NumRows()+1))/2; + this->data_ = const_cast (M.Data()); + this->dim_ = (M.NumRows()*(M.NumRows()+1))/2; } /// Copy constructor SubVector(const SubVector &other) : VectorBase () { // this copy constructor needed for Range() to work in base class. - VectorBase::data_ = other.data_; - VectorBase::dim_ = other.dim_; + this->data_ = other.data_; + this->dim_ = other.dim_; } /// Constructor from a pointer to memory and a length. Keeps a pointer /// to the data but does not take ownership (will never delete). /// Caution: this constructor enables you to evade const constraints. SubVector(const Real *data, MatrixIndexT length) : VectorBase () { - VectorBase::data_ = const_cast(data); - VectorBase::dim_ = length; + this->data_ = const_cast(data); + this->dim_ = length; } /// This operation does not preserve const-ness, so be careful. SubVector(const MatrixBase &matrix, MatrixIndexT row) { - VectorBase::data_ = const_cast(matrix.RowData(row)); - VectorBase::dim_ = matrix.NumCols(); + this->data_ = const_cast(matrix.RowData(row)); + this->dim_ = matrix.NumCols(); } ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). diff --git a/src/nnet3/Makefile b/src/nnet3/Makefile index 5e67211c3..f8f9d32fe 100644 --- a/src/nnet3/Makefile +++ b/src/nnet3/Makefile @@ -31,7 +31,8 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ nnet-compile-looped.o decodable-simple-looped.o \ decodable-online-looped.o convolution.o \ nnet-convolutional-component.o attention.o \ - nnet-attention-component.o nnet-tdnn-component.o nnet-batch-compute.o + nnet-attention-component.o nnet-tdnn-component.o nnet-batch-compute.o \ + nnet-chain-training2.o LIBNAME = kaldi-nnet3 diff --git a/src/nnet3/nnet-chain-diagnostics2.cc b/src/nnet3/nnet-chain-diagnostics2.cc new file mode 100644 index 000000000..858ec4027 --- /dev/null +++ b/src/nnet3/nnet-chain-diagnostics2.cc @@ -0,0 +1,295 @@ +// nnet3/nnet-chain-diagnostics.cc + +// Copyright 2015 Johns Hopkins University (author: Daniel Povey) +// 2019 Idiap Research Institute (author: Srikanth Madikeri) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet-chain-diagnostics2.h" +#include "nnet3/nnet-utils.h" + +namespace kaldi { +namespace nnet3 { + +NnetChainComputeProb2::NnetChainComputeProb2( + const NnetComputeProbOptions &nnet_config, + const chain::ChainTrainingOptions &chain_config, + NnetChainModel2 &model, + const Nnet &nnet): + nnet_config_(nnet_config), + chain_config_(chain_config), + nnet_(nnet), + compiler_(nnet, nnet_config_.optimize_config, nnet_config_.compiler_config), + deriv_nnet_owned_(true), + deriv_nnet_(NULL), + model_(model), + num_minibatches_processed_(0) { + if (nnet_config_.compute_deriv) { + deriv_nnet_ = new Nnet(nnet_); + ScaleNnet(0.0, deriv_nnet_); + SetNnetAsGradient(deriv_nnet_); // force simple update + } else if (nnet_config_.store_component_stats) { + KALDI_ERR << "If you set store_component_stats == true and " + << "compute_deriv == false, use the other constructor."; + } +} + +NnetChainComputeProb2::NnetChainComputeProb2( + const NnetComputeProbOptions &nnet_config, + const chain::ChainTrainingOptions &chain_config, + NnetChainModel2 &model, + Nnet *nnet): + nnet_config_(nnet_config), + chain_config_(chain_config), + nnet_(*nnet), + compiler_(*nnet, nnet_config_.optimize_config, nnet_config_.compiler_config), + deriv_nnet_owned_(false), + deriv_nnet_(nnet), + model_(model), + num_minibatches_processed_(0) { + KALDI_ASSERT(nnet_config.store_component_stats && !nnet_config.compute_deriv); +} + +const Nnet &NnetChainComputeProb2::GetDeriv() const { + if (!nnet_config_.compute_deriv) + KALDI_ERR << "GetDeriv() called when no derivatives were requested."; + return *deriv_nnet_; +} + +NnetChainComputeProb2::~NnetChainComputeProb2() { + if (deriv_nnet_owned_) + delete deriv_nnet_; // delete does nothing if pointer is NULL. +} + +void NnetChainComputeProb2::Reset() { + num_minibatches_processed_ = 0; + objf_info_.clear(); + if (deriv_nnet_) { + ScaleNnet(0.0, deriv_nnet_); + SetNnetAsGradient(deriv_nnet_); + } +} + +void NnetChainComputeProb2::Compute(NnetChainExample &chain_eg) { + std::string default_lang = "default"; + Compute(default_lang, chain_eg); +} + +void NnetChainComputeProb2::Compute(const std::string &lang_key, NnetChainExample &chain_eg) { + bool need_model_derivative = nnet_config_.compute_deriv, + store_component_stats = nnet_config_.store_component_stats; + ComputationRequest request; + // if the options specify cross-entropy regularization, we'll be computing + // this objective (not interpolated with the regular objective-- we give it a + // separate name), but currently we won't make it contribute to the + // derivative-- we just compute the derivative of the regular output. + // This is because in the place where we use the derivative (the + // model-combination code) we decided to keep it simple and just use the + // regular objective. + bool use_xent_regularization = (chain_config_.xent_regularize != 0.0), + use_xent_derivative = false; + std::string lang_name = "default"; + ParseFromQueryString(lang_key, "lang", &lang_name); + for (size_t i = 0; i < chain_eg.outputs.size(); i++) { + // there will normally be exactly one output , named "output" + if(chain_eg.outputs[i].name.compare("output")==0) { + chain_eg.outputs[i].name = "output-" + lang_name; + break; + } + } + + GetChainComputationRequest(nnet_, chain_eg, need_model_derivative, + store_component_stats, use_xent_regularization, + use_xent_derivative, &request); + std::shared_ptr computation = compiler_.Compile(request); + NnetComputer computer(nnet_config_.compute_config, *computation, + nnet_, deriv_nnet_); + // give the inputs to the computer object. + computer.AcceptInputs(nnet_, chain_eg.inputs); + computer.Run(); + this->ProcessOutputs(lang_name, chain_eg, &computer); + if (nnet_config_.compute_deriv) + computer.Run(); +} + +void NnetChainComputeProb2::ProcessOutputs(const std::string &lang_name, NnetChainExample &eg, + NnetComputer *computer) { + // There will normally be just one output here, named 'output', + // but the code is more general than this. + std::vector::const_iterator iter = eg.outputs.begin(), + end = eg.outputs.end(); + for (; iter != end; ++iter) { + const NnetChainSupervision &sup = *iter; + int32 node_index = nnet_.GetNodeIndex(sup.name); + if (node_index < 0 || + !nnet_.IsOutputNode(node_index)) + KALDI_ERR << "Network has no output named " << sup.name; + + const CuMatrixBase &nnet_output = computer->GetOutput(sup.name); + bool use_xent = (chain_config_.xent_regularize != 0.0); + std::string xent_name = sup.name + "-xent"; // typically "output-xent". + CuMatrix nnet_output_deriv, xent_deriv; + if (nnet_config_.compute_deriv) + nnet_output_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(), + kUndefined); + if (use_xent) + xent_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(), + kUndefined); + + BaseFloat tot_like, tot_l2_term, tot_weight; + + ComputeChainObjfAndDeriv(chain_config_, *(model_.GetDenGraphForLang(lang_name)), + sup.supervision, nnet_output, + &tot_like, &tot_l2_term, &tot_weight, + (nnet_config_.compute_deriv ? &nnet_output_deriv : + NULL), (use_xent ? &xent_deriv : NULL)); + + // note: in this context we don't want to apply 'sup.deriv_weights' because + // this code is used only in combination, where it's part of an L-BFGS + // optimization algorithm, and in that case if there is a mismatch between + // the computed objective function and the derivatives, it may cause errors + // in the optimization procedure such as early termination. (line search + // and conjugate gradient descent both rely on the derivatives being + // accurate, and don't fail gracefully if the derivatives are not accurate). + + ChainObjectiveInfo &totals = objf_info_[sup.name]; + totals.tot_weight += tot_weight; + totals.tot_like += tot_like; + totals.tot_l2_term += tot_l2_term; + + if (nnet_config_.compute_deriv) + computer->AcceptInput(sup.name, &nnet_output_deriv); + + if (use_xent) { + ChainObjectiveInfo &xent_totals = objf_info_[xent_name]; + // this block computes the cross-entropy objective. + const CuMatrixBase &xent_output = computer->GetOutput( + xent_name); + // at this point, xent_deriv is posteriors derived from the numerator + // computation. note, xent_deriv has a factor of '.supervision.weight', + // but so does tot_weight. + BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans); + xent_totals.tot_weight += tot_weight; + xent_totals.tot_like += xent_objf; + } + num_minibatches_processed_++; + } +} + +bool NnetChainComputeProb2::PrintTotalStats() const { + bool ans = false; + unordered_map::const_iterator + iter, end; + iter = objf_info_.begin(); + end = objf_info_.end(); + for (; iter != end; ++iter) { + const std::string &name = iter->first; + int32 node_index = nnet_.GetNodeIndex(name); + KALDI_ASSERT(node_index >= 0); + const ChainObjectiveInfo &info = iter->second; + BaseFloat like = (info.tot_like / info.tot_weight), + l2_term = (info.tot_l2_term / info.tot_weight), + tot_objf = like + l2_term; + if (info.tot_l2_term == 0.0) { + KALDI_LOG << "Overall log-probability for '" + << name << "' is " + << like << " per frame" + << ", over " << info.tot_weight << " frames."; + } else { + KALDI_LOG << "Overall log-probability for '" + << name << "' is " + << like << " + " << l2_term << " = " << tot_objf << " per frame" + << ", over " << info.tot_weight << " frames."; + } + if (info.tot_weight > 0) + ans = true; + } + return ans; +} + + +const ChainObjectiveInfo* NnetChainComputeProb2::GetObjective( + const std::string &output_name) const { + unordered_map::const_iterator + iter = objf_info_.find(output_name); + if (iter != objf_info_.end()) + return &(iter->second); + else + return NULL; +} + +double NnetChainComputeProb2::GetTotalObjective(double *total_weight) const { + double tot_objectives = 0.0; + double tot_weight = 0.0; + unordered_map::const_iterator + iter = objf_info_.begin(), end = objf_info_.end(); + for (; iter != end; ++iter) { + tot_objectives += iter->second.tot_like + iter->second.tot_l2_term; + tot_weight += iter->second.tot_weight; + } + + if (total_weight) *total_weight = tot_weight; + return tot_objectives; +} + +static bool HasXentOutputs2(const Nnet &nnet) { + const std::vector node_names = nnet.GetNodeNames(); + for (std::vector::const_iterator it = node_names.begin(); + it != node_names.end(); ++it) { + int32 node_index = nnet.GetNodeIndex(*it); + if (nnet.IsOutputNode(node_index) && + it->find("-xent") != std::string::npos) { + return true; + } + } + return false; +} + +void RecomputeStats2(std::vector &egs, + const chain::ChainTrainingOptions &chain_config_in, + NnetChainModel2 &model, + Nnet *nnet) { + RecomputeStats2("default", egs, chain_config_in, model, nnet); +} + +// TODO: Note this only works for lang=default for now. So we will have to generalize this later +void RecomputeStats2(const std::string &lang_name, std::vector &egs, + const chain::ChainTrainingOptions &chain_config_in, + NnetChainModel2 &model, + Nnet *nnet) { + KALDI_LOG << "Recomputing stats on nnet (affects batch-norm)"; + chain::ChainTrainingOptions chain_config(chain_config_in); + if (HasXentOutputs2(*nnet) && + chain_config.xent_regularize == 0) { + // this forces it to compute the output for xent outputs, + // usually 'output-xent', which + // means that we'll be computing batch-norm stats for any + // components in that branch that have batch-norm. + chain_config.xent_regularize = 0.1; + } + + ZeroComponentStats(nnet); + NnetComputeProbOptions nnet_config; + nnet_config.store_component_stats = true; + NnetChainComputeProb2 prob_computer(nnet_config, chain_config, model, nnet); + for (size_t i = 0; i < egs.size(); i++) + prob_computer.Compute(egs[i]); + /* prob_computer.PrintTotalStats(); */ + KALDI_LOG << "Done recomputing stats."; +} +} // namespace nnet3 +} // namespace kaldi + diff --git a/src/nnet3/nnet-chain-diagnostics2.h b/src/nnet3/nnet-chain-diagnostics2.h new file mode 100644 index 000000000..e62307453 --- /dev/null +++ b/src/nnet3/nnet-chain-diagnostics2.h @@ -0,0 +1,114 @@ +// nnet3/nnet-chain-diagnostics.h + +// Copyright 2015 Johns Hopkins University (author: Daniel Povey) +// 2019 Idiap Research Institute (author: Srikanth Madikeri) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_NNET3_NNET_CHAIN_DIAGNOSTICS2_H_ +#define KALDI_NNET3_NNET_CHAIN_DIAGNOSTICS2_H_ + +#include "nnet3/nnet-example.h" +#include "nnet3/nnet-computation.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/nnet-optimize.h" +#include "nnet3/nnet-chain-example.h" +#include "nnet3/nnet-chain-training2.h" +#include "nnet3/nnet-diagnostics.h" +#include "nnet3/nnet-chain-diagnostics.h" +#include "chain/chain-training.h" +#include "chain/chain-den-graph.h" + +namespace kaldi { +namespace nnet3 { + + +/** This class is for computing objective-function values in a nnet3+chain + setup, for diagnostics. It also supports computing model derivatives. + Note: if the --xent-regularization option is nonzero, the cross-entropy + objective will be computed, and displayed when you call PrintTotalStats(), + but it will not contribute to model derivatives (there is no code to compute + the regularized objective function, and anyway it's not clear that we really + need this regularization in the combination phase). + */ +class NnetChainComputeProb2 { + public: + // does not store a reference to 'config' but does store one to 'nnet'. + NnetChainComputeProb2(const NnetComputeProbOptions &nnet_config, + const chain::ChainTrainingOptions &chain_config, + NnetChainModel2 &model, + const Nnet &nnet); + + NnetChainComputeProb2(const NnetComputeProbOptions &nnet_config, + const chain::ChainTrainingOptions &chain_config, + NnetChainModel2 &model, + Nnet *nnet); + + // Reset the likelihood stats, and the derivative stats (if computed). + void Reset(); + + // compute objective on one minibatch. + void Compute(NnetChainExample &chain_eg); + void Compute(const std::string &lang_key, NnetChainExample &chain_eg); + + // Prints out the final stats, and return true if there was a nonzero count. + bool PrintTotalStats() const; + + // returns the objective-function info for this output name (e.g. "output"), + // or NULL if there is no such info. + const ChainObjectiveInfo *GetObjective(const std::string &output_name) const; + + // This function returns the total objective over all output nodes recorded here, and + // outputs to 'tot_weight' the total weight (typically the number of frames) + // corresponding to it. + double GetTotalObjective(double *tot_weight) const; + + // if config.compute_deriv == true, returns a reference to the + // computed derivative. Otherwise crashes. + const Nnet &GetDeriv() const; + + ~NnetChainComputeProb2(); + private: + void ProcessOutputs(const std::string &key, NnetChainExample &chain_eg, + NnetComputer *computer); + + NnetComputeProbOptions nnet_config_; + chain::ChainTrainingOptions chain_config_; + const Nnet &nnet_; + CachingOptimizingCompiler compiler_; + bool deriv_nnet_owned_; + Nnet *deriv_nnet_; + NnetChainModel2 &model_; + int32 num_minibatches_processed_; // this is only for diagnostics + + unordered_map objf_info_; + +}; + +void RecomputeStats2(const std::string &lang_name, std::vector &egs, + const chain::ChainTrainingOptions &chain_config_in, + NnetChainModel2 &model, + Nnet *nnet); + +void RecomputeStats2(std::vector &egs, + const chain::ChainTrainingOptions &chain_config_in, + NnetChainModel2 &model, + Nnet *nnet); +} // namespace nnet3 +} // namespace kaldi + +#endif // KALDI_NNET3_NNET_CHAIN_DIAGNOSTICS2_H_ + diff --git a/src/nnet3/nnet-chain-example.cc b/src/nnet3/nnet-chain-example.cc index 53da15d6f..1ceb0de84 100644 --- a/src/nnet3/nnet-chain-example.cc +++ b/src/nnet3/nnet-chain-example.cc @@ -350,6 +350,29 @@ void GetChainComputationRequest(const Nnet &nnet, KALDI_ERR << "No outputs in computation request."; } +// Returns the frame subsampling factor, which is the difference between the +// first 't' value we encounter in 'indexes', and the next 't' value that is +// different from the first 't'. It will typically be 3. +// This function will crash if it could not figure it out (e.g. because +// 'indexes' was empty or had only one element). +static int32 GetFrameSubsamplingFactor(const std::vector &indexes) { + + auto iter = indexes.begin(), end = indexes.end(); + int32 cur_t_value; + if (iter != end) { + cur_t_value = iter->t; + ++iter; + } + for (; iter != end; ++iter) { + if (iter->t != cur_t_value) { + KALDI_ASSERT(iter->t > cur_t_value); + return iter->t - cur_t_value; + } + } + KALDI_ERR << "Error getting frame subsampling factor"; + return 0; // Shouldn't be reached, this is to avoid compiler warnings. +} + void ShiftChainExampleTimes(int32 frame_shift, const std::vector &exclude_names, NnetChainExample *eg) { @@ -377,10 +400,11 @@ void ShiftChainExampleTimes(int32 frame_shift, sup_end = eg->outputs.end(); for (; sup_iter != sup_end; ++sup_iter) { std::vector &indexes = sup_iter->indexes; - KALDI_ASSERT(indexes.size() >= 2 && indexes[0].n == indexes[1].n && - indexes[0].x == indexes[1].x); - int32 frame_subsampling_factor = indexes[1].t - indexes[0].t; - KALDI_ASSERT(frame_subsampling_factor > 0); + int32 frame_subsampling_factor = GetFrameSubsamplingFactor(indexes); + /* KALDI_ASSERT(indexes.size() >= 2 && indexes[0].n == indexes[1].n && */ + /* indexes[0].x == indexes[1].x); */ + /* int32 frame_subsampling_factor = indexes[1].t - indexes[0].t; */ + /* KALDI_ASSERT(frame_subsampling_factor > 0); */ // We need to shift by a multiple of frame_subsampling_factor. // Round to the closest multiple. @@ -551,6 +575,52 @@ void ChainExampleMerger::Finish() { } +bool ParseFromQueryString(const std::string &string, + const std::string &key_name, + std::string *value) { + size_t question_mark_location = string.find_last_of("?"); + if (question_mark_location == std::string::npos) + return false; + std::string key_name_plus_equals = key_name + "="; + // the following do/while and the initialization of key_name_location is a + // little convoluted. We want to find "key_name_plus_equals" but if we find + // it and it's not preceded by '?' or '&' then it's part of a longer key and we + // need to ignore it and see if there's a next one. + size_t key_name_location = question_mark_location; + do { + key_name_location = string.find(key_name_plus_equals, + key_name_location + 1); + } while (key_name_location != std::string::npos && + key_name_location != question_mark_location + 1 && + string[key_name_location - 1] != '&'); + + if (key_name_location == std::string::npos) + return false; + size_t value_location = key_name_location + key_name_plus_equals.length(); + size_t next_ampersand = string.find_first_of("&", value_location); + size_t value_len; + if (next_ampersand == std::string::npos) + value_len = std::string::npos; // will mean "rest of string" + else + value_len = next_ampersand - value_location; + *value = string.substr(value_location, value_len); + return true; +} + + +bool ParseFromQueryString(const std::string &string, + const std::string &key_name, + BaseFloat *value) { + std::string s; + if (!ParseFromQueryString(string, key_name, &s)) + return false; + bool ans = ConvertStringToReal(s, value); + if (!ans) + KALDI_ERR << "For key " << key_name << ", expected float but found '" + << s << "', in string: " << string; + return true; +} + } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-chain-example.h b/src/nnet3/nnet-chain-example.h index 187bb4ef3..40d58c568 100644 --- a/src/nnet3/nnet-chain-example.h +++ b/src/nnet3/nnet-chain-example.h @@ -274,6 +274,13 @@ MapType eg_to_egs_; }; +bool ParseFromQueryString(const std::string &string, + const std::string &key_name, + std::string *value); + +bool ParseFromQueryString(const std::string &string, + const std::string &key_name, + BaseFloat *value); } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-chain-training2.cc b/src/nnet3/nnet-chain-training2.cc new file mode 100644 index 000000000..da25efe44 --- /dev/null +++ b/src/nnet3/nnet-chain-training2.cc @@ -0,0 +1,388 @@ +// nnet3/nnet-chain-training.cc + +// Copyright 2015 Johns Hopkins University (author: Daniel Povey) +// 2016 Xiaohui Zhang +// 2019 Idiap Research Institute (author: Srikanth Madikeri) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/nnet-chain-training2.h" +#include "nnet3/nnet-utils.h" + +namespace kaldi { +namespace nnet3 { + +NnetChainTrainer2::NnetChainTrainer2(const NnetChainTraining2Options &opts, + const NnetChainModel2 &model, + Nnet *nnet): + opts_(opts), + model_(model), + nnet_(nnet), + compiler_(*nnet, opts_.nnet_config.optimize_config, + opts_.nnet_config.compiler_config), + num_minibatches_processed_(0), + max_change_stats_(*nnet), + srand_seed_(RandInt(0, 100000)) { + + if (opts.nnet_config.zero_component_stats) + ZeroComponentStats(nnet); + KALDI_ASSERT(opts.nnet_config.momentum >= 0.0 && + opts.nnet_config.max_param_change >= 0.0 && + opts.nnet_config.backstitch_training_interval > 0); + delta_nnet_ = nnet_->Copy(); + ScaleNnet(0.0, delta_nnet_); + + if (opts.nnet_config.read_cache != "") { + bool binary; + try { + Input ki(opts.nnet_config.read_cache, &binary); + compiler_.ReadCache(ki.Stream(), binary); + KALDI_LOG << "Read computation cache from " << opts.nnet_config.read_cache; + } catch (...) { + KALDI_WARN << "Could not open cached computation. " + "Probably this is the first training iteration."; + } + } +} + + +void NnetChainTrainer2::Train(const std::string &key, NnetChainExample &chain_eg) { + bool need_model_derivative = true; + const NnetTrainerOptions &nnet_config = opts_.nnet_config; + bool use_xent_regularization = (opts_.chain_config.xent_regularize != 0.0); + ComputationRequest request; + std::string lang_name = "default"; + ParseFromQueryString(key, "lang", &lang_name); + for (size_t i = 0; i < chain_eg.outputs.size(); i++) { + // there will normally be exactly one output , named "output" + if(chain_eg.outputs[i].name.compare("output")==0) + chain_eg.outputs[i].name = "output-" + lang_name; + } + GetChainComputationRequest(*nnet_, chain_eg, need_model_derivative, + nnet_config.store_component_stats, + use_xent_regularization, need_model_derivative, + &request); + std::shared_ptr computation = compiler_.Compile(request); + + if (nnet_config.backstitch_training_scale > 0.0 && num_minibatches_processed_ + % nnet_config.backstitch_training_interval == + srand_seed_ % nnet_config.backstitch_training_interval) { + // backstitch training is incompatible with momentum > 0 + KALDI_ASSERT(nnet_config.momentum == 0.0); + FreezeNaturalGradient(true, delta_nnet_); + bool is_backstitch_step1 = true; + srand(srand_seed_ + num_minibatches_processed_); + ResetGenerators(nnet_); + TrainInternalBackstitch(key, chain_eg, *computation, is_backstitch_step1); + FreezeNaturalGradient(false, delta_nnet_); // un-freeze natural gradient + is_backstitch_step1 = false; + srand(srand_seed_ + num_minibatches_processed_); + ResetGenerators(nnet_); + TrainInternalBackstitch(key, chain_eg, *computation, is_backstitch_step1); + } else { // conventional training + TrainInternal(key, chain_eg, *computation); + } + if (num_minibatches_processed_ == 0) { + ConsolidateMemory(nnet_); + ConsolidateMemory(delta_nnet_); + } + num_minibatches_processed_++; +} + +void NnetChainTrainer2::TrainInternal(const std::string &key, + const NnetChainExample &eg, + const NnetComputation &computation) { + const NnetTrainerOptions &nnet_config = opts_.nnet_config; + // note: because we give the 1st arg (nnet_) as a pointer to the + // constructor of 'computer', it will use that copy of the nnet to + // store stats. + NnetComputer computer(nnet_config.compute_config, computation, + nnet_, delta_nnet_); + + std::string lang_name = "default"; + ParseFromQueryString(key, "lang", &lang_name); + + // give the inputs to the computer object. + computer.AcceptInputs(*nnet_, eg.inputs); + computer.Run(); + + this->ProcessOutputs(false, lang_name, eg, &computer); + computer.Run(); + + // If relevant, add in the part of the gradient that comes from + // parameter-level L2 regularization. + ApplyL2Regularization(*nnet_, + GetNumNvalues(eg.inputs, false) * + nnet_config.l2_regularize_factor, + delta_nnet_); + + // Updates the parameters of nnet + bool success = UpdateNnetWithMaxChange( + *delta_nnet_, + nnet_config.max_param_change, + 1.0, 1.0 - nnet_config.momentum, nnet_, + &max_change_stats_); + + // Scale down the batchnorm stats (keeps them fresh... this affects what + // happens when we use the model with batchnorm test-mode set). + ScaleBatchnormStats(nnet_config.batchnorm_stats_scale, nnet_); + + // The following will only do something if we have a LinearComponent + // or AffineComponent with orthonormal-constraint set to a nonzero value. + ConstrainOrthonormal(nnet_); + + // Scale delta_nnet + if (success) + ScaleNnet(nnet_config.momentum, delta_nnet_); + else + ScaleNnet(0.0, delta_nnet_); +} + +void NnetChainTrainer2::TrainInternalBackstitch(const std::string key, const NnetChainExample &eg, + const NnetComputation &computation, + bool is_backstitch_step1) { + const NnetTrainerOptions &nnet_config = opts_.nnet_config; + // note: because we give the 1st arg (nnet_) as a pointer to the + // constructor of 'computer', it will use that copy of the nnet to + // store stats. + NnetComputer computer(nnet_config.compute_config, computation, + nnet_, delta_nnet_); + // give the inputs to the computer object. + computer.AcceptInputs(*nnet_, eg.inputs); + computer.Run(); + + bool is_backstitch_step2 = !is_backstitch_step1; + this->ProcessOutputs(is_backstitch_step2, key, eg, &computer); + computer.Run(); + + BaseFloat max_change_scale, scale_adding; + if (is_backstitch_step1) { + // max-change is scaled by backstitch_training_scale; + // delta_nnet is scaled by -backstitch_training_scale when added to nnet; + max_change_scale = nnet_config.backstitch_training_scale; + scale_adding = -nnet_config.backstitch_training_scale; + } else { + // max-change is scaled by 1 + backstitch_training_scale; + // delta_nnet is scaled by 1 + backstitch_training_scale when added to nnet; + max_change_scale = 1.0 + nnet_config.backstitch_training_scale; + scale_adding = 1.0 + nnet_config.backstitch_training_scale; + // If relevant, add in the part of the gradient that comes from L2 + // regularization. It may not be optimally inefficient to do it on both + // passes of the backstitch, like we do here, but it probably minimizes + // any harmful interactions with the max-change. + ApplyL2Regularization(*nnet_, + 1.0 / scale_adding * GetNumNvalues(eg.inputs, false) * + nnet_config.l2_regularize_factor, delta_nnet_); + } + + // Updates the parameters of nnet + UpdateNnetWithMaxChange( + *delta_nnet_, nnet_config.max_param_change, + max_change_scale, scale_adding, nnet_, + &max_change_stats_); + + if (is_backstitch_step1) { + // The following will only do something if we have a LinearComponent or + // AffineComponent with orthonormal-constraint set to a nonzero value. We + // choose to do this only on the 1st backstitch step, for efficiency. + ConstrainOrthonormal(nnet_); + } + + if (!is_backstitch_step1) { + // Scale down the batchnorm stats (keeps them fresh... this affects what + // happens when we use the model with batchnorm test-mode set). Do this + // after backstitch step 2 so that the stats are scaled down before we start + // the next minibatch. + ScaleBatchnormStats(nnet_config.batchnorm_stats_scale, nnet_); + } + + ScaleNnet(0.0, delta_nnet_); +} + +void NnetChainTrainer2::ProcessOutputs(bool is_backstitch_step2, + const std::string &lang_name, + const NnetChainExample &eg, + NnetComputer *computer) { + // normally the eg will have just one output named 'output', but + // we don't assume this. + // In backstitch training, the output-name with the "_backstitch" suffix is + // the one computed after the first, backward step of backstitch. + const std::string suffix = (is_backstitch_step2 ? "_backstitch" : ""); + std::vector::const_iterator iter = eg.outputs.begin(), + end = eg.outputs.end(); + for (; iter != end; ++iter) { + const NnetChainSupervision &sup = *iter; + std::string node_name = "output-" + lang_name; + /* sup.name = node_name; */ + int32 node_index = nnet_->GetNodeIndex(node_name); + if (node_index < 0 || + !nnet_->IsOutputNode(node_index)) + KALDI_ERR << "Network has no output named " << node_name; + + const CuMatrixBase &nnet_output = computer->GetOutput(node_name); + CuMatrix nnet_output_deriv(nnet_output.NumRows(), + nnet_output.NumCols(), + kUndefined); + + bool use_xent = (opts_.chain_config.xent_regularize != 0.0); + std::string xent_name = node_name + "-xent"; // "output-${lang_name}-xent". + CuMatrix xent_deriv; + + BaseFloat tot_objf, tot_l2_term, tot_weight; + + ComputeChainObjfAndDeriv(opts_.chain_config, *(model_.GetDenGraphForLang(lang_name)), + sup.supervision, nnet_output, + &tot_objf, &tot_l2_term, &tot_weight, + &nnet_output_deriv, + (use_xent ? &xent_deriv : NULL)); + + if (use_xent) { + // this block computes the cross-entropy objective. + const CuMatrixBase &xent_output = computer->GetOutput( + xent_name); + // at this point, xent_deriv is posteriors derived from the numerator + // computation. note, xent_objf has a factor of '.supervision.weight' + BaseFloat xent_objf = TraceMatMat(xent_output, xent_deriv, kTrans); + objf_info_[xent_name + suffix].UpdateStats(xent_name + suffix, + opts_.nnet_config.print_interval, + num_minibatches_processed_, + tot_weight, xent_objf); + } + + if (opts_.apply_deriv_weights && sup.deriv_weights.Dim() != 0) { + CuVector cu_deriv_weights(sup.deriv_weights); + nnet_output_deriv.MulRowsVec(cu_deriv_weights); + if (use_xent) + xent_deriv.MulRowsVec(cu_deriv_weights); + } + + /* computer->AcceptInput(sup.name, &nnet_output_deriv); */ + computer->AcceptInput(node_name, &nnet_output_deriv); + + /* objf_info_[sup.name + suffix].UpdateStats(sup.name + suffix, */ + objf_info_[node_name + suffix].UpdateStats(sup.name + suffix, + opts_.nnet_config.print_interval, + num_minibatches_processed_, + tot_weight, tot_objf, tot_l2_term); + + if (use_xent) { + xent_deriv.Scale(opts_.chain_config.xent_regularize); + computer->AcceptInput(xent_name, &xent_deriv); + } + } +} + +bool NnetChainTrainer2::PrintTotalStats() const { + unordered_map::const_iterator + iter = objf_info_.begin(), + end = objf_info_.end(); + bool ans = false; + for (; iter != end; ++iter) { + const std::string &name = iter->first; + const ObjectiveFunctionInfo &info = iter->second; + ans = info.PrintTotalStats(name) || ans; + } + max_change_stats_.Print(*nnet_); + return ans; +} + +NnetChainTrainer2::~NnetChainTrainer2() { + if (opts_.nnet_config.write_cache != "") { + Output ko(opts_.nnet_config.write_cache, opts_.nnet_config.binary_write_cache); + compiler_.WriteCache(ko.Stream(), opts_.nnet_config.binary_write_cache); + KALDI_LOG << "Wrote computation cache to " << opts_.nnet_config.write_cache; + } + delete delta_nnet_; +} + +NnetChainModel2::NnetChainModel2( + const NnetChainTraining2Options &opts, + Nnet *nnet, + const std::string &den_fst_dir + ): + opts_(opts), + nnet(nnet), + den_fst_dir_(den_fst_dir) { +} + +NnetChainModel2::~NnetChainModel2() { +} + +NnetChainModel2::LanguageInfo::LanguageInfo( + const NnetChainModel2::LanguageInfo &other): + name(other.name), + den_graph_(other.den_graph_) + { } + + +NnetChainModel2::LanguageInfo::LanguageInfo( + const std::string &name, + const fst::StdVectorFst &den_fst, + int32 num_pdfs): + name(name), + den_graph_(den_fst, num_pdfs){ +} + +void NnetChainModel2::GetPathname(const std::string &dir, + const std::string &name, + const std::string &suffix, + std::string *pathname) { + std::ostringstream str; + str << dir << '/' << name << '.' << suffix; + *pathname = str.str(); +} + +void NnetChainModel2::GetPathname(const std::string &dir, + const std::string &name, + int32 job_id, + const std::string &suffix, + std::string *pathname) { + std::ostringstream str; + str << dir << '/' << name << '.' << job_id << '.' << suffix; + *pathname = str.str(); +} + +NnetChainModel2::LanguageInfo *NnetChainModel2::GetInfoForLang( + const std::string &lang) { + auto iter = lang_info_.find(lang); + if (iter != lang_info_.end()) { + return iter->second; + } else { + std::string den_fst_filename; + GetPathname(den_fst_dir_, lang, "den.fst", &den_fst_filename); + fst::StdVectorFst den_fst; + ReadFstKaldi(den_fst_filename, &den_fst); + std::string outputname = "output-" + lang; + + LanguageInfo *info = new LanguageInfo(lang, den_fst, nnet->OutputDim(outputname)); + lang_info_[lang] = info; + return info; + } +} + +/* fst::StdVectorFst* NnetChainModel2::GetDenFstForLang( */ +/* const std::string &language_name) { */ +/* LanguageInfo *info = GetInfoForLang(language_name); */ +/* return &(info->den_fst); */ +/* } */ + +chain::DenominatorGraph *NnetChainModel2::GetDenGraphForLang(const std::string &language_name){ + LanguageInfo *info = GetInfoForLang(language_name); + return &(info->den_graph_); +} +} // namespace nnet3 +} // namespace kaldi + diff --git a/src/nnet3/nnet-chain-training2.h b/src/nnet3/nnet-chain-training2.h new file mode 100644 index 000000000..8654a77f3 --- /dev/null +++ b/src/nnet3/nnet-chain-training2.h @@ -0,0 +1,192 @@ +// nnet3/nnet-chain-training.h + +// Copyright 2015 Johns Hopkins University (author: Daniel Povey) +// 2019 Idiap Research Institute (author: Srikanth Madikeri) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_NNET3_NNET_CHAIN_TRAINING_H_ +#define KALDI_NNET3_NNET_CHAIN_TRAINING_H_ + +#include "nnet3/nnet-example.h" +#include "nnet3/nnet-computation.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/nnet-optimize.h" +#include "nnet3/nnet-chain-example.h" +#include "nnet3/nnet-training.h" +#include "nnet3/nnet-chain-training.h" +#include "chain/chain-training.h" +#include "chain/chain-den-graph.h" +#include "nnet3/nnet-chain-example.h" + +namespace kaldi { +namespace nnet3 { + +struct NnetChainTraining2Options { + NnetTrainerOptions nnet_config; + chain::ChainTrainingOptions chain_config; + bool apply_deriv_weights; + NnetChainTraining2Options(): apply_deriv_weights(true) { } + + void Register(OptionsItf *opts) { + nnet_config.Register(opts); + chain_config.Register(opts); + opts->Register("apply-deriv-weights", &apply_deriv_weights, + "If true, apply the per-frame derivative weights stored with " + "the example"); + } +}; + +class NnetChainModel2 { + public: + /** + Constructor to which you pass the model directory and the den-fst + directory. There is no requirement that all these directories be distinct. + + For each language called "lang" the following files should exist: + /lang.den.fst /lang.normalization.fst + + In practice, the language name will be either "default", in the + typical (monolingual) setup, or it might be arbitrary strings + representing languages such as "english", "french", and so on. + In general the language can be any string containing ASCII letters, numbers + or underscores. + + The models and denominator FSTs will only be read when they are actually + required, so languages that are not used by a particular job (e.g. because + they were not represented in the egs) will not actually be read. + + **/ + + NnetChainModel2(const NnetChainTraining2Options &opts, + Nnet *nnet, + const std::string &den_fst_dir); + + /* fst::StdVectorFst *GetDenFstForLang(const std::string &language_name); */ + chain::DenominatorGraph *GetDenGraphForLang(const std::string &language_name); + + ~NnetChainModel2(); + + private: + // This function sets "pathname" to the string: + // /. + void GetPathname(const std::string &dir, + const std::string &name, + const std::string &suffix, + std::string *pathname); + + // If job_id is >= 0, then this version of GetPathname() sets "pathname" to + // the string: + // /.. + // otherwise (job_id < 0) it sets it to + // /. + void GetPathname(const std::string &dir, + const std::string &name, + int32 job_id, + const std::string &suffix, + std::string *pathname); + + // struct LanguageInfo contains the data that is stored per language. + struct LanguageInfo { + // name of the language + std::string name; + // den_fst comes from /.den.fst + // fst::StdVectorFst den_fst; + chain::DenominatorGraph den_graph_; + + // transform comes from /.ada + LanguageInfo() { } + + LanguageInfo(const std::string &name, const fst::StdVectorFst &den_fst, int32 num_pdfs); + // Copy constructor + LanguageInfo(const LanguageInfo &other); + }; + + // get the LanguageInfo* for this language, creating it (and reading its + // contents from disk) if it does not already exist. + LanguageInfo *GetInfoForLang(const std::string &lang); + + const NnetChainTraining2Options &opts_; + Nnet *nnet; + // Directory where denominator FSTs are located. + std::string den_fst_dir_; + + std::unordered_map lang_info_; +}; // class NnetChainModel2 + + +/** + This class is for single-threaded training of neural nets using the 'chain' + model. +*/ +class NnetChainTrainer2 { + public: + NnetChainTrainer2(const NnetChainTraining2Options &config, + const NnetChainModel2 &model, + Nnet *nnet); + + // train on one minibatch. + void Train(const std::string &key, NnetChainExample &eg); + + // Prints out the final stats, and return true if there was a nonzero count. + bool PrintTotalStats() const; + + ~NnetChainTrainer2(); + private: + // The internal function for doing one step of conventional SGD training. + void TrainInternal(const std::string &key, const NnetChainExample &eg, + const NnetComputation &computation); + + // The internal function for doing one step of backstitch training. Depending + // on whether is_backstitch_step1 is true, It could be either the first + // (backward) step, or the second (forward) step of backstitch. + void TrainInternalBackstitch(const std::string key, const NnetChainExample &eg, + const NnetComputation &computation, + bool is_backstitch_step1); + + void ProcessOutputs(bool is_backstitch_step2, const std::string &key, const NnetChainExample &eg, + NnetComputer *computer); + + const NnetChainTraining2Options opts_; + + NnetChainModel2 model_; + Nnet *nnet_; + Nnet *delta_nnet_; // stores the change to the parameters on each training + // iteration. + CachingOptimizingCompiler compiler_; + + // This code supports multiple output layers, even though in the + // normal case there will be just one output layer named "output". + // So we store the objective functions per output layer. + int32 num_minibatches_processed_; + + // stats for max-change. + MaxChangeStats max_change_stats_; + + unordered_map objf_info_; + + // This value is used in backstitch training when we need to ensure + // consistent dropout masks. It's set to a value derived from rand() + // when the class is initialized. + int32 srand_seed_; +}; + + +}// namespace nnet3 +} // namespace kaldi + +#endif // KALDI_NNET3_NNET_CHAIN_TRAINING_H_ + diff --git a/src/nnet3/nnet-computation-graph.cc b/src/nnet3/nnet-computation-graph.cc index ea30b0040..cf99bbcd8 100644 --- a/src/nnet3/nnet-computation-graph.cc +++ b/src/nnet3/nnet-computation-graph.cc @@ -857,7 +857,6 @@ void ComputationGraphBuilder::UpdateComputableInfo(int32 cindex_id) { void ComputationGraphBuilder::IncrementUsableCount(int32 cindex_id) { - KALDI_PARANOID_ASSERT(static_cast(cindex_id)(cindex_id) 0); if (--cindex_info_[cindex_id].usable_count == 0 && cindex_info_[cindex_id].computable != kNotComputable) { diff --git a/src/nnet3/nnet-example-utils.cc b/src/nnet3/nnet-example-utils.cc index 15004092e..6b917483b 100644 --- a/src/nnet3/nnet-example-utils.cc +++ b/src/nnet3/nnet-example-utils.cc @@ -81,7 +81,13 @@ static void GetIoSizes(const std::vector &src, } - +static int32 FindMaxNValue(const NnetIo &io) { + int32 max_n = 0; + for (auto &index: io.indexes) + if (index.n > max_n) + max_n = index.n; + return max_n; +} // Do the final merging of NnetIo, once we have obtained the names, dims and // sizes for each feature/supervision type. @@ -98,6 +104,9 @@ static void MergeIo(const std::vector &src, // The features in the different NnetIo in the Indexes across all examples std::vector > output_lists(num_feats); + // This is 1 for single examples and larger than 1 for already-merged egs, and + // it must be the same for all io's across all examples: + int32 example_stride = FindMaxNValue(src[0].io[0]) + 1; // Initialize the merged_eg merged_eg->io.clear(); merged_eg->io.resize(num_feats); @@ -139,9 +148,11 @@ static void MergeIo(const std::vector &src, for (int32 i = this_offset; i < this_offset + this_size; i++) { // we could easily support merging already-merged egs, but I don't see a // need for it right now. - KALDI_ASSERT(output_iter[i].n == 0 && - "Merging already-merged egs? Not currentlysupported."); - output_iter[i].n = n; + /* KALDI_ASSERT(output_iter[i].n == 0 && */ + /* "Merging already-merged egs? Not currentlysupported."); */ + KALDI_ASSERT(output_iter[i].n < example_stride); + output_iter[i].n += n * example_stride; + //output_iter[i].n = n; } this_offset += this_size; // note: this_offset is a reference. } @@ -354,10 +365,15 @@ UtteranceSplitter::UtteranceSplitter(const ExampleGenerationConfig &config): } UtteranceSplitter::~UtteranceSplitter() { + /* KALDI_LOG << "Split " << total_num_utterances_ << " utts, with " */ + /* << "total length " << total_input_frames_ << " frames (" */ + /* << (total_input_frames_ / 360000.0) << " hours assuming " */ + /* << "100 frames per second)"; */ KALDI_LOG << "Split " << total_num_utterances_ << " utts, with " << "total length " << total_input_frames_ << " frames (" << (total_input_frames_ / 360000.0) << " hours assuming " - << "100 frames per second)"; + << "100 frames per second) into " << total_num_chunks_ + << " chunks."; float average_chunk_length = total_frames_in_chunks_ * 1.0 / total_num_chunks_, overlap_percent = total_frames_overlap_ * 100.0 / total_input_frames_, output_percent = total_frames_in_chunks_ * 100.0 / total_input_frames_, diff --git a/src/online2/Makefile b/src/online2/Makefile index 242c7be6d..bbc7ac07b 100644 --- a/src/online2/Makefile +++ b/src/online2/Makefile @@ -9,7 +9,7 @@ OBJFILES = online-gmm-decodable.o online-feature-pipeline.o online-ivector-featu online-nnet2-feature-pipeline.o online-gmm-decoding.o online-timing.o \ online-endpoint.o onlinebin-util.o online-speex-wrapper.o \ online-nnet2-decoding.o online-nnet2-decoding-threaded.o \ - online-nnet3-decoding.o + online-nnet3-decoding.o online-nnet3-incremental-decoding.o LIBNAME = kaldi-online2 diff --git a/src/online2/online-endpoint.cc b/src/online2/online-endpoint.cc index aa7752c44..a3be0791f 100644 --- a/src/online2/online-endpoint.cc +++ b/src/online2/online-endpoint.cc @@ -71,10 +71,10 @@ bool EndpointDetected(const OnlineEndpointConfig &config, return false; } -template +template int32 TrailingSilenceLength(const TransitionModel &tmodel, const std::string &silence_phones_str, - const LatticeFasterOnlineDecoderTpl &decoder) { + const DEC &decoder) { std::vector silence_phones; if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones)) KALDI_ERR << "Bad --silence-phones option in endpointing config: " @@ -87,7 +87,7 @@ int32 TrailingSilenceLength(const TransitionModel &tmodel, ConstIntegerSet silence_set(silence_phones); bool use_final_probs = false; - typename LatticeFasterOnlineDecoderTpl::BestPathIterator iter = + typename DEC::BestPathIterator iter = decoder.BestPathEnd(use_final_probs, NULL); int32 num_silence_frames = 0; while (!iter.Done()) { // we're going backwards in time from the most @@ -117,7 +117,7 @@ bool EndpointDetected( BaseFloat final_relative_cost = decoder.FinalRelativeCost(); int32 num_frames_decoded = decoder.NumFramesDecoded(), - trailing_silence_frames = TrailingSilenceLength(tmodel, + trailing_silence_frames = TrailingSilenceLength>(tmodel, config.silence_phones, decoder); @@ -125,6 +125,26 @@ bool EndpointDetected( frame_shift_in_seconds, final_relative_cost); } +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder) { + if (decoder.NumFramesDecoded() == 0) return false; + + BaseFloat final_relative_cost = decoder.FinalRelativeCost(); + + int32 num_frames_decoded = decoder.NumFramesDecoded(), + trailing_silence_frames = TrailingSilenceLength>(tmodel, + config.silence_phones, + decoder); + + return EndpointDetected(config, num_frames_decoded, trailing_silence_frames, + frame_shift_in_seconds, final_relative_cost); +} + + // Instantiate EndpointDetected for the types we need. // It will require TrailingSilenceLength so we don't have to instantiate that. @@ -143,5 +163,21 @@ bool EndpointDetected( BaseFloat frame_shift_in_seconds, const LatticeFasterOnlineDecoderTpl &decoder); +template +bool EndpointDetected >( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl > &decoder); + + +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder); + + } // namespace kaldi diff --git a/src/online2/online-endpoint.h b/src/online2/online-endpoint.h index aaf9232db..3171f0c53 100644 --- a/src/online2/online-endpoint.h +++ b/src/online2/online-endpoint.h @@ -35,6 +35,7 @@ #include "lat/kaldi-lattice.h" #include "hmm/transition-model.h" #include "decoder/lattice-faster-online-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" namespace kaldi { /// @addtogroup onlinedecoding OnlineDecoding @@ -187,10 +188,10 @@ bool EndpointDetected(const OnlineEndpointConfig &config, /// integer id's of phones that we consider silence. We use the the /// BestPathEnd() and TraceBackOneLink() functions of LatticeFasterOnlineDecoder /// to do this efficiently. -template +template int32 TrailingSilenceLength(const TransitionModel &tmodel, const std::string &silence_phones, - const LatticeFasterOnlineDecoderTpl &decoder); + const DEC &decoder); /// This is a higher-level convenience function that works out the @@ -202,6 +203,15 @@ bool EndpointDetected( BaseFloat frame_shift_in_seconds, const LatticeFasterOnlineDecoderTpl &decoder); +/// This is a higher-level convenience function that works out the +/// arguments to the EndpointDetected function above, from the decoder. +template +bool EndpointDetected( + const OnlineEndpointConfig &config, + const TransitionModel &tmodel, + BaseFloat frame_shift_in_seconds, + const LatticeIncrementalOnlineDecoderTpl &decoder); + diff --git a/src/online2/online-ivector-feature.cc b/src/online2/online-ivector-feature.cc index 32a4db700..3a15ac9a3 100644 --- a/src/online2/online-ivector-feature.cc +++ b/src/online2/online-ivector-feature.cc @@ -519,6 +519,57 @@ void OnlineSilenceWeighting::ComputeCurrentTraceback( } } +template +void OnlineSilenceWeighting::ComputeCurrentTraceback( + const LatticeIncrementalOnlineDecoderTpl &decoder) { + int32 num_frames_decoded = decoder.NumFramesDecoded(), + num_frames_prev = frame_info_.size(); + // note, num_frames_prev is not the number of frames previously decoded, + // it's the generally-larger number of frames that we were requested to + // provide weights for. + if (num_frames_prev < num_frames_decoded) + frame_info_.resize(num_frames_decoded); + if (num_frames_prev > num_frames_decoded && + frame_info_[num_frames_decoded].transition_id != -1) + KALDI_ERR << "Number of frames decoded decreased"; // Likely bug + + if (num_frames_decoded == 0) + return; + int32 frame = num_frames_decoded - 1; + bool use_final_probs = false; + typename LatticeIncrementalOnlineDecoderTpl::BestPathIterator iter = + decoder.BestPathEnd(use_final_probs, NULL); + while (frame >= 0) { + LatticeArc arc; + arc.ilabel = 0; + while (arc.ilabel == 0) // the while loop skips over input-epsilons + iter = decoder.TraceBackBestPath(iter, &arc); + // note, the iter.frame values are slightly unintuitively defined, + // they are one less than you might expect. + KALDI_ASSERT(iter.frame == frame - 1); + + if (frame_info_[frame].token == iter.tok) { + // we know that the traceback from this point back will be identical, so + // no point tracing back further. Note: we are comparing memory addresses + // of tokens of the decoder; this guarantees it's the same exact token, + // because tokens, once allocated on a frame, are only deleted, never + // reallocated for that frame. + break; + } + + if (num_frames_output_and_correct_ > frame) + num_frames_output_and_correct_ = frame; + + frame_info_[frame].token = iter.tok; + frame_info_[frame].transition_id = arc.ilabel; + frame--; + // leave frame_info_.current_weight at zero for now (as set in the + // constructor), reflecting that we haven't already output a weight for that + // frame. + } +} + + // Instantiate the template OnlineSilenceWeighting::ComputeCurrentTraceback(). template void OnlineSilenceWeighting::ComputeCurrentTraceback >( @@ -526,6 +577,13 @@ void OnlineSilenceWeighting::ComputeCurrentTraceback >( template void OnlineSilenceWeighting::ComputeCurrentTraceback( const LatticeFasterOnlineDecoderTpl &decoder); +template +void OnlineSilenceWeighting::ComputeCurrentTraceback >( + const LatticeIncrementalOnlineDecoderTpl > &decoder); +template +void OnlineSilenceWeighting::ComputeCurrentTraceback( + const LatticeIncrementalOnlineDecoderTpl &decoder); + void OnlineSilenceWeighting::GetDeltaWeights( int32 num_frames_ready, int32 first_decoder_frame, diff --git a/src/online2/online-ivector-feature.h b/src/online2/online-ivector-feature.h index 12bc5c6bb..0d02ab06e 100644 --- a/src/online2/online-ivector-feature.h +++ b/src/online2/online-ivector-feature.h @@ -33,6 +33,7 @@ #include "feat/online-feature.h" #include "ivector/ivector-extractor.h" #include "decoder/lattice-faster-online-decoder.h" +#include "decoder/lattice-incremental-online-decoder.h" namespace kaldi { /// @addtogroup onlinefeat OnlineFeatureExtraction @@ -480,6 +481,8 @@ class OnlineSilenceWeighting { // It will be instantiated for FST == fst::Fst and fst::GrammarFst. template void ComputeCurrentTraceback(const LatticeFasterOnlineDecoderTpl &decoder); + template + void ComputeCurrentTraceback(const LatticeIncrementalOnlineDecoderTpl &decoder); // Calling this function gets the changes in weight that require us to modify // the stats... the output format is (frame-index, delta-weight). diff --git a/src/online2/online-nnet3-incremental-decoding.cc b/src/online2/online-nnet3-incremental-decoding.cc new file mode 100644 index 000000000..5e7acf147 --- /dev/null +++ b/src/online2/online-nnet3-incremental-decoding.cc @@ -0,0 +1,75 @@ +// online2/online-nnet3-incremental-decoding.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "online2/online-nnet3-incremental-decoding.h" +#include "lat/lattice-functions.h" +#include "lat/determinize-lattice-pruned.h" +#include "decoder/grammar-fst.h" + +namespace kaldi { + +template +SingleUtteranceNnet3IncrementalDecoderTpl::SingleUtteranceNnet3IncrementalDecoderTpl( + const LatticeIncrementalDecoderConfig &decoder_opts, + const TransitionModel &trans_model, + const nnet3::DecodableNnetSimpleLoopedInfo &info, + const FST &fst, + OnlineNnet2FeaturePipeline *features): + decoder_opts_(decoder_opts), + input_feature_frame_shift_in_seconds_(features->FrameShiftInSeconds()), + trans_model_(trans_model), + decodable_(trans_model_, info, + features->InputFeature(), features->IvectorFeature()), + decoder_(fst, trans_model, decoder_opts_) { + decoder_.InitDecoding(); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::InitDecoding(int32 frame_offset) { + decoder_.InitDecoding(); + decodable_.SetFrameOffset(frame_offset); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::AdvanceDecoding() { + decoder_.AdvanceDecoding(&decodable_); +} + +template +void SingleUtteranceNnet3IncrementalDecoderTpl::GetBestPath(bool end_of_utterance, + Lattice *best_path) const { + decoder_.GetBestPath(best_path, end_of_utterance); +} + +template +bool SingleUtteranceNnet3IncrementalDecoderTpl::EndpointDetected( + const OnlineEndpointConfig &config) { + BaseFloat output_frame_shift = + input_feature_frame_shift_in_seconds_ * + decodable_.FrameSubsamplingFactor(); + return kaldi::EndpointDetected(config, trans_model_, + output_frame_shift, decoder_); +} + + +// Instantiate the template for the types needed. +template class SingleUtteranceNnet3IncrementalDecoderTpl >; +template class SingleUtteranceNnet3IncrementalDecoderTpl; + +} // namespace kaldi diff --git a/src/online2/online-nnet3-incremental-decoding.h b/src/online2/online-nnet3-incremental-decoding.h new file mode 100644 index 000000000..e407cc2be --- /dev/null +++ b/src/online2/online-nnet3-incremental-decoding.h @@ -0,0 +1,148 @@ +// online2/online-nnet3-incremental-decoding.h + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_ONLINE2_ONLINE_NNET3_INCREMENTAL_DECODING_H_ +#define KALDI_ONLINE2_ONLINE_NNET3_INCREMENTAL_DECODING_H_ + +#include +#include +#include + +#include "nnet3/decodable-online-looped.h" +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" +#include "itf/online-feature-itf.h" +#include "online2/online-endpoint.h" +#include "online2/online-nnet2-feature-pipeline.h" +#include "decoder/lattice-incremental-online-decoder.h" +#include "hmm/transition-model.h" +#include "hmm/posterior.h" + +namespace kaldi { +/// @addtogroup onlinedecoding OnlineDecoding +/// @{ + + +/** + You will instantiate this class when you want to decode a single utterance + using the online-decoding setup for neural nets. The template will be + instantiated only for FST = fst::Fst and FST = fst::GrammarFst. +*/ + +template +class SingleUtteranceNnet3IncrementalDecoderTpl { + public: + + // Constructor. The pointer 'features' is not being given to this class to own + // and deallocate, it is owned externally. + SingleUtteranceNnet3IncrementalDecoderTpl(const LatticeIncrementalDecoderConfig &decoder_opts, + const TransitionModel &trans_model, + const nnet3::DecodableNnetSimpleLoopedInfo &info, + const FST &fst, + OnlineNnet2FeaturePipeline *features); + + /// Initializes the decoding and sets the frame offset of the underlying + /// decodable object. This method is called by the constructor. You can also + /// call this method when you want to reset the decoder state, but want to + /// keep using the same decodable object, e.g. in case of an endpoint. + void InitDecoding(int32 frame_offset = 0); + + /// Advances the decoding as far as we can. + void AdvanceDecoding(); + + /// Finalizes the decoding. Cleans up and prunes remaining tokens, so the + /// GetLattice() call will return faster. You must not call this before + /// calling (TerminateDecoding() or InputIsFinished()) and then Wait(). + void FinalizeDecoding() { decoder_.FinalizeDecoding(); } + + int32 NumFramesDecoded() const { return decoder_.NumFramesDecoded(); } + + int32 NumFramesInLattice() const { return decoder_.NumFramesInLattice(); } + + /* Gets the lattice. The output lattice has any acoustic scaling in it + (which will typically be desirable in an online-decoding context); if you + want an un-scaled lattice, scale it using ScaleLattice() with the inverse + of the acoustic weight. + + @param [in] num_frames_to_include The number of frames you want + to be included in the lattice. Must be in the range + [NumFramesInLattice().. NumFramesDecoded()]. If you + make it a few frames less than NumFramesDecoded(), it + will save significant computation. + @param [in] use_final_probs True if you want the lattice to + contain final-probs (if at least one state was final + on the most recently decoded frame). Must be false + if num_frames_to_include < NumFramesDecoded(). + Must be true if you have previously called + FinalizeDecoding(). + */ + const CompactLattice &GetLattice(int32 num_frames_to_include, + bool use_final_probs = false) { + return decoder_.GetLattice(num_frames_to_include, use_final_probs); + } + + + + + + /// Outputs an FST corresponding to the single best path through the current + /// lattice. If "use_final_probs" is true AND we reached the final-state of + /// the graph then it will include those as final-probs, else it will treat + /// all final-probs as one. + void GetBestPath(bool end_of_utterance, + Lattice *best_path) const; + + + /// This function calls EndpointDetected from online-endpoint.h, + /// with the required arguments. + bool EndpointDetected(const OnlineEndpointConfig &config); + + const LatticeIncrementalOnlineDecoderTpl &Decoder() const { return decoder_; } + + ~SingleUtteranceNnet3IncrementalDecoderTpl() { } + private: + + const LatticeIncrementalDecoderConfig &decoder_opts_; + + // this is remembered from the constructor; it's ultimately + // derived from calling FrameShiftInSeconds() on the feature pipeline. + BaseFloat input_feature_frame_shift_in_seconds_; + + // we need to keep a reference to the transition model around only because + // it's needed by the endpointing code. + const TransitionModel &trans_model_; + + nnet3::DecodableAmNnetLoopedOnline decodable_; + + LatticeIncrementalOnlineDecoderTpl decoder_; + +}; + + +typedef SingleUtteranceNnet3IncrementalDecoderTpl > SingleUtteranceNnet3IncrementalDecoder; + +/// @} End of "addtogroup onlinedecoding" + +} // namespace kaldi + + + +#endif // KALDI_ONLINE2_ONLINE_NNET3_DECODING_H_ diff --git a/src/online2bin/Makefile b/src/online2bin/Makefile index 28c135eb9..2552e7148 100644 --- a/src/online2bin/Makefile +++ b/src/online2bin/Makefile @@ -12,7 +12,7 @@ BINFILES = online2-wav-gmm-latgen-faster apply-cmvn-online \ online2-wav-dump-features ivector-randomize \ online2-wav-nnet2-am-compute online2-wav-nnet2-latgen-threaded \ online2-wav-nnet3-latgen-faster online2-wav-nnet3-latgen-grammar \ - online2-tcp-nnet3-decode-faster + online2-tcp-nnet3-decode-faster online2-wav-nnet3-latgen-incremental OBJFILES = diff --git a/src/online2bin/online2-wav-nnet3-latgen-faster.cc b/src/online2bin/online2-wav-nnet3-latgen-faster.cc index 1549dd6ae..c7fb3806e 100644 --- a/src/online2bin/online2-wav-nnet3-latgen-faster.cc +++ b/src/online2bin/online2-wav-nnet3-latgen-faster.cc @@ -58,7 +58,8 @@ void GetDiagnosticsAndPrintOutput(const std::string &utt, *tot_like += likelihood; KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is " << (likelihood / num_frames) << " over " << num_frames - << " frames."; + << " frames, = " << (-weight.Value1() / num_frames) + << ',' << (weight.Value2() / num_frames); if (word_syms != NULL) { std::cerr << utt << ' '; diff --git a/src/online2bin/online2-wav-nnet3-latgen-incremental.cc b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc new file mode 100644 index 000000000..aaa87f24d --- /dev/null +++ b/src/online2bin/online2-wav-nnet3-latgen-incremental.cc @@ -0,0 +1,306 @@ +// online2bin/online2-wav-nnet3-latgen-incremental.cc + +// Copyright 2019 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "feat/wave-reader.h" +#include "online2/online-nnet3-incremental-decoding.h" +#include "online2/online-nnet2-feature-pipeline.h" +#include "online2/onlinebin-util.h" +#include "online2/online-timing.h" +#include "online2/online-endpoint.h" +#include "fstext/fstext-lib.h" +#include "lat/lattice-functions.h" +#include "util/kaldi-thread.h" +#include "nnet3/nnet-utils.h" + +namespace kaldi { + +void GetDiagnosticsAndPrintOutput(const std::string &utt, + const fst::SymbolTable *word_syms, + const CompactLattice &clat, + int64 *tot_num_frames, + double *tot_like) { + if (clat.NumStates() == 0) { + KALDI_WARN << "Empty lattice."; + return; + } + CompactLattice best_path_clat; + CompactLatticeShortestPath(clat, &best_path_clat); + + Lattice best_path_lat; + ConvertLattice(best_path_clat, &best_path_lat); + + double likelihood; + LatticeWeight weight; + int32 num_frames; + std::vector alignment; + std::vector words; + GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); + num_frames = alignment.size(); + likelihood = -(weight.Value1() + weight.Value2()); + *tot_num_frames += num_frames; + *tot_like += likelihood; + KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is " + << (likelihood / num_frames) << " over " << num_frames + << " frames, = " << (-weight.Value1() / num_frames) + << ',' << (weight.Value2() / num_frames); + + if (word_syms != NULL) { + std::cerr << utt << ' '; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + std::cerr << std::endl; + } +} + +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace fst; + + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Reads in wav file(s) and simulates online decoding with neural nets\n" + "(nnet3 setup), with optional iVector-based speaker adaptation and\n" + "optional endpointing. Note: some configuration values and inputs are\n" + "set via config files whose filenames are passed as options\n" + "The lattice determinization algorithm here can operate\n" + "incrementally.\n" + "\n" + "Usage: online2-wav-nnet3-latgen-incremental [options] " + " \n" + "The spk2utt-rspecifier can just be if\n" + "you want to decode utterance by utterance.\n"; + + ParseOptions po(usage); + + std::string word_syms_rxfilename; + + // feature_opts includes configuration for the iVector adaptation, + // as well as the basic features. + OnlineNnet2FeaturePipelineConfig feature_opts; + nnet3::NnetSimpleLoopedComputationOptions decodable_opts; + LatticeIncrementalDecoderConfig decoder_opts; + OnlineEndpointConfig endpoint_opts; + + BaseFloat chunk_length_secs = 0.18; + bool do_endpointing = false; + bool online = true; + + po.Register("chunk-length", &chunk_length_secs, + "Length of chunk size in seconds, that we process. Set to <= 0 " + "to use all input in one chunk."); + po.Register("word-symbol-table", &word_syms_rxfilename, + "Symbol table for words [for debug output]"); + po.Register("do-endpointing", &do_endpointing, + "If true, apply endpoint detection"); + po.Register("online", &online, + "You can set this to false to disable online iVector estimation " + "and have all the data for each utterance used, even at " + "utterance start. This is useful where you just want the best " + "results and don't care about online operation. Setting this to " + "false has the same effect as setting " + "--use-most-recent-ivector=true and --greedy-ivector-extractor=true " + "in the file given to --ivector-extraction-config, and " + "--chunk-length=-1."); + po.Register("num-threads-startup", &g_num_threads, + "Number of threads used when initializing iVector extractor."); + + feature_opts.Register(&po); + decodable_opts.Register(&po); + decoder_opts.Register(&po); + endpoint_opts.Register(&po); + + + po.Read(argc, argv); + + if (po.NumArgs() != 5) { + po.PrintUsage(); + return 1; + } + + std::string nnet3_rxfilename = po.GetArg(1), + fst_rxfilename = po.GetArg(2), + spk2utt_rspecifier = po.GetArg(3), + wav_rspecifier = po.GetArg(4), + clat_wspecifier = po.GetArg(5); + + OnlineNnet2FeaturePipelineInfo feature_info(feature_opts); + + if (!online) { + feature_info.ivector_extractor_info.use_most_recent_ivector = true; + feature_info.ivector_extractor_info.greedy_ivector_extractor = true; + chunk_length_secs = -1.0; + } + + TransitionModel trans_model; + nnet3::AmNnetSimple am_nnet; + { + bool binary; + Input ki(nnet3_rxfilename, &binary); + trans_model.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); + nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet())); + } + + // this object contains precomputed stuff that is used by all decodable + // objects. It takes a pointer to am_nnet because if it has iVectors it has + // to modify the nnet to accept iVectors at intervals. + nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts, + &am_nnet); + + + fst::Fst *decode_fst = ReadFstKaldiGeneric(fst_rxfilename); + + fst::SymbolTable *word_syms = NULL; + if (word_syms_rxfilename != "") + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_rxfilename; + + int32 num_done = 0, num_err = 0; + double tot_like = 0.0; + int64 num_frames = 0; + + SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); + RandomAccessTableReader wav_reader(wav_rspecifier); + CompactLatticeWriter clat_writer(clat_wspecifier); + + OnlineTimingStats timing_stats; + + for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { + std::string spk = spk2utt_reader.Key(); + const std::vector &uttlist = spk2utt_reader.Value(); + OnlineIvectorExtractorAdaptationState adaptation_state( + feature_info.ivector_extractor_info); + for (size_t i = 0; i < uttlist.size(); i++) { + std::string utt = uttlist[i]; + if (!wav_reader.HasKey(utt)) { + KALDI_WARN << "Did not find audio for utterance " << utt; + num_err++; + continue; + } + const WaveData &wave_data = wav_reader.Value(utt); + // get the data for channel zero (if the signal is not mono, we only + // take the first channel). + SubVector data(wave_data.Data(), 0); + + OnlineNnet2FeaturePipeline feature_pipeline(feature_info); + feature_pipeline.SetAdaptationState(adaptation_state); + + OnlineSilenceWeighting silence_weighting( + trans_model, + feature_info.silence_weighting_config, + decodable_opts.frame_subsampling_factor); + + SingleUtteranceNnet3IncrementalDecoder decoder(decoder_opts, trans_model, + decodable_info, + *decode_fst, &feature_pipeline); + OnlineTimer decoding_timer(utt); + + BaseFloat samp_freq = wave_data.SampFreq(); + int32 chunk_length; + if (chunk_length_secs > 0) { + chunk_length = int32(samp_freq * chunk_length_secs); + if (chunk_length == 0) chunk_length = 1; + } else { + chunk_length = std::numeric_limits::max(); + } + + int32 samp_offset = 0; + std::vector > delta_weights; + + while (samp_offset < data.Dim()) { + int32 samp_remaining = data.Dim() - samp_offset; + int32 num_samp = chunk_length < samp_remaining ? chunk_length + : samp_remaining; + + SubVector wave_part(data, samp_offset, num_samp); + feature_pipeline.AcceptWaveform(samp_freq, wave_part); + + samp_offset += num_samp; + decoding_timer.WaitUntil(samp_offset / samp_freq); + if (samp_offset == data.Dim()) { + // no more input. flush out last frames + feature_pipeline.InputFinished(); + } + + if (silence_weighting.Active() && + feature_pipeline.IvectorFeature() != NULL) { + silence_weighting.ComputeCurrentTraceback(decoder.Decoder()); + silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(), + &delta_weights); + feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights); + } + + decoder.AdvanceDecoding(); + + if (do_endpointing && decoder.EndpointDetected(endpoint_opts)) { + break; + } + } + decoder.FinalizeDecoding(); + + bool use_final_probs = true; + CompactLattice clat = decoder.GetLattice(decoder.NumFramesDecoded(), + use_final_probs); + + Connect(&clat); + GetDiagnosticsAndPrintOutput(utt, word_syms, clat, + &num_frames, &tot_like); + + decoding_timer.OutputStats(&timing_stats); + + // In an application you might avoid updating the adaptation state if + // you felt the utterance had low confidence. See lat/confidence.h + feature_pipeline.GetAdaptationState(&adaptation_state); + + // we want to output the lattice with un-scaled acoustics. + BaseFloat inv_acoustic_scale = + 1.0 / decodable_opts.acoustic_scale; + ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat); + + clat_writer.Write(utt, clat); + KALDI_LOG << "Decoded utterance " << utt; + num_done++; + } + } + timing_stats.Print(online); + + KALDI_LOG << "Decoded " << num_done << " utterances, " + << num_err << " with errors."; + KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames) + << " per frame over " << num_frames << " frames."; + delete decode_fst; + delete word_syms; // will delete if non-NULL. + return (num_done != 0 ? 0 : 1); + } catch(const std::exception& e) { + std::cerr << e.what(); + return -1; + } +} // main() diff --git a/src/onlinebin/online-net-client.cc b/src/onlinebin/online-net-client.cc index dfcfa9361..64d157886 100644 --- a/src/onlinebin/online-net-client.cc +++ b/src/onlinebin/online-net-client.cc @@ -30,6 +30,7 @@ int main(int argc, char *argv[]) { try { +#ifndef KALDI_NO_PORTAUDIO using namespace kaldi; typedef kaldi::int32 int32; @@ -122,6 +123,9 @@ int main(int argc, char *argv[]) { } freeaddrinfo(server_addr); return 0; +#else + throw std::runtime_error("kaldi is compiled with KALDI_NO_PORTAUDIO"); +#endif } catch(const std::exception& e) { std::cerr << e.what(); return -1; diff --git a/tools/extras/install_openblas.sh b/tools/extras/install_openblas.sh index bcf75b6d8..e8a67ceb3 100755 --- a/tools/extras/install_openblas.sh +++ b/tools/extras/install_openblas.sh @@ -10,7 +10,7 @@ if ! command -v gfortran 2>/dev/null; then echo "$0: gfortran is not installed. Please install it, e.g. by:" echo " apt-get install gfortran" echo "(if on Debian or Ubuntu), or:" - echo " yum install fortran" + echo " yum install gcc-gfortran" echo "(if on RedHat/CentOS). On a Mac, if brew is installed, it's:" echo " brew install gfortran" exit 1