from timeit import default_timer as timer
from itertools import cycle

# noinspection PyUnresolvedReferences
import vtkmodules.vtkInteractionStyle
# noinspection PyUnresolvedReferences
import vtkmodules.vtkRenderingOpenGL2
from vtkmodules.vtkCommonDataModel import vtkPartitionedDataSetCollection, vtkDataAssembly
from vtkmodules.vtkCommonColor import vtkNamedColors
from vtkmodules.vtkFiltersSources import vtkSphereSource
from vtkmodules.vtkRenderingCore import (
    vtkActor,
    vtkRenderer,
    vtkRenderWindow,
    vtkRenderWindowInteractor,
    vtkCompositePolyDataMapper,
    vtkCompositeDataDisplayAttributes,
    vtkColorTransferFunction
)

import numpy as np
from vtkmodules.numpy_interface import dataset_adapter as dsa, algorithms as algs

class SimplePDSCDemo():
    def __init__(self):

        self.dset_colors=cycle([
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
            [0.0, 0.0, 1.0],
            [1.0, 1.0, 0.0],
            [0.0, 1.0, 1.0],
            [1.0, 0.0, 1.0]
        ])

    def start(self):
        pdsc=vtkPartitionedDataSetCollection()
        data_assembly = vtkDataAssembly()
        pdsc.SetDataAssembly(data_assembly)
        data_assembly.SetRootNodeName('SimplePDSCDemo')

        # Create a bunch of offset spheres, add a couple of scalar point data fields and add the spjeres to a vtkPartitionedDataSetCollection 
        dset_index=0
        for i in range(5):
            sphere=vtkSphereSource()
            sphere.SetCenter(0, 0, 0+i*2.)
            sphere.Update()
            poly=sphere.GetOutputDataObject(0)
            poly_np=dsa.WrapDataObject(poly)
            if i in [0, 2, 4]:
                scalarsx=np.zeros(poly.GetNumberOfPoints())
                scalarsx=poly_np.Points[:,0]
                poly_np.PointData.append(scalarsx, 'ScalarsX')

                # Setting active scalars here forces the mapper to use this scalar array
                # For fine grained control, set per block in the vtkCompositeDataDisplayAttributes
                #poly.GetPointData().SetActiveScalars('ScalarsX')

                scalarsy=np.zeros(poly.GetNumberOfPoints())
                scalarsy=poly_np.Points[:,1]
                poly_np.PointData.append(scalarsy, 'ScalarsY')

            current_node_id=data_assembly.AddNode('Part_'+str(i), 0)
            data_assembly.AddDataSetIndex(current_node_id, dset_index)
            pdsc.SetPartition(dset_index, 0, poly)
            dset_index+=1

        # Mapper, actor, renderer, and render window setup
        mapper = vtkCompositePolyDataMapper()
        mapper.SetInputDataObject(pdsc)

        # Setting the scalar mode to use point field data forces this mode on to every block
        # If you then want to use a single color per block, you have disable the sclaras per block....
        #mapper.SetScalarModeToUsePointFieldData()

        actor = vtkActor()
        actor.SetMapper(mapper)

        renderer = vtkRenderer()
        renderer.SetBackground(vtkNamedColors().GetColor3d('SteelBlue'))
        renderWindow = vtkRenderWindow()
        renderWindow.AddRenderer(renderer)
        renderWindowInteractor = vtkRenderWindowInteractor()
        renderWindowInteractor.SetRenderWindow(renderWindow)

        renderer.AddActor(actor)
        renderWindow.SetWindowName('CompositePolyDataMapper')

        renderWindow.Render()
        self._set_display_attributes(mapper)
        renderWindow.Render()
        renderWindowInteractor.Start()    

    def _set_display_attributes(self, mapper):
        pdsc=mapper.GetInputDataObject(0,0)
        indices=pdsc.GetDataAssembly().GetDataSetIndices(0)
        if mapper.GetCompositeDataDisplayAttributes() is None:
            mapper.SetCompositeDataDisplayAttributes(vtkCompositeDataDisplayAttributes())
        display_attributes=mapper.GetCompositeDataDisplayAttributes()

        for index in indices:
            dset=pdsc.GetPartitionAsDataObject(index,0)

            # Try to set the middle sphere to use a different scalar array
            if dset.GetPointData().HasArray('ScalarsX') and index==0:
                # Set scalar mode to point field data    
                display_attributes.SetBlockScalarMode(dset, 3)
                display_attributes.SetBlockArrayAccessMode(dset, 1)
                display_attributes.SetBlockArrayName(dset, 'ScalarsX')
            elif dset.GetPointData().HasArray('ScalarsY') and index==2:
                # Set scalar mode to point field data    
                display_attributes.SetBlockScalarMode(dset, 3)
                display_attributes.SetBlockArrayAccessMode(dset, 1)
                display_attributes.SetBlockArrayName(dset, 'ScalarsY')
            elif dset.GetPointData().HasArray('ScalarsX') and index==4:
                # Set scalar mode to point field data    
                display_attributes.SetBlockScalarMode(dset, 3)
                display_attributes.SetBlockArrayAccessMode(dset, 1)
                display_attributes.SetBlockArrayName(dset, 'ScalarsX')
            else :
                display_attributes.SetBlockColor(dset, next(self.dset_colors))

if __name__=='__main__':

    test=SimplePDSCDemo()
    test.start()