#!/usr/bin/env python
# cryoposition - 
#
# Author: David Mastronarde
#
# $Id: cryoposition,v 937343107256 2023/02/19 22:44:16 mast $
#

progname = 'cryoposition'
prefix = 'ERROR: ' + progname + ' - '

def cleanup():
   if leaveTemp <= 0:
      prnstr('Cleaning up temporary files', flush = True)
      origlen = len(cleanList)
      for ind in range(origlen):
         cleanList.append(cleanList[ind] + '~')
      cleanupFiles(cleanList)

def getThresholdFromClipOutput(clipLines, key):
   for line in clipLines:
      if key in line:
         prnstr(line.strip())
         lsplit = line.strip().split()
         return float(lsplit[-1])
   else:    # ELSE ON FOR
      exitError('Could not find threshold value in clip output')


#### MAIN PROGRAM  ####
#
# load System Libraries
import os, sys, glob, math, shutil
from math import fabs

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *
from pysed import *

# Fallbacks from ../manpages/autodoc2man 3 1 cryoposition
options = ["root:RootName:CH:", "thickness:ThicknessOfTomograms:I:",
           "find:FindBeadsInVolume:I:", "size:BeadSize:F:", "light:LightFeatures:B:",
           "binning:BinningToApply:I:", "erase:EraseFraction:F:",
           "high:HighSDCriterion:F:", "boost:BoostThickness:F:",
           "scales:ScalesToApply:IA:", "box:BoxSizeInXYZ:IT:",
           "spacing:SpacingOfBoxesInXYZ:IT:", "gpu:UseGPU:I:",
           "pitch:TomoPitchModel:FN:", "control:ControlValue:FPM:",
           "fsopt:FindSecOptions:CH:", "leave:LeaveTempFiles:I:", "use:UseTempFiles:I:"]

(opts, nonopts) = PipReadOrParseOptions(sys.argv, options, progname, 2, 0, 0)
os.environ['PIP_PRINT_ENTRIES'] = '0'

beadOptimal = 5.0
beadMinimum = 4.2
beadBigVolMin = 3.0
beadMaximum = 7.0
maxBinning = 4
volOptimal = 650
xOversizeFrac = 1.18
yOversizeFrac = 1.06
findAvgFallback = 0.5
findStoreFallback = 0.5
extraHistFrac = 0.95
beadVolFrac = 0.33
noBeadDiameter = 3.
threshSumFactor = 0.05
maxBeadEraseFrac = 0.01
minBeadEraseFrac = 0.0005
lightBeads = 0
refEraseLimit = 0.01           # Good maxErase value for a reference volume
voxelsAtRefLimit = 6.7e-5      # Fraction of voxels in beads for the reference volume


(comExt, dualNum, dsRootName, typeExt, stackExt) = findRootAxisAndExtensions()
if dualNum < 0 or not comExt or typeExt == None:
   exitError('Command files like tilt.com either are missing or have conflicting ' + \
             'entries about critical information')

# Get options
rootName = PipGetInOutFile('RootName', 0)
if not rootName:
   exitError('The root name of image files must be entered')
if dualNum > 0 and rootName[-1] not in ('a', 'b'):
   exitError('Command files indicate this is a dual axis set but the entered root ' +\
             'name does not end in a or b')
thickness = PipGetInteger('ThicknessOfTomograms', 0)
if thickness <= 0:
   exitError('A sample thickness must be entered')
findBeads = PipGetInteger('FindBeadsInVolume', 0)
beadSize = PipGetFloat('BeadSize', 0.)
binning = PipGetInteger('Binning', 0)
eraseFrac = PipGetFloat('EraseFraction', 0.002)
highSDcrit = PipGetFloat('HighSDCriterion', 5.)
fsOpts = PipGetString('FindSecOptions', '')
leaveTemp = PipGetInteger('LeaveTempFiles', 0)
useTemp = PipGetInteger('UseTempFiles', 0)
boostThickness = PipGetFloat('BoostThickness', 0.1)
useGPU = PipGetInteger('UseGPU', 0)
gpuEntered = 1 - PipGetErrNo()

# Get control values
numControl = PipNumberOfEntries('ControlValue')
for ind in range(numControl):
   (conType, conVal) = PipGetTwoFloats('ControlValue', 0, 0)
   conInt = int(round(conType))
   if conInt == 1:
      beadMaximum = conVal
   elif conInt == 2:
      maxBinning = int(conVal)
   elif conInt == 3:
      xOversizeFrac = conVal
   elif conInt == 4:
      yOversizeFrac = conVal
   elif conInt == 5:
      findAvgFallback = conVal
   elif conInt == 6:
      findStoreFallback = conVal
   elif conInt == 7:
      extraHistFrac = conVal
   elif conInt == 8:
      beadVolFrac = conVal
   elif conInt == 9:
      noBeadDiameter = conVal
   elif conInt == 10:
      threshSumFactor = conVal
   elif conInt == 11:
      minBeadEraseFrac = conVal
   elif conInt == 12:
      maxBeadEraseFrac = conVal
   elif conInt == 13:
      voxelsAtRefLimit = conVal

comSuffix = '.' + comExt
if dualNum > 0:
   comSuffix = rootName[-1] + '.' + comExt

pitchModel = 'tomopitch' + comSuffix[0:-3] + 'mod'
pitchModel = PipGetString('TomoPitchModel', pitchModel)
stackName = rootName + '.' + stackExt
if not os.path.exists(stackName):
   exitError('The raw stack ' + stackName + ' does not exist')
(nxStack, nyStack, nzStack) = getmrcsize(stackName)
volSize = math.sqrt(nxStack * nyStack)

lightBeads = PipGetBoolean('LightFeatures', 0)
lightEntered = not PipGetErrNo()

# If doing beads, get size and whether beads are light
if findBeads:
   trackLines = readTextFile('track' + comSuffix)
   if beadSize <= 0.:
      beadSize = optionValue(trackLines, 'BeadDiameter', FLOAT_VALUE, numVal = 1)
   if beadSize <= 0.:
      exitError('There is no positive BeadDiameter entry in ' + trackcom + \
                   '; fix this or enter a size with -size')
   if not lightEntered:
      lightBeads = optionValue(trackLines, 'LightBeads', BOOL_VALUE)

   # Get a binning based on bead size
   if binning <= 0:
      binning = max(1, int(round(beadSize / beadOptimal)))
      if beadSize / binning > beadMaximum and beadSize / (binning + 1) > beadMinimum:
         binning += 1
      while beadSize / binning < beadMinimum and binning > 1:
         binning -= 1
      while volSize / binning > 2 * volOptimal and \
             beadSize / (binning + 1) > beadBigVolMin:
         binning += 1

      while binning > maxBinning and beadSize / (binning - 1) <= beadMaximum:
         binning -= 1

# Or get a binning based on volume size
elif binning <= 0:
   binning = min(maxBinning, max(1, int(round(volSize / volOptimal))))

# Now that binning is known, set up scales to apply in findsection
if binning == 1:
   scales = (3, 4, 6, 8)
elif binning == 2:
   scales = (2, 3, 4)
elif binning == 3:
   scales = (1, 2, 3)
else:
   scales = (1, 2)

scalesIn = PipGetIntegerArray('ScalesToApply', 0)
if scalesIn:
   scales = scalesIn
netBin = scales[0] * binning
if netBin == 1:
   nxzBox = 48
   nyBox = 12
elif netBin == 2:
   nxzBox = 32
   nyBox = 8
else:
   nxzBox = 16
   nyBox = 4

(nxBox, nyBox, nzBox) = PipGetThreeIntegers('BoxSizeInXYZ', nxzBox, nyBox, nxzBox)
(xSpacing, ySpacing, zSpacing) = PipGetThreeIntegers(
   'SpacingOfBoxesInXYZ', max(1, nxBox // 2), max(1, nyBox // 4), max(1, nzBox // 2))
polarity = -1
if lightBeads:
   polarity = 1;

# Get a fallback bead diameter for erasure purposes and also adjust the volume fraction
# to keep the minimum size up as beads get smaller
if beadSize <= 0.:
   beadSize = noBeadDiameter
beadBinned = beadSize / binning
beadVol = (beadBinned**3) * 3.1416 / 6.
bvf = beadVolFrac
if beadBinned < beadOptimal:
   bvf = min(1., beadVolFrac * beadOptimal / beadBinned)
minForThresh = max(2, int(round(bvf * beadVol)))

# Get the newstack com lines, make sure we can get transform file
newstLines = readTextFile('newst' + comSuffix)
newstLines = extractProgramEntries(newstLines, 'newstack', '-Standard')
if not newstLines:
   exitError('The file newst' + comSuffix + ' is in an older format and cannot be used')
xfFile = optionValue(newstLines, 'TransformFile', STRING_VALUE)
if not xfFile:
   exitError('Cannot find transform file name in newst' + comSuffix)

# Get transforms and set ali size, transposing if middle terms are bigger than outer
xfLines = readTextFile(xfFile)
line = xfLines[len(xfLines) // 2]
lsplit = line.split()
nxFullAli = nxStack
nyFullAli = nyStack
try:
   if abs(float(lsplit[0])) + abs(float(lsplit[3])) < \
          abs(float(lsplit[1])) + abs(float(lsplit[2])):
      nyFullAli = nxStack
      nxFullAli = nyStack
except Exception:
   exitError('Trying to interpret transform in ' + xfFile)

# Get the tilt lines then look up the X-axis tilt and increase thickness if needed
tiltAllLines = readTextFile('tilt' + comSuffix)
tiltLines = extractProgramEntries(tiltAllLines, 'tilt', '-Standard')
xtilt = optionValue(tiltLines, 'XAXISTILT', FLOAT_VALUE, numVal = 1)
if xtilt:
   extraThick = int(round(thickness * math.tan(xtilt * 0.0174533)))
   extraThick += extraThick % 2
   if extraThick > 0.01 * thickness:
      thickness += extraThick
      prnstr('Increasing thickness to ' + str(thickness) +
             ' to compensate for the X-tilt in ' + 'tilt' + comSuffix)

# Set up oversize size and the filenames
nxFullOver = int(nxFullAli * xOversizeFrac)
nyFullOver = int(nyFullAli * yOversizeFrac)
nxAli = nxFullOver // binning
nyAli = nyFullOver // binning
xSubsetStart = (nxFullAli - nxFullOver) // 2
ySubsetStart = (nyFullAli - nyFullOver) // 2
setRootAndExtension(rootName, typeExt)
aliName = datasetFilename('_cpos.ali')
fullRecName = datasetFilename('_cpos.rec')
peakModel = rootName + '_cpos.pkmod'
autoModel = rootName + '_cposAuto.mod'
eraseRec = datasetFilename('_cposErase.rec')
threshRec = datasetFilename('_cposThresh.rec')
reprojName = datasetFilename('_cpos.reproj')
boxStack = datasetFilename('_cposBox.st')
eraseAli = datasetFilename('_cposErase.ali')
cleanList = [aliName, peakModel, autoModel, threshRec, reprojName, boxStack, eraseAli]

if leaveTemp >= 0 or (-leaveTemp) % 2 == 0:
   cleanList.append(fullRecName)
if leaveTemp >= 0 or (-leaveTemp) // 2 == 0:
   cleanList.append(eraseRec)

# Set environment variable to produce files of the right type
setOutputFormatIfNeeded(typeExt)

try:

   # Oversized aligned stack
   sedcom = [sedModify('OutputFile', aliName)] + \
       sedDelAndAdd('SizeToOutputInXandY',
                    fmtstr('{},{}', nxFullOver // binning, nyFullOver // binning),
                    'OutputFile') + \
                    sedDelAndAdd('BinByFactor', binning, 'OutputFile') + \
                    sedDelAndAdd('TaperAtFill', '1,0', 'OutputFile') + \
                    sedDelAndAdd('AntialiasFilter', -1, 'OutputFile')
   sedlines = pysed(sedcom, newstLines)
   needNewAli = not os.path.exists(aliName)
   if useTemp < 1 or needNewAli:
      prnstr('Building oversized aligned stack with binning = ' + str(binning),
             flush = True)
      runcmd('newstack -StandardInput', sedlines)

   # Oversized tomogram : First build base tilt sed command
   tiltBase = sedDelAndAdd('IMAGEBINNED', binning, 'OutputFile') + \
       sedDelAndAdd('AdjustOrigin', 1, 'OutputFile') + \
       [sedModify('THICKNESS', thickness),
        sedModify('SUBSETSTART', fmtstr('{} {}', xSubsetStart, ySubsetStart)),
        sedModify('XAXISTILT', 0.)]

   # Get rid of a log value; modify scale value if log was there OR it was seemingly
   # not modified yet
   foundLog = optionValue(tiltLines, 'LOG', FLOAT_VALUE)
   if foundLog:
      tiltBase.append('/^ *LOG/d')
   scaleArr = optionValue(tiltLines, 'SCALE', FLOAT_VALUE)
   if not scaleArr or len(scaleArr) < 2:
      exitError('Cannot verify or modify SCALE value in tilt' + comSuffix + \
                   ' for linear scaling')
   if foundLog or scaleArr[1] > 3.:
      tiltBase.append(sedModify('SCALE', fmtstr('{} {:.3f}', scaleArr[0],
                                                scaleArr[1] / 5000.)))

   if gpuEntered:
      tiltBase += sedDelAndAdd('UseGPU', useGPU, 'OutputFile')

   # Get the rest of the command for oversized tomo
   sedcom = tiltBase + [sedModify('OutputFile', fullRecName),
                        sedModify('InputProjections', aliName),
                        sedModify('WIDTH', nxFullOver),
                        '/^ *SLICE/d']
   sedlines = pysed(sedcom, tiltLines)
   if useTemp < 2 or not os.path.exists(fullRecName):
      prnstr('Building oversized binned tomogram', flush = True)
      runcmd('tilt -StandardInput', sedlines)
   
   indLowest = -1
   indStoring = -1
   recThreshold = None
   needReproj = useTemp < 5 or not os.path.exists(reprojName)
   if findBeads and needReproj:

      # Find beads in the subvolume corresponding to regular size reconstruction
      beadcom = ['InputFile ' + fullRecName,
                 'OutputFile ' + peakModel,
                 'BeadSize ' + str(beadSize),
                 'StorageThreshold -1',
                 'BinningOfVolume ' + str(binning),
                 'TiltFile ' + rootName + '.tlt',
                 'YAxisElongated',
                 fmtstr('XMinAndMax {},{}', -xSubsetStart // binning,
                        (nxFullOver + xSubsetStart) // binning),
                 fmtstr('ZMinAndMax {},{}', -ySubsetStart // binning,
                        (nyFullOver + ySubsetStart) // binning),
                 fmtstr('FallbackThresholds {},{}', findAvgFallback, findStoreFallback)]
      
      prnstr('Finding beads in tomogram', flush = True)
      findLines = runcmd('findbeads3d -StandardInput', beadcom)

      # Look for results and see if fallback storage is used, or nothing
      elongation = 1.5
      for ind in range(len(findLines)):
         line = findLines[ind].strip()
         if 'lowest dip' in line:
            indLowest = ind
         if 'Storing' in line and 'peaks in model' in line:
            indStoring = ind
            lsplit = line.split()
            try:
               numStored = int(lsplit[1])
            except Exception:
               numStored = 0
         if 'using fallback storage threshold' in line:
            prnstr(line)
            indLowest = -1
         if 'Elongation factor is' in line:
            lsplit = line.split();
            elongation = float(lsplit[3])

      if indLowest > 0:
         prnstr(findLines[indLowest].strip())
      if indStoring < 0:
         prnstr('Bead-finding failed, falling back to erasing a small fraction of ' + \
                   'dense material')
      else:

         # Now if anything was stored, extract boxes
         # Make boxes even and 
         prnstr(findLines[indStoring].strip())
         radius = 0.5 * beadSize / binning
         nxzBoxSE = int(round(2 * radius + max(radius, 6.)))
         nyBoxSE = int(round(elongation * (2 * radius + max(radius, 6.))))
         nxzBoxSE += nxzBoxSE % 2
         nyBoxSE += nyBoxSE % 2
         boxcom = ['InputImageFile ' + fullRecName,
                   'ModelFile ' + peakModel,
                   'OutputFile ' + boxStack,
                   fmtstr('VolumeSizeXYZ {},{},{}', nxzBoxSE, nyBoxSE, nxzBoxSE)]
         if useTemp < 3 or not os.path.exists(boxStack):
            prnstr('Extracting boxed beads or densities')
            runcmd('boxstartend -StandardInput', boxcom)

         # Look for threshold of extra counts on one size of peak in histogram
         prnstr('Analyzing histogram of boxed beads')
         try:
            clipLines = runcmd(fmtstr('clip hist -E {},{} "{}"', extraHistFrac, polarity,
                                      boxStack))
            recThreshold = getThresholdFromClipOutput(clipLines, 'extra counts')
         except ImodpyError:
            errStrn = getErrStrings()
            for line in errStrn:
               if 'fewer counts' in line or 'too close to end' in line:
                  prnstr(line.strip())
                  prnstr('Falling back to erasing a small fraction of dense material')
                  break
            else:   #ELSE ON FOR
               cleanup()
               exitFromImodError(progname)

         if recThreshold:
            fullRecSize = nxFullOver * nyFullOver * thickness / binning;
            maxErase = maxBeadEraseFrac
            if indStoring >= 0 and numStored > 0:
               voxelFrac = beadVol * numStored / fullRecSize
               eraseLimFactor = refEraseLimit / voxelsAtRefLimit
               prnstr(fmtstr('Fraction of voxels in beads : {:g}', voxelFrac))
               maxErase = voxelFrac * eraseLimFactor
               maxErase = min(maxBeadEraseFrac, max(minBeadEraseFrac, maxErase))
            prnstr(fmtstr('Making sure that threshold does not select more than {:.4f}' +\
                       ' of voxels', maxErase))
            if lightBeads:
               maxErase = 1. - maxErase
            clipLines = runcmd(fmtstr('clip hist -t {} "{}"', maxErase, fullRecName))
            limThreshold = getThresholdFromClipOutput(clipLines, 'Threshold value')
            if (lightBeads and limThreshold > recThreshold) or \
                   (not lightBeads and limThreshold < recThreshold):
               prnstr('Using that threshold to limit number of selected pixels')
               recThreshold = limThreshold
                   
                         
   # If no beads, or no threshold was gotten that way, do histogram on whole volume
   if recThreshold == None and needReproj:
      frac = eraseFrac
      if lightBeads:
         frac = 1. - frac
      prnstr('Getting fallback threshold value from histogram of full volume')
      clipLines = runcmd(fmtstr('clip hist -t {} "{}"', frac, fullRecName))
      recThreshold = getThresholdFromClipOutput(clipLines, 'Threshold value')

   # Threshold the volume after determining good min and max values for it
   (nxt, nyt, nzt, mode, xPix, yPix, zPix, xOrig, yOrig, zOrig, tmin, tmax, tmean) = \
       getmrc(fullRecName, doAll = True)
   if lightBeads and needReproj:
      lowForThresh = tmean
      highForThresh = min(tmax, 2 * recThreshold - tmean)
   elif needReproj:
      highForThresh = tmean
      lowForThresh = max(tmin, 2 * recThreshold - tmean)
   if useTemp < 4 or (needReproj and not os.path.exists(threshRec)):
      prnstr('Creating thresholded volume with minimum feature size ' + str(minForThresh),
             flush = True)
      runcmd(fmtstr('clip thresh -t {} -M {},{} -l {} -h {} "{}" "{}"', recThreshold,
                    minForThresh, polarity, lowForThresh, highForThresh, fullRecName,
                    threshRec))

   if needReproj:

      # Reproject thresholded volume
      sedcom = tiltBase + [sedModify('OutputFile', reprojName),
                           sedModify('InputProjections', aliName),
                          sedModify('WIDTH', nxFullOver),
                          '/^ *SLICE/d',
                           '/^ *EXCLUDELIST/d'] + \
                          sedDelAndAdd('RecFileToReproject', threshRec, 'OutputFile') + \
                          sedDelAndAdd('ThresholdedReproj',
                                       fmtstr('{} {} {}',
                                              (lowForThresh + highForThresh) / 2,
                                              polarity, threshSumFactor), 'OutputFile')
      sedlines = pysed(sedcom, tiltLines)
      prnstr('Reprojecting thresholded volume', flush = True)
      runcmd('tilt -StandardInput', sedlines)


      # Fix the header in the reprojection to match the oversize ali
      (nxt, nyt, nzt, mode, xPix, yPix, zPix, xOrig, yOrig, zOrig, tmin, tmax, tmean) = \
          getmrc(aliName, doAll = True)
      runcmd(fmtstr('alterheader -del {},{},{} -org {},{},{} "{}"',
                    xPix, yPix, zPix, xOrig, yOrig, zOrig, reprojName))

   # Make contours around thresholded density
   opt = '-l'
   if lightBeads:
      opt = '-h'
   if useTemp < 6 or not os.path.exists(autoModel):
      prnstr('Making contours around reprojected density', flush = True)
      runcmd(fmtstr('imodauto {} {} -m 1 -f 3 -x "{}" "{}"', opt, 128,
                    reprojName, autoModel))

   # Erase from the aligned stack.  There is no need to make separate output file
   ccdcom = ['InputFile ' + aliName,
             'OutputFile ' + eraseAli,
             'ModelFile ' + autoModel,
             'BoundaryObjects 1',
             'PolynomialOrder 0']
   if useTemp < 7 or needNewAli:
      prnstr('Erasing high-density regions from aligned stack', flush = True)
      runcmd('ccderaser -StandardInput', ccdcom)

   # Make new tomogram, back down to regular size, from erased stack
   sedcom = tiltBase + [sedModify('OutputFile', eraseRec),
                        sedModify('InputProjections', eraseAli)] + \
                        sedDelAndAdd('WIDTH', nxFullAli, 'SCALE') + \
                        sedDelAndAdd('SLICE', fmtstr('{} {}', -ySubsetStart,
                                                     nyFullOver + ySubsetStart - binning),
                                     'SCALE')
   sedlines = pysed(sedcom, tiltLines)
   if useTemp < 8 or not os.path.exists(threshRec):
      prnstr('Building tomogram from erased stack', flush = True)
      runcmd('tilt -StandardInput', sedlines)

   # Find the section at last
   findcom = ['TomogramFile ' + eraseRec,
              'HighSDboxCriterion ' + str(highSDcrit),
              'BoostHighSDThickness ' + str(boostThickness),
              fmtstr('SizeOfBoxesInXYZ {},{},{}', nxBox, nyBox, nzBox),
              fmtstr('SpacingInXYZ {},{},{}', xSpacing, ySpacing, zSpacing),
              'TomoPitchModel ' + pitchModel]
   for scale in scales:
      findcom.append(fmtstr('BinningInXYZ {0},{0},{0}', scale))
   if ((not needReproj and os.path.exists(boxStack)) or indStoring >= 0) and findBeads >1:
      findcom += ['BeadModelFile ' + peakModel,
                  'BeadDiameter ' + str(beadSize / binning)]
      if findBeads == 1:
         findcom.append('ControlValue 29,0.')
   prnstr('Analyzing structure to find material to include', flush = True)
   runcmd('findsection ' + fsOpts + ' -StandardInput', findcom, 'stdout')
   prnstr('Tomopitch model created', flush = True)
   #if xtilt:
   #   pysed([sedModify('XAXISTILT', 0.)], tiltAllLines, 'tilt' + comSuffix)
   #   prnstr('Set XAXISTILT to 0 in tilt' + comSuffix + ' to match volume analyzed')
   cleanup()
   sys.exit(0)

except ImodpyError:
   cleanup()
   exitFromImodError(progname)
except (IndexError, ValueError):
   cleanup()
   exitError('An error occurred interpreting program output')
