Basic Example (2d Convolution)
"""Compare forward pass of a 2d convolution layer."""
from torch import allclose, manual_seed, rand
from torch.nn import Conv2d
from einconv.modules import ConvNd
manual_seed(0) # make deterministic
x = rand(10, 4, 28, 28) # random input
conv_params = {
"in_channels": 4,
"out_channels": 8,
"kernel_size": 4, # can also use tuple
"padding": 1, # can also use tuple, or string
"stride": 3, # can also use tuple
"dilation": 2, # can also use tuple
"groups": 2,
"bias": True,
}
N = 2 # convolution dimension
torch_layer = Conv2d(**conv_params)
ein_layer = ConvNd(N, **conv_params)
ein_layer.weight.data = torch_layer.weight.data
ein_layer.bias.data = torch_layer.bias.data
assert allclose(torch_layer(x), ein_layer(x), rtol=1e-4, atol=1e-6)