Source code for madmom.audio.stft
# encoding: utf-8
# pylint: disable=no-member
# pylint: disable=invalid-name
# pylint: disable=too-many-arguments
"""
This module contains Short-Time Fourier Transform (STFT) related functionality.
"""
from __future__ import absolute_import, division, print_function
import numpy as np
import scipy.fftpack as fftpack
from ..processors import Processor
from .signal import Signal, FramedSignal
STFT_DTYPE = np.complex64
[docs]def fft_frequencies(num_fft_bins, sample_rate):
"""
Frequencies of the FFT bins.
Parameters
----------
num_fft_bins : int
Number of FFT bins (i.e. half the FFT length).
sample_rate : float
Sample rate of the signal.
Returns
-------
fft_frequencies : numpy array
Frequencies of the FFT bins [Hz].
"""
return np.fft.fftfreq(num_fft_bins * 2, 1. / sample_rate)[:num_fft_bins]
[docs]def stft(frames, window, fft_size=None, circular_shift=False):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given
framed signal.
Parameters
----------
frames : numpy array or iterable, shape (num_frames, frame_size)
Framed signal (e.g. :class:`FramedSignal` instance)
window : numpy array, shape (frame_size,)
Window (function).
fft_size : int, optional
FFT size (should be a power of 2); if 'None', the 'frame_size' given
by `frames` is used; if the given `fft_size` is greater than the
'frame_size', the frames are zero-padded, if smaller truncated.
circular_shift : bool, optional
Circular shift the individual frames before performing the FFT;
needed for correct phase.
Returns
-------
stft : numpy array, shape (num_frames, frame_size)
The complex STFT of the framed signal.
"""
# check for correct shape of input
if frames.ndim != 2:
# TODO: add multi-channel support
raise ValueError('frames must be a 2D array or iterable, got %s with '
'shape %s.' % (type(frames), frames.shape))
# shape of the frames
num_frames, frame_size = frames.shape
# FFT size to use
if fft_size is None:
fft_size = frame_size
# number of FFT bins to store
num_fft_bins = fft_size >> 1
# size of the FFT circular shift (needed for correct phase)
if circular_shift:
fft_shift = frame_size >> 1
# init objects
data = np.empty((num_frames, num_fft_bins), STFT_DTYPE)
# iterate over all frames
for f, frame in enumerate(frames):
if circular_shift:
# if we need to circular shift the signal for correct phase, we
# first multiply the signal frame with the window (or just use it
# as it is if no window function is given)
if window is not None:
signal = np.multiply(frame, window)
else:
signal = frame
# then swap the two halves of the windowed signal; if the FFT size
# is bigger than the frame size, we need to pad the (windowed)
# signal with additional zeros in between the two halves
fft_signal = np.zeros(fft_size)
fft_signal[:fft_shift] = signal[fft_shift:]
fft_signal[-fft_shift:] = signal[:fft_shift]
else:
# multiply the signal frame with the window and or save it directly
# to fft_signal (i.e. bypass the additional copying step above)
if window is not None:
fft_signal = np.multiply(frame, window)
else:
fft_signal = frame
# perform DFT
data[f] = fftpack.fft(fft_signal, axis=0)[:num_fft_bins]
# return STFT
return data
[docs]def phase(stft):
"""
Returns the phase of the complex STFT of a signal.
Parameters
----------
stft : numpy array, shape (num_frames, frame_size)
The complex STFT of a signal.
Returns
-------
phase : numpy array
Phase of the STFT.
"""
return np.angle(stft)
[docs]def local_group_delay(phase):
"""
Returns the local group delay of the phase of a signal.
Parameters
----------
phase : numpy array, shape (num_frames, frame_size)
Phase of the STFT of a signal.
Returns
-------
lgd : numpy array
Local group delay of the phase.
"""
# check for correct shape of input
if phase.ndim != 2:
raise ValueError('phase must be a 2D array')
# unwrap phase
unwrapped_phase = np.unwrap(phase)
# local group delay is the derivative over frequency
unwrapped_phase[:, :-1] -= unwrapped_phase[:, 1:]
# set the highest frequency to 0
unwrapped_phase[:, -1] = 0
# return the local group delay
return unwrapped_phase
# alias
lgd = local_group_delay
# mixin providing `num_frames` & `num_bins` properties
class _PropertyMixin(object):
# pylint: disable=missing-docstring
@property
def num_frames(self):
"""Number of frames."""
return len(self)
@property
def num_bins(self):
"""Number of bins."""
return int(self.shape[1])
# short-time Fourier transform class
[docs]class ShortTimeFourierTransform(_PropertyMixin, np.ndarray):
"""
ShortTimeFourierTransform class.
Parameters
----------
frames : :class:`.audio.signal.FramedSignal` instance
Framed signal.
window : numpy ufunc or numpy array, optional
Window (function); if a function (e.g. `np.hanning`) is given, a window
with the frame size of `frames` and the given shape is created.
fft_size : int, optional
FFT size (should be a power of 2); if 'None', the `frame_size` given by
`frames` is used, if the given `fft_size` is greater than the
`frame_size`, the frames are zero-padded accordingly.
circular_shift : bool, optional
Circular shift the individual frames before performing the FFT;
needed for correct phase.
kwargs : dict, optional
If no :class:`.audio.signal.FramedSignal` instance was given, one is
instantiated with these additional keyword arguments.
Notes
-----
If the :class:`Signal` (wrapped in the :class:`FramedSignal`) has an
integer dtype, the `window` is automatically scaled as if the `signal` had
a float dtype with the values being in the range [-1, 1]. This results in
same valued STFTs independently of the dtype of the signal. On the other
hand, this prevents extra memory consumption since the data-type of the
signal does not need to be converted (and if no decoding is needed, the
audio signal can be memory-mapped).
Examples
--------
Create a :class:`ShortTimeFourierTransform` from a :class:`Signal` or
:class:`FramedSignal`:
>>> sig = Signal('tests/data/audio/sample.wav')
>>> sig
Signal([-2494, -2510, ..., 655, 639], dtype=int16)
>>> frames = FramedSignal(sig, frame_size=2048, hop_size=441)
>>> frames # doctest: +ELLIPSIS
<madmom.audio.signal.FramedSignal object at 0x...>
>>> stft = ShortTimeFourierTransform(frames)
>>> stft # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
ShortTimeFourierTransform([[-3.15249+0.j , 2.62216-3.02425j, ...,
-0.03634-0.00005j, 0.03670+0.00029j],
[-4.28429+0.j , 2.02009+2.01264j, ...,
-0.01981-0.00933j, -0.00536+0.02162j],
...,
[-4.92274+0.j , 4.09839-9.42525j, ...,
0.00550-0.00257j, 0.00137+0.00577j],
[-9.22709+0.j , 8.76929+4.0005j , ...,
0.00981-0.00014j, -0.00984+0.00006j]],
dtype=complex64)
A ShortTimeFourierTransform can be instantiated directly from a file name:
>>> stft = ShortTimeFourierTransform('tests/data/audio/sample.wav')
>>> stft # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
ShortTimeFourierTransform([[...]], dtype=complex64)
Doing the same with a Signal of float data-type will result in a STFT of
same value range (rounding errors will occur of course):
>>> sig = Signal('tests/data/audio/sample.wav', dtype=np.float)
>>> sig # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
Signal([-0.07611, -0.0766 , ..., 0.01999, 0.0195 ])
>>> frames = FramedSignal(sig, frame_size=2048, hop_size=441)
>>> frames # doctest: +ELLIPSIS
<madmom.audio.signal.FramedSignal object at 0x...>
>>> stft = ShortTimeFourierTransform(frames)
>>> stft # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
ShortTimeFourierTransform([[-3.15240+0.j , 2.62208-3.02415j, ...,
-0.03633-0.00005j, 0.03670+0.00029j],
[-4.28416+0.j , 2.02003+2.01257j, ...,
-0.01981-0.00933j, -0.00536+0.02162j],
...,
[-4.92259+0.j , 4.09827-9.42496j, ...,
0.00550-0.00257j, 0.00137+0.00577j],
[-9.22681+0.j , 8.76902+4.00038j, ...,
0.00981-0.00014j, -0.00984+0.00006j]],
dtype=complex64)
Additional arguments are passed to :class:`FramedSignal` and
:class:`Signal` respectively:
>>> stft = ShortTimeFourierTransform('tests/data/audio/sample.wav', \
frame_size=2048, fps=100, sample_rate=22050)
>>> stft.frames # doctest: +ELLIPSIS
<madmom.audio.signal.FramedSignal object at 0x...>
>>> stft.frames.frame_size
2048
>>> stft.frames.hop_size
220.5
>>> stft.frames.signal.sample_rate
22050
"""
# pylint: disable=super-on-old-class
# pylint: disable=super-init-not-called
# pylint: disable=attribute-defined-outside-init
def __init__(self, frames, window=np.hanning, fft_size=None,
circular_shift=False, **kwargs):
# this method is for documentation purposes only
pass
def __new__(cls, frames, window=np.hanning, fft_size=None,
circular_shift=False, fft_window=None, **kwargs):
# pylint: disable=unused-argument
if isinstance(frames, ShortTimeFourierTransform):
# already a STFT, use the frames thereof
frames = frames.frames
# instantiate a FramedSignal if needed
if not isinstance(frames, FramedSignal):
frames = FramedSignal(frames, **kwargs)
# size of the frames
frame_size = frames.shape[1]
if fft_window is None:
# if a callable window function is given, use the frame size to
# create a window of this size
if hasattr(window, '__call__'):
window = window(frame_size)
# window used for FFT
try:
# if the signal is not scaled, scale the window accordingly
max_range = float(np.iinfo(frames.signal.dtype).max)
try:
# scale the window by the max_range
fft_window = window / max_range
except TypeError:
# if the window is None we can't scale it, thus create a
# uniform window and scale it accordingly
fft_window = np.ones(frame_size) / max_range
except ValueError:
# no scaling needed, use the window as is (can also be None)
fft_window = window
# calculate the STFT
data = stft(frames, fft_window, fft_size=fft_size,
circular_shift=circular_shift)
# cast as ShortTimeFourierTransform
obj = np.asarray(data).view(cls)
# save the other parameters
obj.frames = frames
obj.window = window
obj.fft_window = fft_window
obj.fft_size = fft_size if fft_size else frame_size
obj.circular_shift = circular_shift
# return the object
return obj
def __array_finalize__(self, obj):
if obj is None:
return
# set default values here, also needed for views
self.frames = getattr(obj, 'frames', None)
self.window = getattr(obj, 'window', np.hanning)
self.fft_window = getattr(obj, 'fft_window', None)
self.fft_size = getattr(obj, 'fft_size', None)
self.circular_shift = getattr(obj, 'circular_shift', False)
@property
def bin_frequencies(self):
"""Bin frequencies."""
return fft_frequencies(self.num_bins, self.frames.signal.sample_rate)
[docs] def spec(self, **kwargs):
"""
Returns the magnitude spectrogram of the STFT.
Parameters
----------
kwargs : dict, optional
Keyword arguments passed to
:class:`.audio.spectrogram.Spectrogram`.
Returns
-------
spec : :class:`.audio.spectrogram.Spectrogram`
:class:`.audio.spectrogram.Spectrogram` instance.
"""
# import Spectrogram here, otherwise we have circular imports
from .spectrogram import Spectrogram
return Spectrogram(self, **kwargs)
[docs] def phase(self, **kwargs):
"""
Returns the phase of the STFT.
Parameters
----------
kwargs : dict, optional
keyword arguments passed to :class:`Phase`.
Returns
-------
phase : :class:`Phase`
:class:`Phase` instance.
"""
return Phase(self, **kwargs)
STFT = ShortTimeFourierTransform
[docs]class ShortTimeFourierTransformProcessor(Processor):
"""
ShortTimeFourierTransformProcessor class.
Parameters
----------
window : numpy ufunc, optional
Window function.
fft_size : int, optional
FFT size (should be a power of 2); if 'None', it is determined by the
size of the frames; if is greater than the frame size, the frames are
zero-padded accordingly.
circular_shift : bool, optional
Circular shift the individual frames before performing the FFT;
needed for correct phase.
Examples
--------
Create a :class:`ShortTimeFourierTransformProcessor` and call it with
either a file name or a the output of a (Framed-)SignalProcessor to obtain
a :class:`ShortTimeFourierTransform` instance.
>>> proc = ShortTimeFourierTransformProcessor()
>>> stft = proc('tests/data/audio/sample.wav')
>>> stft # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
ShortTimeFourierTransform([[-3.15249+0.j , 2.62216-3.02425j, ...,
-0.03634-0.00005j, 0.03670+0.00029j],
[-4.28429+0.j , 2.02009+2.01264j, ...,
-0.01981-0.00933j, -0.00536+0.02162j],
...,
[-4.92274+0.j , 4.09839-9.42525j, ...,
0.00550-0.00257j, 0.00137+0.00577j],
[-9.22709+0.j , 8.76929+4.0005j , ...,
0.00981-0.00014j, -0.00984+0.00006j]],
dtype=complex64)
"""
def __init__(self, window=np.hanning, fft_size=None, circular_shift=False,
**kwargs):
# pylint: disable=unused-argument
self.window = window
self.fft_size = fft_size
self.circular_shift = circular_shift
self.fft_window = None # caching only, not intended for general use
[docs] def process(self, data, **kwargs):
"""
Perform FFT on a framed signal and return the STFT.
Parameters
----------
data : numpy array
Data to be processed.
kwargs : dict, optional
Keyword arguments passed to :class:`ShortTimeFourierTransform`.
Returns
-------
stft : :class:`ShortTimeFourierTransform`
:class:`ShortTimeFourierTransform` instance.
"""
# instantiate a STFT
data = ShortTimeFourierTransform(data, window=self.window,
fft_size=self.fft_size,
circular_shift=self.circular_shift,
fft_window=self.fft_window, **kwargs)
# cache the window used for FFT
# Note: depending on the signal this may be scaled already
self.fft_window = data.fft_window
return data
@staticmethod
[docs] def add_arguments(parser, window=None, fft_size=None):
"""
Add STFT related arguments to an existing parser.
Parameters
----------
parser : argparse parser instance
Existing argparse parser.
window : numpy ufunc, optional
Window function.
fft_size : int, optional
Use this size for FFT (should be a power of 2).
Returns
-------
argparse argument group
STFT argument parser group.
Notes
-----
Parameters are included in the group only if they are not 'None'.
"""
# add filterbank related options to the existing parser
g = parser.add_argument_group('short-time Fourier transform arguments')
if window is not None:
g.add_argument('--window', dest='window',
action='store', default=window,
help='window function to use for FFT')
if fft_size is not None:
g.add_argument('--fft_size', action='store', type=int,
default=fft_size,
help='use this size for FFT (should be a power of '
'2) [default=%(default)i]')
# return the group
return g
STFTProcessor = ShortTimeFourierTransformProcessor
# phase of STFT
[docs]class Phase(_PropertyMixin, np.ndarray):
"""
Phase class.
Parameters
----------
stft : :class:`ShortTimeFourierTransform` instance
:class:`ShortTimeFourierTransform` instance.
kwargs : dict, optional
If no :class:`ShortTimeFourierTransform` instance was given, one is
instantiated with these additional keyword arguments.
Examples
--------
Create a :class:`Phase` from a :class:`ShortTimeFourierTransform` (or
anything it can be instantiated from:
>>> stft = ShortTimeFourierTransform('tests/data/audio/sample.wav')
>>> phase = Phase(stft)
>>> phase # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
Phase([[ 3.14159, -0.85649, ..., -3.14016, 0.00779],
[ 3.14159, 0.78355, ..., -2.70136, 1.81393],
...,
[ 3.14159, -1.16063, ..., -0.4373 , 1.33774],
[ 3.14159, 0.42799, ..., -0.0142 , 3.13592]], dtype=float32)
"""
# pylint: disable=super-on-old-class
# pylint: disable=super-init-not-called
# pylint: disable=attribute-defined-outside-init
def __init__(self, stft, **kwargs):
# this method is for documentation purposes only
pass
def __new__(cls, stft, **kwargs):
# pylint: disable=unused-argument
# if a Phase object is given use its STFT
if isinstance(stft, Phase):
stft = stft.stft
# instantiate a ShortTimeFourierTransform object if needed
if not isinstance(stft, ShortTimeFourierTransform):
# set circular_shift if it was not disables explicitly
circular_shift = kwargs.pop('circular_shift', True)
stft = ShortTimeFourierTransform(stft,
circular_shift=circular_shift,
**kwargs)
# TODO: just recalculate with circular_shift set?
if not stft.circular_shift:
import warnings
warnings.warn("`circular_shift` of the STFT must be set to 'True' "
"for correct phase")
# process the STFT and cast the result as Phase
obj = np.asarray(phase(stft)).view(cls)
# save additional attributes
obj.stft = stft
# return the object
return obj
def __array_finalize__(self, obj):
if obj is None:
return
# set default values here, also needed for views
self.stft = getattr(obj, 'stft', None)
@property
def bin_frequencies(self):
return self.stft.bin_frequencies
[docs] def local_group_delay(self, **kwargs):
"""
Returns the local group delay of the phase.
Parameters
----------
kwargs : dict, optional
Keyword arguments passed to :class:`LocalGroupDelay`.
Returns
-------
lgd : :class:`LocalGroupDelay` instance
:class:`LocalGroupDelay` instance.
"""
return LocalGroupDelay(self, **kwargs)
lgd = local_group_delay
# local group delay of STFT
[docs]class LocalGroupDelay(_PropertyMixin, np.ndarray):
"""
Local Group Delay class.
Parameters
----------
stft : :class:`Phase` instance
:class:`Phase` instance.
kwargs : dict, optional
If no :class:`Phase` instance was given, one is instantiated with
these additional keyword arguments.
Examples
--------
Create a :class:`LocalGroupDelay` from a :class:`ShortTimeFourierTransform`
(or anything it can be instantiated from:
>>> stft = ShortTimeFourierTransform('tests/data/audio/sample.wav')
>>> lgd = LocalGroupDelay(stft)
>>> lgd # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
LocalGroupDelay([[-2.2851 , -2.25605, ..., 3.13525, 0. ],
[ 2.35804, 2.53786, ..., 1.76788, 0. ],
...,
[-1.98..., -2.93039, ..., -1.77505, 0. ],
[ 2.7136 , 2.60925, ..., 3.13318, 0. ]])
"""
# pylint: disable=super-on-old-class
# pylint: disable=super-init-not-called
# pylint: disable=attribute-defined-outside-init
def __init__(self, phase, **kwargs):
# this method is for documentation purposes only
pass
def __new__(cls, phase, **kwargs):
# pylint: disable=unused-argument
# try to instantiate a Phase object
if not isinstance(stft, Phase):
phase = Phase(phase, circular_shift=True, **kwargs)
if not phase.stft.circular_shift:
import warnings
warnings.warn("`circular_shift` of the STFT must be set to 'True' "
"for correct local group delay")
# process the phase and cast the result as LocalGroupDelay
obj = np.asarray(local_group_delay(phase)).view(cls)
# save additional attributes
obj.phase = phase
obj.stft = phase.stft
# return the object
return obj
def __array_finalize__(self, obj):
if obj is None:
return
# set default values here, also needed for views
self.phase = getattr(obj, 'phase', None)
self.stft = getattr(obj, 'stft', None)
@property
def bin_frequencies(self):
return self.stft.bin_frequencies
LGD = LocalGroupDelay