import math
import numpy as np
import sys

import vtk


class BlockInfo:
    def __init__(self, dims, origin, spacing):
        self.dims = np.asarray(dims)
        self.origin = np.asarray(origin)
        self.spacing = np.asarray(spacing)
        self.level = None

    @property
    def nnode(self):
        return np.prod(self.dims)

    def levelRelativeTo(self, dxOther):
        ratio = abs(dxOther) / abs(self.spacing[0])
        return round(math.log(ratio, 2))


if len(sys.argv) != 2:
    print('Usage:', sys.argv[0], 'file.npy')
    sys.exit(1)

# Read the block dimensions, origins, spacings
with open(sys.argv[1], 'rb') as fin:
    dims = np.load(fin)
    origins = np.load(fin)
    spacings = np.load(fin)

# Stuff it all into a convenient class
nblock = len(dims)
blocks = [BlockInfo(dims[i], origins[i], spacings[i]) for i in range(nblock)]

# Figure out whether we're dealing with blocks in "good" orientation where all
# the spacings are positive or if it's a "bad" orientation with a mix of
# positive and negative.  We can just check the first block.  The others will
# be oriented the same.
goodOrientation = np.all(blocks[0].spacing > 0)
if goodOrientation:
    print('Good orientation')
else:
    print('Bad orientation')

# Figure out the coarsest absolute spacing in the dataset
smax = max(abs(blk.spacing[0]) for blk in blocks)
print('Coarsest spacing in dataset =', smax)

# Now use the assumed orientation to set the spacing in all three directions at
# the coarsest level.
if goodOrientation:
    coarseSpacing = np.array((smax, smax, smax))
else:
    coarseSpacing = np.array((-smax, smax, -smax))
print('Spacing vector on coarsest level =', coarseSpacing)

# Now use each block's spacing relative to the coarse spacing to figure out the
# level of each block in the AMR structure
maxLevel = 0
for blk in blocks:
    blk.level = blk.levelRelativeTo(smax)
    maxLevel = max(maxLevel, blk.level)
numLevels = maxLevel + 1
print(f'Total of {numLevels} levels of refinement')

# Figure out how many blocks are on each level
blocksPerLevel = [0] * numLevels
for blk in blocks:
    blocksPerLevel[blk.level] += 1
print('Number of blocks per level =', blocksPerLevel)

# Figure out the global origin.  We have to do this differently depending on
# the orientation of the blocks.
origins = np.array([blk.origin for blk in blocks if blk.level == 0])
if goodOrientation:
    globalOrigin = origins.min(axis=0) # FIXME
else:
    globalOrigin = [origins[:,0].max(), origins[:,1].min(), origins[:,2].max()]
print('Global origin =', globalOrigin)

# Build the main AMR dataset
amr = vtk.vtkOverlappingAMR()
amr.Initialize(numLevels, blocksPerLevel)
amr.SetOrigin(globalOrigin)
amr.SetGridDescription(vtk.VTK_XYZ_GRID)
for level in range(numLevels):
    amr.GetAMRInfo().SetSpacing(level, coarseSpacing / 2**level)

# Construct a nested list of AMR boxes
# amrBoxes[i][j] is the j'th box on the i'th level.
amrBoxes = [list() for _ in range(numLevels)]
for flatIdx,blk in enumerate(blocks):
    box = vtk.vtkAMRBox(blk.origin, blk.dims, blk.spacing, globalOrigin)
    amrBoxes[blk.level].append(box)
    index = len(amrBoxes[blk.level]) - 1
    amr.SetAMRBlockSourceIndex(blk.level, index, flatIdx)

# Insert all the AMR boxes into the main AMR data structure
for level in range(numLevels):
    for index,box in enumerate(amrBoxes[level]):
        amr.GetAMRInfo().SetAMRBox(level, index, box)

# NOTE: this is where the segfault happens when the orientation is "bad"
print('Generating parent/child information')
amr.GenerateParentChildInformation()

# Construct the actual grid object in each AMR box
for level in range(numLevels):
    for index,box in enumerate(amrBoxes[level]):
        flatIdx = amr.GetAMRBlockSourceIndex(blk.level, index)
        blk = blocks[flatIdx]
        grid = vtk.vtkUniformGrid()
        grid.SetDimensions(blk.dims)
        grid.SetOrigin(blk.origin)
        grid.SetSpacing(blk.spacing)
        amr.SetAMRBox(level, index, box)
        amr.SetDataSet(level, index, grid)

# NOTE: this results in "malloc(): corrupted top size" regardless of orientation
print('Blanking cells')
vtk.vtkAMRUtilities.BlankCells(amr)
