tfmri.layers.DWT3D

class DWT3D(*args, **kwargs)[source]

Bases: tensorflow_mri.python.layers.signal_layers.DWT

Single-level 3D discrete wavelet transform (DWT) layer.

The input must be a tensor of shape [batch_size, depth, height, width, channels].

The output format is determined by the format_dict argument. If format_dict is True (default), the output is a dict with keys 'aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd', where 'a' is for approximation and 'd' is for detail. The value for each key is a tensor of shape [batch_size, out_depth, out_height, out_width, channels]. The size of each output dimension is determined by out_dim = (in_dim + filter_len - 1) // 2, where filter_len is the length of the decomposition filters for the selected wavelet. If format_dict is False, returns a list of tensors corresponding to each of the keys above.

Parameters
  • wavelet

    A str or a length-3 list of str. When passed a list, different wavelets are applied along each axis.

  • mode

    A str. The padding or signal extension mode. Must be one of the values supported by tfmri.signal.dwt. Defaults to 'symmetric'.

  • format_dict

    A boolean. If True, the output is a dict. Otherwise, it is a list.