PyForTool
Python-fortran-tool
Loading...
Searching...
No Matches
openacc.py
1"""
2This module implements the Openacc class containing the methods relative to OpenACC.
3"""
4
5import re
6
7from pyfortool.util import debugDecor, n2name, alltext, tag
8from pyfortool.expressions import createElem, createExpr
9
10
11class Openacc():
12 """
13 OpenACC directives transformation methods.
14
15 Provides utilities for adding, removing, and transforming OpenACC
16 directives for GPU acceleration.
17 """
18
19 @debugDecor
20 def removeACC(self):
21 """
22 Remove all OpenACC directives.
23
24 Examples
25 --------
26 >>> pft = PYFT('gpu_code.F90')
27 >>> pft.removeACC()
28 """
29 self.removeComments(exclDirectives=[],
30 pattern=re.compile(r'^\!\$ACC', re.IGNORECASE))
31
32 @debugDecor
34 """
35 Remove MNH OpenACC bypass macros for non-Cray compilers.
36
37 Removes the following directives:
38 - !$mnh_undef(LOOP)
39 - !$mnh_undef(OPENACC)
40 - !$mnh_define(LOOP)
41 - !$mnh_define(OPENACC)
42 """
43 self.removeComments(exclDirectives=[],
44 pattern=re.compile(r'^\!\$mnh_undef\‍(LOOP\‍)'))
45 self.removeComments(exclDirectives=[],
46 pattern=re.compile(r'^\!\$mnh_undef\‍(OPENACC\‍)'))
47 self.removeComments(exclDirectives=[],
48 pattern=re.compile(r'^\!\$mnh_define\‍(LOOP\‍)'))
49 self.removeComments(exclDirectives=[],
50 pattern=re.compile(r'^\!\$mnh_define\‍(OPENACC\‍)'))
51
52 @debugDecor
54 """
55 Handle CRAY compiler vectorization issues with GPU directives.
56
57 On compute kernels with !$acc loop independent collapse(X) directives:
58 - If BR_ functions are used: removes !$acc loop collapse and uses DO CONCURRENT
59 - If !$mnh_undef(OPENACC) is present: removes !$acc loop collapse
60 - If !$mnh_undef(LOOP) is present: converts nested loops to DO CONCURRENT
61
62 Notes
63 -----
64 Works around CRAY compiler vectorization bugs with BR_ (bit-reproducibility)
65 functions and OpenACC directives.
66 """
67 def checkPresenceofBR(node):
68 """Return True if a BR_ (math BIT-REPRODUCTIBILITY) function is present in the node"""
69 mathBRList = ['ALOG', 'LOG', 'EXP', 'COS', 'SIN', 'ASIN', 'ATAN', 'ATAN2',
70 'P2', 'P3', 'P4']
71 namedE = node.findall('.//{*}named-E/{*}N/{*}n')
72 for el in namedE:
73 if alltext(el) in ['BR_' + e for e in mathBRList]:
74 return True
75 return False
76
77 def getStatementsInDoConstruct(node, savelist):
78 for sNode in node:
79 if 'do-construct' in tag(sNode):
80 getStatementsInDoConstruct(sNode, savelist)
81 elif 'do-stmt' not in tag(sNode):
82 savelist.append(sNode)
83
84 useNestedLoops = True # Cray compiler needs nested loops by default
85 useAccLoopIndependent = True # It needs nested !$acc loop independent collapse(X)
86 toremove = [] # list of nodes to remove
87 comments = self.findall('.//{*}C')
88
89 for comment in comments:
90 if comment.text.startswith('!$mnh_undef(LOOP)'):
91 useNestedLoops = False # Use DO CONCURRENT
92 if comment.text.startswith('!$mnh_undef(OPENACC)'):
93 useAccLoopIndependent = False
94
95 if comment.text.startswith('!$mnh_define(LOOP)'):
96 useNestedLoops = True # Use DO CONCURRENT
97 if comment.text.startswith('!$mnh_define(OPENACC)'):
98 useAccLoopIndependent = True
99
100 if comment.text.startswith('!$acc loop independent collapse('):
101 # Get the statements content in the DO-construct
102 par = self.getParent(comment)
103 ind = list(par).index(comment)
104 nestedLoop = par[ind + 1]
105 statements = []
106 getStatementsInDoConstruct(nestedLoop, statements)
107
108 # Check presence of BR_ within the statements
109 isBRPresent = False
110 for stmt in statements:
111 if checkPresenceofBR(stmt):
112 isBRPresent = True
113 break
114
115 # Remove !$acc loop independent collapse
116 # if BR_ is present or if !$mnh_undef(OPENACC)
117 if not useAccLoopIndependent or isBRPresent:
118 toremove.append((self, comment))
119
120 # Use DO CONCURRENT instead of nested-loop if BR_ is present or if !$mnh_undef(LOOP)
121 if not useNestedLoops or isBRPresent:
122 # Determine the table of indices
123 doStmt = nestedLoop.findall('.//{*}do-stmt')
124 table = {}
125 for do in reversed(doStmt):
126 table[do.find('.//{*}do-V/{*}named-E/{*}N/{*}n').text] = \
127 [alltext(do.find('.//{*}lower-bound')),
128 alltext(do.find('.//{*}upper-bound'))]
129
130 # Create the do-construct
131 inner, outer, _ = self.createDoConstruct(table,
132 indent=len(nestedLoop.tail),
133 concurrent=True)
134
135 # Insert the statements in the new do-construct
136 for stmt in statements:
137 inner.insert(-1, stmt)
138
139 # Insert the new do-construct and delete all the old do-construct
140 par.insert(ind, outer)
141 toremove.append((par, nestedLoop))
142
143 # Suppression of nodes
144 for parent, elem in toremove:
145 parent.remove(elem)
146
147 @debugDecor
148 def allocatetoHIP(self):
149 """
150 Convert ALLOCATE/DEALLOCATE to HIP versions for AMD GPUs.
151
152 Transforms memory allocation calls for variables that are:
153 - Sent to GPU via !$acc enter data copyin
154 - Created on GPU via !$acc enter data create
155
156 Required for AMD MI250X GPU managed memory usage (e.g., Adastra cluster).
157
158 Examples
159 --------
160 >>> pft = PYFT('gpu_code.F90')
161 >>> pft.allocatetoHIP()
162 # ALLOCATE(X) becomes CALL MNH_HIPALLOCATE(X)
163 """
164 scopes = self.getScopes()
165 for scope in scopes:
166 varsToChange = []
167 comments = scope.findall('.//{*}C')
168 pointers = scope.findall('.//{*}pointer-a-stmt')
169
170 for coms in comments:
171 if ('!$acc enter data' in coms.text or '!$acc exit data' in coms.text) \
172 and coms.text.count('!') == 1:
173 if coms.text.count('(') == 1:
174 # !$acc enter data copyin( XRRS, XRRS_CLD ) ==> [' XRRS, XRRS_CLD ']
175 varsToChange.extend(coms.text.split(')')[0].split('(')[1:][0].split(','))
176 else:
177 # $acc exit data delete(xb_mg(level,m)%st) ==> xb_mg(level,m)%st
178 variableName = re.search(r'\‍(.*\‍)', coms.text).group(0)[1:-1]
179 varsToChange.insert(0, variableName)
180
181 if len(pointers) > 0:
182 for i, var in enumerate(varsToChange):
183 for pointer in pointers:
184 if alltext(pointer.find('.//{*}E-1/{*}named-E')) == var:
185 varsToChange[i] = alltext(pointer.find('.//{*}E-2/{*}named-E'))
186
187 for i, var in enumerate(varsToChange):
188 varsToChange[i] = var.replace(' ', '')
189
190 if len(varsToChange) > 0:
191 scope.addModuleVar([(scope.path, 'MODE_MNH_HIPFORT', None)])
192
193 allocateStmts = scope.findall('.//{*}allocate-stmt')
194 allocateStmts.extend(scope.findall('.//{*}deallocate-stmt'))
195
196 for stmt in allocateStmts:
197 varsChecking = ''
198 allocateArg = stmt.find('.//{*}arg-spec')
199 arrayR = allocateArg.find('.//{*}array-R')
200 if arrayR is None:
201 # Particular case of multiple parensR such as Tjacobi(level)%r(nz)
202 parensR = allocateArg.findall('.//{*}parens-R')
203 if len(parensR) == 2:
204 # It's a component case
205 varsChecking = alltext(allocateArg).split('%', maxsplit=1)[0]+'%' + \
206 alltext(allocateArg).split('%')[1].split('(', maxsplit=1)[0]
207 elif len(parensR) == 1:
208 # Tjacobi(level)%Sr case where Sr is a type itself
209 if allocateArg.find('.//{*}component-R'):
210 varsChecking = alltext(allocateArg)
211 else:
212 varsChecking = alltext(allocateArg).split('(', maxsplit=1)[0]
213 elif len(parensR) == 0:
214 varsChecking = alltext(allocateArg).split('(', maxsplit=1)[0]
215 else:
216 varsChecking = alltext(allocateArg).split('(', maxsplit=1)[0]
217
218 if varsChecking in varsToChange:
219 removeLastComma = True
220 stmt.text = "CALL MNH_HIP" + stmt.text # For allocate/deallocate statements
221 if tag(stmt) == 'allocate-stmt':
222 # Replace upper:lower into upper,lower
223 lowerBounds = allocateArg.findall('.//{*}lower-bound')
224 if lowerBounds is not None:
225 for bounds in lowerBounds:
226 bounds.tail = ','
227 arrayR = allocateArg.find('.//{*}array-R')
228 if arrayR is None:
229 # Particular case of multiple parensR such as Tjacobi(level)%r(nz)
230 parensR = allocateArg.findall('.//{*}parens-R')
231 if len(parensR) > 1:
232 parensR[-1].text = ','
233 else:
234 if allocateArg.find('.//{*}component-R'):
235 removeLastComma = False
236 else:
237 parensR[0].text = ','
238 else:
239 arrayR.text = ','
240 if removeLastComma:
241 allocateArg.tail = ''
242
243 @debugDecor
244 def addACCData(self):
245 """
246 Add !$acc data directives for GPU data transfer.
247
248 For each subroutine, inserts:
249 - !$acc data present (array1, array2, ...) after declarations
250 - !$acc end data at the end of the routine
251
252 Only affects INTENT arrays (IN, OUT, INOUT).
253
254 Examples
255 --------
256 >>> pft = PYFT('gpu_code.F90')
257 >>> pft.addACCData()
258 """
259 scopes = self.getScopes()
260 if scopes[0].path.split('/')[-1].split(':')[1][:4] == 'MODD':
261 return
262 for scope in scopes:
263 # Do not add !$acc data directives to :
264 # - MODULE or FUNCTION object,
265 # - interface subroutine from a MODI
266 # but only to SUBROUTINES
267 if 'sub:' in scope.path and 'func' not in scope.path and 'interface' not in scope.path:
268 # Look for all intent arrays only
269 arraysIntent = []
270 for var in scope.varList:
271 # intent arrays, not of type TYPE (only REAL, INTEGER, CHARACTER)
272 if var['arg'] and var['as'] and 'TYPE' not in var['t'] and \
273 var['scopePath'] == scope.path:
274 arraysIntent.append(var['n'])
275 # Check if there is any intent variables
276 if len(arraysIntent) == 0:
277 break
278
279 # 1) First !$acc data present()
280 listVar = "!$acc data present ( "
281 count = 0
282 for var in arraysIntent:
283 if count > 6:
284 listVar = listVar + '\n!$acc & '
285 count = 0
286 listVar = listVar + var + ", &"
287 count += 1
288 listVarEnd = listVar[:-3] # remove last comma and &
289 accAddMultipleLines = createExpr(listVarEnd + ')')
290 idx = scope.insertStatement(scope.indent(accAddMultipleLines[0]), first=True)
291
292 # 2) multi-lines !$acc &
293 for iLine, line in enumerate(accAddMultipleLines[1:]):
294 scope.insert(idx + 1 + iLine, line)
295
296 # 3) !$acc end data
297 comment = createElem('C', text='!$acc end data', tail='\n')
298 scope.insertStatement(scope.indent(comment), first=False)
299
300 @debugDecor
301 def addACCRoutineSeq(self, stopScopes):
302 """
303 Add !$acc routine seq directive to subroutines.
304
305 Parameters
306 ----------
307 stopScopes : list of str
308 Scope paths where to stop adding directives.
309
310 Examples
311 --------
312 >>> pft = PYFT('gpu_code.F90')
313 >>> pft.addACCRoutineSeq(['module:MOD/sub:SUB'])
314 # Adds !$acc routine (SUB) seq to SUB subroutine
315 """
316 for scope in self.getScopes():
317 if self.tree.isUnderStopScopes(scope.path, stopScopes,
318 includeInterfaces=True,
319 includeStopScopes=True):
320 name = n2name(scope[0].find('.//{*}N')).upper()
321 acc = createElem('C', text=f'!$acc routine ({name}) seq',
322 tail=scope[0].tail)
323 scope[0].tail = '\n'
324 scope.insert(1, acc)
addACCRoutineSeq(self, stopScopes)
Definition openacc.py:301