a47fefdc2eaac394c87b86274467c0f5a556ae54
[soc.git] / src / soc / fu / div / fsm.py
1 import enum
2 from nmigen import Elaboratable, Module, Signal, Shape, unsigned, Cat, Mux
3 from soc.fu.div.pipe_data import CoreInputData, CoreOutputData, DivPipeSpec
4 from nmutil.iocontrol import PrevControl, NextControl
5 from nmutil.singlepipe import ControlBase
6 from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation
7
8
9 class FSMDivCoreConfig:
10 n_stages = 1
11 bit_width = 64
12 fract_width = 64
13
14
15 class FSMDivCoreInputData:
16 def __init__(self, core_config, reset_less=True):
17 self.core_config = core_config
18 self.dividend = Signal(128, reset_less=reset_less)
19 self.divisor_radicand = Signal(64, reset_less=reset_less)
20 self.operation = DivPipeCoreOperation.create_signal(
21 reset_less=reset_less)
22
23 def __iter__(self):
24 """ Get member signals. """
25 yield self.dividend
26 yield self.divisor_radicand
27 yield self.operation
28
29 def eq(self, rhs):
30 """ Assign member signals. """
31 return [self.dividend.eq(rhs.dividend),
32 self.divisor_radicand.eq(rhs.divisor_radicand),
33 self.operation.eq(rhs.operation),
34 ]
35
36
37 class FSMDivCoreOutputData:
38 def __init__(self, core_config, reset_less=True):
39 self.core_config = core_config
40 self.quotient_root = Signal(64, reset_less=reset_less)
41 self.remainder = Signal(3 * 64, reset_less=reset_less)
42
43 def __iter__(self):
44 """ Get member signals. """
45 yield self.quotient_root
46 yield self.remainder
47 return
48
49 def eq(self, rhs):
50 """ Assign member signals. """
51 return [self.quotient_root.eq(rhs.quotient_root),
52 self.remainder.eq(rhs.remainder)]
53
54
55 class FSMDivCorePrevControl(PrevControl):
56 def __init__(self, pspec):
57 super().__init__()
58 self.pspec = pspec
59 self.data_i = CoreInputData(pspec)
60
61
62 class FSMDivCoreNextControl(NextControl):
63 def __init__(self, pspec):
64 super().__init__()
65 self.pspec = pspec
66 self.data_o = CoreOutputData(pspec)
67
68
69 class DivStateNext(Elaboratable):
70 def __init__(self, quotient_width):
71 self.quotient_width = quotient_width
72 self.i = DivState(quotient_width=quotient_width, name="i")
73 self.divisor = Signal(quotient_width)
74 self.o = DivState(quotient_width=quotient_width, name="o")
75
76 def elaborate(self, platform):
77 m = Module()
78 difference = Signal(self.i.quotient_width * 2)
79 m.d.comb += difference.eq(self.i.dividend_quotient
80 - (self.divisor
81 << (self.quotient_width - 1)))
82 next_quotient_bit = Signal()
83 m.d.comb += next_quotient_bit.eq(
84 ~difference[self.quotient_width * 2 - 1])
85 value = Signal(self.i.quotient_width * 2)
86 with m.If(next_quotient_bit):
87 m.d.comb += value.eq(difference)
88 with m.Else():
89 m.d.comb += value.eq(self.i.dividend_quotient)
90
91 with m.If(self.i.done):
92 m.d.comb += self.o.eq(self.i)
93 with m.Else():
94 m.d.comb += [
95 self.o.q_bits_known.eq(self.i.q_bits_known + 1),
96 self.o.dividend_quotient.eq(Cat(next_quotient_bit, value))]
97 return m
98
99
100 class DivStateInit(Elaboratable):
101 def __init__(self, quotient_width):
102 self.quotient_width = quotient_width
103 self.dividend = Signal(quotient_width * 2)
104 self.o = DivState(quotient_width=quotient_width, name="o")
105
106 def elaborate(self, platform):
107 m = Module()
108 m.d.comb += self.o.q_bits_known.eq(0)
109 m.d.comb += self.o.dividend_quotient.eq(self.dividend)
110 return m
111
112
113 class DivState:
114 def __init__(self, quotient_width, name):
115 self.quotient_width = quotient_width
116 self.q_bits_known = Signal(range(1 + quotient_width),
117 name=name + "_q_bits_known")
118 self.dividend_quotient = Signal(unsigned(2 * quotient_width),
119 name=name + "_dividend_quotient")
120
121 @property
122 def done(self):
123 return self.q_bits_known == self.quotient_width
124
125 @property
126 def quotient(self):
127 """ get the quotient -- requires self.done is True """
128 return self.dividend_quotient[0:self.quotient_width]
129
130 @property
131 def remainder(self):
132 """ get the remainder -- requires self.done is True """
133 return self.dividend_quotient[self.quotient_width:self.quotient_width*2]
134
135 def eq(self, rhs):
136 return [self.q_bits_known.eq(rhs.q_bits_known),
137 self.dividend_quotient.eq(rhs.dividend_quotient)]
138
139
140 class FSMDivCoreStage(ControlBase):
141 def __init__(self, pspec: DivPipeSpec):
142 super().__init__()
143 self.pspec = pspec
144 # override p and n
145 self.p = FSMDivCorePrevControl(pspec)
146 self.n = FSMDivCoreNextControl(pspec)
147 self.saved_input_data = CoreInputData(pspec)
148 self.empty = Signal(reset=1)
149 self.saved_state = DivState(64, name="saved_state")
150 self.div_state_next = DivStateNext(64)
151 self.div_state_init = DivStateInit(64)
152 self.divisor = Signal(unsigned(64))
153
154 def elaborate(self, platform):
155 m = super().elaborate(platform)
156 m.submodules.div_state_next = self.div_state_next
157 m.submodules.div_state_init = self.div_state_init
158 data_i = self.p.data_i
159 core_i: FSMDivCoreInputData = data_i.core
160 data_o = self.n.data_o
161 core_o: FSMDivCoreOutputData = data_o.core
162 core_saved_i: FSMDivCoreInputData = self.saved_input_data.core
163
164 # TODO: handle cancellation
165
166 m.d.comb += self.div_state_init.dividend.eq(core_i.dividend)
167
168 m.d.comb += data_o.eq_without_core(self.saved_input_data)
169 m.d.comb += core_o.quotient_root.eq(self.div_state_next.o.quotient)
170 # fract width of `DivPipeCoreOutputData.remainder`
171 remainder_fract_width = 64 * 3
172 # fract width of `DivPipeCoreInputData.dividend`
173 dividend_fract_width = 64 * 2
174 rem_start = remainder_fract_width - dividend_fract_width
175 m.d.comb += core_o.remainder.eq(self.div_state_next.o.remainder
176 << rem_start)
177 m.d.comb += self.n.valid_o.eq(~self.empty & self.div_state_next.o.done)
178 m.d.comb += self.p.ready_o.eq(self.empty)
179 m.d.sync += self.saved_state.eq(self.div_state_next.o)
180
181 with m.If(self.empty):
182 m.d.comb += self.div_state_next.i.eq(self.div_state_init.o)
183 m.d.comb += self.div_state_next.divisor.eq(core_i.divisor_radicand)
184 with m.If(self.p.valid_i):
185 m.d.sync += self.empty.eq(0)
186 m.d.sync += self.saved_input_data.eq(data_i)
187 with m.Else():
188 m.d.comb += [
189 self.div_state_next.i.eq(self.saved_state),
190 self.div_state_next.divisor.eq(core_saved_i.divisor_radicand)]
191 with m.If(self.n.ready_i & self.n.valid_o):
192 m.d.sync += self.empty.eq(1)
193
194 return m
195
196 def __iter__(self):
197 yield from self.p
198 yield from self.n
199
200 def ports(self):
201 return list(self)