251 Inline structure member accesses in compute statements.
253 Converts TYPE%VAR access patterns into single local variables
254 to improve performance by reducing pointer dereferences.
256 Transformation Examples
257 ----------------------
258 Simple member access:
259 - ZA = 1 + CST%XG => ZA = 1 + XCST_G
262 - ZA = 1 + PARAM_ICE%XRTMIN(3) => ZA = 1 + XPARAM_ICE_XRTMIN3
265 - ZRSMIN(1:KRR) = ICED%XRTMIN(1:KRR) => ZRSMIN(1:KRR) = ICEDXRTMIN1KRR(1:KRR)
268 - IF(TURBN%CSUBG_MF_PDF=='NONE') => IF(CTURBNSUBG_MF_PDF=='NONE')
272 - Only handles single-level structure access (not TOTO%CST%XG)
273 - Does not handle arrays with deferred shapes
274 - Type components must have known dimensions
276 def convertOneType(component, newVarList, scope):
278 objType = scope.getParent(component, 2)
279 objTypeStr = alltext(objType).upper()
280 namedENn = objType.find(
'.//{*}N/{*}n')
281 structure = namedENn.text
282 variable = component.find(
'.//{*}ct').text.upper()
283 if variable[0] ==
"T":
288 arrayRall = objType.findall(
'.//{*}array-R')
289 if len(arrayRall) > 0:
290 arrayR = copy.deepcopy(arrayRall[0])
291 txt = alltext(arrayR).replace(
',',
'')
292 txt = txt.replace(
':',
'')
293 txt = txt.replace(
'(',
'')
294 txt = txt.replace(
')',
'')
295 arrayIndices = arrayIndices + txt
296 elif len(objType.findall(
'.//{*}element-LT')) > 0:
298 for elem
in objType.findall(
'.//{*}element'):
299 arrayIndices = arrayIndices + alltext(elem)
300 newName = variable[0] + structure + variable[1:] + arrayIndices
301 newName = newName.upper()
305 namedENn.text = newName
306 objType.remove(objType.find(
'.//{*}R-LT'))
307 if len(arrayRall) > 0:
308 objType.insert(1, arrayR)
311 if newName
not in newVarList:
312 if len(arrayRall) == 0:
313 newVarList[newName] = (
None, objTypeStr)
315 newVarList[newName] = (arrayR, objTypeStr)
317 scopes = self.getScopes()
318 if scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
320 for scope
in [scope
for scope
in scopes
321 if 'sub:' in scope.path
and 'interface' not in scope.path]:
323 for ifStmt
in (scope.findall(
'.//{*}if-then-stmt') +
324 scope.findall(
'.//{*}else-if-stmt') +
325 scope.findall(
'.//{*}where-stmt')):
326 compo = ifStmt.findall(
'.//{*}component-R')
328 for elcompo
in compo:
329 convertOneType(elcompo, newVarList, scope)
331 for aStmt
in scope.findall(
'.//{*}a-stmt'):
334 if len(aStmt[0].findall(
'.//{*}component-R')) == 0:
335 compoE2 = aStmt.findall(
'.//{*}component-R')
342 nbNamedEinE2 = len(aStmt.findall(
'.//{*}E-2')[0].findall(
'.//{*}named-E/' +
344 if nbNamedEinE2 > 1
or nbNamedEinE2 == 1
and \
345 len(aStmt[0].findall(
'.//{*}R-LT')) == 1:
346 for elcompoE2
in compoE2:
347 convertOneType(elcompoE2, newVarList, scope)
351 for aStmt
in scope.findall(
'.//{*}a-stmt'):
353 if len(aStmt[0].findall(
'.//{*}component-R')) > 0:
355 for el
in newVarList.items():
356 if alltext(aStmt[0]) == el[1][1]:
357 stmtAffect = createExpr(el[0] +
"=" + alltext(aStmt[0]))
358 par = scope.getParent(aStmt)
360 if tag(par[list(par).index(aStmt)+1]) ==
'C':
362 par.insert(list(par).index(aStmt)+1+iExtra, stmtAffect[0])
366 for el, var
in newVarList.items():
367 if el[0].upper() ==
'X' or el[0].upper() ==
'P' or el[0].upper() ==
'Z':
369 elif el[0].upper() ==
'L' or el[0].upper() ==
'O':
371 elif el[0].upper() ==
'N' or el[0].upper() ==
'I' or el[0].upper() ==
'K':
373 elif el[0].upper() ==
'C':
374 varType =
'CHARACTER(LEN=LEN(' + var[1] +
'))'
376 raise PYFTError(
'Case not implemented for the first letter of the newVarName ' +
377 el +
' in convertTypesInCompute')
381 varArray =
', DIMENSION('
382 for i, sub
in enumerate(var[0].findall(
'.//{*}section-subscript')):
383 if len(sub.findall(
'.//{*}upper-bound')) > 0:
384 dimSize = simplifyExpr(
385 alltext(sub.findall(
'.//{*}upper-bound')[0]) +
386 '-' + alltext(sub.findall(
'.//{*}lower-bound')[0]) +
388 elif len(sub.findall(
'.//{*}lover-bound')) > 0:
389 dimSize = simplifyExpr(alltext(sub.findall(
'.//{*}lower-bound')[0]))
391 dimSize =
'SIZE(' + var[1] +
',' + str(i+1) +
')'
392 varArray =
', DIMENSION(' + dimSize +
','
393 varArray = varArray[:-1] +
')'
394 scope.addVar([[scope.path, el, varType + varArray +
' :: ' + el,
None]])
397 stmtAffect = createExpr(el +
"=" + var[1])[0]
398 scope.insertStatement(scope.indent(stmtAffect), first=
True)
504 Convert MODULE to SUBMODULE statements and add INTERFACE of SUBROUTINEs of PHYEX
505 ==> Applied only on MODE_
507 - if an INTERFACE already exists
508 - if no subroutine is present in the module
509 - to CONTAINS routines
510 1) Create interface statement if any
511 2) Add subroutines declaration (with MODULE statement)
512 3) Add SUBMODULE statements and convert SUBROUTINE to MODULE SUBROUTINE statements
514 scopes = self.getScopes()
517 oldModNode = self.find(
'.//{*}program-unit')
518 modNode = copy.deepcopy(self.find(
'.//{*}program-unit'))
520 interfaceStmt = self.findall(
'.//{*}interface-stmt')
521 subStmt = self.findall(
'.//{*}subroutine-stmt')
523 if modScope.path.split(
'/')[-1].split(
':')[1][:4] ==
'MODE' and \
524 len(interfaceStmt) == 0
and len(subStmt) > 0:
525 moduleName = modScope.path.split(
'/')[-1].split(
':')[1][:]
527 newMod = createElem(
'program-unit', text=
'MODULE ' + moduleName, tail=
'\n')
529 newMod.append(createElem(
'implicit-none-stmt', text=
'IMPLICIT NONE', tail=
'\n'))
530 interfaceStmt = createElem(
'interface-construct')
531 interfaceStmt.append(createElem(
'interface-stmt', text=
'INTERFACE', tail=
'\n'))
532 interfaceStmt.append(createElem(
'end-interface-stmt', text=
'END INTERFACE', tail=
'\n'))
533 newMod.append(interfaceStmt)
537 for scope
in scopes[1:]:
539 if sum(
'sub' in s
for s
in scope.path.split(
'/')) == 1:
540 subsModified.append(scope.path.split(
'/')[-1].split(
':')[1][:])
541 subroutineDecl = createElem(
'module-unit')
543 subroutineStmt = copy.deepcopy(scope[0])
544 declType = subroutineStmt.text
545 prefix = createElem(
'prefix')
546 prefix.text =
'MODULE'
547 subroutineStmt.text =
''
548 subroutineStmt.insert(0, prefix)
549 prefix.tail =
' ' + declType
550 subroutineDecl.append(subroutineStmt)
552 for use
in scope.findall(
'.//{*}use-stmt'):
553 subroutineDecl.append(copy.deepcopy(use))
554 subroutineDecl.append(createElem(
'implicit-none-stmt', text=
'IMPLICIT NONE',
557 for var
in [var
for var
in scope.varList
if var[
'arg']
or var[
'result']]:
558 subroutineDecl.append(createExpr(self.varSpec2stmt(var,
True))[0])
559 for external
in scope.findall(
'./{*}external-stmt'):
560 subroutineDecl.append(copy.deepcopy(external))
561 if 'SUBROUTINE' in declType:
562 endStmt = createElem(
'end-subroutine-stmt')
563 declName = subroutineStmt.find(
'./{*}subroutine-N/{*}N/{*}n').text
564 elif 'FUNCTION' in declType:
565 endStmt = createElem(
'end-function-stmt')
566 declName = subroutineStmt.find(
'./{*}function-N/{*}N/{*}n').text
568 raise PYFTError(
'declType in addSubmodulePHYEX not handled')
570 endStmt.text =
'END ' + declType + declName +
'\n'
571 subroutineDecl.append(endStmt)
572 interfaceStmt.insert(1, subroutineDecl)
575 newMod.append(createElem(
'end-program-unit', text=
'END MODULE ' + moduleName,
577 self[0].insert(0, newMod)
580 progUnit = createElem(
'program-unit')
591 submoduleStmt = createElem(
'submodule-stmt', text=
'SUBMODULE (')
592 parentId = createElem(
'parent-identifier')
594 ancestorModule = createElem(
'ancestor-module-N')
595 ancestorModuleN = createElem(
'n', text=moduleName)
596 ancestorModule.append(ancestorModuleN)
597 parentId.append(ancestorModule)
598 submoduleStmt.append(parentId)
599 submoduleModule = createElem(
'submodule-module-N')
600 submoduleModuleN = createElem(
'n', text=
'S' + moduleName, tail=
'\n')
601 submoduleModule.append(submoduleModuleN)
602 submoduleStmt.append(submoduleModule)
603 progUnit.append(submoduleStmt)
606 endSubmoduleStmt = createElem(
'end-submodule-stmt', text=
'END SUBMODULE ')
607 submoduleN = createElem(
'submodule-N')
608 submoduleNN = createElem(
'N')
609 submoduleNNn = createElem(
'n', text=
'S' + moduleName, tail=
'\n')
610 submoduleNN.append(submoduleNNn)
611 submoduleN.append(submoduleNN)
612 endSubmoduleStmt.append(submoduleN)
614 progUnit.append(endSubmoduleStmt)
615 progUnit.append(createElem(
'end-program-unit'))
619 modStmt = modNode.find(
'.//{*}module-stmt')
620 modEndStmt = modNode.find(
'.//{*}end-module-stmt')
621 modNode.remove(modStmt)
622 modNode.remove(modEndStmt)
625 publicStmts = modNode.findall(
'.//{*}public-stmt')
626 privateStmts = modNode.findall(
'.//{*}private-stmt')
627 if len(publicStmts) > 0:
628 for publicStmt
in publicStmts:
629 modNode.remove(publicStmt)
630 if len(privateStmts) > 0:
631 for privateStmt
in privateStmts:
632 modNode.remove(privateStmt)
635 subroutines = modNode.findall(
'.//{*}subroutine-stmt')
636 for sub
in subroutines:
637 if sub.find(
'.//{*}N/{*}n').text
in subsModified:
638 prefix = createElem(
'prefix')
639 prefix.text =
'MODULE'
641 sub.insert(0, prefix)
642 prefix.tail =
' SUBROUTINE '
643 progUnit.insert(1, modNode)
645 self.insert(1, progUnit)
648 self.remove(oldModNode)
654 Add MPPDB_CHEKS on all intent REAL arrays on subroutines.
655 ****** Not applied on modd_ routines. ********
656 Handle optional arguments.
657 Example, for a BL89 routine with 4 arguments, 1 INTENT(IN),
658 2 INTENT(INOUT), 1 INTENT(OUT), it produces :
659 IF (MPPDB_INITIALIZED) THEN
661 CALL MPPDB_CHECK(PZZ, "BL89 beg:PZZ")
662 !Check all INOUT arrays
663 CALL MPPDB_CHECK(PDZZ, "BL89 beg:PDZZ")
664 CALL MPPDB_CHECK(PTHVREF, "BL89 beg:PTHVREF")
667 IF (MPPDB_INITIALIZED) THEN
668 !Check all INOUT arrays
669 CALL MPPDB_CHECK(PDZZ, "BL89 end:PDZZ")
670 CALL MPPDB_CHECK(PTHVREF, "BL89 end:PTHVREF")
671 !Check all OUT arrays
672 CALL MPPDB_CHECK(PLM, "BL89 end:PLM")
674 param printsMode: if True, instead of CALL MPPDB_CHECK, add fortran prints for debugging
676 def addPrints_statement(var, typeofPrints='minmax'):
677 ifBeg, ifEnd =
'',
''
680 if typeofPrints ==
'minmax':
681 strMSG = f
'MINMAX {varName} = \",MINVAL({varName}), MAXVAL({varName})'
682 elif typeofPrints ==
'shape':
683 strMSG = f
'SHAPE {varName} = \",SHAPE({varName})'
685 raise PYFTError(
'typeofPrints is either minmax or shape in addPrints_statement')
687 strMSG = var[
'n'] +
' = \",' + var[
'n']
689 ifBeg = ifBeg +
'IF (PRESENT(' + var[
'n'] +
')) THEN\n '
690 ifEnd = ifEnd +
'\nEND IF'
691 return createExpr(ifBeg +
"print*,\"" + strMSG + ifEnd)[0]
693 def addMPPDB_CHECK_statement(var, subRoutineName, strMSG='beg:
'):
694 ifBeg, ifEnd, addD, addLastDim, addSecondDimType = '',
'',
'',
'',
''
698 if 'D%NIJT' in var[
'as'][0][1]:
700 if len(var[
'as']) == 2:
702 addLastDim =
', ' + var[
'as'][1][1]
703 if len(var[
'as']) >= 2:
706 if 'D%NK' in var[
'as'][1][1]:
707 addSecondDimType =
',' +
'''"VERTICAL"'''
709 addSecondDimType =
',' +
'''"OTHER"'''
710 if 'MERGE' in var[
'as'][-1][1]:
711 keyDimMerge = var[
'as'][-1][1].split(
',')[2][:-1]
712 ifBeg =
'IF (' + keyDimMerge +
') THEN\n'
715 ifBeg = ifBeg +
'IF (PRESENT(' + var[
'n'] +
')) THEN\n IF (SIZE(' + \
716 var[
'n'] +
',1) > 0) THEN\n'
717 ifEnd = ifEnd +
'\nEND IF\nEND IF'
718 argsMPPDB = var[
'n'] +
", " +
"\"" + subRoutineName +
" " + strMSG+var[
'n'] +
"\""
719 return createExpr(ifBeg +
"CALL MPPDB_CHECK(" + addD + argsMPPDB +
720 addLastDim + addSecondDimType +
")" + ifEnd)[0]
721 scopes = self.getScopes()
722 if scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
729 if 'sub:' in scope.path
and 'func' not in scope.path
and 'interface' not in scope.path:
730 subRoutineName = scope.path.split(
'/')[-1].split(
':')[1]
733 arraysIn, arraysInOut, arraysOut = [], [], []
735 for var
in scope.varList:
736 if var[
'arg']
and var[
'as']
and 'TYPE' not in var[
't']
and \
737 'REAL' in var[
't']
and var[
'scopePath'] == scope.path:
740 if var[
'i'] ==
'INOUT':
741 arraysInOut.append(var)
742 if var[
'i'] ==
'OUT':
743 arraysOut.append(var)
745 for var
in scope.varList:
746 if not var[
't']
or var[
't']
and 'TYPE' not in var[
't']:
749 if var[
'i'] ==
'INOUT':
750 arraysInOut.append(var)
751 if var[
'i'] ==
'OUT':
752 arraysOut.append(var)
754 if len(arraysIn) + len(arraysInOut) + len(arraysOut) == 0:
759 scope.addModuleVar([(scope.path,
'MODE_MPPDB',
None)])
761 scope.addModuleVar([(scope.path,
'MODD_BLANK_n', [
'LDUMMY1'])])
764 commentIN = createElem(
'C', text=
'!Check all IN arrays', tail=
'\n')
765 commentINOUT = createElem(
'C', text=
'!Check all INOUT arrays', tail=
'\n')
766 commentOUT = createElem(
'C', text=
'!Check all OUT arrays', tail=
'\n')
769 if len(arraysIn) + len(arraysInOut) > 0:
771 ifMPPDBinit = createExpr(
"IF (MPPDB_INITIALIZED) THEN\n END IF")[0]
773 ifMPPDBinit = createExpr(
"IF (LDUMMY1) THEN\n END IF")[0]
774 ifMPPDB = ifMPPDBinit.find(
'.//{*}if-block')
777 if len(arraysIn) > 0:
778 ifMPPDB.insert(1, commentIN)
779 for i, var
in enumerate(arraysIn):
781 ifMPPDB.insert(2 + i, addMPPDB_CHECK_statement(var, subRoutineName,
784 ifMPPDB.insert(2 + i, addPrints_statement(var,
785 typeofPrints=
'minmax'))
786 ifMPPDB.insert(3 + i, addPrints_statement(var,
787 typeofPrints=
'shape'))
790 if len(arraysInOut) > 0:
791 shiftLineNumber = 2
if len(arraysIn) > 0
else 1
793 ifMPPDB.insert(len(arraysIn) + shiftLineNumber, commentINOUT)
795 ifMPPDB.insert(len(arraysIn)*2 + shiftLineNumber-1, commentINOUT)
797 for i, var
in enumerate(arraysInOut):
799 ifMPPDB.insert(len(arraysIn) + shiftLineNumber + 1 + i,
800 addMPPDB_CHECK_statement(var, subRoutineName,
803 ifMPPDB.insert(len(arraysIn) + shiftLineNumber + 1 + i,
804 addPrints_statement(var, typeofPrints=
'minmax'))
807 scope.insertStatement(scope.indent(ifMPPDBinit), first=
True)
810 if len(arraysInOut) + len(arraysOut) > 0:
812 ifMPPDBend = createExpr(
"IF (MPPDB_INITIALIZED) THEN\n END IF")[0]
814 ifMPPDBend = createExpr(
"IF (LDUMMY1) THEN\n END IF")[0]
815 ifMPPDB = ifMPPDBend.find(
'.//{*}if-block')
818 if len(arraysInOut) > 0:
819 ifMPPDB.insert(1, commentINOUT)
820 for i, var
in enumerate(arraysInOut):
822 ifMPPDB.insert(2 + i, addMPPDB_CHECK_statement(var, subRoutineName,
825 ifMPPDB.insert(2 + i, addPrints_statement(var,
826 typeofPrints=
'minmax'))
829 if len(arraysOut) > 0:
830 shiftLineNumber = 2
if len(arraysInOut) > 0
else 1
832 ifMPPDB.insert(len(arraysInOut) + shiftLineNumber, commentOUT)
834 ifMPPDB.insert(len(arraysInOut)*2 + shiftLineNumber-1, commentOUT)
835 for i, var
in enumerate(arraysOut):
837 ifMPPDB.insert(len(arraysInOut) + shiftLineNumber + 1 + i,
838 addMPPDB_CHECK_statement(var, subRoutineName,
841 ifMPPDB.insert(len(arraysInOut) + shiftLineNumber + 1 + i,
842 addPrints_statement(var, typeofPrints=
'minmax'))
845 scope.insertStatement(scope.indent(ifMPPDBend), first=
False)
970 def removeIJDim(self, stopScopes, parserOptions=None, wrapH=False, simplify=False):
972 Transform routines to be called in a loop on columns
973 :param stopScopes: scope paths where we stop to add the D argument (if needed)
974 :param parserOptions, wrapH: see the PYFT class
975 :param simplify: try to simplify code (remove useless dimensions in call)
977 ComputeInSingleColumn :
978 - Remove all Do loops on JI and JJ
979 - Initialize former indexes JI, JJ, JIJ to first array element:
980 JI=D%NIB, JJ=D%NJB, JIJ=D%NIJB
981 - If simplify is True, replace (:,*) on I/J/IJ dimension on argument
982 with explicit (:,*) on CALL statements:
983 e.g. CALL FOO(D, A(:,JK,1), B(:,:))
984 ==> CALL FOO(D, A(JIJ,JK,1), B(:,:)) only if the target argument is not an array
987 indexToCheck = {
'JI': (
'D%NIB',
'D%NIT'),
988 'JJ': (
'D%NJB',
'D%NJT'),
989 'JIJ': (
'D%NIJB',
'D%NIJT')}
990 hUupperBounds = [v[1]
for v
in indexToCheck.values()]
992 def slice2index(namedE, scope):
994 Transform a slice on the horizontal dimension into an index
995 Eg.: X(1:D%NIJT, 1:D%NKT) => X(JIJ, 1:D%NKT) Be careful, this array is not contiguous.
996 X(1:D%NIJT, JK) => X(JIJ, JK)
997 :param namedE: array to transform
998 :param scope: scope where the array is
1001 for isub, sub
in enumerate(namedE.findall(
'./{*}R-LT/{*}array-R/' +
1002 '{*}section-subscript-LT/' +
1003 '{*}section-subscript')):
1004 if ':' in alltext(sub):
1005 loopIndex, _, _ = scope.findIndexArrayBounds(namedE, isub, _loopVarPHYEX)
1006 if loopIndex
in indexToCheck:
1009 lowerBound = createElem(
'lower-bound')
1010 sub.insert(0, lowerBound)
1012 lowerBound = sub.find(
'./{*}lower-bound')
1013 lowerBound.tail =
''
1014 for item
in lowerBound:
1015 lowerBound.remove(item)
1016 upperBound = sub.find(
'./{*}upper-bound')
1017 if upperBound
is not None:
1018 sub.remove(upperBound)
1019 lowerBound.append(createExprPart(loopIndex))
1020 if loopIndex
not in indexRemoved:
1021 indexRemoved.append(loopIndex)
1024 if ':' not in alltext(namedE.find(
'./{*}R-LT/{*}array-R/{*}section-subscript-LT')):
1025 namedE.find(
'./{*}R-LT/{*}array-R').tag = f
'{{{NAMESPACE}}}parens-R'
1026 namedE.find(
'./{*}R-LT/{*}parens-R/' +
1027 '{*}section-subscript-LT').tag = f
'{{{NAMESPACE}}}element-LT'
1028 for ss
in namedE.findall(
'./{*}R-LT/{*}parens-R/' +
1029 '{*}element-LT/{*}section-subscript'):
1030 ss.tag = f
'{{{NAMESPACE}}}element'
1031 lowerBound = ss.find(
'./{*}lower-bound')
1032 for item
in lowerBound:
1034 ss.remove(lowerBound)
1037 self.attachArraySpecToEntity()
1041 for scope
in [scope
for scope
in self.getScopes()[::-1]
1042 if 'func:' not in scope.path
and
1043 (scope.path
in stopScopes
or
1044 self.tree.isUnderStopScopes(scope.path, stopScopes,
1045 includeInterfaces=
True))]:
1047 scope.addArrayParentheses()
1048 scope.expandAllArraysPHYEX()
1055 for doNode
in scope.findall(
'.//{*}do-construct')[::-1]:
1056 for loopI
in doNode.findall(
'./{*}do-stmt/{*}do-V/{*}named-E/{*}N'):
1057 loopIname = n2name(loopI).upper()
1058 if loopIname
in indexToCheck:
1061 par = scope.getParent(doNode)
1062 index = list(par).index(doNode)
1063 for item
in doNode[1:-1][::-1]:
1064 par.insert(index, item)
1066 if loopIname
not in indexRemoved:
1067 indexRemoved.append(loopIname)
1072 for intr
in scope.findall(
'.//{*}R-LT/{*}parens-R/../..'):
1073 intrName = n2name(intr.find(
'./{*}N')).upper()
1074 if intrName
in (
'PACK',
'UNPACK',
'COUNT',
'MAXVAL',
'MINVAL',
'ALL',
'ANY',
'SUM'):
1082 while par
is not None and not isStmt(par):
1083 par = scope.getParent(par)
1084 if tag(par)
in (
'a-stmt',
'op-E'):
1089 for namedE
in parToUse.findall(
'.//{*}R-LT/{*}array-R/../..'):
1090 slice2index(namedE, scope)
1093 if intr.find(
'.//{*}R-LT/{*}array-R')
is None:
1094 if intrName
in (
'MAXVAL',
'MINVAL',
'SUM',
'ALL',
'ANY'):
1096 parens = intr.find(
'./{*}R-LT/{*}parens-R')
1097 parens.tag = f
'{{{NAMESPACE}}}parens-E'
1098 intrPar = scope.getParent(intr)
1099 intrPar.insert(list(intrPar).index(intr), parens)
1100 intrPar.remove(intr)
1101 elif intrName ==
'COUNT':
1103 nodeN = intr.find(
'./{*}N')
1104 for item
in nodeN[1:]:
1106 nodeN.find(
'./{*}n').text =
'MERGE'
1107 elementLT = intr.find(
'./{*}R-LT/{*}parens-R/{*}element-LT')
1109 element = createElem(
'element', tail=
', ')
1110 element.append(createExprPart(val))
1111 elementLT.insert(0, element)
1124 assert scope.find(
'.//{*}include')
is None and \
1125 scope.find(
'.//{*}include-stmt')
is None, \
1126 "inlining must be performed before removing horizontal dimensions"
1128 if scope.path
in stopScopes:
1130 preserveShape = [v[
'n']
for v
in scope.varList
if v[
'arg']]
1135 if 'sub:' in scope.path:
1139 for namedE
in scope.findall(
'.//{*}named-E/{*}R-LT/{*}parens-R/../..'):
1140 if n2name(namedE.find(
'./{*}N')).upper()
not in preserveShape:
1141 var = scope.varList.findVar(n2name(namedE.find(
'./{*}N')).upper())
1142 if var
is not None and var[
'as']
is not None and len(var[
'as']) > 0:
1143 subs = namedE.findall(
'./{*}R-LT/{*}parens-R/' +
1144 '{*}element-LT/{*}element')
1145 if (len(subs) == 1
and var[
'as'][0][1]
in hUupperBounds)
or \
1146 (len(subs) == 2
and (var[
'as'][0][1]
in hUupperBounds
and
1147 var[
'as'][1][1]
in hUupperBounds)):
1148 namedE.remove(namedE.find(
'./{*}R-LT'))
1152 for call
in scope.findall(
'.//{*}call-stmt'):
1153 for namedE
in call.findall(
'./{*}arg-spec//{*}named-E'):
1154 subs = namedE.findall(
'.//{*}section-subscript')
1155 var = scope.varList.findVar(n2name(namedE.find(
'./{*}N')).upper())
1156 if len(subs) > 0
and (var
is None or var[
'as']
is None or
1157 len(var[
'as']) < len(subs)):
1165 elif (len(subs) >= 2
and
1166 ':' in alltext(subs[0])
and var[
'as'][0][1]
in hUupperBounds
and
1167 ':' in alltext(subs[1])
and var[
'as'][1][1]
in hUupperBounds):
1169 remove = len(subs) == 2
1170 index = (len(subs) > 2
and
1171 len([sub
for sub
in subs
if ':' in alltext(sub)]) == 2)
1172 elif (len(subs) >= 1
and
1173 ':' in alltext(subs[0])
and var[
'as'][0][1]
in hUupperBounds):
1175 remove = len(subs) == 1
1176 index = (len(subs) > 1
and
1177 len([sub
for sub
in subs
if ':' in alltext(sub)]) == 1)
1182 if n2name(namedE.find(
'./{*}N')).upper()
in preserveShape:
1183 slice2index(namedE, scope)
1185 nodeRLT = namedE.find(
'.//{*}R-LT')
1186 scope.getParent(nodeRLT).remove(nodeRLT)
1188 slice2index(namedE, scope)
1191 subs = namedE.findall(
'.//{*}section-subscript')
1192 if len(subs) > 0
and all(alltext(sub) ==
':' for sub
in subs):
1193 nodeRLT = namedE.find(
'.//{*}R-LT')
1194 scope.getParent(nodeRLT).remove(nodeRLT)
1199 for decl
in scope.findall(
'.//{*}T-decl-stmt/{*}EN-decl-LT/{*}EN-decl'):
1200 name = n2name(decl.find(
'./{*}EN-N/{*}N')).upper()
1201 if name
not in preserveShape:
1202 varsShape = decl.findall(
'.//{*}shape-spec-LT')
1203 for varShape
in varsShape:
1204 subs = varShape.findall(
'.//{*}shape-spec')
1205 if (len(subs) == 1
and alltext(subs[0])
in hUupperBounds)
or \
1206 (len(subs) == 2
and (alltext(subs[0])
in hUupperBounds
and
1207 alltext(subs[1])
in hUupperBounds)):
1209 itemToRemove = scope.getParent(varShape)
1210 scope.getParent(itemToRemove).remove(itemToRemove)
1215 for loopIndex
in indexRemoved:
1218 scope.insertStatement(
1219 createExpr(loopIndex +
" = " + indexToCheck[loopIndex][0])[0],
True)
1220 if len(indexRemoved) > 0:
1221 scope.addArgInTree(
'D',
'TYPE(DIMPHYEX_t), INTENT(IN) :: D',
1222 0, stopScopes, moduleVarList=[(
'MODD_DIMPHYEX', [
'DIMPHYEX_t'])],
1223 parserOptions=parserOptions, wrapH=wrapH)
1225 scope.addVar([[scope.path, loopIndex,
'INTEGER :: ' + loopIndex,
None]
1226 for loopIndex
in indexRemoved
1227 if scope.varList.findVar(loopIndex, exactScope=
True)
is None])
1423 Convert all calling of functions and gradient present in shumansGradients
1424 table into the use of subroutines
1425 and use mnh_expand_directives to handle intermediate computations
1427 def getDimsAndMNHExpandIndexes(zshugradwkDim, dimWorkingVar=''):
1429 if zshugradwkDim == 1:
1430 dimSuffRoutine =
'2D'
1432 mnhExpandArrayIndexes =
'JIJ=IIJB:IIJE'
1433 localVariables = [
'JIJ']
1434 elif zshugradwkDim == 2:
1436 if 'D%NKT' in dimWorkingVar:
1437 mnhExpandArrayIndexes =
'JIJ=IIJB:IIJE,JK=1:IKT'
1438 localVariables = [
'JIJ',
'JK']
1439 elif 'D%NIT' in dimWorkingVar
and 'D%NJT' in dimWorkingVar:
1441 mnhExpandArrayIndexes =
'JI=1:IIT,JJ=1:IJT'
1442 localVariables = [
'JI',
'JJ']
1443 dimSuffRoutine =
'2D'
1449 mnhExpandArrayIndexes =
'JIJ=IIJB:IIJE,JK=1:IKT'
1450 localVariables = [
'JIJ',
'JK']
1451 elif zshugradwkDim == 3:
1453 mnhExpandArrayIndexes =
'JI=1:IIT,JJ=1:IJT,JK=1:IKT'
1454 localVariables = [
'JI',
'JJ',
'JK']
1456 raise PYFTError(
'Shuman func to routine conversion not implemented ' +
1457 'for 4D+ dimensions variables')
1458 return dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes, localVariables
1460 def FUNCtoROUTINE(scope, stmt, itemFuncN, localShumansCount, inComputeStmt,
1461 nbzshugradwk, zshugradwkDim, dimWorkingVar):
1463 :param scope: node on which the calling function is present before transformation
1464 :param stmt: statement node (a-stmt or call-stmt) that contains the function(s) to be
1466 :param itemFuncN: <n>FUNCTIONNAME</n> node
1467 :param localShumansCount: instance of the shumansGradients dictionnary
1468 for the given scope (which contains the number of times a
1469 function has been called within a transformation)
1470 :param dimWorkingVar: string of the declaration of a potential working variable
1471 depending on the array on wich the shuman is applied
1472 (e.g. MZM(PRHODJ(:,IKB));
1473 dimWorkingVar = 'REAL, DIMENSION(D%NIJT) :: ' )
1475 :return callStmt: the new CALL to the routines statement
1476 :return computeStmt: the a-stmt computation statement if there was an operation
1477 in the calling function in stmt
1478 :return localVariables: list of local variables needed for the mnh_expand directive
1482 parStmt = scope.getParent(stmt)
1483 parItemFuncN = scope.getParent(itemFuncN)
1485 grandparItemFuncN = scope.getParent(itemFuncN, level=2)
1486 funcName = alltext(itemFuncN)
1489 indexForCall = list(parStmt).index(stmt)
1494 siblsItemFuncN = scope.getSiblings(parItemFuncN, after=
True, before=
False)
1495 workingItem = siblsItemFuncN[0][0][0]
1498 if len(siblsItemFuncN[0][0]) > 1:
1500 workingItem = scope.updateContinuation(siblsItemFuncN[0][0], removeALL=
True,
1501 align=
False, addBegin=
False)[0]
1505 opE = workingItem.findall(
'.//{*}op-E')
1506 scope.removeArrayParenthesesInNode(workingItem)
1507 computeStmt, remaningArgsofFunc = [],
''
1508 dimSuffVar = str(zshugradwkDim) +
'D'
1509 dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes, _ = \
1510 getDimsAndMNHExpandIndexes(zshugradwkDim, dimWorkingVar)
1513 computingVarName =
'ZSHUGRADWK'+str(nbzshugradwk)+
'_'+str(zshugradwkDim)+
'D'
1515 if not scope.varList.findVar(computingVarName):
1516 scope.addVar([[scope.path, computingVarName,
1517 dimWorkingVar + computingVarName,
None]])
1521 computeVar = scope.varList.findVar(computingVarName)
1522 dimWorkingVar =
'REAL, DIMENSION('
1523 for dims
in computeVar[
'as'][:arrayDim]:
1524 dimWorkingVar += dims[1] +
','
1525 dimWorkingVar = dimWorkingVar[:-1] +
') ::'
1527 dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes, localVariables = \
1528 getDimsAndMNHExpandIndexes(zshugradwkDim, dimWorkingVar)
1531 mnhOpenDir =
"!$mnh_expand_array(" + mnhExpandArrayIndexes +
")"
1532 mnhCloseDir =
"!$mnh_end_expand_array(" + mnhExpandArrayIndexes +
")"
1535 workingComputeItem = workingItem[0]
1537 if len(workingItem) == 2:
1538 remaningArgsofFunc =
',' + alltext(workingItem[1])
1539 elif len(workingItem) > 2:
1540 raise PYFTError(
'ShumanFUNCtoCALL: expected maximum 1 argument in shuman ' +
1541 'function to transform')
1542 computeStmt = createExpr(computingVarName +
" = " + alltext(workingComputeItem))[0]
1543 workingItem = computeStmt.find(
'.//{*}E-1')
1545 parStmt.insert(indexForCall, createElem(
'C', text=
'!$acc kernels', tail=
'\n'))
1546 parStmt.insert(indexForCall + 1, createElem(
'C', text=mnhOpenDir, tail=
'\n'))
1547 parStmt.insert(indexForCall + 2, computeStmt)
1548 parStmt.insert(indexForCall + 3, createElem(
'C', text=mnhCloseDir, tail=
'\n'))
1549 parStmt.insert(indexForCall + 4, createElem(
'C',
1550 text=
'!$acc end kernels', tail=
'\n'))
1551 parStmt.insert(indexForCall + 5, createElem(
'C',
1552 text=
'!', tail=
'\n'))
1556 if zshugradwkDim == 1:
1557 dimSuffRoutine =
'2D'
1558 workingVar =
'Z' + funcName + dimSuffVar +
'_WORK' + str(localShumansCount[funcName])
1559 if funcName
in (
'GY_U_UV',
'GX_V_UV'):
1560 gpuGradientImplementation =
'_DEVICE('
1561 newFuncName = funcName + dimSuffRoutine +
'_DEVICE'
1563 gpuGradientImplementation =
'_PHY(D, '
1564 newFuncName = funcName + dimSuffRoutine +
'_PHY'
1565 callStmt = createExpr(
"CALL " + funcName + dimSuffRoutine + gpuGradientImplementation
1566 + alltext(workingItem) + remaningArgsofFunc +
1567 ", " + workingVar +
")")[0]
1568 parStmt.insert(indexForCall, callStmt)
1571 parOfgrandparItemFuncN = scope.getParent(grandparItemFuncN)
1572 indexWorkingVar = list(parOfgrandparItemFuncN).index(grandparItemFuncN)
1573 savedTail = grandparItemFuncN.tail
1574 parOfgrandparItemFuncN.remove(grandparItemFuncN)
1577 xmlWorkingvar = createExprPart(workingVar)
1578 xmlWorkingvar.tail = savedTail
1579 parOfgrandparItemFuncN.insert(indexWorkingVar, xmlWorkingvar)
1582 if not scope.varList.findVar(workingVar):
1583 scope.addVar([[scope.path, workingVar, dimWorkingVar + workingVar,
None]])
1585 return callStmt, computeStmt, nbzshugradwk, newFuncName, localVariables
1587 shumansGradients = {
'MZM': 0,
'MXM': 0,
'MYM': 0,
'MZF': 0,
'MXF': 0,
'MYF': 0,
1588 'DZM': 0,
'DXM': 0,
'DYM': 0,
'DZF': 0,
'DXF': 0,
'DYF': 0,
1589 'GZ_M_W': 0,
'GZ_W_M': 0,
'GZ_U_UW': 0,
'GZ_V_VW': 0,
1590 'GX_M_U': 0,
'GX_U_M': 0,
'GX_W_UW': 0,
'GX_M_M': 0,
1591 'GY_V_M': 0,
'GY_M_V': 0,
'GY_W_VW': 0,
'GY_M_M': 0,
1592 'GX_V_UV': 0,
'GY_U_UV': 0}
1593 scopes = self.getScopes()
1594 if len(scopes) == 0
or scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
1596 for scope
in scopes:
1597 if 'sub:' in scope.path
and 'func' not in scope.path \
1598 and 'interface' not in scope.path:
1601 localVariablesToAdd = set()
1602 foundStmtandCalls, computeStmtforParenthesis = {}, []
1603 aStmt = scope.findall(
'.//{*}a-stmt')
1604 callStmts = scope.findall(
'.//{*}call-stmt')
1605 aStmtandCallStmts = aStmt + callStmts
1606 funcToSuppress = set()
1607 for stmt
in aStmtandCallStmts:
1608 elemN = stmt.findall(
'.//{*}n')
1610 if alltext(el)
in list(shumansGradients):
1611 funcToSuppress.add(alltext(el))
1614 parStmt = scope.getParent(stmt)
1615 if tag(parStmt) ==
'action-stmt':
1616 scope.changeIfStatementsInIfConstructs(
1617 singleItem=scope.getParent(parStmt))
1619 if str(stmt)
in foundStmtandCalls:
1620 foundStmtandCalls[str(stmt)][1] += 1
1622 foundStmtandCalls[str(stmt)] = [stmt, 1]
1625 subToInclude = set()
1626 for stmt
in foundStmtandCalls:
1627 localShumansGradients = copy.deepcopy(shumansGradients)
1628 elemToLookFor = [foundStmtandCalls[stmt][0]]
1629 previousComputeStmt = []
1632 while len(elemToLookFor) > 0:
1634 for elem
in elemToLookFor:
1635 elemN = elem.findall(
'.//{*}n')
1637 if alltext(el)
in list(localShumansGradients.keys()):
1642 nodeE1var = foundStmtandCalls[stmt][0].findall(
1643 './/{*}E-1/{*}named-E/{*}N')
1644 if len(nodeE1var) > 0:
1645 var = scope.varList.findVar(alltext(nodeE1var[0]))
1646 allSubscripts = foundStmtandCalls[stmt][0].findall(
1647 './/{*}E-1//{*}named-E/{*}R-LT/' +
1648 '{*}array-R/{*}section-subscript-LT')
1652 elPar = scope.getParent(el, level=2)
1653 callVar = elPar.findall(
'.//{*}named-E/{*}N')
1654 if alltext(el)[0] ==
'G':
1659 var = scope.varList.findVar(alltext(callVar[-1]))
1660 shumanIsCalledOn = scope.getParent(callVar[-1])
1663 var, inested =
None, 0
1665 while (
not var
or var[
'as']
is None or
1666 len(var[
'as']) == 0):
1670 var = scope.varList.findVar(
1671 alltext(callVar[inested]))
1673 shumanIsCalledOn = scope.getParent(callVar[inested-1])
1674 allSubscripts = shumanIsCalledOn.findall(
1675 './/{*}R-LT/{*}array-R/' +
1676 '{*}section-subscript-LT')
1680 arrayDim = len(var[
'as'])
1684 if len(allSubscripts) > 0:
1685 for subLT
in allSubscripts:
1687 lowerBound = sub.findall(
'.//{*}lower-bound')
1688 if len(lowerBound) > 0:
1689 if len(sub.findall(
'.//{*}upper-bound')) > 0:
1692 raise PYFTError(
'ShumanFUNCtoCALL does ' +
1693 'not handle conversion ' +
1694 'to routine of array ' +
1695 'subselection lower:upper' +
1696 ': how to set up the ' +
1697 'shape of intermediate ' +
1707 dimWorkingVar =
'REAL, DIMENSION('
1708 for dims
in var[
'as'][:arrayDim]:
1709 dimWorkingVar += dims[1] +
','
1710 dimWorkingVar = dimWorkingVar[:-1] +
') ::'
1713 localShumansGradients[alltext(el)] += 1
1717 if foundStmtandCalls[stmt][0].tail:
1718 foundStmtandCalls[stmt][0].tail = \
1719 foundStmtandCalls[stmt][0].tail.replace(
'\n',
'') +
'\n'
1721 foundStmtandCalls[stmt][0].tail =
'\n'
1724 result = FUNCtoROUTINE(scope, elem, el,
1725 localShumansGradients,
1726 elem
in previousComputeStmt,
1727 nbzshugradwk, arrayDim,
1729 (newCallStmt, newComputeStmt,
1730 nbzshugradwk, newFuncName, lv) = result
1731 localVariablesToAdd.update(lv)
1732 subToInclude.add(newFuncName)
1735 elemToLookFor.append(newCallStmt)
1739 if len(newComputeStmt) > 0:
1740 elemToLookFor.append(newComputeStmt)
1741 computeStmtforParenthesis.append(newComputeStmt)
1745 previousComputeStmt.append(newComputeStmt)
1749 elemToLookForNew = []
1750 for i
in elemToLookFor:
1751 nodeNs = i.findall(
'.//{*}n')
1754 if alltext(nnn)
in list(localShumansGradients):
1755 elemToLookForNew.append(i)
1757 elemToLookFor = elemToLookForNew
1760 if nbzshugradwk > maxnbZshugradwk:
1761 maxnbZshugradwk = nbzshugradwk
1764 if tag(foundStmtandCalls[stmt][0]) !=
'call-stmt':
1765 scope.addArrayParenthesesInNode(foundStmtandCalls[stmt][0])
1769 if tag(foundStmtandCalls[stmt][0]) !=
'call-stmt':
1772 dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes, lv = \
1773 getDimsAndMNHExpandIndexes(arrayDim, dimWorkingVar)
1774 localVariablesToAdd.update(lv)
1776 parStmt = scope.getParent(foundStmtandCalls[stmt][0])
1777 indexForCall = list(parStmt).index(foundStmtandCalls[stmt][0])
1778 mnhOpenDir =
"!$mnh_expand_array(" + mnhExpandArrayIndexes +
")"
1779 mnhCloseDir =
"!$mnh_end_expand_array(" + mnhExpandArrayIndexes +
")"
1780 parStmt.insert(indexForCall,
1781 createElem(
'C', text=
"!$acc kernels", tail=
'\n'))
1782 parStmt.insert(indexForCall + 1,
1783 createElem(
'C', text=mnhOpenDir, tail=
'\n'))
1784 parStmt.insert(indexForCall + 3,
1785 createElem(
'C', text=mnhCloseDir, tail=
'\n'))
1786 parStmt.insert(indexForCall + 4,
1787 createElem(
'C', text=
"!$acc end kernels", tail=
'\n'))
1788 parStmt.insert(indexForCall + 5,
1789 createElem(
'C', text=
"!", tail=
'\n'))
1792 for stmt
in computeStmtforParenthesis:
1793 scope.addArrayParenthesesInNode(stmt)
1797 for sub
in sorted(subToInclude):
1798 if re.match(
r'[MD][XYZ][MF](2D)?_PHY', sub):
1799 moduleVars.append((scope.path,
'MODE_SHUMAN_PHY', sub))
1800 if re.match(
r'[MD][XYZ][MF](2D)?_DEVICE', sub):
1801 moduleVars.append((scope.path,
'MODI_SHUMAN_DEVICE', sub))
1803 for kind
in (
'M',
'U',
'V',
'W'):
1804 if re.match(
r'G[XYZ]_' + kind +
r'_[MUVW]{1,2}_PHY', sub):
1805 moduleVars.append((scope.path, f
'MODE_GRADIENT_{kind}_PHY', sub))
1806 elif re.match(
r'G[XYZ]_' + kind +
r'_[MUVW]{1,2}_DEVICE', sub):
1807 moduleVars.append((scope.path, f
'MODI_GRADIENT_{kind}', sub))
1808 scope.addModuleVar(moduleVars)
1811 for sub
in funcToSuppress:
1812 if scope.varList.findVar(sub):
1813 scope.removeVar([(scope.path, sub)])
1816 for varName
in localVariablesToAdd:
1817 if not scope.varList.findVar(varName):
1818 var = {
'as': [],
'asx': [],
1819 'n': varName,
'i':
None,
't':
'INTEGER',
'arg':
False,
1820 'use':
False,
'opt':
False,
'allocatable':
False,
1821 'parameter':
False,
'init':
None,
'scopePath': scope.path}
1822 scope.addVar([[scope.path, var[
'n'], scope.varSpec2stmt(var),
None]])