Public Types | Public Member Functions | List of all members
MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ > Struct Template Reference

Template performing matrix diff-squared-add operation within a thread. More...

#include <custom_accum.h>

Collaboration diagram for MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >:
Collaboration graph

Public Types

typedef cutlass::Shape< 1, 1, 1, 1 > InstructionShape
 The shape of the instruction. More...
 
typedef AccumulatorsPerThread_ AccumulatorsPerThread
 The number of accumulators per thread. More...
 
typedef ThreadsPerWarp_ ThreadsPerWarp
 The number of threads per warp. More...
 
typedef cutlass::ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape AccumulatorsPerWarp
 The number of accumulators per warp. More...
 
typedef ScalarA_ ScalarA
 The type for A. More...
 
typedef cutlass::Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
 The fragment for A. More...
 
typedef ScalarB_ ScalarB
 The type for B. More...
 
typedef cutlass::Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
 The fragment for B. More...
 
typedef ScalarC_ ScalarC
 The type for C and D. More...
 
typedef cutlass::Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW, 16 > Accumulators
 The accumulators. More...
 

Public Member Functions

CUTLASS_DEVICE ThreadDiffSquaredAdd ()
 Ctor. More...
 
CUTLASS_DEVICE void multiply_add (FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
 Multiply : d = (a-b)^2 + c. More...
 

Detailed Description

template<typename AccumulatorsPerThread_, typename ThreadsPerWarp_, typename ScalarA_, typename ScalarB_, typename ScalarC_>
struct MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >

Template performing matrix diff-squared-add operation within a thread.

Member Typedef Documentation

◆ Accumulators

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef cutlass::Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::Accumulators

The accumulators.

◆ AccumulatorsPerThread

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef AccumulatorsPerThread_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::AccumulatorsPerThread

The number of accumulators per thread.

◆ AccumulatorsPerWarp

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef cutlass::ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::AccumulatorsPerWarp

The number of accumulators per warp.

◆ FragmentA

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef cutlass::Fragment<ScalarA, AccumulatorsPerThread::kW> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::FragmentA

The fragment for A.

◆ FragmentB

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef cutlass::Fragment<ScalarB, AccumulatorsPerThread::kH> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::FragmentB

The fragment for B.

◆ InstructionShape

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef cutlass::Shape<1, 1, 1, 1> MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::InstructionShape

The shape of the instruction.

◆ ScalarA

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef ScalarA_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarA

The type for A.

◆ ScalarB

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef ScalarB_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarB

The type for B.

◆ ScalarC

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef ScalarC_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ScalarC

The type for C and D.

◆ ThreadsPerWarp

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
typedef ThreadsPerWarp_ MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ThreadsPerWarp

The number of threads per warp.

Constructor & Destructor Documentation

◆ ThreadDiffSquaredAdd()

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
CUTLASS_DEVICE MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::ThreadDiffSquaredAdd ( )
inline

Ctor.

Member Function Documentation

◆ multiply_add()

template<typename AccumulatorsPerThread_ , typename ThreadsPerWarp_ , typename ScalarA_ , typename ScalarB_ , typename ScalarC_ >
CUTLASS_DEVICE void MLCommon::LinAlg::ThreadDiffSquaredAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >::multiply_add ( FragmentA const &  a,
FragmentB const &  b,
Accumulators const &  c,
Accumulators d 
)
inline

Multiply : d = (a-b)^2 + c.


The documentation for this struct was generated from the following file: