Halide 18.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
ThreadInfo.h
Go to the documentation of this file.
1#ifndef THREAD_INFO_H
2#define THREAD_INFO_H
3
4/** \file
5 *
6 * Data structure containing information about GPU threads for a particular
7 * location in the loop nest and its surrounding block. Useful when computing
8 * GPU features
9 */
10
11#include <vector>
12
13#include "Errors.h"
14#include "FunctionDAG.h"
15
16namespace Halide {
17namespace Internal {
18namespace Autoscheduler {
19
20static constexpr int MAX_THREADS_PER_BLOCK = 1024;
21
22struct LoopNest;
23
24// Sort / filter thread tile options
28 bool operator<(const ThreadTileOption &other) const {
29 return max_idle_lane_wastage < other.max_idle_lane_wastage;
30 }
31
32 // Ensure we don't accidentally copy this type
33 ThreadTileOption() = default;
38};
39
40struct ThreadInfo {
41 ThreadInfo(int vectorized_loop_index, const std::vector<int64_t> &size, const std::vector<FunctionDAG::Node::Loop> &loop, const std::vector<int64_t> &max_thread_counts) {
42 init_threads_in_this_block(max_thread_counts);
43
44 std::size_t num_thread_loops = 0;
45
46 if (vectorized_loop_index != -1 && size[vectorized_loop_index] != 1) {
47 threads[num_thread_loops] = size[vectorized_loop_index];
48 num_threads *= size[vectorized_loop_index];
50 loop_indices.push_back(vectorized_loop_index);
51 loop_vars.push_back(loop[vectorized_loop_index].var);
52 }
53
54 for (std::size_t i = 0; i < size.size() && num_thread_loops < 3; i++) {
55 if (size[i] == 1 || (int)i == vectorized_loop_index) {
56 continue;
57 }
58
59 if (num_threads * size[i] > MAX_THREADS_PER_BLOCK) {
60 break;
61 }
62
63 threads[num_thread_loops] = size[i];
64 num_threads *= size[i];
66 loop_indices.push_back(i);
67 loop_vars.push_back(loop[i].var);
68 }
69
70 if (loop_indices.empty()) {
71 internal_assert(!size.empty());
73 loop_indices.push_back(0);
74 loop_vars.push_back(loop[0].var);
75 }
76
80 internal_assert(!loop_indices.empty() && loop_indices.size() <= 3);
81 internal_assert(!loop_vars.empty() && loop_vars.size() <= 3);
82
83 count_num_active_warps_per_block();
84 }
85
86 template<typename Fn>
87 void for_each_thread_id(const Fn &fn) const {
88 int thread_id = 0;
89 for (int z = 0; z < threads_in_this_block[2]; z++) {
90 for (int y = 0; y < threads_in_this_block[1]; y++) {
91 for (int x = 0; x < threads_in_this_block[0]; x++) {
92 // Skip any threads in this loop nest with extent less than the
93 // extents of the largest thread loops in this block
94 // for thread.x in [0, 10]:
95 // ...
96 // for thread.x in [0, 5]:
97 // ...
98 // For the 2nd loop, skip threads with x id >= 5
99 bool active = x < threads[0] && y < threads[1] && z < threads[2];
100
102 ++thread_id;
103 }
104 }
105 }
106 }
107
108 template<typename Fn>
110 int thread_id = 0;
111 for (int z = 0; z < threads_in_this_block[2]; z++) {
112 for (int y = 0; y < threads_in_this_block[1]; y++) {
113 for (int x = 0; x < threads_in_this_block[0]; x++) {
114 // Skip any threads in this loop nest with extent less than the
115 // extents of the largest thread loops in this block
116 // for thread.x in [0, 10]:
117 // ...
118 // for thread.x in [0, 5]:
119 // ...
120 // For the 2nd loop, skip threads with x id >= 5
121 bool active = x < threads[0] && y < threads[1] && z < threads[2];
122
123 bool last_thread = thread_id == 31;
124 fn(thread_id, x, y, z, active, last_thread);
125 ++thread_id;
126
127 if (last_thread) {
128 return;
129 }
130 }
131 }
132 }
133 }
134
135 template<typename Fn>
154
155 template<typename Fn>
156 void for_each_active_thread_id(const Fn &fn) const {
158 if (!is_active) {
159 return;
160 }
161
163 });
164 }
165
166 double warp_lane_utilization() const {
168 }
169
170 double idle_lane_wastage() const {
171 return ((double)(num_active_warps_per_block * 32) - (double)num_active_threads) / MAX_THREADS_PER_BLOCK;
172 }
173
174 double block_occupancy() const {
175 return (double)num_threads / MAX_THREADS_PER_BLOCK;
176 }
177
181 bool has_tail_warp = false;
184
185 int threads_in_this_block[3] = {1, 1, 1};
187
188 int threads[3] = {1, 1, 1};
191
192 std::vector<int> loop_indices;
193 std::vector<std::string> loop_vars;
194
195private:
196 void init_threads_in_this_block(const std::vector<int64_t> &max_thread_counts) {
197 int num_thread_loops = 0;
198 for (auto c : max_thread_counts) {
199 if (c == 1) {
200 continue;
201 }
202
203 if (num_thread_loops >= 3 || num_threads_in_this_block * c > MAX_THREADS_PER_BLOCK) {
204 break;
205 }
206
210 }
211
213 if (num_threads_in_this_block % 32 != 0) {
215 }
216 }
217
218 void count_num_active_warps_per_block() {
219 bool current_warp_is_active = false;
223 bool first_warp = true;
224
227
228 if (is_active) {
231 }
233
234 if ((thread_id + 1) % 32 == 0 || is_last_thread) {
237
238 if (first_warp) {
239 first_warp = false;
241 }
242
243 if (is_last_thread) {
247
249 }
250 }
251
255 }
256 });
257
259 if (has_tail_warp) {
261 }
262 }
263};
264
265} // namespace Autoscheduler
266} // namespace Internal
267} // namespace Halide
268
269#endif // THREAD_INFO_H
#define internal_assert(c)
Definition Errors.h:19
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
Internal::ConstantInterval cast(Type t, const Internal::ConstantInterval &a)
Cast operators for ConstantIntervals.
signed __INT64_TYPE__ int64_t
std::vector< std::string > loop_vars
Definition ThreadInfo.h:193
void for_each_active_thread_id(const Fn &fn) const
Definition ThreadInfo.h:156
ThreadInfo(int vectorized_loop_index, const std::vector< int64_t > &size, const std::vector< FunctionDAG::Node::Loop > &loop, const std::vector< int64_t > &max_thread_counts)
Definition ThreadInfo.h:41
void for_each_thread_id(const Fn &fn) const
Definition ThreadInfo.h:87
void for_each_thread_id_in_first_warp(Fn &fn) const
Definition ThreadInfo.h:109
ThreadTileOption & operator=(const ThreadTileOption &)=delete
bool operator<(const ThreadTileOption &other) const
Definition ThreadInfo.h:28
IntrusivePtr< const LoopNest > loop_nest
Definition ThreadInfo.h:26
ThreadTileOption & operator=(ThreadTileOption &&)=default
ThreadTileOption(ThreadTileOption &&)=default
ThreadTileOption(const ThreadTileOption &)=delete
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself.