245 loopVar=None, reuseLoop=True, funcList=None,
246 updateMemSet=False, updateCopy=False, addAccIndependentCollapse=True):
248 Transform array syntax assignments into explicit DO loops.
250 Converts Fortran array syntax (e.g., A(:) = B(:)) into equivalent DO loop form.
254 concurrent : bool, optional
255 If True, use 'DO CONCURRENT' loops instead of simple 'DO' loops.
257 useMnhExpand : bool, optional
258 If True, respect mnh_expand directives to transform entire blocks
259 into a single loop. Default is True.
260 everywhere : bool, optional
261 If True, transform all array syntax in the code.
262 If False, only transform sections marked with !$mnh_expand directives.
264 loopVar : callable or None, optional
265 Function to determine loop index variable name.
266 Takes arguments: (lowerDecl, upperDecl, lowerUsed, upperUsed, name, index)
267 Returns: str (variable name), True (auto-generate name), or False (skip).
268 None (default) auto-generates variable names (J1, J2, etc.).
269 reuseLoop : bool, optional
270 If True, attempt to reuse loops when consecutive statements
271 have identical bounds. Default is True.
272 funcList : list of str, optional
273 Additional function names to recognize as array functions.
274 These functions will not be expanded. Default is None (empty list).
275 updateMemSet : bool, optional
276 If True, transform constant array initializations (e.g., A(:) = 0)
277 into DO loops. Default is False.
278 updateCopy : bool, optional
279 If True, transform array copy operations (e.g., A(:) = B(:))
280 into DO loops. Default is False.
281 addAccIndependentCollapse : bool, optional
282 If True, add !$acc loop independent collapse(N) directive
283 before DO constructs. Default is True.
289 Transformation Examples
290 ----------------------
297 DO J1 = LBOUND(A, 1), UBOUND(A, 1)
298 A(J1) = B(J1) + C(J1)
302 DO CONCURRENT (J1=LBOUND(A, 1):UBOUND(A, 1))
303 A(J1) = B(J1) + C(J1)
309 WHERE (MASK(:)) X(:) = Y(:)
312 DO J1 = 1, SIZE(X, 1)
313 IF (MASK(J1)) X(J1) = Y(J1)
318 - Only transforms array syntax using explicit ':' notation.
319 - Intrinsic array functions (COUNT, ANY, SUM, etc.) are preserved.
320 - Does not transform:
321 - A=A(:) (no-op on left side)
322 - A(:)=A (single array without slice on right side)
323 - When useMnhExpand=True, requires specific directive format:
324 !$mnh_expand_array(INDEX=bounds)
325 ... code to transform ...
326 !$mnh_end_expand_array(INDEX=bounds)
362 def decode(directive):
364 Decode mnh_expand directive
365 :param directive: mnh directive text
366 :return: (table, kind) where
367 table is a dictionnary: keys are variable names, values are tuples with first
369 kind is 'array' or 'where'
377 table = directive.split(
'(')[1].split(
')')[0].split(
',')
378 table = {c.split(
'=')[0]: c.split(
'=')[1].split(
':')
380 table.pop(
'OPENACC',
None)
381 if directive.lstrip(
' ').startswith(
'!$mnh_expand'):
382 kind = directive[13:].lstrip(
' ').split(
'(')[0].strip()
384 kind = directive[17:].lstrip(
' ').split(
'(')[0].strip()
387 def updateStmt(stmt, table, kind, extraindent, parent, scope):
389 Updates the statement given the table dictionnary '(:, :)' is replaced by '(JI, JK)' if
390 table.keys() is ['JI', 'JK']
391 :param stmt: statement to update
392 :param table: dictionnary retruned by the decode function
393 :param kind: kind of mnh directives: 'array' or 'where'
394 or None if transformation is not governed by
396 :param scope: current scope
399 def addExtra(node, extra):
400 """Helper function to add indentation spaces"""
401 if extra != 0
and (node.tail
is not None)
and '\n' in node.tail:
407 node.tail = re.sub(
r"(\n[ ]*)(\Z|[^\n ]+)",
408 r"\1" + extra *
' ' +
r"\2", node.tail)
410 addExtra(stmt, extraindent)
413 elif tag(stmt) ==
'cpp':
414 i = list(parent).index(stmt)
419 parent[i - 1].tail = parent[i - 1].tail.rstrip(
' ')
420 elif tag(stmt) ==
'a-stmt':
421 sss = stmt.findall(
'./{*}E-1/{*}named-E/{*}R-LT/{*}array-R/' +
422 '{*}section-subscript-LT/{*}section-subscript')
423 if len([ss
for ss
in sss
if ':' in alltext(ss)]) != len(table):
424 raise PYFTError(
"Inside code sections to transform in DO loops, " +
425 "all affectations must use ':'.\n" +
426 "This is not the case in:\n{stmt}".format(stmt=alltext(stmt)))
427 if stmt.find(
'./{*}E-1/{*}named-E/{*}N').tail
is not None and kind
is not None:
428 raise PYFTError(
"To keep the compatibility with the filepp version of loop " +
429 "expansion, nothing must appear between array names and " +
430 "opening parethesis inside mnh directive sections.")
433 for namedE
in stmt.findall(
'.//{*}R-LT/..'):
434 scope.arrayR2parensR(namedE, table)
435 for cnt
in stmt.findall(
'.//{*}cnt'):
436 addExtra(cnt, extraindent)
437 elif tag(stmt) ==
'if-stmt':
439 "An if statement is inside a code section transformed in DO loop in %s",
442 updateStmt(stmt.find(
'./{*}action-stmt')[0], table, kind, 0, stmt, scope)
443 elif tag(stmt) ==
'if-construct':
445 "An if construct is inside a code section transformed in DO loop in %s",
448 for ifBlock
in stmt.findall(
'./{*}if-block'):
449 for child
in ifBlock:
450 if tag(child)
not in (
'if-then-stmt',
'else-if-stmt',
451 'else-stmt',
'end-if-stmt'):
452 updateStmt(child, table, kind, extraindent, ifBlock, scope)
455 addExtra(child, extraindent)
456 for cnt
in child.findall(
'.//{*}cnt'):
458 addExtra(cnt, extraindent)
459 elif tag(stmt) ==
'where-stmt':
461 stmt.tag = f
'{{{NAMESPACE}}}if-stmt'
462 stmt.text =
'IF (' + stmt.text.split(
'(', 1)[1]
464 updateStmt(stmt.find(
'./{*}action-stmt')[0], table, kind,
465 extraindent, stmt, scope)
466 mask = stmt.find(
'./{*}mask-E')
467 mask.tag = f
'{{{NAMESPACE}}}condition-E'
468 for namedE
in mask.findall(
'.//{*}R-LT/..'):
469 scope.arrayR2parensR(namedE, table)
470 for cnt
in stmt.findall(
'.//{*}cnt'):
471 addExtra(cnt, extraindent)
472 elif tag(stmt) ==
'where-construct':
473 if kind !=
'where' and kind
is not None:
474 raise PYFTError(
'To keep the compatibility with the filepp version of loop " + \
475 "expansion, no where construct must appear " + \
476 "in mnh_expand_array blocks.')
478 stmt.tag = f
'{{{NAMESPACE}}}if-construct'
480 for whereBlock
in stmt.findall(
'./{*}where-block'):
481 whereBlock.tag = f
'{{{NAMESPACE}}}if-block'
482 for child
in whereBlock:
483 if tag(child) ==
'end-where-stmt':
485 child.tag = f
'{{{NAMESPACE}}}end-if-stmt'
486 child.text =
'END IF'
488 addExtra(child, extraindent)
489 elif tag(child)
in (
'where-construct-stmt',
'else-where-stmt'):
491 addExtra(child, extraindent)
492 if tag(child) ==
'where-construct-stmt':
494 child.tag = f
'{{{NAMESPACE}}}if-then-stmt'
495 child.text =
'IF (' + child.text.split(
'(', 1)[1]
500 if '(' in child.text:
502 child.tag = f
'{{{NAMESPACE}}}else-if-stmt'
503 child.text =
'ELSE IF (' + child.text.split(
'(', 1)[1]
506 child.tag = f
'{{{NAMESPACE}}}else-stmt'
508 for mask
in child.findall(
'./{*}mask-E'):
510 mask.tag = f
'{{{NAMESPACE}}}condition-E'
512 for namedE
in mask.findall(
'.//{*}R-LT/..'):
514 scope.arrayR2parensR(namedE, table)
515 for cnt
in child.findall(
'.//{*}cnt'):
517 addExtra(cnt, extraindent)
519 updateStmt(child, table, kind, extraindent, whereBlock, scope)
521 raise PYFTError(
'Unexpected tag found in mnh_expand ' +
522 'directives: {t}'.format(t=tag(stmt)))
525 def closeLoop(loopdesc):
526 """Helper function to deal with indentation"""
528 inner, outer, indent, extraindent = loopdesc
529 if inner[-2].tail
is not None:
531 outer.tail = inner[-2].tail[:-extraindent]
532 inner[-2].tail =
'\n' + (indent + extraindent - 2) *
' '
539 def recur(elem, scope):
543 for ie, sElem
in enumerate(list(elem)):
544 if tag(sElem) ==
'C' and sElem.text.lstrip(
' ').startswith(
'!$mnh_expand')
and \
548 raise PYFTError(
'Nested mnh_directives are not allowed')
550 inEverywhere = closeLoop(inEverywhere)
553 table, kind = decode(sElem.text)
555 indent = len(sElem.tail) - len(sElem.tail.rstrip(
' '))
556 toremove.append((elem, sElem))
560 if elem[ie - 1].tail
is None:
561 elem[ie - 1].tail =
''
562 elem[ie - 1].tail += sElem.tail.replace(
'\n',
'', 1).rstrip(
' ')
565 if addAccIndependentCollapse:
566 accCollapse = createElem(
'C', text=
'!$acc loop independent collapse(' +
567 str(len(table.keys())) +
')',
568 tail=
'\n' + indent *
' ')
569 toinsert.append((elem, accCollapse, ie))
572 inner, outer, extraindent = scope.createDoConstruct(table, indent=indent,
573 concurrent=concurrent)
574 toinsert.append((elem, outer, ie))
576 elif (tag(sElem) ==
'C' and
577 sElem.text.lstrip(
' ').startswith(
'!$mnh_end_expand')
and useMnhExpand):
580 raise PYFTError(
'End mnh_directive found before begin directive ' +
581 'in {f}'.format(f=scope.getFileName()))
582 if (table, kind) != decode(sElem.text):
583 raise PYFTError(
"Opening and closing mnh directives must be conform " +
584 "in {f}".format(f=scope.getFileName()))
586 toremove.append((elem, sElem))
590 outer.tail += sElem.tail.replace(
'\n',
'', 1)
592 elem[ie - 1].tail = elem[ie - 1].tail[:-2]
596 toremove.append((elem, sElem))
597 inner.insert(-1, sElem)
599 updateStmt(sElem, table, kind, extraindent, inner, scope)
601 elif everywhere
and tag(sElem)
in (
'a-stmt',
'if-stmt',
'where-stmt',
606 if tag(sElem) ==
'a-stmt':
608 arr = sElem.find(
'./{*}E-1/{*}named-E/{*}R-LT/{*}array-R/../..')
612 nodeE2 = sElem.find(
'./{*}E-2')
614 num = len(nodeE2.findall(
'.//{*}array-R'))
621 elif (len(nodeE2) == 1
and tag(nodeE2[0]) ==
'named-E' and num == 1
and
622 nodeE2[0].find(
'.//{*}parens-R')
is None):
628 if (isMemSet
and not updateMemSet)
or (isCopy
and not updateCopy):
630 elif tag(sElem) ==
'if-stmt':
632 arr = sElem.find(
'./{*}action-stmt/{*}a-stmt/{*}E-1/' +
633 '{*}named-E/{*}R-LT/{*}array-R/../..')
636 scope.changeIfStatementsInIfConstructs(singleItem=sElem)
639 elif tag(sElem) ==
'where-stmt':
640 arr = sElem.find(
'./{*}mask-E//{*}named-E/{*}R-LT/{*}array-R/../..')
641 elif tag(sElem) ==
'where-construct':
642 arr = sElem.find(
'./{*}where-block/{*}where-construct-stmt/' +
643 '{*}mask-E//{*}named-E/{*}R-LT/{*}array-R/../..')
650 elif len(set(alltext(a).count(
':')
651 for a
in sElem.findall(
'.//{*}R-LT/{*}array-R'))) > 1:
655 elif len(set([
'ALL',
'ANY',
'COSHAPE',
'COUNT',
'CSHIFT',
'DIMENSION',
656 'DOT_PRODUCT',
'EOSHIFT',
'LBOUND',
'LCOBOUND',
'MATMUL',
657 'MAXLOC',
'MAXVAL',
'MERGE',
'MINLOC',
'MINVAL',
'PACK',
658 'PRODUCT',
'REDUCE',
'RESHAPE',
'SHAPE',
'SIZE',
'SPREAD',
659 'SUM',
'TRANSPOSE',
'UBOUND',
'UCOBOUND',
'UNPACK'] +
660 (funcList
if funcList
is not None else [])
661 ).intersection(set(n2name(nodeN)
for nodeN
662 in sElem.findall(
'.//{*}named-E/{*}N')))) > 0:
668 newtable, varNew = scope.findArrayBounds(arr, loopVar, newVarList)
671 if var
not in newVarList:
672 newVarList.append(var)
679 inEverywhere = closeLoop(inEverywhere)
682 if not (inEverywhere
and table == newtable):
684 inEverywhere = closeLoop(inEverywhere)
686 if ie != 0
and elem[ie - 1].tail
is not None:
689 tail = tailSave.get(elem[ie - 1], elem[ie - 1].tail)
690 indent = len(tail) - len(tail.rstrip(
' '))
696 inner, outer, extraindent = scope.createDoConstruct(
697 table, indent=indent, concurrent=concurrent)
698 toinsert.append((elem, outer, ie))
699 inEverywhere = (inner, outer, indent, extraindent)
700 tailSave[sElem] = sElem.tail
701 toremove.append((elem, sElem))
702 inner.insert(-1, sElem)
704 updateStmt(sElem, table, kind, extraindent, inner, scope)
707 inEverywhere = closeLoop(inEverywhere)
710 inEverywhere = closeLoop(inEverywhere)
714 inEverywhere = closeLoop(inEverywhere)
716 for scope
in self.getScopes():
719 for elem, outer, ie
in toinsert[::-1]:
720 elem.insert(ie, outer)
722 for parent, elem
in toremove:
725 self.addVar([(v[
'scopePath'], v[
'n'], f
"INTEGER :: {v['n']}",
None)
726 for v
in newVarList])
810 def inline(self, subContained, callStmt, mainScope,
811 simplify=False, loopVar=None):
813 Inline a single contained subroutine at its call site.
815 This method performs the actual inlining of a contained subroutine
816 into the calling scope. It handles:
817 - ELEMENTAL subroutines with array arguments
818 - Optional arguments (PRESENT intrinsic)
819 - Variable name conflicts
820 - USE statement merging
824 subContained : xml element
825 XML fragment corresponding to the contained subroutine scope.
826 callStmt : xml element
827 The call-stmt node to replace with inlined code.
828 mainScope : PYFTscope
829 Scope of the main (calling) subroutine.
830 simplify : bool, optional
831 If True, remove empty constructs and unused variables
832 after inlining. Default is False.
833 loopVar : callable or None, optional
834 Function to determine loop index variable name.
835 Used when inlining ELEMENTAL subroutines called on arrays.
836 Takes: (lowerDecl, upperDecl, lowerUsed, upperUsed, name, index)
837 Returns: str, True (auto-generate), or False (skip).
841 - For ELEMENTAL subroutines on arrays: DO loops are introduced.
842 - Optional arguments: PRESENT(var) is replaced with .TRUE. or .FALSE.
843 - Missing optional arguments: code paths using them are removed.
844 - Name conflicts: local variables are renamed with _N suffixes.
846 def setPRESENTby(node, var, val):
848 Replace PRESENT(var) by .TRUE. if val is True, by .FALSE. otherwise on node
850 :param node: xml node to work on (a contained subroutine)
851 :param var: string of the name of the optional variable to check
853 for namedE
in node.findall(
'.//{*}named-E/{*}N/..'):
854 if n2name(namedE.find(
'./{*}N')).upper() ==
'PRESENT':
855 presentarg = n2name(namedE.find(
'./{*}R-LT/{*}parens-R/{*}element-LT/'
856 '{*}element/{*}named-E/{*}N'))
857 if presentarg.upper() == var.upper():
858 for nnn
in namedE[:]:
860 namedE.tag = f
'{{{NAMESPACE}}}literal-E'
861 namedE.text =
'.TRUE.' if val
else '.FALSE.'
864 parent = mainScope.getParent(callStmt)
867 if tag(parent) ==
'action-stmt':
868 mainScope.changeIfStatementsInIfConstructs(mainScope.getParent(parent))
869 parent = mainScope.getParent(callStmt)
873 prefix = subContained.findall(
'.//{*}prefix')
874 if len(prefix) > 0
and 'ELEMENTAL' in [p.text.upper()
for p
in prefix]:
876 mainScope.addArrayParenthesesInNode(callStmt)
879 mainScope.addExplicitArrayBounds(node=callStmt)
882 arrayRincallStmt = callStmt.findall(
'.//{*}array-R')
883 if len(arrayRincallStmt) > 0:
885 table, _ = mainScope.findArrayBounds(mainScope.getParent(arrayRincallStmt[0], 2),
889 for varName
in table.keys():
890 if not mainScope.varList.findVar(varName):
891 var = {
'as': [],
'asx': [],
892 'n': varName,
'i':
None,
't':
'INTEGER',
'arg':
False,
893 'use':
False,
'opt':
False,
'allocatable':
False,
894 'parameter':
False,
'init':
None,
'scopePath': mainScope.path}
895 mainScope.addVar([[mainScope.path, var[
'n'],
896 mainScope.varSpec2stmt(var),
None]])
899 inner, outer, _ = mainScope.createDoConstruct(table)
902 inner.insert(-1, callStmt)
904 parent.insert(list(parent).index(callStmt), outer)
905 parent.remove(callStmt)
907 for namedE
in callStmt.findall(
'./{*}arg-spec/{*}arg/{*}named-E'):
909 if namedE.find(
'./{*}R-LT'):
910 mainScope.arrayR2parensR(namedE, table)
913 node = copy.deepcopy(subContained)
918 varList = copy.deepcopy(self.varList)
919 for var
in [var
for var
in varList.restrict(subContained.path,
True)
920 if not var[
'arg']
and not var[
'use']]:
922 if varList.restrict(mainScope.path,
True).findVar(var[
'n']):
925 newName = re.sub(
r'_\d+$',
'', var[
'n'])
927 while (varList.restrict(subContained.path,
True).findVar(newName +
'_' + str(i))
or
928 varList.restrict(mainScope.path,
True).findVar(newName +
'_' + str(i))):
930 newName +=
'_' + str(i)
931 node.renameVar(var[
'n'], newName)
932 subst.append((var[
'n'], newName))
935 var[
'scopePath'] = mainScope.path
936 localVarToAdd.append(var)
939 for oldName, newName
in subst:
940 for var
in localVarToAdd + varList[:]:
941 if var[
'as']
is not None:
942 var[
'as'] = [[re.sub(
r'\b' + oldName +
r'\b', newName, dim[i])
943 if dim[i]
is not None else None
945 for dim
in var[
'as']]
952 localUseToAdd = node.findall(
'./{*}use-stmt')
953 for sNode
in node.findall(
'./{*}T-decl-stmt') + localUseToAdd + \
954 node.findall(
'./{*}implicit-none-stmt'):
957 while tag(node[icom]) ==
'C':
958 if node[icom].text.startswith(
'!$acc'):
959 if 'routine' in node[icom].text:
960 node.remove(node[icom])
964 node.remove(node[icom])
972 for argN
in subContained.findall(
'.//{*}subroutine-stmt/{*}dummy-arg-LT/{*}arg-N'):
973 vartable[alltext(argN).upper()] =
None
974 for iarg, arg
in enumerate(callStmt.findall(
'.//{*}arg')):
975 key = arg.find(
'.//{*}arg-N')
978 dummyName = alltext(key).upper()
981 dummyName = list(vartable.keys())[iarg]
983 nodeRLTarray = argnode.findall(
'.//{*}R-LT/{*}array-R')
984 if len(nodeRLTarray) > 0:
986 if len(nodeRLTarray) > 1
or \
987 argnode.find(
'./{*}R-LT/{*}array-R')
is None or \
988 tag(argnode) !=
'named-E':
990 raise PYFTError(
'Argument to complicated: ' + str(alltext(argnode)))
991 dim = nodeRLTarray[0].find(
'./{*}section-subscript-LT')[:]
997 argname =
"".join(argnode.itertext())
1000 tmp = copy.deepcopy(argnode)
1001 nodeRLT = tmp.find(
'./{*}R-LT')
1002 nodeRLT.remove(nodeRLT.find(
'./{*}array-R'))
1003 argname =
"".join(tmp.itertext())
1004 vartable[dummyName] = {
'node': argnode,
'name': argname,
'dim': dim}
1007 for dummyName
in [dummyName
for (dummyName, value)
in vartable.items()
1008 if value
is not None]:
1009 setPRESENTby(node, dummyName,
True)
1010 for dummyName
in [dummyName
for (dummyName, value)
in vartable.items()
1012 setPRESENTby(node, dummyName,
False)
1015 for dummyName
in [dummyName
for (dummyName, value)
in vartable.items()
if value
is None]:
1016 for nodeN
in [nodeN
for nodeN
in node.findall(
'.//{*}named-E/{*}N')
1017 if n2name(nodeN).upper() == dummyName]:
1020 par = node.getParent(nodeN, level=2)
1021 allreadySuppressed = []
1022 while par
and not removed
and par
not in allreadySuppressed:
1025 if tagName
in (
'a-stmt',
'print-stmt'):
1028 elif tagName ==
'call-stmt':
1033 raise NotImplementedError(
'call-stmt not (yet?) implemented')
1034 elif tagName
in (
'if-stmt',
'where-stmt'):
1037 elif tagName
in (
'if-then-stmt',
'else-if-stmt',
'where-construct-stmt',
1041 toSuppress = node.getParent(par, 2)
1042 elif tagName
in (
'select-case-stmt',
'case-stmt'):
1045 toSuppress = node.getParent(par, 2)
1046 elif tagName.endswith(
'-block')
or tagName.endswith(
'-stmt')
or \
1047 tagName.endswith(
'-construct'):
1054 raise PYFTError((
"We shouldn't be here. A case may have been " +
1055 "overlooked (tag={tag}).".format(tag=tagName)))
1056 if toSuppress
is not None:
1058 if toSuppress
not in allreadySuppressed:
1061 node.removeStmtNode(toSuppress,
False, simplify)
1062 allreadySuppressed.extend(list(toSuppress.iter()))
1064 par = node.getParent(par)
1067 for name, dummy
in vartable.items():
1071 for namedE
in [namedE
for namedE
in node.findall(
'.//{*}named-E/{*}N/{*}n/../..')
1072 if n2name(namedE.find(
'{*}N')).upper() == name]:
1074 nodeN = namedE.find(
'./{*}N')
1075 ns = nodeN.findall(
'./{*}n')
1076 ns[0].text = n2name(nodeN)
1082 descMain = varList.restrict(mainScope.path,
True).findVar(dummy[
'name'])
1083 descSub = varList.restrict(subContained.path,
True).findVar(name)
1087 if var[
'as']
is not None:
1088 var[
'as'] = [[re.sub(
r'\b' + name +
r'\b', dummy[
'name'], dim[i])
1089 if dim[i]
is not None else None
1091 for dim
in var[
'as']]
1096 nodeRLT = namedE.find(
'./{*}R-LT')
1097 if nodeRLT
is not None and tag(nodeRLT[0]) !=
'component-R':
1099 assert tag(nodeRLT[0])
in (
'array-R',
'parens-R'),
'Internal error'
1100 slices = nodeRLT[0].findall(
'./{*}section-subscript-LT/' +
1101 '{*}section-subscript')
1102 slices += nodeRLT[0].findall(
'./{*}element-LT/{*}element')
1105 if (descMain
is not None and descMain[
'as']
is not None and
1106 len(descMain[
'as']) > 0)
or \
1107 len(descSub[
'as']) > 0
or dummy[
'dim']
is not None:
1109 if len(descSub[
'as']) > 0:
1110 ndim = len(descSub[
'as'])
1113 if dummy[
'dim']
is not None:
1116 ndim = len([d
for d
in dummy[
'dim']
if ':' in alltext(d)])
1119 ndim = len(descMain[
'as'])
1120 ns[0].text +=
'(' + (
', '.join([
':'] * ndim)) +
')'
1121 updatedNamedE = createExprPart(alltext(namedE))
1122 namedE.tag = updatedNamedE.tag
1123 namedE.text = updatedNamedE.text
1124 for nnn
in namedE[:]:
1126 namedE.extend(updatedNamedE[:])
1127 slices = namedE.find(
'./{*}R-LT')[0].findall(
1128 './{*}section-subscript-LT/{*}section-subscript')
1134 namedE.find(
'./{*}N')[0].text = dummy[
'name']
1142 for isl, sl
in enumerate(slices):
1144 if len(descSub[
'as']) == 0
or descSub[
'as'][isl][1]
is None:
1147 if dummy[
'dim']
is not None:
1149 tagName =
'./{*}lower-bound' if i == 0
else './{*}upper-bound'
1150 descSub[i] = dummy[
'dim'][isl].find(tagName)
1153 if descSub[i]
is not None:
1155 descSub[i] = alltext(descSub[i])
1158 if descMain
is not None and descMain[
'as'][isl][1]
is not None:
1160 descSub[i] = descMain[
'as'][isl][i]
1161 if i == 0
and descSub[i]
is None:
1164 descSub[i] =
"L" if i == 0
else "U"
1165 descSub[i] +=
"BOUND({name}, {isl})".format(
1166 name=dummy[
'name'], isl=isl + 1)
1168 descSub[0] = descSub[
'as'][isl][0]
1169 if descSub[0]
is None:
1171 descSub[1] = descSub[
'as'][isl][1]
1183 if dummy[
'dim']
is not None and \
1184 not alltext(dummy[
'dim'][isl]).strip().startswith(
':'):
1185 offset = alltext(dummy[
'dim'][isl].find(
'./{*}lower-bound'))
1187 if descMain
is not None and descMain[
'as']
is not None:
1188 offset = descMain[
'as'][isl][0]
1191 elif offset.strip().startswith(
'-'):
1192 offset =
'(' + offset +
')'
1194 offset =
"LBOUND({name}, {isl})".format(
1195 name=dummy[
'name'], isl=isl + 1)
1196 if offset.upper() == descSub[0].upper():
1199 if descSub[0].strip().startswith(
'-'):
1200 offset +=
'- (' + descSub[0] +
')'
1202 offset +=
'-' + descSub[0]
1206 if tag(sl) ==
'element' or \
1207 (tag(sl) ==
'section-subscript' and ':' not in alltext(sl)):
1211 low = sl.find(
'./{*}lower-bound')
1213 low = createElem(
'lower-bound', tail=sl.text)
1214 low.append(createExprPart(descSub[0]))
1217 up = sl.find(
'./{*}upper-bound')
1219 up = createElem(
'upper-bound')
1220 up.append(createExprPart(descSub[1]))
1223 for bound
in bounds:
1225 if bound[-1].tail
is None:
1227 bound[-1].tail +=
'+' + offset
1233 if dummy[
'dim']
is not None and len(dummy[
'dim']) > len(slices):
1234 slices[-1].tail =
', '
1235 par = node.getParent(slices[-1])
1236 par.extend(dummy[
'dim'][len(slices):])
1241 updatedNamedE = createExprPart(alltext(namedE))
1242 namedE.tag = updatedNamedE.tag
1243 namedE.text = updatedNamedE.text
1244 for nnn
in namedE[:]:
1246 namedE.extend(updatedNamedE[:])
1248 node.remove(node.find(
'./{*}subroutine-stmt'))
1249 node.remove(node.find(
'./{*}end-subroutine-stmt'))
1252 mainScope.addVar([[mainScope.path, var[
'n'], mainScope.varSpec2stmt(var),
None]
1253 for var
in localVarToAdd])
1254 mainScope.addModuleVar([[mainScope.path, n2name(useStmt.find(
'.//{*}module-N//{*}N')),
1255 [n2name(v.find(
'.//{*}N'))
1256 for v
in useStmt.findall(
'.//{*}use-N')]]
1257 for useStmt
in localUseToAdd])
1260 index = list(parent).index(callStmt)
1261 parent.remove(callStmt)
1262 if callStmt.tail
is not None:
1263 if node[-1].tail
is None:
1264 node[-1].tail = callStmt.tail
1266 node[-1].tail = node[-1].tail + callStmt.tail
1267 for node
in node[::-1]:
1269 parent.insert(index, node)
1451 Remove statement nodes with optional code simplification.
1455 nodes : xml element or list of xml elements
1456 Node(s) to remove from the code tree.
1458 If True, also remove variables that become unused after
1459 the deletion of the nodes.
1460 simplifyStruct : bool
1461 If True, also remove empty enclosing constructs
1462 (IF blocks, loops) that become empty after node removal.
1466 >>> pft = PYFT('input.F90')
1467 >>> nodes = pft.findall('.//{*}call-stmt')
1468 >>> pft.removeStmtNode(nodes, simplifyVar=True, simplifyStruct=True)
1472 - Handles nested structures (removes inner statements first).
1473 - When simplifyStruct=True:
1474 - Empty IF blocks are removed
1475 - Empty loops are removed
1476 - WHERE constructs are handled
1477 - When simplifyVar=True:
1478 - Unused local variables are removed
1479 - Empty type declarations are cleaned up
1484 nodesToSuppress = []
1485 if not isinstance(nodes, list):
1488 if tag(node)
in (
'if-stmt',
'where-stmt'):
1489 action = node.find(
'./{*}action-stmt')
1490 if action
is not None and len(action) != 0:
1491 nodesToSuppress.append(action[0])
1493 nodesToSuppress.append(node)
1494 elif tag(node) ==
'}action-stmt':
1496 nodesToSuppress.append(node[0])
1498 nodesToSuppress.append(node)
1500 nodesToSuppress.append(node)
1505 for node
in nodesToSuppress:
1506 scopePath = self.getScopePath(node)
1507 if tag(node) ==
'do-construct':
1509 varToCheck.extend([(scopePath, n2name(arg))
1510 for arg
in node.find(
'./{*}do-stmt').findall(
'.//{*}N')])
1511 elif tag(node)
in (
'if-construct',
'if-stmt'):
1513 varToCheck.extend([(scopePath, n2name(arg))
1514 for arg
in node.findall(
'.//{*}condition-E//{*}N')])
1515 elif tag(node)
in (
'where-construct',
'where-stmt'):
1517 varToCheck.extend([(scopePath, n2name(arg))
1518 for arg
in node.findall(
'.//{*}mask-E//{*}N')])
1519 elif tag(node) ==
'call-stmt':
1521 varToCheck.extend([(scopePath, n2name(arg))
1522 for arg
in node.findall(
'./{*}arg-spec//{*}N')])
1524 varToCheck.append((scopePath,
1525 n2name(node.find(
'./{*}procedure-designator//{*}N'))))
1526 elif tag(node)
in (
'a-stmt',
'print-stmt'):
1527 varToCheck.extend([(scopePath, n2name(arg))
for arg
in node.findall(
'.//{*}N')])
1528 elif tag(node) ==
'selectcase-construct':
1530 varToCheck.extend([(scopePath, n2name(arg))
1531 for arg
in node.findall(
'.//{*}case-E//{*}N')])
1532 varToCheck.extend([(scopePath, n2name(arg))
1533 for arg
in node.findall(
'.//{*}case-value//{*}N')])
1537 for node
in nodesToSuppress:
1538 parent = self.getParent(node)
1539 parents[id(node)] = parent
1540 newlines =
'\n' * (alltext(node).count(
'\n')
if tag(node).endswith(
'-construct')
else 0)
1541 if node.tail
is not None or len(newlines) > 0:
1542 previous = self.getSiblings(node, after=
False)
1543 if len(previous) == 0:
1546 previous = previous[-1]
1547 if previous.tail
is None:
1549 previous.tail = (previous.tail.replace(
'\n',
'') +
1550 (node.tail
if node.tail
is not None else ''))
1554 self.removeVarIfUnused(varToCheck, excludeDummy=
True,
1555 excludeModule=
True, simplify=simplifyVar)
1558 newNodesToSuppress = []
1559 for node
in nodesToSuppress:
1560 parent = parents[id(node)]
1563 if tag(parent) ==
'action-stmt':
1564 newNodesToSuppress.append(self.getParent(parent))
1566 elif simplifyStruct:
1567 if tag(parent) ==
'do-construct' and len(
_nodesInDo(parent)) == 0:
1568 newNodesToSuppress.append(parent)
1569 elif tag(parent) ==
'if-block':
1570 parPar = self.getParent(parent)
1572 newNodesToSuppress.append(parPar)
1573 elif tag(parent) ==
'where-block':
1574 parPar = self.getParent(parent)
1576 newNodesToSuppress.append(parPar)
1577 elif tag(parent) ==
'selectcase-block':
1578 parPar = self.getParent(parent)
1580 newNodesToSuppress.append(parPar)
1582 constructNodes, otherNodes = [], []
1583 for nnn
in newNodesToSuppress:
1584 if tag(nnn).endswith(
'-construct'):
1585 if nnn
not in constructNodes:
1586 constructNodes.append(nnn)
1588 if nnn
not in otherNodes:
1589 otherNodes.append(nnn)
1591 if len(otherNodes) > 0:
1594 for nnn
in constructNodes:
1651 :param loopVariables: ordered dictionnary with loop variables as key and bounds as values.
1652 Bounds are expressed with a 2-tuple.
1653 Keys must be in the same order as the order used when addressing an
1654 element: if loopVariables.keys is [JI, JK], arrays are
1655 addressed with (JI, JK)
1656 :param indent: current indentation
1657 :param concurrent: if False, output is made of nested 'DO' loops
1658 if True, output is made of a single 'DO CONCURRENT' loop
1659 :return: (inner, outer, extraindent) with
1660 - inner the inner do-construct where statements must be added
1661 - outer the outer do-construct to be inserted somewhere
1662 - extraindent the number of added indentation
1663 (2 if concurrent else 2*len(loopVariables))
1689 for var, (lo, up)
in list(loopVariables.items())[::-1]:
1690 nodeV = createElem(
'V', tail=
'=')
1691 nodeV.append(createExprPart(var))
1692 lower, upper = createArrayBounds(lo, up,
'DOCONCURRENT')
1694 triplet = createElem(
'forall-triplet-spec')
1695 triplet.extend([nodeV, lower, upper])
1697 triplets.append(triplet)
1699 tripletLT = createElem(
'forall-triplet-spec-LT', tail=
')')
1700 for triplet
in triplets[:-1]:
1702 tripletLT.extend(triplets)
1704 dostmt = createElem(
'do-stmt', text=
'DO CONCURRENT (', tail=
'\n')
1705 dostmt.append(tripletLT)
1706 enddostmt = createElem(
'end-do-stmt', text=
'END DO')
1708 doconstruct = createElem(
'do-construct', tail=
'\n')
1709 doconstruct.extend([dostmt, enddostmt])
1710 inner = outer = doconstruct
1711 doconstruct[0].tail += (indent + 2) *
' '
1722 def makeDo(var, lo, up):
1723 doV = createElem(
'do-V', tail=
'=')
1724 doV.append(createExprPart(var))
1725 lower, upper = createArrayBounds(lo, up,
'DO')
1727 dostmt = createElem(
'do-stmt', text=
'DO ', tail=
'\n')
1728 dostmt.extend([doV, lower, upper])
1730 enddostmt = createElem(
'end-do-stmt', text=
'END DO')
1732 doconstruct = createElem(
'do-construct', tail=
'\n')
1733 doconstruct.extend([dostmt, enddostmt])
1738 for i, (var, (lo, up))
in enumerate(list(loopVariables.items())[::-1]):
1739 doconstruct = makeDo(var, lo, up)
1741 doconstruct[0].tail += (indent + 2 * i + 2) *
' '
1746 inner.insert(1, doconstruct)
1749 doconstruct.tail += (indent + 2 * i - 2) *
' '
1750 return inner, outer, 2
if concurrent
else 2 * len(loopVariables)
1775 :param item: item to remove from list
1776 :param itemPar: the parent of item (the list)
1779 nodesToSuppress = [item]
1782 i = list(itemPar).index(item)
1783 if item.tail
is not None and ',' in item.tail:
1786 item.tail = tail.replace(
',',
'')
1787 elif i != 0
and ',' in itemPar[i - 1].tail:
1789 tail = itemPar[i - 1].tail
1790 itemPar[i - 1].tail = tail.replace(
',',
'')
1796 while j < len(itemPar)
and not found:
1797 if nonCode(itemPar[j]):
1799 if itemPar[j].tail
is not None and ',' in itemPar[j].tail:
1802 tail = itemPar[j].tail
1803 itemPar[j].tail = tail.replace(
',',
'')
1813 while j >= 0
and not found:
1814 if itemPar[j].tail
is not None and ',' in itemPar[j].tail:
1817 tail = itemPar[j].tail
1818 itemPar[j].tail = tail.replace(
',',
'')
1820 if nonCode(itemPar[j]):
1828 len([e
for e
in itemPar
if not nonCode(e)]) != 1:
1829 raise RuntimeError(
"Something went wrong here....")
1832 if i + 1 < len(itemPar)
and tag(itemPar[i + 1]) ==
'cnt':
1834 reason =
'lastOnLine'
1846 elif len([itemPar[j]
for j
in range(i + 1, len(itemPar))
if not nonCode(itemPar[j])]) == 0:
1860 if reason
is not None:
1861 def _getPrecedingCnt(itemPar, i):
1863 Return the index of the preceding node which is a continuation character
1864 :param itemPar: the list containig the node to suppress
1865 :param i: the index of the current node, i-1 is the starting index for the search
1866 :return: a tuple with three elements:
1867 - the node containing the preceding '&' character
1868 - the parent of the node containing the preceding '&' character
1869 - index of the preceding '&' in the parent (previsous element
1872 - In the general case the preceding '&' belongs to the same list:
1873 USE MODD, ONLY: X, &
1875 - But it exists a special case, where the preceding '&' don't belong to
1876 the same list (in the following example, both '&' are attached to the parent):
1881 while j >= 0
and tag(itemPar[j]) ==
'C':
1883 if j >= 0
and tag(itemPar[j]) ==
'cnt':
1884 return itemPar[j], itemPar, j
1890 siblings = self.getSiblings(itemPar, before=
True, after=
False)
1891 j2 = len(siblings) - 1
1892 while j2 >= 0
and tag(siblings[j2]) ==
'C':
1894 if j2 >= 0
and tag(siblings[j2]) ==
'cnt':
1895 return siblings[j2], siblings, j2
1896 return None,
None,
None
1898 precCnt, newl, j = _getPrecedingCnt(itemPar, i)
1899 if precCnt
is not None:
1901 nodesToSuppress.append(precCnt
if reason ==
'last' else itemPar[i + 1])
1903 precCnt2, _, _ = _getPrecedingCnt(newl, j)
1904 if precCnt2
is not None:
1906 nodesToSuppress.append(precCnt2
if reason ==
'last' else precCnt)
1909 for node
in nodesToSuppress:
1916 parent = self.getParent(itemPar)
1917 i = list(parent).index(node)
1918 if i != 0
and node.tail
is not None:
1919 if parent[i - 1].tail
is None:
1920 parent[i - 1].tail =
''
1921 parent[i - 1].tail = parent[i - 1].tail + node.tail