1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6 Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
8 dynamically partitionable shifter. Unlike part_shift_scalar, both
9 operands can be partitioned
13 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
14 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
16 from nmigen
import Signal
, Module
, Elaboratable
, Cat
, Mux
, C
17 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
18 from ieee754
.part_shift
.bitrev
import GatedBitReverse
21 class ShifterMask(Elaboratable
):
23 def __init__(self
, pwid
, bwid
, max_bits
, min_bits
):
24 self
.max_bits
= max_bits
25 self
.min_bits
= min_bits
27 self
.mask
= Signal(bwid
, reset_less
=True)
29 self
.gates
= Signal(pwid
, reset_less
=True)
31 def elaborate(self
, platform
):
35 minm
= (1<<self
.min_bits
)-1
36 maxm
= (1<<self
.max_bits
)-1
38 # zero-width mustn't try to do anything
40 comb
+= self
.mask
.eq(minm
)
44 bits
= Signal(self
.pwid
, reset_less
=True)
46 for j
in range(self
.pwid
):
47 bit
= Signal(self
.pwid
, name
="bit%d" % j
, reset_less
=True)
48 comb
+= bit
.eq(C(0, self
.pwid
))
50 comb
+= bit
.eq((~self
.gates
[j
]) & bl
[j
-1])
52 comb
+= bit
.eq(~self
.gates
[j
])
55 # XXX ARGH, really annoying: simulation bug, can't use Cat(*bl).
56 for j
in range(len(bits
)):
57 comb
+= bits
[j
].eq(bl
[j
])
58 comb
+= self
.mask
.eq(C(0, self
.mask
.shape()))
59 comb
+= self
.mask
.eq(Cat(minm
, bits
) & C(maxm
, self
.mask
.shape()))
64 class PartialResult(Elaboratable
):
65 def __init__(self
, pwid
, bwid
, reswid
):
69 self
.b
= Signal(bwid
, reset_less
=True)
70 self
.a_interval
= Signal(bwid
, reset_less
=True)
71 self
.gate
= Signal(reset_less
=True)
72 self
.partial
= Signal(reswid
, reset_less
=True)
74 def elaborate(self
, platform
):
78 shiftbits
= math
.ceil(math
.log2(self
.reswid
+1))+1 # hmmm...
79 print ("partial", self
.reswid
, self
.pwid
, shiftbits
)
82 # This calculates which partition of b to select the
83 # shifter from. According to the table above, the
84 # partition to select is given by the highest set bit in
85 # the partition mask, this calculates that with a mux
88 # This computes the partial results table. note that
89 # the shift amount is truncated because there's no point
90 # trying to shift data by 64 bits if the result width
92 shifter
= Signal(shiftbits
, reset_less
=True)
93 maxval
= C(self
.reswid
, element
.shape())
94 with m
.If(element
> maxval
):
95 comb
+= shifter
.eq(maxval
)
97 comb
+= shifter
.eq(element
)
98 comb
+= self
.partial
.eq(self
.a_interval
<< shifter
)
103 class PartitionedDynamicShift(Elaboratable
):
105 def __init__(self
, width
, partition_points
):
107 self
.partition_points
= PartitionPoints(partition_points
)
109 self
.a
= Signal(width
, reset_less
=True)
110 self
.b
= Signal(width
, reset_less
=True)
111 self
.shift_right
= Signal(reset_less
=True)
112 self
.output
= Signal(width
, reset_less
=True)
114 def elaborate(self
, platform
):
120 pwid
= self
.partition_points
.get_max_partition_count(width
)-1
121 gates
= Signal(pwid
, reset_less
=True)
122 comb
+= gates
.eq(self
.partition_points
.as_sig())
125 keys
= list(self
.partition_points
.keys()) + [self
.width
]
128 # create gated-reversed versions of a, b and the output
129 # left-shift is non-reversed, right-shift is reversed
130 m
.submodules
.a_br
= a_br
= GatedBitReverse(self
.a
.width
)
131 comb
+= a_br
.data
.eq(self
.a
)
132 comb
+= a_br
.reverse_en
.eq(self
.shift_right
)
134 m
.submodules
.out_br
= out_br
= GatedBitReverse(self
.output
.width
)
135 comb
+= out_br
.reverse_en
.eq(self
.shift_right
)
136 comb
+= self
.output
.eq(out_br
.output
)
138 m
.submodules
.gate_br
= gate_br
= GatedBitReverse(pwid
)
139 comb
+= gate_br
.data
.eq(gates
)
140 comb
+= gate_br
.reverse_en
.eq(self
.shift_right
)
142 # break out both the input and output into partition-stratified blocks
148 for i
in range(len(keys
)):
150 widths
.append(width
- start
)
151 a_intervals
.append(a_br
.output
[start
:end
])
152 b_intervals
.append(self
.b
[start
:end
])
153 intervals
.append([start
,end
])
156 min_bits
= math
.ceil(math
.log2(intervals
[0][1] - intervals
[0][0]))
158 # shifts are normally done as (e.g. for 32 bit) result = a &
159 # (b&0b11111) truncating the b input. however here of course
160 # the size of the partition varies dynamically.
162 for i
in range(len(b_intervals
)):
163 bwid
= len(b_intervals
[i
])
166 shifter_masks
.append(C((1<<min_bits
)-1, bwid
))
168 max_bits
= math
.ceil(math
.log2(width
-intervals
[i
][0]))
169 sm
= ShifterMask(bitwid
, bwid
, max_bits
, min_bits
)
170 setattr(m
.submodules
, "sm%d" % i
, sm
)
172 comb
+= sm
.gates
.eq(gates
[i
:pwid
])
173 shifter_masks
.append(sm
.mask
)
177 # Instead of generating the matrix described in the wiki, I
178 # instead calculate the shift amounts for each partition, then
179 # calculate the partial results of each partition << shift
180 # amount. On the wiki, the following table is given for output #3:
182 # 0 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
183 # 0 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
184 # 0 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
185 # 0 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
186 # 1 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
187 # 1 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
188 # 1 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
189 # 1 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
191 # Each output for o3 is given by a3bx and the partial results
192 # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
193 # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
194 # those partial results to calculate a0, a1, a2, and a3
197 for i
in range(0, len(keys
)):
198 masked
= Signal(b_intervals
[i
].shape(), name
="masked%d" % i
,
200 comb
+= masked
.eq(b_intervals
[i
] & shifter_masks
[i
])
201 masked_b
.append(masked
)
203 element
= Signal(b_intervals
[0].shape(), reset_less
=True)
204 comb
+= element
.eq(masked_b
[0])
205 b_shl_amount
.append(element
)
206 for i
in range(1, len(keys
)):
207 elem_sig
= Signal(b_intervals
[i
].shape(), reset_less
=True)
208 comb
+= elem_sig
.eq(Mux(gates
[i
-1], masked_b
[i
], element
))
210 b_shl_amount
.append(elem_sig
)
212 # because the right-shift input is reversed, we have to also
213 # reverse the *order* of the shift amounts (not the bits *in* the
215 b_shr_amount
= list(reversed(b_shl_amount
))
217 # select shift-amount (b) for partition based on op being left or right
219 for i
in range(len(b_shl_amount
)):
220 shift_amount
= Signal(masked_b
[i
].width
, name
="shift_amount%d" % i
,
222 sel
= Mux(self
.shift_right
, b_shr_amount
[i
], b_shl_amount
[i
])
223 comb
+= shift_amount
.eq(sel
)
224 shift_amounts
.append(shift_amount
)
226 # now calculate partial results
228 # first item (simple)
230 partial
= Signal(width
, name
="partial0", reset_less
=True)
231 comb
+= partial
.eq(a_intervals
[0] << shift_amounts
[0])
232 partial_results
.append(partial
)
235 for i
in range(1, len(keys
)):
236 reswid
= width
- intervals
[i
][0]
237 shiftbits
= math
.ceil(math
.log2(reswid
+1))+1 # hmmm...
238 print ("partial", reswid
, width
, intervals
[i
], shiftbits
)
240 pr
= PartialResult(pwid
, len(b_intervals
[i
]), reswid
)
241 setattr(m
.submodules
, "pr%d" % i
, pr
)
242 comb
+= pr
.gate
.eq(gate_br
.output
[i
-1])
243 comb
+= pr
.b
.eq(shift_amounts
[i
])
244 comb
+= pr
.a_interval
.eq(a_intervals
[i
])
245 partial_results
.append(pr
.partial
)
247 # This calculates the outputs o0-o3 from the partial results
248 # table above. Note: only relevant bits of the partial result equal
249 # to the width of the output column are accumulated in a Mux-cascade.
252 result
= partial_results
[0]
253 out
.append(result
[s
:e
])
254 for i
in range(1, len(keys
)):
255 start
, end
= (intervals
[i
][0], width
)
256 reswid
= width
- start
257 sel
= Mux(gate_br
.output
[i
-1], 0,
258 result
[intervals
[0][1]:][:end
-start
])
259 print("select: [%d:%d]" % (start
, end
))
260 res
= Signal(end
-start
+1, name
="res%d" % i
, reset_less
=True)
261 comb
+= res
.eq(partial_results
[i
] | sel
)
266 comb
+= out_br
.data
.eq(Cat(*out
))