custom_accum.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2021, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cutlass/cutlass.h>
20 #include <cutlass/fragment.h>
21 
22 namespace MLCommon {
23 namespace LinAlg {
24 
26 template <typename AccumulatorsPerThread_,
27  typename ThreadsPerWarp_,
28  typename ScalarA_,
29  typename ScalarB_,
30  typename ScalarC_>
33  typedef cutlass::Shape<1, 1, 1, 1> InstructionShape;
35  typedef AccumulatorsPerThread_ AccumulatorsPerThread;
37  typedef ThreadsPerWarp_ ThreadsPerWarp;
39  typedef
40  typename cutlass::ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
42  typedef ScalarA_ ScalarA;
44  typedef cutlass::Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
46  typedef ScalarB_ ScalarB;
48  typedef cutlass::Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
50  typedef ScalarC_ ScalarC;
52  typedef cutlass::Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16>
54 
56  CUTLASS_DEVICE ThreadDiffSquaredAdd() {}
57 
59  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
60  FragmentB const& b,
61  Accumulators const& c,
62  Accumulators& d)
63  {
64  for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
65  for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
66  auto diff = a[i] - b[j];
67  const auto idx = j * AccumulatorsPerThread::kW + i;
68  d[idx] = diff * diff + c[idx];
69  }
70  }
71  }
72 };
73 
75 template <typename AccumulatorsPerThread_,
76  typename ThreadsPerWarp_,
77  typename ScalarA_,
78  typename ScalarB_,
79  typename ScalarC_>
82  typedef cutlass::Shape<1, 1, 1, 1> InstructionShape;
84  typedef AccumulatorsPerThread_ AccumulatorsPerThread;
86  typedef ThreadsPerWarp_ ThreadsPerWarp;
88  typedef
89  typename cutlass::ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
91  typedef ScalarA_ ScalarA;
93  typedef cutlass::Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
95  typedef ScalarB_ ScalarB;
97  typedef cutlass::Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
99  typedef ScalarC_ ScalarC;
101  typedef cutlass::Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16>
103 
105  CUTLASS_DEVICE ThreadL1NormAdd() {}
106 
108  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
109  FragmentB const& b,
110  Accumulators const& c,
111  Accumulators& d)
112  {
113  for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
114  for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
115  auto diff = a[i] < b[j] ? b[j] - a[i] : a[i] - b[j];
116  const auto idx = j * AccumulatorsPerThread::kW + i;
117  d[idx] = diff + c[idx];
118  }
119  }
120  }
121 };
122 
123 }; // end namespace LinAlg
124 }; // end namespace MLCommon
Definition: kernelparams.h:21
Template performing matrix diff-squared-add operation within a thread.
Definition: custom_accum.h:31
CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
Multiply : d = (a-b)^2 + c.
Definition: custom_accum.h:59
ThreadsPerWarp_ ThreadsPerWarp
The number of threads per warp.
Definition: custom_accum.h:37
CUTLASS_DEVICE ThreadDiffSquaredAdd()
Ctor.
Definition: custom_accum.h:56
cutlass::Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: custom_accum.h:33
ScalarC_ ScalarC
The type for C and D.
Definition: custom_accum.h:50
ScalarA_ ScalarA
The type for A.
Definition: custom_accum.h:42
cutlass::ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape AccumulatorsPerWarp
The number of accumulators per warp.
Definition: custom_accum.h:40
cutlass::Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
The fragment for B.
Definition: custom_accum.h:48
cutlass::Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW, 16 > Accumulators
The accumulators.
Definition: custom_accum.h:53
AccumulatorsPerThread_ AccumulatorsPerThread
The number of accumulators per thread.
Definition: custom_accum.h:35
ScalarB_ ScalarB
The type for B.
Definition: custom_accum.h:46
cutlass::Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
The fragment for A.
Definition: custom_accum.h:44
Template performing matrix L1-norm operation within a thread.
Definition: custom_accum.h:80
cutlass::Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
The fragment for A.
Definition: custom_accum.h:93
ScalarB_ ScalarB
The type for B.
Definition: custom_accum.h:95
cutlass::ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape AccumulatorsPerWarp
The number of accumulators per warp.
Definition: custom_accum.h:89
ThreadsPerWarp_ ThreadsPerWarp
The number of threads per warp.
Definition: custom_accum.h:86
cutlass::Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW, 16 > Accumulators
The accumulators.
Definition: custom_accum.h:102
cutlass::Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
The fragment for B.
Definition: custom_accum.h:97
ScalarC_ ScalarC
The type for C and D.
Definition: custom_accum.h:99
AccumulatorsPerThread_ AccumulatorsPerThread
The number of accumulators per thread.
Definition: custom_accum.h:84
ScalarA_ ScalarA
The type for A.
Definition: custom_accum.h:91
CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
Multiply : d = |a-b| + c.
Definition: custom_accum.h:108
CUTLASS_DEVICE ThreadL1NormAdd()
Ctor.
Definition: custom_accum.h:105
cutlass::Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: custom_accum.h:82