run tests in parallel
[ieee754fpu.git] / src / ieee754 / part_mul_add / partpoints.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Value, Cat, C
6
7 def make_partition(mask, width):
8 """ from a mask and a bitwidth, create partition points.
9 note that the assumption is that the mask indicates the
10 breakpoints in regular intervals, and that the last bit (MSB)
11 of the mask is therefore *ignored*.
12 mask len = 4, width == 16 will return:
13 {4: mask[0], 8: mask[1], 12: mask[2]}
14 mask len = 8, width == 64 will return:
15 {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
16 """
17 ppoints = {}
18 mlen = mask.shape()[0]
19 ppos = mlen
20 midx = 0
21 while ppos < width and midx < mlen: # -1, ignore last bit
22 ppoints[ppos] = mask[midx]
23 ppos += mlen
24 midx += 1
25 return ppoints
26
27
28 class PartitionPoints(dict):
29 """Partition points and corresponding ``Value``s.
30
31 The points at where an ALU is partitioned along with ``Value``s that
32 specify if the corresponding partition points are enabled.
33
34 For example: ``{1: True, 5: True, 10: True}`` with
35 ``width == 16`` specifies that the ALU is split into 4 sections:
36 * bits 0 <= ``i`` < 1
37 * bits 1 <= ``i`` < 5
38 * bits 5 <= ``i`` < 10
39 * bits 10 <= ``i`` < 16
40
41 If the partition_points were instead ``{1: True, 5: a, 10: True}``
42 where ``a`` is a 1-bit ``Signal``:
43 * If ``a`` is asserted:
44 * bits 0 <= ``i`` < 1
45 * bits 1 <= ``i`` < 5
46 * bits 5 <= ``i`` < 10
47 * bits 10 <= ``i`` < 16
48 * Otherwise
49 * bits 0 <= ``i`` < 1
50 * bits 1 <= ``i`` < 10
51 * bits 10 <= ``i`` < 16
52 """
53
54 def __init__(self, partition_points=None):
55 """Create a new ``PartitionPoints``.
56
57 :param partition_points: the input partition points to values mapping.
58 """
59 super().__init__()
60 if partition_points is not None:
61 for point, enabled in partition_points.items():
62 if not isinstance(point, int):
63 raise TypeError("point must be a non-negative integer")
64 if point < 0:
65 raise ValueError("point must be a non-negative integer")
66 self[point] = Value.cast(enabled)
67
68 def like(self, name=None, src_loc_at=0, mul=1):
69 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
70
71 :param name: the base name for the new ``Signal``s.
72 :param mul: a multiplication factor on the indices
73 """
74 if name is None:
75 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
76 retval = PartitionPoints()
77 for point, enabled in self.items():
78 point *= mul
79 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
80 return retval
81
82 def eq(self, rhs):
83 """Assign ``PartitionPoints`` using ``Signal.eq``."""
84 if set(self.keys()) != set(rhs.keys()):
85 raise ValueError("incompatible point set")
86 for point, enabled in self.items():
87 yield enabled.eq(rhs[point])
88
89 def as_mask(self, width, mul=1):
90 """Create a bit-mask from `self`.
91
92 Each bit in the returned mask is clear only if the partition point at
93 the same bit-index is enabled.
94
95 :param width: the bit width of the resulting mask
96 :param mul: a "multiplier" which in-place expands the partition points
97 typically set to "2" when used for multipliers
98 """
99 bits = []
100 for i in range(width):
101 i /= mul
102 if i.is_integer() and int(i) in self:
103 bits.append(~self[i])
104 else:
105 bits.append(True)
106 return Cat(*bits)
107
108 def as_sig(self):
109 """Create a straight concatenation of `self` signals
110 """
111 return Cat(self.values())
112
113 def get_max_partition_count(self, width):
114 """Get the maximum number of partitions.
115
116 Gets the number of partitions when all partition points are enabled.
117 """
118 retval = 1
119 for point in self.keys():
120 if point < width:
121 retval += 1
122 return retval
123
124 def fits_in_width(self, width):
125 """Check if all partition points are smaller than `width`."""
126 for point in self.keys():
127 if point >= width:
128 return False
129 return True
130
131 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
132 if index == -1 or index == 7:
133 return C(True, 1)
134 assert index >= 0 and index < 8
135 return self[(index * 8 + 8)*mfactor]
136
137