pack.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <raft/neighbors/ball_cover.cuh>
20 
21 namespace ML {
22 namespace Dbscan {
23 namespace VertexDeg {
24 
25 template <typename Type, typename Index_>
26 struct Pack {
28  raft::neighbors::ball_cover::BallCoverIndex<Index_, Type, Index_, Index_>* rbc_index;
34  Index_* vd;
36  Type* weight_sum;
38  Index_* ia;
39  rmm::device_uvector<Index_>* ja;
41  Index_ max_k;
43  bool* adj;
45  const Type* x;
47  const Type* sample_weight;
49  Type eps;
51  Index_ N;
53  Index_ D;
54 
60  void resetArray(cudaStream_t stream, Index_ vdlen)
61  {
62  RAFT_CUDA_TRY(cudaMemsetAsync(vd, 0, sizeof(Index_) * vdlen, stream));
63  }
64 };
65 
66 } // namespace VertexDeg
67 } // namespace Dbscan
68 } // namespace ML
Definition: dbscan.hpp:30
Definition: pack.h:26
Index_ max_k
Definition: pack.h:41
Index_ N
Definition: pack.h:51
bool * adj
Definition: pack.h:43
Index_ * vd
Definition: pack.h:34
const Type * sample_weight
Definition: pack.h:47
Type eps
Definition: pack.h:49
const Type * x
Definition: pack.h:45
Index_ * ia
Definition: pack.h:38
rmm::device_uvector< Index_ > * ja
Definition: pack.h:39
Index_ D
Definition: pack.h:53
Type * weight_sum
Definition: pack.h:36
void resetArray(cudaStream_t stream, Index_ vdlen)
reset the output array before calling the actual kernel
Definition: pack.h:60
raft::neighbors::ball_cover::BallCoverIndex< Index_, Type, Index_, Index_ > * rbc_index
Definition: pack.h:28