7c8ce7c9ae92450b9326c80fb2a69c231da02e1a
[riscv-tests.git] / benchmarks / dgemm / dgemm_main.c
1 //**************************************************************************
2 // Double-precision general matrix multiplication benchmark
3 //--------------------------------------------------------------------------
4
5 //--------------------------------------------------------------------------
6 // Macros
7
8 // Set HOST_DEBUG to 1 if you are going to compile this for a host
9 // machine (ie Athena/Linux) for debug purposes and set HOST_DEBUG
10 // to 0 if you are compiling with the smips-gcc toolchain.
11
12 #ifndef HOST_DEBUG
13 #define HOST_DEBUG 0
14 #endif
15
16 // Set PREALLOCATE to 1 if you want to preallocate the benchmark
17 // function before starting stats. If you have instruction/data
18 // caches and you don't want to count the overhead of misses, then
19 // you will need to use preallocation.
20
21 #ifndef PREALLOCATE
22 #define PREALLOCATE 0
23 #endif
24
25 // Set SET_STATS to 1 if you want to carve out the piece that actually
26 // does the computation.
27
28 #ifndef SET_STATS
29 #define SET_STATS 0
30 #endif
31
32 //--------------------------------------------------------------------------
33 // Input/Reference Data
34
35 #include "dataset1.h"
36
37 //--------------------------------------------------------------------------
38 // Helper functions
39
40 int verify( long n, const double test[], const double correct[] )
41 {
42 int i;
43 for ( i = 0; i < n; i++ ) {
44 if ( test[i] != correct[i] ) {
45 return 2;
46 }
47 }
48 return 1;
49 }
50
51 #if HOST_DEBUG
52 #include <stdio.h>
53 #include <stdlib.h>
54 void printArray( char name[], long n, const double arr[] )
55 {
56 int i;
57 printf( " %10s :", name );
58 for ( i = 0; i < n; i++ )
59 printf( " %8.1f ", arr[i] );
60 printf( "\n" );
61 }
62 #endif
63
64 void finishTest( int toHostValue )
65 {
66 #if HOST_DEBUG
67 if ( toHostValue == 1 )
68 printf( "*** PASSED ***\n" );
69 else
70 printf( "*** FAILED *** (tohost = %d)\n", toHostValue );
71 exit(0);
72 #else
73 asm( "mtpcr %0, cr30" : : "r" (toHostValue) );
74 while ( 1 ) { }
75 #endif
76 }
77
78 void setStats( int enable )
79 {
80 #if ( !HOST_DEBUG && SET_STATS )
81 asm( "mtpcr %0, cr10" : : "r" (enable) );
82 #endif
83 }
84
85 //--------------------------------------------------------------------------
86 // square_dgemm function
87
88 void square_dgemm( long n0, const double a0[], const double b0[], double c0[] )
89 {
90 long n = (n0+2)/3*3;
91 double a[n*n], b[n*n], c[n*n];
92
93 for (long i = 0; i < n0; i++)
94 {
95 long j;
96 for (j = 0; j < n0; j++)
97 {
98 a[i*n+j] = a0[i*n0+j];
99 b[i*n+j] = b0[j*n0+i];
100 }
101 for ( ; j < n; j++)
102 {
103 a[i*n+j] = b[i*n+j] = 0;
104 }
105 }
106 for (long i = n0; i < n; i++)
107 for (long j = 0; j < n; j++)
108 a[i*n+j] = b[i*n+j] = 0;
109
110 long i, j, k;
111 for (i = 0; i < n; i+=3)
112 {
113 for (j = 0; j < n; j+=3)
114 {
115 double *a0 = a + (i+0)*n, *b0 = b + (j+0)*n;
116 double *a1 = a + (i+1)*n, *b1 = b + (j+1)*n;
117 double *a2 = a + (i+2)*n, *b2 = b + (j+2)*n;
118
119 double s00 = 0, s01 = 0, s02 = 0;
120 double s10 = 0, s11 = 0, s12 = 0;
121 double s20 = 0, s21 = 0, s22 = 0;
122
123 while (a0 < a + (i+1)*n)
124 {
125 double a00 = a0[0], a01 = a0[1], a02 = a0[2];
126 double b00 = b0[0], b01 = b0[1], b02 = b0[2];
127 double a10 = a1[0], a11 = a1[1], a12 = a1[2];
128 double b10 = b1[0], b11 = b1[1], b12 = b1[2];
129 asm ("" ::: "memory");
130 double a20 = a2[0], a21 = a2[1], a22 = a2[2];
131 double b20 = b2[0], b21 = b2[1], b22 = b2[2];
132
133 s00 = a00*b00 + (a01*b01 + (a02*b02 + s00));
134 s01 = a00*b10 + (a01*b11 + (a02*b12 + s01));
135 s02 = a00*b20 + (a01*b21 + (a02*b22 + s02));
136 s10 = a10*b00 + (a11*b01 + (a12*b02 + s10));
137 s11 = a10*b10 + (a11*b11 + (a12*b12 + s11));
138 s12 = a10*b20 + (a11*b21 + (a12*b22 + s12));
139 s20 = a20*b00 + (a21*b01 + (a22*b02 + s20));
140 s21 = a20*b10 + (a21*b11 + (a22*b12 + s21));
141 s22 = a20*b20 + (a21*b21 + (a22*b22 + s22));
142
143 a0 += 3; b0 += 3;
144 a1 += 3; b1 += 3;
145 a2 += 3; b2 += 3;
146 }
147
148 c[(i+0)*n+j+0] = s00; c[(i+0)*n+j+1] = s01; c[(i+0)*n+j+2] = s02;
149 c[(i+1)*n+j+0] = s10; c[(i+1)*n+j+1] = s11; c[(i+1)*n+j+2] = s12;
150 c[(i+2)*n+j+0] = s20; c[(i+2)*n+j+1] = s21; c[(i+2)*n+j+2] = s22;
151 }
152 }
153
154 for (long i = 0; i < n0; i++)
155 {
156 long j;
157 for (j = 0; j < n0 - 2; j+=3)
158 {
159 c0[i*n0+j+0] = c[i*n+j+0];
160 c0[i*n0+j+1] = c[i*n+j+1];
161 c0[i*n0+j+2] = c[i*n+j+2];
162 }
163 for ( ; j < n0; j++)
164 c0[i*n0+j] = c[i*n+j];
165 }
166 }
167
168 //--------------------------------------------------------------------------
169 // Main
170
171 int main( int argc, char* argv[] )
172 {
173 double results_data[DATA_SIZE*DATA_SIZE];
174
175 // Output the input array
176
177 #if HOST_DEBUG
178 printArray( "input1", DATA_SIZE*DATA_SIZE, input1_data );
179 printArray( "input2", DATA_SIZE*DATA_SIZE, input2_data );
180 printArray( "verify", DATA_SIZE*DATA_SIZE, verify_data );
181 #endif
182
183 // If needed we preallocate everything in the caches
184
185 #if PREALLOCATE
186 square_dgemm( DATA_SIZE, input1_data, input2_data, results_data );
187 #endif
188
189 // Do the dgemm
190
191 setStats(1);
192 square_dgemm( DATA_SIZE, input1_data, input2_data, results_data );
193 setStats(0);
194
195 // Print out the results
196
197 #if HOST_DEBUG
198 printArray( "results", DATA_SIZE*DATA_SIZE, results_data );
199 #endif
200
201 // Check the results
202
203 finishTest(verify( DATA_SIZE*DATA_SIZE, results_data, verify_data ));
204
205 }