Template performing matrix diff-squared-add operation within a thread. More...
#include <custom_accum.h>
Public Types | |
typedef cutlass::Shape< 1, 1, 1, 1 > | InstructionShape |
The shape of the instruction. More... | |
typedef AccumulatorsPerThread_ | AccumulatorsPerThread |
The number of accumulators per thread. More... | |
typedef ThreadsPerWarp_ | ThreadsPerWarp |
The number of threads per warp. More... | |
typedef cutlass::ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape | AccumulatorsPerWarp |
The number of accumulators per warp. More... | |
typedef ScalarA_ | ScalarA |
The type for A. More... | |
typedef cutlass::Fragment< ScalarA, AccumulatorsPerThread::kW > | FragmentA |
The fragment for A. More... | |
typedef ScalarB_ | ScalarB |
The type for B. More... | |
typedef cutlass::Fragment< ScalarB, AccumulatorsPerThread::kH > | FragmentB |
The fragment for B. More... | |
typedef ScalarC_ | ScalarC |
The type for C and D. More... | |
typedef cutlass::Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW, 16 > | Accumulators |
The accumulators. More... | |
Public Member Functions | |
CUTLASS_DEVICE | ThreadDiffSquaredAdd () |
Ctor. More... | |
CUTLASS_DEVICE void | multiply_add (FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d) |
Multiply : d = (a-b)^2 + c. More... | |
Template performing matrix diff-squared-add operation within a thread.
typedef cutlass::Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::Accumulators |
The accumulators.
typedef AccumulatorsPerThread_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::AccumulatorsPerThread |
The number of accumulators per thread.
typedef cutlass::ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::AccumulatorsPerWarp |
The number of accumulators per warp.
typedef cutlass::Fragment<ScalarA, AccumulatorsPerThread::kW> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::FragmentA |
The fragment for A.
typedef cutlass::Fragment<ScalarB, AccumulatorsPerThread::kH> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::FragmentB |
The fragment for B.
typedef cutlass::Shape<1, 1, 1, 1> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::InstructionShape |
The shape of the instruction.
typedef ScalarA_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarA |
The type for A.
typedef ScalarB_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarB |
The type for B.
typedef ScalarC_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarC |
The type for C and D.
typedef ThreadsPerWarp_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ThreadsPerWarp |
The number of threads per warp.
|
inline |
Ctor.
|
inline |
Multiply : d = (a-b)^2 + c.