1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Value
, Cat
, C
8 def make_partition(mask
, width
):
9 """ from a mask and a bitwidth, create partition points.
10 note that the assumption is that the mask indicates the
11 breakpoints in regular intervals, and that the last bit (MSB)
12 of the mask is therefore *ignored*.
13 mask len = 4, width == 16 will return:
14 {4: mask[0], 8: mask[1], 12: mask[2]}
15 mask len = 8, width == 64 will return:
16 {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
22 while ppos
< width
and midx
< mlen
: # -1, ignore last bit
23 ppoints
[ppos
] = mask
[midx
]
29 def make_partition2(mask
, width
):
30 """ from a mask and a bitwidth, create partition points.
31 note that the mask represents the actual partition points
32 and therefore must be ONE LESS than the number of required
35 mask len = 3, width == 16 will return:
36 {4: mask[0], 8: mask[1], 12: mask[2]}
37 mask len = 7, width == 64 will return:
38 {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
40 if isinstance(mask
, dict): # convert dict/partpoints to sequential list
41 mask
= list(mask
.values())
42 mlen
= len(mask
) + 1 # ONE MORE partitions than break-points
43 jumpsize
= width
// mlen
# amount to jump by (size of each partition)
46 print ("make_partition2", width
, mask
, len(mask
), mlen
, jumpsize
)
47 assert jumpsize
> 0, "incorrect width // mlen (%d // %d)" % (width
, mlen
)
49 while ppos
< width
and midx
< mlen
: # -1, ignore last bit
50 print (" make_partition2", ppos
, width
, midx
, mlen
)
51 ppoints
[ppos
] = mask
[midx
]
54 print (" make_partition2", mask
, width
, ppoints
)
58 class PartitionPoints(dict):
59 """Partition points and corresponding ``Value``s.
61 The points at where an ALU is partitioned along with ``Value``s that
62 specify if the corresponding partition points are enabled.
64 For example: ``{1: True, 5: True, 10: True}`` with
65 ``width == 16`` specifies that the ALU is split into 4 sections:
68 * bits 5 <= ``i`` < 10
69 * bits 10 <= ``i`` < 16
71 If the partition_points were instead ``{1: True, 5: a, 10: True}``
72 where ``a`` is a 1-bit ``Signal``:
73 * If ``a`` is asserted:
76 * bits 5 <= ``i`` < 10
77 * bits 10 <= ``i`` < 16
80 * bits 1 <= ``i`` < 10
81 * bits 10 <= ``i`` < 16
84 def __init__(self
, partition_points
=None):
85 """Create a new ``PartitionPoints``.
87 :param partition_points: the input partition points to values mapping.
90 if partition_points
is not None:
91 for point
, enabled
in partition_points
.items():
92 if not isinstance(point
, int):
93 raise TypeError("point must be a non-negative integer")
95 raise ValueError("point must be a non-negative integer")
96 self
[point
] = Value
.cast(enabled
)
98 def like(self
, name
=None, src_loc_at
=0, mul
=1):
99 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
101 :param name: the base name for the new ``Signal``s.
102 :param mul: a multiplication factor on the indices
105 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
106 retval
= PartitionPoints()
107 for point
, enabled
in self
.items():
109 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
113 """Assign ``PartitionPoints`` using ``Signal.eq``."""
114 if set(self
.keys()) != set(rhs
.keys()):
115 raise ValueError("incompatible point set")
116 for point
, enabled
in self
.items():
117 yield enabled
.eq(rhs
[point
])
119 def as_mask(self
, width
, mul
=1):
120 """Create a bit-mask from `self`.
122 Each bit in the returned mask is clear only if the partition point at
123 the same bit-index is enabled.
125 :param width: the bit width of the resulting mask
126 :param mul: a "multiplier" which in-place expands the partition points
127 typically set to "2" when used for multipliers
130 for i
in range(width
):
132 if i
.is_integer() and int(i
) in self
:
133 bits
.append(~self
[i
])
139 """Create a straight concatenation of `self` signals
141 return Cat(self
.values())
143 def get_max_partition_count(self
, width
):
144 """Get the maximum number of partitions.
146 Gets the number of partitions when all partition points are enabled.
149 for point
in self
.keys():
154 def fits_in_width(self
, width
):
155 """Check if all partition points are smaller than `width`."""
156 for point
in self
.keys():
161 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
162 if index
== -1 or index
== 7:
164 assert index
>= 0 and index
< 8
165 return self
[(index
* 8 + 8)*mfactor
]