52b8391d0bc06abe4d76f35e0a4d10f572d4ac63
[pyelftools.git] / test / test_utils.py
1 #-------------------------------------------------------------------------------
2 # elftools tests
3 #
4 # Eli Bendersky (eliben@gmail.com)
5 # This code is in the public domain
6 #-------------------------------------------------------------------------------
7 import unittest
8 from io import BytesIO
9 from random import randint
10
11 from elftools.common.py3compat import int2byte
12 from elftools.common.utils import (parse_cstring_from_stream, merge_dicts,
13 preserve_stream_pos)
14
15
16 class Test_parse_cstring_from_stream(unittest.TestCase):
17 def _make_random_bytes(self, n):
18 return b''.join(int2byte(randint(32, 127)) for i in range(n))
19
20 def test_small1(self):
21 sio = BytesIO(b'abcdefgh\x0012345')
22 self.assertEqual(parse_cstring_from_stream(sio), b'abcdefgh')
23 self.assertEqual(parse_cstring_from_stream(sio, 2), b'cdefgh')
24 self.assertEqual(parse_cstring_from_stream(sio, 8), b'')
25
26 def test_small2(self):
27 sio = BytesIO(b'12345\x006789\x00abcdefg\x00iii')
28 self.assertEqual(parse_cstring_from_stream(sio), b'12345')
29 self.assertEqual(parse_cstring_from_stream(sio, 5), b'')
30 self.assertEqual(parse_cstring_from_stream(sio, 6), b'6789')
31
32 def test_large1(self):
33 text = b'i' * 400 + b'\x00' + b'bb'
34 sio = BytesIO(text)
35 self.assertEqual(parse_cstring_from_stream(sio), b'i' * 400)
36 self.assertEqual(parse_cstring_from_stream(sio, 150), b'i' * 250)
37
38 def test_large2(self):
39 text = self._make_random_bytes(5000) + b'\x00' + b'jujajaja'
40 sio = BytesIO(text)
41 self.assertEqual(parse_cstring_from_stream(sio), text[:5000])
42 self.assertEqual(parse_cstring_from_stream(sio, 2348), text[2348:5000])
43
44
45 class Test_preserve_stream_pos(unittest.TestCase):
46 def test_basic(self):
47 sio = BytesIO(b'abcdef')
48 with preserve_stream_pos(sio):
49 sio.seek(4)
50 self.assertEqual(sio.tell(), 0)
51
52 sio.seek(5)
53 with preserve_stream_pos(sio):
54 sio.seek(0)
55 self.assertEqual(sio.tell(), 5)
56
57
58 class Test_merge_dicts(unittest.TestCase):
59 def test_basic(self):
60 md = merge_dicts({10: 20, 20: 30}, {30: 40, 50: 60})
61 self.assertEqual(md, {10: 20, 20: 30, 30: 40, 50: 60})
62
63 def test_keys_resolve(self):
64 md = merge_dicts({10: 20, 20: 30}, {20: 40, 50: 60})
65 self.assertEqual(md, {10: 20, 20: 40, 50: 60})
66
67
68 if __name__ == '__main__':
69 unittest.main()