Finish refactoring out py3compat and delete the file
[pyelftools.git] / elftools / common / utils.py
1 #-------------------------------------------------------------------------------
2 # elftools: common/utils.py
3 #
4 # Miscellaneous utilities for elftools
5 #
6 # Eli Bendersky (eliben@gmail.com)
7 # This code is in the public domain
8 #-------------------------------------------------------------------------------
9 from contextlib import contextmanager
10 from .exceptions import ELFParseError, ELFError, DWARFError
11 from ..construct import ConstructError, ULInt8
12 import os
13
14
15 def merge_dicts(*dicts):
16 "Given any number of dicts, merges them into a new one."""
17 result = {}
18 for d in dicts:
19 result.update(d)
20 return result
21
22 def bytes2str(b):
23 """Decode a bytes object into a string."""
24 return b.decode('latin-1')
25
26 def bytelist2string(bytelist):
27 """ Convert a list of byte values (e.g. [0x10 0x20 0x00]) to a bytes object
28 (e.g. b'\x10\x20\x00').
29 """
30 return b''.join(bytes((b,)) for b in bytelist)
31
32
33 def struct_parse(struct, stream, stream_pos=None):
34 """ Convenience function for using the given struct to parse a stream.
35 If stream_pos is provided, the stream is seeked to this position before
36 the parsing is done. Otherwise, the current position of the stream is
37 used.
38 Wraps the error thrown by construct with ELFParseError.
39 """
40 try:
41 if stream_pos is not None:
42 stream.seek(stream_pos)
43 return struct.parse_stream(stream)
44 except ConstructError as e:
45 raise ELFParseError(str(e))
46
47
48 def parse_cstring_from_stream(stream, stream_pos=None):
49 """ Parse a C-string from the given stream. The string is returned without
50 the terminating \x00 byte. If the terminating byte wasn't found, None
51 is returned (the stream is exhausted).
52 If stream_pos is provided, the stream is seeked to this position before
53 the parsing is done. Otherwise, the current position of the stream is
54 used.
55 Note: a bytes object is returned here, because this is what's read from
56 the binary file.
57 """
58 if stream_pos is not None:
59 stream.seek(stream_pos)
60 CHUNKSIZE = 64
61 chunks = []
62 found = False
63 while True:
64 chunk = stream.read(CHUNKSIZE)
65 end_index = chunk.find(b'\x00')
66 if end_index >= 0:
67 chunks.append(chunk[:end_index])
68 found = True
69 break
70 else:
71 chunks.append(chunk)
72 if len(chunk) < CHUNKSIZE:
73 break
74 return b''.join(chunks) if found else None
75
76
77 def elf_assert(cond, msg=''):
78 """ Assert that cond is True, otherwise raise ELFError(msg)
79 """
80 _assert_with_exception(cond, msg, ELFError)
81
82
83 def dwarf_assert(cond, msg=''):
84 """ Assert that cond is True, otherwise raise DWARFError(msg)
85 """
86 _assert_with_exception(cond, msg, DWARFError)
87
88
89 @contextmanager
90 def preserve_stream_pos(stream):
91 """ Usage:
92 # stream has some position FOO (return value of stream.tell())
93 with preserve_stream_pos(stream):
94 # do stuff that manipulates the stream
95 # stream still has position FOO
96 """
97 saved_pos = stream.tell()
98 yield
99 stream.seek(saved_pos)
100
101
102 def roundup(num, bits):
103 """ Round up a number to nearest multiple of 2^bits. The result is a number
104 where the least significant bits passed in bits are 0.
105 """
106 return (num - 1 | (1 << bits) - 1) + 1
107
108 def read_blob(stream, length):
109 """Read length bytes from stream, return a list of ints
110 """
111 return [struct_parse(ULInt8(''), stream) for i in range(length)]
112
113 def save_dwarf_section(section, filename):
114 """Debug helper: dump section contents into a file
115 Section is expected to be one of the debug_xxx_sec elements of DWARFInfo
116 """
117 stream = section.stream
118 pos = stream.tell()
119 stream.seek(0, os.SEEK_SET)
120 section.stream.seek(0)
121 with open(filename, 'wb') as file:
122 data = stream.read(section.size)
123 file.write(data)
124 stream.seek(pos, os.SEEK_SET)
125
126 def iterbytes(b):
127 """Return an iterator over the elements of a bytes object.
128
129 For example, for b'abc' yields b'a', b'b' and then b'c'.
130 """
131 for i in range(len(b)):
132 yield b[i:i+1]
133
134 def bytes2hex(b, sep=''):
135 if not sep:
136 return b.hex()
137 return sep.join(map('{:02x}'.format, b))
138
139 #------------------------- PRIVATE -------------------------
140
141 def _assert_with_exception(cond, msg, exception_type):
142 if not cond:
143 raise exception_type(msg)