Template performing matrix diff-squared-add operation within a thread. More...
#include <custom_accum.h>

Public Types | |
| typedef cutlass::Shape< 1, 1, 1, 1 > | InstructionShape | 
| The shape of the instruction.  More... | |
| typedef AccumulatorsPerThread_ | AccumulatorsPerThread | 
| The number of accumulators per thread.  More... | |
| typedef ThreadsPerWarp_ | ThreadsPerWarp | 
| The number of threads per warp.  More... | |
| typedef cutlass::ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape | AccumulatorsPerWarp | 
| The number of accumulators per warp.  More... | |
| typedef ScalarA_ | ScalarA | 
| The type for A.  More... | |
| typedef cutlass::Fragment< ScalarA, AccumulatorsPerThread::kW > | FragmentA | 
| The fragment for A.  More... | |
| typedef ScalarB_ | ScalarB | 
| The type for B.  More... | |
| typedef cutlass::Fragment< ScalarB, AccumulatorsPerThread::kH > | FragmentB | 
| The fragment for B.  More... | |
| typedef ScalarC_ | ScalarC | 
| The type for C and D.  More... | |
| typedef cutlass::Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW, 16 > | Accumulators | 
| The accumulators.  More... | |
Public Member Functions | |
| CUTLASS_DEVICE | ThreadDiffSquaredAdd () | 
| Ctor.  More... | |
| CUTLASS_DEVICE void | multiply_add (FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d) | 
| Multiply : d = (a-b)^2 + c.  More... | |
Template performing matrix diff-squared-add operation within a thread.
| typedef cutlass::Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::Accumulators | 
The accumulators.
| typedef AccumulatorsPerThread_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::AccumulatorsPerThread | 
The number of accumulators per thread.
| typedef cutlass::ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::AccumulatorsPerWarp | 
The number of accumulators per warp.
| typedef cutlass::Fragment<ScalarA, AccumulatorsPerThread::kW> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::FragmentA | 
The fragment for A.
| typedef cutlass::Fragment<ScalarB, AccumulatorsPerThread::kH> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::FragmentB | 
The fragment for B.
| typedef cutlass::Shape<1, 1, 1, 1> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::InstructionShape | 
The shape of the instruction.
| typedef ScalarA_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarA | 
The type for A.
| typedef ScalarB_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarB | 
The type for B.
| typedef ScalarC_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarC | 
The type for C and D.
| typedef ThreadsPerWarp_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ThreadsPerWarp | 
The number of threads per warp.
      
  | 
  inline | 
Ctor.
      
  | 
  inline | 
Multiply : d = (a-b)^2 + c.