|
11 | 11 |
|
12 | 12 | #include "common.glslh" |
13 | 13 |
|
| 14 | +#extension GL_EXT_control_flow_attributes : require |
| 15 | + |
14 | 16 | #define DIMLIMIT 8 |
15 | 17 | #define DIMLIMIT_DIV4 2 |
16 | 18 |
|
| 19 | +// |
| 20 | +// Hashed layout utils |
| 21 | +// |
| 22 | + |
| 23 | +/* |
| 24 | + * The hashed layout is a packed int32 where each group of 4 bits contain some |
| 25 | + * information about the memory layout of a tensor buffer or texture. It is |
| 26 | + * passed into shaders as a specialization constant and allows shader compilers |
| 27 | + * to select optimized code paths suited for a particular memory layout. |
| 28 | + * |
| 29 | + * Currently the following information is packed into the layout integer: |
| 30 | + * - bits 0-15: first 4 elements of the dim order array |
| 31 | + * - bits 0-3: dim_order[0] |
| 32 | + * - bits 4-7: dim_order[1] |
| 33 | + * - bits 8-11: dim_order[2] |
| 34 | + * - bits 12-15: dim_order[3] |
| 35 | + */ |
| 36 | + |
| 37 | +// Extracts the 4-bit packed value at the given position (0..7) from a 32-bit |
| 38 | +// int. Position 0 corresponds to the least-significant 4 bits; position 7 to |
| 39 | +// the most-significant. |
| 40 | +int extract_4b(const int packed, const int pos) { |
| 41 | + return (packed >> (pos * 4)) & 0xF; |
| 42 | +} |
| 43 | + |
| 44 | + |
| 45 | +// Corresponds to dim_order[:4] = [0, 1, 2, 3] |
| 46 | +#define CONTIGUOUS_BUFFER_LAYOUT_ID 12816 |
| 47 | +// Corresponds to dim_order[:4] = [2, 0, 1, 3] |
| 48 | +#define CHANNELS_LAST_BUFFER_LAYOUT_ID 12546 |
| 49 | + |
| 50 | +// Used as a default value for hashed layout ints, representing the most common |
| 51 | +// layout used for buffer-backed tensors (i.e. contiguous buffers) |
| 52 | +#define CONTIG_LAYOUT_INT 12816 |
| 53 | + |
| 54 | +int layout_id(const int hashed_layout) { |
| 55 | + // Extract the first 16 bits |
| 56 | + return hashed_layout & 0xFFFF; |
| 57 | +} |
| 58 | + |
| 59 | +bool is_contiguous(const int hashed_layout) { |
| 60 | + return layout_id(hashed_layout) == CONTIGUOUS_BUFFER_LAYOUT_ID; |
| 61 | +} |
| 62 | + |
| 63 | +bool is_channels_last(const int hashed_layout) { |
| 64 | + return layout_id(hashed_layout) == CHANNELS_LAST_BUFFER_LAYOUT_ID; |
| 65 | +} |
| 66 | + |
17 | 67 | // |
18 | 68 | // BufferMetadata |
19 | 69 | // |
@@ -126,15 +176,6 @@ uint idx_at(const TensorIndex tidx, const uint dim) { |
126 | 176 | return tidx.data[div_4(dim)][mod_4(dim)]; |
127 | 177 | } |
128 | 178 |
|
129 | | -void permute(inout TensorIndex tidx, const ivec4 permute_order[DIMLIMIT_DIV4]) { |
130 | | - TensorIndex new_tidx = tidx; |
131 | | - for (int d = 0; d < DIMLIMIT; ++d) { |
132 | | - int src_dim = permute_order[div_4(d)][mod_4(d)]; |
133 | | - new_tidx.data[div_4(d)][mod_4(d)] = idx_at(tidx, src_dim); |
134 | | - } |
135 | | - tidx = new_tidx; |
136 | | -} |
137 | | - |
138 | 179 | uint x(const TensorIndex tidx) { |
139 | 180 | return tidx.data[0][0]; |
140 | 181 | } |
@@ -174,83 +215,110 @@ struct TextureElementIndex { |
174 | 215 | // Index Conversions |
175 | 216 | // |
176 | 217 |
|
177 | | -void contiguous_idx_to_tensor_idx( |
| 218 | +TensorIndex contiguous_idx_to_tensor_idx( |
178 | 219 | const BufferMetadata meta, |
179 | | - uint contiguous_idx, |
180 | | - out TensorIndex tidx) { |
181 | | - initialize(tidx); |
182 | | - int dim = int_ndim(meta); |
183 | | - int i = 0; |
| 220 | + uint contiguous_idx) { |
| 221 | + TensorIndex tidx; |
184 | 222 |
|
185 | 223 | uint contiguous_strides[DIMLIMIT]; |
| 224 | + |
186 | 225 | contiguous_strides[0] = 1; |
187 | | - for (int d = 1; d < DIMLIMIT; ++d) { |
| 226 | + [[unroll]] for (int d = 1; d < DIMLIMIT; ++d) { |
188 | 227 | contiguous_strides[d] = size_at(meta, d - 1) * contiguous_strides[d - 1]; |
189 | 228 | } |
190 | 229 |
|
191 | | - for (int d = max(dim - 1, 0); d >= 0; d--) { |
192 | | - uint dim_stride = contiguous_strides[d]; |
193 | | - |
194 | | - tidx.data[div_4(d)][mod_4(d)] = contiguous_idx / dim_stride; |
195 | | - contiguous_idx = contiguous_idx % dim_stride; |
| 230 | + [[unroll]] for (int d = DIMLIMIT - 1; d >= 0; --d) { |
| 231 | + tidx.data[div_4(d)][mod_4(d)] = contiguous_idx / contiguous_strides[d]; |
| 232 | + contiguous_idx = contiguous_idx % contiguous_strides[d]; |
196 | 233 | } |
197 | | -} |
198 | 234 |
|
199 | | -TensorIndex contiguous_idx_to_tensor_idx( |
200 | | - const BufferMetadata meta, |
201 | | - uint contiguous_idx) { |
202 | | - TensorIndex tidx; |
203 | | - contiguous_idx_to_tensor_idx(meta, contiguous_idx, tidx); |
204 | 235 | return tidx; |
205 | 236 | } |
206 | 237 |
|
207 | 238 | uint tensor_idx_to_contiguous_idx( |
208 | 239 | const BufferMetadata meta, |
209 | 240 | const TensorIndex tidx) { |
210 | 241 | uint contiguous_strides[DIMLIMIT]; |
| 242 | + |
211 | 243 | contiguous_strides[0] = 1; |
212 | | - for (int d = 1; d < DIMLIMIT; ++d) { |
| 244 | + [[unroll]] for (int d = 1; d < DIMLIMIT; ++d) { |
213 | 245 | contiguous_strides[d] = size_at(meta, d - 1) * contiguous_strides[d - 1]; |
214 | 246 | } |
215 | 247 |
|
216 | 248 | uint contig_idx = 0; |
217 | | - for (int d = 0; d < ndim(meta); ++d) { |
| 249 | + [[unroll]] for (int d = 0; d < DIMLIMIT; ++d) { |
218 | 250 | contig_idx += contiguous_strides[d] * idx_at(tidx, d); |
219 | 251 | } |
| 252 | + |
220 | 253 | return contig_idx; |
221 | 254 | } |
222 | 255 |
|
223 | | -void linear_idx_to_tensor_idx( |
| 256 | +TensorIndex linear_idx_to_tensor_idx( |
224 | 257 | const BufferMetadata meta, |
225 | | - uint linear_idx, |
226 | | - out TensorIndex tidx) { |
| 258 | + uint linear_idx) { |
| 259 | + TensorIndex tidx; |
227 | 260 | initialize(tidx); |
228 | 261 | int dim = int_ndim(meta); |
229 | 262 | int i = 0; |
230 | 263 | for (int d = max(dim - 1, 0); d >= 0; d--) { |
231 | | - uint dim_idx = dim_order_at(meta, d); |
232 | | - uint dim_stride = stride_at(meta, dim_idx); |
| 264 | + uint dim_idx = meta.dim_order[div_4(d)][mod_4(d)]; |
| 265 | + uint dim_stride = meta.strides[div_4(dim_idx)][mod_4(dim_idx)]; |
233 | 266 |
|
234 | 267 | tidx.data[div_4(dim_idx)][mod_4(dim_idx)] = linear_idx / dim_stride; |
235 | 268 | linear_idx = linear_idx % dim_stride; |
236 | 269 | } |
| 270 | + return tidx; |
237 | 271 | } |
238 | 272 |
|
239 | | -TensorIndex linear_idx_to_tensor_idx( |
| 273 | +TensorIndex linear_idx_to_tensor_idx_contig_case( |
| 274 | + const BufferMetadata meta, |
| 275 | + uint linear_idx) { |
| 276 | + TensorIndex tidx; |
| 277 | + |
| 278 | + [[unroll]] for (int d = DIMLIMIT - 1; d >= 0; --d) { |
| 279 | + tidx.data[div_4(d)][mod_4(d)] = linear_idx / stride_at(meta, d); |
| 280 | + linear_idx = linear_idx % stride_at(meta, d); |
| 281 | + } |
| 282 | + |
| 283 | + return tidx; |
| 284 | +} |
| 285 | + |
| 286 | +TensorIndex linear_idx_to_tensor_idx_channelslast_case( |
240 | 287 | const BufferMetadata meta, |
241 | 288 | uint linear_idx) { |
242 | 289 | TensorIndex tidx; |
243 | | - linear_idx_to_tensor_idx(meta, linear_idx, tidx); |
| 290 | + |
| 291 | + const uint dim_order[DIMLIMIT] = uint[DIMLIMIT](2, 0, 1, 3, 6, 5, 4, 7); |
| 292 | + |
| 293 | + [[unroll]] for (int d = DIMLIMIT - 1; d >= 0; --d) { |
| 294 | + uint dim = dim_order[d]; |
| 295 | + tidx.data[div_4(dim)][mod_4(dim)] = linear_idx / stride_at(meta, dim); |
| 296 | + linear_idx = linear_idx % stride_at(meta, dim); |
| 297 | + } |
| 298 | + |
244 | 299 | return tidx; |
245 | 300 | } |
246 | 301 |
|
| 302 | +TensorIndex linear_idx_to_tensor_idx( |
| 303 | + const BufferMetadata meta, |
| 304 | + uint linear_idx, |
| 305 | + int hashed_layout) { |
| 306 | + if (is_contiguous(hashed_layout)) { |
| 307 | + return linear_idx_to_tensor_idx_contig_case(meta, linear_idx); |
| 308 | + } else if (is_channels_last(hashed_layout)) { |
| 309 | + return linear_idx_to_tensor_idx_channelslast_case(meta, linear_idx); |
| 310 | + } |
| 311 | + return linear_idx_to_tensor_idx(meta, linear_idx); |
| 312 | +} |
| 313 | + |
247 | 314 | uint tensor_idx_to_linear_idx( |
248 | 315 | const BufferMetadata meta, |
249 | 316 | const TensorIndex tidx) { |
250 | 317 | uint lin_idx = 0; |
251 | | - for (int d = 0; d < ndim(meta); ++d) { |
| 318 | + [[unroll]] for (int d = 0; d < DIMLIMIT; ++d) { |
252 | 319 | lin_idx += stride_at(meta, d) * idx_at(tidx, d); |
253 | 320 | } |
| 321 | + |
254 | 322 | return lin_idx; |
255 | 323 | } |
256 | 324 |
|
|
0 commit comments