diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..59ad507 --- /dev/null +++ b/.gitignore @@ -0,0 +1,252 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python,windows,macos,linux +# Edit at https://www.toptal.com/developers/gitignore?templates=python,windows,macos,linux + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# End of https://www.toptal.com/developers/gitignore/api/python,windows,macos,linux + +data/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..05e59f2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,28 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: local + hooks: + - id: ruff + name: Lint python files + entry: ruff + args: [ + "check", + "--fix", + ".", + ] + language: system + types: [python] + - id: ruff-format + name: Format python files + entry: ruff + args: [ + "format", + ".", + ] + language: system + types: [python] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..24078f6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,7 @@ +Copyright 2024 wattai + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..3131deb --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ +# WIP: Independent Vector Analysis; IVA + +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) + +This repository aims to provide tools and resources for implementing IVA algorithms and applications. + +![IVA Example Image](https://wallpapers.miyanova.com/data/367_600.jpg) + +## Setup + +1. Clone the repository + + ```shell + git clone git@github.com:wattai/Independent-Vector-Analysis.git + cd Independent-Vector-Analysis + ``` + +1. Install dependencies + + ```shell + pip install -e . + ``` + +## Run the IVA algorithm + +1. Run a demo script + + ```shell + python demo_script.py + ``` + +## What is Independent Vector Analysis? + +Independent Vector Analysis (IVA) is a computational technique used in signal processing to separate mixed signals into their original, independent components. It is an extension of Independent Component Analysis (ICA) that is particularly useful when dealing with multiple datasets or multidimensional data. IVA assumes that the source signals are statistically independent and aims to maximize this independence to achieve separation. This method is widely applied in fields such as biomedical signal processing, audio source separation, and telecommunications. IVA is advantageous over ICA when the datasets have dependencies across different dimensions, as it can exploit these dependencies to improve the separation performance. The technique often involves optimization algorithms and requires careful consideration of the model parameters to ensure accurate results. IVA's effectiveness can be influenced by the choice of cost functions and constraints, which are crucial for capturing the statistical properties of the source signals. Additionally, the performance of IVA can be enhanced by incorporating prior knowledge about the signal structure or by using advanced algorithms that adaptively adjust to the data characteristics. diff --git a/mystft.py b/mystft.py deleted file mode 100644 index 7510ba9..0000000 --- a/mystft.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Sat Mar 12 13:22:07 2016 - -@author: wattai -""" - -# 参考ページからコピペ. - -# ================================== -# -# Short Time Fourier Trasform -# -# ================================== -from scipy import ceil, complex128, float64, hamming, zeros -from scipy.fftpack import fft# , ifft -from scipy import ifft # こっちじゃないとエラー出るときあった気がする -from scipy.io.wavfile import read -from matplotlib import pylab as pl -import numpy as np - -# ====== -# STFT -# ====== -""" -x : 入力信号(モノラル) -win : 窓関数 -step : シフト幅 -""" -def stft(x, win, step): - l = len(x) # 入力信号の長さ - N = len(win) # 窓幅、つまり切り出す幅 - M = int(ceil(float(l - N + step) / step)) # スペクトログラムの時間フレーム数 - - new_x = zeros( int( N + ((M - 1) * step) ), dtype = float64) - new_x[: l] = x # 信号をいい感じの長さにする - - # リスト内包表記版 - #X = np.array([ fft(new_x[int(step * m) : int(step * m + N)] * win) for m in np.arange(M) ]) - - # 純粋に繰り返し(なぜかこっちのが速い) - X = zeros([M, N], dtype = complex128) # スペクトログラムの初期化(複素数型) - for m in range(M): - start = int( step * m ) - X[m, :] = fft(new_x[start : start + N] * win) - - return X - -# ======= -# iSTFT -# ======= -def istft(X, win, step): - M, N = X.shape - assert (len(win) == N), "FFT length and window length are different." - - l = int( (M - 1) * step + N ) - x = zeros(l, dtype = float64) - wsum = zeros(l, dtype = float64) - - for m in range(M): - start = int( step * m ) - ### 滑らかな接続 - x[start : start + N] = x[start : start + N] + ifft(X[m, :]).real * win - wsum[start : start + N] += win ** 2 - - pos = (wsum != 0) - #x_pre = x.copy() - ### 窓分のスケール合わせ - x[pos] /= wsum[pos] - return x - - -if __name__ == "__main__": - wavfile = "./townofdeath.wav" - fs, data = read(wavfile) - data = data[:, 0] - - fftLen = 512 # とりあえず - win = hamming(fftLen) # ハミング窓 - step = fftLen / 4 - - ### STFT - spectrogram = stft(data, win, step) - - ### iSTFT - resyn_data = istft(spectrogram, win, step) - - ### Plot - fig = pl.figure() - fig.add_subplot(311) - pl.plot(data) - pl.xlim([0, len(data)]) - pl.title("Input signal", fontsize = 10) - fig.add_subplot(312) - pl.imshow(abs(spectrogram[:, : int(fftLen / 2 + 1) ].T), aspect = "auto", origin = "lower") - pl.title("Spectrogram", fontsize = 10) - fig.add_subplot(313) - pl.plot(resyn_data) - pl.xlim([0, len(resyn_data)]) - pl.title("Resynthesized signal", fontsize = 10) - pl.show() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..443913d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "iva" +dynamic = ["version"] +authors = [ + {name = "wattai", email = "evamark.wattai@gmail.com"}, +] +description = "An implementation of IndependentVectorAnalysis that decompose mixed vector signals into their original sources." +readme = "README.md" +license = {file = "LICENSE"} +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", +] +dependencies = [ + "numpy~=1.26.4", + "scipy~=1.14.1", + "numba~=0.60.0", + "librosa~=0.10.2.post1", +] +[project.optional-dependencies] +dev = [ + "pytest~=8.3.3", + "pre-commit~=4.0.1", + "ruff", +] + +[tool.setuptools] +package-dir = {"" = "src"} +packages = ["iva"] diff --git a/scripts/run.py b/scripts/run.py new file mode 100644 index 0000000..2e26aba --- /dev/null +++ b/scripts/run.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +import numpy as np + +from iva.independent_vector_analysis import _IndependentVectorAnalysis + + +def load_dummy_signals( + time_length_sec=15, + num_channels=2, + fs=16000, +) -> np.ndarray: + return np.random.randn(int(time_length_sec * fs), num_channels), fs + + +if __name__ == "__main__": + # data, fs = sf.read("yuki_stereo_VM00_VF00_0750.wav") # 2人の会話 + data, fs = load_dummy_signals() + + iva = _IndependentVectorAnalysis( + num_components=2, + fs=fs, + num_iterations=5, + fft_window_length=1024, + ) + result = iva.fit_transform(data) diff --git a/src/iva/independent_vector_analysis.py b/src/iva/independent_vector_analysis.py new file mode 100644 index 0000000..3c0d23d --- /dev/null +++ b/src/iva/independent_vector_analysis.py @@ -0,0 +1,267 @@ +# -*- coding: utf-8 -*- + +import time + +import numpy as np +import soundfile as sf + +from iva.time_frequency_analysis import BaseShortTimeFFT, HammingShortTimeFFT + + +class IndependentVectorAnalysis: + def __init__(self, data, num_sources): + pass + + def fit(self): + """ + Fit the model to the data using an iterative algorithm. + + This method estimates the independent components from the mixed data + using a specified number of sources. + + Returns: + self: The fitted model. + """ + # Implementation of the fitting algorithm goes here + return self + + def transform(self, data): + pass + + +class _IndependentVectorAnalysis: + def __init__( + self, + num_iterations: int = 10, + num_components: int = 4, + sft: BaseShortTimeFFT = HammingShortTimeFFT( + fs=16000, window_length=2048, hop=1024 + ), + ): + """Independent Vector Analysis. + + Args: + num_iterations: The number of the iteration of IVA computaition. + num_components: The number of source signals. + sft: An instance of a short-time Fourier transform class. + + """ + self.N = num_iterations + self.fftLen = sft.window_length + self.n_components = num_components + self.fs = sft.fs + self.sft = sft + + self.W = None # Separation Matrix + self.r = None # + self.spectrogram = None + self.rebuild_spectrogram = None + + def _auxiva(self): + # 独立ベクトル分析開始 + N_t = self.spectrogram.shape[0] + N_omega = self.spectrogram.shape[1] + K = self.spectrogram.shape[2] + E = np.eye(K, dtype="complex") + self.W = np.zeros([K, K, N_omega], dtype="complex") # 分離行列初期化 + self.W[:, :, :] = E[:, :, None] + self.r = np.zeros([K, N_t], dtype="complex") # 時系列信号パワー初期化 + self.rebuild_spectrogram = np.zeros( + [ + self.spectrogram.shape[0], + self.spectrogram.shape[1], + self.spectrogram.shape[2], + ], + dtype="complex", + ) + z = 0 + 0j + + print(self.spectrogram.shape) + + # 反復回数N + for i in range(self.N): + for k in range(K): + # 補助変数の更新1 + self.r[k, :] = np.squeeze( + np.sum( + ( + np.abs( + self.spectrogram[:, :, :].T.transpose([1, 2, 0]) + @ self.W[k, :, :][:, :, None] + .conj() + .transpose([1, 0, 2]) + ) + ) + ** 2, + axis=0, + ) + ) + + self.r[k, :] = np.sqrt(self.r[k, :]) + dr = np.gradient(self.r[k, :]) + G_R = self.r[k, :].copy() # np.log(r[k, :]) # コントラスト関数指定 + fi = np.gradient(G_R, dr) / self.r[k, :] + fi0 = 1000.0 + fi[fi0 < fi] = fi0 + + # 補助変数の更新2 + V = ( + (1 / N_t) + * ( + ((fi * self.spectrogram.T).transpose([1, 0, 2])) + @ (self.spectrogram.conj().transpose([1, 0, 2])) + ) + ).transpose([1, 2, 0]) + # 分離行列の更新1(solve) + self.W[k, :, :] = ( + np.linalg.inv( + self.W.conj().transpose([2, 0, 1]) @ V.transpose([2, 0, 1]) + ) + @ E[k, :] + ).T + # 分離行列の更新2 + self.W[k, :, :] /= np.sqrt( + ( + self.W[k, :, :][:, :, None].conj().transpose([1, 2, 0]) + @ V.transpose([2, 0, 1]) + ) + @ self.W[k, :, :][:, :, None].transpose([1, 0, 2]) + ).squeeze() + + # 分離行列の正規化 + self.rebuild_spectrogram = ( + self.W.conj().transpose([2, 0, 1]) + @ self.spectrogram.transpose([1, 2, 0]) + ).transpose([2, 0, 1]) + z = np.sum(np.linalg.norm(self.rebuild_spectrogram, axis=1) ** 2) + self.W[:] /= np.sqrt(z / (N_omega * N_t * K)) + + print(str(i + 1) + "/" + str(self.N)) + + # 信号源復元(分離処理) + self.rebuild_spectrogram = ( + self.W.conj().transpose([2, 0, 1]) @ self.spectrogram.transpose([1, 2, 0]) + ).transpose([2, 0, 1]) + + return self.rebuild_spectrogram + + def fit_transform(self, data): + L, sigch = data.shape + if self.n_components is None: + self.n_components = data.shape[1] + elif self.n_components > data.shape[1]: + self.n_components = data.shape[1] + + start = time.time() + # Whitening --------------------------------------------------- + # whited_data = whitening(data, self.n_components) + whited_data = zca_whitening(data, self.n_components) + ### -------------------------------------------------------------- + elapsed_time1 = time.time() - start + print(whited_data.shape) + + sum_time = elapsed_time1 + elapsed_time2 = time.time() - start - sum_time + + # 時間領域 to 時間-周波数領域 -------------------------------------- + print(whited_data.shape) + self.spectrogram = self.sft.stft(whited_data.transpose(1, 0)).transpose(2, 1, 0) + print(self.spectrogram.shape) + # ----------------------------------------------------------------- + sum_time += elapsed_time2 + elapsed_time3 = time.time() - start - sum_time + + ### AuxIVA -------------------------------------------------------- + self.rebuild_spectrogram = self._auxiva() + ### -------------------------------------------------------------- + sum_time += elapsed_time3 + elapsed_time4 = time.time() - start - sum_time + + # 時間-周波数領域 to 時間領域 -------------------------------------- + result = self.sft.istft(self.rebuild_spectrogram.transpose(2, 1, 0)).transpose( + 1, 0 + ) + result = result[ + len(result) - len(whited_data) :, : + ] # STFTで生じた余分な信号長のカット + # result = multi_icwt(rebuild_spectrogram, omega0, sigma, fs) # iCWT(complex morlet) + # ----------------------------------------------------------------- + sum_time += elapsed_time4 + elapsed_time5 = time.time() - start - sum_time + + print("PCA : {0:6.2f}".format(elapsed_time1) + "[sec]") + print("FICA: {0:6.2f}".format(elapsed_time2) + "[sec]") + print("STFT: {0:6.2f}".format(elapsed_time3) + "[sec]") + print("IVA : {0:6.2f}".format(elapsed_time4) + "[sec]") + print("iSTFT: {0:5.2f}".format(elapsed_time5) + "[sec]") + + print(np.sqrt(np.average(np.abs(data[:, :]) ** 2))) + print(np.sqrt(np.average(np.abs(result[:, :]) ** 2))) + + print(np.linalg.norm(data[:, 0])) + print(np.linalg.norm(result[:, 0])) + + # 振幅補正(RMSの比を基準に) + # result[:, :] *= np.sqrt(np.average(np.abs(data[:, :])**2)) / np.sqrt(np.average(np.abs(result[:, :])**2)) + # 振幅補正(L2-norm の比を基準に) + # result *= np.average(np.linalg.norm(data)) / np.linalg.norm(result) + + return result + + +def whitening(x, n_components): + import numpy.linalg as LA + + x = x.copy() + nData, nDim = np.shape(x) + # 中心化centering + x = x - np.mean(x, axis=0) + # 相関行列 + C = np.dot(x.T, x) / nData + # 共分散行列の固有値分解でE,Dを求める + E, D, E_T = LA.svd(C) # 元の + # D, E = LA.eig(C) + print(D) + D = np.diag(D[:n_components] ** (-0.5)) + # 白色化行列V + # V = np.dot(E, np.dot(D, E_T)) # 元の + E = E[:, :n_components].copy() + print(E.shape) + print(E) + V = D @ E.T.conj() # PCA + # 線形変換z + z = x @ V.T + return z + + +def zca_whitening(x, n_components): + eps = 1e-6 + import numpy.linalg as LA + + x = x.copy() + nData, nDim = np.shape(x) + # 中心化centering + x = x - np.mean(x, axis=0) + # 相関行列 + C = np.dot(x.T, x) / nData + # 共分散行列の固有値分解でE,Dを求める + E, D, E_T = LA.svd(C) # 元の + # D, E = LA.eig(C) + print(D) + D = np.diag(1.0 / (np.sqrt(D[:n_components]) + eps)) + # D = np.diag(D[:n_components] ** (-0.5)) + # 白色化行列V + # V = np.dot(E, np.dot(D, E_T)) # 元の + E = E[:, :n_components].copy() + print(E.shape) + print(E) + V = E @ D @ E.T.conj() # ZCA + # 線形変換z + z = x @ V.T + return z + + +if __name__ == "__main__": + data, samplerate = sf.read("yuki_stereo_VM00_VF00_0750.wav") # 2人の会話 + iva = IndependentVectorAnalysis(N=5, fftLen=1024, n_components=2, fs=samplerate) + result = iva.fit_transform(data) diff --git a/src/iva/legacies/mystft.py b/src/iva/legacies/mystft.py new file mode 100644 index 0000000..3c83cf3 --- /dev/null +++ b/src/iva/legacies/mystft.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Mar 12 13:22:07 2016 + +@author: wattai +""" + +# 参考ページからコピペ. + +# ================================== +# +# Short Time Fourier Trasform +# +# ================================== +import numpy as np +# from scipy.fftpack import fft # , ifft +# from scipy import ifft # こっちじゃないとエラー出るときあった気がする +# from scipy.io.wavfile import read +# from matplotlib import pylab as pl + +# ====== +# STFT +# ====== +""" +x : 入力信号(モノラル) +win : 窓関数 +step : シフト幅 +""" + + +def stft(x, win, step): + signal_length = len(x) # 入力信号の長さ + N = len(win) # 窓幅、つまり切り出す幅 + M = int( + np.ceil(float(signal_length - N + step) / step) + ) # スペクトログラムの時間フレーム数 + + new_x = np.zeros(int(N + ((M - 1) * step)), dtype=np.float64) + new_x[:signal_length] = x # 信号をいい感じの長さにする + + # リスト内包表記版 + # X = np.array([ fft(new_x[int(step * m) : int(step * m + N)] * win) for m in np.arange(M) ]) + + # 純粋に繰り返し(なぜかこっちのが速い) + X = np.zeros([M, N], dtype=np.complex128) # スペクトログラムの初期化(複素数型) + for m in range(M): + start = int(step * m) + X[m, :] = np.fft.fft(new_x[start : start + N] * win) + + return X + + +# ======= +# iSTFT +# ======= +def istft(X, win, step): + M, N = X.shape + assert len(win) == N, "FFT length and window length are different." + + signal_length = int((M - 1) * step + N) + x = np.zeros(signal_length, dtype=np.float64) + wsum = np.zeros(signal_length, dtype=np.float64) + + for m in range(M): + start = int(step * m) + ### 滑らかな接続 + x[start : start + N] = x[start : start + N] + np.fft.ifft(X[m, :]).real * win + wsum[start : start + N] += win**2 + + pos = wsum != 0 + # x_pre = x.copy() + ### 窓分のスケール合わせ + x[pos] /= wsum[pos] + return x + + +# if __name__ == "__main__": +# wavfile = "./townofdeath.wav" +# fs, data = read(wavfile) +# data = data[:, 0] +# +# fftLen = 512 # とりあえず +# win = np.hamming(fftLen) # ハミング窓 +# step = fftLen / 4 +# +# ### STFT +# spectrogram = stft(data, win, step) +# +# ### iSTFT +# resyn_data = istft(spectrogram, win, step) +# +# ### Plot +# fig = pl.figure() +# fig.add_subplot(311) +# pl.plot(data) +# pl.xlim([0, len(data)]) +# pl.title("Input signal", fontsize=10) +# fig.add_subplot(312) +# pl.imshow( +# abs(spectrogram[:, : int(fftLen / 2 + 1)].T), aspect="auto", origin="lower" +# ) +# pl.title("Spectrogram", fontsize=10) +# fig.add_subplot(313) +# pl.plot(resyn_data) +# pl.xlim([0, len(resyn_data)]) +# pl.title("Resynthesized signal", fontsize=10) +# pl.show() +# diff --git a/stft_iva_istft.py b/src/iva/legacies/stft_iva_istft.py similarity index 61% rename from stft_iva_istft.py rename to src/iva/legacies/stft_iva_istft.py index c9addf8..be16650 100644 --- a/stft_iva_istft.py +++ b/src/iva/legacies/stft_iva_istft.py @@ -6,36 +6,32 @@ """ import numpy as np -from scipy import ceil, complex64, float64, hamming, zeros, hanning -from scipy.fftpack import fft, ifft import scipy as sp import soundfile as sf -from sklearn.decomposition import PCA, KernelPCA, FactorAnalysis, TruncatedSVD import time -import concurrent.futures -from scipy.signal import fftconvolve, resample -import pandas as pd +from scipy.signal import resample import matplotlib.pyplot as plt -import numba -#from numba.decorators import jit, autojit -#from numba import guvectorize +# from numba.decorators import jit, autojit +# from numba import guvectorize + def whitening(x, n_components): import numpy.linalg as LA + x = x.copy() nData, nDim = np.shape(x) - #中心化centering + # 中心化centering x = x - np.mean(x, axis=0) # 相関行列 C = np.dot(x.T, x) / nData # 共分散行列の固有値分解でE,Dを求める - E, D, E_T = LA.svd(C) # 元の - #D, E = LA.eig(C) + E, D, E_T = LA.svd(C) # 元の + # D, E = LA.eig(C) print(D) D = np.diag(D[:n_components] ** (-0.5)) - #白色化行列V - #V = np.dot(E, np.dot(D, E_T)) # 元の + # 白色化行列V + # V = np.dot(E, np.dot(D, E_T)) # 元の E = E[:, :n_components].copy() print(E.shape) print(E) @@ -44,23 +40,25 @@ def whitening(x, n_components): z = x @ V.T return z + def zca_whitening(x, n_components): eps = 1e-6 import numpy.linalg as LA + x = x.copy() nData, nDim = np.shape(x) - #中心化centering + # 中心化centering x = x - np.mean(x, axis=0) # 相関行列 C = np.dot(x.T, x) / nData # 共分散行列の固有値分解でE,Dを求める - E, D, E_T = LA.svd(C) # 元の - #D, E = LA.eig(C) + E, D, E_T = LA.svd(C) # 元の + # D, E = LA.eig(C) print(D) D = np.diag(1.0 / (np.sqrt(D[:n_components]) + eps)) - #D = np.diag(D[:n_components] ** (-0.5)) - #白色化行列V - #V = np.dot(E, np.dot(D, E_T)) # 元の + # D = np.diag(D[:n_components] ** (-0.5)) + # 白色化行列V + # V = np.dot(E, np.dot(D, E_T)) # 元の E = E[:, :n_components].copy() print(E.shape) print(E) @@ -68,211 +66,259 @@ def zca_whitening(x, n_components): # 線形変換z z = x @ V.T return z - + class IndependentVectorAnalysis: - def __init__(self, N=10, fftLen=128, n_components=4, fs=16000): - self.N = N self.fftLen = fftLen self.n_components = n_components self.fs = fs - self.W = None # Separation Matrix - self.r = None # + self.W = None # Separation Matrix + self.r = None # self.spectrogram = None self.rebuild_spectrogram = None - - - def _auxiva(self): + def _auxiva(self): # 独立ベクトル分析開始 N_t = self.spectrogram.shape[0] N_omega = self.spectrogram.shape[1] K = self.spectrogram.shape[2] - E = np.eye( K, dtype='complex' ) - self.W = np.zeros([K, K, N_omega], dtype='complex') # 分離行列初期化 + E = np.eye(K, dtype="complex") + self.W = np.zeros([K, K, N_omega], dtype="complex") # 分離行列初期化 self.W[:, :, :] = E[:, :, None] - self.r = np.zeros([ K, N_t ], dtype='complex') # 時系列信号パワー初期化 - #V2 = np.zeros([ K, K, K, N_omega ], dtype='complex') - self.rebuild_spectrogram = np.zeros([ self.spectrogram.shape[0], self.spectrogram.shape[1], self.spectrogram.shape[2] ] , dtype='complex' ) - z = 0+0j - + self.r = np.zeros([K, N_t], dtype="complex") # 時系列信号パワー初期化 + # V2 = np.zeros([ K, K, K, N_omega ], dtype='complex') + self.rebuild_spectrogram = np.zeros( + [ + self.spectrogram.shape[0], + self.spectrogram.shape[1], + self.spectrogram.shape[2], + ], + dtype="complex", + ) + z = 0 + 0j + print(self.spectrogram.shape) - + # 反復回数N - for i in range( self.N ): - for k in range( K ): + for i in range(self.N): + for k in range(K): # 補助変数の更新1 - self.r[k, :] = np.squeeze(np.sum((np.abs( self.spectrogram[:, :, :].T.transpose([1,2,0]) @ self.W[k, :, :][:, :, None].conj().transpose([1,0,2]) ))**2, axis=0)) - - self.r[k, :] = np.sqrt( self.r[k, :] ) + self.r[k, :] = np.squeeze( + np.sum( + ( + np.abs( + self.spectrogram[:, :, :].T.transpose([1, 2, 0]) + @ self.W[k, :, :][:, :, None] + .conj() + .transpose([1, 0, 2]) + ) + ) + ** 2, + axis=0, + ) + ) + + self.r[k, :] = np.sqrt(self.r[k, :]) dr = np.gradient(self.r[k, :]) - G_R = self.r[k, :].copy() #np.log(r[k, :]) # コントラスト関数指定 - fi = ( np.gradient(G_R, dr)/self.r[k, :] ) - fi0 = 1000. + G_R = self.r[k, :].copy() # np.log(r[k, :]) # コントラスト関数指定 + fi = np.gradient(G_R, dr) / self.r[k, :] + fi0 = 1000.0 fi[fi0 < fi] = fi0 - + # 補助変数の更新2 - V = ((1/N_t) * ((((fi*self.spectrogram.T).transpose([1,0,2]) ) @ (self.spectrogram.conj().transpose([1,0,2]) )))).transpose([1,2,0]) - # 分離行列の更新1(solve) - self.W[k, :, :] = (np.linalg.inv(self.W.conj().transpose([2,0,1]) @ V.transpose([2,0,1])) @ E[k, :]).T + V = ( + (1 / N_t) + * ( + ((fi * self.spectrogram.T).transpose([1, 0, 2])) + @ (self.spectrogram.conj().transpose([1, 0, 2])) + ) + ).transpose([1, 2, 0]) + # 分離行列の更新1(solve) + self.W[k, :, :] = ( + np.linalg.inv( + self.W.conj().transpose([2, 0, 1]) @ V.transpose([2, 0, 1]) + ) + @ E[k, :] + ).T # 分離行列の更新2 - self.W[k, :, :] /= np.sqrt( (self.W[k, :, :][:, :, None].conj().transpose([1,2,0]) @ V.transpose([2,0,1])) @ self.W[k, :, :][:, :, None].transpose([1,0,2]) ).squeeze() - - # 分離行列の正規化 - self.rebuild_spectrogram = ( self.W.conj().transpose([2,0,1]) @ self.spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) - z = np.sum(np.linalg.norm(self.rebuild_spectrogram, axis=1)**2) - self.W[:] /= (np.sqrt( z /(N_omega *N_t *K) )) - - print( str(i+1) +"/" +str(self.N) ) - + self.W[k, :, :] /= np.sqrt( + ( + self.W[k, :, :][:, :, None].conj().transpose([1, 2, 0]) + @ V.transpose([2, 0, 1]) + ) + @ self.W[k, :, :][:, :, None].transpose([1, 0, 2]) + ).squeeze() + + # 分離行列の正規化 + self.rebuild_spectrogram = ( + self.W.conj().transpose([2, 0, 1]) + @ self.spectrogram.transpose([1, 2, 0]) + ).transpose([2, 0, 1]) + z = np.sum(np.linalg.norm(self.rebuild_spectrogram, axis=1) ** 2) + self.W[:] /= np.sqrt(z / (N_omega * N_t * K)) + + print(str(i + 1) + "/" + str(self.N)) + # 信号源復元(分離処理) - self.rebuild_spectrogram = ( self.W.conj().transpose([2,0,1]) @ self.spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) - + self.rebuild_spectrogram = ( + self.W.conj().transpose([2, 0, 1]) @ self.spectrogram.transpose([1, 2, 0]) + ).transpose([2, 0, 1]) + return self.rebuild_spectrogram - - def fit_transform(self, data): + def fit_transform(self, data): L, sigch = data.shape - win = sp.hamming(self.fftLen) # ハミング窓 - step = (self.fftLen/2) /2 # フレーム窓シフト幅(論文[一般で言われているシフト幅]のもう/2で合致?) + win = sp.hamming(self.fftLen) # ハミング窓 + step = ( + self.fftLen / 2 + ) / 2 # フレーム窓シフト幅(論文[一般で言われているシフト幅]のもう/2で合致?) if self.n_components is None: self.n_components = data.shape[1] elif self.n_components > data.shape[1]: self.n_components = data.shape[1] - #from sklearn.decomposition import PCA, KernelPCA, FactorAnalysis, TruncatedSVD, SparsePCA + # from sklearn.decomposition import PCA, KernelPCA, FactorAnalysis, TruncatedSVD, SparsePCA start = time.time() # Whitening --------------------------------------------------- # pca = PCA(n_components=self.n_components, copy=True, whiten=True) # whited_data = pca.fit_transform( data[:, :] ) - #tSVD = TruncatedSVD(n_components=n_components, algorithm='arpack') - #whited_data = tSVD.fit_transform( data ) + # tSVD = TruncatedSVD(n_components=n_components, algorithm='arpack') + # whited_data = tSVD.fit_transform( data ) # whited_data = whitening(data, self.n_components) # 必ずこっちを使うこと - whited_data = zca_whitening(data, self.n_components) # 必ずこっちを使うこと + whited_data = zca_whitening(data, self.n_components) # 必ずこっちを使うこと ### -------------------------------------------------------------- elapsed_time1 = time.time() - start print(whited_data.shape) - + sum_time = elapsed_time1 elapsed_time2 = time.time() - start - sum_time - + # CWT用パラメータ --------------------------------------------------- - omega0 = 6 - sigma = 6 - N_scale = 256 + # omega0 = 6 + # sigma = 6 + # N_scale = 256 # ---------------------------------------------------------------- # 時間領域 to 時間-周波数領域 -------------------------------------- - self.spectrogram = multi_stft(whited_data, win, step) # STFT + self.spectrogram = multi_stft(whited_data, win, step) # STFT # self.spectrogram = multi_cwt(whited_data, omega0, sigma, self.fs, N_scale) # CWT(complex morlet) # ----------------------------------------------------------------- - sum_time += elapsed_time2 + sum_time += elapsed_time2 elapsed_time3 = time.time() - start - sum_time - + ### AuxIVA -------------------------------------------------------- self.rebuild_spectrogram = IndependentVectorAnalysis._auxiva(self) ### -------------------------------------------------------------- sum_time += elapsed_time3 elapsed_time4 = time.time() - start - sum_time - + # 時間-周波数領域 to 時間領域 -------------------------------------- - result = multi_istft(self.rebuild_spectrogram, win, step) # iSTFT - result = result[len(result)-len(whited_data):, :] # STFTで生じた余分な信号長のカット - #result = multi_icwt(rebuild_spectrogram, omega0, sigma, fs) # iCWT(complex morlet) + result = multi_istft(self.rebuild_spectrogram, win, step) # iSTFT + result = result[ + len(result) - len(whited_data) :, : + ] # STFTで生じた余分な信号長のカット + # result = multi_icwt(rebuild_spectrogram, omega0, sigma, fs) # iCWT(complex morlet) # ----------------------------------------------------------------- - sum_time += elapsed_time4 + sum_time += elapsed_time4 elapsed_time5 = time.time() - start - sum_time - - print('PCA : {0:6.2f}'.format(elapsed_time1) + "[sec]") - print('FICA: {0:6.2f}'.format(elapsed_time2) + "[sec]") - print('STFT: {0:6.2f}'.format(elapsed_time3) + "[sec]") - print('IVA : {0:6.2f}'.format(elapsed_time4) + "[sec]") - print('iSTFT: {0:5.2f}'.format(elapsed_time5) + "[sec]") - - print(np.sqrt(np.average(np.abs(data[:, :])**2))) - print(np.sqrt(np.average(np.abs(result[:, :])**2))) - + + print("PCA : {0:6.2f}".format(elapsed_time1) + "[sec]") + print("FICA: {0:6.2f}".format(elapsed_time2) + "[sec]") + print("STFT: {0:6.2f}".format(elapsed_time3) + "[sec]") + print("IVA : {0:6.2f}".format(elapsed_time4) + "[sec]") + print("iSTFT: {0:5.2f}".format(elapsed_time5) + "[sec]") + + print(np.sqrt(np.average(np.abs(data[:, :]) ** 2))) + print(np.sqrt(np.average(np.abs(result[:, :]) ** 2))) + print(np.linalg.norm(data[:, 0])) print(np.linalg.norm(result[:, 0])) - + # 振幅補正(RMSの比を基準に) - #result[:, :] *= np.sqrt(np.average(np.abs(data[:, :])**2)) / np.sqrt(np.average(np.abs(result[:, :])**2)) + # result[:, :] *= np.sqrt(np.average(np.abs(data[:, :])**2)) / np.sqrt(np.average(np.abs(result[:, :])**2)) # 振幅補正(L2-norm の比を基準に) - #result *= np.average(np.linalg.norm(data)) / np.linalg.norm(result) - + # result *= np.average(np.linalg.norm(data)) / np.linalg.norm(result) + return result - + def transform(self, data, W): + self.W = W.copy() # 分離行列の取得 - self.W = W.copy() # 分離行列の取得 - L, sigch = data.shape - win = sp.hamming(self.fftLen) # ハミング窓 - step = (self.fftLen/2) #/2 # フレーム窓シフト幅(論文[一般で言われているシフト幅]のもう/2しないとダメ?) - if self.n_components is None: self.n_components = data.shape[1] - elif self.n_components > data.shape[1]: self.n_components = data.shape[1] - - - #from sklearn.decomposition import PCA, KernelPCA, FactorAnalysis, TruncatedSVD, SparsePCA + win = sp.hamming(self.fftLen) # ハミング窓 + step = ( + self.fftLen / 2 + ) # /2 # フレーム窓シフト幅(論文[一般で言われているシフト幅]のもう/2しないとダメ?) + if self.n_components is None: + self.n_components = data.shape[1] + elif self.n_components > data.shape[1]: + self.n_components = data.shape[1] + + # from sklearn.decomposition import PCA, KernelPCA, FactorAnalysis, TruncatedSVD, SparsePCA start = time.time() # Whitening --------------------------------------------------- - #pca = PCA(n_components=n_components, copy=True, whiten=True) - #whited_data = pca.fit_transform( data[:, :] ) - #tSVD = TruncatedSVD(n_components=n_components, algorithm='arpack') - #whited_data = tSVD.fit_transform( data ) - whited_data = whitening(data, self.n_components) # 必ずこっちを使うこと + # pca = PCA(n_components=n_components, copy=True, whiten=True) + # whited_data = pca.fit_transform( data[:, :] ) + # tSVD = TruncatedSVD(n_components=n_components, algorithm='arpack') + # whited_data = tSVD.fit_transform( data ) + whited_data = whitening(data, self.n_components) # 必ずこっちを使うこと ### -------------------------------------------------------------- elapsed_time1 = time.time() - start print(whited_data.shape) - + sum_time = elapsed_time1 elapsed_time2 = time.time() - start - sum_time - + # CWT用パラメータ --------------------------------------------------- - omega0 = 6 - sigma = 6 - N_scale = 256 + # omega0 = 6 + # sigma = 6 + # N_scale = 256 # ---------------------------------------------------------------- # 時間領域 to 時間-周波数領域 -------------------------------------- - self.spectrogram = multi_stft(whited_data, win, step) # STFT - #spectrogram = multi_cwt(whited_data, omega0, sigma, fs, N_scale) # CWT(complex morlet) + self.spectrogram = multi_stft(whited_data, win, step) # STFT + # spectrogram = multi_cwt(whited_data, omega0, sigma, fs, N_scale) # CWT(complex morlet) # ----------------------------------------------------------------- - sum_time += elapsed_time2 + sum_time += elapsed_time2 elapsed_time3 = time.time() - start - sum_time - + ### 信号源復元(分離処理) - self.rebuild_spectrogram = ( self.W.conj().transpose([2,0,1]) @ self.spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) + self.rebuild_spectrogram = ( + self.W.conj().transpose([2, 0, 1]) @ self.spectrogram.transpose([1, 2, 0]) + ).transpose([2, 0, 1]) ### -------------------------------------------------------------- sum_time += elapsed_time3 elapsed_time4 = time.time() - start - sum_time - + # 時間-周波数領域 to 時間領域 -------------------------------------- - result = multi_istft(self.rebuild_spectrogram, win, step) # iSTFT - result = result[len(result)-len(whited_data):, :] # STFTで生じた余分な信号長のカット - #result = multi_icwt(rebuild_spectrogram, omega0, sigma, fs) # iCWT(complex morlet) + result = multi_istft(self.rebuild_spectrogram, win, step) # iSTFT + result = result[ + len(result) - len(whited_data) :, : + ] # STFTで生じた余分な信号長のカット + # result = multi_icwt(rebuild_spectrogram, omega0, sigma, fs) # iCWT(complex morlet) # ----------------------------------------------------------------- - sum_time += elapsed_time4 + sum_time += elapsed_time4 elapsed_time5 = time.time() - start - sum_time - - print('PCA : {0:6.2f}'.format(elapsed_time1) + "[sec]") - print('FICA: {0:6.2f}'.format(elapsed_time2) + "[sec]") - print('STFT: {0:6.2f}'.format(elapsed_time3) + "[sec]") - print('IVA : {0:6.2f}'.format(elapsed_time4) + "[sec]") - print('iSTFT: {0:5.2f}'.format(elapsed_time5) + "[sec]") - - print(np.sqrt(np.average(np.abs(data[:, :])**2))) - print(np.sqrt(np.average(np.abs(result[:, :])**2))) - + + print("PCA : {0:6.2f}".format(elapsed_time1) + "[sec]") + print("FICA: {0:6.2f}".format(elapsed_time2) + "[sec]") + print("STFT: {0:6.2f}".format(elapsed_time3) + "[sec]") + print("IVA : {0:6.2f}".format(elapsed_time4) + "[sec]") + print("iSTFT: {0:5.2f}".format(elapsed_time5) + "[sec]") + + print(np.sqrt(np.average(np.abs(data[:, :]) ** 2))) + print(np.sqrt(np.average(np.abs(result[:, :]) ** 2))) + print(np.linalg.norm(data[:, 0])) print(np.linalg.norm(result[:, 0])) - + # 振幅補正(RMSの比を基準に) - #result[:, :] *= np.sqrt(np.average(np.abs(data[:, :])**2)) / np.sqrt(np.average(np.abs(result[:, :])**2)) + # result[:, :] *= np.sqrt(np.average(np.abs(data[:, :])**2)) / np.sqrt(np.average(np.abs(result[:, :])**2)) # 振幅補正(L2-norm の比を基準に) result *= np.average(np.linalg.norm(data)) / np.linalg.norm(result) - - return result + + return result + """ # AuxIVA1(多次元対応) ------------------------------------------------------- @@ -290,9 +336,9 @@ def auxiva1(spectrogram, N): V2 = np.zeros([ K, K, K, N_omega ], dtype='complex') rebuild_spectrogram = np.zeros([ spectrogram.shape[0], spectrogram.shape[1], spectrogram.shape[2] ] , dtype='complex' ) z = 0+0j - + print(spectrogram.shape) - + # 反復回数N for i in range( N ): for k in range( K ): @@ -322,14 +368,14 @@ def auxiva1(spectrogram, N): z += np.linalg.norm(rebuild_spectrogram[t, :, k])**2 #z = np.sum(r[:])**2 W[:] /= (np.sqrt( z /(N_omega *N_t *K) )) - + print( str(i+1) +"/" +str(N) ) - + # 信号源復元(分離処理) for omega in range( N_omega ): rebuild_spectrogram[:, omega, :] = ( W[:, :, omega].conj() @ spectrogram[:, omega, :].T ).T - - + + # Projection Back #n = 0 #for omega in range( N_omega ): @@ -340,59 +386,93 @@ def auxiva1(spectrogram, N): # U = rebuild_spectrogram[t, omega, :] # v = (A - E) @ U # rebuild_spectrogram[t, omega, :] = v / (2*(K**2)) - + return rebuild_spectrogram """ -def auxiva2(spectrogram, N): +def auxiva2(spectrogram, N): # 独立ベクトル分析開始 N_t = spectrogram.shape[0] N_omega = spectrogram.shape[1] K = spectrogram.shape[2] - E = np.identity( K, dtype='complex' ) - W = np.zeros([K, K, N_omega], dtype='complex') + E = np.identity(K, dtype="complex") + W = np.zeros([K, K, N_omega], dtype="complex") W[:, :, :] = E[:, :, None] - r = np.zeros([ K, N_t ], dtype='complex') - #V2 = np.zeros([ K, K, K, N_omega ], dtype='complex') - rebuild_spectrogram = np.zeros([ spectrogram.shape[0], spectrogram.shape[1], spectrogram.shape[2] ] , dtype='complex' ) - z = 0+0j - + r = np.zeros([K, N_t], dtype="complex") + # V2 = np.zeros([ K, K, K, N_omega ], dtype='complex') + rebuild_spectrogram = np.zeros( + [spectrogram.shape[0], spectrogram.shape[1], spectrogram.shape[2]], + dtype="complex", + ) + z = 0 + 0j + print(spectrogram.shape) - + # 反復回数N - for i in range( N ): - for k in range( K ): + for i in range(N): + for k in range(K): # 補助変数の更新1 - r[k, :] = np.squeeze(np.sum((np.abs( spectrogram[:, :, :].T.transpose([1,2,0]) @ W[k, :, :][:, :, None].conj().transpose([1,0,2]) ))**2, axis=0)) - - r[k, :] = np.sqrt( r[k, :] ) + r[k, :] = np.squeeze( + np.sum( + ( + np.abs( + spectrogram[:, :, :].T.transpose([1, 2, 0]) + @ W[k, :, :][:, :, None].conj().transpose([1, 0, 2]) + ) + ) + ** 2, + axis=0, + ) + ) + + r[k, :] = np.sqrt(r[k, :]) dr = np.gradient(r[k, :]) - G_R = r[k, :].copy() #np.log(r[k, :]) # コントラスト関数指定 - fi = ( np.gradient(G_R, dr)/r[k, :] ) - fi0 = 1000. + G_R = r[k, :].copy() # np.log(r[k, :]) # コントラスト関数指定 + fi = np.gradient(G_R, dr) / r[k, :] + fi0 = 1000.0 fi[fi0 < fi] = fi0 - + # 補助変数の更新2 - V = ((1/N_t) * ( (((fi*spectrogram.T).transpose([1,0,2]) ) @ (spectrogram.conj().transpose([1,0,2]) )))).transpose([1,2,0]) - # 分離行列の更新1(solve) - W[k, :, :] = (np.linalg.inv(W.conj().transpose([2,0,1]) @ V.transpose([2,0,1])) @ E[k, :]).T + V = ( + (1 / N_t) + * ( + ((fi * spectrogram.T).transpose([1, 0, 2])) + @ (spectrogram.conj().transpose([1, 0, 2])) + ) + ).transpose([1, 2, 0]) + # 分離行列の更新1(solve) + W[k, :, :] = ( + np.linalg.inv(W.conj().transpose([2, 0, 1]) @ V.transpose([2, 0, 1])) + @ E[k, :] + ).T # 分離行列の更新2 - W[k, :, :] /= np.sqrt( (W[k, :, :][:, :, None].conj().transpose([1,2,0]) @ V.transpose([2,0,1])) @ W[k, :, :][:, :, None].transpose([1,0,2]) ).squeeze() + W[k, :, :] /= np.sqrt( + ( + W[k, :, :][:, :, None].conj().transpose([1, 2, 0]) + @ V.transpose([2, 0, 1]) + ) + @ W[k, :, :][:, :, None].transpose([1, 0, 2]) + ).squeeze() + + # 分離行列の正規化 + rebuild_spectrogram = ( + W.conj().transpose([2, 0, 1]) @ spectrogram.transpose([1, 2, 0]) + ).transpose([2, 0, 1]) + z = np.sum(np.linalg.norm(rebuild_spectrogram, axis=1) ** 2) + W[:] /= np.sqrt(z / (N_omega * N_t * K)) + + print(str(i + 1) + "/" + str(N)) - # 分離行列の正規化 - rebuild_spectrogram = ( W.conj().transpose([2,0,1]) @ spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) - z = np.sum(np.linalg.norm(rebuild_spectrogram, axis=1)**2) - W[:] /= (np.sqrt( z /(N_omega *N_t *K) )) - - print( str(i+1) +"/" +str(N) ) - # 信号源復元(分離処理) - rebuild_spectrogram = ( W.conj().transpose([2,0,1]) @ spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) - + rebuild_spectrogram = ( + W.conj().transpose([2, 0, 1]) @ spectrogram.transpose([1, 2, 0]) + ).transpose([2, 0, 1]) + return rebuild_spectrogram - -""" + + +""" def auxiva3(spectrogram, N): # 独立ベクトル分析開始 @@ -406,16 +486,16 @@ def auxiva3(spectrogram, N): #V2 = np.zeros([ K, K, K, N_omega ], dtype='complex') rebuild_spectrogram = np.zeros([ spectrogram.shape[0], spectrogram.shape[1], spectrogram.shape[2] ] , dtype='complex' ) z = 0+0j - + print(spectrogram.shape) spectrogram = spectrogram.transpose([1,0,2]) - + # 反復回数N for i in range( N ): for k in range( K ): # 補助変数の更新1 r[k, :] = np.squeeze(np.sum((np.abs( spectrogram @ W[:, k, :][:, :, None].conj() ))**2, axis=0)) - + r[k, :] = np.sqrt( r[k, :] ) dr = np.gradient(r[k, :]) G_R = r[k, :].copy() #np.log(r[k, :]) # コントラスト関数指定 @@ -427,23 +507,23 @@ def auxiva3(spectrogram, N): V = ((1/N_t) * ( (((fi*spectrogram.transpose([0,2,1])) ) @ (spectrogram.conj() )))) # 分離行列の更新1(solve) W[:, k, :] = (np.linalg.inv(W.conj() @ V) @ E[k, :]) - + # 分離行列の更新2 w = W[:, k, :][:, :, None] W[:, k, :] = ( W[:, k, :].T / np.sqrt( w.conj().transpose([0,2,1]) @ V @ w ).squeeze() ).T - # 分離行列の正規化 - rebuild_spectrogram = ( W.conj() @ spectrogram.transpose([0,2,1]) ).transpose([2,0,1]) + # 分離行列の正規化 + rebuild_spectrogram = ( W.conj() @ spectrogram.transpose([0,2,1]) ).transpose([2,0,1]) z = np.sum(np.linalg.norm(rebuild_spectrogram, axis=1)**2) W[:] /= (np.sqrt( z /(N_omega *N_t *K) )) - + print( str(i+1) +"/" +str(N) ) - + # 信号源復元(分離処理) rebuild_spectrogram = ( W.conj() @ spectrogram.transpose([0,2,1]) ).transpose([2,0,1]) - - return rebuild_spectrogram - + + return rebuild_spectrogram + def auxiva4(spectrogram, N): # 独立ベクトル分析開始 @@ -457,12 +537,12 @@ def auxiva4(spectrogram, N): #V2 = np.zeros([ K, K, K, N_omega ], dtype='complex') rebuild_spectrogram = np.zeros([ spectrogram.shape[0], spectrogram.shape[1], spectrogram.shape[2] ] , dtype='complex' ) z = 0+0j - + print(spectrogram.shape) - + # 反復回数N for i in range( N ): - + # 補助変数の更新1 r[:, :] = np.einsum('ijk->kj', (np.abs( spectrogram.transpose([1,0,2]) @ W.conj().transpose([2,0,1]) ))**2) r[:, :] = np.sqrt( r[:, :] ) @@ -471,28 +551,28 @@ def auxiva4(spectrogram, N): fi_ = ( np.gradient(G_R, dr, axis=1) / r[:, :] ) fi0 = 1000. fi_[fi0 < fi_] = fi0 - + # 補助変数の更新2 V = ((1/N_t) * ( (((np.einsum('ijk,li->lijk', spectrogram, fi_).transpose([0,2,3,1]) ) @ (spectrogram.conj().transpose([1,0,2]) )))).transpose([0,2,3,1])) - + # 分離行列の更新1(solve) - W = np.einsum( 'ijkk,ik->ikj', np.linalg.inv(W.conj().transpose([2,0,1]) @ V.transpose([0,3,1,2])), E[:, :] ) - + W = np.einsum( 'ijkk,ik->ikj', np.linalg.inv(W.conj().transpose([2,0,1]) @ V.transpose([0,3,1,2])), E[:, :] ) + # 分離行列の更新2 W /= np.sqrt( (W[:, :, :, None].conj().transpose([0,2,3,1]) @ V.transpose([0,3,1,2])) @ W[:, :, :, None].transpose([0,2,1,3]) ).squeeze() - # 分離行列の正規化 - rebuild_spectrogram = ( W.conj().transpose([2,0,1]) @ spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) + # 分離行列の正規化 + rebuild_spectrogram = ( W.conj().transpose([2,0,1]) @ spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) z = np.sum(np.linalg.norm(rebuild_spectrogram, axis=1)**2) W[:] /= (np.sqrt( z /(N_omega *N_t *K) )) - + print( str(i+1) +"/" +str(N) ) - + # 信号源復元(分離処理) rebuild_spectrogram = ( W.conj().transpose([2,0,1]) @ spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) - + return rebuild_spectrogram - + def auxiva5(spectrogram, N): from joblib import Parallel, delayed @@ -507,40 +587,40 @@ def auxiva5(spectrogram, N): #V2 = np.zeros([ K, K, K, N_omega ], dtype='complex') rebuild_spectrogram = np.zeros([ spectrogram.shape[0], spectrogram.shape[1], spectrogram.shape[2] ] , dtype='complex' ) z = 0+0j - + print(spectrogram.shape) - + # 反復回数N for i in range( N ): W = Parallel(n_jobs=-1)( [delayed(aux_iva_sub_process)(k, r, spectrogram, W, N_t, E) for k in range(K)] ) W = np.asarray(W, dtype='complex') - # 分離行列の正規化 - rebuild_spectrogram = ( W.conj().transpose([2,0,1]) @ spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) + # 分離行列の正規化 + rebuild_spectrogram = ( W.conj().transpose([2,0,1]) @ spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) z = np.sum(np.linalg.norm(rebuild_spectrogram, axis=1)**2) W[:] /= (np.sqrt( z /(N_omega *N_t *K) )) - + print( str(i+1) +"/" +str(N) ) - + # 信号源復元(分離処理) rebuild_spectrogram = ( W.conj().transpose([2,0,1]) @ spectrogram.transpose([1,2,0]) ).transpose([2,0,1]) - + return rebuild_spectrogram def aux_iva_sub_process(k, r, spectrogram, W, N_t, E): # 補助変数の更新1 r[k, :] = np.squeeze(np.sum((np.abs( spectrogram[:, :, :].T.transpose([1,2,0]) @ W[k, :, :][:, :, None].conj().transpose([1,0,2]) ))**2, axis=0)) - + r[k, :] = np.sqrt( r[k, :] ) dr = np.gradient(r[k, :]) G_R = r[k, :].copy() #np.log(r[k, :]) # コントラスト関数指定 fi = ( np.gradient(G_R, dr)/r[k, :] ) fi0 = 1000. fi[fi0 < fi] = fi0 - + # 補助変数の更新2 V = ((1/N_t) * ( (((fi*spectrogram.T).transpose([1,0,2]) ) @ (spectrogram.conj().transpose([1,0,2]) )))).transpose([1,2,0]) - # 分離行列の更新1(solve) + # 分離行列の更新1(solve) W[k, :, :] = (np.linalg.inv(W.conj().transpose([2,0,1]) @ V.transpose([2,0,1])) @ E[k, :]).T # 分離行列の更新2 W[k, :, :] /= np.sqrt( (W[k, :, :][:, :, None].conj().transpose([1,2,0]) @ V.transpose([2,0,1])) @ W[k, :, :][:, :, None].transpose([1,0,2]) ).squeeze() @@ -549,15 +629,19 @@ def aux_iva_sub_process(k, r, spectrogram, W, N_t, E): """ + def multi_stft(data, win, step): import mystft + ### STFT --------------------------------------------------------- for i in range(data.shape[1]): - if i==0: + if i == 0: buff = mystft.stft(data[:, i], win, step) - spectrogram_ = np.empty([buff.shape[0], buff.shape[1], data.shape[1]], dtype='complex') + spectrogram_ = np.empty( + [buff.shape[0], buff.shape[1], data.shape[1]], dtype="complex" + ) spectrogram_[:, :, i] = buff - if i>0: + if i > 0: spectrogram_[:, :, i] = mystft.stft(data[:, i], win, step) ### --------------------------------------------------------------- return spectrogram_ @@ -565,16 +649,19 @@ def multi_stft(data, win, step): def multi_istft(rebuild_spectrogram, win, step): import mystft + ### iSTFT --------------------------------------------------------- - for i in range( rebuild_spectrogram.shape[2] ): - if i==0: + for i in range(rebuild_spectrogram.shape[2]): + if i == 0: buff = mystft.istft(rebuild_spectrogram[:, :, i], win, step) resyn_data = np.empty([buff.shape[0], rebuild_spectrogram.shape[2]]) resyn_data[:, i] = buff - if i>0: + if i > 0: resyn_data[:, i] = mystft.istft(rebuild_spectrogram[:, :, i], win, step) ### --------------------------------------------------------------- return resyn_data + + """ def multi_cwt(data, omega0, sigma, fs, N_scale): import sys,os @@ -591,7 +678,7 @@ def multi_cwt(data, omega0, sigma, fs, N_scale): spectrogram[:, :, i] = fe.cwt(data[:, i], omega0, sigma, fs, N_scale).T.copy() ### --------------------------------------------------------------- return spectrogram - + def multi_icwt(rebuild_spectrogram, omega0, sigma, fs): import sys,os sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../biological_signal_processing') @@ -607,111 +694,116 @@ def multi_icwt(rebuild_spectrogram, omega0, sigma, fs): ### --------------------------------------------------------------- return resyn_data """ + + def IVA(data, N=10, fftLen=128, n_components=4, fs=16000): - L, sigch = data.shape - #N = 10 # 反復回数 - #fftLen = 8192/4 # とりあえず # フレーム窓長 - win = sp.hamming(fftLen) # ハミング窓 - step = (fftLen/2) #/2 # フレーム窓シフト幅(論文[一般で言われているシフト幅]のもう/2しないとダメ) - #n_components = 4 - if n_components is None: n_components = data.shape[1] - elif n_components > data.shape[1]: n_components = data.shape[1] - - - #from sklearn.decomposition import PCA, KernelPCA, FactorAnalysis, TruncatedSVD, SparsePCA + # N = 10 # 反復回数 + # fftLen = 8192/4 # とりあえず # フレーム窓長 + win = sp.hamming(fftLen) # ハミング窓 + step = ( + fftLen / 2 + ) # /2 # フレーム窓シフト幅(論文[一般で言われているシフト幅]のもう/2しないとダメ) + # n_components = 4 + if n_components is None: + n_components = data.shape[1] + elif n_components > data.shape[1]: + n_components = data.shape[1] + + # from sklearn.decomposition import PCA, KernelPCA, FactorAnalysis, TruncatedSVD, SparsePCA start = time.time() # Whitening --------------------------------------------------- - #pca = PCA(n_components=n_components, copy=True, whiten=True) - #whited_data = pca.fit_transform( data[:, :] ) - #tSVD = TruncatedSVD(n_components=n_components, algorithm='arpack') - #whited_data = tSVD.fit_transform( data ) - whited_data = whitening(data, n_components) # 必ずこっちを使うこと + # pca = PCA(n_components=n_components, copy=True, whiten=True) + # whited_data = pca.fit_transform( data[:, :] ) + # tSVD = TruncatedSVD(n_components=n_components, algorithm='arpack') + # whited_data = tSVD.fit_transform( data ) + whited_data = whitening(data, n_components) # 必ずこっちを使うこと ### -------------------------------------------------------------- elapsed_time1 = time.time() - start print(whited_data.shape) - + sum_time = elapsed_time1 elapsed_time2 = time.time() - start - sum_time - + # CWT用パラメータ --------------------------------------------------- - omega0 = 6 - sigma = 6 - N_scale = 256 + # omega0 = 6 + # sigma = 6 + # N_scale = 256 # ---------------------------------------------------------------- # 時間領域 to 時間-周波数領域 -------------------------------------- - spectrogram_ = multi_stft(whited_data, win, step) # STFT - #spectrogram = multi_cwt(whited_data, omega0, sigma, fs, N_scale) # CWT(complex morlet) + spectrogram_ = multi_stft(whited_data, win, step) # STFT + # spectrogram = multi_cwt(whited_data, omega0, sigma, fs, N_scale) # CWT(complex morlet) # ----------------------------------------------------------------- - sum_time += elapsed_time2 + sum_time += elapsed_time2 elapsed_time3 = time.time() - start - sum_time - + ### AuxIVA -------------------------------------------------------- # Cython # python setup.py build_ext --inplace # でコンパイル - #import pyximport; pyximport.install()#pyimport = True) - #import iva_test - #cProfile.run('iva_test.auxiva1( spectrogram, N )') - #rebuild_spectrogram = iva_test.auxiva1( spectrogram, N ) - rebuild_spectrogram = auxiva2( spectrogram_, N ) + # import pyximport; pyximport.install()#pyimport = True) + # import iva_test + # cProfile.run('iva_test.auxiva1( spectrogram, N )') + # rebuild_spectrogram = iva_test.auxiva1( spectrogram, N ) + rebuild_spectrogram = auxiva2(spectrogram_, N) ### -------------------------------------------------------------- sum_time += elapsed_time3 elapsed_time4 = time.time() - start - sum_time - + # 時間-周波数領域 to 時間領域 -------------------------------------- - result = multi_istft(rebuild_spectrogram, win, step) # iSTFT + result = multi_istft(rebuild_spectrogram, win, step) # iSTFT # hanning窓の前半分をかけて入りを滑らかに - #result[:int(fftLen/2), :] = (hanning(fftLen)[:int(fftLen/2)] * result[:int(fftLen/2), :].T).T - result = result[len(result)-len(whited_data):, :] # STFTで生じた余分な信号長のカット - #result = multi_icwt(rebuild_spectrogram, omega0, sigma, fs) # iCWT(complex morlet) + # result[:int(fftLen/2), :] = (hanning(fftLen)[:int(fftLen/2)] * result[:int(fftLen/2), :].T).T + result = result[ + len(result) - len(whited_data) :, : + ] # STFTで生じた余分な信号長のカット + # result = multi_icwt(rebuild_spectrogram, omega0, sigma, fs) # iCWT(complex morlet) # ----------------------------------------------------------------- - sum_time += elapsed_time4 + sum_time += elapsed_time4 elapsed_time5 = time.time() - start - sum_time - - print('PCA : {0:6.2f}'.format(elapsed_time1) + "[sec]") - print('FICA: {0:6.2f}'.format(elapsed_time2) + "[sec]") - print('STFT: {0:6.2f}'.format(elapsed_time3) + "[sec]") - print('IVA : {0:6.2f}'.format(elapsed_time4) + "[sec]") - print('iSTFT: {0:5.2f}'.format(elapsed_time5) + "[sec]") - - print(np.sqrt(np.average(np.abs(data[:, :])**2))) - print(np.sqrt(np.average(np.abs(result[:, :])**2))) - + + print("PCA : {0:6.2f}".format(elapsed_time1) + "[sec]") + print("FICA: {0:6.2f}".format(elapsed_time2) + "[sec]") + print("STFT: {0:6.2f}".format(elapsed_time3) + "[sec]") + print("IVA : {0:6.2f}".format(elapsed_time4) + "[sec]") + print("iSTFT: {0:5.2f}".format(elapsed_time5) + "[sec]") + + print(np.sqrt(np.average(np.abs(data[:, :]) ** 2))) + print(np.sqrt(np.average(np.abs(result[:, :]) ** 2))) + print(np.linalg.norm(data[:, 0])) print(np.linalg.norm(result[:, 0])) - + # 振幅補正(RMSの比を基準に) - #result[:, :] *= np.sqrt(np.average(np.abs(data[:, :])**2)) / np.sqrt(np.average(np.abs(result[:, :])**2)) + # result[:, :] *= np.sqrt(np.average(np.abs(data[:, :])**2)) / np.sqrt(np.average(np.abs(result[:, :])**2)) # 振幅補正(L2-norm の比を基準に) result *= np.average(np.linalg.norm(data)) / np.linalg.norm(result) - #import matplotlib.pyplot as plt - #plt.figure() - #plt.plot(data[:, 0]) - - #plt.figure() - #plt.plot(result[:, 0]) - - + # import matplotlib.pyplot as plt + # plt.figure() + # plt.plot(data[:, 0]) + + # plt.figure() + # plt.plot(result[:, 0]) + return result -if __name__ == "__main__": +if __name__ == "__main__": data, samplerate = sf.read("yuki_stereo_VM00_VF00_0750.wav") # 2人の会話 - #data, samplerate = sf.read("townofdeath.wav") # freeの曲 + # data, samplerate = sf.read("townofdeath.wav") # freeの曲 if samplerate > 16000: print(samplerate) print("to") - seconds = int(np.floor(len(data)/samplerate)) - size = samplerate *seconds - data = data[:size, :].copy() # np.c_[data[:size, 0], data[:size, 1]] + seconds = int(np.floor(len(data) / samplerate)) + size = samplerate * seconds + data = data[:size, :].copy() # np.c_[data[:size, 0], data[:size, 1]] samplerate = 16000 - data = resample(data, samplerate*seconds) + data = resample(data, samplerate * seconds) print(samplerate) print(seconds, "[sec]") - # data = data[0*samplerate:30*samplerate, :] - + # data = data[0*samplerate:30*samplerate, :] + """ data1, samplerate = sf.read("dev2_ASY2016/dev2_mix4_asynchrec_realmix_ch12.wav") data2, samplerate = sf.read("dev2_ASY2016/dev2_mix4_asynchrec_realmix_ch34.wav") @@ -719,52 +811,48 @@ def IVA(data, N=10, fftLen=128, n_components=4, fs=16000): data4, samplerate = sf.read("dev2_ASY2016/dev2_mix4_asynchrec_realmix_ch78.wav") data =np.c_[data1, data2, data3, data4] """ - - origin_data = data.copy() # 生データ保存 - - + + origin_data = data.copy() # 生データ保存 + # 関数版IVA -------------------------- # result = IVA(data, N=20, fftLen=128*(2**1), n_components=2, fs=samplerate) # -------------------------------- - + # クラス版IVA -------------------------- - iva = IndependentVectorAnalysis(N=5, fftLen=1024, - n_components=2, fs=samplerate) + iva = IndependentVectorAnalysis(N=5, fftLen=1024, n_components=2, fs=samplerate) result = iva.fit_transform(data) # -------------------------------- print(result) - + W_im = np.abs(iva.W.reshape(128, -1)) - #print(W_im.shape) + # print(W_im.shape) plt.pcolormesh(W_im) plt.colorbar() plt.show() xx = np.linalg.det(np.abs(iva.W.T)) plt.plot(xx) plt.show() - + plt.plot(iva.r.T.real) plt.show() - + plt.plot(np.abs(iva.W.T.reshape(-1, 1))) plt.show() - + ss = [] for i in range(iva.W.shape[2]): ss.append(np.sum(np.abs(iva.W[:, :, i]))) ss = np.array(ss) plt.plot(ss) plt.show() - + for i in range(origin_data.shape[1]): - sf.write('origin_file%d.wav' % i, - origin_data[:, i], samplerate, 'PCM_16') + sf.write("origin_file%d.wav" % i, origin_data[:, i], samplerate, "PCM_16") for i in range(result.shape[1]): - sf.write('iva_file%d.wav' % i, - result[:, i], samplerate, 'PCM_16') + sf.write("iva_file%d.wav" % i, result[:, i], samplerate, "PCM_16") - sf.write('STEREO_ORIGIN.wav', origin_data, samplerate, 'PCM_16') - sf.write('STEREO_IVA.wav', result, samplerate, 'PCM_16') + sf.write("STEREO_ORIGIN.wav", origin_data, samplerate, "PCM_16") + sf.write("STEREO_IVA.wav", result, samplerate, "PCM_16") """ # 音声認識API --------------------------------------------------------------- diff --git a/src/iva/time_frequency_analysis.py b/src/iva/time_frequency_analysis.py new file mode 100644 index 0000000..60f6ebf --- /dev/null +++ b/src/iva/time_frequency_analysis.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +import abc + +import numpy as np + +from scipy.signal import ShortTimeFFT +from scipy.signal.windows import hamming + + +class BaseShortTimeFFT(abc.ABC): + @property + def fs(self): + return self._fs + + @property + def window_length(self): + return self._window_length + + @abc.abstractmethod + def stft(self, x: np.ndarray) -> np.ndarray: + pass + + @abc.abstractmethod + def istft(self, S: np.ndarray) -> np.ndarray: + pass + + +class HammingShortTimeFFT(BaseShortTimeFFT): + def __init__(self, fs: float, window_length: int, hop: int): + self._fs = fs + self._window_length = window_length + self.sft = ShortTimeFFT(hamming(window_length), hop, fs) + self.len_x = None + + def stft(self, x: np.ndarray) -> np.ndarray: + self.len_x = x.shape[-1] + return self.sft.stft(x) + + def istft(self, S: np.ndarray) -> np.ndarray: + if self.len_x is None: + raise ValueError("`len_x`: length of the original signal must be provided.") + return self.sft.istft(S, k1=self.len_x) diff --git a/tests/test_independent_vector_analysis.py b/tests/test_independent_vector_analysis.py new file mode 100644 index 0000000..d27f9e2 --- /dev/null +++ b/tests/test_independent_vector_analysis.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +import pytest +import numpy as np + +from iva.independent_vector_analysis import _IndependentVectorAnalysis + + +def load_dummy_signals( + time_length_sec=15, + num_channels=2, + fs=16000, +) -> np.ndarray: + return np.random.randn(int(time_length_sec * fs), num_channels), fs + + +@pytest.mark.parametrize( + "num_components,num_iterations", + [ + (2, 5), + (3, 4), + (4, 3), + ], +) +def test_iva( + num_components, + num_iterations, +): + input, _ = load_dummy_signals() + + iva = _IndependentVectorAnalysis( + num_components=num_components, + num_iterations=num_iterations, + ) + out = iva.fit_transform(input) + print(out) diff --git a/tests/test_time_frequency_analysis.py b/tests/test_time_frequency_analysis.py new file mode 100644 index 0000000..9e274ee --- /dev/null +++ b/tests/test_time_frequency_analysis.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +import pytest +import numpy as np + +from iva.time_frequency_analysis import HammingShortTimeFFT + +SAMPLING_FREQUENCY = 16000 +WINDOW_LENGTH = 512 +HOP_LENGTH = 256 +NUM_CHANNELS = 2 + + +@pytest.fixture +def sft(): + return HammingShortTimeFFT(SAMPLING_FREQUENCY, WINDOW_LENGTH, HOP_LENGTH) + + +@pytest.mark.parametrize( + "x", + [ + np.random.randn(SAMPLING_FREQUENCY), + np.random.randn(NUM_CHANNELS, SAMPLING_FREQUENCY), + ], +) +def test_sft(sft, x): + out = sft.istft(sft.stft(x)) + assert out.shape == x.shape + assert np.allclose(out, x)