1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
4 from nmigen
.hdl
.ast
import Array
, Cat
, Repl
, Signal
5 from nmigen
.hdl
.dsl
import Module
6 from nmigen
.hdl
.ir
import Elaboratable
9 class BitwiseMux(Elaboratable
):
10 """ Mux, but treating input/output Signals as bit vectors, rather than
11 integers. This means each bit in the output is independently multiplexed
12 based on the corresponding bit in each of the inputs.
15 def __init__(self
, width
):
16 self
.sel
= Signal(width
)
17 self
.t
= Signal(width
)
18 self
.f
= Signal(width
)
19 self
.output
= Signal(width
)
21 def elaborate(self
, platform
):
23 m
.d
.comb
+= self
.output
.eq((~self
.sel
& self
.f
) |
(self
.sel
& self
.t
))
27 class BitwiseLut(Elaboratable
):
28 def __init__(self
, input_count
, width
):
29 assert isinstance(input_count
, int)
30 assert isinstance(width
, int)
31 self
.input_count
= input_count
35 return Signal(width
, name
=f
"input{i}")
36 self
.inputs
= tuple(inp(i
) for i
in range(input_count
))
37 self
.output
= Signal(width
)
38 self
.lut
= Signal(2 ** input_count
)
41 return Signal(input_count
, name
=f
"lut_index_{i}")
42 self
._lut
_indexes
= [lut_index(i
) for i
in range(width
)]
44 def elaborate(self
, platform
):
46 lut
= Array(self
.lut
[i
] for i
in range(self
.lut
.width
))
47 for i
in range(self
.width
):
48 for j
in range(self
.input_count
):
49 m
.d
.comb
+= self
._lut
_indexes
[i
][j
].eq(self
.inputs
[j
][i
])
50 m
.d
.comb
+= self
.output
[i
].eq(lut
[self
._lut
_indexes
[i
]])
54 class TreeBitwiseLut(Elaboratable
):
55 """tree-based version of BitwiseLut"""
57 def __init__(self
, input_count
, width
):
58 assert isinstance(input_count
, int)
59 assert isinstance(width
, int)
60 self
.input_count
= input_count
64 return Signal(width
, name
=f
"input{i}")
65 self
.inputs
= tuple(inp(i
) for i
in range(input_count
))
66 self
.output
= Signal(width
)
67 self
.lut
= Signal(2 ** input_count
)
69 self
._build
_mux
_inputs
()
71 def _make_key_str(self
, *sel_values
):
72 k
= ['x'] * self
.input_count
73 for i
, v
in enumerate(sel_values
):
74 k
[i
] = '1' if v
else '0'
75 return '0b' + ''.join(reversed(k
))
77 def _build_mux_inputs(self
, *sel_values
):
78 name
= f
"mux_input_{self._make_key_str(*sel_values)}"
79 self
._mux
_inputs
[sel_values
] = Signal(self
.width
, name
=name
)
80 if len(sel_values
) < self
.input_count
:
81 self
._build
_mux
_inputs
(*sel_values
, False)
82 self
._build
_mux
_inputs
(*sel_values
, True)
84 def elaborate(self
, platform
):
86 m
.d
.comb
+= self
.output
.eq(self
._mux
_inputs
[()])
87 for sel_values
, v
in self
._mux
_inputs
.items():
88 if len(sel_values
) < self
.input_count
:
89 mux_name
= f
"mux_{self._make_key_str(*sel_values)}"
90 mux
= BitwiseMux(self
.width
)
91 setattr(m
.submodules
, mux_name
, mux
)
93 mux
.f
.eq(self
._mux
_inputs
[(*sel_values
, False)]),
94 mux
.t
.eq(self
._mux
_inputs
[(*sel_values
, True)]),
95 mux
.sel
.eq(self
.inputs
[len(sel_values
)]),
100 for i
in range(self
.input_count
):
103 m
.d
.comb
+= v
.eq(Repl(self
.lut
[lut_index
], self
.width
))