#!/usr/bin/env python3

import sys
from pathlib import Path

import vtkmodules.vtkRenderingOpenGL2  # noqa: F401

from vtkmodules.vtkFiltersSources import vtkSphereSource
from vtkmodules.vtkIOImage import vtkImageReader2Factory
from vtkmodules.vtkImagingCore import vtkImageFlip
from vtkmodules.vtkRenderingCore import (
    vtkActor,
    vtkPolyDataMapper,
    vtkRenderWindow,
    vtkRenderWindowInteractor,
    vtkRenderer,
    vtkSkybox,
    vtkTexture,
)
from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera


CUBEMAP_NAMES = [
    "posx",  # +X
    "negx",  # -X
    "posy",  # +Y
    "negy",  # -Y
    "posz",  # +Z
    "negz",  # -Z
]


def read_cubemap_from_dir(directory: Path) -> vtkTexture:
    texture = vtkTexture()
    texture.CubeMapOn()

    factory = vtkImageReader2Factory()

    for i, name in enumerate(CUBEMAP_NAMES):
        fn = directory / f"{name}.jpg"
        if not fn.is_file():
            sys.exit(f"Missing cubemap face: {fn}")

        reader = factory.CreateImageReader2(str(fn))
        reader.SetFileName(str(fn))

        # Flip Y axis to match OpenGL cubemap orientation
        flip = vtkImageFlip()
        flip.SetInputConnection(reader.GetOutputPort())
        flip.SetFilteredAxis(1)

        texture.SetInputConnection(i, flip.GetOutputPort())

    texture.InterpolateOn()
    texture.MipmapOn()
    return texture


def make_sphere():
    src = vtkSphereSource()
    src.SetRadius(1.0)
    src.SetThetaResolution(128)
    src.SetPhiResolution(128)
    src.Update()
    return src.GetOutput()


def main():
    if len(sys.argv) != 2:
        sys.exit(
            "Usage:\n"
            "  pbr.py <cubemap_dir>\n\n"
            "Expected files:\n"
            "  posx.jpg  negx.jpg\n"
            "  posy.jpg  negy.jpg\n"
            "  posz.jpg  negz.jpg"
        )

    cubemap_dir = Path(sys.argv[1])
    if not cubemap_dir.is_dir():
        sys.exit(f"Not a directory: {cubemap_dir}")

    ren = vtkRenderer()
    win = vtkRenderWindow()
    win.AddRenderer(ren)
    win.SetSize(1000, 700)

    iren = vtkRenderWindowInteractor()
    iren.SetRenderWindow(win)
    iren.SetInteractorStyle(vtkInteractorStyleTrackballCamera())

    env_tex = read_cubemap_from_dir(cubemap_dir)

    ren.AutomaticLightCreationOff()
    ren.UseImageBasedLightingOn()
    ren.UseSphericalHarmonicsOff()
    ren.SetEnvironmentTexture(env_tex, True)

    skybox = vtkSkybox()
    skybox.SetTexture(env_tex)
    ren.AddActor(skybox)

    mesh = make_sphere()

    mapper = vtkPolyDataMapper()
    mapper.SetInputData(mesh)

    actor = vtkActor()
    actor.SetMapper(mapper)

    prop = actor.GetProperty()
    prop.SetInterpolationToPBR()
    prop.SetMetallic(1.0)
    prop.SetRoughness(0.1)

    ren.AddActor(actor)

    win.Render()
    iren.Start()


if __name__ == "__main__":
    main()
