Skip to content

Einsum Expressions

einconv.expressions.convNd_forward

Generates einsum expression of the forward pass of a convolution.

einsum_expression

einsum_expression(x: Tensor, weight: Union[Tensor, Parameter], stride: Union[int, Tuple[int, ...]] = 1, padding: Union[int, str, Tuple[int, ...]] = 0, dilation: Union[int, Tuple[int, ...]] = 1, groups: int = 1, simplify: bool = True) -> Tuple[str, List[Union[Tensor, Parameter]], Tuple[int, ...]]

Generate einsum expression of a convolution's forward pass.

Parameters:

  • x (Tensor) –

    Convolution input. Has shape [batch_size, in_channels, *input_sizes] where len(input_sizes) == N.

  • weight (Union[Tensor, Parameter]) –

    Kernel of the convolution. Has shape [out_channels, in_channels / groups, *kernel_size] where kernel_size is an N-tuple of kernel dimensions.

  • stride (Union[int, Tuple[int, ...]], default: 1 ) –

    Stride of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • padding (Union[int, str, Tuple[int, ...]], default: 0 ) –

    Padding of the convolution. Can be a single integer (shared along all spatial dimensions), an N-tuple of integers, or a string. Default: 0. Allowed strings are 'same' and 'valid'.

  • dilation (Union[int, Tuple[int, ...]], default: 1 ) –

    Dilation of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • groups (int, default: 1 ) –

    In how many groups to split the input channels. Default: 1.

  • simplify (bool, default: True ) –

    Whether to simplify the einsum expression. Default: True.

Returns:

  • str

    Einsum equation

  • List[Union[Tensor, Parameter]]

    Einsum operands in order un-grouped input, patterns, un-grouped weight

  • Tuple[int, ...]

    Output shape: [batch_size, out_channels, *output_sizes].

Source code in einconv/expressions/convNd_forward.py
def einsum_expression(
    x: Tensor,
    weight: Union[Tensor, Parameter],
    stride: Union[int, Tuple[int, ...]] = 1,
    padding: Union[int, str, Tuple[int, ...]] = 0,
    dilation: Union[int, Tuple[int, ...]] = 1,
    groups: int = 1,
    simplify: bool = True,
) -> Tuple[str, List[Union[Tensor, Parameter]], Tuple[int, ...]]:
    """Generate einsum expression of a convolution's forward pass.

    Args:
        x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]``
            where ``len(input_sizes) == N``.
        weight: Kernel of the convolution. Has shape ``[out_channels,
            in_channels / groups, *kernel_size]`` where ``kernel_size`` is an
            ``N``-tuple of kernel dimensions.
        stride: Stride of the convolution. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        padding: Padding of the convolution. Can be a single integer (shared along
            all spatial dimensions), an ``N``-tuple of integers, or a string.
            Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
        dilation: Dilation of the convolution. Can be a single integer (shared along
            all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        groups: In how many groups to split the input channels. Default: ``1``.
        simplify: Whether to simplify the einsum expression. Default: ``True``.

    Returns:
        Einsum equation
        Einsum operands in order un-grouped input, patterns, un-grouped weight
        Output shape: ``[batch_size, out_channels, *output_sizes]``.
    """
    N = x.dim() - 2

    # construct einsum equation
    x_str = "n g c_in " + " ".join([f"i{i}" for i in range(N)])
    pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
    weight_str = "g c_out c_in " + " ".join([f"k{i}" for i in range(N)])
    lhs = ",".join([x_str, *pattern_strs, weight_str])

    rhs = "n g c_out " + " ".join([f"o{i}" for i in range(N)])

    equation = "->".join([lhs, rhs])
    equation = translate_to_torch(equation)

    # construct einsum operands
    patterns = create_conv_index_patterns(
        N,
        input_size=x.shape[2:],
        kernel_size=weight.shape[2:],
        stride=stride,
        padding=padding,
        dilation=dilation,
        device=weight.device,
        dtype=weight.dtype,
    )
    x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups)
    weight_ungrouped = rearrange(weight, "(g c_out) ... -> g c_out ...", g=groups)
    operands = [x_ungrouped, *patterns, weight_ungrouped]

    # construct output shape
    output_size = [p.shape[1] for p in patterns]
    batch_size = x.shape[0]
    out_channels = weight.shape[0]
    shape = (batch_size, out_channels, *output_size)

    if simplify:
        equation, operands = einconv.simplify(equation, operands)

    return equation, operands, shape

einconv.expressions.convNd_input_vjp

Generates einsum expression of the input VJP of a convolution.

einsum_expression

einsum_expression(weight: Union[Tensor, Parameter], v: Tensor, input_size: Union[int, Tuple[int, ...]], stride: Union[int, Tuple[int, ...]] = 1, padding: Union[int, Tuple[int, ...]] = 0, dilation: Union[int, Tuple[int, ...]] = 1, groups: int = 1, simplify: bool = True) -> Tuple[str, List[Union[Tensor, Parameter]], Tuple[int, ...]]

Generate einsum expression of a convolution's input VJP.

Parameters:

  • weight (Union[Tensor, Parameter]) –

    Kernel of the convolution. Has shape [out_channels, in_channels / groups, *kernel_size] where kernel_size is an N-tuple of kernel dimensions.

  • v (Tensor) –

    Vector multiplied by the Jacobian. Has shape [batch_size, out_channels, *output_sizes] where len(output_sizes) == N (same shape as the convolution's output).

  • input_size (Union[int, Tuple[int, ...]]) –

    Spatial dimensions of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers.

  • stride (Union[int, Tuple[int, ...]], default: 1 ) –

    Stride of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • padding (Union[int, Tuple[int, ...]], default: 0 ) –

    Padding of the convolution. Can be a single integer (shared along all spatial dimensions), an N-tuple of integers, or a string. Default: 0. Allowed strings are 'same' and 'valid'.

  • dilation (Union[int, Tuple[int, ...]], default: 1 ) –

    Dilation of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • groups (int, default: 1 ) –

    In how many groups to split the input channels. Default: 1.

  • simplify (bool, default: True ) –

    Whether to simplify the einsum expression. Default: True.

Returns:

  • str

    Einsum equation

  • List[Union[Tensor, Parameter]]

    Einsum operands in order un-grouped vector, patterns, un-grouped weight

  • Tuple[int, ...]

    Output shape: [batch_size, in_channels, *input_sizes]

Source code in einconv/expressions/convNd_input_vjp.py
def einsum_expression(
    weight: Union[Tensor, Parameter],
    v: Tensor,
    input_size: Union[int, Tuple[int, ...]],
    stride: Union[int, Tuple[int, ...]] = 1,
    padding: Union[int, Tuple[int, ...]] = 0,
    dilation: Union[int, Tuple[int, ...]] = 1,
    groups: int = 1,
    simplify: bool = True,
) -> Tuple[str, List[Union[Tensor, Parameter]], Tuple[int, ...]]:
    """Generate einsum expression of a convolution's input VJP.

    Args:
        weight: Kernel of the convolution. Has shape ``[out_channels,
            in_channels / groups, *kernel_size]`` where ``kernel_size`` is an
            ``N``-tuple of kernel dimensions.
        v: Vector multiplied by the Jacobian. Has shape
            ``[batch_size, out_channels, *output_sizes]``
            where ``len(output_sizes) == N`` (same shape as the convolution's output).
        input_size: Spatial dimensions of the convolution. Can be a single integer
            (shared along all spatial dimensions), or an ``N``-tuple of integers.
        stride: Stride of the convolution. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        padding: Padding of the convolution. Can be a single integer (shared along
            all spatial dimensions), an ``N``-tuple of integers, or a string.
            Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
        dilation: Dilation of the convolution. Can be a single integer (shared along
            all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        groups: In how many groups to split the input channels. Default: ``1``.
        simplify: Whether to simplify the einsum expression. Default: ``True``.

    Returns:
        Einsum equation
        Einsum operands in order un-grouped vector, patterns, un-grouped weight
        Output shape: ``[batch_size, in_channels, *input_sizes]``
    """
    N = weight.dim() - 2

    # construct einsum equation
    v_str = "n g c_out " + " ".join([f"o{i}" for i in range(N)])
    pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
    weight_str = "g c_out c_in " + " ".join([f"k{i}" for i in range(N)])
    lhs = ",".join([v_str, *pattern_strs, weight_str])

    rhs = "n g c_in " + " ".join([f"i{i}" for i in range(N)])

    equation = "->".join([lhs, rhs])
    equation = translate_to_torch(equation)

    # construct einsum operands
    patterns = create_conv_index_patterns(
        N,
        input_size,
        kernel_size=weight.shape[2:],
        stride=stride,
        padding=padding,
        dilation=dilation,
        device=weight.device,
        dtype=weight.dtype,
    )
    v_ungrouped = rearrange(v, "n (g c_out) ... -> n g c_out ...", g=groups)
    weight_ungrouped = rearrange(weight, "(g c_out) ... -> g c_out ...", g=groups)
    operands = [v_ungrouped, *patterns, weight_ungrouped]

    # construct output shape
    batch_size = v.shape[0]
    group_in_channels = weight.shape[1]
    t_input_size = _tuple(input_size, N)
    shape = (batch_size, groups * group_in_channels, *t_input_size)

    if simplify:
        equation, operands = einconv.simplify(equation, operands)

    return equation, operands, shape

einconv.expressions.convNd_weight_vjp

Generates einsum expression of the weight VJP of a convolution.

einsum_expression

einsum_expression(x: Tensor, v: Tensor, kernel_size: Union[int, Tuple[int, ...]], dilation: Union[int, Tuple[int, ...]] = 1, padding: Union[int, Tuple[int, ...]] = 0, stride: Union[int, Tuple[int, ...]] = 1, groups: int = 1, simplify: bool = True) -> Tuple[str, List[Tensor], Tuple[int, ...]]

Generate einsum expression of a convolution's weight VJP.

Parameters:

  • x (Tensor) –

    Convolution input. Has shape [batch_size, in_channels, *input_sizes] where len(input_sizes) == N.

  • v (Tensor) –

    Vector multiplied by the Jacobian. Has shape [batch_size, out_channels, *output_sizes] where len(output_sizes) == N (same shape as the convolution's output).

  • kernel_size (Union[int, Tuple[int, ...]]) –

    Kernel dimensions. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers.

  • stride (Union[int, Tuple[int, ...]], default: 1 ) –

    Stride of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • padding (Union[int, Tuple[int, ...]], default: 0 ) –

    Padding of the convolution. Can be a single integer (shared along all spatial dimensions), an N-tuple of integers, or a string. Default: 0. Allowed strings are 'same' and 'valid'.

  • dilation (Union[int, Tuple[int, ...]], default: 1 ) –

    Dilation of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • groups (int, default: 1 ) –

    In how many groups to split the input channels. Default: 1.

  • simplify (bool, default: True ) –

    Whether to simplify the einsum expression. Default: True.

Returns:

  • str

    Einsum equation

  • List[Tensor]

    Einsum operands in order ungrouped input, patterns, ungrouped vector.

  • Tuple[int, ...]

    Output shape: [out_channels, in_channels // groups, *kernel_size]

Source code in einconv/expressions/convNd_weight_vjp.py
def einsum_expression(
    x: Tensor,
    v: Tensor,
    kernel_size: Union[int, Tuple[int, ...]],
    dilation: Union[int, Tuple[int, ...]] = 1,
    padding: Union[int, Tuple[int, ...]] = 0,
    stride: Union[int, Tuple[int, ...]] = 1,
    groups: int = 1,
    simplify: bool = True,
) -> Tuple[str, List[Tensor], Tuple[int, ...]]:
    """Generate einsum expression of a convolution's weight VJP.

    Args:
        x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]``
            where ``len(input_sizes) == N``.
        v: Vector multiplied by the Jacobian. Has shape
            ``[batch_size, out_channels, *output_sizes]``
            where ``len(output_sizes) == N`` (same shape as the convolution's output).
        kernel_size: Kernel dimensions. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers.
        stride: Stride of the convolution. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        padding: Padding of the convolution. Can be a single integer (shared along
            all spatial dimensions), an ``N``-tuple of integers, or a string.
            Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
        dilation: Dilation of the convolution. Can be a single integer (shared along
            all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        groups: In how many groups to split the input channels. Default: ``1``.
        simplify: Whether to simplify the einsum expression. Default: ``True``.

    Returns:
        Einsum equation
        Einsum operands in order ungrouped input, patterns, ungrouped vector.
        Output shape: ``[out_channels, in_channels // groups, *kernel_size]``
    """
    N = x.dim() - 2

    # construct einsum equation
    v_str = "n g c_out " + " ".join([f"o{i}" for i in range(N)])
    x_str = "n g c_in " + " ".join([f"i{i}" for i in range(N)])
    pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
    lhs = ",".join([x_str, *pattern_strs, v_str])

    rhs = "g c_out c_in " + " ".join([f"k{i}" for i in range(N)])

    equation = "->".join([lhs, rhs])
    equation = translate_to_torch(equation)

    # construct einsum operands
    patterns = create_conv_index_patterns(
        N,
        input_size=x.shape[2:],
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        device=x.device,
        dtype=x.dtype,
    )
    x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups)
    v_ungrouped = rearrange(v, "n (g c_out) ... -> n g c_out ...", g=groups)
    operands = [x_ungrouped, *patterns, v_ungrouped]

    # construct output shape
    in_channels = x.shape[1]
    out_channels = v.shape[1]
    t_kernel_size = _tuple(kernel_size, N)
    shape = (out_channels, in_channels // groups, *t_kernel_size)

    if simplify:
        equation, operands = einconv.simplify(equation, operands)

    return equation, operands, shape

einconv.expressions.convNd_unfold

einsum_expression

einsum_expression(x: Tensor, kernel_size: Union[int, Tuple[int, ...]], stride: Union[int, Tuple[int, ...]] = 1, padding: Union[str, int, Tuple[int, ...]] = 0, dilation: Union[int, Tuple[int, ...]] = 1, simplify: bool = True) -> Tuple[str, List[Tensor], Tuple[int, ...]]

Generate einsum expression to unfold the input of a convolution.

Parameters:

  • x (Tensor) –

    Convolution input. Has shape [batch_size, in_channels, *input_sizes] where len(input_sizes) == N.

  • kernel_size (Union[int, Tuple[int, ...]]) –

    Kernel dimensions. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers.

  • stride (Union[int, Tuple[int, ...]], default: 1 ) –

    Stride of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • padding (Union[str, int, Tuple[int, ...]], default: 0 ) –

    Padding of the convolution. Can be a single integer (shared along all spatial dimensions), an N-tuple of integers, or a string. Default: 0. Allowed strings are 'same' and 'valid'.

  • dilation (Union[int, Tuple[int, ...]], default: 1 ) –

    Dilation of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • simplify (bool, default: True ) –

    Whether to simplify the einsum expression. Default: True.

Returns:

  • str

    Einsum equation

  • List[Tensor]

    Einsum operands in order input, patterns

  • Tuple[int, ...]

    Output shape: [batch_size, in_channels, tot_output_size]

Source code in einconv/expressions/convNd_unfold.py
def einsum_expression(
    x: Tensor,
    kernel_size: Union[int, Tuple[int, ...]],
    stride: Union[int, Tuple[int, ...]] = 1,
    padding: Union[str, int, Tuple[int, ...]] = 0,
    dilation: Union[int, Tuple[int, ...]] = 1,
    simplify: bool = True,
) -> Tuple[str, List[Tensor], Tuple[int, ...]]:
    """Generate einsum expression to unfold the input of a convolution.

    Args:
        x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]``
            where ``len(input_sizes) == N``.
        kernel_size: Kernel dimensions. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers.
        stride: Stride of the convolution. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        padding: Padding of the convolution. Can be a single integer (shared along
            all spatial dimensions), an ``N``-tuple of integers, or a string.
            Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
        dilation: Dilation of the convolution. Can be a single integer (shared along
            all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        simplify: Whether to simplify the einsum expression. Default: ``True``.

    Returns:
        Einsum equation
        Einsum operands in order input, patterns
        Output shape: ``[batch_size, in_channels, tot_output_size]``
    """
    N = x.dim() - 2

    # construct einsum equation
    x_str = "n c_in " + " ".join([f"i{i}" for i in range(N)])
    pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
    lhs = ",".join([x_str, *pattern_strs])

    rhs = (
        "n c_in "
        + " ".join([f"k{i}" for i in range(N)])
        + " "
        + " ".join([f"o{i}" for i in range(N)])
    )

    equation = "->".join([lhs, rhs])
    equation = translate_to_torch(equation)

    # construct einsum operands
    patterns = create_conv_index_patterns(
        N,
        input_size=x.shape[2:],
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        device=x.device,
        dtype=x.dtype,
    )
    operands = [x, *patterns]

    # construct output shape
    output_tot_size = int(Tensor([p.shape[1] for p in patterns]).int().prod())
    t_kernel_size = _tuple(kernel_size, N)
    kernel_tot_size = int(Tensor(t_kernel_size).int().prod())
    batch_size, in_channels = x.shape[:2]
    shape = (batch_size, in_channels * kernel_tot_size, output_tot_size)

    if simplify:
        equation, operands = einconv.simplify(equation, operands)

    return equation, operands, shape

einconv.expressions.convNd_kfc

Input-based factor of the KFC Fisher approximation for convolutions.

KFC was introduced by:

  • Grosse, R., & Martens, J. (2016). A Kronecker-factored approximate Fisher matrix for convolution layers. International Conference on Machine Learning (ICML).

einsum_expression

einsum_expression(x: Tensor, kernel_size: Union[int, Tuple[int, ...]], stride: Union[int, Tuple[int, ...]] = 1, padding: Union[int, str, Tuple[int, ...]] = 0, dilation: Union[int, Tuple[int, ...]] = 1, groups: int = 1, simplify: bool = True) -> Tuple[str, List[Tensor], Tuple[int, ...]]

Generate einsum expression of input-based KFC factor for convolution.

Parameters:

  • x (Tensor) –

    Convolution input. Has shape [batch_size, in_channels, *input_sizes] where len(input_sizes) == N.

  • kernel_size (Union[int, Tuple[int, ...]]) –

    Kernel dimensions. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers.

  • stride (Union[int, Tuple[int, ...]], default: 1 ) –

    Stride of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • padding (Union[int, str, Tuple[int, ...]], default: 0 ) –

    Padding of the convolution. Can be a single integer (shared along all spatial dimensions), an N-tuple of integers, or a string. Default: 0. Allowed strings are 'same' and 'valid'.

  • dilation (Union[int, Tuple[int, ...]], default: 1 ) –

    Dilation of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • groups (int, default: 1 ) –

    In how many groups to split the input channels. Default: 1.

  • simplify (bool, default: True ) –

    Whether to simplify the einsum expression. Default: True.

Returns:

  • str

    Einsum equation

  • List[Tensor]

    Einsum operands in order un-grouped input, patterns, un-grouped input, patterns, normalization scaling

  • Tuple[int, ...]

    Output shape: [groups, in_channels //groups * tot_kernel_sizes, in_channels //groups * tot_kernel_sizes]

Source code in einconv/expressions/convNd_kfc.py
def einsum_expression(
    x: Tensor,
    kernel_size: Union[int, Tuple[int, ...]],
    stride: Union[int, Tuple[int, ...]] = 1,
    padding: Union[int, str, Tuple[int, ...]] = 0,
    dilation: Union[int, Tuple[int, ...]] = 1,
    groups: int = 1,
    simplify: bool = True,
) -> Tuple[str, List[Tensor], Tuple[int, ...]]:
    """Generate einsum expression of input-based KFC factor for convolution.

    Args:
        x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]``
            where ``len(input_sizes) == N``.
        kernel_size: Kernel dimensions. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers.
        stride: Stride of the convolution. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        padding: Padding of the convolution. Can be a single integer (shared along
            all spatial dimensions), an ``N``-tuple of integers, or a string.
            Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
        dilation: Dilation of the convolution. Can be a single integer (shared along
            all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        groups: In how many groups to split the input channels. Default: ``1``.
        simplify: Whether to simplify the einsum expression. Default: ``True``.

    Returns:
        Einsum equation
        Einsum operands in order un-grouped input, patterns, un-grouped input, \
        patterns, normalization scaling
        Output shape: ``[groups, in_channels //groups * tot_kernel_sizes,\
        in_channels //groups * tot_kernel_sizes]``
    """
    N = x.dim() - 2

    # construct einsum equation
    x1_str = "n g c_in " + " ".join([f"i{i}" for i in range(N)])
    x2_str = "n g c_in_ " + " ".join([f"i{i}_" for i in range(N)])
    pattern1_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
    pattern2_strs: List[str] = [f"k{i}_ o{i} i{i}_" for i in range(N)]
    scale_str = "s"
    lhs = ",".join([x1_str, *pattern1_strs, *pattern2_strs, x2_str, scale_str])

    rhs = (
        "g c_in "
        + " ".join([f"k{i}" for i in range(N)])
        + " c_in_ "
        + " ".join([f"k{i}_" for i in range(N)])
    )

    equation = "->".join([lhs, rhs])
    equation = translate_to_torch(equation)

    # construct einsum operands
    patterns = create_conv_index_patterns(
        N,
        input_size=x.shape[2:],
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        device=x.device,
        dtype=x.dtype,
    )
    x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups)
    batch_size = x.shape[0]
    scale = Tensor([1.0 / batch_size]).to(x.device).to(x.dtype)
    operands = [x_ungrouped, *patterns, *patterns, x_ungrouped, scale]

    # construct output shape
    t_kernel_size = _tuple(kernel_size, N)
    kernel_tot_sizes = int(Tensor(t_kernel_size).int().prod())
    in_channels = x.shape[1]
    shape = (
        groups,
        in_channels // groups * kernel_tot_sizes,
        in_channels // groups * kernel_tot_sizes,
    )

    if simplify:
        equation, operands = einconv.simplify(equation, operands)

    return equation, operands, shape

einconv.expressions.convNd_kfac_reduce

Input-based factor of the K-FAC reduce approximation for convolutions.

KFAC-reduce was introduced by:

  • Eschenhagen, R. (2022). Kronecker-factored approximate curvature for linear weight-sharing layers, Master thesis.

einsum_expression

einsum_expression(x: Tensor, kernel_size: Union[int, Tuple[int, ...]], stride: Union[int, Tuple[int, ...]] = 1, padding: Union[int, str, Tuple[int, ...]] = 0, dilation: Union[int, Tuple[int, ...]] = 1, groups: int = 1, simplify: bool = True) -> Tuple[str, List[Tensor], Tuple[int, ...]]

Generate einsum expression of input-based KFAC-reduce factor for convolution.

Parameters:

  • x (Tensor) –

    Convolution input. Has shape [batch_size, in_channels, *input_sizes] where len(input_sizes) == N.

  • kernel_size (Union[int, Tuple[int, ...]]) –

    Kernel dimensions. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers.

  • stride (Union[int, Tuple[int, ...]], default: 1 ) –

    Stride of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • padding (Union[int, str, Tuple[int, ...]], default: 0 ) –

    Padding of the convolution. Can be a single integer (shared along all spatial dimensions), an N-tuple of integers, or a string. Default: 0. Allowed strings are 'same' and 'valid'.

  • dilation (Union[int, Tuple[int, ...]], default: 1 ) –

    Dilation of the convolution. Can be a single integer (shared along all spatial dimensions), or an N-tuple of integers. Default: 1.

  • groups (int, default: 1 ) –

    In how many groups to split the input channels. Default: 1.

  • simplify (bool, default: True ) –

    Whether to simplify the einsum expression. Default: True.

Returns:

  • str

    Einsum equation

  • List[Tensor]

    Einsum operands in order un-grouped input, patterns, un-grouped input, patterns, normalization scaling

  • Tuple[int, ...]

    Output shape: [groups, in_channels //groups * tot_kernel_sizes, in_channels //groups * tot_kernel_sizes]

Source code in einconv/expressions/convNd_kfac_reduce.py
def einsum_expression(
    x: Tensor,
    kernel_size: Union[int, Tuple[int, ...]],
    stride: Union[int, Tuple[int, ...]] = 1,
    padding: Union[int, str, Tuple[int, ...]] = 0,
    dilation: Union[int, Tuple[int, ...]] = 1,
    groups: int = 1,
    simplify: bool = True,
) -> Tuple[str, List[Tensor], Tuple[int, ...]]:
    """Generate einsum expression of input-based KFAC-reduce factor for convolution.

    Args:
        x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]``
            where ``len(input_sizes) == N``.
        kernel_size: Kernel dimensions. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers.
        stride: Stride of the convolution. Can be a single integer (shared along all
            spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        padding: Padding of the convolution. Can be a single integer (shared along
            all spatial dimensions), an ``N``-tuple of integers, or a string.
            Default: ``0``. Allowed strings are ``'same'`` and ``'valid'``.
        dilation: Dilation of the convolution. Can be a single integer (shared along
            all spatial dimensions), or an ``N``-tuple of integers. Default: ``1``.
        groups: In how many groups to split the input channels. Default: ``1``.
        simplify: Whether to simplify the einsum expression. Default: ``True``.

    Returns:
        Einsum equation
        Einsum operands in order un-grouped input, patterns, un-grouped input, \
        patterns, normalization scaling
        Output shape: ``[groups, in_channels //groups * tot_kernel_sizes,\
        in_channels //groups * tot_kernel_sizes]``
    """
    N = x.dim() - 2

    # construct einsum equation
    x1_str = "n g c_in " + " ".join([f"i{i}" for i in range(N)])
    x2_str = "n g c_in_ " + " ".join([f"i{i}_" for i in range(N)])
    pattern1_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)]
    pattern2_strs: List[str] = [f"k{i}_ o{i}_ i{i}_" for i in range(N)]
    scale_str = "s"
    lhs = ",".join([x1_str, *pattern1_strs, *pattern2_strs, x2_str, scale_str])

    rhs = (
        "g c_in "
        + " ".join([f"k{i}" for i in range(N)])
        + " c_in_ "
        + " ".join([f"k{i}_" for i in range(N)])
    )

    equation = "->".join([lhs, rhs])
    equation = translate_to_torch(equation)

    # construct einsum operands
    patterns = create_conv_index_patterns(
        N,
        input_size=x.shape[2:],
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation,
        device=x.device,
        dtype=x.dtype,
    )
    x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups)
    output_tot_size = Tensor([p.shape[1] for p in patterns]).int().prod()
    batch_size = x.shape[0]
    scale = Tensor([1.0 / (batch_size * output_tot_size**2)]).to(x.device).to(x.dtype)
    operands = [x_ungrouped, *patterns, *patterns, x_ungrouped, scale]

    # construct output shape
    t_kernel_size = _tuple(kernel_size, N)
    kernel_tot_size = int(Tensor(t_kernel_size).int().prod())
    in_channels = x.shape[1]
    shape = (
        groups,
        in_channels // groups * kernel_tot_size,
        in_channels // groups * kernel_tot_size,
    )

    if simplify:
        equation, operands = einconv.simplify(equation, operands)

    return equation, operands, shape