252 Inline structure member accesses in compute statements.
254 Converts TYPE%VAR access patterns into single local variables
255 to improve performance by reducing pointer dereferences.
257 Transformation Examples
258 ----------------------
259 Simple member access:
260 - ZA = 1 + CST%XG => ZA = 1 + XCST_G
263 - ZA = 1 + PARAM_ICE%XRTMIN(3) => ZA = 1 + XPARAM_ICE_XRTMIN3
266 - ZRSMIN(1:KRR) = ICED%XRTMIN(1:KRR) => ZRSMIN(1:KRR) = ICEDXRTMIN1KRR(1:KRR)
269 - IF(TURBN%CSUBG_MF_PDF=='NONE') => IF(CTURBNSUBG_MF_PDF=='NONE')
273 - Only handles single-level structure access (not TOTO%CST%XG)
274 - Does not handle arrays with deferred shapes
275 - Type components must have known dimensions
277 def _getShapeFromLHS(aStmt, scope):
279 Get array shape from the LHS variable declaration.
282 varNode = e1.find(
'.//{*}N/{*}n')
285 varDesc = scope.varList.findVar(varNode.text)
286 if varDesc
is None or not varDesc.get(
'as')
or len(varDesc[
'as']) == 0:
290 def convertOneType(component, newVarList, scope, aStmt=None):
292 objType = scope.getParent(component, 2)
293 objTypeStr = alltext(objType).upper()
294 namedENn = objType.find(
'.//{*}N/{*}n')
295 structure = namedENn.text
296 variable = component.find(
'.//{*}ct').text.upper()
297 if variable[0] ==
"T":
302 arrayRall = objType.findall(
'.//{*}array-R')
303 if len(arrayRall) > 0:
304 arrayR = copy.deepcopy(arrayRall[0])
305 txt = alltext(arrayR).replace(
',',
'')
306 txt = txt.replace(
':',
'')
307 txt = txt.replace(
'(',
'')
308 txt = txt.replace(
')',
'')
309 arrayIndices = arrayIndices + txt
310 elif len(objType.findall(
'.//{*}element-LT')) > 0:
314 for elem
in objType.findall(
'.//{*}element'):
317 arrayIndices = arrayIndices + txt
318 allConst = all(t.lstrip(
'-').isdigit()
for t
in elements)
319 if not allConst
and aStmt
is not None:
321 memberShape = _getShapeFromLHS(aStmt, scope)
322 if memberShape
is not None:
324 newName = variable[0] + structure + variable[1:]
325 newName = newName.upper()
327 namedENn.text = newName
328 rlt = objType.find(
'.//{*}R-LT')
329 compR = rlt.find(
'.//{*}component-R')
332 if newName
not in newVarList:
333 newVarList[newName] = (memberShape, objTypeStr,
336 newName = variable[0] + structure + variable[1:] + arrayIndices
337 newName = newName.upper()
341 namedENn.text = newName
342 objType.remove(objType.find(
'.//{*}R-LT'))
343 if len(arrayRall) > 0:
344 objType.insert(1, arrayR)
347 if newName
not in newVarList:
348 if len(arrayRall) == 0:
349 newVarList[newName] = (
None, objTypeStr)
351 newVarList[newName] = (arrayR, objTypeStr)
353 scopes = self.getScopes()
354 if scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
356 for scope
in [scope
for scope
in scopes
357 if 'sub:' in scope.path
and 'interface' not in scope.path]:
359 for ifStmt
in (scope.findall(
'.//{*}if-then-stmt') +
360 scope.findall(
'.//{*}else-if-stmt') +
361 scope.findall(
'.//{*}where-stmt')):
362 compo = ifStmt.findall(
'.//{*}component-R')
364 for elcompo
in compo:
365 convertOneType(elcompo, newVarList, scope)
367 for aStmt
in scope.findall(
'.//{*}a-stmt'):
370 if len(aStmt[0].findall(
'.//{*}component-R')) == 0:
371 compoE2 = aStmt.findall(
'.//{*}component-R')
378 nbNamedEinE2 = len(aStmt.findall(
'.//{*}E-2')[0].findall(
'.//{*}named-E/' +
380 if nbNamedEinE2 > 1
or nbNamedEinE2 == 1
and \
381 len(aStmt[0].findall(
'.//{*}R-LT')) == 1:
382 for elcompoE2
in compoE2:
383 convertOneType(elcompoE2, newVarList, scope,
388 for aStmt
in scope.findall(
'.//{*}a-stmt'):
390 if len(aStmt[0].findall(
'.//{*}component-R')) > 0:
392 for el
in newVarList.items():
393 if alltext(aStmt[0]) == el[1][1]:
394 stmtAffect = createExpr(el[0] +
"=" + alltext(aStmt[0]))
395 par = scope.getParent(aStmt)
397 if tag(par[list(par).index(aStmt)+1]) ==
'C':
399 par.insert(list(par).index(aStmt)+1+iExtra, stmtAffect[0])
403 for el, var
in newVarList.items():
404 if el[0].upper() ==
'X' or el[0].upper() ==
'P' or el[0].upper() ==
'Z':
406 elif el[0].upper() ==
'L' or el[0].upper() ==
'O':
408 elif el[0].upper() ==
'N' or el[0].upper() ==
'I' or el[0].upper() ==
'K':
410 elif el[0].upper() ==
'C':
411 varType =
'CHARACTER(LEN=LEN(' + var[1] +
'))'
413 raise PYFTError(
'Case not implemented for the first letter of the newVarName ' +
414 el +
' in convertTypesInCompute')
417 if isinstance(var[0], list):
420 varArray =
', DIMENSION('
421 for i, bound
in enumerate(memberShape):
432 varArray =
', DIMENSION('
433 for i, sub
in enumerate(var[0].findall(
'.//{*}section-subscript')):
434 if len(sub.findall(
'.//{*}upper-bound')) > 0:
435 dimSize = simplifyExpr(
436 alltext(sub.findall(
'.//{*}upper-bound')[0]) +
437 '-' + alltext(sub.findall(
'.//{*}lower-bound')[0]) +
439 elif len(sub.findall(
'.//{*}lover-bound')) > 0:
440 dimSize = simplifyExpr(alltext(sub.findall(
'.//{*}lower-bound')[0]))
442 dimSize =
'SIZE(' + var[1] +
',' + str(i+1) +
')'
443 varArray =
', DIMENSION(' + dimSize +
','
444 varArray = varArray[:-1] +
')'
445 scope.addVar([[scope.path, el, varType + varArray +
' :: ' + el,
None]])
448 if isinstance(var[0], list):
450 stmtAffect = createExpr(el +
"=" + var[2] +
'%' + var[3])[0]
452 stmtAffect = createExpr(el +
"=" + var[1])[0]
453 scope.insertStatement(scope.indent(stmtAffect), first=
True)
559 Convert MODULE to SUBMODULE statements and add INTERFACE of SUBROUTINEs of PHYEX
560 ==> Applied only on MODE_
562 - if an INTERFACE already exists
563 - if no subroutine is present in the module
564 - to CONTAINS routines
565 1) Create interface statement if any
566 2) Add subroutines declaration (with MODULE statement)
567 3) Add SUBMODULE statements and convert SUBROUTINE to MODULE SUBROUTINE statements
569 scopes = self.getScopes()
572 oldModNode = self.find(
'.//{*}program-unit')
573 modNode = copy.deepcopy(self.find(
'.//{*}program-unit'))
575 interfaceStmt = self.findall(
'.//{*}interface-stmt')
576 subStmt = self.findall(
'.//{*}subroutine-stmt')
578 if modScope.path.split(
'/')[-1].split(
':')[1][:4] ==
'MODE' and \
579 len(interfaceStmt) == 0
and len(subStmt) > 0:
580 moduleName = modScope.path.split(
'/')[-1].split(
':')[1][:]
582 newMod = createElem(
'program-unit', text=
'MODULE ' + moduleName, tail=
'\n')
584 newMod.append(createElem(
'implicit-none-stmt', text=
'IMPLICIT NONE', tail=
'\n'))
585 interfaceStmt = createElem(
'interface-construct')
586 interfaceStmt.append(createElem(
'interface-stmt', text=
'INTERFACE', tail=
'\n'))
587 interfaceStmt.append(createElem(
'end-interface-stmt', text=
'END INTERFACE', tail=
'\n'))
588 newMod.append(interfaceStmt)
592 for scope
in scopes[1:]:
594 if sum(
'sub' in s
for s
in scope.path.split(
'/')) == 1:
595 subsModified.append(scope.path.split(
'/')[-1].split(
':')[1][:])
596 subroutineDecl = createElem(
'module-unit')
598 subroutineStmt = copy.deepcopy(scope[0])
599 declType = subroutineStmt.text
600 prefix = createElem(
'prefix')
601 prefix.text =
'MODULE'
602 subroutineStmt.text =
''
603 subroutineStmt.insert(0, prefix)
604 prefix.tail =
' ' + declType
605 subroutineDecl.append(subroutineStmt)
607 for use
in scope.findall(
'.//{*}use-stmt'):
608 subroutineDecl.append(copy.deepcopy(use))
609 subroutineDecl.append(createElem(
'implicit-none-stmt', text=
'IMPLICIT NONE',
612 for var
in [var
for var
in scope.varList
if var[
'arg']
or var[
'result']]:
613 subroutineDecl.append(createExpr(self.varSpec2stmt(var,
True))[0])
614 for external
in scope.findall(
'./{*}external-stmt'):
615 subroutineDecl.append(copy.deepcopy(external))
616 if 'SUBROUTINE' in declType:
617 endStmt = createElem(
'end-subroutine-stmt')
618 declName = subroutineStmt.find(
'./{*}subroutine-N/{*}N/{*}n').text
619 elif 'FUNCTION' in declType:
620 endStmt = createElem(
'end-function-stmt')
621 declName = subroutineStmt.find(
'./{*}function-N/{*}N/{*}n').text
623 raise PYFTError(
'declType in addSubmodulePHYEX not handled')
625 endStmt.text =
'END ' + declType + declName +
'\n'
626 subroutineDecl.append(endStmt)
627 interfaceStmt.insert(1, subroutineDecl)
630 newMod.append(createElem(
'end-program-unit', text=
'END MODULE ' + moduleName,
632 self[0].insert(0, newMod)
635 progUnit = createElem(
'program-unit')
646 submoduleStmt = createElem(
'submodule-stmt', text=
'SUBMODULE (')
647 parentId = createElem(
'parent-identifier')
649 ancestorModule = createElem(
'ancestor-module-N')
650 ancestorModuleN = createElem(
'n', text=moduleName)
651 ancestorModule.append(ancestorModuleN)
652 parentId.append(ancestorModule)
653 submoduleStmt.append(parentId)
654 submoduleModule = createElem(
'submodule-module-N')
655 submoduleModuleN = createElem(
'n', text=
'S' + moduleName, tail=
'\n')
656 submoduleModule.append(submoduleModuleN)
657 submoduleStmt.append(submoduleModule)
658 progUnit.append(submoduleStmt)
661 endSubmoduleStmt = createElem(
'end-submodule-stmt', text=
'END SUBMODULE ')
662 submoduleN = createElem(
'submodule-N')
663 submoduleNN = createElem(
'N')
664 submoduleNNn = createElem(
'n', text=
'S' + moduleName, tail=
'\n')
665 submoduleNN.append(submoduleNNn)
666 submoduleN.append(submoduleNN)
667 endSubmoduleStmt.append(submoduleN)
669 progUnit.append(endSubmoduleStmt)
670 progUnit.append(createElem(
'end-program-unit'))
674 modStmt = modNode.find(
'.//{*}module-stmt')
675 modEndStmt = modNode.find(
'.//{*}end-module-stmt')
676 modNode.remove(modStmt)
677 modNode.remove(modEndStmt)
680 publicStmts = modNode.findall(
'.//{*}public-stmt')
681 privateStmts = modNode.findall(
'.//{*}private-stmt')
682 if len(publicStmts) > 0:
683 for publicStmt
in publicStmts:
684 modNode.remove(publicStmt)
685 if len(privateStmts) > 0:
686 for privateStmt
in privateStmts:
687 modNode.remove(privateStmt)
690 subroutines = modNode.findall(
'.//{*}subroutine-stmt')
691 for sub
in subroutines:
692 if sub.find(
'.//{*}N/{*}n').text
in subsModified:
693 prefix = createElem(
'prefix')
694 prefix.text =
'MODULE'
696 sub.insert(0, prefix)
697 prefix.tail =
' SUBROUTINE '
698 progUnit.insert(1, modNode)
700 self.insert(1, progUnit)
703 self.remove(oldModNode)
709 Add MPPDB_CHEKS on all intent REAL arrays on subroutines.
710 ****** Not applied on modd_ routines. ********
711 Handle optional arguments.
712 Example, for a BL89 routine with 4 arguments, 1 INTENT(IN),
713 2 INTENT(INOUT), 1 INTENT(OUT), it produces :
714 IF (MPPDB_INITIALIZED) THEN
716 CALL MPPDB_CHECK(PZZ, "BL89 beg:PZZ")
717 !Check all INOUT arrays
718 CALL MPPDB_CHECK(PDZZ, "BL89 beg:PDZZ")
719 CALL MPPDB_CHECK(PTHVREF, "BL89 beg:PTHVREF")
722 IF (MPPDB_INITIALIZED) THEN
723 !Check all INOUT arrays
724 CALL MPPDB_CHECK(PDZZ, "BL89 end:PDZZ")
725 CALL MPPDB_CHECK(PTHVREF, "BL89 end:PTHVREF")
726 !Check all OUT arrays
727 CALL MPPDB_CHECK(PLM, "BL89 end:PLM")
729 param printsMode: if True, instead of CALL MPPDB_CHECK, add fortran prints for debugging
731 def addPrints_statement(var, typeofPrints='minmax'):
732 ifBeg, ifEnd =
'',
''
735 if typeofPrints ==
'minmax':
736 strMSG = f
'MINMAX {varName} = \",MINVAL({varName}), MAXVAL({varName})'
737 elif typeofPrints ==
'shape':
738 strMSG = f
'SHAPE {varName} = \",SHAPE({varName})'
740 raise PYFTError(
'typeofPrints is either minmax or shape in addPrints_statement')
742 strMSG = var[
'n'] +
' = \",' + var[
'n']
744 ifBeg = ifBeg +
'IF (PRESENT(' + var[
'n'] +
')) THEN\n '
745 ifEnd = ifEnd +
'\nEND IF'
746 return createExpr(ifBeg +
"print*,\"" + strMSG + ifEnd)[0]
748 def addMPPDB_CHECK_statement(var, subRoutineName, strMSG='beg:
'):
749 ifBeg, ifEnd, addD, addLastDim, addSecondDimType = '',
'',
'',
'',
''
753 if 'D%NIJT' in var[
'as'][0][1]:
755 if len(var[
'as']) == 2:
757 addLastDim =
', ' + var[
'as'][1][1]
758 if len(var[
'as']) >= 2:
761 if 'D%NK' in var[
'as'][1][1]:
762 addSecondDimType =
',' +
'''"VERTICAL"'''
764 addSecondDimType =
',' +
'''"OTHER"'''
765 if 'MERGE' in var[
'as'][-1][1]:
766 keyDimMerge = var[
'as'][-1][1].split(
',')[2][:-1]
767 ifBeg =
'IF (' + keyDimMerge +
') THEN\n'
770 ifBeg = ifBeg +
'IF (PRESENT(' + var[
'n'] +
')) THEN\n IF (SIZE(' + \
771 var[
'n'] +
',1) > 0) THEN\n'
772 ifEnd = ifEnd +
'\nEND IF\nEND IF'
773 argsMPPDB = var[
'n'] +
", " +
"\"" + subRoutineName +
" " + strMSG+var[
'n'] +
"\""
774 return createExpr(ifBeg +
"CALL MPPDB_CHECK(" + addD + argsMPPDB +
775 addLastDim + addSecondDimType +
")" + ifEnd)[0]
776 scopes = self.getScopes()
777 if scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
784 if 'sub:' in scope.path
and 'func' not in scope.path
and 'interface' not in scope.path:
785 subRoutineName = scope.path.split(
'/')[-1].split(
':')[1]
788 arraysIn, arraysInOut, arraysOut = [], [], []
790 for var
in scope.varList:
791 if var[
'arg']
and var[
'as']
and 'TYPE' not in var[
't']
and \
792 'REAL' in var[
't']
and var[
'scopePath'] == scope.path:
795 if var[
'i'] ==
'INOUT':
796 arraysInOut.append(var)
797 if var[
'i'] ==
'OUT':
798 arraysOut.append(var)
800 for var
in scope.varList:
801 if not var[
't']
or var[
't']
and 'TYPE' not in var[
't']:
804 if var[
'i'] ==
'INOUT':
805 arraysInOut.append(var)
806 if var[
'i'] ==
'OUT':
807 arraysOut.append(var)
809 if len(arraysIn) + len(arraysInOut) + len(arraysOut) == 0:
814 scope.addModuleVar([(scope.path,
'MODE_MPPDB',
None)])
816 scope.addModuleVar([(scope.path,
'MODD_BLANK_n', [
'LDUMMY1'])])
819 commentIN = createElem(
'C', text=
'!Check all IN arrays', tail=
'\n')
820 commentINOUT = createElem(
'C', text=
'!Check all INOUT arrays', tail=
'\n')
821 commentOUT = createElem(
'C', text=
'!Check all OUT arrays', tail=
'\n')
824 if len(arraysIn) + len(arraysInOut) > 0:
826 ifMPPDBinit = createExpr(
"IF (MPPDB_INITIALIZED) THEN\n END IF")[0]
828 ifMPPDBinit = createExpr(
"IF (LDUMMY1) THEN\n END IF")[0]
829 ifMPPDB = ifMPPDBinit.find(
'.//{*}if-block')
832 if len(arraysIn) > 0:
833 ifMPPDB.insert(1, commentIN)
834 for i, var
in enumerate(arraysIn):
836 ifMPPDB.insert(2 + i, addMPPDB_CHECK_statement(var, subRoutineName,
839 ifMPPDB.insert(2 + i, addPrints_statement(var,
840 typeofPrints=
'minmax'))
841 ifMPPDB.insert(3 + i, addPrints_statement(var,
842 typeofPrints=
'shape'))
845 if len(arraysInOut) > 0:
846 shiftLineNumber = 2
if len(arraysIn) > 0
else 1
848 ifMPPDB.insert(len(arraysIn) + shiftLineNumber, commentINOUT)
850 ifMPPDB.insert(len(arraysIn)*2 + shiftLineNumber-1, commentINOUT)
852 for i, var
in enumerate(arraysInOut):
854 ifMPPDB.insert(len(arraysIn) + shiftLineNumber + 1 + i,
855 addMPPDB_CHECK_statement(var, subRoutineName,
858 ifMPPDB.insert(len(arraysIn) + shiftLineNumber + 1 + i,
859 addPrints_statement(var, typeofPrints=
'minmax'))
862 scope.insertStatement(scope.indent(ifMPPDBinit), first=
True)
865 if len(arraysInOut) + len(arraysOut) > 0:
867 ifMPPDBend = createExpr(
"IF (MPPDB_INITIALIZED) THEN\n END IF")[0]
869 ifMPPDBend = createExpr(
"IF (LDUMMY1) THEN\n END IF")[0]
870 ifMPPDB = ifMPPDBend.find(
'.//{*}if-block')
873 if len(arraysInOut) > 0:
874 ifMPPDB.insert(1, commentINOUT)
875 for i, var
in enumerate(arraysInOut):
877 ifMPPDB.insert(2 + i, addMPPDB_CHECK_statement(var, subRoutineName,
880 ifMPPDB.insert(2 + i, addPrints_statement(var,
881 typeofPrints=
'minmax'))
884 if len(arraysOut) > 0:
885 shiftLineNumber = 2
if len(arraysInOut) > 0
else 1
887 ifMPPDB.insert(len(arraysInOut) + shiftLineNumber, commentOUT)
889 ifMPPDB.insert(len(arraysInOut)*2 + shiftLineNumber-1, commentOUT)
890 for i, var
in enumerate(arraysOut):
892 ifMPPDB.insert(len(arraysInOut) + shiftLineNumber + 1 + i,
893 addMPPDB_CHECK_statement(var, subRoutineName,
896 ifMPPDB.insert(len(arraysInOut) + shiftLineNumber + 1 + i,
897 addPrints_statement(var, typeofPrints=
'minmax'))
900 scope.insertStatement(scope.indent(ifMPPDBend), first=
False)
1025 def removeIJDim(self, stopScopes, parserOptions=None, wrapH=False, simplify=False):
1027 Transform routines to be called in a loop on columns
1028 :param stopScopes: scope paths where we stop to add the D argument (if needed)
1029 :param parserOptions, wrapH: see the PYFT class
1030 :param simplify: try to simplify code (remove useless dimensions in call)
1032 ComputeInSingleColumn :
1033 - Remove all Do loops on JI and JJ
1034 - Initialize former indexes JI, JJ, JIJ to first array element:
1035 JI=D%NIB, JJ=D%NJB, JIJ=D%NIJB
1036 - If simplify is True, replace (:,*) on I/J/IJ dimension on argument
1037 with explicit (:,*) on CALL statements:
1038 e.g. CALL FOO(D, A(:,JK,1), B(:,:))
1039 ==> CALL FOO(D, A(JIJ,JK,1), B(:,:)) only if the target argument is not an array
1042 indexToCheck = {
'JI': (
'D%NIB',
'D%NIT'),
1043 'JJ': (
'D%NJB',
'D%NJT'),
1044 'JIJ': (
'D%NIJB',
'D%NIJT')}
1045 hUupperBounds = [v[1]
for v
in indexToCheck.values()]
1047 def slice2index(namedE, scope):
1049 Transform a slice on the horizontal dimension into an index
1050 Eg.: X(1:D%NIJT, 1:D%NKT) => X(JIJ, 1:D%NKT) Be careful, this array is not contiguous.
1051 X(1:D%NIJT, JK) => X(JIJ, JK)
1052 :param namedE: array to transform
1053 :param scope: scope where the array is
1056 for isub, sub
in enumerate(namedE.findall(
'./{*}R-LT/{*}array-R/' +
1057 '{*}section-subscript-LT/' +
1058 '{*}section-subscript')):
1059 if ':' in alltext(sub):
1060 loopIndex, _, _ = scope.findIndexArrayBounds(namedE, isub, _loopVarPHYEX)
1061 if loopIndex
in indexToCheck:
1064 lowerBound = createElem(
'lower-bound')
1065 sub.insert(0, lowerBound)
1067 lowerBound = sub.find(
'./{*}lower-bound')
1068 lowerBound.tail =
''
1069 for item
in lowerBound:
1070 lowerBound.remove(item)
1071 upperBound = sub.find(
'./{*}upper-bound')
1072 if upperBound
is not None:
1073 sub.remove(upperBound)
1074 lowerBound.append(createExprPart(loopIndex))
1075 if loopIndex
not in indexRemoved:
1076 indexRemoved.append(loopIndex)
1079 if ':' not in alltext(namedE.find(
'./{*}R-LT/{*}array-R/{*}section-subscript-LT')):
1080 namedE.find(
'./{*}R-LT/{*}array-R').tag = f
'{{{NAMESPACE}}}parens-R'
1081 namedE.find(
'./{*}R-LT/{*}parens-R/' +
1082 '{*}section-subscript-LT').tag = f
'{{{NAMESPACE}}}element-LT'
1083 for ss
in namedE.findall(
'./{*}R-LT/{*}parens-R/' +
1084 '{*}element-LT/{*}section-subscript'):
1085 ss.tag = f
'{{{NAMESPACE}}}element'
1086 lowerBound = ss.find(
'./{*}lower-bound')
1087 for item
in lowerBound:
1089 ss.remove(lowerBound)
1092 self.attachArraySpecToEntity()
1096 for scope
in [scope
for scope
in self.getScopes()[::-1]
1097 if 'func:' not in scope.path
and
1098 (scope.path
in stopScopes
or
1099 self.tree.isUnderStopScopes(scope.path, stopScopes,
1100 includeInterfaces=
True))]:
1102 scope.addArrayParentheses()
1103 scope.expandAllArraysPHYEX()
1110 for doNode
in scope.findall(
'.//{*}do-construct')[::-1]:
1111 for loopI
in doNode.findall(
'./{*}do-stmt/{*}do-V/{*}named-E/{*}N'):
1112 loopIname = n2name(loopI).upper()
1113 if loopIname
in indexToCheck:
1116 par = scope.getParent(doNode)
1117 index = list(par).index(doNode)
1118 for item
in doNode[1:-1][::-1]:
1119 par.insert(index, item)
1121 if loopIname
not in indexRemoved:
1122 indexRemoved.append(loopIname)
1127 for intr
in scope.findall(
'.//{*}R-LT/{*}parens-R/../..'):
1128 intrName = n2name(intr.find(
'./{*}N')).upper()
1129 if intrName
in (
'PACK',
'UNPACK',
'COUNT',
'MAXVAL',
'MINVAL',
'ALL',
'ANY',
'SUM'):
1137 while par
is not None and not isStmt(par):
1138 par = scope.getParent(par)
1139 if tag(par)
in (
'a-stmt',
'op-E'):
1144 for namedE
in parToUse.findall(
'.//{*}R-LT/{*}array-R/../..'):
1145 slice2index(namedE, scope)
1148 if intr.find(
'.//{*}R-LT/{*}array-R')
is None:
1149 if intrName
in (
'MAXVAL',
'MINVAL',
'SUM',
'ALL',
'ANY'):
1151 parens = intr.find(
'./{*}R-LT/{*}parens-R')
1152 parens.tag = f
'{{{NAMESPACE}}}parens-E'
1153 intrPar = scope.getParent(intr)
1154 intrPar.insert(list(intrPar).index(intr), parens)
1155 intrPar.remove(intr)
1156 elif intrName ==
'COUNT':
1158 nodeN = intr.find(
'./{*}N')
1159 for item
in nodeN[1:]:
1161 nodeN.find(
'./{*}n').text =
'MERGE'
1162 elementLT = intr.find(
'./{*}R-LT/{*}parens-R/{*}element-LT')
1164 element = createElem(
'element', tail=
', ')
1165 element.append(createExprPart(val))
1166 elementLT.insert(0, element)
1179 assert scope.find(
'.//{*}include')
is None and \
1180 scope.find(
'.//{*}include-stmt')
is None, \
1181 "inlining must be performed before removing horizontal dimensions"
1183 if scope.path
in stopScopes:
1185 preserveShape = [v[
'n']
for v
in scope.varList
if v[
'arg']]
1190 if 'sub:' in scope.path:
1194 for namedE
in scope.findall(
'.//{*}named-E/{*}R-LT/{*}parens-R/../..'):
1195 if n2name(namedE.find(
'./{*}N')).upper()
not in preserveShape:
1196 var = scope.varList.findVar(n2name(namedE.find(
'./{*}N')).upper())
1197 if var
is not None and var[
'as']
is not None and len(var[
'as']) > 0:
1198 subs = namedE.findall(
'./{*}R-LT/{*}parens-R/' +
1199 '{*}element-LT/{*}element')
1200 if (len(subs) == 1
and var[
'as'][0][1]
in hUupperBounds)
or \
1201 (len(subs) == 2
and (var[
'as'][0][1]
in hUupperBounds
and
1202 var[
'as'][1][1]
in hUupperBounds)):
1203 namedE.remove(namedE.find(
'./{*}R-LT'))
1207 for call
in scope.findall(
'.//{*}call-stmt'):
1208 for namedE
in call.findall(
'./{*}arg-spec//{*}named-E'):
1209 subs = namedE.findall(
'.//{*}section-subscript')
1210 var = scope.varList.findVar(n2name(namedE.find(
'./{*}N')).upper())
1211 if len(subs) > 0
and (var
is None or var[
'as']
is None or
1212 len(var[
'as']) < len(subs)):
1220 elif (len(subs) >= 2
and
1221 ':' in alltext(subs[0])
and var[
'as'][0][1]
in hUupperBounds
and
1222 ':' in alltext(subs[1])
and var[
'as'][1][1]
in hUupperBounds):
1224 remove = len(subs) == 2
1225 index = (len(subs) > 2
and
1226 len([sub
for sub
in subs
if ':' in alltext(sub)]) == 2)
1227 elif (len(subs) >= 1
and
1228 ':' in alltext(subs[0])
and var[
'as'][0][1]
in hUupperBounds):
1230 remove = len(subs) == 1
1231 index = (len(subs) > 1
and
1232 len([sub
for sub
in subs
if ':' in alltext(sub)]) == 1)
1237 if n2name(namedE.find(
'./{*}N')).upper()
in preserveShape:
1238 slice2index(namedE, scope)
1240 nodeRLT = namedE.find(
'.//{*}R-LT')
1241 scope.getParent(nodeRLT).remove(nodeRLT)
1243 slice2index(namedE, scope)
1246 subs = namedE.findall(
'.//{*}section-subscript')
1247 if len(subs) > 0
and all(alltext(sub) ==
':' for sub
in subs):
1248 nodeRLT = namedE.find(
'.//{*}R-LT')
1249 scope.getParent(nodeRLT).remove(nodeRLT)
1254 for decl
in scope.findall(
'.//{*}T-decl-stmt/{*}EN-decl-LT/{*}EN-decl'):
1255 name = n2name(decl.find(
'./{*}EN-N/{*}N')).upper()
1256 if name
not in preserveShape:
1257 varsShape = decl.findall(
'.//{*}shape-spec-LT')
1258 for varShape
in varsShape:
1259 subs = varShape.findall(
'.//{*}shape-spec')
1260 if (len(subs) == 1
and alltext(subs[0])
in hUupperBounds)
or \
1261 (len(subs) == 2
and (alltext(subs[0])
in hUupperBounds
and
1262 alltext(subs[1])
in hUupperBounds)):
1264 itemToRemove = scope.getParent(varShape)
1265 scope.getParent(itemToRemove).remove(itemToRemove)
1270 for loopIndex
in indexRemoved:
1273 scope.insertStatement(
1274 createExpr(loopIndex +
" = " + indexToCheck[loopIndex][0])[0],
True)
1275 if len(indexRemoved) > 0:
1276 scope.addArgInTree(
'D',
'TYPE(DIMPHYEX_t), INTENT(IN) :: D',
1277 0, stopScopes, moduleVarList=[(
'MODD_DIMPHYEX', [
'DIMPHYEX_t'])],
1278 parserOptions=parserOptions, wrapH=wrapH)
1280 scope.addVar([[scope.path, loopIndex,
'INTEGER :: ' + loopIndex,
None]
1281 for loopIndex
in indexRemoved
1282 if scope.varList.findVar(loopIndex, exactScope=
True)
is None])
1478 Convert all calling of functions and gradient present in shumansGradients
1479 table into the use of subroutines
1480 and use mnh_expand_directives to handle intermediate computations
1482 def getDimsAndMNHExpandIndexes(zshugradwkDim, dimWorkingVar=''):
1484 if zshugradwkDim == 1:
1485 dimSuffRoutine =
'2D'
1487 mnhExpandArrayIndexes =
'JIJ=IIJB:IIJE'
1488 localVariables = [
'JIJ']
1489 elif zshugradwkDim == 2:
1491 if 'D%NKT' in dimWorkingVar:
1492 mnhExpandArrayIndexes =
'JIJ=IIJB:IIJE,JK=1:IKT'
1493 localVariables = [
'JIJ',
'JK']
1494 elif 'D%NIT' in dimWorkingVar
and 'D%NJT' in dimWorkingVar:
1496 mnhExpandArrayIndexes =
'JI=1:IIT,JJ=1:IJT'
1497 localVariables = [
'JI',
'JJ']
1498 dimSuffRoutine =
'2D'
1504 mnhExpandArrayIndexes =
'JIJ=IIJB:IIJE,JK=1:IKT'
1505 localVariables = [
'JIJ',
'JK']
1506 elif zshugradwkDim == 3:
1508 mnhExpandArrayIndexes =
'JI=1:IIT,JJ=1:IJT,JK=1:IKT'
1509 localVariables = [
'JI',
'JJ',
'JK']
1511 raise PYFTError(
'Shuman func to routine conversion not implemented ' +
1512 'for 4D+ dimensions variables')
1513 return dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes, localVariables
1515 def FUNCtoROUTINE(scope, stmt, itemFuncN, localShumansCount, inComputeStmt,
1516 nbzshugradwk, zshugradwkDim, dimWorkingVar):
1518 :param scope: node on which the calling function is present before transformation
1519 :param stmt: statement node (a-stmt or call-stmt) that contains the function(s) to be
1521 :param itemFuncN: <n>FUNCTIONNAME</n> node
1522 :param localShumansCount: instance of the shumansGradients dictionnary
1523 for the given scope (which contains the number of times a
1524 function has been called within a transformation)
1525 :param dimWorkingVar: string of the declaration of a potential working variable
1526 depending on the array on wich the shuman is applied
1527 (e.g. MZM(PRHODJ(:,IKB));
1528 dimWorkingVar = 'REAL, DIMENSION(D%NIJT) :: ' )
1530 :return callStmt: the new CALL to the routines statement
1531 :return computeStmt: the a-stmt computation statement if there was an operation
1532 in the calling function in stmt
1533 :return localVariables: list of local variables needed for the mnh_expand directive
1537 parStmt = scope.getParent(stmt)
1538 parItemFuncN = scope.getParent(itemFuncN)
1540 grandparItemFuncN = scope.getParent(itemFuncN, level=2)
1541 funcName = alltext(itemFuncN)
1544 indexForCall = list(parStmt).index(stmt)
1549 siblsItemFuncN = scope.getSiblings(parItemFuncN, after=
True, before=
False)
1550 workingItem = siblsItemFuncN[0][0][0]
1553 if len(siblsItemFuncN[0][0]) > 1:
1555 workingItem = scope.updateContinuation(siblsItemFuncN[0][0], removeALL=
True,
1556 align=
False, addBegin=
False)[0]
1560 opE = workingItem.findall(
'.//{*}op-E')
1561 scope.removeArrayParenthesesInNode(workingItem)
1562 computeStmt, remaningArgsofFunc = [],
''
1563 dimSuffVar = str(zshugradwkDim) +
'D'
1564 dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes, _ = \
1565 getDimsAndMNHExpandIndexes(zshugradwkDim, dimWorkingVar)
1568 computingVarName =
'ZSHUGRADWK'+str(nbzshugradwk)+
'_'+str(zshugradwkDim)+
'D'
1570 if not scope.varList.findVar(computingVarName):
1571 scope.addVar([[scope.path, computingVarName,
1572 dimWorkingVar + computingVarName,
None]])
1576 computeVar = scope.varList.findVar(computingVarName)
1577 dimWorkingVar =
'REAL, DIMENSION('
1578 for dims
in computeVar[
'as'][:arrayDim]:
1579 dimWorkingVar += dims[1] +
','
1580 dimWorkingVar = dimWorkingVar[:-1] +
') ::'
1582 dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes, localVariables = \
1583 getDimsAndMNHExpandIndexes(zshugradwkDim, dimWorkingVar)
1586 mnhOpenDir =
"!$mnh_expand_array(" + mnhExpandArrayIndexes +
")"
1587 mnhCloseDir =
"!$mnh_end_expand_array(" + mnhExpandArrayIndexes +
")"
1590 workingComputeItem = workingItem[0]
1592 if len(workingItem) == 2:
1593 remaningArgsofFunc =
',' + alltext(workingItem[1])
1594 elif len(workingItem) > 2:
1595 raise PYFTError(
'ShumanFUNCtoCALL: expected maximum 1 argument in shuman ' +
1596 'function to transform')
1597 computeStmt = createExpr(computingVarName +
" = " + alltext(workingComputeItem))[0]
1598 workingItem = computeStmt.find(
'.//{*}E-1')
1600 parStmt.insert(indexForCall, createElem(
'C', text=
'!$acc kernels', tail=
'\n'))
1601 parStmt.insert(indexForCall + 1, createElem(
'C', text=mnhOpenDir, tail=
'\n'))
1602 parStmt.insert(indexForCall + 2, computeStmt)
1603 parStmt.insert(indexForCall + 3, createElem(
'C', text=mnhCloseDir, tail=
'\n'))
1604 parStmt.insert(indexForCall + 4, createElem(
'C',
1605 text=
'!$acc end kernels', tail=
'\n'))
1606 parStmt.insert(indexForCall + 5, createElem(
'C',
1607 text=
'!', tail=
'\n'))
1611 if zshugradwkDim == 1:
1612 dimSuffRoutine =
'2D'
1613 workingVar =
'Z' + funcName + dimSuffVar +
'_WORK' + str(localShumansCount[funcName])
1614 if funcName
in (
'GY_U_UV',
'GX_V_UV'):
1615 gpuGradientImplementation =
'_DEVICE('
1616 newFuncName = funcName + dimSuffRoutine +
'_DEVICE'
1618 gpuGradientImplementation =
'_PHY(D, '
1619 newFuncName = funcName + dimSuffRoutine +
'_PHY'
1620 callStmt = createExpr(
"CALL " + funcName + dimSuffRoutine + gpuGradientImplementation
1621 + alltext(workingItem) + remaningArgsofFunc +
1622 ", " + workingVar +
")")[0]
1623 parStmt.insert(indexForCall, callStmt)
1626 parOfgrandparItemFuncN = scope.getParent(grandparItemFuncN)
1627 indexWorkingVar = list(parOfgrandparItemFuncN).index(grandparItemFuncN)
1628 savedTail = grandparItemFuncN.tail
1629 parOfgrandparItemFuncN.remove(grandparItemFuncN)
1632 xmlWorkingvar = createExprPart(workingVar)
1633 xmlWorkingvar.tail = savedTail
1634 parOfgrandparItemFuncN.insert(indexWorkingVar, xmlWorkingvar)
1637 if not scope.varList.findVar(workingVar):
1638 scope.addVar([[scope.path, workingVar, dimWorkingVar + workingVar,
None]])
1640 return (callStmt, computeStmt, nbzshugradwk, newFuncName,
1641 localVariables, mnhExpandArrayIndexes)
1643 shumansGradients = {
'MZM': 0,
'MXM': 0,
'MYM': 0,
'MZF': 0,
'MXF': 0,
'MYF': 0,
1644 'DZM': 0,
'DXM': 0,
'DYM': 0,
'DZF': 0,
'DXF': 0,
'DYF': 0,
1645 'GZ_M_W': 0,
'GZ_W_M': 0,
'GZ_U_UW': 0,
'GZ_V_VW': 0,
1646 'GX_M_U': 0,
'GX_U_M': 0,
'GX_W_UW': 0,
'GX_M_M': 0,
1647 'GY_V_M': 0,
'GY_M_V': 0,
'GY_W_VW': 0,
'GY_M_M': 0,
1648 'GX_V_UV': 0,
'GY_U_UV': 0}
1649 scopes = self.getScopes()
1650 if len(scopes) == 0
or scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
1652 for scope
in scopes:
1653 if 'sub:' in scope.path
and 'func' not in scope.path \
1654 and 'interface' not in scope.path:
1657 localVariablesToAdd = set()
1658 foundStmtandCalls, computeStmtforParenthesis = {}, []
1659 aStmt = scope.findall(
'.//{*}a-stmt')
1660 callStmts = scope.findall(
'.//{*}call-stmt')
1661 aStmtandCallStmts = aStmt + callStmts
1662 funcToSuppress = set()
1663 for stmt
in aStmtandCallStmts:
1664 elemN = stmt.findall(
'.//{*}n')
1666 if alltext(el)
in list(shumansGradients):
1667 funcToSuppress.add(alltext(el))
1670 parStmt = scope.getParent(stmt)
1671 if tag(parStmt) ==
'action-stmt':
1672 scope.changeIfStatementsInIfConstructs(
1673 singleItem=scope.getParent(parStmt))
1675 if str(stmt)
in foundStmtandCalls:
1676 foundStmtandCalls[str(stmt)][1] += 1
1678 foundStmtandCalls[str(stmt)] = [stmt, 1]
1681 subToInclude = set()
1682 for stmt
in foundStmtandCalls:
1683 localShumansGradients = copy.deepcopy(shumansGradients)
1684 elemToLookFor = [foundStmtandCalls[stmt][0]]
1685 previousComputeStmt = []
1688 while len(elemToLookFor) > 0:
1690 for elem
in elemToLookFor:
1691 elemN = elem.findall(
'.//{*}n')
1693 if alltext(el)
in list(localShumansGradients.keys()):
1698 nodeE1var = foundStmtandCalls[stmt][0].findall(
1699 './/{*}E-1/{*}named-E/{*}N')
1700 if len(nodeE1var) > 0:
1701 var = scope.varList.findVar(alltext(nodeE1var[0]))
1702 allSubscripts = foundStmtandCalls[stmt][0].findall(
1703 './/{*}E-1//{*}named-E/{*}R-LT/' +
1704 '{*}array-R/{*}section-subscript-LT')
1708 elPar = scope.getParent(el, level=2)
1709 callVar = elPar.findall(
'.//{*}named-E/{*}N')
1710 if alltext(el)[0] ==
'G':
1715 var = scope.varList.findVar(alltext(callVar[-1]))
1716 shumanIsCalledOn = scope.getParent(callVar[-1])
1719 var, inested =
None, 0
1721 while (
not var
or var[
'as']
is None or
1722 len(var[
'as']) == 0):
1726 var = scope.varList.findVar(
1727 alltext(callVar[inested]))
1729 shumanIsCalledOn = scope.getParent(callVar[inested-1])
1730 allSubscripts = shumanIsCalledOn.findall(
1731 './/{*}R-LT/{*}array-R/' +
1732 '{*}section-subscript-LT')
1736 arrayDim = len(var[
'as'])
1740 if len(allSubscripts) > 0:
1741 for subLT
in allSubscripts:
1743 lowerBound = sub.findall(
'.//{*}lower-bound')
1744 if len(lowerBound) > 0:
1745 if len(sub.findall(
'.//{*}upper-bound')) > 0:
1748 raise PYFTError(
'ShumanFUNCtoCALL does ' +
1749 'not handle conversion ' +
1750 'to routine of array ' +
1751 'subselection lower:upper' +
1752 ': how to set up the ' +
1753 'shape of intermediate ' +
1763 dimWorkingVar =
'REAL, DIMENSION('
1764 for dims
in var[
'as'][:arrayDim]:
1765 dimWorkingVar += dims[1] +
','
1766 dimWorkingVar = dimWorkingVar[:-1] +
') ::'
1769 localShumansGradients[alltext(el)] += 1
1773 if foundStmtandCalls[stmt][0].tail:
1774 foundStmtandCalls[stmt][0].tail = \
1775 foundStmtandCalls[stmt][0].tail.replace(
'\n',
'') +
'\n'
1777 foundStmtandCalls[stmt][0].tail =
'\n'
1780 result = FUNCtoROUTINE(scope, elem, el,
1781 localShumansGradients,
1782 elem
in previousComputeStmt,
1783 nbzshugradwk, arrayDim,
1785 (newCallStmt, newComputeStmt,
1786 nbzshugradwk, newFuncName, lv,
1787 mnhExpandArrayIndexes) = result
1788 localVariablesToAdd.update(lv)
1789 subToInclude.add(newFuncName)
1792 elemToLookFor.append(newCallStmt)
1796 if len(newComputeStmt) > 0:
1797 elemToLookFor.append(newComputeStmt)
1798 computeStmtforParenthesis.append(
1799 [newComputeStmt, mnhExpandArrayIndexes])
1803 previousComputeStmt.append(newComputeStmt)
1807 elemToLookForNew = []
1808 for i
in elemToLookFor:
1809 nodeNs = i.findall(
'.//{*}n')
1812 if alltext(nnn)
in list(localShumansGradients):
1813 elemToLookForNew.append(i)
1815 elemToLookFor = elemToLookForNew
1818 if nbzshugradwk > maxnbZshugradwk:
1819 maxnbZshugradwk = nbzshugradwk
1824 if tag(foundStmtandCalls[stmt][0]) !=
'call-stmt':
1825 dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes, lv = \
1826 getDimsAndMNHExpandIndexes(arrayDim, dimWorkingVar)
1827 localVariablesToAdd.update(lv)
1829 scope.addArrayParenthesesInNode(foundStmtandCalls[stmt][0])
1831 table = {c.split(
'=')[0]: c.split(
'=')[1].split(
':')
1832 for c
in mnhExpandArrayIndexes.split(
',')}
1833 table.pop(
'OPENACC',
None)
1834 for namedE
in foundStmtandCalls[stmt][0].findall(
1836 arrayR = namedE.find(
'./{*}R-LT/{*}array-R')
1840 for ss
in arrayR.findall(
1841 './{*}section-subscript-LT/{*}section-subscript'):
1842 if ':' in (ss.text
or ''):
1844 varName = list(table.keys())[ivar]
1845 lowerStr, upperStr = table[varName]
1846 lb, ub = createArrayBounds(
1847 lowerStr, upperStr,
'ARRAY')
1851 parStmt = scope.getParent(foundStmtandCalls[stmt][0])
1852 indexForCall = list(parStmt).index(foundStmtandCalls[stmt][0])
1853 mnhOpenDir =
"!$mnh_expand_array(" + mnhExpandArrayIndexes +
")"
1854 mnhCloseDir =
"!$mnh_end_expand_array(" + mnhExpandArrayIndexes +
")"
1855 parStmt.insert(indexForCall,
1856 createElem(
'C', text=
"!$acc kernels", tail=
'\n'))
1857 parStmt.insert(indexForCall + 1,
1858 createElem(
'C', text=mnhOpenDir, tail=
'\n'))
1859 parStmt.insert(indexForCall + 3,
1860 createElem(
'C', text=mnhCloseDir, tail=
'\n'))
1861 parStmt.insert(indexForCall + 4,
1862 createElem(
'C', text=
"!$acc end kernels", tail=
'\n'))
1863 parStmt.insert(indexForCall + 5,
1864 createElem(
'C', text=
"!", tail=
'\n'))
1868 for stmt, mnhExpandArrayIndexes
in computeStmtforParenthesis:
1869 scope.addArrayParenthesesInNode(stmt)
1870 table = {c.split(
'=')[0]: c.split(
'=')[1].split(
':')
1871 for c
in mnhExpandArrayIndexes.split(
',')}
1872 table.pop(
'OPENACC',
None)
1873 for namedE
in stmt.findall(
'.//{*}R-LT/..'):
1874 arrayR = namedE.find(
'./{*}R-LT/{*}array-R')
1878 for ss
in arrayR.findall(
1879 './{*}section-subscript-LT/{*}section-subscript'):
1880 if ':' in (ss.text
or ''):
1882 varName = list(table.keys())[ivar]
1883 lowerStr, upperStr = table[varName]
1884 lb, ub = createArrayBounds(
1885 lowerStr, upperStr,
'ARRAY')
1891 for sub
in sorted(subToInclude):
1892 if re.match(
r'[MD][XYZ][MF](2D)?_PHY', sub):
1893 moduleVars.append((scope.path,
'MODE_SHUMAN_PHY', sub))
1894 if re.match(
r'[MD][XYZ][MF](2D)?_DEVICE', sub):
1895 moduleVars.append((scope.path,
'MODI_SHUMAN_DEVICE', sub))
1897 for kind
in (
'M',
'U',
'V',
'W'):
1898 if re.match(
r'G[XYZ]_' + kind +
r'_[MUVW]{1,2}_PHY', sub):
1899 moduleVars.append((scope.path, f
'MODE_GRADIENT_{kind}_PHY', sub))
1900 elif re.match(
r'G[XYZ]_' + kind +
r'_[MUVW]{1,2}_DEVICE', sub):
1901 moduleVars.append((scope.path, f
'MODI_GRADIENT_{kind}', sub))
1902 scope.addModuleVar(moduleVars)
1905 for sub
in funcToSuppress:
1906 if scope.varList.findVar(sub):
1907 scope.removeVar([(scope.path, sub)])
1910 for varName
in localVariablesToAdd:
1911 if not scope.varList.findVar(varName):
1912 var = {
'as': [],
'asx': [],
1913 'n': varName,
'i':
None,
't':
'INTEGER',
'arg':
False,
1914 'use':
False,
'opt':
False,
'allocatable':
False,
1915 'parameter':
False,
'init':
None,
'scopePath': scope.path}
1916 scope.addVar([[scope.path, var[
'n'], scope.varSpec2stmt(var),
None]])