custom_accum.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2018-2021, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include <cutlass/cutlass.h>
9 #include <cutlass/fragment.h>
10 
11 namespace MLCommon {
12 namespace LinAlg {
13 
15 template <typename AccumulatorsPerThread_,
16  typename ThreadsPerWarp_,
17  typename ScalarA_,
18  typename ScalarB_,
19  typename ScalarC_>
22  typedef cutlass::Shape<1, 1, 1, 1> InstructionShape;
24  typedef AccumulatorsPerThread_ AccumulatorsPerThread;
26  typedef ThreadsPerWarp_ ThreadsPerWarp;
28  typedef
29  typename cutlass::ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
31  typedef ScalarA_ ScalarA;
33  typedef cutlass::Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
35  typedef ScalarB_ ScalarB;
37  typedef cutlass::Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
39  typedef ScalarC_ ScalarC;
41  typedef cutlass::Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16>
43 
45  CUTLASS_DEVICE ThreadDiffSquaredAdd() {}
46 
48  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
49  FragmentB const& b,
50  Accumulators const& c,
51  Accumulators& d)
52  {
53  for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
54  for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
55  auto diff = a[i] - b[j];
56  const auto idx = j * AccumulatorsPerThread::kW + i;
57  d[idx] = diff * diff + c[idx];
58  }
59  }
60  }
61 };
62 
64 template <typename AccumulatorsPerThread_,
65  typename ThreadsPerWarp_,
66  typename ScalarA_,
67  typename ScalarB_,
68  typename ScalarC_>
71  typedef cutlass::Shape<1, 1, 1, 1> InstructionShape;
73  typedef AccumulatorsPerThread_ AccumulatorsPerThread;
75  typedef ThreadsPerWarp_ ThreadsPerWarp;
77  typedef
78  typename cutlass::ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
80  typedef ScalarA_ ScalarA;
82  typedef cutlass::Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
84  typedef ScalarB_ ScalarB;
86  typedef cutlass::Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
88  typedef ScalarC_ ScalarC;
90  typedef cutlass::Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16>
92 
94  CUTLASS_DEVICE ThreadL1NormAdd() {}
95 
97  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
98  FragmentB const& b,
99  Accumulators const& c,
100  Accumulators& d)
101  {
102  for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
103  for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
104  auto diff = a[i] < b[j] ? b[j] - a[i] : a[i] - b[j];
105  const auto idx = j * AccumulatorsPerThread::kW + i;
106  d[idx] = diff + c[idx];
107  }
108  }
109  }
110 };
111 
112 }; // end namespace LinAlg
113 }; // end namespace MLCommon
Definition: Timer.h:9
Template performing matrix diff-squared-add operation within a thread.
Definition: custom_accum.h:20
CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
Multiply : d = (a-b)^2 + c.
Definition: custom_accum.h:48
ThreadsPerWarp_ ThreadsPerWarp
The number of threads per warp.
Definition: custom_accum.h:26
CUTLASS_DEVICE ThreadDiffSquaredAdd()
Ctor.
Definition: custom_accum.h:45
cutlass::Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: custom_accum.h:22
ScalarC_ ScalarC
The type for C and D.
Definition: custom_accum.h:39
ScalarA_ ScalarA
The type for A.
Definition: custom_accum.h:31
cutlass::ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape AccumulatorsPerWarp
The number of accumulators per warp.
Definition: custom_accum.h:29
cutlass::Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
The fragment for B.
Definition: custom_accum.h:37
cutlass::Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW, 16 > Accumulators
The accumulators.
Definition: custom_accum.h:42
AccumulatorsPerThread_ AccumulatorsPerThread
The number of accumulators per thread.
Definition: custom_accum.h:24
ScalarB_ ScalarB
The type for B.
Definition: custom_accum.h:35
cutlass::Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
The fragment for A.
Definition: custom_accum.h:33
Template performing matrix L1-norm operation within a thread.
Definition: custom_accum.h:69
cutlass::Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
The fragment for A.
Definition: custom_accum.h:82
ScalarB_ ScalarB
The type for B.
Definition: custom_accum.h:84
cutlass::ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape AccumulatorsPerWarp
The number of accumulators per warp.
Definition: custom_accum.h:78
ThreadsPerWarp_ ThreadsPerWarp
The number of threads per warp.
Definition: custom_accum.h:75
cutlass::Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW, 16 > Accumulators
The accumulators.
Definition: custom_accum.h:91
cutlass::Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
The fragment for B.
Definition: custom_accum.h:86
ScalarC_ ScalarC
The type for C and D.
Definition: custom_accum.h:88
AccumulatorsPerThread_ AccumulatorsPerThread
The number of accumulators per thread.
Definition: custom_accum.h:73
ScalarA_ ScalarA
The type for A.
Definition: custom_accum.h:80
CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
Multiply : d = |a-b| + c.
Definition: custom_accum.h:97
CUTLASS_DEVICE ThreadL1NormAdd()
Ctor.
Definition: custom_accum.h:94
cutlass::Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: custom_accum.h:71