Source code for mne_nirs.visualisation._plot_GLM_topo

# Authors: Robert Luke <mail@robertluke.net>
#
# License: BSD (3-clause)

import inspect
from copy import deepcopy

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from mne import Info, pick_info
from mne.channels.layout import _merge_ch_data
from mne.io.pick import _picks_to_idx
from mne.utils import warn
from mne.viz import plot_topomap
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable


def _plot_glm_topo(
    inst,
    glm_estimates,
    design_matrix,
    *,
    requested_conditions=None,
    axes=None,
    vlim=None,
    vmin=None,
    vmax=None,
    colorbar=True,
    figsize=(12, 7),
    sphere=None,
):
    info = deepcopy(inst if isinstance(inst, Info) else inst.info)

    if not (info.ch_names == list(glm_estimates.keys())):
        if len(info.ch_names) < len(list(glm_estimates.keys())):
            warn("Reducing GLM results to match MNE data")
            glm_estimates = {a: glm_estimates[a] for a in info.ch_names}
        else:
            raise RuntimeError(
                "MNE data structure does not match regression "
                f"results. Raw = {len(info.ch_names)}. "
                f"GLM = {len(list(glm_estimates.keys()))}"
            )

    estimates = np.zeros((len(glm_estimates), len(design_matrix.columns)))

    for idx, name in enumerate(glm_estimates.keys()):
        estimates[idx, :] = glm_estimates[name].theta.T

    types = np.unique(info.get_channel_types())

    if requested_conditions is None:
        requested_conditions = design_matrix.columns
    requested_conditions = [
        x for x in design_matrix.columns if x in requested_conditions
    ]

    # Plotting setup
    if axes is None:
        fig, axes = plt.subplots(
            nrows=len(types), ncols=len(requested_conditions), figsize=figsize
        )

    estimates = estimates[:, [c in requested_conditions for c in design_matrix.columns]]

    estimates = estimates * 1e6
    design_matrix = design_matrix[requested_conditions]
    vlim, vlim_kwargs = _handle_vlim(vlim, vmin, vmax, estimates)
    del vmin, vmax
    cmap = mpl.cm.RdBu_r
    norm = mpl.colors.Normalize(vmin=vlim[0], vmax=vlim[1])

    for t_idx, t in enumerate(types):
        estmrg, pos, chs, sphere = _handle_overlaps(info, t, sphere, estimates)

        for idx, label in enumerate(design_matrix.columns):
            if label in requested_conditions:
                # Deal with case when only a single
                # chroma or condition is available
                if (len(requested_conditions) == 1) & (len(types) == 1):
                    ax = axes
                elif (len(requested_conditions) == 1) & (len(types) > 1):
                    ax = axes[t_idx]
                elif (len(requested_conditions) > 1) & (len(types) == 1):
                    ax = axes[idx]
                else:
                    ax = axes[t_idx, idx]

                plot_topomap(
                    estmrg[:, idx],
                    pos,
                    extrapolate="local",
                    names=chs,
                    cmap=cmap,
                    axes=ax,
                    show=False,
                    sphere=sphere,
                    **vlim_kwargs,
                )
                ax.set_title(label)

        if colorbar:
            ax1_divider = make_axes_locatable(ax)
            cax1 = ax1_divider.append_axes("right", size="7%", pad="2%")
            cbar = mpl.colorbar.ColorbarBase(
                cax1, cmap=cmap, norm=norm, orientation="vertical"
            )
            cbar.set_label("Haemoglobin (uM)", rotation=270)

    return _get_fig_from_axes(axes)


def _plot_glm_contrast_topo(inst, contrast, figsize=(12, 7), sphere=None):
    info = deepcopy(inst if isinstance(inst, Info) else inst.info)

    # Extract types. One subplot is created per type (hbo/hbr)
    types = np.unique(info.get_channel_types())

    # Extract values to plot and rescale to uM
    estimates = contrast.effect
    if estimates.ndim == 2:  # old nilearn
        assert estimates.shape[0] == 1
        estimates = estimates[0]
    estimates = estimates * 1e6

    # Create subplots for figures
    fig, axes = plt.subplots(nrows=1, ncols=len(types), figsize=figsize)
    # Create limits for colorbar
    vlim, vlim_kwargs = _handle_vlim((None, None), None, None, estimates)
    cmap = mpl.cm.RdBu_r
    norm = mpl.colors.Normalize(vmin=vlim[0], vmax=vlim[1])

    for t_idx, t in enumerate(types):
        estmrg, pos, chs, sphere = _handle_overlaps(info, t, sphere, estimates)

        # Deal with case when only a single chroma is available
        if len(types) == 1:
            ax = axes
        else:
            ax = axes[t_idx]

        # Plot the topomap
        plot_topomap(
            estmrg,
            pos,
            extrapolate="local",
            names=chs,
            cmap=cmap,
            axes=ax,
            show=False,
            sphere=sphere,
            **vlim_kwargs,
        )
        # Sets axes title
        if t == "hbo":
            ax.set_title("Oxyhaemoglobin")
        elif t == "hbr":
            ax.set_title("Deoxyhaemoglobin")
        else:
            ax.set_title(t)

    # Create a single colorbar for all types based on limits above
    ax1_divider = make_axes_locatable(ax)
    cax1 = ax1_divider.append_axes("right", size="7%", pad="2%")
    cbar = mpl.colorbar.ColorbarBase(cax1, cmap=cmap, norm=norm, orientation="vertical")
    cbar.set_label("Contrast Effect", rotation=270)

    return fig


[docs] def plot_glm_group_topo( inst, statsmodel_df, value="Coef.", axes=None, threshold=False, *, vlim=(None, None), vmin=None, vmax=None, cmap=None, sensors=True, res=64, sphere=None, colorbar=True, names=False, show_names=None, extrapolate="local", image_interp="cubic", ): """ Plot topomap of NIRS group level GLM results. Parameters ---------- inst : instance of Info or Raw Raw data or info structure used to generate the GLM results. statsmodel_df : DataFrame Dataframe created from a statsmodel summary. value : String Which column in the `statsmodel_df` to use in the topo map. axes : instance of Axes | None The axes to plot to. If None, the current axes will be used. threshold : Bool If threshold is true, all values with P>|z| greater than 0.05 will be set to zero. vlim : tuple of length 2 Colormap limits to use. If a :class:`tuple` of floats, specifies the lower and upper bounds of the colormap (in that order); providing ``None`` for either entry will set the corresponding boundary at the min/max of the data (separately for each topomap). vmin : float | None Deprecated, use 'vlim' instead. vmax : float | None Deprecated, use 'vlim' instead. cmap : matplotlib colormap | None Colormap to use. If None, 'Reds' is used for all positive data, otherwise defaults to 'RdBu_r'. sensors : bool | str Add markers for sensor locations to the plot. Accepts matplotlib plot format string (e.g., 'r+' for red plusses). If True (default), circles will be used. res : int The resolution of the topomap image (n pixels along each side). sphere : numbers As specified in mne. colorbar : bool Should a colorbar be plotted. names : list of str The channel names to display. show_names : bool Deprecated, use ``names`` instead. extrapolate : str Type of extrapolation for image. image_interp : str Type of interpolation for image. Returns ------- fig : figure Figure with topographic representation of statsmodel_df value. """ info = deepcopy(inst if isinstance(inst, Info) else inst.info) if show_names is not None: names = show_names warn( "show_names is deprecated and will be removed in the next " "release, use names instead", FutureWarning, ) del show_names # Check that the channels in two inputs match if not (info.ch_names == list(statsmodel_df["ch_name"].values)): if len(info.ch_names) < len(list(statsmodel_df["ch_name"].values)): print("Reducing GLM results to match MNE data") statsmodel_df["Keep"] = [ g in info.ch_names for g in statsmodel_df["ch_name"] ] statsmodel_df = statsmodel_df.query("Keep == True") else: warn("MNE data structure does not match regression results") statsmodel_df = statsmodel_df.set_index("ch_name") statsmodel_df = statsmodel_df.reindex(info.ch_names) # Extract estimate of interest to plot estimates = statsmodel_df[value].values if threshold: p = statsmodel_df["P>|z|"].values t = p > 0.05 estimates[t] = 0.0 assert len(np.unique(statsmodel_df["Chroma"])) == 1, "Only one Chroma allowed" if "Condition" in statsmodel_df.columns: assert len(np.unique(statsmodel_df["Condition"])) == 1, ( "Only one condition allowed" ) c = np.unique(statsmodel_df["Condition"])[0] else: c = "Contrast" t = np.unique(statsmodel_df["Chroma"])[0] # Plotting setup if axes is None: fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(12, 7)) # Set limits of topomap and colors vlim, vlim_kwargs = _handle_vlim(vlim, vmin, vmax, estimates) del vmin, vmax if cmap is None: cmap = mpl.cm.RdBu_r norm = mpl.colors.Normalize(vmin=vlim[0], vmax=vlim[1]) estmrg, pos, chs, sphere = _handle_overlaps(info, t, sphere, estimates) if "names" in inspect.signature(plot_topomap).parameters: names_kwarg = dict(names=chs if names else [""] * len(chs)) else: names_kwarg = dict(show_names=names, names=chs) plot_topomap( estmrg, pos, extrapolate=extrapolate, image_interp=image_interp, cmap=cmap, axes=axes, sensors=sensors, res=res, show=False, sphere=sphere, **vlim_kwargs, **names_kwarg, ) axes.set_title(c) if colorbar: ax1_divider = make_axes_locatable(axes) cax1 = ax1_divider.append_axes("right", size="7%", pad="2%") cbar = mpl.colorbar.ColorbarBase( cax1, cmap=cmap, norm=norm, orientation="vertical" ) cbar.set_label(value, rotation=270) return axes
def _handle_overlaps(info, t, sphere, estimates): """Prepare for topomap including merging channels.""" from mne.viz.topomap import _prepare_topomap_plot picks = _picks_to_idx(info, t, exclude=[], allow_empty=True) info_subset = pick_info(info, picks) ( _, pos, merge_channels, ch_names, ch_type, sphere, clip_origin, ) = _prepare_topomap_plot(info_subset, t, sphere=sphere) estmrg, ch_names = _merge_ch_data(estimates.copy()[picks], t, ch_names) return estmrg, pos, ch_names, sphere def _get_fig_from_axes(ax): if isinstance(ax, mpl.axes.SubplotBase): return ax.figure elif type(ax) is np.ndarray: return _get_fig_from_axes(ax[0]) else: raise RuntimeError(f"Unable to extract figure from {ax}") def _handle_vlim(vlim, vmin, vmax, estimates): if vmin is not None or vmax is not None: warn( "vmin and vmax are deprecated and will be removed in the next " "release, please use vlim instead", FutureWarning, ) vlim = (vmin, vmax) else: vmin, vmax = vlim if vmax is None: vmax = np.max(np.abs(estimates)) if vmin is None: vmin = vmax * -1.0 vlim = tuple(vlim) if "vlim" in inspect.signature(plot_topomap).parameters: kwargs = dict(vlim=(vmin, vmax)) else: kwargs = dict(vmin=vmin, vmax=vmax) return vlim, kwargs