deal with zero-width ShiftMask
[ieee754fpu.git] / src / ieee754 / part_shift / part_shift_dynamic.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6 Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
7
8 dynamically partitionable shifter. Unlike part_shift_scalar, both
9 operands can be partitioned
10
11 See:
12
13 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
14 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
15 """
16 from nmigen import Signal, Module, Elaboratable, Cat, Mux, C
17 from ieee754.part_mul_add.partpoints import PartitionPoints
18 import math
19
20 class ShifterMask(Elaboratable):
21 def __init__(self, pwid, bwid, max_bits, min_bits):
22 self.max_bits = max_bits
23 self.min_bits = min_bits
24 self.pwid = pwid
25 self.mask = Signal(bwid, reset_less=True)
26 self.gates = Signal(pwid, reset_less=True)
27
28 def elaborate(self, platform):
29 m = Module()
30 comb = m.d.comb
31
32 # zero-width mustn't try to do anything
33 if self.pwid == 0:
34 self.mask.eq((1<<min_bits)-1)
35 return m
36
37 bits = Signal(self.pwid, reset_less=True)
38 bl = []
39 for j in range(self.pwid):
40 if j != 0:
41 bl.append((~self.gates[j]) & bits[j-1])
42 else:
43 bl.append(~self.gates[j])
44 # XXX ARGH, really annoying: simulation bug, can't use Cat(*bl).
45 for j in range(bits.shape()[0]):
46 comb += bits[j].eq(bl[j])
47 comb += self.mask.eq(Cat((1 << self.min_bits)-1, bits)
48 & ((1 << self.max_bits)-1))
49
50 return m
51
52
53 class PartialResult(Elaboratable):
54 def __init__(self, pwid, bwid, reswid):
55 self.pwid = pwid
56 self.bwid = bwid
57 self.reswid = reswid
58 self.element = Signal(bwid, reset_less=True)
59 self.elmux = Signal(bwid, reset_less=True)
60 self.a_interval = Signal(bwid, reset_less=True)
61 self.masked = Signal(bwid, reset_less=True)
62 self.gate = Signal(reset_less=True)
63 self.partial = Signal(reswid, reset_less=True)
64
65 def elaborate(self, platform):
66 m = Module()
67 comb = m.d.comb
68
69 shiftbits = math.ceil(math.log2(self.reswid+1))+1 # hmmm...
70 print ("partial", self.reswid, self.pwid, shiftbits)
71 element = Mux(self.gate, self.masked, self.element)
72 comb += self.elmux.eq(element)
73 element = self.elmux
74
75 # This calculates which partition of b to select the
76 # shifter from. According to the table above, the
77 # partition to select is given by the highest set bit in
78 # the partition mask, this calculates that with a mux
79 # chain
80
81 # This computes the partial results table. note that
82 # the shift amount is truncated because there's no point
83 # trying to shift data by 64 bits if the result width
84 # is only 8.
85 shifter = Signal(shiftbits, reset_less=True)
86 maxval = C(self.reswid, element.shape())
87 with m.If(element > maxval):
88 comb += shifter.eq(maxval)
89 with m.Else():
90 comb += shifter.eq(element)
91 comb += self.partial.eq(self.a_interval << shifter)
92
93 return m
94
95
96 class PartitionedDynamicShift(Elaboratable):
97 def __init__(self, width, partition_points):
98 self.width = width
99 self.partition_points = PartitionPoints(partition_points)
100
101 self.a = Signal(width, reset_less=True)
102 self.b = Signal(width, reset_less=True)
103 self.output = Signal(width, reset_less=True)
104
105 def elaborate(self, platform):
106 m = Module()
107 comb = m.d.comb
108 width = self.width
109 pwid = self.partition_points.get_max_partition_count(width)-1
110 gates = Signal(pwid, reset_less=True)
111 comb += gates.eq(self.partition_points.as_sig())
112
113 matrix = []
114 keys = list(self.partition_points.keys()) + [self.width]
115 start = 0
116
117 # break out both the input and output into partition-stratified blocks
118 a_intervals = []
119 b_intervals = []
120 intervals = []
121 widths = []
122 start = 0
123 for i in range(len(keys)):
124 end = keys[i]
125 widths.append(width - start)
126 a_intervals.append(self.a[start:end])
127 b_intervals.append(self.b[start:end])
128 intervals.append([start,end])
129 start = end
130
131 min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
132
133 # shifts are normally done as (e.g. for 32 bit) result = a & (b&0b11111)
134 # truncating the b input. however here of course the size of the
135 # partition varies dynamically.
136 shifter_masks = []
137 for i in range(len(b_intervals)):
138 max_bits = math.ceil(math.log2(width-intervals[i][0]))
139 sm = ShifterMask(pwid-i, b_intervals[i].shape()[0],
140 max_bits, min_bits)
141 setattr(m.submodules, "sm%d" % i, sm)
142 comb += sm.gates.eq(gates[i:pwid])
143 shifter_masks.append(sm.mask)
144
145 print(shifter_masks)
146
147 # Instead of generating the matrix described in the wiki, I
148 # instead calculate the shift amounts for each partition, then
149 # calculate the partial results of each partition << shift
150 # amount. On the wiki, the following table is given for output #3:
151 # p2p1p0 | o3
152 # 0 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
153 # 0 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
154 # 0 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
155 # 0 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
156 # 1 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
157 # 1 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
158 # 1 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
159 # 1 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
160
161 # Each output for o3 is given by a3bx and the partial results
162 # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
163 # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
164 # those partial results to calculate a0, a1, a2, and a3
165 element = Signal(b_intervals[0].shape(), reset_less=True)
166 comb += element.eq(b_intervals[0] & shifter_masks[0])
167 partial_results = []
168 partial = Signal(width, name="partial0", reset_less=True)
169 comb += partial.eq(a_intervals[0] << element)
170 partial_results.append(partial)
171 for i in range(1, len(keys)):
172 reswid = width - intervals[i][0]
173 shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
174 print ("partial", reswid, width, intervals[i], shiftbits)
175 s, e = intervals[i]
176 pr = PartialResult(pwid, b_intervals[i].shape()[0], reswid)
177 setattr(m.submodules, "pr%d" % i, pr)
178 masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
179 reset_less=True)
180 comb += pr.masked.eq(b_intervals[i] & shifter_masks[i])
181 comb += pr.gate.eq(gates[i-1])
182 comb += pr.element.eq(element)
183 comb += pr.a_interval.eq(a_intervals[i])
184 partial_results.append(pr.partial)
185 element = pr.elmux
186
187 out = []
188
189 # This calculates the outputs o0-o3 from the partial results
190 # table above. Note: only relevant bits of the partial result equal
191 # to the width of the output column are accumulated in a Mux-cascade.
192 s,e = intervals[0]
193 result = partial_results[0]
194 out.append(result[s:e])
195 for i in range(1, len(keys)):
196 start, end = (intervals[i][0], width)
197 reswid = width - start
198 sel = Mux(gates[i-1], 0, result[intervals[0][1]:][:end-start])
199 print("select: [%d:%d]" % (start, end))
200 res = Signal(end-start+1, name="res%d" % i, reset_less=True)
201 comb += res.eq(partial_results[i] | sel)
202 result = res
203 s,e = intervals[0]
204 out.append(res[s:e])
205
206 comb += self.output.eq(Cat(*out))
207
208 return m
209