d2f98d74b9c22d7e92bb72404eaf597734320e15
[ieee754fpu.git] / src / ieee754 / part / layout_experiment.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-3-or-later
3 # See Notices.txt for copyright information
4 """
5 Links:
6 * https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
7 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
8 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
9 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
10 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
11 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
12 * https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
13 """
14
15 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
16 from nmigen.back.pysim import Simulator, Delay, Settle
17 from nmigen.cli import rtlil
18
19 from collections.abc import Mapping
20 from functools import reduce
21 import operator
22 from collections import defaultdict
23 from pprint import pprint
24
25 from ieee754.part_mul_add.partpoints import PartitionPoints
26
27
28 # main fn, which started out here in the bugtracker:
29 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20
30 # note that signed is **NOT** part of the layout, and will NOT
31 # be added (because it is not relevant or appropriate).
32 # sign belongs in ast.Shape and is the only appropriate location.
33 # there is absolutely nothing within this function that in any
34 # way requires a sign. it is *purely* performing numerical width
35 # computations that have absolutely nothing to do with whether the
36 # actual data is signed or unsigned.
37 def layout(elwid, vec_el_counts, lane_shapes=None, fixed_width=None):
38 """calculate a SIMD layout.
39
40 Glossary:
41 * element: a single scalar value that is an element of a SIMD vector.
42 it has a width in bits. Every element is made of 1 or
43 more parts.
44 * ElWid: the element-width (really the element type) of an instruction.
45 Either an integer or a FP type. Integer `ElWid`s are sign-agnostic.
46 In Python, `ElWid` is either an enum type or is `int`.
47 Example `ElWid` definition for integers:
48
49 class ElWid(Enum):
50 I64 = ... # SVP64 value 0b00
51 I32 = ... # SVP64 value 0b01
52 I16 = ... # SVP64 value 0b10
53 I8 = ... # SVP64 value 0b11
54
55 Example `ElWid` definition for floats:
56
57 class ElWid(Enum):
58 F64 = ... # SVP64 value 0b00
59 F32 = ... # SVP64 value 0b01
60 F16 = ... # SVP64 value 0b10
61 BF16 = ... # SVP64 value 0b11
62
63 # XXX this is redundant and out-of-date with respect to the
64 # clarification that the input is in counts of *elements*
65 # *NOT* "fixed width parts".
66 # fixed-width parts results in 14 such parts being created
67 # when 5 will do, for a simple example 5-6-6-6
68 * part: A piece of a SIMD vector, every SIMD vector is made of a
69 non-negative integer of parts. Elements are made of a power-of-two
70 number of parts. A part is a fixed number of bits wide for each
71 different SIMD layout, it doesn't vary when `elwid` changes. A part
72 can have a bit width of any non-negative integer, it is not restricted
73 to power-of-two. SIMD vectors should have as few parts as necessary,
74 since some circuits have size proportional to the number of parts.
75
76 * elwid: ElWid or nmigen Value with ElWid as the shape
77 the current element-width
78
79 * vec_el_counts: dict[ElWid, int]
80 a map from `ElWid` values `k` to the number of vector elements
81 required within a partition when `elwid == k`.
82
83 Example:
84 vec_el_counts = {ElWid.I8(==0b11): 8, # 8 vector elements
85 ElWid.I16(==0b10): 4, # 4 vector elements
86 ElWid.I32(==0b01): 2, # 2 vector elements
87 ElWid.I64(==0b00): 1} # 1 vector (aka scalar) element
88
89 Another Example:
90 vec_el_counts = {ElWid.BF16(==0b11): 4, # 4 vector elements
91 ElWid.F16(==0b10): 4, # 4 vector elements
92 ElWid.F32(==0b01): 2, # 2 vector elements
93 ElWid.F64(==0b00): 1} # 1 (aka scalar) vector element
94
95 * lane_shapes: int or Mapping[ElWid, int] (optional)
96 the bit-width of all elements in a SIMD layout.
97 if not provided, the lane_shapes are computed from fixed_width
98 and vec_el_counts at each elwidth.
99
100 * fixed_width: int (optional)
101 the total width of a SIMD vector. One or both of lane_shapes or
102 fixed_width may be provided. Both may not be left out.
103 """
104 # when there are no lane_shapes specified, this indicates a
105 # desire to use the maximum available space based on the fixed width
106 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67
107 if lane_shapes is None:
108 assert fixed_width is not None, \
109 "both fixed_width and lane_shapes cannot be None"
110 lane_shapes = {i: fixed_width // vec_el_counts[i]
111 for i in vec_el_counts}
112 print("lane_shapes", fixed_width, lane_shapes)
113
114 # identify if the lane_shapes is a mapping (dict, etc.)
115 # if not, then assume that it is an integer (width) that
116 # needs to be requested across all partitions
117 if not isinstance(lane_shapes, Mapping):
118 lane_shapes = {i: lane_shapes for i in vec_el_counts}
119
120 # compute a set of partition widths
121 print("lane_shapes", lane_shapes, "vec_el_counts", vec_el_counts)
122 cpart_wid = 0
123 width = 0
124 for i, lwid in lane_shapes.items():
125 required_width = lwid * vec_el_counts[i]
126 print(" required width", cpart_wid, i, lwid, required_width)
127 if required_width > width:
128 cpart_wid = lwid
129 width = required_width
130
131 # calculate the minumum width required if fixed_width specified
132 part_count = max(vec_el_counts.values())
133 print("width", width, cpart_wid, part_count)
134 if fixed_width is not None: # override the width and part_wid
135 assert width <= fixed_width, "not enough space to fit partitions"
136 part_wid = fixed_width // part_count
137 assert part_wid * part_count == fixed_width, \
138 "calculated width not aligned multiples"
139 width = fixed_width
140 print("part_wid", part_wid, "count", part_count, "width", width)
141
142 # create the breakpoints dictionary.
143 # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34
144 # https://stackoverflow.com/questions/26367812/
145 dpoints = defaultdict(list) # if empty key, create a (empty) list
146 for i, c in vec_el_counts.items():
147 print ("dpoints", i, "count", c)
148 # calculate part_wid based on overall width divided by number
149 # of elements.
150 part_wid = width // c
151 def add_p(msg, start, p):
152 print (" adding dpoint", msg, start, part_wid, i, c, p)
153 dpoints[p].append(i) # auto-creates list if key non-existent
154 # for each elwidth, create the required number of vector elements
155 for start in range(c):
156 add_p("start", start, start * part_wid) # start of lane
157 add_p("end ", start, start * part_wid + lane_shapes[i]) # end lane
158
159 # do not need the breakpoints at the very start or the very end
160 dpoints.pop(0, None)
161 if fixed_width is not None:
162 dpoints.pop(fixed_width, None)
163 else:
164 dpoints.pop(width, None)
165 plist = list(dpoints.keys())
166 plist.sort()
167 print("dpoints")
168 pprint(dict(dpoints))
169
170 # second stage, add (map to) the elwidth==i expressions.
171 # TODO: use nmutil.treereduce?
172 points = {}
173 for p in plist:
174 points[p] = map(lambda i: elwid == i, dpoints[p])
175 points[p] = reduce(operator.or_, points[p])
176
177 # third stage, create the binary values which *if* elwidth is set to i
178 # *would* result in the mask at that elwidth being set to this value
179 # these can easily be double-checked through Assertion
180 bitp = {}
181 for i in vec_el_counts.keys():
182 bitp[i] = 0
183 for p, elwidths in dpoints.items():
184 if i in elwidths:
185 bitpos = plist.index(p)
186 bitp[i] |= 1 << bitpos
187
188 # fourth stage: determine which partitions are 100% unused.
189 # these can then be "blanked out"
190 bmask = (1 << len(plist))-1
191 for p in bitp.values():
192 bmask &= ~p
193 return (PartitionPoints(points), bitp, bmask, width, lane_shapes,
194 part_wid)
195
196
197 if __name__ == '__main__':
198
199 # for each element-width (elwidth 0-3) the number of Vector Elements is:
200 # elwidth=0b00 QTY 1 partitions: | ? |
201 # elwidth=0b01 QTY 1 partitions: | ? |
202 # elwidth=0b10 QTY 2 partitions: | ? | ? |
203 # elwidth=0b11 QTY 4 partitions: | ? | ? | ? | ? |
204 # actual widths of Signals *within* those partitions is given separately
205 vec_el_counts = {
206 0: 1,
207 1: 1,
208 2: 2,
209 3: 4,
210 }
211
212 # width=3 indicates "same width Vector Elements (3) at all elwidths"
213 # elwidth=0b00 1x 5-bit | unused xx ..3 |
214 # elwidth=0b01 1x 6-bit | unused xx ..3 |
215 # elwidth=0b10 2x 12-bit | xxx ..3 | xxx ..3 |
216 # elwidth=0b11 3x 24-bit | ..3| ..3 | ..3 |..3 |
217 # expected partitions (^) | | | (^)
218 # to be at these points: (|) | | | |
219 width_in_all_parts = 3
220
221 for i in range(4):
222 pprint((i, layout(i, vec_el_counts, width_in_all_parts)))
223
224 # specify that the Vector Element lengths are to be *different* at
225 # each of the elwidths.
226 # combined with vec_el_counts we have:
227 # elwidth=0b00 1x 5-bit |<----unused---------->....5|
228 # elwidth=0b01 1x 6-bit |<----unused--------->.....6|
229 # elwidth=0b10 2x 6-bit |unused>.....6|unused>.....6|
230 # elwidth=0b11 4x 6-bit |.....6|.....6|.....6|.....6|
231 # expected partitions (^) ^ ^ ^^ (^)
232 # to be at these points: (|) | | || (|)
233 # (24) 18 12 65 (0)
234 widths_at_elwidth = {
235 0: 5,
236 1: 6,
237 2: 6,
238 3: 6
239 }
240
241 print ("5,6,6,6 elements", widths_at_elwidth)
242 for i in range(4):
243 pp, bitp, bm, b, c, d = \
244 layout(i, vec_el_counts, widths_at_elwidth)
245 pprint((i, (pp, bitp, bm, b, c, d)))
246 # now check that the expected partition points occur
247 print("5,6,6,6 ppt keys", pp.keys())
248 assert list(pp.keys()) == [5,6,12,18]
249
250 # this example was probably what the 5,6,6,6 one was supposed to be.
251 # combined with vec_el_counts {0:1, 1:1, 2:2, 3:4} we have:
252 # elwidth=0b00 1x 24-bit |.........................24|
253 # elwidth=0b01 1x 12-bit |<--unused--->|...........12|
254 # elwidth=0b10 2x 5 -bit |unused>|....5|unused>|....5|
255 # elwidth=0b11 4x 6 -bit |.....6|.....6|.....6|.....6|
256 # expected partitions (^) ^^ ^ ^^ (^)
257 # to be at these points: (|) || | || (|)
258 # (24) 1817 12 65 (0)
259 widths_at_elwidth = {
260 0: 24, # QTY 1x 24
261 1: 12, # QTY 1x 12
262 2: 5, # QTY 2x 5
263 3: 6 # QTY 4x 6
264 }
265
266 print ("24,12,5,6 elements", widths_at_elwidth)
267 for i in range(4):
268 pp, bitp, bm, b, c, d = \
269 layout(i, vec_el_counts, widths_at_elwidth)
270 pprint((i, (pp, bitp, bm, b, c, d)))
271 # now check that the expected partition points occur
272 print("24,12,5,6 ppt keys", pp.keys())
273 assert list(pp.keys()) == [5,6,12,17,18]
274
275
276 # this tests elwidth as an actual Signal. layout is allowed to
277 # determine arbitrarily the overall length
278 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30
279
280 elwid = Signal(2)
281 pp, bitp, bm, b, c, d = layout(
282 elwid, vec_el_counts, widths_at_elwidth)
283 pprint((pp, b, c, d))
284 for k, v in bitp.items():
285 print("bitp elwidth=%d" % k, bin(v))
286 print("bmask", bin(bm))
287
288 m = Module()
289
290 def process():
291 for i in range(4):
292 yield elwid.eq(i)
293 yield Settle()
294 ppt = []
295 for pval in list(pp.values()):
296 val = yield pval # get nmigen to evaluate pp
297 ppt.append(val)
298 pprint((i, (ppt, b, c, d)))
299 # check the results against bitp static-expected partition points
300 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
301 # https://stackoverflow.com/a/27165694
302 ival = int(''.join(map(str, ppt[::-1])), 2)
303 assert ival == bitp[i]
304
305 sim = Simulator(m)
306 sim.add_process(process)
307 sim.run()
308
309 # this tests elwidth as an actual Signal. layout is *not* allowed to
310 # determine arbitrarily the overall length, it is fixed to 64
311 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22
312
313 elwid = Signal(2)
314 pp, bitp, bm, b, c, d = layout(elwid, vec_el_counts,
315 widths_at_elwidth,
316 fixed_width=64)
317 pprint((pp, b, c, d))
318 for k, v in bitp.items():
319 print("bitp elwidth=%d" % k, bin(v))
320 print("bmask", bin(bm))
321
322 m = Module()
323
324 def process():
325 for i in range(4):
326 yield elwid.eq(i)
327 yield Settle()
328 ppt = []
329 for pval in list(pp.values()):
330 val = yield pval # get nmigen to evaluate pp
331 ppt.append(val)
332 print("test elwidth=%d" % i)
333 pprint((i, (ppt, b, c, d)))
334 # check the results against bitp static-expected partition points
335 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47
336 # https://stackoverflow.com/a/27165694
337 ival = int(''.join(map(str, ppt[::-1])), 2)
338 assert ival == bitp[i], "ival %s actual %s" % (bin(ival),
339 bin(bitp[i]))
340
341 sim = Simulator(m)
342 sim.add_process(process)
343 sim.run()
344
345 # fixed_width=32 and no lane_widths says "allocate maximum"
346 # i.e. Vector Element Widths are auto-allocated
347 # elwidth=0b00 1x 32-bit | .................32 |
348 # elwidth=0b01 1x 32-bit | .................32 |
349 # elwidth=0b10 2x 12-bit | ......16 | ......16 |
350 # elwidth=0b11 3x 24-bit | ..8| ..8 | ..8 |..8 |
351 # expected partitions (^) | | | (^)
352 # to be at these points: (|) | | | |
353
354 # TODO, fix this so that it is correct. put it at the end so it
355 # shows that things break and doesn't stop the other tests.
356 print ("maximum allocation from fixed_width=32")
357 for i in range(4):
358 pprint((i, layout(i, vec_el_counts, fixed_width=32)))
359
360 # example "exponent"
361 # https://libre-soc.org/3d_gpu/architecture/dynamic_simd/shape/
362 # 1xFP64: 11 bits, one exponent
363 # 2xFP32: 8 bits, two exponents
364 # 4xFP16: 5 bits, four exponents
365 # 4xBF16: 8 bits, four exponents
366 vec_el_counts = {
367 0: 1, # QTY 1x FP64
368 1: 2, # QTY 2x FP32
369 2: 4, # QTY 4x FP16
370 3: 4, # QTY 4x BF16
371 }
372 widths_at_elwidth = {
373 0: 11, # FP64 ew=0b00
374 1: 8, # FP32 ew=0b01
375 2: 5, # FP16 ew=0b10
376 3: 8 # BF16 ew=0b11
377 }
378
379 # expected results:
380 #
381 # |31| | |24| 16|15 | | 8|7 0 |
382 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
383 # 32bit | x| x| x| | x| x| x|10 .... 0 |
384 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
385 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
386 # unused x x
387
388 print ("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
389 for i in range(4):
390 pp, bitp, bm, b, c, d = \
391 layout(i, vec_el_counts, widths_at_elwidth,
392 fixed_width=32)
393 pprint((i, (pp, bitp, bin(bm), b, c, d)))
394 # now check that the expected partition points occur
395 print("11,8,5,8 pp keys", pp.keys())
396 #assert list(pp.keys()) == [5,6,12,18]
397
398 ###### ######
399 ###### 2nd test, different from the above, elwid=0b10 ==> 11 bit ######
400 ###### ######
401
402 # example "exponent"
403 vec_el_counts = {
404 0: 1, # QTY 1x FP64
405 1: 2, # QTY 2x FP32
406 2: 4, # QTY 4x FP16
407 3: 4, # QTY 4x BF16
408 }
409 widths_at_elwidth = {
410 0: 11, # FP64 ew=0b00
411 1: 11, # FP32 ew=0b01
412 2: 5, # FP16 ew=0b10
413 3: 8 # BF16 ew=0b11
414 }
415
416 # expected results:
417 #
418 # |31| | |24| 16|15 | | 8|7 0 |
419 # |31|28|26|24| |20|16| 12| |10|8|5|4 0 |
420 # 32bit | x| x| x| | x| x| x|10 .... 0 |
421 # 16bit | x| x|26 ... 16 | x| x|10 .... 0 |
422 # 8bit | x|28 .. 24| 20.16| x|11 .. 8|x|4.. 0 |
423 # unused x x
424
425 print ("11,8,5,8 elements (FP64/32/16/BF exponents)", widths_at_elwidth)
426 for i in range(4):
427 pp, bitp, bm, b, c, d = \
428 layout(i, vec_el_counts, widths_at_elwidth,
429 fixed_width=32)
430 pprint((i, (pp, bitp, bin(bm), b, c, d)))
431 # now check that the expected partition points occur
432 print("11,8,5,8 pp keys", pp.keys())
433 #assert list(pp.keys()) == [5,6,12,18]
434