8e3677a523c8a4ff8aa7b890d048be89ba112208
[soc.git] / src / experiment / alu_hier.py
1 from nmigen import Elaboratable, Signal, Module, Const, Mux
2 from nmigen.cli import main
3 from nmigen.cli import verilog, rtlil
4
5 import operator
6
7
8 class Adder(Elaboratable):
9 def __init__(self, width):
10 self.a = Signal(width)
11 self.b = Signal(width)
12 self.o = Signal(width)
13
14 def elaborate(self, platform):
15 m = Module()
16 m.d.comb += self.o.eq(self.a + self.b)
17 return m
18
19
20 class Subtractor(Elaboratable):
21 def __init__(self, width):
22 self.a = Signal(width)
23 self.b = Signal(width)
24 self.o = Signal(width)
25
26 def elaborate(self, platform):
27 m = Module()
28 m.d.comb += self.o.eq(self.a - self.b)
29 return m
30
31
32 class Multiplier(Elaboratable):
33 def __init__(self, width):
34 self.a = Signal(width)
35 self.b = Signal(width)
36 self.o = Signal(width)
37
38 def elaborate(self, platform):
39 m = Module()
40 m.d.comb += self.o.eq(self.a * self.b)
41 return m
42
43
44 class Shifter(Elaboratable):
45 def __init__(self, width):
46 self.width = width
47 self.a = Signal(width)
48 self.b = Signal(width)
49 self.o = Signal(width)
50
51 def elaborate(self, platform):
52 m = Module()
53 btrunc = Signal(self.width)
54 m.d.comb += btrunc.eq(self.b & Const((1<<self.width)-1))
55 m.d.comb += self.o.eq(self.a >> btrunc)
56 return m
57
58
59 class ALU(Elaboratable):
60 def __init__(self, width):
61 self.p_valid_i = Signal()
62 self.p_ready_o = Signal()
63 self.n_ready_i = Signal()
64 self.n_valid_o = Signal()
65 self.counter = Signal(4)
66 self.op = Signal(2)
67 self.a = Signal(width)
68 self.b = Signal(width)
69 self.o = Signal(width)
70 self.width = width
71
72 def elaborate(self, platform):
73 m = Module()
74 add = Adder(self.width)
75 sub = Subtractor(self.width)
76 mul = Multiplier(self.width)
77 shf = Shifter(self.width)
78
79 m.submodules.add = add
80 m.submodules.sub = sub
81 m.submodules.mul = mul
82 m.submodules.shf = shf
83 for mod in [add, sub, mul, shf]:
84 m.d.comb += [
85 mod.a.eq(self.a),
86 mod.b.eq(self.b),
87 ]
88 go_now = Signal(reset_less=True) # testing no-delay ALU
89
90 with m.If(self.p_valid_i):
91 # input is valid. next check, if we already said "ready" or not
92 with m.If(~self.p_ready_o):
93 # we didn't say "ready" yet, so say so and initialise
94 m.d.sync += self.p_ready_o.eq(1)
95
96 # as this is a "fake" pipeline, just grab the output right now
97 with m.Switch(self.op):
98 for i, mod in enumerate([add, sub, mul, shf]):
99 with m.Case(i):
100 m.d.sync += self.o.eq(mod.o)
101 with m.If(self.op == 2): # MUL, to take 5 instructions
102 m.d.sync += self.counter.eq(5)
103 with m.Elif(self.op == 3): # SHIFT to take 7
104 m.d.sync += self.counter.eq(7)
105 with m.Elif(self.op == 1): # SUB to take 1, straight away
106 m.d.sync += self.counter.eq(1)
107 with m.Else(): # ADD to take 2
108 m.d.sync += self.counter.eq(2)
109 m.d.comb += go_now.eq(1)
110 with m.Else():
111 # input says no longer valid, so drop ready as well.
112 # a "proper" ALU would have had to sync in the opcode and a/b ops
113 m.d.sync += self.p_ready_o.eq(0)
114
115 # ok so the counter's running: when it gets to 1, fire the output
116 with m.If((self.counter == 1) | go_now):
117 # set the output as valid if the recipient is ready for it
118 m.d.sync += self.n_valid_o.eq(1)
119 with m.If(self.n_ready_i & self.n_valid_o):
120 m.d.sync += self.n_valid_o.eq(0)
121 # recipient said it was ready: reset back to known-good.
122 m.d.sync += self.counter.eq(0) # reset the counter
123 m.d.sync += self.o.eq(0) # clear the output for tidiness sake
124
125 # countdown to 1 (transition from 1 to 0 only on acknowledgement)
126 with m.If(self.counter > 1):
127 m.d.sync += self.counter.eq(self.counter - 1)
128
129 return m
130
131 def __iter__(self):
132 yield self.op
133 yield self.a
134 yield self.b
135 yield self.o
136
137 def ports(self):
138 return list(self)
139
140
141 class BranchOp(Elaboratable):
142 def __init__(self, width, op):
143 self.a = Signal(width)
144 self.b = Signal(width)
145 self.o = Signal(width)
146 self.op = op
147
148 def elaborate(self, platform):
149 m = Module()
150 m.d.comb += self.o.eq(Mux(self.op(self.a, self.b), 1, 0))
151 return m
152
153
154 class BranchALU(Elaboratable):
155 def __init__(self, width):
156 self.p_valid_i = Signal()
157 self.p_ready_o = Signal()
158 self.n_ready_i = Signal()
159 self.n_valid_o = Signal()
160 self.counter = Signal(4)
161 self.op = Signal(2)
162 self.a = Signal(width)
163 self.b = Signal(width)
164 self.o = Signal(width)
165 self.width = width
166
167 def elaborate(self, platform):
168 m = Module()
169 bgt = BranchOp(self.width, operator.gt)
170 blt = BranchOp(self.width, operator.lt)
171 beq = BranchOp(self.width, operator.eq)
172 bne = BranchOp(self.width, operator.ne)
173
174 m.submodules.bgt = bgt
175 m.submodules.blt = blt
176 m.submodules.beq = beq
177 m.submodules.bne = bne
178 for mod in [bgt, blt, beq, bne]:
179 m.d.comb += [
180 mod.a.eq(self.a),
181 mod.b.eq(self.b),
182 ]
183
184 go_now = Signal(reset_less=True) # testing no-delay ALU
185 with m.If(self.p_valid_i):
186 # input is valid. next check, if we already said "ready" or not
187 with m.If(~self.p_ready_o):
188 # we didn't say "ready" yet, so say so and initialise
189 m.d.sync += self.p_ready_o.eq(1)
190
191 # as this is a "fake" pipeline, just grab the output right now
192 with m.Switch(self.op):
193 for i, mod in enumerate([bgt, blt, beq, bne]):
194 with m.Case(i):
195 m.d.sync += self.o.eq(mod.o)
196 m.d.sync += self.counter.eq(5) # branch to take 5 cycles (fake)
197 #m.d.comb += go_now.eq(1)
198 with m.Else():
199 # input says no longer valid, so drop ready as well.
200 # a "proper" ALU would have had to sync in the opcode and a/b ops
201 m.d.sync += self.p_ready_o.eq(0)
202
203 # ok so the counter's running: when it gets to 1, fire the output
204 with m.If((self.counter == 1) | go_now):
205 # set the output as valid if the recipient is ready for it
206 m.d.sync += self.n_valid_o.eq(1)
207 with m.If(self.n_ready_i & self.n_valid_o):
208 m.d.sync += self.n_valid_o.eq(0)
209 # recipient said it was ready: reset back to known-good.
210 m.d.sync += self.counter.eq(0) # reset the counter
211 m.d.sync += self.o.eq(0) # clear the output for tidiness sake
212
213 # countdown to 1 (transition from 1 to 0 only on acknowledgement)
214 with m.If(self.counter > 1):
215 m.d.sync += self.counter.eq(self.counter - 1)
216
217 return m
218
219 def __iter__(self):
220 yield self.op
221 yield self.a
222 yield self.b
223 yield self.o
224
225 def ports(self):
226 return list(self)
227
228
229 if __name__ == "__main__":
230 alu = ALU(width=16)
231 vl = rtlil.convert(alu, ports=alu.ports())
232 with open("test_alu.il", "w") as f:
233 f.write(vl)
234
235 alu = BranchALU(width=16)
236 vl = rtlil.convert(alu, ports=alu.ports())
237 with open("test_branch_alu.il", "w") as f:
238 f.write(vl)
239