35 Remove MNH OpenACC bypass macros for non-Cray compilers.
37 Removes the following directives:
39 - !$mnh_undef(OPENACC)
41 - !$mnh_define(OPENACC)
43 self.removeComments(exclDirectives=[],
44 pattern=re.compile(
r'^\!\$mnh_undef\(LOOP\)'))
45 self.removeComments(exclDirectives=[],
46 pattern=re.compile(
r'^\!\$mnh_undef\(OPENACC\)'))
47 self.removeComments(exclDirectives=[],
48 pattern=re.compile(
r'^\!\$mnh_define\(LOOP\)'))
49 self.removeComments(exclDirectives=[],
50 pattern=re.compile(
r'^\!\$mnh_define\(OPENACC\)'))
55 Handle CRAY compiler vectorization issues with GPU directives.
57 On compute kernels with !$acc loop independent collapse(X) directives:
58 - If BR_ functions are used: removes !$acc loop collapse and uses DO CONCURRENT
59 - If !$mnh_undef(OPENACC) is present: removes !$acc loop collapse
60 - If !$mnh_undef(LOOP) is present: converts nested loops to DO CONCURRENT
64 Works around CRAY compiler vectorization bugs with BR_ (bit-reproducibility)
65 functions and OpenACC directives.
67 def checkPresenceofBR(node):
68 """Return True if a BR_ (math BIT-REPRODUCTIBILITY) function is present in the node"""
69 mathBRList = [
'ALOG',
'LOG',
'EXP',
'COS',
'SIN',
'ASIN',
'ATAN',
'ATAN2',
71 namedE = node.findall(
'.//{*}named-E/{*}N/{*}n')
73 if alltext(el)
in [
'BR_' + e
for e
in mathBRList]:
77 def getStatementsInDoConstruct(node, savelist):
79 if 'do-construct' in tag(sNode):
80 getStatementsInDoConstruct(sNode, savelist)
81 elif 'do-stmt' not in tag(sNode):
82 savelist.append(sNode)
85 useAccLoopIndependent =
True
87 comments = self.findall(
'.//{*}C')
89 for comment
in comments:
90 if comment.text.startswith(
'!$mnh_undef(LOOP)'):
91 useNestedLoops =
False
92 if comment.text.startswith(
'!$mnh_undef(OPENACC)'):
93 useAccLoopIndependent =
False
95 if comment.text.startswith(
'!$mnh_define(LOOP)'):
97 if comment.text.startswith(
'!$mnh_define(OPENACC)'):
98 useAccLoopIndependent =
True
100 if comment.text.startswith(
'!$acc loop independent collapse('):
102 par = self.getParent(comment)
103 ind = list(par).index(comment)
104 nestedLoop = par[ind + 1]
106 getStatementsInDoConstruct(nestedLoop, statements)
110 for stmt
in statements:
111 if checkPresenceofBR(stmt):
117 if not useAccLoopIndependent
or isBRPresent:
118 toremove.append((self, comment))
121 if not useNestedLoops
or isBRPresent:
123 doStmt = nestedLoop.findall(
'.//{*}do-stmt')
125 for do
in reversed(doStmt):
126 table[do.find(
'.//{*}do-V/{*}named-E/{*}N/{*}n').text] = \
127 [alltext(do.find(
'.//{*}lower-bound')),
128 alltext(do.find(
'.//{*}upper-bound'))]
131 inner, outer, _ = self.createDoConstruct(table,
132 indent=len(nestedLoop.tail),
136 for stmt
in statements:
137 inner.insert(-1, stmt)
140 par.insert(ind, outer)
141 toremove.append((par, nestedLoop))
144 for parent, elem
in toremove:
150 Convert ALLOCATE/DEALLOCATE to HIP versions for AMD GPUs.
152 Transforms memory allocation calls for variables that are:
153 - Sent to GPU via !$acc enter data copyin
154 - Created on GPU via !$acc enter data create
156 Required for AMD MI250X GPU managed memory usage (e.g., Adastra cluster).
160 >>> pft = PYFT('gpu_code.F90')
161 >>> pft.allocatetoHIP()
162 # ALLOCATE(X) becomes CALL MNH_HIPALLOCATE(X)
164 scopes = self.getScopes()
167 comments = scope.findall(
'.//{*}C')
168 pointers = scope.findall(
'.//{*}pointer-a-stmt')
170 for coms
in comments:
171 if (
'!$acc enter data' in coms.text
or '!$acc exit data' in coms.text) \
172 and coms.text.count(
'!') == 1:
173 if coms.text.count(
'(') == 1:
175 varsToChange.extend(coms.text.split(
')')[0].split(
'(')[1:][0].split(
','))
178 variableName = re.search(
r'\(.*\)', coms.text).group(0)[1:-1]
179 varsToChange.insert(0, variableName)
181 if len(pointers) > 0:
182 for i, var
in enumerate(varsToChange):
183 for pointer
in pointers:
184 if alltext(pointer.find(
'.//{*}E-1/{*}named-E')) == var:
185 varsToChange[i] = alltext(pointer.find(
'.//{*}E-2/{*}named-E'))
187 for i, var
in enumerate(varsToChange):
188 varsToChange[i] = var.replace(
' ',
'')
190 if len(varsToChange) > 0:
191 scope.addModuleVar([(scope.path,
'MODE_MNH_HIPFORT',
None)])
193 allocateStmts = scope.findall(
'.//{*}allocate-stmt')
194 allocateStmts.extend(scope.findall(
'.//{*}deallocate-stmt'))
196 for stmt
in allocateStmts:
198 allocateArg = stmt.find(
'.//{*}arg-spec')
199 arrayR = allocateArg.find(
'.//{*}array-R')
202 parensR = allocateArg.findall(
'.//{*}parens-R')
203 if len(parensR) == 2:
205 varsChecking = alltext(allocateArg).split(
'%', maxsplit=1)[0]+
'%' + \
206 alltext(allocateArg).split(
'%')[1].split(
'(', maxsplit=1)[0]
207 elif len(parensR) == 1:
209 if allocateArg.find(
'.//{*}component-R'):
210 varsChecking = alltext(allocateArg)
212 varsChecking = alltext(allocateArg).split(
'(', maxsplit=1)[0]
213 elif len(parensR) == 0:
214 varsChecking = alltext(allocateArg).split(
'(', maxsplit=1)[0]
216 varsChecking = alltext(allocateArg).split(
'(', maxsplit=1)[0]
218 if varsChecking
in varsToChange:
219 removeLastComma =
True
220 stmt.text =
"CALL MNH_HIP" + stmt.text
221 if tag(stmt) ==
'allocate-stmt':
223 lowerBounds = allocateArg.findall(
'.//{*}lower-bound')
224 if lowerBounds
is not None:
225 for bounds
in lowerBounds:
227 arrayR = allocateArg.find(
'.//{*}array-R')
230 parensR = allocateArg.findall(
'.//{*}parens-R')
232 parensR[-1].text =
','
234 if allocateArg.find(
'.//{*}component-R'):
235 removeLastComma =
False
237 parensR[0].text =
','
241 allocateArg.tail =
''
246 Add !$acc data directives for GPU data transfer.
248 For each subroutine, inserts:
249 - !$acc data present (array1, array2, ...) after declarations
250 - !$acc end data at the end of the routine
252 Only affects INTENT arrays (IN, OUT, INOUT).
256 >>> pft = PYFT('gpu_code.F90')
259 scopes = self.getScopes()
260 if scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
267 if 'sub:' in scope.path
and 'func' not in scope.path
and 'interface' not in scope.path:
270 for var
in scope.varList:
272 if var[
'arg']
and var[
'as']
and 'TYPE' not in var[
't']
and \
273 var[
'scopePath'] == scope.path:
274 arraysIntent.append(var[
'n'])
276 if len(arraysIntent) == 0:
280 listVar =
"!$acc data present ( "
282 for var
in arraysIntent:
284 listVar = listVar +
'\n!$acc & '
286 listVar = listVar + var +
", &"
288 listVarEnd = listVar[:-3]
289 accAddMultipleLines = createExpr(listVarEnd +
')')
290 idx = scope.insertStatement(scope.indent(accAddMultipleLines[0]), first=
True)
293 for iLine, line
in enumerate(accAddMultipleLines[1:]):
294 scope.insert(idx + 1 + iLine, line)
297 comment = createElem(
'C', text=
'!$acc end data', tail=
'\n')
298 scope.insertStatement(scope.indent(comment), first=
False)