ELF: Find all symbols of a given name in `get_symbol_by_name`.
authorMatthew Fernandez <matthew.fernandez@gmail.com>
Tue, 7 Apr 2015 00:00:55 +0000 (10:00 +1000)
committerMatthew Fernandez <matthew.fernandez@nicta.com.au>
Fri, 10 Apr 2015 05:51:50 +0000 (15:51 +1000)
It is possible for an ELF file's symbol table to contain multiple entries under
the same symbol name. Prior to this commit, it was only possible to retrieve
the first of these via `get_symbol_by_name`. This commit alters this function
to return a list of all symbols under the given name, rather than just the
first entry. Functionality when a symbol name does not exist remains
unaffected.

elftools/elf/sections.py
test/test_get_symbol_by_name.py

index 83b5781b8d916515cd5d40f53789c617cc776f11..b198868b1fccf1ba2dc09dbd9964a91cb0aba7b9 100644 (file)
@@ -7,7 +7,7 @@
 # This code is in the public domain
 #-------------------------------------------------------------------------------
 from ..common.utils import struct_parse, elf_assert, parse_cstring_from_stream
-
+from collections import defaultdict
 
 class Section(object):
     """ Base class for ELF sections. Also used for all sections types that have
@@ -102,18 +102,18 @@ class SymbolTableSection(Section):
         return Symbol(entry, name)
 
     def get_symbol_by_name(self, name):
-        """ Get a symbol by its name. Return None if no symbol by the given
-            name exists.
+        """ Get a symbol(s) by name. Return None if no symbol by the given name
+            exists.
         """
         # The first time this method is called, construct a name to number
         # mapping
         #
         if self._symbol_name_map is None:
-            self._symbol_name_map = {}
+            self._symbol_name_map = defaultdict(list)
             for i, sym in enumerate(self.iter_symbols()):
-                self._symbol_name_map[sym.name] = i
-        symnum = self._symbol_name_map.get(name)
-        return None if symnum is None else self.get_symbol(symnum)
+                self._symbol_name_map[sym.name].append(i)
+        symnums = self._symbol_name_map.get(name)
+        return [self.get_symbol(i) for i in symnums] if symnums else None
 
     def iter_symbols(self):
         """ Yield all the symbols in the table
index cad104fe46e583a6788b544e91a72bbde455e33a..d5e7f440a84fdf75e37b351bf71d0d3b79a9b4ab 100644 (file)
@@ -20,10 +20,13 @@ class TestGetSymbolByName(unittest.TestCase):
             self.assertIsNotNone(symtab)
 
             # Test we can find a symbol by its name.
-            main = symtab.get_symbol_by_name(b'main')
-            self.assertIsNotNone(main)
+            mains = symtab.get_symbol_by_name(b'main')
+            self.assertIsNotNone(mains)
 
             # Test it is actually the symbol we expect.
+            self.assertIsInstance(mains, list)
+            self.assertEqual(len(mains), 1)
+            main = mains[0]
             self.assertEqual(main.name, b'main')
             self.assertEqual(main['st_value'], 0x8068)
             self.assertEqual(main['st_size'], 0x28)
@@ -41,5 +44,27 @@ class TestGetSymbolByName(unittest.TestCase):
             undef = symtab.get_symbol_by_name(b'non-existent symbol')
             self.assertIsNone(undef)
 
+    def test_duplicated_symbol(self):
+        with open(os.path.join('test', 'testfiles_for_unittests',
+                               'simple_gcc.elf.arm'), 'rb') as f:
+            elf = ELFFile(f)
+
+            # Find the symbol table.
+            symtab = elf.get_section_by_name(b'.symtab')
+            self.assertIsNotNone(symtab)
+
+            # The '$a' symbols that are present in the test file.
+            expected_symbols = [0x8000, 0x8034, 0x8090, 0x800c, 0x809c, 0x8018,
+                                0x8068]
+
+            # Test we get all expected instances of the symbol '$a'.
+            arm_markers = symtab.get_symbol_by_name(b'$a')
+            self.assertIsNotNone(arm_markers)
+            self.assertIsInstance(arm_markers, list)
+            self.assertEqual(len(arm_markers), len(expected_symbols))
+            for symbol in arm_markers:
+                self.assertEqual(symbol.name, b'$a')
+                self.assertIn(symbol['st_value'], expected_symbols)
+
 if __name__ == '__main__':
     unittest.main()