import numpy as np
import vtk
from vtk.util import numpy_support
from timeit import default_timer as timer

def export_vtk_file(dataset, filename, binary = True):
    writer = None
    if isinstance(dataset, vtk.vtkMultiBlockDataSet):
        writer = vtk.vtkXMLMultiBlockDataWriter()
    elif isinstance(dataset, vtk.vtkUnstructuredGrid):
        writer = vtk.vtkXMLUnstructuredGridWriter()
    elif isinstance(dataset, vtk.vtkPolyData):
        writer = vtk.vtkXMLPolyDataWriter()
    elif isinstance(dataset, vtk.vtkImageData):
        writer = vtk.vtkXMLImageDataWriter()
    else:
        return

    extension = writer.GetDefaultFileExtension()

    if binary == True:
        writer.SetDataModeToBinary()
    else:
        writer.SetDataModeToAscii()

    writer.SetInputData(dataset)

    writer.SetFileName(filename + '.' + extension)

    writer.Write()

f_source = r"C:\Scratch\from_mesh.vtp"
f_target = r"C:\Scratch\to_mesh.vtp"

# Create a reader and initialize wit eth file name
reader_source=vtk.vtkXMLPolyDataReader()
reader_source.SetFileName(f_source)

triangle = vtk.vtkDataSetTriangleFilter()
triangle.SetInputConnection(reader_source.GetOutputPort())

geometry = vtk.vtkGeometryFilter()
geometry.SetInputConnection(triangle.GetOutputPort())
geometry.Update()

triangular_source = geometry.GetOutput()
export_vtk_file(triangular_source, r'C:\Scratch\triangular_source')

reader_target=vtk.vtkXMLPolyDataReader()
reader_target.SetFileName(f_target)
reader_target.Update()
target = reader_target.GetOutput()

# target_points = numpy_support.vtk_to_numpy(target.GetPoints().GetData())
#
# cell_locator = vtk.vtkStaticCellLocator()
# cell_locator.SetDataSet(source_normals)
# cell_locator.BuildLocator()
#
# cellId = vtk.mutable(0)
# closest_point = [0.0, 0.0, 0.0]
# subId = vtk.mutable(0)
# d = vtk.mutable(0.0)
# closest_point_coords = np.zeros(shape = target_points.shape, dtype=np.float32)
#
# n=0
# for point in target_points:
#     cell_locator.FindClosestPoint(point, closest_point_coords[n], cellId, subId, d)
#     n+=1
#
# #target.GetPoints().SetData(numpy_support.numpy_to_vtk(closest_point_coords, deep=False, array_type=vtk.VTK_FLOAT))
# export_vtk_file(target, r'C:\Scratch\modified_target', binary = False)

triangular_source.GetPointData().SetActiveScalars('Pressure')

probe = vtk.vtkProbeFilter()
probe.SetComputeTolerance(False)
probe.SetTolerance(5.0)
locator = vtk.vtkCellLocator()
probe.SetCellLocatorPrototype(locator)

# locator.Initialize()
# locator.SetDataSet(triangles_out)
# locator.BuildLocator()

#locator_strategy = vtk.vtkClosestNPointsStrategy()
#locator_strategy.SetClosestNPoints(5)
#probe.SetFindCellStrategy(vtk.vtkCellLocatorStrategy())
#probe.SetFindCellStrategy(locator_strategy)

probe.SetInputData(target)
probe.SetSourceData(triangular_source)
probe.Update()
mapped = probe.GetOutput()

export_vtk_file(mapped, r'C:\Scratch\simple_mapped')