flatnode.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 // We want to define some functions as usable on device
20 // But need to guard against this file being compiled by a host compiler
21 #ifdef __CUDACC__
22 #define FLATNODE_HD __host__ __device__
23 #else
24 #define FLATNODE_HD
25 #endif
26 
33 template <typename DataT, typename LabelT, typename IdxT = int>
35  private:
36  IdxT colid = 0;
37  DataT quesval = DataT(0);
38  DataT best_metric_val = DataT(0);
39  IdxT left_child_id = -1;
40  IdxT instance_count = 0;
42  IdxT colid, DataT quesval, DataT best_metric_val, int64_t left_child_id, IdxT instance_count)
43  : colid(colid),
44  quesval(quesval),
45  best_metric_val(best_metric_val),
46  left_child_id(left_child_id),
47  instance_count(instance_count)
48  {
49  }
50 
51  public:
52  FLATNODE_HD IdxT ColumnId() const { return colid; }
53  FLATNODE_HD DataT QueryValue() const { return quesval; }
54  FLATNODE_HD DataT BestMetric() const { return best_metric_val; }
55  FLATNODE_HD int64_t LeftChildId() const { return left_child_id; }
56  FLATNODE_HD int64_t RightChildId() const { return left_child_id + 1; }
57  FLATNODE_HD IdxT InstanceCount() const { return instance_count; }
58 
60  IdxT colid, DataT quesval, DataT best_metric_val, int64_t left_child_id, IdxT instance_count)
61  {
63  colid, quesval, best_metric_val, left_child_id, instance_count};
64  }
65  FLATNODE_HD static SparseTreeNode CreateLeafNode(IdxT instance_count)
66  {
67  return SparseTreeNode<DataT, LabelT>{0, 0, 0, -1, instance_count};
68  }
69  FLATNODE_HD bool IsLeaf() const { return left_child_id == -1; }
70  bool operator==(const SparseTreeNode& other) const
71  {
72  return (this->colid == other.colid) && (this->quesval == other.quesval) &&
73  (this->best_metric_val == other.best_metric_val) &&
74  (this->left_child_id == other.left_child_id) &&
75  (this->instance_count == other.instance_count);
76  }
77 };
#define FLATNODE_HD
Definition: flatnode.h:24
Definition: flatnode.h:34
static FLATNODE_HD SparseTreeNode CreateSplitNode(IdxT colid, DataT quesval, DataT best_metric_val, int64_t left_child_id, IdxT instance_count)
Definition: flatnode.h:59
FLATNODE_HD int64_t LeftChildId() const
Definition: flatnode.h:55
bool operator==(const SparseTreeNode &other) const
Definition: flatnode.h:70
FLATNODE_HD DataT QueryValue() const
Definition: flatnode.h:53
static FLATNODE_HD SparseTreeNode CreateLeafNode(IdxT instance_count)
Definition: flatnode.h:65
FLATNODE_HD DataT BestMetric() const
Definition: flatnode.h:54
FLATNODE_HD int64_t RightChildId() const
Definition: flatnode.h:56
FLATNODE_HD bool IsLeaf() const
Definition: flatnode.h:69
FLATNODE_HD IdxT ColumnId() const
Definition: flatnode.h:52
FLATNODE_HD IdxT InstanceCount() const
Definition: flatnode.h:57