245 loopVar=None, reuseLoop=True, funcList=None,
246 updateMemSet=False, updateCopy=False, addAccIndependentCollapse=True):
248 Transform array syntax assignments into explicit DO loops.
250 Converts Fortran array syntax (e.g., A(:) = B(:)) into equivalent DO loop form.
254 concurrent : bool, optional
255 If True, use 'DO CONCURRENT' loops instead of simple 'DO' loops.
257 useMnhExpand : bool, optional
258 If True, respect mnh_expand directives to transform entire blocks
259 into a single loop. Default is True.
260 everywhere : bool, optional
261 If True, transform all array syntax in the code.
262 If False, only transform sections marked with !$mnh_expand directives.
264 loopVar : callable or None, optional
265 Function to determine loop index variable name.
266 Takes arguments: (lowerDecl, upperDecl, lowerUsed, upperUsed, name, index)
267 Returns: str (variable name), True (auto-generate name), or False (skip).
268 None (default) auto-generates variable names (J1, J2, etc.).
269 reuseLoop : bool, optional
270 If True, attempt to reuse loops when consecutive statements
271 have identical bounds. Default is True.
272 funcList : list of str, optional
273 Additional function names to recognize as array functions.
274 These functions will not be expanded. Default is None (empty list).
275 updateMemSet : bool, optional
276 If True, transform constant array initializations (e.g., A(:) = 0)
277 into DO loops. Default is False.
278 updateCopy : bool, optional
279 If True, transform array copy operations (e.g., A(:) = B(:))
280 into DO loops. Default is False.
281 addAccIndependentCollapse : bool, optional
282 If True, add !$acc loop independent collapse(N) directive
283 before DO constructs. Default is True.
289 Transformation Examples
290 ----------------------
297 DO J1 = LBOUND(A, 1), UBOUND(A, 1)
298 A(J1) = B(J1) + C(J1)
302 DO CONCURRENT (J1=LBOUND(A, 1):UBOUND(A, 1))
303 A(J1) = B(J1) + C(J1)
309 WHERE (MASK(:)) X(:) = Y(:)
312 DO J1 = 1, SIZE(X, 1)
313 IF (MASK(J1)) X(J1) = Y(J1)
318 - Only transforms array syntax using explicit ':' notation.
319 - Intrinsic array functions (COUNT, ANY, SUM, etc.) are preserved.
320 - Does not transform:
321 - A=A(:) (no-op on left side)
322 - A(:)=A (single array without slice on right side)
323 - When useMnhExpand=True, requires specific directive format:
324 !$mnh_expand_array(INDEX=bounds)
325 ... code to transform ...
326 !$mnh_end_expand_array(INDEX=bounds)
362 def decode(directive):
364 Decode mnh_expand directive
365 :param directive: mnh directive text
366 :return: (table, kind) where
367 table is a dictionnary: keys are variable names, values are tuples with first
369 kind is 'array' or 'where'
377 table = directive.split(
'(')[1].split(
')')[0].split(
',')
378 table = {c.split(
'=')[0]: c.split(
'=')[1].split(
':')
380 table.pop(
'OPENACC',
None)
381 if directive.lstrip(
' ').startswith(
'!$mnh_expand'):
382 kind = directive[13:].lstrip(
' ').split(
'(')[0].strip()
384 kind = directive[17:].lstrip(
' ').split(
'(')[0].strip()
387 def updateStmt(stmt, table, kind, extraindent, parent, scope):
389 Updates the statement given the table dictionnary '(:, :)' is replaced by '(JI, JK)' if
390 table.keys() is ['JI', 'JK']
391 :param stmt: statement to update
392 :param table: dictionnary retruned by the decode function
393 :param kind: kind of mnh directives: 'array' or 'where'
394 or None if transformation is not governed by
396 :param scope: current scope
399 def addExtra(node, extra):
400 """Helper function to add indentation spaces"""
401 if extra != 0
and (node.tail
is not None)
and '\n' in node.tail:
407 node.tail = re.sub(
r"(\n[ ]*)(\Z|[^\n ]+)",
408 r"\1" + extra *
' ' +
r"\2", node.tail)
410 addExtra(stmt, extraindent)
413 elif tag(stmt) ==
'cpp':
414 i = list(parent).index(stmt)
419 parent[i - 1].tail = parent[i - 1].tail.rstrip(
' ')
420 elif tag(stmt) ==
'a-stmt':
421 sss = stmt.findall(
'./{*}E-1/{*}named-E/{*}R-LT/{*}array-R/' +
422 '{*}section-subscript-LT/{*}section-subscript')
423 if len([ss
for ss
in sss
if ':' in alltext(ss)]) != len(table):
424 raise PYFTError(
"Inside code sections to transform in DO loops, " +
425 "all affectations must use ':'.\n" +
426 "This is not the case in:\n{stmt}".format(stmt=alltext(stmt)))
427 if stmt.find(
'./{*}E-1/{*}named-E/{*}N').tail
is not None and kind
is not None:
428 raise PYFTError(
"To keep the compatibility with the filepp version of loop " +
429 "expansion, nothing must appear between array names and " +
430 "opening parethesis inside mnh directive sections.")
433 for namedE
in stmt.findall(
'.//{*}R-LT/..'):
434 scope.arrayR2parensR(namedE, table)
435 for cnt
in stmt.findall(
'.//{*}cnt'):
436 addExtra(cnt, extraindent)
437 elif tag(stmt) ==
'if-stmt':
439 "An if statement is inside a code section transformed in DO loop in %s",
442 updateStmt(stmt.find(
'./{*}action-stmt')[0], table, kind, 0, stmt, scope)
443 elif tag(stmt) ==
'if-construct':
445 "An if construct is inside a code section transformed in DO loop in %s",
448 for ifBlock
in stmt.findall(
'./{*}if-block'):
449 for child
in ifBlock:
450 if tag(child)
not in (
'if-then-stmt',
'else-if-stmt',
451 'else-stmt',
'end-if-stmt'):
452 updateStmt(child, table, kind, extraindent, ifBlock, scope)
455 addExtra(child, extraindent)
456 for cnt
in child.findall(
'.//{*}cnt'):
458 addExtra(cnt, extraindent)
459 elif tag(stmt) ==
'where-stmt':
461 stmt.tag = f
'{{{NAMESPACE}}}if-stmt'
462 stmt.text =
'IF (' + stmt.text.split(
'(', 1)[1]
464 updateStmt(stmt.find(
'./{*}action-stmt')[0], table, kind,
465 extraindent, stmt, scope)
466 mask = stmt.find(
'./{*}mask-E')
467 mask.tag = f
'{{{NAMESPACE}}}condition-E'
468 for namedE
in mask.findall(
'.//{*}R-LT/..'):
469 scope.arrayR2parensR(namedE, table)
470 for cnt
in stmt.findall(
'.//{*}cnt'):
471 addExtra(cnt, extraindent)
472 elif tag(stmt) ==
'where-construct':
473 if kind !=
'where' and kind
is not None:
474 raise PYFTError(
'To keep the compatibility with the filepp version of loop " + \
475 "expansion, no where construct must appear " + \
476 "in mnh_expand_array blocks.')
478 stmt.tag = f
'{{{NAMESPACE}}}if-construct'
480 for whereBlock
in stmt.findall(
'./{*}where-block'):
481 whereBlock.tag = f
'{{{NAMESPACE}}}if-block'
482 for child
in whereBlock:
483 if tag(child) ==
'end-where-stmt':
485 child.tag = f
'{{{NAMESPACE}}}end-if-stmt'
486 child.text =
'END IF'
488 addExtra(child, extraindent)
489 elif tag(child)
in (
'where-construct-stmt',
'else-where-stmt'):
491 addExtra(child, extraindent)
492 if tag(child) ==
'where-construct-stmt':
494 child.tag = f
'{{{NAMESPACE}}}if-then-stmt'
495 child.text =
'IF (' + child.text.split(
'(', 1)[1]
500 if '(' in child.text:
502 child.tag = f
'{{{NAMESPACE}}}else-if-stmt'
503 child.text =
'ELSE IF (' + child.text.split(
'(', 1)[1]
506 child.tag = f
'{{{NAMESPACE}}}else-stmt'
508 for mask
in child.findall(
'./{*}mask-E'):
510 mask.tag = f
'{{{NAMESPACE}}}condition-E'
512 for namedE
in mask.findall(
'.//{*}R-LT/..'):
514 scope.arrayR2parensR(namedE, table)
515 for cnt
in child.findall(
'.//{*}cnt'):
517 addExtra(cnt, extraindent)
519 updateStmt(child, table, kind, extraindent, whereBlock, scope)
521 raise PYFTError(
'Unexpected tag found in mnh_expand ' +
522 'directives: {t}'.format(t=tag(stmt)))
525 def closeLoop(loopdesc):
526 """Helper function to deal with indentation"""
528 inner, outer, indent, extraindent = loopdesc
529 if inner[-2].tail
is not None:
531 outer.tail = inner[-2].tail[:-extraindent]
532 inner[-2].tail =
'\n' + (indent + extraindent - 2) *
' '
539 def recur(elem, scope):
543 for ie, sElem
in enumerate(list(elem)):
544 if tag(sElem) ==
'C' and sElem.text.lstrip(
' ').startswith(
'!$mnh_expand')
and \
548 raise PYFTError(
'Nested mnh_directives are not allowed')
550 inEverywhere = closeLoop(inEverywhere)
553 table, kind = decode(sElem.text)
555 indent = len(sElem.tail) - len(sElem.tail.rstrip(
' '))
556 toremove.append((elem, sElem))
560 if elem[ie - 1].tail
is None:
561 elem[ie - 1].tail =
''
562 elem[ie - 1].tail += sElem.tail.replace(
'\n',
'', 1).rstrip(
' ')
565 if addAccIndependentCollapse:
566 accCollapse = createElem(
'C', text=
'!$acc loop independent collapse(' +
567 str(len(table.keys())) +
')',
568 tail=
'\n' + indent *
' ')
569 toinsert.append((elem, accCollapse, ie))
572 inner, outer, extraindent = scope.createDoConstruct(table, indent=indent,
573 concurrent=concurrent)
574 toinsert.append((elem, outer, ie))
576 elif (tag(sElem) ==
'C' and
577 sElem.text.lstrip(
' ').startswith(
'!$mnh_end_expand')
and useMnhExpand):
580 raise PYFTError(
'End mnh_directive found before begin directive ' +
581 'in {f}'.format(f=scope.getFileName()))
582 if (table, kind) != decode(sElem.text):
583 raise PYFTError(
"Opening and closing mnh directives must be conform " +
584 "in {f}".format(f=scope.getFileName()))
586 toremove.append((elem, sElem))
590 outer.tail += sElem.tail.replace(
'\n',
'', 1)
592 elem[ie - 1].tail = elem[ie - 1].tail[:-2]
596 toremove.append((elem, sElem))
597 inner.insert(-1, sElem)
599 updateStmt(sElem, table, kind, extraindent, inner, scope)
601 elif everywhere
and tag(sElem)
in (
'a-stmt',
'if-stmt',
'where-stmt',
606 if tag(sElem) ==
'a-stmt':
608 arr = sElem.find(
'./{*}E-1/{*}named-E/{*}R-LT/{*}array-R/../..')
612 nodeE2 = sElem.find(
'./{*}E-2')
614 num = len(nodeE2.findall(
'.//{*}array-R'))
621 elif (len(nodeE2) == 1
and tag(nodeE2[0]) ==
'named-E' and num == 1
and
622 nodeE2[0].find(
'.//{*}parens-R')
is None):
628 if (isMemSet
and not updateMemSet)
or (isCopy
and not updateCopy):
630 elif tag(sElem) ==
'if-stmt':
632 arr = sElem.find(
'./{*}action-stmt/{*}a-stmt/{*}E-1/' +
633 '{*}named-E/{*}R-LT/{*}array-R/../..')
636 scope.changeIfStatementsInIfConstructs(singleItem=sElem)
639 elif tag(sElem) ==
'where-stmt':
640 arr = sElem.find(
'./{*}mask-E//{*}named-E/{*}R-LT/{*}array-R/../..')
641 elif tag(sElem) ==
'where-construct':
642 arr = sElem.find(
'./{*}where-block/{*}where-construct-stmt/' +
643 '{*}mask-E//{*}named-E/{*}R-LT/{*}array-R/../..')
650 elif len(set(alltext(a).count(
':')
651 for a
in sElem.findall(
'.//{*}R-LT/{*}array-R'))) > 1:
655 elif len(set([
'ALL',
'ANY',
'COSHAPE',
'COUNT',
'CSHIFT',
'DIMENSION',
656 'DOT_PRODUCT',
'EOSHIFT',
'LBOUND',
'LCOBOUND',
'MATMUL',
657 'MAXLOC',
'MAXVAL',
'MERGE',
'MINLOC',
'MINVAL',
'PACK',
658 'PRODUCT',
'REDUCE',
'RESHAPE',
'SHAPE',
'SIZE',
'SPREAD',
659 'SUM',
'TRANSPOSE',
'UBOUND',
'UCOBOUND',
'UNPACK'] +
660 (funcList
if funcList
is not None else [])
661 ).intersection(set(n2name(nodeN)
for nodeN
662 in sElem.findall(
'.//{*}named-E/{*}N')))) > 0:
668 newtable, varNew = scope.findArrayBounds(arr, loopVar, newVarList)
671 if var
not in newVarList:
672 newVarList.append(var)
679 inEverywhere = closeLoop(inEverywhere)
682 if not (inEverywhere
and table == newtable):
684 inEverywhere = closeLoop(inEverywhere)
686 if ie != 0
and elem[ie - 1].tail
is not None:
689 tail = tailSave.get(elem[ie - 1], elem[ie - 1].tail)
690 indent = len(tail) - len(tail.rstrip(
' '))
696 inner, outer, extraindent = scope.createDoConstruct(
697 table, indent=indent, concurrent=concurrent)
698 toinsert.append((elem, outer, ie))
699 inEverywhere = (inner, outer, indent, extraindent)
700 tailSave[sElem] = sElem.tail
701 toremove.append((elem, sElem))
702 inner.insert(-1, sElem)
704 updateStmt(sElem, table, kind, extraindent, inner, scope)
707 inEverywhere = closeLoop(inEverywhere)
710 inEverywhere = closeLoop(inEverywhere)
714 inEverywhere = closeLoop(inEverywhere)
716 for scope
in self.getScopes():
719 for elem, outer, ie
in toinsert[::-1]:
720 elem.insert(ie, outer)
722 for parent, elem
in toremove:
725 self.addVar([(v[
'scopePath'], v[
'n'], f
"INTEGER :: {v['n']}",
None)
726 for v
in newVarList])
810 def inline(self, subContained, callStmt, mainScope,
811 simplify=False, loopVar=None):
813 Inline a single contained subroutine at its call site.
815 This method performs the actual inlining of a contained subroutine
816 into the calling scope. It handles:
817 - ELEMENTAL subroutines with array arguments
818 - Optional arguments (PRESENT intrinsic)
819 - Variable name conflicts
820 - USE statement merging
824 subContained : xml element
825 XML fragment corresponding to the contained subroutine scope.
826 callStmt : xml element
827 The call-stmt node to replace with inlined code.
828 mainScope : PYFTscope
829 Scope of the main (calling) subroutine.
830 simplify : bool, optional
831 If True, remove empty constructs and unused variables
832 after inlining. Default is False.
833 loopVar : callable or None, optional
834 Function to determine loop index variable name.
835 Used when inlining ELEMENTAL subroutines called on arrays.
836 Takes: (lowerDecl, upperDecl, lowerUsed, upperUsed, name, index)
837 Returns: str, True (auto-generate), or False (skip).
841 - For ELEMENTAL subroutines on arrays: DO loops are introduced.
842 - Optional arguments: PRESENT(var) is replaced with .TRUE. or .FALSE.
843 - Missing optional arguments: code paths using them are removed.
844 - Name conflicts: local variables are renamed with _N suffixes.
846 def setPRESENTby(node, var, val):
848 Replace PRESENT(var) by .TRUE. if val is True, by .FALSE. otherwise on node
850 :param node: xml node to work on (a contained subroutine)
851 :param var: string of the name of the optional variable to check
853 for namedE
in node.findall(
'.//{*}named-E/{*}N/..'):
854 if n2name(namedE.find(
'./{*}N')).upper() ==
'PRESENT':
855 presentarg = n2name(namedE.find(
'./{*}R-LT/{*}parens-R/{*}element-LT/'
856 '{*}element/{*}named-E/{*}N'))
857 if presentarg.upper() == var.upper():
858 for nnn
in namedE[:]:
860 namedE.tag = f
'{{{NAMESPACE}}}literal-E'
861 namedE.text =
'.TRUE.' if val
else '.FALSE.'
864 parent = mainScope.getParent(callStmt)
867 if tag(parent) ==
'action-stmt':
868 mainScope.changeIfStatementsInIfConstructs(mainScope.getParent(parent))
869 parent = mainScope.getParent(callStmt)
873 prefix = subContained.findall(
'.//{*}prefix')
874 if len(prefix) > 0
and 'ELEMENTAL' in [p.text.upper()
for p
in prefix]:
876 mainScope.addArrayParenthesesInNode(callStmt)
879 mainScope.addExplicitArrayBounds(node=callStmt)
882 arrayRincallStmt = callStmt.findall(
'.//{*}array-R')
883 if len(arrayRincallStmt) > 0:
885 table, _ = mainScope.findArrayBounds(mainScope.getParent(arrayRincallStmt[0], 2),
889 for varName
in table.keys():
890 if not mainScope.varList.findVar(varName):
891 var = {
'as': [],
'asx': [],
892 'n': varName,
'i':
None,
't':
'INTEGER',
'arg':
False,
893 'use':
False,
'opt':
False,
'allocatable':
False,
894 'parameter':
False,
'init':
None,
'scopePath': mainScope.path}
895 mainScope.addVar([[mainScope.path, var[
'n'],
896 mainScope.varSpec2stmt(var),
None]])
899 inner, outer, _ = mainScope.createDoConstruct(table)
902 inner.insert(-1, callStmt)
904 parent.insert(list(parent).index(callStmt), outer)
905 parent.remove(callStmt)
907 for namedE
in callStmt.findall(
'./{*}arg-spec/{*}arg/{*}named-E'):
909 if namedE.find(
'./{*}R-LT'):
910 mainScope.arrayR2parensR(namedE, table)
913 node = copy.deepcopy(subContained)
918 varList = copy.deepcopy(self.varList)
919 for var
in [var
for var
in varList.restrict(subContained.path,
True)
920 if not var[
'arg']
and not var[
'use']]:
922 if varList.restrict(mainScope.path,
True).findVar(var[
'n']):
925 newName = re.sub(
r'_\d+$',
'', var[
'n'])
927 while (varList.restrict(subContained.path,
True).findVar(newName +
'_' + str(i))
or
928 varList.restrict(mainScope.path,
True).findVar(newName +
'_' + str(i))):
930 newName +=
'_' + str(i)
931 node.renameVar(var[
'n'], newName)
932 subst.append((var[
'n'], newName))
935 var[
'scopePath'] = mainScope.path
936 localVarToAdd.append(var)
939 for oldName, newName
in subst:
940 for var
in localVarToAdd + varList[:]:
941 if var[
'as']
is not None:
942 var[
'as'] = [[re.sub(
r'\b' + oldName +
r'\b', newName, dim[i])
943 if dim[i]
is not None else None
945 for dim
in var[
'as']]
952 localUseToAdd = node.findall(
'./{*}use-stmt')
953 for sNode
in node.findall(
'./{*}T-decl-stmt') + localUseToAdd + \
954 node.findall(
'./{*}implicit-none-stmt'):
957 while tag(node[icom]) ==
'C':
958 if node[icom].text.startswith(
'!$acc'):
959 if 'routine' in node[icom].text:
960 node.remove(node[icom])
964 node.remove(node[icom])
972 for argN
in subContained.findall(
'.//{*}subroutine-stmt/{*}dummy-arg-LT/{*}arg-N'):
973 vartable[alltext(argN).upper()] =
None
974 for iarg, arg
in enumerate(callStmt.findall(
'.//{*}arg')):
975 key = arg.find(
'.//{*}arg-N')
978 dummyName = alltext(key).upper()
981 dummyName = list(vartable.keys())[iarg]
983 nodeRLTarray = argnode.findall(
'.//{*}R-LT/{*}array-R')
984 if len(nodeRLTarray) > 0:
986 if len(nodeRLTarray) > 1
or \
987 argnode.find(
'./{*}R-LT/{*}array-R')
is None or \
988 tag(argnode) !=
'named-E':
990 raise PYFTError(
'Argument to complicated: ' + str(alltext(argnode)))
991 dim = nodeRLTarray[0].find(
'./{*}section-subscript-LT')[:]
997 argname =
"".join(argnode.itertext())
1000 tmp = copy.deepcopy(argnode)
1001 nodeRLT = tmp.find(
'./{*}R-LT')
1002 nodeRLT.remove(nodeRLT.find(
'./{*}array-R'))
1003 argname =
"".join(tmp.itertext())
1004 vartable[dummyName] = {
'node': argnode,
'name': argname,
'dim': dim}
1007 for dummyName
in [dummyName
for (dummyName, value)
in vartable.items()
1008 if value
is not None]:
1009 setPRESENTby(node, dummyName,
True)
1010 for dummyName
in [dummyName
for (dummyName, value)
in vartable.items()
1012 setPRESENTby(node, dummyName,
False)
1015 for dummyName
in [dummyName
for (dummyName, value)
in vartable.items()
if value
is None]:
1016 for nodeN
in [nodeN
for nodeN
in node.findall(
'.//{*}named-E/{*}N')
1017 if n2name(nodeN).upper() == dummyName]:
1020 par = node.getParent(nodeN, level=2)
1021 allreadySuppressed = []
1022 while par
and not removed
and par
not in allreadySuppressed:
1025 if tagName
in (
'a-stmt',
'print-stmt'):
1028 elif tagName ==
'call-stmt':
1033 raise NotImplementedError(
'call-stmt not (yet?) implemented')
1034 elif tagName
in (
'if-stmt',
'where-stmt'):
1037 elif tagName
in (
'if-then-stmt',
'else-if-stmt',
'where-construct-stmt',
1041 toSuppress = node.getParent(par, 2)
1042 elif tagName
in (
'select-case-stmt',
'case-stmt'):
1045 toSuppress = node.getParent(par, 2)
1046 elif tagName.endswith(
'-block')
or tagName.endswith(
'-stmt')
or \
1047 tagName.endswith(
'-construct'):
1054 raise PYFTError((
"We shouldn't be here. A case may have been " +
1055 "overlooked (tag={tag}).".format(tag=tagName)))
1056 if toSuppress
is not None:
1058 if toSuppress
not in allreadySuppressed:
1061 node.removeStmtNode(toSuppress,
False, simplify)
1062 allreadySuppressed.extend(list(toSuppress.iter()))
1064 par = node.getParent(par)
1067 for name, dummy
in vartable.items():
1071 for namedE
in [namedE
for namedE
in node.findall(
'.//{*}named-E/{*}N/{*}n/../..')
1072 if n2name(namedE.find(
'{*}N')).upper() == name]:
1074 nodeN = namedE.find(
'./{*}N')
1075 ns = nodeN.findall(
'./{*}n')
1076 ns[0].text = n2name(nodeN)
1082 descMain = varList.restrict(mainScope.path,
True).findVar(dummy[
'name'])
1083 descSub = varList.restrict(subContained.path,
True).findVar(name)
1087 if var[
'as']
is not None:
1088 var[
'as'] = [[re.sub(
r'\b' + name +
r'\b', dummy[
'name'], dim[i])
1089 if dim[i]
is not None else None
1091 for dim
in var[
'as']]
1096 nodeRLT = namedE.find(
'./{*}R-LT')
1097 if nodeRLT
is not None and tag(nodeRLT[0]) !=
'component-R':
1099 assert tag(nodeRLT[0])
in (
'array-R',
'parens-R'),
'Internal error'
1100 slices = nodeRLT[0].findall(
'./{*}section-subscript-LT/' +
1101 '{*}section-subscript')
1102 slices += nodeRLT[0].findall(
'./{*}element-LT/{*}element')
1105 if (descMain
is not None and descMain[
'as']
is not None and
1106 len(descMain[
'as']) > 0)
or \
1107 len(descSub[
'as']) > 0
or dummy[
'dim']
is not None:
1109 if len(descSub[
'as']) > 0:
1110 ndim = len(descSub[
'as'])
1113 if dummy[
'dim']
is not None:
1116 ndim = len([d
for d
in dummy[
'dim']
if ':' in alltext(d)])
1119 ndim = len(descMain[
'as'])
1120 ns[0].text +=
'(' + (
', '.join([
':'] * ndim)) +
')'
1121 updatedNamedE = createExprPart(alltext(namedE))
1122 namedE.tag = updatedNamedE.tag
1123 namedE.text = updatedNamedE.text
1124 for nnn
in namedE[:]:
1126 namedE.extend(updatedNamedE[:])
1127 slices = namedE.find(
'./{*}R-LT')[0].findall(
1128 './{*}section-subscript-LT/{*}section-subscript')
1134 namedE.find(
'./{*}N')[0].text = dummy[
'name']
1142 for isl, sl
in enumerate(slices):
1144 if len(descSub[
'as']) == 0
or descSub[
'as'][isl][1]
is None:
1147 if dummy[
'dim']
is not None:
1149 tagName =
'./{*}lower-bound' if i == 0
else './{*}upper-bound'
1150 descSub[i] = dummy[
'dim'][isl].find(tagName)
1153 if descSub[i]
is not None:
1155 descSub[i] = alltext(descSub[i])
1158 if descMain
is not None and descMain[
'as'][isl][1]
is not None:
1160 descSub[i] = descMain[
'as'][isl][i]
1161 if i == 0
and descSub[i]
is None:
1164 descSub[i] =
"L" if i == 0
else "U"
1165 descSub[i] +=
"BOUND({name}, {isl})".format(
1166 name=dummy[
'name'], isl=isl + 1)
1168 descSub[0] = descSub[
'as'][isl][0]
1169 if descSub[0]
is None:
1171 descSub[1] = descSub[
'as'][isl][1]
1183 if dummy[
'dim']
is not None and \
1184 not alltext(dummy[
'dim'][isl]).strip().startswith(
':'):
1185 offset = alltext(dummy[
'dim'][isl].find(
'./{*}lower-bound'))
1187 if descMain
is not None and descMain[
'as']
is not None:
1188 offset = descMain[
'as'][isl][0]
1191 elif offset.strip().startswith(
'-'):
1192 offset =
'(' + offset +
')'
1194 offset =
"LBOUND({name}, {isl})".format(
1195 name=dummy[
'name'], isl=isl + 1)
1196 if offset.upper() == descSub[0].upper():
1199 if descSub[0].strip().startswith(
'-'):
1200 offset +=
'- (' + descSub[0] +
')'
1202 offset +=
'-' + descSub[0]
1206 if tag(sl) ==
'element' or \
1207 (tag(sl) ==
'section-subscript' and ':' not in alltext(sl)):
1211 low = sl.find(
'./{*}lower-bound')
1213 low = createElem(
'lower-bound', tail=sl.text)
1214 low.append(createExprPart(descSub[0]))
1217 up = sl.find(
'./{*}upper-bound')
1219 up = createElem(
'upper-bound')
1220 up.append(createExprPart(descSub[1]))
1223 for bound
in bounds:
1225 if bound[-1].tail
is None:
1227 bound[-1].tail +=
'+' + offset
1233 if dummy[
'dim']
is not None and len(dummy[
'dim']) > len(slices):
1234 slices[-1].tail =
', '
1235 par = node.getParent(slices[-1])
1236 par.extend(dummy[
'dim'][len(slices):])
1241 updatedNamedE = createExprPart(alltext(namedE))
1242 namedE.tag = updatedNamedE.tag
1243 namedE.text = updatedNamedE.text
1244 for nnn
in namedE[:]:
1246 namedE.extend(updatedNamedE[:])
1248 node.remove(node.find(
'./{*}subroutine-stmt'))
1249 node.remove(node.find(
'./{*}end-subroutine-stmt'))
1252 mainScope.addVar([[mainScope.path, var[
'n'], mainScope.varSpec2stmt(var),
None]
1253 for var
in localVarToAdd])
1254 mainScope.addModuleVar([[mainScope.path, n2name(useStmt.find(
'.//{*}module-N//{*}N')),
1255 [n2name(v.find(
'.//{*}N'))
1256 for v
in useStmt.findall(
'.//{*}use-N')]]
1257 for useStmt
in localUseToAdd])
1260 index = list(parent).index(callStmt)
1261 parent.remove(callStmt)
1262 if callStmt.tail
is not None:
1263 if node[-1].tail
is None:
1264 node[-1].tail = callStmt.tail
1266 node[-1].tail = node[-1].tail + callStmt.tail
1267 for node
in node[::-1]:
1269 parent.insert(index, node)
1449 Remove statement nodes with optional code simplification.
1453 nodes : xml element or list of xml elements
1454 Node(s) to remove from the code tree.
1456 If True, also remove variables that become unused after
1457 the deletion of the nodes.
1458 simplifyStruct : bool
1459 If True, also remove empty enclosing constructs
1460 (IF blocks, loops) that become empty after node removal.
1464 >>> pft = PYFT('input.F90')
1465 >>> nodes = pft.findall('.//{*}call-stmt')
1466 >>> pft.removeStmtNode(nodes, simplifyVar=True, simplifyStruct=True)
1470 - Handles nested structures (removes inner statements first).
1471 - When simplifyStruct=True:
1472 - Empty IF blocks are removed
1473 - Empty loops are removed
1474 - WHERE constructs are handled
1475 - When simplifyVar=True:
1476 - Unused local variables are removed
1477 - Empty type declarations are cleaned up
1482 nodesToSuppress = []
1483 if not isinstance(nodes, list):
1486 if tag(node)
in (
'if-stmt',
'where-stmt'):
1487 action = node.find(
'./{*}action-stmt')
1488 if action
is not None and len(action) != 0:
1489 nodesToSuppress.append(action[0])
1491 nodesToSuppress.append(node)
1492 elif tag(node) ==
'}action-stmt':
1494 nodesToSuppress.append(node[0])
1496 nodesToSuppress.append(node)
1498 nodesToSuppress.append(node)
1503 for node
in nodesToSuppress:
1504 scopePath = self.getScopePath(node)
1505 if tag(node) ==
'do-construct':
1507 varToCheck.extend([(scopePath, n2name(arg))
1508 for arg
in node.find(
'./{*}do-stmt').findall(
'.//{*}N')])
1509 elif tag(node)
in (
'if-construct',
'if-stmt'):
1511 varToCheck.extend([(scopePath, n2name(arg))
1512 for arg
in node.findall(
'.//{*}condition-E//{*}N')])
1513 elif tag(node)
in (
'where-construct',
'where-stmt'):
1515 varToCheck.extend([(scopePath, n2name(arg))
1516 for arg
in node.findall(
'.//{*}mask-E//{*}N')])
1517 elif tag(node) ==
'call-stmt':
1519 varToCheck.extend([(scopePath, n2name(arg))
1520 for arg
in node.findall(
'./{*}arg-spec//{*}N')])
1522 varToCheck.append((scopePath,
1523 n2name(node.find(
'./{*}procedure-designator//{*}N'))))
1524 elif tag(node)
in (
'a-stmt',
'print-stmt'):
1525 varToCheck.extend([(scopePath, n2name(arg))
for arg
in node.findall(
'.//{*}N')])
1526 elif tag(node) ==
'selectcase-construct':
1528 varToCheck.extend([(scopePath, n2name(arg))
1529 for arg
in node.findall(
'.//{*}case-E//{*}N')])
1530 varToCheck.extend([(scopePath, n2name(arg))
1531 for arg
in node.findall(
'.//{*}case-value//{*}N')])
1535 for node
in nodesToSuppress:
1536 parent = self.getParent(node)
1537 parents[id(node)] = parent
1538 newlines =
'\n' * (alltext(node).count(
'\n')
if tag(node).endswith(
'-construct')
else 0)
1539 if node.tail
is not None or len(newlines) > 0:
1540 previous = self.getSiblings(node, after=
False)
1541 if len(previous) == 0:
1544 previous = previous[-1]
1545 if previous.tail
is None:
1547 previous.tail = (previous.tail.replace(
'\n',
'') +
1548 (node.tail
if node.tail
is not None else ''))
1552 self.removeVarIfUnused(varToCheck, excludeDummy=
True,
1553 excludeModule=
True, simplify=simplifyVar)
1556 newNodesToSuppress = []
1557 for node
in nodesToSuppress:
1558 parent = parents[id(node)]
1561 if tag(parent) ==
'action-stmt':
1562 newNodesToSuppress.append(self.getParent(parent))
1564 elif simplifyStruct:
1565 if tag(parent) ==
'do-construct' and len(
_nodesInDo(parent)) == 0:
1566 newNodesToSuppress.append(parent)
1567 elif tag(parent) ==
'if-block':
1568 parPar = self.getParent(parent)
1570 newNodesToSuppress.append(parPar)
1571 elif tag(parent) ==
'where-block':
1572 parPar = self.getParent(parent)
1574 newNodesToSuppress.append(parPar)
1575 elif tag(parent) ==
'selectcase-block':
1576 parPar = self.getParent(parent)
1578 newNodesToSuppress.append(parPar)
1580 constructNodes, otherNodes = [], []
1581 for nnn
in newNodesToSuppress:
1582 if tag(nnn).endswith(
'-construct'):
1583 if nnn
not in constructNodes:
1584 constructNodes.append(nnn)
1586 if nnn
not in otherNodes:
1587 otherNodes.append(nnn)
1589 if len(otherNodes) > 0:
1592 for nnn
in constructNodes:
1649 :param loopVariables: ordered dictionnary with loop variables as key and bounds as values.
1650 Bounds are expressed with a 2-tuple.
1651 Keys must be in the same order as the order used when addressing an
1652 element: if loopVariables.keys is [JI, JK], arrays are
1653 addressed with (JI, JK)
1654 :param indent: current indentation
1655 :param concurrent: if False, output is made of nested 'DO' loops
1656 if True, output is made of a single 'DO CONCURRENT' loop
1657 :return: (inner, outer, extraindent) with
1658 - inner the inner do-construct where statements must be added
1659 - outer the outer do-construct to be inserted somewhere
1660 - extraindent the number of added indentation
1661 (2 if concurrent else 2*len(loopVariables))
1687 for var, (lo, up)
in list(loopVariables.items())[::-1]:
1688 nodeV = createElem(
'V', tail=
'=')
1689 nodeV.append(createExprPart(var))
1690 lower, upper = createArrayBounds(lo, up,
'DOCONCURRENT')
1692 triplet = createElem(
'forall-triplet-spec')
1693 triplet.extend([nodeV, lower, upper])
1695 triplets.append(triplet)
1697 tripletLT = createElem(
'forall-triplet-spec-LT', tail=
')')
1698 for triplet
in triplets[:-1]:
1700 tripletLT.extend(triplets)
1702 dostmt = createElem(
'do-stmt', text=
'DO CONCURRENT (', tail=
'\n')
1703 dostmt.append(tripletLT)
1704 enddostmt = createElem(
'end-do-stmt', text=
'END DO')
1706 doconstruct = createElem(
'do-construct', tail=
'\n')
1707 doconstruct.extend([dostmt, enddostmt])
1708 inner = outer = doconstruct
1709 doconstruct[0].tail += (indent + 2) *
' '
1720 def makeDo(var, lo, up):
1721 doV = createElem(
'do-V', tail=
'=')
1722 doV.append(createExprPart(var))
1723 lower, upper = createArrayBounds(lo, up,
'DO')
1725 dostmt = createElem(
'do-stmt', text=
'DO ', tail=
'\n')
1726 dostmt.extend([doV, lower, upper])
1728 enddostmt = createElem(
'end-do-stmt', text=
'END DO')
1730 doconstruct = createElem(
'do-construct', tail=
'\n')
1731 doconstruct.extend([dostmt, enddostmt])
1736 for i, (var, (lo, up))
in enumerate(list(loopVariables.items())[::-1]):
1737 doconstruct = makeDo(var, lo, up)
1739 doconstruct[0].tail += (indent + 2 * i + 2) *
' '
1744 inner.insert(1, doconstruct)
1747 doconstruct.tail += (indent + 2 * i - 2) *
' '
1748 return inner, outer, 2
if concurrent
else 2 * len(loopVariables)
1773 :param item: item to remove from list
1774 :param itemPar: the parent of item (the list)
1777 nodesToSuppress = [item]
1780 i = list(itemPar).index(item)
1781 if item.tail
is not None and ',' in item.tail:
1784 item.tail = tail.replace(
',',
'')
1785 elif i != 0
and ',' in itemPar[i - 1].tail:
1787 tail = itemPar[i - 1].tail
1788 itemPar[i - 1].tail = tail.replace(
',',
'')
1794 while j < len(itemPar)
and not found:
1795 if nonCode(itemPar[j]):
1797 if itemPar[j].tail
is not None and ',' in itemPar[j].tail:
1800 tail = itemPar[j].tail
1801 itemPar[j].tail = tail.replace(
',',
'')
1811 while j >= 0
and not found:
1812 if itemPar[j].tail
is not None and ',' in itemPar[j].tail:
1815 tail = itemPar[j].tail
1816 itemPar[j].tail = tail.replace(
',',
'')
1818 if nonCode(itemPar[j]):
1826 len([e
for e
in itemPar
if not nonCode(e)]) != 1:
1827 raise RuntimeError(
"Something went wrong here....")
1830 if i + 1 < len(itemPar)
and tag(itemPar[i + 1]) ==
'cnt':
1832 reason =
'lastOnLine'
1844 elif len([itemPar[j]
for j
in range(i + 1, len(itemPar))
if not nonCode(itemPar[j])]) == 0:
1858 if reason
is not None:
1859 def _getPrecedingCnt(itemPar, i):
1861 Return the index of the preceding node which is a continuation character
1862 :param itemPar: the list containig the node to suppress
1863 :param i: the index of the current node, i-1 is the starting index for the search
1864 :return: a tuple with three elements:
1865 - the node containing the preceding '&' character
1866 - the parent of the node containing the preceding '&' character
1867 - index of the preceding '&' in the parent (previsous element
1870 - In the general case the preceding '&' belongs to the same list:
1871 USE MODD, ONLY: X, &
1873 - But it exists a special case, where the preceding '&' don't belong to
1874 the same list (in the following example, both '&' are attached to the parent):
1879 while j >= 0
and tag(itemPar[j]) ==
'C':
1881 if j >= 0
and tag(itemPar[j]) ==
'cnt':
1882 return itemPar[j], itemPar, j
1888 siblings = self.getSiblings(itemPar, before=
True, after=
False)
1889 j2 = len(siblings) - 1
1890 while j2 >= 0
and tag(siblings[j2]) ==
'C':
1892 if j2 >= 0
and tag(siblings[j2]) ==
'cnt':
1893 return siblings[j2], siblings, j2
1894 return None,
None,
None
1896 precCnt, newl, j = _getPrecedingCnt(itemPar, i)
1897 if precCnt
is not None:
1899 nodesToSuppress.append(precCnt
if reason ==
'last' else itemPar[i + 1])
1901 precCnt2, _, _ = _getPrecedingCnt(newl, j)
1902 if precCnt2
is not None:
1904 nodesToSuppress.append(precCnt2
if reason ==
'last' else precCnt)
1907 for node
in nodesToSuppress:
1914 parent = self.getParent(itemPar)
1915 i = list(parent).index(node)
1916 if i != 0
and node.tail
is not None:
1917 if parent[i - 1].tail
is None:
1918 parent[i - 1].tail =
''
1919 parent[i - 1].tail = parent[i - 1].tail + node.tail