Remove dependence on machine/syscall.h
[riscv-tests.git] / benchmarks / common / syscalls.c
1 #include <stdint.h>
2 #include <string.h>
3 #include <stdarg.h>
4 #include <stdio.h>
5 #include <limits.h>
6 #include "util.h"
7
8 #define SYS_write 64
9 #define SYS_exit 93
10 #define SYS_stats 1234
11
12 // initialized in crt.S
13 int have_vec;
14
15 static long handle_frontend_syscall(long which, long arg0, long arg1, long arg2)
16 {
17 volatile uint64_t magic_mem[8] __attribute__((aligned(64)));
18 magic_mem[0] = which;
19 magic_mem[1] = arg0;
20 magic_mem[2] = arg1;
21 magic_mem[3] = arg2;
22 __sync_synchronize();
23 write_csr(tohost, (long)magic_mem);
24 while (swap_csr(fromhost, 0) == 0);
25 return magic_mem[0];
26 }
27
28 // In setStats, we might trap reading uarch-specific counters.
29 // The trap handler will skip over the instruction and write 0,
30 // but only if a0 is the destination register.
31 #define read_csr_safe(reg) ({ register long __tmp asm("a0"); \
32 asm volatile ("csrr %0, " #reg : "=r"(__tmp)); \
33 __tmp; })
34
35 #define NUM_COUNTERS 18
36 static long counters[NUM_COUNTERS];
37 static char* counter_names[NUM_COUNTERS];
38 static int handle_stats(int enable)
39 {
40 //use csrs to set stats register
41 if (enable)
42 asm volatile ("csrrs a0, stats, 1" ::: "a0");
43 int i = 0;
44 #define READ_CTR(name) do { \
45 while (i >= NUM_COUNTERS) ; \
46 long csr = read_csr_safe(name); \
47 if (!enable) { csr -= counters[i]; counter_names[i] = #name; } \
48 counters[i++] = csr; \
49 } while (0)
50 READ_CTR(cycle); READ_CTR(instret);
51 READ_CTR(uarch0); READ_CTR(uarch1); READ_CTR(uarch2); READ_CTR(uarch3);
52 READ_CTR(uarch4); READ_CTR(uarch5); READ_CTR(uarch6); READ_CTR(uarch7);
53 READ_CTR(uarch8); READ_CTR(uarch9); READ_CTR(uarch10); READ_CTR(uarch11);
54 READ_CTR(uarch12); READ_CTR(uarch13); READ_CTR(uarch14); READ_CTR(uarch15);
55 #undef READ_CTR
56 if (!enable)
57 asm volatile ("csrrc a0, stats, 1" ::: "a0");
58 return 0;
59 }
60
61 static void tohost_exit(int code)
62 {
63 write_csr(tohost, (code << 1) | 1);
64 while (1);
65 }
66
67 long handle_trap(long cause, long epc, long regs[32])
68 {
69 int* csr_insn;
70 asm ("jal %0, 1f; csrr a0, stats; 1:" : "=r"(csr_insn));
71 long sys_ret = 0;
72
73 if (cause == CAUSE_ILLEGAL_INSTRUCTION &&
74 (*(int*)epc & *csr_insn) == *csr_insn)
75 ;
76 else if (cause != CAUSE_SYSCALL)
77 tohost_exit(1337);
78 else if (regs[17] == SYS_exit)
79 tohost_exit(regs[10]);
80 else if (regs[17] == SYS_stats)
81 sys_ret = handle_stats(regs[10]);
82 else
83 sys_ret = handle_frontend_syscall(regs[17], regs[10], regs[11], regs[12]);
84
85 regs[10] = sys_ret;
86 return epc+4;
87 }
88
89 static long syscall(long num, long arg0, long arg1, long arg2)
90 {
91 register long a7 asm("a7") = num;
92 register long a0 asm("a0") = arg0;
93 register long a1 asm("a1") = arg1;
94 register long a2 asm("a2") = arg2;
95 asm volatile ("scall" : "+r"(a0) : "r"(a1), "r"(a2), "r"(a7));
96 return a0;
97 }
98
99 void exit(int code)
100 {
101 syscall(SYS_exit, code, 0, 0);
102 while (1);
103 }
104
105 void setStats(int enable)
106 {
107 syscall(SYS_stats, enable, 0, 0);
108 }
109
110 void printstr(const char* s)
111 {
112 syscall(SYS_write, 1, (long)s, strlen(s));
113 }
114
115 void __attribute__((weak)) thread_entry(int cid, int nc)
116 {
117 // multi-threaded programs override this function.
118 // for the case of single-threaded programs, only let core 0 proceed.
119 while (cid != 0);
120 }
121
122 int __attribute__((weak)) main(int argc, char** argv)
123 {
124 // single-threaded programs override this function.
125 printstr("Implement main(), foo!\n");
126 return -1;
127 }
128
129 static void init_tls()
130 {
131 register void* thread_pointer asm("tp");
132 extern char _tls_data;
133 extern __thread char _tdata_begin, _tdata_end, _tbss_end;
134 size_t tdata_size = &_tdata_end - &_tdata_begin;
135 memcpy(thread_pointer, &_tls_data, tdata_size);
136 size_t tbss_size = &_tbss_end - &_tdata_end;
137 memset(thread_pointer + tdata_size, 0, tbss_size);
138 }
139
140 void _init(int cid, int nc)
141 {
142 init_tls();
143 thread_entry(cid, nc);
144
145 // only single-threaded programs should ever get here.
146 int ret = main(0, 0);
147
148 char buf[NUM_COUNTERS * 32] __attribute__((aligned(64)));
149 char* pbuf = buf;
150 for (int i = 0; i < NUM_COUNTERS; i++)
151 if (counters[i])
152 pbuf += sprintf(pbuf, "%s = %d\n", counter_names[i], counters[i]);
153 if (pbuf != buf)
154 printstr(buf);
155
156 exit(ret);
157 }
158
159 #undef putchar
160 int putchar(int ch)
161 {
162 static __thread char buf[64] __attribute__((aligned(64)));
163 static __thread int buflen = 0;
164
165 buf[buflen++] = ch;
166
167 if (ch == '\n' || buflen == sizeof(buf))
168 {
169 syscall(SYS_write, 1, (long)buf, buflen);
170 buflen = 0;
171 }
172
173 return 0;
174 }
175
176 void printhex(uint64_t x)
177 {
178 char str[17];
179 int i;
180 for (i = 0; i < 16; i++)
181 {
182 str[15-i] = (x & 0xF) + ((x & 0xF) < 10 ? '0' : 'a'-10);
183 x >>= 4;
184 }
185 str[16] = 0;
186
187 printstr(str);
188 }
189
190 static inline void printnum(void (*putch)(int, void**), void **putdat,
191 unsigned long long num, unsigned base, int width, int padc)
192 {
193 unsigned digs[sizeof(num)*CHAR_BIT];
194 int pos = 0;
195
196 while (1)
197 {
198 digs[pos++] = num % base;
199 if (num < base)
200 break;
201 num /= base;
202 }
203
204 while (width-- > pos)
205 putch(padc, putdat);
206
207 while (pos-- > 0)
208 putch(digs[pos] + (digs[pos] >= 10 ? 'a' - 10 : '0'), putdat);
209 }
210
211 static unsigned long long getuint(va_list *ap, int lflag)
212 {
213 if (lflag >= 2)
214 return va_arg(*ap, unsigned long long);
215 else if (lflag)
216 return va_arg(*ap, unsigned long);
217 else
218 return va_arg(*ap, unsigned int);
219 }
220
221 static long long getint(va_list *ap, int lflag)
222 {
223 if (lflag >= 2)
224 return va_arg(*ap, long long);
225 else if (lflag)
226 return va_arg(*ap, long);
227 else
228 return va_arg(*ap, int);
229 }
230
231 static void vprintfmt(void (*putch)(int, void**), void **putdat, const char *fmt, va_list ap)
232 {
233 register const char* p;
234 const char* last_fmt;
235 register int ch, err;
236 unsigned long long num;
237 int base, lflag, width, precision, altflag;
238 char padc;
239
240 while (1) {
241 while ((ch = *(unsigned char *) fmt) != '%') {
242 if (ch == '\0')
243 return;
244 fmt++;
245 putch(ch, putdat);
246 }
247 fmt++;
248
249 // Process a %-escape sequence
250 last_fmt = fmt;
251 padc = ' ';
252 width = -1;
253 precision = -1;
254 lflag = 0;
255 altflag = 0;
256 reswitch:
257 switch (ch = *(unsigned char *) fmt++) {
258
259 // flag to pad on the right
260 case '-':
261 padc = '-';
262 goto reswitch;
263
264 // flag to pad with 0's instead of spaces
265 case '0':
266 padc = '0';
267 goto reswitch;
268
269 // width field
270 case '1':
271 case '2':
272 case '3':
273 case '4':
274 case '5':
275 case '6':
276 case '7':
277 case '8':
278 case '9':
279 for (precision = 0; ; ++fmt) {
280 precision = precision * 10 + ch - '0';
281 ch = *fmt;
282 if (ch < '0' || ch > '9')
283 break;
284 }
285 goto process_precision;
286
287 case '*':
288 precision = va_arg(ap, int);
289 goto process_precision;
290
291 case '.':
292 if (width < 0)
293 width = 0;
294 goto reswitch;
295
296 case '#':
297 altflag = 1;
298 goto reswitch;
299
300 process_precision:
301 if (width < 0)
302 width = precision, precision = -1;
303 goto reswitch;
304
305 // long flag (doubled for long long)
306 case 'l':
307 lflag++;
308 goto reswitch;
309
310 // character
311 case 'c':
312 putch(va_arg(ap, int), putdat);
313 break;
314
315 // string
316 case 's':
317 if ((p = va_arg(ap, char *)) == NULL)
318 p = "(null)";
319 if (width > 0 && padc != '-')
320 for (width -= strnlen(p, precision); width > 0; width--)
321 putch(padc, putdat);
322 for (; (ch = *p) != '\0' && (precision < 0 || --precision >= 0); width--) {
323 putch(ch, putdat);
324 p++;
325 }
326 for (; width > 0; width--)
327 putch(' ', putdat);
328 break;
329
330 // (signed) decimal
331 case 'd':
332 num = getint(&ap, lflag);
333 if ((long long) num < 0) {
334 putch('-', putdat);
335 num = -(long long) num;
336 }
337 base = 10;
338 goto signed_number;
339
340 // unsigned decimal
341 case 'u':
342 base = 10;
343 goto unsigned_number;
344
345 // (unsigned) octal
346 case 'o':
347 // should do something with padding so it's always 3 octits
348 base = 8;
349 goto unsigned_number;
350
351 // pointer
352 case 'p':
353 static_assert(sizeof(long) == sizeof(void*));
354 lflag = 1;
355 putch('0', putdat);
356 putch('x', putdat);
357 /* fall through to 'x' */
358
359 // (unsigned) hexadecimal
360 case 'x':
361 base = 16;
362 unsigned_number:
363 num = getuint(&ap, lflag);
364 signed_number:
365 printnum(putch, putdat, num, base, width, padc);
366 break;
367
368 // escaped '%' character
369 case '%':
370 putch(ch, putdat);
371 break;
372
373 // unrecognized escape sequence - just print it literally
374 default:
375 putch('%', putdat);
376 fmt = last_fmt;
377 break;
378 }
379 }
380 }
381
382 int printf(const char* fmt, ...)
383 {
384 va_list ap;
385 va_start(ap, fmt);
386
387 vprintfmt((void*)putchar, 0, fmt, ap);
388
389 va_end(ap);
390 return 0; // incorrect return value, but who cares, anyway?
391 }
392
393 int sprintf(char* str, const char* fmt, ...)
394 {
395 va_list ap;
396 char* str0 = str;
397 va_start(ap, fmt);
398
399 void sprintf_putch(int ch, void** data)
400 {
401 char** pstr = (char**)data;
402 **pstr = ch;
403 (*pstr)++;
404 }
405
406 vprintfmt(sprintf_putch, (void**)&str, fmt, ap);
407 *str = 0;
408
409 va_end(ap);
410 return str - str0;
411 }