how to speedup vtkCellPicker ?

Hi all,

below is a snippet of a molecular viewer written using vtk/PyQt which I adapted from a colleague’s code. Basically, the viewer is quite efficient regarding the mouse interaction (translation, rotation, zoom). I also need to add a callback which gives me the index of the atom which is picked, this index being transferred in other part of the whole PyQt application through signal/slot mechanism. Here the troubles begin. To do so, a vtkCellPicker is used but I found that it was painfully slow to retrieve that index (see on_pick method). Eventually, it even freezes my application. I tried to use a vtkCellLocator by doing:

locator = vtk.vtkCellLocator()
locator.SetDataSet(self._polydata)
locator.BuildLocator()
picker.AddLocator(locator)

But it did not improve the situation.

I also tried other pickers which seem to be much much faster ( vtkPropPicker , vtkPointPicker ) but I do not know how to retrieve the index of the picked atom from those pickers. Would you have any idea how to make it ? Being new in vtk, code snippet would be very welcome. Many thanks in advance.

import sys

import numpy as np

from PyQt5 import QtCore, QtWidgets

import vtk
from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor

vtk.vtkObject.GlobalWarningDisplayOff()

RGB_COLOURS = {}
RGB_COLOURS[“selection”] = (0, (1.00, 0.20, 1.00))
RGB_COLOURS[“default”] = (1, (1.00, 0.90, 0.90))

CHEMICAL_ELEMENTS = {}
CHEMICAL_ELEMENTS[‘H’] = {‘vdw_radius’: 1.09, ‘color’: ‘255;255;255’}
CHEMICAL_ELEMENTS[‘C’] = {‘vdw_radius’: 1.70, ‘color’: ‘0;255;0’}
CHEMICAL_ELEMENTS[‘N’] = {‘vdw_radius’: 1.55, ‘color’: ‘0;0;255’}
CHEMICAL_ELEMENTS[‘O’] = {‘vdw_radius’: 1.52, ‘color’: ‘255;0;0’}

NUMBER_OF_ATOMS = 100000

RESOLUTION = int(np.sqrt(5000000.0 / NUMBER_OF_ATOMS))
RESOLUTION = 10 if RESOLUTION > 10 else RESOLUTION
RESOLUTION = 4 if RESOLUTION < 4 else RESOLUTION

def color_string_to_rgb(color):
“”"Convert a color stored in r;g;b format to [r/255.0,g/255.0,b/255.0] format.

Args:
    color (str): the color to convert
"""

if not color.strip():
    color = "1;1;1"

return np.array(color.split(';')).astype(np.float32)/255.

def ndarray_to_vtkarray(colors, scales, n_atoms):
“”"Convert the colors and scales NumPy arrays to vtk arrays.

Args:
    colors (numpy.array): the colors
    scales (numpy.array): the scales
    n_atoms (int): the number of atoms
"""
# define the colours
color_scalars = vtk.vtkFloatArray()
color_scalars.SetNumberOfValues(len(colors))
for i, c in enumerate(colors):
    color_scalars.SetValue(i, c)
color_scalars.SetName("colors")

# some scales
scales_scalars = vtk.vtkFloatArray()
scales_scalars.SetNumberOfValues(scales.shape[0])
for i, r in enumerate(scales):
    scales_scalars.SetValue(i, r)
scales_scalars.SetName("scales")

# the original index
index_scalars = vtk.vtkIntArray()
index_scalars.SetNumberOfValues(n_atoms)
for i in range(n_atoms):
    index_scalars.SetValue(i, i)
index_scalars.SetName("index")

scalars = vtk.vtkFloatArray()
scalars.SetNumberOfComponents(3)
scalars.SetNumberOfTuples(scales_scalars.GetNumberOfTuples())
scalars.CopyComponent(0, scales_scalars, 0)
scalars.CopyComponent(1, color_scalars, 0)
scalars.CopyComponent(2, index_scalars, 0)
scalars.SetName("scalars")
return scalars

class MolecularViewer(QtWidgets.QWidget):
“”“This class implements a molecular viewer.
“””

def __init__(self, parent):

    super(MolecularViewer, self).__init__(parent)

    self._iren = QVTKRenderWindowInteractor(self)

    self._renderer = vtk.vtkRenderer()

    self._iren.GetRenderWindow().AddRenderer(self._renderer)

    self._iren.GetRenderWindow().SetPosition((0, 0))

    self._iren.GetInteractorStyle().SetCurrentStyleToTrackballCamera()

    self._iren.Enable()

    self._camera = vtk.vtkCamera()
    self._renderer.SetActiveCamera(self._camera)
    self._camera.SetFocalPoint(0, 0, 0)
    self._camera.SetPosition(0, 0, 20)

    self._previously_picked_atom = None

    self._iren.Initialize()

    self._atoms = np.random.choice(list(CHEMICAL_ELEMENTS.keys()), NUMBER_OF_ATOMS).tolist()

    self._atom_colours, self._lut = self.build_color_transfer_function()

    self._atom_scales = np.array([CHEMICAL_ELEMENTS[at]['vdw_radius'] for at in self._atoms]).astype(np.float32)

    scalars = ndarray_to_vtkarray(self._atom_colours, self._atom_scales, len(self._atoms))

    self._polydata = vtk.vtkPolyData()
    self._polydata.GetPointData().SetScalars(scalars)

    coordinates = np.random.uniform(-100, 100, (NUMBER_OF_ATOMS, 3))

    self.set_coordinates(coordinates)

    self._iren.AddObserver("LeftButtonPressEvent", self.on_pick)

@property
def iren(self):
    return self._iren

@property
def renderer(self):
    return self._renderer

def build_color_transfer_function(self):
    """Returns the colors and their associated transfer function
    """

    lut = vtk.vtkColorTransferFunction()

    for (idx, color) in RGB_COLOURS.values():
        lut.AddRGBPoint(idx, *color)

    colours = []
    unic_colours = {}

    color_string_list = [color_string_to_rgb(CHEMICAL_ELEMENTS[at]['color']) for at in self._atoms]

    col_ids = len(RGB_COLOURS)

    for col in color_string_list:
        tup_col = tuple(col)
        if not (tup_col in unic_colours.keys()):
            unic_colours[tup_col] = col_ids
            lut.AddRGBPoint(col_ids, *tup_col)
            colours.append(col_ids)
            col_ids += 1
        else:
            colours.append(unic_colours[tup_col])

    return colours, lut

def build_scene(self):
    '''
    build a vtkPolyData object for a given frame of the trajectory
    '''

    actor_list = []
    line_actor = None

    line_mapper = vtk.vtkPolyDataMapper()
    if vtk.vtkVersion.GetVTKMajorVersion() < 6:
        line_mapper.SetInput(self._polydata)
    else:
        line_mapper.SetInputData(self._polydata)

    line_mapper.SetLookupTable(self._lut)
    line_mapper.ScalarVisibilityOn()
    line_mapper.ColorByArrayComponent("scalars", 1)
    line_actor = vtk.vtkLODActor()
    line_actor.GetProperty().SetLineWidth(3)
    line_actor.SetMapper(line_mapper)
    actor_list.append(line_actor)

    sphere = vtk.vtkSphereSource()
    sphere.SetCenter(0, 0, 0)
    sphere.SetRadius(0.2)
    sphere.SetThetaResolution(RESOLUTION)
    sphere.SetPhiResolution(RESOLUTION)
    glyph = vtk.vtkGlyph3D()
    glyph.SetInputData(self._polydata)

    glyph.SetScaleModeToScaleByScalar()
    glyph.SetColorModeToColorByScalar()
    glyph.SetScaleFactor(1)
    glyph.SetSourceConnection(sphere.GetOutputPort())
    glyph.SetIndexModeToScalar()
    sphere_mapper = vtk.vtkPolyDataMapper()
    sphere_mapper.SetLookupTable(self._lut)
    sphere_mapper.SetScalarRange(self._polydata.GetScalarRange())
    sphere_mapper.SetInputConnection(glyph.GetOutputPort())
    sphere_mapper.ScalarVisibilityOn()
    sphere_mapper.ColorByArrayComponent("scalars", 1)
    ball_actor = vtk.vtkLODActor()
    ball_actor.SetMapper(sphere_mapper)
    ball_actor.GetProperty().SetAmbient(0.2)
    ball_actor.GetProperty().SetDiffuse(0.5)
    ball_actor.GetProperty().SetSpecular(0.3)
    ball_actor.SetNumberOfCloudPoints(30000)
    actor_list.append(ball_actor)
    self.glyph = glyph

    self._picking_domain = ball_actor

    assembly = vtk.vtkAssembly()
    for actor in actor_list:
        assembly.AddPart(actor)

    return assembly

def clear_actors(self):
    """Clear the trajectory and the vtk scene.
    """

    if not hasattr(self, "_actors"):
        return

    self._actors.VisibilityOff()
    self._actors.ReleaseGraphicsResources(self._iren.GetRenderWindow())
    self._renderer.RemoveActor(self._actors)

    del self._actors

def get_atom_index(self, pid):
    """Return the atom index from the vtk data point index.

    Args:
        pid (int): the data point index
    """

    _, _, idx = self.glyph.GetOutput().GetPointData().GetArray("scalars").GetTuple3(pid)

    return int(idx)

def on_pick(self, obj, event=None):
    """Event handler when an atom is mouse-picked with the left mouse button
    """

    # Get the picked position and retrieve the index of the atom that was picked from it
    pos = obj.GetEventPosition()

    picker = vtk.vtkCellPicker()
    picker.SetTolerance(0.005)

    picker.AddPickList(self._picking_domain)
    picker.PickFromListOn()
    picker.Pick(pos[0], pos[1], 0, self._renderer)
    pid = picker.GetPointId()
    if pid > 0:
        idx = self.get_atom_index(pid)
        self.on_pick_atom(idx)

def on_pick_atom(self, picked_atom):
    """Change the color of a selected atom
    """

    # If an atom was previously picked, restore its scale and color
    if self._previously_picked_atom is not None:
        index, scale, color = self._previously_picked_atom
        self._atom_scales[index] = scale
        self._atom_colours[index] = color
        self._polydata.GetPointData().GetArray("scalars").SetTuple3(
            index, self._atom_scales[index], self._atom_colours[index], index)

    # Save the scale and color of the picked atom
    self._previously_picked_atom = (
        picked_atom, self._atom_scales[picked_atom], self._atom_colours[picked_atom])

    # Set its colors with the default value for atom selection and increase its size
    self._atom_colours[picked_atom] = RGB_COLOURS['selection'][0]
    self._atom_scales[picked_atom] *= 2

    self._polydata.GetPointData().GetArray("scalars").SetTuple3(picked_atom, self._atom_scales[picked_atom], self._atom_colours[picked_atom], picked_atom)

    self._polydata.Modified()

    # self._iren.Render()

def set_coordinates(self, coords):
    '''
    Sets a new configuration

    @param frame: the configuration number
    @type frame: integer
    '''

    self.clear_actors()

    points = vtk.vtkPoints()
    points.SetNumberOfPoints(len(self._atoms))
    for i, (x, y, z) in enumerate(coords):
        points.SetPoint(i, x, y, z)

    self._polydata.SetPoints(points)

    # Update the view.
    self.update_renderer()

def update_renderer(self):
    '''
    Update the renderer
    '''
    # deleting old frame
    self.clear_actors()

    # creating new polydata
    self._actors = self.build_scene()

    # adding polydata to renderer
    self._renderer.AddActor(self._actors)

    # rendering
    self._iren.Render()

class MainWindow(QtWidgets.QMainWindow):
“”“This class implements the main window of the application.
“””

def __init__(self, parent=None):
    super(MainWindow, self).__init__(parent)

    self.init_ui()

def build_layout(self):
    """Build the layout of the main window.
    """

    self._vl = QtWidgets.QVBoxLayout()
    self._vl.addWidget(self._molecular_viewer.iren)

    self._main_frame.setLayout(self._vl)

def build_widgets(self):
    """Build the widgets of the main window.
    """

    self._main_frame = QtWidgets.QFrame(self)

    self._molecular_viewer = MolecularViewer(self._main_frame)
    self._molecular_viewer.renderer.ResetCamera()
    self._molecular_viewer.iren.Initialize()
    self._molecular_viewer.iren.Start()

    self.setCentralWidget(self._main_frame)

    self.setGeometry(0, 0, 800, 800)

    self.show()

def init_ui(self):
    """Set the widgets of the main window
    """

    self.build_widgets()
    self.build_layout()

app = QtWidgets.QApplication(sys.argv)

window = MainWindow()

sys.exit(app.exec_())