input1: (N,∗,Hin1) where Hin1=in1_features
and ∗ means any number of additional dimensions.
All but the last dimension of the inputs should be the same.
input2: (N,∗,Hin2) where Hin2=in2_features
weight: (out_features,in1_features,in2_features)
bias: (out_features)
output: (N,∗,Hout) where Hout=out_features
and all but the last dimension are the same shape as the input.