133 Convert STR%VAR into single local variable contained in compute (a-stmt)
134 and in if-then-stmt, else-if-stmt, where-stmt
136 ZA = 1 + CST%XG ==> ZA = 1 + XCST_G
137 ZA = 1 + PARAM_ICE%XRTMIN(3) ==> ZA = 1 + XPARAM_ICE_XRTMIN3
138 ZRSMIN(1:KRR) = ICED%XRTMIN(1:KRR) => ZRSMIN(1:KRR) = ICEDXRTMIN1KRR(1:KRR)
139 IF(TURBN%CSUBG_MF_PDF=='NONE')THEN => IF(CTURBNSUBG_MF_PDF=='NONE')THEN
141 RESTRICTION : works only if the r-component variable is contained in 1 parent structure.
142 Allowed for conversion : CST%XG
143 Not converted : TOTO%CST%XG (for now, recursion must be coded)
144 Not converted : TOTO%ARRAY(:) (shape of the array must be determined from E1)
146 def convertOneType(component, newVarList, scope):
148 objType = scope.getParent(component, 2)
149 objTypeStr = alltext(objType).upper()
150 namedENn = objType.find(
'.//{*}N/{*}n')
151 structure = namedENn.text
152 variable = component.find(
'.//{*}ct').text.upper()
156 arrayRall = objType.findall(
'.//{*}array-R')
157 if len(arrayRall) > 0:
158 arrayR = copy.deepcopy(arrayRall[0])
159 txt = alltext(arrayR).replace(
',',
'')
160 txt = txt.replace(
':',
'')
161 txt = txt.replace(
'(',
'')
162 txt = txt.replace(
')',
'')
163 arrayIndices = arrayIndices + txt
164 elif len(objType.findall(
'.//{*}element-LT')) > 0:
166 for elem
in objType.findall(
'.//{*}element'):
167 arrayIndices = arrayIndices + alltext(elem)
168 newName = variable[0] + structure + variable[1:] + arrayIndices
169 newName = newName.upper()
173 namedENn.text = newName
174 objType.remove(objType.find(
'.//{*}R-LT'))
175 if len(arrayRall) > 0:
176 objType.insert(1, arrayR)
179 if newName
not in newVarList:
180 if len(arrayRall) == 0:
181 newVarList[newName] = (
None, objTypeStr)
183 newVarList[newName] = (arrayR, objTypeStr)
185 scopes = self.getScopes()
186 if scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
188 for scope
in [scope
for scope
in scopes
189 if 'sub:' in scope.path
and 'interface' not in scope.path]:
191 for ifStmt
in (scope.findall(
'.//{*}if-then-stmt') +
192 scope.findall(
'.//{*}else-if-stmt') +
193 scope.findall(
'.//{*}where-stmt')):
194 compo = ifStmt.findall(
'.//{*}component-R')
196 for elcompo
in compo:
197 convertOneType(elcompo, newVarList, scope)
199 for aStmt
in scope.findall(
'.//{*}a-stmt'):
202 if len(aStmt[0].findall(
'.//{*}component-R')) == 0:
203 compoE2 = aStmt.findall(
'.//{*}component-R')
210 nbNamedEinE2 = len(aStmt.findall(
'.//{*}E-2')[0].findall(
'.//{*}named-E/' +
212 if nbNamedEinE2 > 1
or nbNamedEinE2 == 1
and \
213 len(aStmt[0].findall(
'.//{*}R-LT')) == 1:
214 for elcompoE2
in compoE2:
215 convertOneType(elcompoE2, newVarList, scope)
218 for el, var
in newVarList.items():
219 if el[0].upper() ==
'X' or el[0].upper() ==
'P' or el[0].upper() ==
'Z':
221 elif el[0].upper() ==
'L' or el[0].upper() ==
'O':
223 elif el[0].upper() ==
'N' or el[0].upper() ==
'I' or el[0].upper() ==
'K':
225 elif el[0].upper() ==
'C':
226 varType =
'CHARACTER(LEN=LEN(' + var[1] +
'))'
228 raise PYFTError(
'Case not implemented for the first letter of the newVarName' +
229 el +
' in convertTypesInCompute')
233 varArray =
', DIMENSION('
234 for i, sub
in enumerate(var[0].findall(
'.//{*}section-subscript')):
235 if len(sub.findall(
'.//{*}upper-bound')) > 0:
236 dimSize = simplifyExpr(
237 alltext(sub.findall(
'.//{*}upper-bound')[0]) +
238 '-' + alltext(sub.findall(
'.//{*}lower-bound')[0]) +
240 elif len(sub.findall(
'.//{*}lover-bound')) > 0:
241 dimSize = simplifyExpr(alltext(sub.findall(
'.//{*}lower-bound')[0]))
243 dimSize =
'SIZE(' + var[1] +
',' + str(i+1) +
')'
244 varArray =
', DIMENSION(' + dimSize +
','
245 varArray = varArray[:-1] +
')'
246 scope.addVar([[scope.path, el, varType + varArray +
' :: ' + el,
None]])
249 stmtAffect = createExpr(el +
"=" + var[1])[0]
250 scope.insertStatement(scope.indent(stmtAffect), first=
True)
319 Add MPPDB_CHEKS on all intent REAL arrays on subroutines.
320 ****** Not applied on modd_ routines. ********
321 Handle optional arguments.
322 Example, for a BL89 routine with 4 arguments, 1 INTENT(IN),
323 2 INTENT(INOUT), 1 INTENT(OUT), it produces :
324 IF (MPPDB_INITIALIZED) THEN
326 CALL MPPDB_CHECK(PZZ, "BL89 beg:PZZ")
327 !Check all INOUT arrays
328 CALL MPPDB_CHECK(PDZZ, "BL89 beg:PDZZ")
329 CALL MPPDB_CHECK(PTHVREF, "BL89 beg:PTHVREF")
332 IF (MPPDB_INITIALIZED) THEN
333 !Check all INOUT arrays
334 CALL MPPDB_CHECK(PDZZ, "BL89 end:PDZZ")
335 CALL MPPDB_CHECK(PTHVREF, "BL89 end:PTHVREF")
336 !Check all OUT arrays
337 CALL MPPDB_CHECK(PLM, "BL89 end:PLM")
339 param printsMode: if True, instead of CALL MPPDB_CHECK, add fortran prints for debugging
341 def addPrints_statement(var, typeofPrints='minmax'):
342 ifBeg, ifEnd =
'',
''
345 if typeofPrints ==
'minmax':
346 strMSG = f
'MINMAX {varName} = \",MINVAL({varName}), MAXVAL({varName})'
347 elif typeofPrints ==
'shape':
348 strMSG = f
'SHAPE {varName} = \",SHAPE({varName})'
350 raise PYFTError(
'typeofPrints is either minmax or shape in addPrints_statement')
352 strMSG = var[
'n'] +
' = \",' + var[
'n']
354 ifBeg = ifBeg +
'IF (PRESENT(' + var[
'n'] +
')) THEN\n '
355 ifEnd = ifEnd +
'\nEND IF'
356 return createExpr(ifBeg +
"print*,\"" + strMSG + ifEnd)[0]
358 def addMPPDB_CHECK_statement(var, subRoutineName, strMSG='beg:
'):
359 ifBeg, ifEnd, addD, addLastDim, addSecondDimType = '',
'',
'',
'',
''
363 if 'D%NIJT' in var[
'as'][0][1]:
365 if len(var[
'as']) == 2:
367 addLastDim =
', ' + var[
'as'][1][1]
368 if len(var[
'as']) >= 2:
371 if 'D%NK' in var[
'as'][1][1]:
372 addSecondDimType =
',' +
'''"VERTICAL"'''
374 addSecondDimType =
',' +
'''"OTHER"'''
375 if 'MERGE' in var[
'as'][-1][1]:
376 keyDimMerge = var[
'as'][-1][1].split(
',')[2][:-1]
377 ifBeg =
'IF (' + keyDimMerge +
') THEN\n'
380 ifBeg = ifBeg +
'IF (PRESENT(' + var[
'n'] +
')) THEN\n IF (SIZE(' + \
381 var[
'n'] +
',1) > 0) THEN\n'
382 ifEnd = ifEnd +
'\nEND IF\nEND IF'
383 argsMPPDB = var[
'n'] +
", " +
"\"" + subRoutineName +
" " + strMSG+var[
'n'] +
"\""
384 return createExpr(ifBeg +
"CALL MPPDB_CHECK(" + addD + argsMPPDB +
385 addLastDim + addSecondDimType +
")" + ifEnd)[0]
386 scopes = self.getScopes()
387 if scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
394 if 'sub:' in scope.path
and 'func' not in scope.path
and 'interface' not in scope.path:
395 subRoutineName = scope.path.split(
'/')[-1].split(
':')[1]
398 arraysIn, arraysInOut, arraysOut = [], [], []
400 for var
in scope.varList:
401 if var[
'arg']
and var[
'as']
and 'TYPE' not in var[
't']
and \
402 'REAL' in var[
't']
and var[
'scopePath'] == scope.path:
405 if var[
'i'] ==
'INOUT':
406 arraysInOut.append(var)
407 if var[
'i'] ==
'OUT':
408 arraysOut.append(var)
410 for var
in scope.varList:
411 if not var[
't']
or var[
't']
and 'TYPE' not in var[
't']:
414 if var[
'i'] ==
'INOUT':
415 arraysInOut.append(var)
416 if var[
'i'] ==
'OUT':
417 arraysOut.append(var)
419 if len(arraysIn) + len(arraysInOut) + len(arraysOut) == 0:
424 scope.addModuleVar([(scope.path,
'MODE_MPPDB',
None)])
426 scope.addModuleVar([(scope.path,
'MODD_BLANK_n', [
'LDUMMY1'])])
429 commentIN = createElem(
'C', text=
'!Check all IN arrays', tail=
'\n')
430 commentINOUT = createElem(
'C', text=
'!Check all INOUT arrays', tail=
'\n')
431 commentOUT = createElem(
'C', text=
'!Check all OUT arrays', tail=
'\n')
434 if len(arraysIn) + len(arraysInOut) > 0:
436 ifMPPDBinit = createExpr(
"IF (MPPDB_INITIALIZED) THEN\n END IF")[0]
438 ifMPPDBinit = createExpr(
"IF (LDUMMY1) THEN\n END IF")[0]
439 ifMPPDB = ifMPPDBinit.find(
'.//{*}if-block')
442 if len(arraysIn) > 0:
443 ifMPPDB.insert(1, commentIN)
444 for i, var
in enumerate(arraysIn):
446 ifMPPDB.insert(2 + i, addMPPDB_CHECK_statement(var, subRoutineName,
449 ifMPPDB.insert(2 + i, addPrints_statement(var,
450 typeofPrints=
'minmax'))
451 ifMPPDB.insert(3 + i, addPrints_statement(var,
452 typeofPrints=
'shape'))
455 if len(arraysInOut) > 0:
456 shiftLineNumber = 2
if len(arraysIn) > 0
else 1
458 ifMPPDB.insert(len(arraysIn) + shiftLineNumber, commentINOUT)
460 ifMPPDB.insert(len(arraysIn)*2 + shiftLineNumber-1, commentINOUT)
462 for i, var
in enumerate(arraysInOut):
464 ifMPPDB.insert(len(arraysIn) + shiftLineNumber + 1 + i,
465 addMPPDB_CHECK_statement(var, subRoutineName,
468 ifMPPDB.insert(len(arraysIn) + shiftLineNumber + 1 + i,
469 addPrints_statement(var, typeofPrints=
'minmax'))
472 scope.insertStatement(scope.indent(ifMPPDBinit), first=
True)
475 if len(arraysInOut) + len(arraysOut) > 0:
477 ifMPPDBend = createExpr(
"IF (MPPDB_INITIALIZED) THEN\n END IF")[0]
479 ifMPPDBend = createExpr(
"IF (LDUMMY1) THEN\n END IF")[0]
480 ifMPPDB = ifMPPDBend.find(
'.//{*}if-block')
483 if len(arraysInOut) > 0:
484 ifMPPDB.insert(1, commentINOUT)
485 for i, var
in enumerate(arraysInOut):
487 ifMPPDB.insert(2 + i, addMPPDB_CHECK_statement(var, subRoutineName,
490 ifMPPDB.insert(2 + i, addPrints_statement(var,
491 typeofPrints=
'minmax'))
494 if len(arraysOut) > 0:
495 shiftLineNumber = 2
if len(arraysInOut) > 0
else 1
497 ifMPPDB.insert(len(arraysInOut) + shiftLineNumber, commentOUT)
499 ifMPPDB.insert(len(arraysInOut)*2 + shiftLineNumber-1, commentOUT)
500 for i, var
in enumerate(arraysOut):
502 ifMPPDB.insert(len(arraysInOut) + shiftLineNumber + 1 + i,
503 addMPPDB_CHECK_statement(var, subRoutineName,
506 ifMPPDB.insert(len(arraysInOut) + shiftLineNumber + 1 + i,
507 addPrints_statement(var, typeofPrints=
'minmax'))
510 scope.insertStatement(scope.indent(ifMPPDBend), first=
False)
618 def removeIJDim(self, stopScopes, parserOptions=None, wrapH=False, simplify=False):
620 Transform routines to be called in a loop on columns
621 :param stopScopes: scope paths where we stop to add the D argument (if needed)
622 :param parserOptions, wrapH: see the PYFT class
623 :param simplify: try to simplify code (remove useless dimensions in call)
625 ComputeInSingleColumn :
626 - Remove all Do loops on JI and JJ
627 - Initialize former indexes JI, JJ, JIJ to first array element:
628 JI=D%NIB, JJ=D%NJB, JIJ=D%NIJB
629 - If simplify is True, replace (:,*) on I/J/IJ dimension on argument
630 with explicit (:,*) on CALL statements:
631 e.g. CALL FOO(D, A(:,JK,1), B(:,:))
632 ==> CALL FOO(D, A(JIJ,JK,1), B(:,:)) only if the target argument is not an array
635 indexToCheck = {
'JI': (
'D%NIB',
'D%NIT'),
636 'JJ': (
'D%NJB',
'D%NJT'),
637 'JIJ': (
'D%NIJB',
'D%NIJT')}
639 def slice2index(namedE, scope):
641 Transform a slice on the horizontal dimension into an index
642 Eg.: X(1:D%NIJT, 1:D%NKT) => X(JIJ, 1:D%NKT) Be careful, this array is not contiguous.
643 X(1:D%NIJT, JK) => X(JIJ, JK)
644 :param namedE: array to transform
645 :param scope: scope where the array is
648 for isub, sub
in enumerate(namedE.findall(
'./{*}R-LT/{*}array-R/' +
649 '{*}section-subscript-LT/' +
650 '{*}section-subscript')):
651 if ':' in alltext(sub):
652 loopIndex, _, _ = scope.findIndexArrayBounds(namedE, isub, _loopVarPHYEX)
653 if loopIndex
in indexToCheck:
656 lowerBound = createElem(
'lower-bound')
657 sub.insert(0, lowerBound)
659 lowerBound = sub.find(
'./{*}lower-bound')
661 for item
in lowerBound:
662 lowerBound.remove(item)
663 upperBound = sub.find(
'./{*}upper-bound')
664 if upperBound
is not None:
665 sub.remove(upperBound)
666 lowerBound.append(createExprPart(loopIndex))
667 if loopIndex
not in indexRemoved:
668 indexRemoved.append(loopIndex)
671 if ':' not in alltext(namedE.find(
'./{*}R-LT/{*}array-R/{*}section-subscript-LT')):
672 namedE.find(
'./{*}R-LT/{*}array-R').tag = f
'{{{NAMESPACE}}}parens-R'
673 namedE.find(
'./{*}R-LT/{*}parens-R/' +
674 '{*}section-subscript-LT').tag = f
'{{{NAMESPACE}}}element-LT'
675 for ss
in namedE.findall(
'./{*}R-LT/{*}parens-R/' +
676 '{*}element-LT/{*}section-subscript'):
677 ss.tag = f
'{{{NAMESPACE}}}element'
678 lowerBound = ss.find(
'./{*}lower-bound')
679 for item
in lowerBound:
681 ss.remove(lowerBound)
684 self.addArrayParentheses()
687 self.attachArraySpecToEntity()
688 hUupperBounds = [v[1]
for v
in indexToCheck.values()]
692 for scope
in [scope
for scope
in self.getScopes()[::-1]
693 if 'func:' not in scope.path
and
694 (scope.path
in stopScopes
or
695 self.tree.isUnderStopScopes(scope.path, stopScopes,
696 includeInterfaces=
True))]:
702 for doNode
in scope.findall(
'.//{*}do-construct')[::-1]:
703 for loopI
in doNode.findall(
'./{*}do-stmt/{*}do-V/{*}named-E/{*}N'):
704 loopIname = n2name(loopI).upper()
705 if loopIname
in indexToCheck:
708 par = scope.getParent(doNode)
709 index = list(par).index(doNode)
710 for item
in doNode[1:-1][::-1]:
711 par.insert(index, item)
713 if loopIname
not in indexRemoved:
714 indexRemoved.append(loopIname)
719 for intr
in scope.findall(
'.//{*}R-LT/{*}parens-R/../..'):
720 intrName = n2name(intr.find(
'./{*}N')).upper()
721 if intrName
in (
'PACK',
'UNPACK',
'COUNT',
'MAXVAL',
'MINVAL',
'ALL',
'ANY',
'SUM'):
729 while par
is not None and not isStmt(par):
730 par = scope.getParent(par)
731 if tag(par)
in (
'a-stmt',
'op-E'):
736 for namedE
in parToUse.findall(
'.//{*}R-LT/{*}array-R/../..'):
737 slice2index(namedE, scope)
740 if intr.find(
'.//{*}R-LT/{*}array-R')
is None:
741 if intrName
in (
'MAXVAL',
'MINVAL',
'SUM',
'ALL',
'ANY'):
743 parens = intr.find(
'./{*}R-LT/{*}parens-R')
744 parens.tag = f
'{{{NAMESPACE}}}parens-E'
745 intrPar = scope.getParent(intr)
746 intrPar.insert(list(intrPar).index(intr), parens)
748 elif intrName ==
'COUNT':
750 nodeN = intr.find(
'./{*}N')
751 for item
in nodeN[1:]:
753 nodeN.find(
'./{*}n').text =
'MERGE'
754 elementLT = intr.find(
'./{*}R-LT/{*}parens-R/{*}element-LT')
756 element = createElem(
'element', tail=
', ')
757 element.append(createExprPart(val))
758 elementLT.insert(0, element)
771 assert scope.find(
'.//{*}include')
is None and \
772 scope.find(
'.//{*}include-stmt')
is None, \
773 "inlining must be performed before removing horizontal dimensions"
775 if scope.path
in stopScopes:
777 preserveShape = [v[
'n']
for v
in scope.varList
if v[
'arg']]
782 if 'sub:' in scope.path:
786 for namedE
in scope.findall(
'.//{*}named-E/{*}R-LT/{*}parens-R/../..'):
787 if n2name(namedE.find(
'./{*}N')).upper()
not in preserveShape:
788 var = scope.varList.findVar(n2name(namedE.find(
'./{*}N')).upper())
789 if var
is not None and var[
'as']
is not None and len(var[
'as']) > 0:
790 subs = namedE.findall(
'./{*}R-LT/{*}parens-R/' +
791 '{*}element-LT/{*}element')
792 if (len(subs) == 1
and var[
'as'][0][1]
in hUupperBounds)
or \
793 (len(subs) == 2
and (var[
'as'][0][1]
in hUupperBounds
and
794 var[
'as'][1][1]
in hUupperBounds)):
795 namedE.remove(namedE.find(
'./{*}R-LT'))
799 for call
in scope.findall(
'.//{*}call-stmt'):
800 for namedE
in call.findall(
'./{*}arg-spec//{*}named-E'):
801 subs = namedE.findall(
'.//{*}section-subscript')
802 var = scope.varList.findVar(n2name(namedE.find(
'./{*}N')).upper())
803 if len(subs) > 0
and (var
is None or var[
'as']
is None or
804 len(var[
'as']) < len(subs)):
812 elif (len(subs) >= 2
and
813 ':' in alltext(subs[0])
and var[
'as'][0][1]
in hUupperBounds
and
814 ':' in alltext(subs[1])
and var[
'as'][1][1]
in hUupperBounds):
816 remove = len(subs) == 2
817 index = (len(subs) > 2
and
818 len([sub
for sub
in subs
if ':' in alltext(sub)]) == 2)
819 elif (len(subs) >= 1
and
820 ':' in alltext(subs[0])
and var[
'as'][0][1]
in hUupperBounds):
822 remove = len(subs) == 1
823 index = (len(subs) > 1
and
824 len([sub
for sub
in subs
if ':' in alltext(sub)]) == 1)
829 if n2name(namedE.find(
'./{*}N')).upper()
in preserveShape:
830 slice2index(namedE, scope)
832 nodeRLT = namedE.find(
'.//{*}R-LT')
833 scope.getParent(nodeRLT).remove(nodeRLT)
835 slice2index(namedE, scope)
840 for decl
in scope.findall(
'.//{*}T-decl-stmt/{*}EN-decl-LT/{*}EN-decl'):
841 name = n2name(decl.find(
'./{*}EN-N/{*}N')).upper()
842 if name
not in preserveShape:
843 varsShape = decl.findall(
'.//{*}shape-spec-LT')
844 for varShape
in varsShape:
845 subs = varShape.findall(
'.//{*}shape-spec')
846 if (len(subs) == 1
and alltext(subs[0])
in hUupperBounds)
or \
847 (len(subs) == 2
and (alltext(subs[0])
in hUupperBounds
and
848 alltext(subs[1])
in hUupperBounds)):
850 itemToRemove = scope.getParent(varShape)
851 scope.getParent(itemToRemove).remove(itemToRemove)
856 for loopIndex
in indexRemoved:
859 scope.insertStatement(
860 createExpr(loopIndex +
" = " + indexToCheck[loopIndex][0])[0],
True)
861 if len(indexRemoved) > 0:
862 scope.addArgInTree(
'D',
'TYPE(DIMPHYEX_t) :: D',
863 0, stopScopes, moduleVarList=[(
'MODD_DIMPHYEX', [
'DIMPHYEX_t'])],
864 parserOptions=parserOptions, wrapH=wrapH)
866 scope.addVar([[scope.path, loopIndex,
'INTEGER :: ' + loopIndex,
None]
867 for loopIndex
in indexRemoved
868 if scope.varList.findVar(loopIndex, exactScope=
True)
is None])
1063 Convert all calling of functions and gradient present in shumansGradients
1064 table into the use of subroutines
1065 and use mnh_expand_directives to handle intermediate computations
1067 def getDimsAndMNHExpandIndexes(zshugradwkDim, dimWorkingVar=''):
1069 if zshugradwkDim == 1:
1070 dimSuffRoutine =
'2D'
1072 mnhExpandArrayIndexes =
'JIJ=IIJB:IIJE'
1073 elif zshugradwkDim == 2:
1075 if 'D%NKT' in dimWorkingVar:
1076 mnhExpandArrayIndexes =
'JIJ=IIJB:IIJE,JK=1:IKT'
1077 elif 'D%NIT' in dimWorkingVar
and 'D%NJT' in dimWorkingVar:
1079 mnhExpandArrayIndexes =
'JI=1:IIT,JJ=1:IJT'
1080 dimSuffRoutine =
'2D'
1086 mnhExpandArrayIndexes =
'JIJ=IIJB:IIJE,JK=1:IKT'
1087 elif zshugradwkDim == 3:
1089 mnhExpandArrayIndexes =
'JI=1:IIT,JJ=1:IJT,JK=1:IKT'
1091 raise PYFTError(
'Shuman func to routine conversion not implemented ' +
1092 'for 4D+ dimensions variables')
1093 return dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes
1095 def FUNCtoROUTINE(scope, stmt, itemFuncN, localShumansCount, inComputeStmt,
1096 nbzshugradwk, zshugradwkDim, dimWorkingVar):
1098 :param scope: node on which the calling function is present before transformation
1099 :param stmt: statement node (a-stmt or call-stmt) that contains the function(s) to be
1101 :param itemFuncN: <n>FUNCTIONNAME</n> node
1102 :param localShumansCount: instance of the shumansGradients dictionnary
1103 for the given scope (which contains the number of times a
1104 function has been called within a transformation)
1105 :param dimWorkingVar: string of the declaration of a potential working variable
1106 depending on the array on wich the shuman is applied
1107 (e.g. MZM(PRHODJ(:,IKB));
1108 dimWorkingVar = 'REAL, DIMENSION(D%NIJT) :: ' )
1110 :return callStmt: the new CALL to the routines statement
1111 :return computeStmt: the a-stmt computation statement if there was an operation
1112 in the calling function in stmt
1115 parStmt = scope.getParent(stmt)
1116 parItemFuncN = scope.getParent(itemFuncN)
1118 grandparItemFuncN = scope.getParent(itemFuncN, level=2)
1119 funcName = alltext(itemFuncN)
1122 indexForCall = list(parStmt).index(stmt)
1127 siblsItemFuncN = scope.getSiblings(parItemFuncN, after=
True, before=
False)
1128 workingItem = siblsItemFuncN[0][0][0]
1131 if len(siblsItemFuncN[0][0]) > 1:
1133 workingItem = scope.updateContinuation(siblsItemFuncN[0][0], removeALL=
True,
1134 align=
False, addBegin=
False)[0]
1138 opE = workingItem.findall(
'.//{*}op-E')
1139 scope.removeArrayParenthesesInNode(workingItem)
1140 computeStmt, remaningArgsofFunc = [],
''
1141 dimSuffVar = str(zshugradwkDim) +
'D'
1142 dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes = \
1143 getDimsAndMNHExpandIndexes(zshugradwkDim, dimWorkingVar)
1146 computingVarName =
'ZSHUGRADWK'+str(nbzshugradwk)+
'_'+str(zshugradwkDim)+
'D'
1148 if not scope.varList.findVar(computingVarName):
1149 scope.addVar([[scope.path, computingVarName,
1150 dimWorkingVar + computingVarName,
None]])
1154 computeVar = scope.varList.findVar(computingVarName)
1155 dimWorkingVar =
'REAL, DIMENSION('
1156 for dims
in computeVar[
'as'][:arrayDim]:
1157 dimWorkingVar += dims[1] +
','
1158 dimWorkingVar = dimWorkingVar[:-1] +
') ::'
1160 dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes = \
1161 getDimsAndMNHExpandIndexes(zshugradwkDim, dimWorkingVar)
1164 mnhOpenDir =
"!$mnh_expand_array(" + mnhExpandArrayIndexes +
")"
1165 mnhCloseDir =
"!$mnh_end_expand_array(" + mnhExpandArrayIndexes +
")"
1168 workingComputeItem = workingItem[0]
1170 if len(workingItem) == 2:
1171 remaningArgsofFunc =
',' + alltext(workingItem[1])
1172 elif len(workingItem) > 2:
1173 raise PYFTError(
'ShumanFUNCtoCALL: expected maximum 1 argument in shuman ' +
1174 'function to transform')
1175 computeStmt = createExpr(computingVarName +
" = " + alltext(workingComputeItem))[0]
1176 workingItem = computeStmt.find(
'.//{*}E-1')
1178 parStmt.insert(indexForCall, createElem(
'C', text=
'!$acc kernels', tail=
'\n'))
1179 parStmt.insert(indexForCall + 1, createElem(
'C', text=mnhOpenDir, tail=
'\n'))
1180 parStmt.insert(indexForCall + 2, computeStmt)
1181 parStmt.insert(indexForCall + 3, createElem(
'C', text=mnhCloseDir, tail=
'\n'))
1182 parStmt.insert(indexForCall + 4, createElem(
'C',
1183 text=
'!$acc end kernels', tail=
'\n'))
1184 parStmt.insert(indexForCall + 5, createElem(
'C',
1185 text=
'!', tail=
'\n'))
1189 if zshugradwkDim == 1:
1190 dimSuffRoutine =
'2D'
1191 workingVar =
'Z' + funcName + dimSuffVar +
'_WORK' + str(localShumansCount[funcName])
1192 if funcName
in (
'GY_U_UV',
'GX_V_UV'):
1193 gpuGradientImplementation =
'_DEVICE('
1194 newFuncName = funcName + dimSuffRoutine +
'_DEVICE'
1196 gpuGradientImplementation =
'_PHY(D, '
1197 newFuncName = funcName + dimSuffRoutine +
'_PHY'
1198 callStmt = createExpr(
"CALL " + funcName + dimSuffRoutine + gpuGradientImplementation
1199 + alltext(workingItem) + remaningArgsofFunc +
1200 ", " + workingVar +
")")[0]
1201 parStmt.insert(indexForCall, callStmt)
1204 parOfgrandparItemFuncN = scope.getParent(grandparItemFuncN)
1205 indexWorkingVar = list(parOfgrandparItemFuncN).index(grandparItemFuncN)
1206 savedTail = grandparItemFuncN.tail
1207 parOfgrandparItemFuncN.remove(grandparItemFuncN)
1210 xmlWorkingvar = createExprPart(workingVar)
1211 xmlWorkingvar.tail = savedTail
1212 parOfgrandparItemFuncN.insert(indexWorkingVar, xmlWorkingvar)
1215 if not scope.varList.findVar(workingVar):
1216 scope.addVar([[scope.path, workingVar, dimWorkingVar + workingVar,
None]])
1218 return callStmt, computeStmt, nbzshugradwk, newFuncName
1220 shumansGradients = {
'MZM': 0,
'MXM': 0,
'MYM': 0,
'MZF': 0,
'MXF': 0,
'MYF': 0,
1221 'DZM': 0,
'DXM': 0,
'DYM': 0,
'DZF': 0,
'DXF': 0,
'DYF': 0,
1222 'GZ_M_W': 0,
'GZ_W_M': 0,
'GZ_U_UW': 0,
'GZ_V_VW': 0,
1223 'GX_M_U': 0,
'GX_U_M': 0,
'GX_W_UW': 0,
'GX_M_M': 0,
1224 'GY_V_M': 0,
'GY_M_V': 0,
'GY_W_VW': 0,
'GY_M_M': 0,
1225 'GX_V_UV': 0,
'GY_U_UV': 0}
1226 scopes = self.getScopes()
1227 if len(scopes) == 0
or scopes[0].path.split(
'/')[-1].split(
':')[1][:4] ==
'MODD':
1229 for scope
in scopes:
1230 if 'sub:' in scope.path
and 'func' not in scope.path \
1231 and 'interface' not in scope.path:
1234 foundStmtandCalls, computeStmtforParenthesis = {}, []
1235 aStmt = scope.findall(
'.//{*}a-stmt')
1236 callStmts = scope.findall(
'.//{*}call-stmt')
1237 aStmtandCallStmts = aStmt + callStmts
1238 for stmt
in aStmtandCallStmts:
1239 elemN = stmt.findall(
'.//{*}n')
1241 if alltext(el)
in list(shumansGradients):
1244 parStmt = scope.getParent(stmt)
1245 if tag(parStmt) ==
'action-stmt':
1246 scope.changeIfStatementsInIfConstructs(
1247 singleItem=scope.getParent(parStmt))
1249 if str(stmt)
in foundStmtandCalls:
1250 foundStmtandCalls[str(stmt)][1] += 1
1252 foundStmtandCalls[str(stmt)] = [stmt, 1]
1255 subToInclude = set()
1256 for stmt
in foundStmtandCalls:
1257 localShumansGradients = copy.deepcopy(shumansGradients)
1258 elemToLookFor = [foundStmtandCalls[stmt][0]]
1259 previousComputeStmt = []
1262 while len(elemToLookFor) > 0:
1264 for elem
in elemToLookFor:
1265 elemN = elem.findall(
'.//{*}n')
1267 if alltext(el)
in list(localShumansGradients.keys()):
1272 nodeE1var = foundStmtandCalls[stmt][0].findall(
1273 './/{*}E-1/{*}named-E/{*}N')
1274 if len(nodeE1var) > 0:
1275 var = scope.varList.findVar(alltext(nodeE1var[0]))
1276 allSubscripts = foundStmtandCalls[stmt][0].findall(
1277 './/{*}E-1//{*}named-E/{*}R-LT/' +
1278 '{*}array-R/{*}section-subscript-LT')
1282 elPar = scope.getParent(el, level=2)
1283 callVar = elPar.findall(
'.//{*}named-E/{*}N')
1284 if alltext(el)[0] ==
'G':
1289 var = scope.varList.findVar(alltext(callVar[-1]))
1290 shumanIsCalledOn = scope.getParent(callVar[-1])
1293 var, inested =
None, 0
1295 while not var
or len(var[
'as']) == 0:
1299 var = scope.varList.findVar(
1300 alltext(callVar[inested]))
1302 shumanIsCalledOn = scope.getParent(callVar[inested-1])
1303 allSubscripts = shumanIsCalledOn.findall(
1304 './/{*}R-LT/{*}array-R/' +
1305 '{*}section-subscript-LT')
1309 arrayDim = len(var[
'as'])
1313 if len(allSubscripts) > 0:
1314 for subLT
in allSubscripts:
1316 lowerBound = sub.findall(
'.//{*}lower-bound')
1317 if len(lowerBound) > 0:
1318 if len(sub.findall(
'.//{*}upper-bound')) > 0:
1321 raise PYFTError(
'ShumanFUNCtoCALL does ' +
1322 'not handle conversion ' +
1323 'to routine of array ' +
1324 'subselection lower:upper' +
1325 ': how to set up the ' +
1326 'shape of intermediate ' +
1336 dimWorkingVar =
'REAL, DIMENSION('
1337 for dims
in var[
'as'][:arrayDim]:
1338 dimWorkingVar += dims[1] +
','
1339 dimWorkingVar = dimWorkingVar[:-1] +
') ::'
1342 localShumansGradients[alltext(el)] += 1
1346 if foundStmtandCalls[stmt][0].tail:
1347 foundStmtandCalls[stmt][0].tail = \
1348 foundStmtandCalls[stmt][0].tail.replace(
'\n',
'') +
'\n'
1350 foundStmtandCalls[stmt][0].tail =
'\n'
1353 result = FUNCtoROUTINE(scope, elem, el,
1354 localShumansGradients,
1355 elem
in previousComputeStmt,
1356 nbzshugradwk, arrayDim,
1358 (newCallStmt, newComputeStmt,
1359 nbzshugradwk, newFuncName) = result
1360 subToInclude.add(newFuncName)
1363 elemToLookFor.append(newCallStmt)
1367 if len(newComputeStmt) > 0:
1368 elemToLookFor.append(newComputeStmt)
1369 computeStmtforParenthesis.append(newComputeStmt)
1373 previousComputeStmt.append(newComputeStmt)
1377 elemToLookForNew = []
1378 for i
in elemToLookFor:
1379 nodeNs = i.findall(
'.//{*}n')
1382 if alltext(nnn)
in list(localShumansGradients):
1383 elemToLookForNew.append(i)
1385 elemToLookFor = elemToLookForNew
1388 if nbzshugradwk > maxnbZshugradwk:
1389 maxnbZshugradwk = nbzshugradwk
1392 scope.addArrayParenthesesInNode(foundStmtandCalls[stmt][0])
1396 if tag(foundStmtandCalls[stmt][0]) !=
'call-stmt':
1399 dimSuffRoutine, dimSuffVar, mnhExpandArrayIndexes = \
1400 getDimsAndMNHExpandIndexes(arrayDim, dimWorkingVar)
1402 parStmt = scope.getParent(foundStmtandCalls[stmt][0])
1403 indexForCall = list(parStmt).index(foundStmtandCalls[stmt][0])
1404 mnhOpenDir =
"!$mnh_expand_array(" + mnhExpandArrayIndexes +
")"
1405 mnhCloseDir =
"!$mnh_end_expand_array(" + mnhExpandArrayIndexes +
")"
1406 parStmt.insert(indexForCall,
1407 createElem(
'C', text=
"!$acc kernels", tail=
'\n'))
1408 parStmt.insert(indexForCall + 1,
1409 createElem(
'C', text=mnhOpenDir, tail=
'\n'))
1410 parStmt.insert(indexForCall + 3,
1411 createElem(
'C', text=mnhCloseDir, tail=
'\n'))
1412 parStmt.insert(indexForCall + 4,
1413 createElem(
'C', text=
"!$acc end kernels", tail=
'\n'))
1414 parStmt.insert(indexForCall + 5,
1415 createElem(
'C', text=
"!", tail=
'\n'))
1418 for stmt
in computeStmtforParenthesis:
1419 scope.addArrayParenthesesInNode(stmt)
1423 for sub
in sorted(subToInclude):
1424 if re.match(
r'[MD][XYZ][MF](2D)?_PHY', sub):
1425 moduleVars.append((scope.path,
'MODE_SHUMAN_PHY', sub))
1427 for kind
in (
'M',
'U',
'V',
'W'):
1428 if re.match(
r'G[XYZ]_' + kind +
r'_[MUVW]{1,2}_PHY', sub):
1429 moduleVars.append((scope.path, f
'MODE_GRADIENT_{kind}_PHY', sub))
1430 scope.addModuleVar(moduleVars)