Template performing matrix diff-squared-add operation within a thread. More...
#include <custom_accum.h>

Public Types | |
| typedef cutlass::Shape< 1, 1, 1, 1 > | InstructionShape |
| The shape of the instruction. More... | |
| typedef AccumulatorsPerThread_ | AccumulatorsPerThread |
| The number of accumulators per thread. More... | |
| typedef ThreadsPerWarp_ | ThreadsPerWarp |
| The number of threads per warp. More... | |
| typedef cutlass::ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape | AccumulatorsPerWarp |
| The number of accumulators per warp. More... | |
| typedef ScalarA_ | ScalarA |
| The type for A. More... | |
| typedef cutlass::Fragment< ScalarA, AccumulatorsPerThread::kW > | FragmentA |
| The fragment for A. More... | |
| typedef ScalarB_ | ScalarB |
| The type for B. More... | |
| typedef cutlass::Fragment< ScalarB, AccumulatorsPerThread::kH > | FragmentB |
| The fragment for B. More... | |
| typedef ScalarC_ | ScalarC |
| The type for C and D. More... | |
| typedef cutlass::Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW, 16 > | Accumulators |
| The accumulators. More... | |
Public Member Functions | |
| CUTLASS_DEVICE | ThreadDiffSquaredAdd () |
| Ctor. More... | |
| CUTLASS_DEVICE void | multiply_add (FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d) |
| Multiply : d = (a-b)^2 + c. More... | |
Template performing matrix diff-squared-add operation within a thread.
| typedef cutlass::Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::Accumulators |
The accumulators.
| typedef AccumulatorsPerThread_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::AccumulatorsPerThread |
The number of accumulators per thread.
| typedef cutlass::ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::AccumulatorsPerWarp |
The number of accumulators per warp.
| typedef cutlass::Fragment<ScalarA, AccumulatorsPerThread::kW> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::FragmentA |
The fragment for A.
| typedef cutlass::Fragment<ScalarB, AccumulatorsPerThread::kH> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::FragmentB |
The fragment for B.
| typedef cutlass::Shape<1, 1, 1, 1> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::InstructionShape |
The shape of the instruction.
| typedef ScalarA_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarA |
The type for A.
| typedef ScalarB_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarB |
The type for B.
| typedef ScalarC_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarC |
The type for C and D.
| typedef ThreadsPerWarp_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ThreadsPerWarp |
The number of threads per warp.
|
inline |
Ctor.
|
inline |
Multiply : d = (a-b)^2 + c.