split out adder code (PartitionedAdder) into module, PartitionPoints too
[ieee754fpu.git] / src / ieee754 / part_mul_add / partpoints.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Value, Cat, C
6
7
8 class PartitionPoints(dict):
9 """Partition points and corresponding ``Value``s.
10
11 The points at where an ALU is partitioned along with ``Value``s that
12 specify if the corresponding partition points are enabled.
13
14 For example: ``{1: True, 5: True, 10: True}`` with
15 ``width == 16`` specifies that the ALU is split into 4 sections:
16 * bits 0 <= ``i`` < 1
17 * bits 1 <= ``i`` < 5
18 * bits 5 <= ``i`` < 10
19 * bits 10 <= ``i`` < 16
20
21 If the partition_points were instead ``{1: True, 5: a, 10: True}``
22 where ``a`` is a 1-bit ``Signal``:
23 * If ``a`` is asserted:
24 * bits 0 <= ``i`` < 1
25 * bits 1 <= ``i`` < 5
26 * bits 5 <= ``i`` < 10
27 * bits 10 <= ``i`` < 16
28 * Otherwise
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 10
31 * bits 10 <= ``i`` < 16
32 """
33
34 def __init__(self, partition_points=None):
35 """Create a new ``PartitionPoints``.
36
37 :param partition_points: the input partition points to values mapping.
38 """
39 super().__init__()
40 if partition_points is not None:
41 for point, enabled in partition_points.items():
42 if not isinstance(point, int):
43 raise TypeError("point must be a non-negative integer")
44 if point < 0:
45 raise ValueError("point must be a non-negative integer")
46 self[point] = Value.wrap(enabled)
47
48 def like(self, name=None, src_loc_at=0, mul=1):
49 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
50
51 :param name: the base name for the new ``Signal``s.
52 :param mul: a multiplication factor on the indices
53 """
54 if name is None:
55 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
56 retval = PartitionPoints()
57 for point, enabled in self.items():
58 point *= mul
59 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
60 return retval
61
62 def eq(self, rhs):
63 """Assign ``PartitionPoints`` using ``Signal.eq``."""
64 if set(self.keys()) != set(rhs.keys()):
65 raise ValueError("incompatible point set")
66 for point, enabled in self.items():
67 yield enabled.eq(rhs[point])
68
69 def as_mask(self, width, mul=1):
70 """Create a bit-mask from `self`.
71
72 Each bit in the returned mask is clear only if the partition point at
73 the same bit-index is enabled.
74
75 :param width: the bit width of the resulting mask
76 :param mul: a "multiplier" which in-place expands the partition points
77 typically set to "2" when used for multipliers
78 """
79 bits = []
80 for i in range(width):
81 i /= mul
82 if i.is_integer() and int(i) in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
107 if index == -1 or index == 7:
108 return C(True, 1)
109 assert index >= 0 and index < 8
110 return self[(index * 8 + 8)*mfactor]
111
112