From 9a318256b74054b8d592efe7be298764d0de415a Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 12 Oct 2021 20:39:51 -0700 Subject: [PATCH] fix layout bugs --- src/ieee754/part/layout_experiment.py | 266 ++++++++++++++++---------- 1 file changed, 170 insertions(+), 96 deletions(-) diff --git a/src/ieee754/part/layout_experiment.py b/src/ieee754/part/layout_experiment.py index 346b8595..6ddea00d 100644 --- a/src/ieee754/part/layout_experiment.py +++ b/src/ieee754/part/layout_experiment.py @@ -13,18 +13,42 @@ Links: """ from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl -from nmigen.back.pysim import Simulator, Delay, Settle +from nmigen.sim import Simulator, Delay, Settle from nmigen.cli import rtlil +from enum import Enum from collections.abc import Mapping from functools import reduce import operator from collections import defaultdict -from pprint import pprint +import dataclasses from ieee754.part_mul_add.partpoints import PartitionPoints +@dataclasses.dataclass +class LayoutResult: + ppoints: PartitionPoints + bitp: dict + bmask: int + width: int + lane_shapes: dict + part_wid: int + full_part_count: int + + def __repr__(self): + fields = [] + for field in dataclasses.fields(LayoutResult): + field_v = getattr(self, field.name) + if isinstance(field_v, PartitionPoints): + field_v = ',\n '.join( + f"{k}: {v}" for k, v in field_v.items()) + field_v = f"{{{field_v}}}" + fields.append(f"{field.name}={field_v}") + fields = ",\n ".join(fields) + return f"LayoutResult({fields})" + + # main fn, which started out here in the bugtracker: # https://bugs.libre-soc.org/show_bug.cgi?id=713#c20 def layout(elwid, signed, part_counts, lane_shapes=None, fixed_width=None): @@ -33,7 +57,10 @@ def layout(elwid, signed, part_counts, lane_shapes=None, fixed_width=None): Glossary: * element: a single scalar value that is an element of a SIMD vector. it has a width in bits, and a signedness. Every element is made of 1 or - more parts. + more parts. An element optionally includes the padding associated with + it. + * lane: an element. An element optionally includes the padding associated + with it. * ElWid: the element-width (really the element type) of an instruction. Either an integer or a FP type. Integer `ElWid`s are sign-agnostic. In Python, `ElWid` is either an enum type or is `int`. @@ -52,13 +79,14 @@ def layout(elwid, signed, part_counts, lane_shapes=None, fixed_width=None): BF16 = ... F32 = ... F64 = ... - * part: A piece of a SIMD vector, every SIMD vector is made of a - non-negative integer of parts. Elements are made of a power-of-two - number of parts. A part is a fixed number of bits wide for each - different SIMD layout, it doesn't vary when `elwid` changes. A part - can have a bit width of any non-negative integer, it is not restricted - to power-of-two. SIMD vectors should have as few parts as necessary, - since some circuits have size proportional to the number of parts. + * part: (not to be confused with a partition) A piece of a SIMD vector, + every SIMD vector is made of a non-negative integer of parts. Elements + are made of a power-of-two number of parts. A part is a fixed number + of bits wide for each different SIMD layout, it doesn't vary when + `elwid` changes. A part can have a bit width of any non-negative + integer, it is not restricted to power-of-two. SIMD vectors should + have as few parts as necessary, since some circuits have size + proportional to the number of parts. * elwid: ElWid or nmigen Value with ElWid as the shape @@ -83,60 +111,88 @@ def layout(elwid, signed, part_counts, lane_shapes=None, fixed_width=None): the total width of a SIMD vector. One of lane_shapes and fixed_width must be provided. """ + print(f"layout(elwid={elwid},\n" + f" signed={signed},\n" + f" part_counts={part_counts},\n" + f" lane_shapes={lane_shapes},\n" + f" fixed_width={fixed_width})") + assert isinstance(part_counts, Mapping) + # assert all part_counts are powers of two + assert all(v != 0 and (v & (v - 1)) == 0 for v in part_counts.values()),\ + "part_counts values must all be powers of two" + + full_part_count = max(part_counts.values()) + # when there are no lane_shapes specified, this indicates a # desire to use the maximum available space based on the fixed width # https://bugs.libre-soc.org/show_bug.cgi?id=713#c67 if lane_shapes is None: assert fixed_width is not None, \ "both fixed_width and lane_shapes cannot be None" - lane_shapes = {i: fixed_width // part_counts[i] for i in part_counts} + lane_shapes = {} + for k, cur_part_count in part_counts.items(): + cur_element_count = full_part_count // cur_part_count + assert fixed_width % cur_element_count == 0, ( + f"fixed_width ({fixed_width}) can't be split evenly into " + f"{cur_element_count} elements") + lane_shapes[k] = fixed_width // cur_element_count print("lane_shapes", fixed_width, lane_shapes) # identify if the lane_shapes is a mapping (dict, etc.) # if not, then assume that it is an integer (width) that # needs to be requested across all partitions if not isinstance(lane_shapes, Mapping): lane_shapes = {i: lane_shapes for i in part_counts} - # compute a set of partition widths - cpart_wid = [-lane_shapes[i] // c for i, c in part_counts.items()] - print("cpart_wid", cpart_wid, "part_counts", part_counts) - cpart_wid = -min(cpart_wid) - part_count = max(part_counts.values()) - # calculate the minumum width required - width = cpart_wid * part_count - print("width", width, cpart_wid, part_count) + # calculate the minimum possible bit-width of a part. + # we divide each element's width by the number of parts in an element, + # giving the number of bits needed per part. + # we use `-min(-a // b for ...)` to get `max(ceil(a / b) for ...)`, + # but using integers. + min_part_wid = -min(-lane_shapes[i] // c for i, c in part_counts.items()) + # calculate the minimum bit-width required + min_width = min_part_wid * full_part_count + print("width", min_width, min_part_wid, full_part_count) if fixed_width is not None: # override the width and part_wid - assert width < fixed_width, "not enough space to fit partitions" - part_wid = fixed_width // part_count - assert part_wid * part_count == fixed_width, \ - "calculated width not aligned multiples" + assert min_width <= fixed_width, "not enough space to fit partitions" + part_wid = fixed_width // full_part_count + assert fixed_width % full_part_count == 0, \ + "fixed_width must be a multiple of full_part_count" width = fixed_width - print("part_wid", part_wid, "count", part_count) + print("part_wid", part_wid, "count", full_part_count) else: # go with computed width - part_wid = cpart_wid + width = min_width + part_wid = min_part_wid # create the breakpoints dictionary. # do multi-stage version https://bugs.libre-soc.org/show_bug.cgi?id=713#c34 # https://stackoverflow.com/questions/26367812/ - dpoints = defaultdict(list) # if empty key, create a (empty) list - for i, c in part_counts.items(): - def add_p(p): - dpoints[p].append(i) # auto-creates list if key non-existent - for start in range(0, part_count, c): - add_p(start * part_wid) # start of lane - add_p(start * part_wid + lane_shapes[i]) # start of padding + # dpoints: dict from bit-index to dict[ElWid, None] + # we use a dict from ElWid to None as the values of dpoints in order to + # get an ordered set + dpoints = defaultdict(dict) # if empty key, create a (empty) dict + for i, cur_part_count in part_counts.items(): + def add_p(bit_index): + # auto-creates dict if key non-existent + dpoints[bit_index][i] = None + # go through all elements for elwid `i`, each element starts at + # part index `start_part`, and goes for `cur_part_count` parts + for start_part in range(0, full_part_count, cur_part_count): + start_bit = start_part * part_wid + add_p(start_bit) # start of lane + add_p(start_bit + lane_shapes[i]) # start of padding # do not need the breakpoints at the very start or the very end dpoints.pop(0, None) dpoints.pop(width, None) plist = list(dpoints.keys()) plist.sort() print("dpoints") - pprint(dict(dpoints)) + for k in plist: + print(f"{k}: {list(dpoints[k].keys())}") # second stage, add (map to) the elwidth==i expressions. # TODO: use nmutil.treereduce? points = {} for p in plist: - points[p] = map(lambda i: elwid == i, dpoints[p]) - points[p] = reduce(operator.or_, points[p]) + it = map(lambda i: elwid == i, dpoints[p]) + points[p] = reduce(operator.or_, it) # third stage, create the binary values which *if* elwidth is set to i # *would* result in the mask at that elwidth being set to this value # these can easily be double-checked through Assertion @@ -149,92 +205,111 @@ def layout(elwid, signed, part_counts, lane_shapes=None, fixed_width=None): bitp[i] |= 1 << bitpos # fourth stage: determine which partitions are 100% unused. # these can then be "blanked out" - bmask = (1 << len(plist))-1 + bmask = (1 << len(plist)) - 1 for p in bitp.values(): bmask &= ~p - return (PartitionPoints(points), bitp, bmask, width, lane_shapes, - part_wid, part_count) + return LayoutResult(PartitionPoints(points), bitp, bmask, width, + lane_shapes, part_wid, full_part_count) if __name__ == '__main__': - # for each element-width (elwidth 0-3) the number of partitions is given - # elwidth=0b00 QTY 1 partitions: | ? | - # elwidth=0b01 QTY 1 partitions: | ? | - # elwidth=0b10 QTY 2 partitions: | ? | ? | - # elwidth=0b11 QTY 4 partitions: | ? | ? | ? | ? | + class FpElWid(Enum): + F64 = 0 + F32 = 1 + F16 = 2 + BF16 = 3 + + def __repr__(self): + return super().__str__() + + class IntElWid(Enum): + I64 = 0 + I32 = 1 + I16 = 2 + I8 = 3 + + def __repr__(self): + return super().__str__() + + # for each element-width (elwidth 0-3) the number of parts in an element + # is given: + # | part0 | part1 | part2 | part3 | + # elwid=F64 4 parts per element: |<-------------F64------------->| + # elwid=F32 2 parts per element: |<-----F32----->|<-----F32----->| + # elwid=F16 1 part per element: |<-F16->|<-F16->|<-F16->|<-F16->| + # elwid=BF16 1 part per element: ||||| # actual widths of Signals *within* those partitions is given separately part_counts = { - 0: 1, - 1: 1, - 2: 2, - 3: 4, + FpElWid.F64: 4, + FpElWid.F32: 2, + FpElWid.F16: 1, + FpElWid.BF16: 1, } - # width=3 indicates "we want the same width (3) at all elwidths" - # elwidth=0b00 1x 5-bit | ..3 | - # elwidth=0b01 1x 6-bit | ..3 | - # elwidth=0b10 2x 12-bit | ..3 | ..3 | - # elwidth=0b11 3x 24-bit | ..3| ..3 | ..3 |..3 | - width_in_all_parts = 3 + # width=3 indicates "we want the same element bit-width (3) at all elwids" + # elwid=F64 1x 3-bit |<--------i3------->| + # elwid=F32 2x 3-bit |<---i3-->|<---i3-->| + # elwid=F16 4x 3-bit ||||| + # elwid=BF16 4x 3-bit ||||| + width_for_all_els = 3 - for i in range(4): - pprint((i, layout(i, True, part_counts, width_in_all_parts))) + for i in FpElWid: + print(i, layout(i, True, part_counts, width_for_all_els)) # fixed_width=32 and no lane_widths says "allocate maximum" - # elwidth=0b00 1x 32-bit | .................32 | - # elwidth=0b01 1x 32-bit | .................32 | - # elwidth=0b10 2x 12-bit | ......16 | ......16 | - # elwidth=0b11 3x 24-bit | ..8| ..8 | ..8 |..8 | + # elwid=F64 1x 32-bit |<-------i32------->| + # elwid=F32 2x 16-bit |<--i16-->|<--i16-->| + # elwid=F16 4x 8-bit ||||| + # elwid=BF16 4x 8-bit ||||| - #print ("maximum allocation from fixed_width=32") - # for i in range(4): - # pprint((i, layout(i, True, part_counts, fixed_width=32))) + print("maximum allocation from fixed_width=32") + for i in FpElWid: + print(i, layout(i, True, part_counts, fixed_width=32)) # specify that the length is to be *different* at each of the elwidths. # combined with part_counts we have: - # elwidth=0b00 1x 5-bit | ....5 | - # elwidth=0b01 1x 6-bit | .....6 | - # elwidth=0b10 2x 12-bit | ....12 | .....12 | - # elwidth=0b11 3x 24-bit | 24 | 24 | 24 | 24 | + # elwid=F64 1x 24-bit |<-------i24------->| + # elwid=F32 2x 12-bit |<--i12-->|<--i12-->| + # elwid=F16 4x 6-bit ||||| + # elwid=BF16 4x 5-bit ||||| widths_at_elwidth = { - 0: 5, - 1: 6, - 2: 12, - 3: 24 + FpElWid.F64: 24, + FpElWid.F32: 12, + FpElWid.F16: 6, + FpElWid.BF16: 5, } - for i in range(4): - pprint((i, layout(i, False, part_counts, widths_at_elwidth))) + for i in FpElWid: + print(i, layout(i, False, part_counts, widths_at_elwidth)) # this tests elwidth as an actual Signal. layout is allowed to # determine arbitrarily the overall length # https://bugs.libre-soc.org/show_bug.cgi?id=713#c30 - elwid = Signal(2) - pp, bitp, bm, b, c, d, e = layout( - elwid, False, part_counts, widths_at_elwidth) - pprint((pp, b, c, d, e)) - for k, v in bitp.items(): - print("bitp elwidth=%d" % k, bin(v)) - print("bmask", bin(bm)) + elwid = Signal(FpElWid) + lr = layout(elwid, False, part_counts, widths_at_elwidth) + print(lr) + for k, v in lr.bitp.items(): + print(f"bitp elwidth={k}", bin(v)) + print("bmask", bin(lr.bmask)) m = Module() def process(): - for i in range(4): + for i in FpElWid: yield elwid.eq(i) yield Settle() ppt = [] - for pval in list(pp.values()): + for pval in lr.ppoints.values(): val = yield pval # get nmigen to evaluate pp ppt.append(val) - pprint((i, (ppt, b, c, d, e))) + print(i, ppt) # check the results against bitp static-expected partition points # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47 # https://stackoverflow.com/a/27165694 ival = int(''.join(map(str, ppt[::-1])), 2) - assert ival == bitp[i] + assert ival == lr.bitp[i] sim = Simulator(m) sim.add_process(process) @@ -244,32 +319,31 @@ if __name__ == '__main__': # determine arbitrarily the overall length, it is fixed to 64 # https://bugs.libre-soc.org/show_bug.cgi?id=713#c22 - elwid = Signal(2) - pp, bitp, bm, b, c, d, e = layout(elwid, False, part_counts, widths_at_elwidth, - fixed_width=64) - pprint((pp, b, c, d, e)) - for k, v in bitp.items(): - print("bitp elwidth=%d" % k, bin(v)) - print("bmask", bin(bm)) + elwid = Signal(FpElWid) + lr = layout(elwid, False, part_counts, widths_at_elwidth, fixed_width=64) + print(lr) + for k, v in lr.bitp.items(): + print(f"bitp elwidth={k}", bin(v)) + print("bmask", bin(lr.bmask)) m = Module() def process(): - for i in range(4): + for i in FpElWid: yield elwid.eq(i) yield Settle() ppt = [] - for pval in list(pp.values()): + for pval in list(lr.ppoints.values()): val = yield pval # get nmigen to evaluate pp ppt.append(val) - print("test elwidth=%d" % i) - pprint((i, (ppt, b, c, d, e))) + print(f"test elwidth={i}") + print(i, ppt) # check the results against bitp static-expected partition points # https://bugs.libre-soc.org/show_bug.cgi?id=713#c47 # https://stackoverflow.com/a/27165694 ival = int(''.join(map(str, ppt[::-1])), 2) - assert ival == bitp[i], "ival %s actual %s" % (bin(ival), - bin(bitp[i])) + assert ival == lr.bitp[i], \ + f"ival {bin(ival)} actual {bin(lr.bitp[i])}" sim = Simulator(m) sim.add_process(process) -- 2.30.2