394d3497bb2d7f0b3345932063742541c6278a6c
[nmigen.git] / nmigen / test / test_hdl_xfrm.py
1 # nmigen: UnusedElaboratable=no
2
3 from ..hdl.ast import *
4 from ..hdl.cd import *
5 from ..hdl.ir import *
6 from ..hdl.xfrm import *
7 from ..hdl.mem import *
8 from .utils import *
9
10
11 class DomainRenamerTestCase(FHDLTestCase):
12 def setUp(self):
13 self.s1 = Signal()
14 self.s2 = Signal()
15 self.s3 = Signal()
16 self.s4 = Signal()
17 self.s5 = Signal()
18 self.c1 = Signal()
19
20 def test_rename_signals(self):
21 f = Fragment()
22 f.add_statements(
23 self.s1.eq(ClockSignal()),
24 ResetSignal().eq(self.s2),
25 self.s3.eq(0),
26 self.s4.eq(ClockSignal("other")),
27 self.s5.eq(ResetSignal("other")),
28 )
29 f.add_driver(self.s1, None)
30 f.add_driver(self.s2, None)
31 f.add_driver(self.s3, "sync")
32
33 f = DomainRenamer("pix")(f)
34 self.assertRepr(f.statements, """
35 (
36 (eq (sig s1) (clk pix))
37 (eq (rst pix) (sig s2))
38 (eq (sig s3) (const 1'd0))
39 (eq (sig s4) (clk other))
40 (eq (sig s5) (rst other))
41 )
42 """)
43 self.assertEqual(f.drivers, {
44 None: SignalSet((self.s1, self.s2)),
45 "pix": SignalSet((self.s3,)),
46 })
47
48 def test_rename_multi(self):
49 f = Fragment()
50 f.add_statements(
51 self.s1.eq(ClockSignal()),
52 self.s2.eq(ResetSignal("other")),
53 )
54
55 f = DomainRenamer({"sync": "pix", "other": "pix2"})(f)
56 self.assertRepr(f.statements, """
57 (
58 (eq (sig s1) (clk pix))
59 (eq (sig s2) (rst pix2))
60 )
61 """)
62
63 def test_rename_cd(self):
64 cd_sync = ClockDomain()
65 cd_pix = ClockDomain()
66
67 f = Fragment()
68 f.add_domains(cd_sync, cd_pix)
69
70 f = DomainRenamer("ext")(f)
71 self.assertEqual(cd_sync.name, "ext")
72 self.assertEqual(f.domains, {
73 "ext": cd_sync,
74 "pix": cd_pix,
75 })
76
77 def test_rename_cd_preserves_allow_reset_less(self):
78 cd_pix = ClockDomain(reset_less=True)
79
80 f = Fragment()
81 f.add_domains(cd_pix)
82 f.add_statements(
83 self.s1.eq(ResetSignal(allow_reset_less=True)),
84 )
85
86 f = DomainRenamer("pix")(f)
87 f = DomainLowerer()(f)
88 self.assertRepr(f.statements, """
89 (
90 (eq (sig s1) (const 1'd0))
91 )
92 """)
93
94
95 def test_rename_cd_subfragment(self):
96 cd_sync = ClockDomain()
97 cd_pix = ClockDomain()
98
99 f1 = Fragment()
100 f1.add_domains(cd_sync, cd_pix)
101 f2 = Fragment()
102 f2.add_domains(cd_sync)
103 f1.add_subfragment(f2)
104
105 f1 = DomainRenamer("ext")(f1)
106 self.assertEqual(cd_sync.name, "ext")
107 self.assertEqual(f1.domains, {
108 "ext": cd_sync,
109 "pix": cd_pix,
110 })
111
112 def test_rename_wrong_to_comb(self):
113 with self.assertRaises(ValueError,
114 msg="Domain 'sync' may not be renamed to 'comb'"):
115 DomainRenamer("comb")
116
117 def test_rename_wrong_from_comb(self):
118 with self.assertRaises(ValueError,
119 msg="Domain 'comb' may not be renamed"):
120 DomainRenamer({"comb": "sync"})
121
122
123 class DomainLowererTestCase(FHDLTestCase):
124 def setUp(self):
125 self.s = Signal()
126
127 def test_lower_clk(self):
128 sync = ClockDomain()
129 f = Fragment()
130 f.add_domains(sync)
131 f.add_statements(
132 self.s.eq(ClockSignal("sync"))
133 )
134
135 f = DomainLowerer()(f)
136 self.assertRepr(f.statements, """
137 (
138 (eq (sig s) (sig clk))
139 )
140 """)
141
142 def test_lower_rst(self):
143 sync = ClockDomain()
144 f = Fragment()
145 f.add_domains(sync)
146 f.add_statements(
147 self.s.eq(ResetSignal("sync"))
148 )
149
150 f = DomainLowerer()(f)
151 self.assertRepr(f.statements, """
152 (
153 (eq (sig s) (sig rst))
154 )
155 """)
156
157 def test_lower_rst_reset_less(self):
158 sync = ClockDomain(reset_less=True)
159 f = Fragment()
160 f.add_domains(sync)
161 f.add_statements(
162 self.s.eq(ResetSignal("sync", allow_reset_less=True))
163 )
164
165 f = DomainLowerer()(f)
166 self.assertRepr(f.statements, """
167 (
168 (eq (sig s) (const 1'd0))
169 )
170 """)
171
172 def test_lower_drivers(self):
173 sync = ClockDomain()
174 pix = ClockDomain()
175 f = Fragment()
176 f.add_domains(sync, pix)
177 f.add_driver(ClockSignal("pix"), None)
178 f.add_driver(ResetSignal("pix"), "sync")
179
180 f = DomainLowerer()(f)
181 self.assertEqual(f.drivers, {
182 None: SignalSet((pix.clk,)),
183 "sync": SignalSet((pix.rst,))
184 })
185
186 def test_lower_wrong_domain(self):
187 f = Fragment()
188 f.add_statements(
189 self.s.eq(ClockSignal("xxx"))
190 )
191
192 with self.assertRaises(DomainError,
193 msg="Signal (clk xxx) refers to nonexistent domain 'xxx'"):
194 DomainLowerer()(f)
195
196 def test_lower_wrong_reset_less_domain(self):
197 sync = ClockDomain(reset_less=True)
198 f = Fragment()
199 f.add_domains(sync)
200 f.add_statements(
201 self.s.eq(ResetSignal("sync"))
202 )
203
204 with self.assertRaises(DomainError,
205 msg="Signal (rst sync) refers to reset of reset-less domain 'sync'"):
206 DomainLowerer()(f)
207
208
209 class SampleLowererTestCase(FHDLTestCase):
210 def setUp(self):
211 self.i = Signal()
212 self.o1 = Signal()
213 self.o2 = Signal()
214 self.o3 = Signal()
215
216 def test_lower_signal(self):
217 f = Fragment()
218 f.add_statements(
219 self.o1.eq(Sample(self.i, 2, "sync")),
220 self.o2.eq(Sample(self.i, 1, "sync")),
221 self.o3.eq(Sample(self.i, 1, "pix")),
222 )
223
224 f = SampleLowerer()(f)
225 self.assertRepr(f.statements, """
226 (
227 (eq (sig o1) (sig $sample$s$i$sync$2))
228 (eq (sig o2) (sig $sample$s$i$sync$1))
229 (eq (sig o3) (sig $sample$s$i$pix$1))
230 (eq (sig $sample$s$i$sync$1) (sig i))
231 (eq (sig $sample$s$i$sync$2) (sig $sample$s$i$sync$1))
232 (eq (sig $sample$s$i$pix$1) (sig i))
233 )
234 """)
235 self.assertEqual(len(f.drivers["sync"]), 2)
236 self.assertEqual(len(f.drivers["pix"]), 1)
237
238 def test_lower_const(self):
239 f = Fragment()
240 f.add_statements(
241 self.o1.eq(Sample(1, 2, "sync")),
242 )
243
244 f = SampleLowerer()(f)
245 self.assertRepr(f.statements, """
246 (
247 (eq (sig o1) (sig $sample$c$1$sync$2))
248 (eq (sig $sample$c$1$sync$1) (const 1'd1))
249 (eq (sig $sample$c$1$sync$2) (sig $sample$c$1$sync$1))
250 )
251 """)
252 self.assertEqual(len(f.drivers["sync"]), 2)
253
254
255 class SwitchCleanerTestCase(FHDLTestCase):
256 def test_clean(self):
257 a = Signal()
258 b = Signal()
259 c = Signal()
260 stmts = [
261 Switch(a, {
262 1: a.eq(0),
263 0: [
264 b.eq(1),
265 Switch(b, {1: [
266 Switch(a|b, {})
267 ]})
268 ]
269 })
270 ]
271
272 self.assertRepr(SwitchCleaner()(stmts), """
273 (
274 (switch (sig a)
275 (case 1
276 (eq (sig a) (const 1'd0)))
277 (case 0
278 (eq (sig b) (const 1'd1)))
279 )
280 )
281 """)
282
283
284 class LHSGroupAnalyzerTestCase(FHDLTestCase):
285 def test_no_group_unrelated(self):
286 a = Signal()
287 b = Signal()
288 stmts = [
289 a.eq(0),
290 b.eq(0),
291 ]
292
293 groups = LHSGroupAnalyzer()(stmts)
294 self.assertEqual(list(groups.values()), [
295 SignalSet((a,)),
296 SignalSet((b,)),
297 ])
298
299 def test_group_related(self):
300 a = Signal()
301 b = Signal()
302 stmts = [
303 a.eq(0),
304 Cat(a, b).eq(0),
305 ]
306
307 groups = LHSGroupAnalyzer()(stmts)
308 self.assertEqual(list(groups.values()), [
309 SignalSet((a, b)),
310 ])
311
312 def test_no_loops(self):
313 a = Signal()
314 b = Signal()
315 stmts = [
316 a.eq(0),
317 Cat(a, b).eq(0),
318 Cat(a, b).eq(0),
319 ]
320
321 groups = LHSGroupAnalyzer()(stmts)
322 self.assertEqual(list(groups.values()), [
323 SignalSet((a, b)),
324 ])
325
326 def test_switch(self):
327 a = Signal()
328 b = Signal()
329 stmts = [
330 a.eq(0),
331 Switch(a, {
332 1: b.eq(0),
333 })
334 ]
335
336 groups = LHSGroupAnalyzer()(stmts)
337 self.assertEqual(list(groups.values()), [
338 SignalSet((a,)),
339 SignalSet((b,)),
340 ])
341
342 def test_lhs_empty(self):
343 stmts = [
344 Cat().eq(0)
345 ]
346
347 groups = LHSGroupAnalyzer()(stmts)
348 self.assertEqual(list(groups.values()), [
349 ])
350
351
352 class LHSGroupFilterTestCase(FHDLTestCase):
353 def test_filter(self):
354 a = Signal()
355 b = Signal()
356 c = Signal()
357 stmts = [
358 Switch(a, {
359 1: a.eq(0),
360 0: [
361 b.eq(1),
362 Switch(b, {1: []})
363 ]
364 })
365 ]
366
367 self.assertRepr(LHSGroupFilter(SignalSet((a,)))(stmts), """
368 (
369 (switch (sig a)
370 (case 1
371 (eq (sig a) (const 1'd0)))
372 (case 0 )
373 )
374 )
375 """)
376
377 def test_lhs_empty(self):
378 stmts = [
379 Cat().eq(0)
380 ]
381
382 self.assertRepr(LHSGroupFilter(SignalSet())(stmts), "()")
383
384
385 class ResetInserterTestCase(FHDLTestCase):
386 def setUp(self):
387 self.s1 = Signal()
388 self.s2 = Signal(reset=1)
389 self.s3 = Signal(reset=1, reset_less=True)
390 self.c1 = Signal()
391
392 def test_reset_default(self):
393 f = Fragment()
394 f.add_statements(
395 self.s1.eq(1)
396 )
397 f.add_driver(self.s1, "sync")
398
399 f = ResetInserter(self.c1)(f)
400 self.assertRepr(f.statements, """
401 (
402 (eq (sig s1) (const 1'd1))
403 (switch (sig c1)
404 (case 1 (eq (sig s1) (const 1'd0)))
405 )
406 )
407 """)
408
409 def test_reset_cd(self):
410 f = Fragment()
411 f.add_statements(
412 self.s1.eq(1),
413 self.s2.eq(0),
414 )
415 f.add_domains(ClockDomain("sync"))
416 f.add_driver(self.s1, "sync")
417 f.add_driver(self.s2, "pix")
418
419 f = ResetInserter({"pix": self.c1})(f)
420 self.assertRepr(f.statements, """
421 (
422 (eq (sig s1) (const 1'd1))
423 (eq (sig s2) (const 1'd0))
424 (switch (sig c1)
425 (case 1 (eq (sig s2) (const 1'd1)))
426 )
427 )
428 """)
429
430 def test_reset_value(self):
431 f = Fragment()
432 f.add_statements(
433 self.s2.eq(0)
434 )
435 f.add_driver(self.s2, "sync")
436
437 f = ResetInserter(self.c1)(f)
438 self.assertRepr(f.statements, """
439 (
440 (eq (sig s2) (const 1'd0))
441 (switch (sig c1)
442 (case 1 (eq (sig s2) (const 1'd1)))
443 )
444 )
445 """)
446
447 def test_reset_less(self):
448 f = Fragment()
449 f.add_statements(
450 self.s3.eq(0)
451 )
452 f.add_driver(self.s3, "sync")
453
454 f = ResetInserter(self.c1)(f)
455 self.assertRepr(f.statements, """
456 (
457 (eq (sig s3) (const 1'd0))
458 (switch (sig c1)
459 (case 1 )
460 )
461 )
462 """)
463
464
465 class EnableInserterTestCase(FHDLTestCase):
466 def setUp(self):
467 self.s1 = Signal()
468 self.s2 = Signal()
469 self.s3 = Signal()
470 self.c1 = Signal()
471
472 def test_enable_default(self):
473 f = Fragment()
474 f.add_statements(
475 self.s1.eq(1)
476 )
477 f.add_driver(self.s1, "sync")
478
479 f = EnableInserter(self.c1)(f)
480 self.assertRepr(f.statements, """
481 (
482 (eq (sig s1) (const 1'd1))
483 (switch (sig c1)
484 (case 0 (eq (sig s1) (sig s1)))
485 )
486 )
487 """)
488
489 def test_enable_cd(self):
490 f = Fragment()
491 f.add_statements(
492 self.s1.eq(1),
493 self.s2.eq(0),
494 )
495 f.add_driver(self.s1, "sync")
496 f.add_driver(self.s2, "pix")
497
498 f = EnableInserter({"pix": self.c1})(f)
499 self.assertRepr(f.statements, """
500 (
501 (eq (sig s1) (const 1'd1))
502 (eq (sig s2) (const 1'd0))
503 (switch (sig c1)
504 (case 0 (eq (sig s2) (sig s2)))
505 )
506 )
507 """)
508
509 def test_enable_subfragment(self):
510 f1 = Fragment()
511 f1.add_statements(
512 self.s1.eq(1)
513 )
514 f1.add_driver(self.s1, "sync")
515
516 f2 = Fragment()
517 f2.add_statements(
518 self.s2.eq(1)
519 )
520 f2.add_driver(self.s2, "sync")
521 f1.add_subfragment(f2)
522
523 f1 = EnableInserter(self.c1)(f1)
524 (f2, _), = f1.subfragments
525 self.assertRepr(f1.statements, """
526 (
527 (eq (sig s1) (const 1'd1))
528 (switch (sig c1)
529 (case 0 (eq (sig s1) (sig s1)))
530 )
531 )
532 """)
533 self.assertRepr(f2.statements, """
534 (
535 (eq (sig s2) (const 1'd1))
536 (switch (sig c1)
537 (case 0 (eq (sig s2) (sig s2)))
538 )
539 )
540 """)
541
542 def test_enable_read_port(self):
543 mem = Memory(width=8, depth=4)
544 f = EnableInserter(self.c1)(mem.read_port(transparent=False)).elaborate(platform=None)
545 self.assertRepr(f.named_ports["EN"][0], """
546 (m (sig c1) (sig mem_r_en) (const 1'd0))
547 """)
548
549 def test_enable_write_port(self):
550 mem = Memory(width=8, depth=4)
551 f = EnableInserter(self.c1)(mem.write_port()).elaborate(platform=None)
552 self.assertRepr(f.named_ports["EN"][0], """
553 (m (sig c1) (cat (repl (slice (sig mem_w_en) 0:1) 8)) (const 8'd0))
554 """)
555
556
557 class _MockElaboratable(Elaboratable):
558 def __init__(self):
559 self.s1 = Signal()
560
561 def elaborate(self, platform):
562 f = Fragment()
563 f.add_statements(
564 self.s1.eq(1)
565 )
566 f.add_driver(self.s1, "sync")
567 return f
568
569
570 class TransformedElaboratableTestCase(FHDLTestCase):
571 def setUp(self):
572 self.c1 = Signal()
573 self.c2 = Signal()
574
575 def test_getattr(self):
576 e = _MockElaboratable()
577 te = EnableInserter(self.c1)(e)
578
579 self.assertIs(te.s1, e.s1)
580
581 def test_composition(self):
582 e = _MockElaboratable()
583 te1 = EnableInserter(self.c1)(e)
584 te2 = ResetInserter(self.c2)(te1)
585
586 self.assertIsInstance(te1, TransformedElaboratable)
587 self.assertIs(te1, te2)
588
589 f = Fragment.get(te2, None)
590 self.assertRepr(f.statements, """
591 (
592 (eq (sig s1) (const 1'd1))
593 (switch (sig c1)
594 (case 0 (eq (sig s1) (sig s1)))
595 )
596 (switch (sig c2)
597 (case 1 (eq (sig s1) (const 1'd0)))
598 )
599 )
600 """)
601
602
603 class MockUserValue(UserValue):
604 def __init__(self, lowered):
605 super().__init__()
606 self.lowered = lowered
607
608 def lower(self):
609 return self.lowered
610
611
612 class UserValueTestCase(FHDLTestCase):
613 def setUp(self):
614 self.s = Signal()
615 self.c = Signal()
616 self.uv = MockUserValue(self.s)
617
618 def test_lower(self):
619 sync = ClockDomain()
620 f = Fragment()
621 f.add_domains(sync)
622 f.add_statements(
623 self.uv.eq(1)
624 )
625 for signal in self.uv._lhs_signals():
626 f.add_driver(signal, "sync")
627
628 f = ResetInserter(self.c)(f)
629 f = DomainLowerer()(f)
630 self.assertRepr(f.statements, """
631 (
632 (eq (sig s) (const 1'd1))
633 (switch (sig c)
634 (case 1 (eq (sig s) (const 1'd0)))
635 )
636 (switch (sig rst)
637 (case 1 (eq (sig s) (const 1'd0)))
638 )
639 )
640 """)
641
642
643 class UserValueRecursiveTestCase(UserValueTestCase):
644 def setUp(self):
645 self.s = Signal()
646 self.c = Signal()
647 self.uv = MockUserValue(MockUserValue(self.s))
648
649 # inherit the test_lower method from UserValueTestCase because the checks are the same