# -*- coding: utf-8 -*-
"""
Created on Wed Aug 12 09:35:34 2020

@author: ppantina
"""

# -*- coding: utf-8 -*-
"""
Created on Mon Jul 20 13:16:24 2020

@author: ppantina
"""
'''
def [dataout,headers,extraPackets] = readDigRxMixed(fullName,kind,gates,maxPacketsToRead,firstPackets,socketID)
 fullName: path and file name
 kind: channel type (1x8 array)
       0: Raw
       1: M0
       2: M0, M1normal
       3: M0, M1staggered
       4: M0, M1staggered, M2
 gates           : number of range gates per channel (1x8 array)
 maxPacketsToRead: number of packets to read from file.  Default: 1e20 (a large number)
 firstPackets    : Any packets that go before this file, from a previous iteration of this function.
 socketID        : socketID to read.  Default: [] (everything)
 saveExtraPackets: boolean to process all data or to save a partial block as extraPackets.
                   Default: 'false' to save all data and return empty extraPackets

 dataout     : dict of output data
 headers     : dict of headers
 extraPackets: if desired, return the last partial block of continuous data.
               Put these before the next file as 'firstPackets' if desired.

 This program is designed to read moments or raw data coming from the 2015 version of the Remote Sensing Solutions
 IRAP digital receiver.

 Written in   Matlab by Matt McLinden, NASA/Goddard Space Flight Center
 Rewritten in Python by Peter Pantina, SSAI/NASA GSFC, July 2020
'''

def L0_sub_readDigRxMixed_singleCh (fullName, kind, gates, maxPacketsToRead = 1e20, firstPackets = [], socketID = [], saveExtraPackets = False):

    ##Import some modules
    import numpy as np
    import os
    import sys
    from  struct        import unpack
    from  nested_lookup import nested_lookup

    ##Declare some variables
    packetSize      = 740                                #predefined [bytes]
    channels        = np.where( gates>0)[0]              #we ignore channels with no data
    channelsRaw     = np.where((gates>0) & (kind==0))[0] #select the raw channels
    channelsMoments = np.where((gates>0) & (kind> 0))[0] #select the moments channels
    numBytes        = os.path.getsize(fullName)          #number of bytes   in file
    numPackets      = int(np.floor(numBytes/packetSize)) #number of packets in file

    #Read either ALL packets, or user-specified number, whichever is smaller
    packetsToRead = np.min([maxPacketsToRead, numPackets])

    ##Open File and read all the data into memory
    print('Opening ' + fullName + '...')
    f = open(fullName, 'rb')
    if f.fileno() < 0:
        print('ERROR! File Failed to Open')
        sys.exit()
    #endif

    ##Print some statistics to the screen
    print('Reading up to %i packets, total of %5.1f MB...' %(packetsToRead,packetsToRead*packetSize/1024/1024))

    #packetBlock is a [packetSize x numPackets] block of data
    ##<I are little-endian uint32s (4-words with 8-bytes = 32-bits)
    prepacketBlock = np.fromfile(fullName, dtype = '<I', count = int(packetSize * int(packetsToRead/4)))
    packetBlock    = np.reshape (prepacketBlock,(int(len(prepacketBlock)/packetSize), packetSize)).T
    if len(firstPackets) > 0:
        packetBlock = np.concatenate((firstPackets, packetBlock), axis = 1)
    f.close()

    ##Swap the packet header bytes (due to digirx encoding) and save to dict{}.
    ##The packet header is the first 6x 32-bit words.
    ##Ignore the 3rd and 4th header idx here.
    print('Swapping header bytes...')
    digrx = {}
    digrx['digrxsec' ]= np.zeros(packetBlock.shape[1])
    digrx['digrxusec']= np.zeros(packetBlock.shape[1])
    digrx['socketID' ]= np.zeros(packetBlock.shape[1])
    digrx['count'    ]= np.zeros(packetBlock.shape[1])

    for i in range(packetBlock.shape[1]): #endian swap to decode header
        digrx['digrxsec' ][i] = unpack('>I', packetBlock[0,i])[0]
        digrx['digrxusec'][i] = unpack('>I', packetBlock[1,i])[0]
        digrx['socketID' ][i] = unpack('>I', packetBlock[2,i])[0]
        digrx['count'    ][i] = unpack('>I', packetBlock[5,i])[0]
    #endfor

    ##If we are filtering by socketID, do so here.
    ###THIS LOOP IS NOT YET TESTED###
    if len(socketID)>0:
        packetBlock = packetBlock[:,digrx['socketID']==socketID]
        digrx['digrxsec' ][i] = unpack('>I', packetBlock[0,i])[0]
        digrx['digrxusec'][i] = unpack('>I', packetBlock[1,i])[0]
        digrx['socketID' ][i] = unpack('>I', packetBlock[2,i])[0]
        digrx['count'    ][i] = unpack('>I', packetBlock[5,i])[0]
    #endif

    ##Find contigous sets of packets and process them as blocks.
    ##Print some stats to the screen
    breaks = np.where(np.diff(digrx['count'])!=1)[0] #Find breaks (where the count jumps)
    print('Finding breaks... missing %d of %d packets (%3.1f percent), with %d gaps'\
    %((np.max(digrx['count'])- np.min(digrx['count'])+1)-len(digrx['count']),\
       np.max(digrx['count'])- np.min(digrx['count']),\
    100*(-len(digrx['count'])+(np.max(digrx['count'])-np.min(digrx['count'])+1))/(np.max(digrx['count'])-np.min(digrx['count'])+1),len(breaks)))

    ##If there are breaks, AND you want to save extraPackets,
    ##define the extraPackets as the last data points after the final break.
    ##Otherwise, return an empty array and set the
    ##'final break' as simply the length of the packetBlock,
    ##thereby processing ALL packets
    if ((len(breaks)>0) & (saveExtraPackets == True)):
        extraPackets = packetBlock[:,breaks[-1]+1:]
        print('    Saving %d of %d packets in extraPackets\n' %(len(extraPackets[0,:]),len(packetBlock[0,:])))
    else:
        extraPackets = []
        breaks       = np.hstack((breaks, packetBlock.shape[1]))
    #endif

    ##Declare some empty dicts
    head    = {}
    numWords= {}
    data    = {}
    raw     = {}
    mom     = {}

    ##Process each block of contiguous data.
    print('Parsing data...')
    blkptr = 0

    for b in range(len(breaks)):

        ##If the block doesn't have any packets in it, increment the block counter and move on
        ##Blocks with 2 or fewer records have failed, so setting this to 10
        ##Any skipped blocked get concatenated out later. There are no gaps.
        if breaks[b]-blkptr < 10:
            blkptr = breaks[b] + 1
            continue

        ##Make the block into an array of data without headers.
        ##Cut out first 6 words (header data), read to the next break,
        ##and reshape into an flattened array.
        ##The "F" notation flattens the array in a Matlab-like format.
        packetArray = packetBlock[6:,blkptr:breaks[b]].flatten('F')

        ##Find indexes in this array that are starts of a profile.
        ##This is when the word is (in uint32) equal to 32768 (makes more sensen in hex).
        proInd  = np.where(packetArray==32768)[0]

        if len(proInd) == 0:
            continue

        ##The last profile will be incomplete in almost every case, so remove it.
        proInd = np.delete(proInd, -1)

        ##Find out what channels these profile starts correspond to.
        ##Each profile starts with a profile header (defined below).
        ##Just take [Matt's] word for it that this will give the channel for the profile.
        proIndChs   = ((packetArray[proInd+16]>>20) % 2**3) ##same as channel decode, below.

        ##Divide the profile indexes by channel
        for ch in channels:
            ##head.setdefault(xxx, {}) initializes a dict with new keys.
            head.setdefault('chan' + str(ch), {}).setdefault('block' + str(b), {})['ind'] = proInd[proIndChs==ch] #int64
            ##The following decoding was copied from Matt M. on July 2020
            ##Mostly saving these as uint32s. Use uint64 when summing two uint32s.
            head['chan'+ str(ch)]['block' + str(b)]['priCnt'    ] =           packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +10]
            head['chan'+ str(ch)]['block' + str(b)]['latch'     ] =           packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +2 ] & int('0000FFFF',16)
            head['chan'+ str(ch)]['block' + str(b)]['encoder'   ] =           packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +4 ]
            head['chan'+ str(ch)]['block' + str(b)]['ppsFracCnt'] = np.uint64(packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +12]) + np.uint64(packetArray[head['chan'+ str(ch)]['block' + str(b)]['ind']+13])*2**32
            head['chan'+ str(ch)]['block' + str(b)]['ppsCnt'    ] =           packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +14]
            head['chan'+ str(ch)]['block' + str(b)]['tenMHzCnt' ] = np.uint64(packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +8 ]) + np.uint64(packetArray[head['chan'+ str(ch)]['block' + str(b)]['ind']+9 ])*2**32
            head['chan'+ str(ch)]['block' + str(b)]['rxid'      ] =         ((packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +16]) >> 24) % 2
            head['chan'+ str(ch)]['block' + str(b)]['channel'   ] =         ((packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +16]) >> 20) % 2**3
            head['chan'+ str(ch)]['block' + str(b)]['numProds'  ] =         ((packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +16]) >> 16) % 2**2+1
            head['chan'+ str(ch)]['block' + str(b)]['numGates'  ] =         ((packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +17]) >> 16)*2+2
            head['chan'+ str(ch)]['block' + str(b)]['numPRIs'   ] =          (packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'] +17]  & int('0000FFFF', 16))+1

            ##Find the idx of the nearest cpuTime based on the packet header that contained the profile start.
            ##Original Matlab code used "ceil" to find the idx. Using "rint" (round) here.
            idx = (np.floor(head['chan'+ str(ch)]['block' + str(b)]['ind']/(packetSize-6))+blkptr).astype('int') #idx must be an int
            head['chan'+ str(ch)]['block' + str(b)]['cpusec' ]   = digrx['digrxsec' ][idx]
            head['chan'+ str(ch)]['block' + str(b)]['cpuusec']   = digrx['digrxusec'][idx]
        #endfor ch

        ##For raw data
        ##Parse out data and save to dict
        for ch in channelsRaw:

            ##Number of profiles in this block.
            numRaw = len(head['chan'+ str(ch)]['block' + str(b)]['ind'])

            ##Number of numbers (words) to read
            numWords.setdefault('chan' + str(ch), {})['block' + str(b)] = gates[ch]*2 #number of gates x2 (real/imag), each is 1 word.

            ##Declare empty array for uint32s
            data[ch] = np.zeros((numRaw, numWords['chan' + str(ch)]['block' + str(b)]), dtype = 'uint32')

            ##Grab the data
            ##For each profile of this channel within the block, grab numWords starting at index+18
            for p in np.arange(numRaw):
                data[ch][p,:]= packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'][p]+18:head['chan' + str(ch)]['block' + str(b)]['ind'][p]+18 +numWords['chan' + str(ch)]['block' + str(b)]]
            ##Cast them as floats. Normally you would use .astype('float32) above, but this
            ##does not work. Changing the dtype tells Python to reinterpret the uint32s as float32s.
            data[ch].dtype = np.float32

            ##Split data into gates
            ##Data is [real,imag,real,imag,real,imag...], so use [::2]
            raw.setdefault('chan' + str(ch), {}).setdefault('block' + str(b), {})['raw'] =  data[ch][:,0:][:,::2] + 1j*data[ch][:,1:][:,::2]
        #endfor ch raw

        ##For moments
        '''
        ##Use priCnts in this case. Loop thru all channels and determine
        ##which priCnts are common to all channels. Then use that as a new index.
        priCnts = head['chan0']['block' + str(b)]['priCnt'] ##priCnt, block0

        for ch in channelsMoments:
            ##For each channel, find where priCnts (from block0 above) intersect (are common with) priCnts from channel [ch].
            ##Continue thru all channels. priCnts will likely end up smaller than what you started with.
            ##Convert the intersected {set} to a [list] and then a np.array() for easier processing.
            priCnts = np.array(list(set(head['chan' + str(ch)]['block' + str(b)]['priCnt']).intersection(priCnts)))
        '''

        for ch in channelsMoments:
            '''
            ##For each moments channel, locate the indices [np.in1d] where priCnts
            ##match those in the priCnts array from above.
            idx   = np.in1d(head['chan' + str(ch)]['block' + str(b)]['priCnt'],  priCnts)

            ##Define all keys
            names = head['chan' + str(ch)]['block0'].keys()
            for n in names:
                ##For all keys, reindex the header information according to the new priCnt
                head['chan' + str(ch)]['block' + str(b)][n] = head['chan' + str(ch)]['block' + str(b)][n][idx]
           '''

            ##Number of profiles in this block.
            numMoments = len(head['chan' + str(ch)]['block' + str(b)]['ind'])

            ##Number of numbers (words) to read. Depending on the "kind" (which moments you have),
            ##you must read a different number of words.
            if   kind[ch] == 1: numWords.setdefault('chan' + str(ch), {})['block' + str(b)] = gates[ch]*2 #M0 only
            elif kind[ch] == 2: numWords.setdefault('chan' + str(ch), {})['block' + str(b)] = gates[ch]*4 #M0, M1
            elif kind[ch] == 3: numWords.setdefault('chan' + str(ch), {})['block' + str(b)] = gates[ch]*6 #M0, M1 staggered
            elif kind[ch] == 4: numWords.setdefault('chan' + str(ch), {})['block' + str(b)] = gates[ch]*8 #M0, M1 staggered, M2

            ##Declare empty array for uint32s
            data[ch]= np.zeros((numMoments, numWords['chan' + str(ch)]['block' + str(b)]), dtype = 'uint32')

            ##Grab the data
            ##For each profile of this channel within the block, grab numWords starting at index+18
            for p in np.arange(numMoments):
               data[ch][p,:]= packetArray[head['chan' + str(ch)]['block' + str(b)]['ind'][p]+18:head['chan' + str(ch)]['block' + str(b)]['ind'][p]+18+numWords['chan' + str(ch)]['block' + str(b)]]
            ##Cast them as floats. Normally you would use .astype('float32) above, but this
            ##does not work. Changing the dtype tells Python to reinterpret the uint32s as float32s.
            data[ch].dtype = np.float32
        #endfor ch mom

        ##Split data into moments
        ##Decoding routine taken from Matt M, July 2020.
        for ch in channels:
            ##All channels have m0
            mom.setdefault('chan' + str(ch), {}).setdefault('block' + str(b), {})['m0'] =data[ch][:,0:gates[ch]*2:2]

            if kind[ch]== 2: #if m1 mode
                mom['chan' + str(ch)]['block' + str(b)]['m1'] =(data[ch][:,gates[ch]*2+1:2:gates[ch]*2]+1j*data[ch][:,gates[ch]*2+2:2:gates[ch]*2]).astype(np.complex64)

            elif ((kind[ch]==3) | (kind[ch]==4)): #if m1a/m1b mode
                mom['chan' + str(ch)]['block' + str(b)]['m1a'] = np.zeros_like(mom['chan' + str(ch)]['block' + str(b)]['m0'], dtype = 'complex64' )
                mom['chan' + str(ch)]['block' + str(b)]['m1b'] = np.zeros_like(mom['chan' + str(ch)]['block' + str(b)]['m0'], dtype = 'complex64' )

                mom['chan' + str(ch)]['block' + str(b)]['m1a'][:,0:gates[ch]:2] = data[ch][:,gates[ch]*2+np.arange(0,gates[ch]*4,8)]+1j*data[ch][:,gates[ch]*2+np.arange(1,gates[ch]*4,8)]
                mom['chan' + str(ch)]['block' + str(b)]['m1a'][:,1:gates[ch]:2] = data[ch][:,gates[ch]*2+np.arange(2,gates[ch]*4,8)]+1j*data[ch][:,gates[ch]*2+np.arange(3,gates[ch]*4,8)]
                mom['chan' + str(ch)]['block' + str(b)]['m1b'][:,0:gates[ch]:2] = data[ch][:,gates[ch]*2+np.arange(4,gates[ch]*4,8)]+1j*data[ch][:,gates[ch]*2+np.arange(5,gates[ch]*4,8)]
                mom['chan' + str(ch)]['block' + str(b)]['m1b'][:,1:gates[ch]:2] = data[ch][:,gates[ch]*2+np.arange(6,gates[ch]*4,8)]+1j*data[ch][:,gates[ch]*2+np.arange(7,gates[ch]*4,8)]
            #endif

            if kind[ch] == 4: #if m2 mode, add in a final m2 key
                mom['chan' + str(ch)]['block' + str(b)]['m2'] =                  (data[ch][:,gates[ch]*6+np.arange(0,gates[ch]*2,2)]+1j*data[ch][:,gates[ch]*6+np.arange(1,gates[ch]*2,2)])
            #endif
        #endfor ch

        ##Increment the block counter.
        blkptr = breaks[b]+1
    #endfor b

    ##Now go thru dicts to remove the ['block'] key
    print ('Concatinating Data... ')
    headers  = {}
    dataout  = {}

    ##For all channels, concatenate the header
    for ch in channels:
        names = head['chan' + str(ch)]['block0'].keys()
        for n in names: ##enumerate the header keys
            ##Nested_lookup finds any instance of key ['n'] in a dict{},
            ##and saves them as an array of arrays
            nest = nested_lookup(n, head['chan' + str(ch)])
            ##Concatenate them to make a single array
            headers.setdefault('chan' + str(ch), {})[n] = np.concatenate(nest, axis = 0)

    ##For all moment channels, concatenate the moms
    for ch in channelsMoments:
        names = mom['chan' + str(ch)]['block0'].keys()
        for n in names: ##enumerate the moment keys
            ##Nested_lookup finds any instance of key ['n'] in a dict{},
            ##and saves them as an array of arrays
            nest = nested_lookup(n, mom['chan' + str(ch)])
            ##Concatenate them to make a single array (axis == 1)
            dataout.setdefault('chan' + str(ch), {})[n] = np.concatenate(nest, axis = 0)

    ##For all raw channels, concatenate the raw
    for ch in channelsRaw:
        names = raw['chan' + str(ch)]['block0'].keys()
        for n in names: ##enumerate the raw keys (only ['raw'] at the moment)
            ##Nested_lookup finds any instance of key ['n'] in a dict{},
            ##and saves them as an array of arrays
            nest = nested_lookup(n, raw['chan' + str(ch)])
            ##Concatenate them to make a single array (axis == 1)
            dataout.setdefault('chan' + str(ch), {})[n] = np.concatenate(nest, axis = 0)

    return(headers, dataout, extraPackets)

