check that bgt test ALU works
[soc.git] / src / experiment / alu_hier.py
1 from nmigen import Elaboratable, Signal, Module, Const, Mux
2 from nmigen.cli import main
3 from nmigen.cli import verilog, rtlil
4
5 import operator
6
7
8 class Adder(Elaboratable):
9 def __init__(self, width):
10 self.a = Signal(width)
11 self.b = Signal(width)
12 self.o = Signal(width)
13
14 def elaborate(self, platform):
15 m = Module()
16 m.d.comb += self.o.eq(self.a + self.b)
17 return m
18
19
20 class Subtractor(Elaboratable):
21 def __init__(self, width):
22 self.a = Signal(width)
23 self.b = Signal(width)
24 self.o = Signal(width)
25
26 def elaborate(self, platform):
27 m = Module()
28 m.d.comb += self.o.eq(self.a - self.b)
29 return m
30
31
32 class Multiplier(Elaboratable):
33 def __init__(self, width):
34 self.a = Signal(width)
35 self.b = Signal(width)
36 self.o = Signal(width)
37
38 def elaborate(self, platform):
39 m = Module()
40 m.d.comb += self.o.eq(self.a * self.b)
41 return m
42
43
44 class Shifter(Elaboratable):
45 def __init__(self, width):
46 self.width = width
47 self.a = Signal(width)
48 self.b = Signal(width)
49 self.o = Signal(width)
50
51 def elaborate(self, platform):
52 m = Module()
53 btrunc = Signal(self.width)
54 m.d.comb += btrunc.eq(self.b & Const((1<<self.width)-1))
55 m.d.comb += self.o.eq(self.a >> btrunc)
56 return m
57
58
59 class ALU(Elaboratable):
60 def __init__(self, width):
61 self.op = Signal(2)
62 self.a = Signal(width)
63 self.b = Signal(width)
64 self.o = Signal(width)
65 self.width = width
66
67 def elaborate(self, platform):
68 m = Module()
69 add = Adder(self.width)
70 sub = Subtractor(self.width)
71 mul = Multiplier(self.width)
72 shf = Shifter(self.width)
73
74 m.submodules.add = add
75 m.submodules.sub = sub
76 m.submodules.mul = mul
77 m.submodules.shf = shf
78 for mod in [add, sub, mul, shf]:
79 m.d.comb += [
80 mod.a.eq(self.a),
81 mod.b.eq(self.b),
82 ]
83 with m.Switch(self.op):
84 for i, mod in enumerate([add, sub, mul, shf]):
85 with m.Case(i):
86 m.d.comb += self.o.eq(mod.o)
87 return m
88
89 def __iter__(self):
90 yield self.op
91 yield self.a
92 yield self.b
93 yield self.o
94
95 def ports(self):
96 return list(self)
97
98
99 class BranchOp(Elaboratable):
100 def __init__(self, width, op):
101 self.a = Signal(width)
102 self.b = Signal(width)
103 self.o = Signal(width)
104 self.op = op
105
106 def elaborate(self, platform):
107 m = Module()
108 m.d.comb += self.o.eq(Mux(self.op(self.a, self.b), 1, 0))
109 return m
110
111
112 class BranchALU(Elaboratable):
113 def __init__(self, width):
114 self.op = Signal(2)
115 self.a = Signal(width)
116 self.b = Signal(width)
117 self.o = Signal(width)
118 self.width = width
119
120 def elaborate(self, platform):
121 m = Module()
122 bgt = BranchOp(self.width, operator.gt)
123 blt = BranchOp(self.width, operator.lt)
124 beq = BranchOp(self.width, operator.eq)
125 bne = BranchOp(self.width, operator.ne)
126
127 m.submodules.bgt = bgt
128 m.submodules.blt = blt
129 m.submodules.beq = beq
130 m.submodules.bne = bne
131 for mod in [bgt, blt, beq, bne]:
132 m.d.comb += [
133 mod.a.eq(self.a),
134 mod.b.eq(self.b),
135 ]
136 with m.Switch(self.op):
137 for i, mod in enumerate([bgt, blt, beq, bne]):
138 with m.Case(i):
139 m.d.comb += self.o.eq(mod.o)
140 return m
141
142 def __iter__(self):
143 yield self.op
144 yield self.a
145 yield self.b
146 yield self.o
147
148 def ports(self):
149 return list(self)
150
151
152 if __name__ == "__main__":
153 alu = ALU(width=16)
154 vl = rtlil.convert(alu, ports=alu.ports())
155 with open("test_alu.il", "w") as f:
156 f.write(vl)
157
158 alu = BranchALU(width=16)
159 vl = rtlil.convert(alu, ports=alu.ports())
160 with open("test_branch_alu.il", "w") as f:
161 f.write(vl)
162