switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part_mul_add / partpoints.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Value, Cat, C
6
7
8 def make_partition(mask, width):
9 """ from a mask and a bitwidth, create partition points.
10 note that the assumption is that the mask indicates the
11 breakpoints in regular intervals, and that the last bit (MSB)
12 of the mask is therefore *ignored*.
13 mask len = 4, width == 16 will return:
14 {4: mask[0], 8: mask[1], 12: mask[2]}
15 mask len = 8, width == 64 will return:
16 {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
17 """
18 ppoints = {}
19 mlen = len(mask)
20 ppos = mlen
21 midx = 0
22 while ppos < width and midx < mlen: # -1, ignore last bit
23 ppoints[ppos] = mask[midx]
24 ppos += mlen
25 midx += 1
26 return ppoints
27
28
29 def make_partition2(mask, width):
30 """ from a mask and a bitwidth, create partition points.
31 note that the mask represents the actual partition points
32 and therefore must be ONE LESS than the number of required
33 partitions
34
35 mask len = 3, width == 16 will return:
36 {4: mask[0], 8: mask[1], 12: mask[2]}
37 mask len = 7, width == 64 will return:
38 {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
39 """
40 if isinstance(mask, dict): # convert dict/partpoints to sequential list
41 mask = list(mask.values())
42 mlen = len(mask) + 1 # ONE MORE partitions than break-points
43 jumpsize = width // mlen # amount to jump by (size of each partition)
44 ppoints = {}
45 ppos = jumpsize
46 print ("make_partition2", width, mask, len(mask), mlen, jumpsize)
47 assert jumpsize > 0, "incorrect width // mlen (%d // %d)" % (width, mlen)
48 midx = 0
49 while ppos < width and midx < mlen: # -1, ignore last bit
50 print (" make_partition2", ppos, width, midx, mlen)
51 ppoints[ppos] = mask[midx]
52 ppos += jumpsize
53 midx += 1
54 print (" make_partition2", mask, width, ppoints)
55 return ppoints
56
57
58 class PartitionPoints(dict):
59 """Partition points and corresponding ``Value``s.
60
61 The points at where an ALU is partitioned along with ``Value``s that
62 specify if the corresponding partition points are enabled.
63
64 For example: ``{1: True, 5: True, 10: True}`` with
65 ``width == 16`` specifies that the ALU is split into 4 sections:
66 * bits 0 <= ``i`` < 1
67 * bits 1 <= ``i`` < 5
68 * bits 5 <= ``i`` < 10
69 * bits 10 <= ``i`` < 16
70
71 If the partition_points were instead ``{1: True, 5: a, 10: True}``
72 where ``a`` is a 1-bit ``Signal``:
73 * If ``a`` is asserted:
74 * bits 0 <= ``i`` < 1
75 * bits 1 <= ``i`` < 5
76 * bits 5 <= ``i`` < 10
77 * bits 10 <= ``i`` < 16
78 * Otherwise
79 * bits 0 <= ``i`` < 1
80 * bits 1 <= ``i`` < 10
81 * bits 10 <= ``i`` < 16
82 """
83
84 def __init__(self, partition_points=None):
85 """Create a new ``PartitionPoints``.
86
87 :param partition_points: the input partition points to values mapping.
88 """
89 super().__init__()
90 if partition_points is not None:
91 for point, enabled in partition_points.items():
92 if not isinstance(point, int):
93 raise TypeError("point must be a non-negative integer")
94 if point < 0:
95 raise ValueError("point must be a non-negative integer")
96 self[point] = Value.cast(enabled)
97
98 def like(self, name=None, src_loc_at=0, mul=1):
99 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
100
101 :param name: the base name for the new ``Signal``s.
102 :param mul: a multiplication factor on the indices
103 """
104 if name is None:
105 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
106 retval = PartitionPoints()
107 for point, enabled in self.items():
108 point *= mul
109 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
110 return retval
111
112 def eq(self, rhs):
113 """Assign ``PartitionPoints`` using ``Signal.eq``."""
114 if set(self.keys()) != set(rhs.keys()):
115 raise ValueError("incompatible point set")
116 for point, enabled in self.items():
117 yield enabled.eq(rhs[point])
118
119 def as_mask(self, width, mul=1):
120 """Create a bit-mask from `self`.
121
122 Each bit in the returned mask is clear only if the partition point at
123 the same bit-index is enabled.
124
125 :param width: the bit width of the resulting mask
126 :param mul: a "multiplier" which in-place expands the partition points
127 typically set to "2" when used for multipliers
128 """
129 bits = []
130 for i in range(width):
131 i /= mul
132 if i.is_integer() and int(i) in self:
133 bits.append(~self[i])
134 else:
135 bits.append(True)
136 return Cat(*bits)
137
138 def as_sig(self):
139 """Create a straight concatenation of `self` signals
140 """
141 return Cat(self.values())
142
143 def get_max_partition_count(self, width):
144 """Get the maximum number of partitions.
145
146 Gets the number of partitions when all partition points are enabled.
147 """
148 retval = 1
149 for point in self.keys():
150 if point < width:
151 retval += 1
152 return retval
153
154 def fits_in_width(self, width):
155 """Check if all partition points are smaller than `width`."""
156 for point in self.keys():
157 if point >= width:
158 return False
159 return True
160
161 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
162 if index == -1 or index == 7:
163 return C(True, 1)
164 assert index >= 0 and index < 8
165 return self[(index * 8 + 8)*mfactor]
166
167