Skip to content

Commit 116fcd1

Browse files
AlejandroFernandezLucespyansys-ci-botclatapiegerma89
authored
fix(plotting): Improve interface of the plotting class. (#3702)
* fix(plotting): Test plotting changes * chore: adding changelog file 3702.fixed.md [dependabot-skip] * fix(plotting): Small bug * fix: Bug caused tests failure * fix: Add warning * chore: adding changelog file 3702.fixed.md [dependabot-skip] * test: Add mesh testing * test: Add docstring to test * fix: PyVista lazy import --------- Co-authored-by: pyansys-ci-bot <92810346+pyansys-ci-bot@users.noreply.github.com> Co-authored-by: Camille <78221213+clatapie@users.noreply.github.com> Co-authored-by: German <28149841+germa89@users.noreply.github.com>
1 parent 9f5e4e9 commit 116fcd1

3 files changed

Lines changed: 74 additions & 25 deletions

File tree

doc/changelog.d/3702.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
fix(plotting): Improve interface of the plotting class.

src/ansys/mapdl/core/plotting/visualizer.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222

2323
"""Module for the MapdlPlotter class."""
2424
from collections import OrderedDict
25-
from typing import Any, Iterable, Optional, Union
25+
from typing import Any, Dict, Iterable, Optional, Union
2626

2727
from ansys.tools.visualization_interface import Plotter
2828
from ansys.tools.visualization_interface.backends.pyvista import PyVistaBackendInterface
2929
import numpy as np
3030
from numpy.typing import NDArray
3131

32+
from ansys.mapdl.core import LOG as logger
3233
from ansys.mapdl.core import _HAS_VISUALIZER
3334
from ansys.mapdl.core.misc import get_bounding_box
3435
from ansys.mapdl.core.plotting.consts import (
@@ -277,9 +278,9 @@ def plot_iter(
277278

278279
def add_mesh(
279280
self,
280-
meshes,
281-
points,
282-
labels,
281+
meshes: Union[pv.PolyData, Dict[str, Any]] = [],
282+
points=[],
283+
labels=[],
283284
*,
284285
cpos=None,
285286
show_bounds=False,
@@ -328,6 +329,10 @@ def add_mesh(
328329
plotter_kwargs : dict, optional
329330
Extra kwargs, by default {}
330331
"""
332+
if not meshes and not points and not labels:
333+
logger.warning("No meshes, points or labels to plot.")
334+
return
335+
331336
if theme is None:
332337
theme = MapdlTheme()
333338

@@ -375,31 +380,43 @@ def add_mesh(
375380
**(add_points_kwargs or {}),
376381
)
377382

383+
if isinstance(meshes, pv.PolyData):
384+
meshes = [meshes]
378385
for mesh in meshes:
379-
scalars: Optional[NDArray[Any]] = mesh.get("scalars")
380-
381-
if (
382-
"scalars" in mesh
383-
and scalars.ndim == 2
384-
and (scalars.shape[1] == 3 or scalars.shape[1] == 4)
385-
):
386-
# for the case we are using scalars for plotting
387-
rgb = True
386+
rgb = False
387+
if isinstance(mesh, Dict):
388+
scalars: Optional[NDArray[Any]] = mesh.get("scalars")
389+
390+
if (
391+
"scalars" in mesh
392+
and scalars.ndim == 2
393+
and (scalars.shape[1] == 3 or scalars.shape[1] == 4)
394+
):
395+
# for the case we are using scalars for plotting
396+
rgb = True
397+
398+
# To avoid index error.
399+
mesh_ = mesh["mesh"]
400+
if not isinstance(mesh_, list):
401+
mesh_ = [mesh_]
388402
else:
389-
rgb = False
390-
391-
# To avoid index error.
392-
mesh_ = mesh["mesh"]
393-
if not isinstance(mesh_, list):
394-
mesh_ = [mesh_]
395-
403+
scalars = None
404+
mesh_ = meshes
396405
for each_mesh in mesh_:
397406
self.scene.add_mesh(
398407
each_mesh,
399408
scalars=scalars,
400409
scalar_bar_args=scalar_bar_args,
401-
color=mesh.get("color", color),
402-
style=mesh.get("style", style),
410+
color=(
411+
mesh.get("color", color)
412+
if isinstance(mesh, Dict) and "color" in mesh
413+
else color
414+
),
415+
style=(
416+
mesh.get("style", style)
417+
if isinstance(mesh, Dict) and "style" in mesh
418+
else style
419+
),
403420
show_edges=show_edges,
404421
edge_color=edge_color,
405422
smooth_shading=smooth_shading,
@@ -648,9 +665,9 @@ def bc_nodes_plot(
648665

649666
def plot(
650667
self,
651-
meshes,
652-
points,
653-
labels,
668+
meshes: Union[pv.PolyData, Dict[str, Any]] = [],
669+
points=[],
670+
labels=[],
654671
*,
655672
title="",
656673
cpos=None,

tests/test_plotting.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,37 @@ def test_plot_path(mapdl, tmpdir):
12621262
mapdl.eplot(vtk=False)
12631263

12641264

1265+
def test_add_mesh():
1266+
"""Test the add_mesh method from MapdlPlotter class."""
1267+
import pyvista as pv
1268+
1269+
cube1 = pv.Cube()
1270+
pl1 = MapdlPlotter()
1271+
pl1.add_mesh(cube1)
1272+
1273+
cube2 = pv.Cube()
1274+
meshes_dict = [
1275+
{
1276+
"mesh": cube2,
1277+
"scalars": np.random.default_rng(seed=1).random((8, 3)),
1278+
}
1279+
]
1280+
1281+
pl2 = MapdlPlotter()
1282+
pl2.add_mesh(meshes_dict)
1283+
1284+
cube3 = pv.Cube()
1285+
sphere = pv.Sphere()
1286+
1287+
pl3 = MapdlPlotter()
1288+
pl3.add_mesh([cube3, sphere])
1289+
1290+
assert pl1.meshes[0] == cube1
1291+
assert pl2.meshes[0] == cube2
1292+
assert pl3.meshes[0] == cube3
1293+
assert pl3.meshes[1] == sphere
1294+
1295+
12651296
def test_plot_path_screenshoot(mapdl, cleared, tmpdir):
12661297
mapdl.graphics("POWER")
12671298
# mapdl.screenshot is not affected by the device.

0 commit comments

Comments
 (0)