Use len(sig) instead of sig.shape()[0]
[ieee754fpu.git] / src / ieee754 / part_mux / part_mux.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamically-partitionable "comparison" class, directly equivalent
8 to Signal.__eq__ except SIMD-partitionable
9
10 See:
11
12 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/mux
13 * http://bugs.libre-riscv.org/show_bug.cgi?id=132
14 """
15
16 from nmigen import Signal, Module, Elaboratable, Mux
17 from ieee754.part_mul_add.partpoints import PartitionPoints
18 from ieee754.part_mul_add.partpoints import make_partition
19
20 modcount = 0 # global for now
21 def PMux(m, mask, sel, a, b):
22 global modcount
23 modcount += 1
24 width = len(a.sig) # get width
25 part_pts = make_partition(mask, width) # create partition points
26 pm = PartitionedMux(width, part_pts)
27 m.d.comb += pm.a.eq(a.sig)
28 m.d.comb += pm.b.eq(b.sig)
29 m.d.comb += pm.sel.eq(sel)
30 setattr(m.submodules, "pmux%d" % modcount, pm)
31 return pm.output
32
33 class PartitionedMux(Elaboratable):
34 """PartitionedMux: Partitioned "Mux"
35
36 takes a partition point set, subdivides a and b into blocks
37 and "selects" them. the assumption is that "sel" has had
38 its LSB propagated up throughout the entire partition, and
39 consequently the incoming selector (sel) can completely
40 ignore what the *actual* partition bits are.
41 """
42 def __init__(self, width, partition_points):
43 self.width = width
44 self.partition_points = PartitionPoints(partition_points)
45 self.mwidth = len(self.partition_points)+1
46 self.a = Signal(width, reset_less=True)
47 self.b = Signal(width, reset_less=True)
48 self.sel = Signal(self.mwidth, reset_less=True)
49 self.output = Signal(width, reset_less=True)
50 assert self.partition_points.fits_in_width(width), \
51 "partition_points doesn't fit in width"
52
53 def elaborate(self, platform):
54 m = Module()
55 comb = m.d.comb
56
57 # loop across all partition ranges.
58 # drop the selection directly into the output.
59 keys = list(self.partition_points.keys()) + [self.width]
60 stt = 0
61 for i in range(len(keys)):
62 end = keys[i]
63 mux = self.output[stt:end]
64 comb += mux.eq(Mux(self.sel[i], self.a[stt:end], self.b[stt:end]))
65 stt = end # for next time round loop
66
67 return m
68
69 def ports(self):
70 return [self.a.sig, self.b.sig, self.sel, self.output]
71