tfmri.layers.DWT2D

class DWT2D(*args, **kwargs)[source]

Bases: tensorflow_mri.python.layers.signal_layers.DWT

Single-level 2D discrete wavelet transform (DWT) layer.

The input must be a tensor of shape [batch_size, height, width, channels].

The output format is determined by the format_dict argument. If format_dict is True (default), the output is a dict with keys 'aa', 'ad', 'da', 'dd', where 'a' is for approximation and 'd' is for detail. The value for each key is a tensor of shape [batch_size, out_height, out_width, channels]. The size of each output dimension is determined by out_dim = (in_dim + filter_len - 1) // 2, where filter_len is the length of the decomposition filters for the selected wavelet. If format_dict is False, returns a list of tensors corresponding to each of the keys above.

Parameters
  • wavelet

    A str or a length-2 list of str. When passed a list, different wavelets are applied along each axis.

  • mode

    A str. The padding or signal extension mode. Must be one of the values supported by tfmri.signal.dwt. Defaults to 'symmetric'.

  • format_dict

    A boolean. If True, the output is a dict. Otherwise, it is a list.