flatnode.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2021, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 // We want to define some functions as usable on device
9 // But need to guard against this file being compiled by a host compiler
10 #ifdef __CUDACC__
11 #define FLATNODE_HD __host__ __device__
12 #else
13 #define FLATNODE_HD
14 #endif
15 
22 template <typename DataT, typename LabelT, typename IdxT = int>
24  private:
25  IdxT colid = 0;
26  DataT quesval = DataT(0);
27  DataT best_metric_val = DataT(0);
28  IdxT left_child_id = -1;
29  IdxT instance_count = 0;
31  IdxT colid, DataT quesval, DataT best_metric_val, int64_t left_child_id, IdxT instance_count)
32  : colid(colid),
33  quesval(quesval),
34  best_metric_val(best_metric_val),
35  left_child_id(left_child_id),
36  instance_count(instance_count)
37  {
38  }
39 
40  public:
41  FLATNODE_HD IdxT ColumnId() const { return colid; }
42  FLATNODE_HD DataT QueryValue() const { return quesval; }
43  FLATNODE_HD DataT BestMetric() const { return best_metric_val; }
44  FLATNODE_HD int64_t LeftChildId() const { return left_child_id; }
45  FLATNODE_HD int64_t RightChildId() const { return left_child_id + 1; }
46  FLATNODE_HD IdxT InstanceCount() const { return instance_count; }
47 
49  IdxT colid, DataT quesval, DataT best_metric_val, int64_t left_child_id, IdxT instance_count)
50  {
52  colid, quesval, best_metric_val, left_child_id, instance_count};
53  }
54  FLATNODE_HD static SparseTreeNode CreateLeafNode(IdxT instance_count)
55  {
56  return SparseTreeNode<DataT, LabelT>{0, 0, 0, -1, instance_count};
57  }
58  FLATNODE_HD bool IsLeaf() const { return left_child_id == -1; }
59  bool operator==(const SparseTreeNode& other) const
60  {
61  return (this->colid == other.colid) && (this->quesval == other.quesval) &&
62  (this->best_metric_val == other.best_metric_val) &&
63  (this->left_child_id == other.left_child_id) &&
64  (this->instance_count == other.instance_count);
65  }
66 };
#define FLATNODE_HD
Definition: flatnode.h:13
Definition: flatnode.h:23
static FLATNODE_HD SparseTreeNode CreateSplitNode(IdxT colid, DataT quesval, DataT best_metric_val, int64_t left_child_id, IdxT instance_count)
Definition: flatnode.h:48
FLATNODE_HD int64_t LeftChildId() const
Definition: flatnode.h:44
bool operator==(const SparseTreeNode &other) const
Definition: flatnode.h:59
FLATNODE_HD DataT QueryValue() const
Definition: flatnode.h:42
static FLATNODE_HD SparseTreeNode CreateLeafNode(IdxT instance_count)
Definition: flatnode.h:54
FLATNODE_HD DataT BestMetric() const
Definition: flatnode.h:43
FLATNODE_HD int64_t RightChildId() const
Definition: flatnode.h:45
FLATNODE_HD bool IsLeaf() const
Definition: flatnode.h:58
FLATNODE_HD IdxT ColumnId() const
Definition: flatnode.h:41
FLATNODE_HD IdxT InstanceCount() const
Definition: flatnode.h:46