#!/usr/bin/env python
# Reducefiltvol - run binvol to reduce and/or mtffilter to filter a volume
#
# Author: David Mastronarde
#
# $Id: reducefiltvol,v 371d0dd0ef98 2024/10/03 19:51:21 mast $

progname = 'reducefiltvol'
prefix = 'ERROR: ' + progname + ' - '

# load System Libraries
import os, sys, math

#
# Setup runtime environment
if os.getenv('IMOD_DIR') != None:
   IMOD_DIR = os.environ['IMOD_DIR']
   if sys.platform == 'cygwin' and sys.version_info[0] > 2:
      IMOD_DIR = IMOD_DIR.replace('\\', '/')
      if IMOD_DIR[1] == ':' and IMOD_DIR[2] == '/':
         IMOD_DIR = '/cygdrive/' + IMOD_DIR[0].lower() + IMOD_DIR[2:]
   sys.path.insert(0, os.path.join(IMOD_DIR, 'pylib'))
   from imodpy import *
   addIMODbinIgnoreSIGHUP()
else:
   sys.stdout.write(prefix + " IMOD_DIR is not defined!\n")
   sys.exit(1)

#
# load IMOD Libraries
from pip import *

# Initializations
InputFile = ""
OutputFile = ""

# Fallbacks from ../manpages/autodoc2man 3 1 reducefiltvol
options = ["input:InputFile:FN:", "output:OutputFile:FN:", "reduce:ReductionFactor:F:",
           "zfactor:ZReductionFactor:F:", "lowpass:LowPassRadiusSigma:FP:",
           "deconv:DeconvolutionStrength:F:", "snr:SNRFalloff:F:",
           "dchigh:HighPassNyquist:F:", "pixel:PixelSize:F:", "volt:Voltage:I:",
           "cs:SphericalAberration:F:", "defocus:DefocusInMicrons:F:",
           "mode:PhaseShift:I:", "setup:SetupChunksIfMemoryError:B:",
           "param:ParameterFile:PF:", "help:usage:B:"]

#
(opts, nonopts) = PipReadOrParseOptions(sys.argv, options, progname, 3, 1, 1)

# Input and output files
inputFile = PipGetInOutFile('InputFile', 0)
if not inputFile:
   exitError('An input file must be entered')
outputFile = PipGetInOutFile('OutputFile', 1)
if not outputFile:
   exitError('An output file must be entered')

if not os.path.exists(inputFile):
   exitError("Input file " + inputFile + " does not exist")

# Reduction factors
factor = PipGetFloat('ReductionFactor', 1.)
ifFactor = 1 - PipGetErrNo()
zFactor = PipGetFloat('ZReductionFactor', factor)
doReduce = (ifFactor or PipGetErrNo() == 0) and (factor != 1. or zFactor != 1.)
if factor < 1. or zFactor < 1.:
   exitError('Reduction factors must be greater than or equal to 1')

# Filtering main options
(radius, sigma) = PipGetTwoFloats('LowPassRadiusSigma', 0., 0.)
doGaussian = 1 - PipGetErrNo()
deconvStrength = PipGetFloat('DeconvolutionStrength', 0.)
doDeconv = 1 - PipGetErrNo()
if doGaussian and doDeconv:
   exitError('You cannot do both a Gaussian filter and a deconvolution filter')

doFilter = doGaussian or doDeconv
setupChunks = PipGetBoolean('SetupChunksIfMemoryError', 0)

if not doReduce and not doFilter:
   exitError('There is no meaningful operation specified')

# Get all the optional entries, keep track if gotten
snrFalloff = PipGetFloat('SNRFalloff', 0.)
ifSNR = 1 - PipGetErrNo()
highPass = PipGetFloat('HighPassNyquist', 0.)
ifHighPass = 1 - PipGetErrNo()
defocus = PipGetFloat('DefocusInMicrons', 0.)
ifDefocus = 1 - PipGetErrNo()
phase = PipGetFloat('PhaseShift', 0.)
ifPhase = 1 - PipGetErrNo()
pixelSize = PipGetFloat('PixelSize', 0.)
ifPixel = 1 - PipGetErrNo()
spherAber = PipGetFloat('SphericalAberration', 0.)
ifCs = 1 - PipGetErrNo()
voltage = PipGetInteger('Voltage', 0.)
ifVoltage = 1 - PipGetErrNo()
outMode = PipGetInteger('ModeToOutput', 0)
ifMode = 1 - PipGetErrNo()

# Set up file names if intermediate file
redOutput = outputFile
filtInput = inputFile
if doReduce and doFilter:
   (outRoot, outExt) = os.path.splitext(outputFile)
   redOutput = outRoot + '.filttemp' + outExt
   filtInput = redOutput

# Now get ctfplotter com file if it is needed for anything in there
if doDeconv and not (ifPixel and ifCs and ifVoltage):

   plotcom = 'ctfplotter.com'
   (comExt, dualNum, rootName, typeExt, stackExt) = findRootAxisAndExtensions()
   if dualNum == 2:
      plotcom = 'ctfplottera.com'

   mess = ''
   vmess = ''
   csMess = ''
   if not os.path.exists(plotcom):
      mess = plotcom + ' does not exist'
      vmess = mess
      csMess = mess
   else:
      ctfLines = readTextFile(plotcom)

   # Find the pixel size and raw stack
   if not ifPixel and not mess:
      rawPixel = optionValue(ctfLines, 'PixelSize', FLOAT_VALUE, numVal = 1)
      if not rawPixel:
         mess = 'could not find raw stack pixel size in ' + plotcom
      else:   
         stackInput = optionValue(ctfLines, 'InputStack', STRING_VALUE)
         if not stackInput:
            mess = 'could not find name of raw stack in ' + plotcom
         elif not os.path.exists(stackInput):
            mess = 'raw stack file ' + stackInput + ' not found'

   # Read the headers to determine binning and get final pixel size
   if not ifPixel and not mess:
      try:
         which = stackInput
         (nx, ny, nz, mode, pxRaw, pyRaw, pzRaw) = getmrc(stackInput)
         which = inputFile
         (nx, ny, nz, mode, pxVol, pyVol, pzVol) = getmrc(inputFile)
         pixelSize = rawPixel * pxVol / pxRaw
         if doReduce:
            pixelSize *= factor
         ifPixel = 1
         
      except ImodpyError:
         mess = 'there was an error reading the header of ' + which

   if mess:
      prnstr('Assuming pixel size in volume header is correct; ' + mess)

   # Collect voltage and/or Cs and just go on if not there
   if not ifVoltage and not vmess:
      voltage = optionValue(ctfLines, 'Voltage', INT_VALUE, numVal = 1)
      if not voltage:
         vmess = 'could not find value in ' + plotcom
      else:
         ifVoltage = 1

   if vmess:
      prnstr('Assuming voltage is 300; ' + vmess)

   if not ifCs and not csMess:
      spherAber = optionValue(ctfLines, 'SphericalAberration', FLOAT_VALUE, numVal = 1)
      if not spherAber:
         csMess = ' could not find value in ' + plotcom
      else:
         ifCs = 1

   if csMess:
      prnstr('Assuming spherical aberration is 2.7; ' + csMess)
             
   # For defocus, need a file to analyze
   if not defocus:
      if not rootName:
         exitError('Defocus must be entered; could not determine root name of dataset')
      defFile = rootName + '.defocus'
      if dualNum == 2:
         defFile = rootName + 'a.defocus'
         if not os.path.exists(defFile):
            defFile = rootName + 'b.defocus'

      if not os.path.exists(defFile):
         exitError('Defocus must be entered; could not find ' + defFile)

      # Figure out what's there from the header
      defLines = readTextFile(defFile)
      lsplit = defLines[0].split()
      astig = False
      startLine = 0
      if len(lsplit) < 5:
         exitError('Defocus must be entered; the defocus file ' + defFile + \
                   ' has too few entries on first line')
      if len(lsplit) == 6:
         try:
            versNum = int(lsplit[5])
            if versNum > 2:
               flags = int(lsplit[0])
               astig = flags % 2 != 0
               startLine = 1

         except ValueError:
            exitError('Defocus must be entered; an error occurred converting a ' + \
                      'value to integer on first line of ' + defFile)

      if startLine > 0 and len(defLines) < 2:
         exitError('Defocus must be entered; there is only a header line in ' + defFile)

      # Get the defocus at minimum angle
      minAngle = 1000.
      for line in defLines[startLine:]:
         lsplit = line.split()
         try:
            lowAngle = float(lsplit[2])
            highAngle = float(lsplit[3])
            angle = math.fabs(lowAngle + highAngle) / 2.
            if angle < minAngle:
               minAngle = angle
               defocus = float(lsplit[4])
               if astig:
                  defocus = (defocus + float(lsplit[5])) / 2.
               defocus /= 1000.
            
         except Exception:
            exitError('Defocus must be entered; an error occurred trying to analyze ' + \
                      defFile)

# Set up filter com lines before running in case of program error 
if doFilter:
   comlines = ['InputFile ' + filtInput,
               'OutputFile ' + outputFile,
               'FilterIn3D 1']
   if ifMode:
         comlines.append('ModeToOutput ' + str(outMode))
   if doGaussian:
      comlines.append(fmtstr('LowPassRadiusSigma {},{}', radius, sigma))
   else:
      comlines += ['DeconvolutionStrength ' + str(deconvStrength),
                   'PixelSize ' + str(pixelSize),
                   'Defocus ' + str(defocus)]
      if ifSNR:
         comlines.append('SNRFalloff ' + str(snrFalloff))
      if ifHighPass:
         comlines.append('HighPassNyquist ' + str(highPass))
      if ifPhase:
         comlines.append('PhaseShift ' + str(phase))
      if ifVoltage:
         comlines.append('Voltage ' + str(voltage))
      if ifCs:
         comlines.append('SphericalAberration ' + str(spherAber))
            
# Run reduction
os.environ['IMOD_BRIEF_HEADER'] = '1'
if doReduce:
   modeOpt = ''
   if ifMode:
      modeOpt = '-mode ' + str(outMode)
   try:
      prnstr("Reducing the volume with Binvol...")
      runcmd(fmtstr('binvol -xbin {} -ybin {} -zbin {} -anti -1 {} "{}" "{}"', factor,
                    factor, zFactor, modeOpt, inputFile, redOutput), None, 'stdout')
   except ImodpyError:
      exitFromImodError(progname)

# Run filter
if doFilter:
   prnstr('Filtering the volume with Mtffilter...')
   try:

      # Filter, print output, and clean up and exit
      filtLines = runcmd('mtffilter -StandardInput', comlines)
      for line in filtLines:
         prnstr(line.rstrip())
      if doReduce:
         cleanupFiles([filtInput])
      sys.exit(0)
      
   except ImodpyError:
      if setupChunks:

         # Check for memory error
         errStrings = getErrStrings()
         for line in errStrings:
            if '[MTF1]' in line:
               prnstr('Mtffilter exited with: ' + line)
               prnstr('Setting up chunks for filtering the volume in parallel',
                      flush = True)
               break
         else:  # ELSE ON FOR
            exitFromImodError(progname)

      else:
         exitFromImodError(progname)

   # Have to make chunk files
   
   comlines[0] = 'InputFile INPUTFILE'
   comlines[1] = 'OutputFile OUTPUTFILE'
   comlines.insert(0, '$mtffilter -StandardInput')
   writeTextFile('rfvfilter.com', comlines)
   try:
      runcmd(fmtstr('chunksetup -master rfvfilter.com "{}" "{}"', filtInput, outputFile),
             None, 'stdout')
   except ImodpyError:
      exitFromImodError(progname)

   # Add cleanup line to finish file, forgive errors
   if doReduce:
      finish = 'rfvfilter-finish.com'
      finLines = readTextFile(finish, returnOnErr = True)
      if isinstance(finLines, str):
         prnstr('WARNING: You will have to remove ' + filtInput + ' when done; failed ' +\
                'to read ' + finish + ' with error: ' + finLines)
      else:
         finLines.append('$b3dremove ' + filtInput)
         if writeTextFile(finish, finLines, returnOnErr = True):
            prnstr('WARNING: You will have to remove ' + filtInput + ' when done; ' + \
                   'failed to write ' + finish)

sys.exit(0)
