diff --git a/pyproject.toml b/pyproject.toml index 2e61bbd..280b999 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "maxplotlibx" -version = "0.1.1" +version = "0.1.2" description = "A reproducible plotting module with various backends and export options." readme = "README.md" requires-python = ">=3.8" @@ -18,6 +18,7 @@ dependencies = [ "matplotlib", "pint", "plotly", + "tikzpics>=0.1.1", ] [project.optional-dependencies] test = [ diff --git a/src/maxplotlib/backends/matplotlib/utils.py b/src/maxplotlib/backends/matplotlib/utils.py index 6e15ca5..6d9080b 100644 --- a/src/maxplotlib/backends/matplotlib/utils.py +++ b/src/maxplotlib/backends/matplotlib/utils.py @@ -1,18 +1,9 @@ # import sys; from os.path import dirname; sys.path.append(f'{dirname(__file__)}/../../') # import matplotlib.pylab as pylab -import math -import pickle -from pathlib import Path -import _pickle as cPickle -import matplotlib.colors as mcolors import matplotlib.pyplot as plt -import numpy as np import pint -from matplotlib.collections import PatchCollection -from mpl_toolkits.mplot3d import Axes3D -from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection def setup_tex_fonts(fontsize=14, usetex=False): @@ -71,7 +62,6 @@ def _2pt(width, dpi=300): elif isinstance(width, str): length_in = convert_to_inches(width) length_pt = length_in * dpi - # print(f"{length_in = } {length_pt = }") return length_pt else: raise NotImplementedError diff --git a/src/maxplotlib/backends/matplotlib/utils_old.py b/src/maxplotlib/backends/matplotlib/utils_old.py deleted file mode 100644 index 2516ab6..0000000 --- a/src/maxplotlib/backends/matplotlib/utils_old.py +++ /dev/null @@ -1,852 +0,0 @@ -# import sys; from os.path import dirname; sys.path.append(f'{dirname(__file__)}/../../') - -# import matplotlib.pylab as pylab -import math -import pickle -from pathlib import Path - -import _pickle as cPickle -import matplotlib.colors as mcolors -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.collections import PatchCollection -from mpl_toolkits.mplot3d import Axes3D -from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection - - -class Color: - def __init__(self, hex_color): - self.hx = hex_color - self.rgb = tuple(int(hex_color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4)) - - self.rgb_dec = [i / 255 for i in self.rgb] - self.rgb_dec_str = ["{:.6f}".format(i) for i in self.rgb_dec] - - self.rgb_inv = tuple(np.subtract((256, 256, 256), self.rgb)) - self.rgb_dec_inv = [1.0 - c for c in self.rgb_dec] - # self.pgf_col_str = '\definecolor{currentstroke}{rgb}{' - self.pgf_col_str = "{rgb}{" - self.pgf_col_str += self.rgb_dec_str[0] + "," - self.pgf_col_str += self.rgb_dec_str[1] + "," - self.pgf_col_str += self.rgb_dec_str[2] + "}%" - - def invert(self): - inverted_color = Color(self.hx) - - def define_color_str(self, name): - hex_str = self.hx.replace("#", "") - out_str = "\\definecolor{" + name + "}{HTML}{" + hex_str + "}" - out_str += " " * (70 - len(out_str)) + "% https://www.colorhexa.com/" + hex_str - return out_str - - def __str__(self): - return self.hx - - -# https://matplotlib.org/stable/gallery/color/named_colors.html -def mcolors2mplcolors(colors): - names = sorted(colors, key=lambda c: tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(c)))) - col_dict = dict() - for name in names: - col_dict[name] = Color(colors[name]) - return col_dict - - -def import_colors(cmap="pastel"): - col_dict = dict() - col_dict["pastel"] = mcolors2mplcolors(mcolors.CSS4_COLORS) - col_dict["cmap2"] = mcolors2mplcolors(mcolors.CSS4_COLORS) - col_dict["thesis_colors"] = mcolors2mplcolors(mcolors.CSS4_COLORS) - - col_dict["pastel"]["white"] = Color("#ffffff") - col_dict["pastel"]["black"] = Color("#000000") - col_dict["pastel"]["yellow"] = Color("#FFFFB3") - col_dict["pastel"]["dkyellow"] = Color("#FFED6F") - col_dict["pastel"]["purple"] = Color("#BEBADA") - col_dict["pastel"]["dkpurple"] = Color("#BC80BD") - col_dict["pastel"]["red"] = Color("#FB8072") - col_dict["pastel"]["ltred"] = Color("#FFCCCB") - col_dict["pastel"]["dkred"] = Color("#CB0505") - col_dict["pastel"]["orange"] = Color("#FDB462") - col_dict["pastel"]["dkgold"] = Color("#B8860B") - col_dict["pastel"]["blue"] = Color("#80B1D3") - col_dict["pastel"]["dkblue"] = Color("#00008B") - col_dict["pastel"]["deepskyblue"] = Color("#1f78b4") - col_dict["pastel"]["green"] = Color("#B3DE69") - col_dict["pastel"]["ltgreen"] = Color("#CCEBC5") - col_dict["pastel"]["dkgreen"] = Color("#006400") - col_dict["pastel"]["bluegreen"] = Color("#8DD3C7") - col_dict["pastel"]["pink"] = Color("#FCCDE5") - col_dict["pastel"]["ltgray"] = Color("#D9D9D9") - col_dict["pastel"]["dkgray"] = Color("#515151") - col_dict["pastel"]["brown"] = Color("#D2691E") - - col_dict["cmap2"]["white"] = Color("#ffffff") - col_dict["cmap2"]["black"] = Color("#000000") - col_dict["cmap2"]["yellow"] = Color("#ffff99") - col_dict["cmap2"]["dkyellow"] = Color("#FFED6F") - col_dict["cmap2"]["ltpurple"] = Color("#cab2d6") - col_dict["cmap2"]["purple"] = Color("#6a3d9a") - col_dict["cmap2"]["dkpurple"] = Color("#BC80BD") - col_dict["cmap2"]["red"] = Color("#e31a1c") - col_dict["cmap2"]["ltred"] = Color("#fb9a99") - col_dict["cmap2"]["dkred"] = Color("#CB0505") - col_dict["cmap2"]["ltorange"] = Color("#fdbf6f") - col_dict["cmap2"]["orange"] = Color("#ff7f00") - col_dict["cmap2"]["blue"] = Color("#1f78b4") - col_dict["cmap2"]["dkblue"] = Color("#00008B") - col_dict["cmap2"]["deepskyblue"] = Color("#1f78b4") - col_dict["cmap2"]["green"] = Color("#33a02c") - col_dict["cmap2"]["ltgreen"] = Color("#b2df8a") - col_dict["cmap2"]["dkgreen"] = Color("#006400") - col_dict["cmap2"]["bluegreen"] = Color("#8DD3C7") - col_dict["cmap2"]["pink"] = Color("#FCCDE5") - col_dict["cmap2"]["ltgray"] = Color("#D9D9D9") - col_dict["cmap2"]["dkgray"] = Color("#515151") - col_dict["cmap2"]["brown"] = Color("#b15928") - - -def import_col_list(cmap="cmap2"): - col_dict = import_colors(cmap) - col_list = [ - col_dict["black"], - col_dict["blue"], - col_dict["red"], - col_dict["green"], - col_dict["orange"], - col_dict["purple"], - col_dict["gray"], - col_dict["brown"], - # col_dict['green'], - ] - return col_list - - -def id2color(id, cmap="cmap2"): - col_list = import_col_list(cmap=cmap) - return col_list[id % len(col_list)].hx - - -# -# ,------.,--. ,--. -# | .---'`--' ,---. ,--.,--.,--.--. ,---. ,---. ,---. ,-' '-.,--.,--. ,---. -# | `--, ,--.| .-. || || || .--'| .-. : ( .-' | .-. :'-. .-'| || || .-. | -# | |` | |' '-' '' '' '| | \ --. .-' `)\ --. | | ' '' '| '-' ' -# `--' `--'.`- / `----' `--' `----' `----' `----' `--' `----' | |-' -# `---' `--' - - -linestyles = dict() -linestyles["solid"] = "solid" # Same as (0, ()) or '-' -linestyles["dotted"] = "dotted" # Same as (0, (1, 1)) or '.' -linestyles["dashed"] = "dashed" # Same as '--' -linestyles["dashdot"] = "dashdot" # Same as '-.' -linestyles["loosely dotted"] = (0, (1, 10)) -linestyles["dotted"] = (0, (1, 1)) -linestyles["densely dotted"] = (0, (1, 1)) - -linestyles["loosely dashed"] = (0, (5, 10)) -linestyles["dashed"] = (0, (5, 5)) -linestyles["densely dashed"] = (0, (5, 1)) - -linestyles["loosely dashdotted"] = (0, (3, 10, 1, 10)) -linestyles["dashdotted"] = (0, (3, 5, 1, 5)) -linestyles["densely dashdotted"] = (0, (3, 1, 1, 1)) - -linestyles["dashdotdotted"] = (0, (3, 5, 1, 5, 1, 5)) -linestyles["loosely dashdotdotted"] = (0, (3, 10, 1, 10, 1, 10)) -linestyles["densely dashdotdotted"] = (0, (3, 1, 1, 1, 1, 1)) - -linestyle_list = [linestyles[l] for l in linestyles] -linestyle_list_ordered = [ - linestyles["solid"], - linestyles["densely dashdotted"], - linestyles["dashed"], - linestyles["dotted"], -] - - -class figure: - def __init__( - self, - load_file="", - nx_subplots=1, - ny_subplots=1, - width=426.79135, - figsize=None, - scale_width=1, - dpi=300, - threeD=False, - ratio="golden", - legend=True, - axes_grid=False, - gridspec_kw={"wspace": 0.08, "hspace": 0.1}, - legend_position="upper right", - filename="MaxFigureClassInstance", - directory="./", - cmap="cmap2", - fontsize=14, - tex_fonts=True, - ): - # if width == 'singlecol': - # width = 426.79135 / 2.0 - # if width == 'doublecol': - # width = 426.79135 - - self.nx_subplots = nx_subplots - self.ny_subplots = ny_subplots - self.width = width * scale_width - self.dpi = dpi - self.threeD = threeD - self.ratio = ratio - self.legend = legend - self.filename = filename - self.directory = directory - self.axes_grid = axes_grid - self.gridspec_kw = gridspec_kw - self.cmap = cmap - self.col_list = import_colors(self.cmap) - - self.axes_grid_which = "major" - self.grid_alpha = 1.0 - self.grid_linestyle = linestyles["densely dotted"] - self.fontsize = fontsize - # print(self.directory) - if len(self.directory) > 0: - if not self.directory[-1] == "/": - self.directory += "/" - - # - # plt.style.use('seaborn') - # - if not figsize == None: - self.width = figsize[0] - self.ratio = figsize[0] / figsize[1] - - if tex_fonts: - self.setup_tex_fonts() - - if not load_file == "": - self.load(load_file) - elif threeD: - self.create_3dplot() - else: - self.create_lineplot() - - def setup_tex_fonts(self): - self.tex_fonts = { - # Use LaTeX to write all text - "text.usetex": True, - "font.family": "serif", - "pgf.rcfonts": False, # don't setup fonts from rc parameters - # Use 10pt font in plots, to match 10pt font in document - "axes.labelsize": self.fontsize, - "font.size": self.fontsize, - # Make the legend/label fonts a little smaller - "legend.fontsize": self.fontsize, - "xtick.labelsize": self.fontsize, - "ytick.labelsize": self.fontsize, - } - self.setup_plotstyle( - tex_fonts=self.tex_fonts, - axes_grid=self.axes_grid, - axes_grid_which=self.axes_grid_which, - grid_alpha=self.grid_alpha, - grid_linestyle=self.grid_linestyle, - ) - - def create_lineplot(self): - self.fig, self.axs = plt.subplots( - self.ny_subplots, - self.nx_subplots, - figsize=self.set_size( - self.width, - ratio=self.ratio, # - ), # sharex=True,#sharex='all', sharey='all', - dpi=self.dpi, - constrained_layout=False, - gridspec_kw=self.gridspec_kw, - ) - - def create_3dplot(self): - self.fig = plt.figure(figsize=self.set_size(self.width), dpi=self.dpi) - self.axs = self.fig.add_subplot(111, projection="3d") - - def setup_plotstyle( - self, - tex_fonts=True, - axes_grid=True, - axes_grid_which="major", - grid_alpha=0.0, - grid_linestyle="dotted", - ): - if tex_fonts: - plt.rcParams.update(self.tex_fonts) - - plt.rcParams["axes.grid"] = axes_grid # False ## display grid or not - # gridlines at major, minor or both ticks - plt.rcParams["axes.grid.which"] = axes_grid_which - plt.rcParams["grid.alpha"] = grid_alpha # transparency, between 0.0 and 1.0 - plt.rcParams["grid.linestyle"] = grid_linestyle - - plt.rcParams["xtick.direction"] = "in" - plt.rcParams["ytick.direction"] = "in" - - # This is to avoid the overlapping tick labels. - plt.rcParams["xtick.major.pad"] = 8 - plt.rcParams["ytick.major.pad"] = 8 - # plt.rc('text.latex', preamble=r'\usepackage{wasysym}') - - def set_common_xlabel(self, xlabel="common X"): - self.fig.text( - 0.5, - -0.075, - xlabel, - va="center", - ha="center", - fontsize=self.fontsize, - ) - # fig.text(0.04, 0.5, 'common Y', va='center', ha='center', rotation='vertical', fontsize=rcParams['axes.labelsize']) - - def set_size(self, width, fraction=1, ratio="golden"): - """Set figure dimensions to avoid scaling in LaTeX. - Parameters - ---------- - width: float - Document textwidth or columnwidth in pts - fraction: float, optional - Fraction of the width which you wish the figure to occupy - Returns - ------- - fig_dim: tuple - Dimensions of figure in inches - """ - # - # Width of figure (in pts) - if width == "thesis": - width_pt = 426.79135 - elif width == "beamer": - width_pt = 307.28987 - else: - width_pt = width - - # Width of figure - fig_width_pt = width_pt * fraction - - # Convert from pt to inches - inches_per_pt = 1 / 72.27 - - # Golden ratio to set aesthetic figure height - # https://disq.us/p/2940ij3 - golden_ratio = (5**0.5 - 1) / 2 - - # Figure width in inches - fig_width_in = fig_width_pt * inches_per_pt - - if ratio == "golden": - # Figure height in inches - fig_height_in = fig_width_in * golden_ratio - - elif ratio == "square": - # Figure height in inches - fig_height_in = fig_width_in - # print('ratio',ratio) - if type(ratio) == int or type(ratio) == float: - fig_height_in = fig_width_in * ratio - - fig_dim = (fig_width_in, fig_height_in) - - return fig_dim - - def get_axis(self, subfigure): - if subfigure == -1: - return self.axs - elif not isinstance(subfigure, list): - return self.axs[subfigure] - elif isinstance(subfigure, list) and len(subfigure) == 2: - return self.axs[subfigure[0], subfigure[1]] - - def get_limits(self, ax=None): - if ax == None: - xxmin, xxmax = self.axs.get_xlim() - yymin, yymax = self.axs.get_ylim() - else: - xxmin, xxmax = ax.get_xlim() - yymin, yymax = ax.get_ylim() - arr = [xxmin, xxmax, yymin, yymax] - return arr - - def set_labels(self, delta, point, subfigure=-1, axis="x"): - ax = self.get_axis(subfigure) - plt.sca(ax) - if axis == "x": - xmin, xmax = ax.get_xlim() - width = int((xmax - xmin) / delta + 1) * delta - locs, labels = plt.xticks() - i0 = int(xmin / delta) - i1 = int(xmax / delta) - xvec = [] - xvec = np.arange(point - width, point + width + delta, delta) - xvec += point - xvec = xvec[xvec >= xmin] - xvec = xvec[xvec <= xmax] - new_labels = [i * delta for i in range(i0, i1)] - # if precision == 0: new_labels = [int(x) for x in new_labels] - plt.xticks(xvec, xvec) - if axis == "y": - return - - def scale_axis( - self, - subfigure=-1, - axis="x", - axs_in=None, - scale=1.0, - shift=0, - precision=2, - delta=-1, - includepoint=-1, - nticks=5, - locs_labels=None, - ): - # https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.xticks.html - # if subfigure_x == -1 and subfigure_y == -1 and nx_subplots > 1 and ny_subplots > 1: - # print('enter subfigure_x and subfigure_y!') - # return - # if subfigure_x == -1 and subfigure_y == -1: - # ax = self.axs[] - if subfigure == -1: - ax = self.axs - elif not isinstance(subfigure, list): - ax = self.axs[subfigure] - elif isinstance(subfigure, list) and len(subfigure) == 2: - ax = self.axs[subfigure[0], subfigure[1]] - - if not axs_in == None: - ax = axs_in - # print("precision", precision, precision, precision) - plt.sca(ax) - if axis == "x": - if locs_labels == None: - xmin, xmax = ax.get_xlim() - locs, labels = plt.xticks() - if delta == -1 and includepoint == -1: - new_labels = [round((x + shift) * scale, precision) for x in locs] - if precision == 0: - new_labels = [int(x) for x in new_labels] - else: - if delta == -1: - delta = (xmax - xmin) / (nticks - 1) - if includepoint == -1: - includepoint = xmin - width = int((xmax - xmin) / delta + 1) * delta - i0 = int(xmin / delta) - i1 = int(xmax / delta + 1) - locs = np.arange( - includepoint - width, - includepoint + width + delta, - delta, - ) - locs = locs[locs >= xmin - 1e-12] - locs = locs[locs <= xmax + 1e-12] - - new_labels = [round((x + shift) * scale, precision) for x in locs] - if precision == 0: - new_labels = [int(y) for y in new_labels] - new_labels = [f"${l}$" for l in new_labels] - # plt.xticks(locs,new_labels) - ax.set_xticks(locs) - ax.set_xticklabels(new_labels) - ax.axis(xmin=xmin, xmax=xmax) - else: - # plt.xticks(locs_labels['locs'],locs_labels['labels']) - ax.set_xticks(locs_labels["locs"]) - ax.set_xticklabels(locs_labels["labels"]) - - if axis == "y": - if locs_labels == None: - ymin, ymax = ax.get_ylim() - locs, labels = plt.yticks() - if delta == -1 and includepoint == -1: - new_labels = [round((y + shift) * scale, precision) for y in locs] - if precision == 0: - new_labels = [int(y) for y in new_labels] - else: - if delta == -1: - delta = (ymax - ymin) / (nticks - 1) - if includepoint == -1: - includepoint = ymin - width = int((ymax - ymin) / delta + 1) * delta - i0 = int(ymin / delta) - i1 = int(ymax / delta + 1) - locs = np.arange( - includepoint - width, - includepoint + width + delta, - delta, - ) - locs = locs[locs >= ymin - 1e-12] - locs = locs[locs <= ymax + 1e-12] - - new_labels = [round((y + shift) * scale, precision) for y in locs] - if precision == 0: - new_labels = [int(y) for y in new_labels] - new_labels = [f"${l}$" for l in new_labels] - # plt.yticks(locs,new_labels) - - ax.set_yticks(locs) - ax.set_yticklabels(new_labels) - - ax.axis(ymin=ymin, ymax=ymax) - else: - # plt.yticks(locs_labels['locs'],locs_labels['labels']) - ax.set_yticks(locs_labels["locs"]) - ax.set_yticklabels(locs_labels["labels"]) - - def adjustFigAspect(self, aspect=1): - """ - Adjust the subplot parameters so that the figure has the correct - aspect ratio. - """ - xsize, ysize = self.fig.get_size_inches() - minsize = min(xsize, ysize) - xlim = 0.4 * minsize / xsize - ylim = 0.4 * minsize / ysize - if aspect < 1: - xlim *= aspect - else: - ylim /= aspect - self.fig.subplots_adjust( - left=0.5 - xlim, - right=0.5 + xlim, - bottom=0.5 - ylim, - top=0.5 + ylim, - ) - - def add_figure_label( - self, - label, - pos="top left", - bbox=dict(facecolor="white", edgecolor="gray", boxstyle="round"), - ha="left", - va="top", - ax=None, - ): - limits = self.get_limits(ax) - # print(limits) - lx = limits[1] - limits[0] - ly = limits[3] - limits[2] - - if isinstance(pos, str): - if "top" in pos: - y = limits[2] + 0.90 * ly - elif "center in pos": - y = limits[2] + 0.5 * ly - else: - y = limits[2] + 0.05 * ly - - if "left" in pos: - x = limits[0] + 0.1 * lx - else: - x = limits[0] + 0.95 * lx - else: - x = limits[0] + pos[0] * lx - y = limits[2] + pos[1] * ly - - # print(x,y) - if ax == None: - ax = self.axs - ax.text( - x, - y, - f"{label}", - rotation=0, - ha=ha, - va=va, - bbox=bbox, - fontsize=self.fontsize, - ) - - def savefig( - self, - filename="", - formats=["png"], - format="", - create_sh_file=False, - print_imgcat=True, - format_folder=False, - tight_layout=True, - ): - # self.update_figure() - # self.fig.tight_layout() - # print(self.directory) - - if "/" in self.filename: - tmp = self.filename - spl = tmp.split("/") - self.filename = spl[-1] - self.directory = tmp.replace(spl[-1], "") - - if not self.directory == "": - Path(self.directory).mkdir(parents=True, exist_ok=True) - - if isinstance(format, list): - formats = format - format = "" - if not format == "": - formats = [format] - if filename == "": - filename = self.filename - - if formats == "all" or formats == ["all"]: - self.dump() - formats = ["jpg", "pdf", "pgf", "png", "svg", "txt", "pickle", "tex"] - - if isinstance(formats, str): - formats = [formats] - - if format_folder: - for format in formats: - # print('self.directory',self.directory) - Path(self.directory + "/" + format).mkdir(parents=True, exist_ok=True) - - _dir = self.directory - for format in formats: - if format_folder: - self.directory = "{}{}/".format(_dir, format) - # print('pl',format) - if format in [ - "eps", - "jpeg", - "jpg", - "pdf", - "png", - "ps", - "raw", - "rgba", - "svg", - "svgz", - "tif", - "tiff", - ]: - # self.fig.savefig(self.directory + filename + '.' + format,bbox_inches='tight', transparent=False) - if tight_layout: - self.fig.savefig( - self.directory + filename + "." + format, - bbox_inches="tight", - ) - else: - self.fig.savefig(self.directory + filename + "." + format) - elif format == "pgf": - # Save pgf figure - self.fig.savefig( - self.directory + filename + "." + format, - bbox_inches="tight", - ) - - # Replace pgf figure colors with colorlet - # This is based af. - # col_list = self.col_list - file_str = "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n" - file_str += "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n" - file_str += "%% Do not forget to add the following lines" - for cmap in ["pastel", "cmap2"]: - col_list = import_colors(cmap) - file_str += "\n\n%% Definitions for " + cmap + "\n\n" - for col in col_list: - file_str += "%" + col_list[col].define_color_str(col) + "\n" - file_str += "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n" - file_str += "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n\n\n" - with open(self.directory + filename + "." + format, "r") as f: - for line in f: - file_str += line - for cmap in ["pastel", "cmap2"]: - col_list = import_colors(cmap) - for col in col_list: - if col_list[col].pgf_col_str in line: - # print(line) - file_str += ( - "\\colorlet{currentstroke}{" + col + "}%\n" - ) - file_str += ( - "\\colorlet{currentfill}{" + col + "}%\n" - ) - file_str += "\\colorlet{textcolor}{" + col + "}%\n" - - with open(self.directory + filename + "." + format, "w") as f: - f.write(file_str) - elif format in ["txt", "dat", "csv"]: - self.matplotlib2txt(self.directory + filename, format) - elif format == "pickle": - pickle.dump(self.fig, open(self.directory + filename + ".pickle", "wb")) - elif format == "tex": - import tikzplotlib - - # tikzplotlib.clean_figure() - tikzplotlib.save(self.directory + filename + ".tex") - else: - try: - plt.savefig( - self.directory + filename + "." + format, - bbox_inches="tight", - ) - except Exception as e: - print( - "ERROR: Could not save figure: " - + self.directory - + filename - + "." - + format, - ) - print(e) - - imgcat_formats = ["png"] - if create_sh_file: - with open("show_latest_image.sh", "w") as f: - for format in formats: - if format in imgcat_formats: - f.write( - "imgcat " + self.directory + filename + "." + format + "\n", - ) - - if print_imgcat and ("png" in formats or "pdf" in formats): - if format_folder: - self.directory = "{}{}/".format(_dir, "png") - self.imgcat(formats) - self.directory = _dir - - # if format in formats: - # print('imgcat ' + filename + '.' + format) - - def imgcat(self, formats="png"): - imgcat_formats = ["png"] - if isinstance(formats, str): - formats = [formats] - for format in formats: - if format in imgcat_formats: - print("imgcat " + self.directory + self.filename + "." + format) - - def matplotlib2txt(self, filename, format="txt"): - # ax = plt.gca() # get axis handle - - x_unit = "mm" - y_unit = "mm" - - max_len_arr = 0 - - # Create a vector with each axis - axs_vec = [] - if self.nx_subplots > 1 and self.ny_subplots > 1: - for i in range(self.nx_subplots): - for j in range(self.ny_subplots): - axs_vec.append(self.axs[i, j]) - elif self.nx_subplots > 1 or self.ny_subplots > 1: - for i in range(self.nx_subplots * self.ny_subplots): - axs_vec.append(self.axs[i]) - else: - axs_vec.append(self.axs) - - # Save all the data - line_names = [] - line_xdata = [] - line_ydata = [] - for iax, ax in enumerate(axs_vec): - for line in ax.lines: - line_names.append(line) - line_xdata.append(line.get_xdata()) - line_ydata.append(line.get_ydata()) - max_len_arr = max(max_len_arr, len(line.get_xdata())) - # print(line_names) - data_mat = np.empty((len(line_names) * 2, max_len_arr)) - data_mat[:, :] = np.NaN - i = 0 - - header = "" - - for name, xdata, ydata in zip(line_names, line_xdata, line_ydata): - data_mat[2 * i, 0 : len(xdata)] = xdata - data_mat[2 * i + 1, 0 : len(ydata)] = ydata - - header += "Axial position, " + str(name) + "," - - i += 1 - - np.savetxt(filename + "." + format, data_mat.T, header=header, delimiter=",") - - def dump( - self, - filename="", - ): - if filename == "": - filename = self.filename + "_dump.txt" - Path(self.directory + "/dump").mkdir(parents=True, exist_ok=True) - with open(self.directory + "dump/" + filename, "wb") as file: - file.write(cPickle.dumps(self.__dict__)) - - def load(self, filename): - with open(filename, "rb") as file: - self.__dict__ = cPickle.loads(file.read()) - - def get_lines(self): - lines = plt.gca().lines - out = [] - for i, line in enumerate(lines): - line_dict = dict() - line_dict["line"] = line - line_dict["line_name"] = str(line) - line_dict["line_xdat"] = line.get_xdata() - line_dict["line_ydat"] = line.get_ydata() - out.append(line_dict) - return out - - def add_label_box( - self, - label="Test", - xpos=0.05, - ypos=0.95, - rotation=0, - ha="left", - va="top", - bbox=dict(facecolor="white", edgecolor="gray", boxstyle="round"), - ): - self.axs.text( - xpos, - ypos, - label, - rotation=rotation, - ha=ha, - va=va, - transform=self.axs.transAxes, - bbox=bbox, - ) - # print(label) - - -def fmt_scientific(x, pos): - a, b = "{:.1e}".format(x).split("e") - b = int(b) - return r"${} \times 10^{{{}}}$".format(a, b) - - -def fmt_10pow(x, pos): - a, b = "{:.1e}".format(x).split("e") - b = int(b) - return r"$10^{{{}}}$".format(b) - - -def fmt_int(x, pos, num_dec): - return r"${}$".format(int(x)) - - -def fmt_1dec(x, pos): - return r"${}$".format(round(x, 1)) - - -def fmt_2dec(x, pos): - return r"${}$".format(round(x, 2)) - - -if __name__ == "__main__": - print("Testing maxplotlib") - mfig = figure(filename="mpl_test") - mfig.axs.plot([0, 1, 2, 3], [0, 0, 1, 1]) - mfig.savefig(formats=["png"]) diff --git a/src/maxplotlib/canvas/canvas.py b/src/maxplotlib/canvas/canvas.py index 8fbc51f..63c178e 100644 --- a/src/maxplotlib/canvas/canvas.py +++ b/src/maxplotlib/canvas/canvas.py @@ -1,17 +1,20 @@ import os +import re from typing import Dict +import matplotlib.patches as patches import matplotlib.pyplot as plt -import plotly.graph_objects as go from plotly.subplots import make_subplots +from tikzpics import TikzFigure from maxplotlib.backends.matplotlib.utils import ( set_size, setup_plotstyle, setup_tex_fonts, ) +from maxplotlib.colors.colors import Color +from maxplotlib.linestyle.linestyle import Linestyle from maxplotlib.subfigure.line_plot import LinePlot -from maxplotlib.subfigure.tikz_figure import TikzFigure from maxplotlib.utils.options import Backends @@ -145,8 +148,6 @@ def add_tikzfigure( # Initialize the LinePlot for the given subplot position tikz_figure = TikzFigure( - col=col, - row=row, label=label, **kwargs, ) @@ -263,20 +264,33 @@ def savefig( if verbose: print(f"Saved {full_filepath}") - def plot(self, backend: Backends = "matplotlib", savefig=False, layers=None): + def plot( + self, + backend: Backends = "matplotlib", + savefig=False, + layers=None, + ): if backend == "matplotlib": return self.plot_matplotlib(savefig=savefig, layers=layers) elif backend == "plotly": return self.plot_plotly(savefig=savefig) + elif backend == "tikzpics": + return self.plot_tikzpics(savefig=savefig) else: raise ValueError(f"Invalid backend: {backend}") - def show(self, backend="matplotlib"): + def show( + self, + backend: Backends = "matplotlib", + ): if backend == "matplotlib": self.plot(backend="matplotlib", savefig=False, layers=None) self._matplotlib_fig.show() elif backend == "plotly": - plot = self.plot_plotly(savefig=False) + self.plot_plotly(savefig=False) + elif backend == "tikzpics": + fig = self.plot_tikzpics(savefig=False) + fig.show() else: raise ValueError("Invalid backend") @@ -307,19 +321,20 @@ def plot_matplotlib(self, savefig=False, layers=None, usetex=False): dpi=self.dpi, ) - # print(f"{(fig_width / self._dpi, fig_height / self._dpi) = }") - fig, axes = plt.subplots( self.nrows, self.ncols, figsize=(fig_width, fig_height), squeeze=False, - dpi=self._dpi, + dpi=self.dpi, ) for (row, col), subplot in self.subplots.items(): ax = axes[row][col] - subplot.plot_matplotlib(ax, layers=layers) + if isinstance(subplot, TikzFigure): + plot_matplotlib(subplot, ax, layers=layers) + else: + subplot.plot_matplotlib(ax, layers=layers) # ax.set_title(f"Subplot ({row}, {col})") ax.grid() @@ -329,6 +344,22 @@ def plot_matplotlib(self, savefig=False, layers=None, usetex=False): self._matplotlib_axes = axes return fig, axes + def plot_tikzpics( + self, + savefig=None, + verbose=False, + ) -> TikzFigure: + if len(self.subplots) > 1: + raise NotImplementedError( + "Only one subplot is supported for tikzpics backend." + ) + for (row, col), line_plot in self.subplots.items(): + if verbose: + print(f"Plotting subplot at row {row}, col {col}") + print(f"{line_plot = }") + tikz_subplot = line_plot.plot_tikzpics(verbose=verbose) + return tikz_subplot + def plot_plotly(self, show=True, savefig=None, usetex=False): """ Generate and optionally display the subplots using Plotly. @@ -338,7 +369,7 @@ def plot_plotly(self, show=True, savefig=None, usetex=False): savefig (str, optional): Filename to save the figure if provided. """ - tex_fonts = setup_tex_fonts( + setup_tex_fonts( fontsize=self.fontsize, usetex=usetex, ) # adjust or redefine for Plotly if needed @@ -423,9 +454,6 @@ def subplot_matrix(self): return self._subplot_matrix # Property setters - @nrows.setter - def dpi(self, value): - self._dpi = value @nrows.setter def nrows(self, value): @@ -472,75 +500,140 @@ def __setitem__(self, key, value): raise IndexError("Subplot index out of range") self._subplot_matrix[row][col] = value - # def generate_matplotlib_code(self): - # """Generate code for plotting the data using matplotlib.""" - # code = "import matplotlib.pyplot as plt\n\n" - # code += f"fig, axes = plt.subplots({self.nrows}, {self.ncols}, figsize={self.figsize})\n\n" - # if self.nrows == 1 and self.ncols == 1: - # code += "axes = [axes] # Single subplot\n\n" - # else: - # code += "axes = axes.flatten()\n\n" - # for idx, (subplot_idx, lines) in enumerate(self.subplots.items()): - # code += f"# Subplot {subplot_idx}\n" - # code += f"ax = axes[{idx}]\n" - # for line in lines: - # x_data = line['x'] - # y_data = line['y'] - # label = line['label'] - # kwargs = line.get('kwargs', {}) - # kwargs_str = ', '.join(f"{k}={repr(v)}" for k, v in kwargs.items()) - # code += f"ax.plot({x_data}, {y_data}, label={repr(label)}" - # if kwargs_str: - # code += f", {kwargs_str}" - # code += ")\n" - # code += "ax.set_xlabel('X-axis')\n" - # code += "ax.set_ylabel('Y-axis')\n" - # if self.nrows * self.ncols > 1: - # code += f"ax.set_title('Subplot {subplot_idx}')\n" - # code += "ax.legend()\n\n" - # code += "plt.tight_layout()\nplt.show()\n" - # return code - - # def generate_latex_plot(self): - # """Generate LaTeX code for plotting the data using pgfplots in subplots.""" - # latex_code = "\\begin{figure}[h!]\n\\centering\n" - # total_subplots = self.nrows * self.ncols - # for idx in range(total_subplots): - # subplot_idx = divmod(idx, self.ncols) - # lines = self.subplots.get(subplot_idx, []) - # if not lines: - # continue # Skip empty subplots - # latex_code += "\\begin{subfigure}[b]{0.45\\textwidth}\n" - # latex_code += " \\begin{tikzpicture}\n" - # latex_code += " \\begin{axis}[\n" - # latex_code += " xlabel={X-axis},\n" - # latex_code += " ylabel={Y-axis},\n" - # if self.nrows * self.ncols > 1: - # latex_code += f" title={{Subplot {subplot_idx}}},\n" - # latex_code += " legend style={at={(1.05,1)}, anchor=north west},\n" - # latex_code += " legend entries={" + ", ".join(f"{{{line['label']}}}" for line in lines) + "}\n" - # latex_code += " ]\n" - # for line in lines: - # options = [] - # kwargs = line.get('kwargs', {}) - # if 'color' in kwargs: - # options.append(f"color={kwargs['color']}") - # if 'linestyle' in kwargs: - # linestyle_map = {'-': 'solid', '--': 'dashed', '-.': 'dash dot', ':': 'dotted'} - # linestyle = linestyle_map.get(kwargs['linestyle'], kwargs['linestyle']) - # options.append(f"style={linestyle}") - # options_str = f"[{', '.join(options)}]" if options else "" - # latex_code += f" \\addplot {options_str} coordinates {{\n" - # for x, y in zip(line['x'], line['y']): - # latex_code += f" ({x}, {y})\n" - # latex_code += " };\n" - # latex_code += " \\end{axis}\n" - # latex_code += " \\end{tikzpicture}\n" - # latex_code += "\\end{subfigure}\n" - # latex_code += "\\hfill\n" if (idx + 1) % self.ncols != 0 else "\n" - # latex_code += "\\caption{Multiple Subplots}\n" - # latex_code += "\\end{figure}\n" - # return latex_code + +def plot_matplotlib(tikzfigure: TikzFigure, ax, layers=None): + """ + Plot all nodes and paths on the provided axis using Matplotlib. + + Parameters: + - ax (matplotlib.axes.Axes): Axis on which to plot the figure. + """ + + # TODO: Specify which layers to retreive nodes from with layers=layers + nodes = tikzfigure.layers.get_nodes() + paths = tikzfigure.layers.get_paths() + + for path in paths: + x_coords = [node.x for node in path.nodes] + y_coords = [node.y for node in path.nodes] + + # Parse path color + path_color_spec = path.kwargs.get("color", "black") + try: + color = Color(path_color_spec).to_rgb() + except ValueError as e: + print(e) + color = "black" + + # Parse line width + line_width_spec = path.kwargs.get("line_width", 1) + if isinstance(line_width_spec, str): + match = re.match(r"([\d.]+)(pt)?", line_width_spec) + if match: + line_width = float(match.group(1)) + else: + print( + f"Invalid line width specification: '{line_width_spec}', defaulting to 1", + ) + line_width = 1 + else: + line_width = float(line_width_spec) + + # Parse line style using Linestyle class + style_spec = path.kwargs.get("style", "solid") + linestyle = Linestyle(style_spec).to_matplotlib() + + ax.plot( + x_coords, + y_coords, + color=color, + linewidth=line_width, + linestyle=linestyle, + zorder=1, # Lower z-order to place behind nodes + ) + + # Plot nodes after paths so they appear on top + for node in nodes: + # Determine shape and size + shape = node.kwargs.get("shape", "circle") + fill_color_spec = node.kwargs.get("fill", "white") + edge_color_spec = node.kwargs.get("draw", "black") + linewidth = float(node.kwargs.get("line_width", 1)) + size = float(node.kwargs.get("size", 1)) + + # Parse colors using the Color class + try: + facecolor = Color(fill_color_spec).to_rgb() + except ValueError as e: + print(e) + facecolor = "white" + + try: + edgecolor = Color(edge_color_spec).to_rgb() + except ValueError as e: + print(e) + edgecolor = "black" + + # Plot shapes + if shape == "circle": + radius = size / 2 + circle = patches.Circle( + (node.x, node.y), + radius, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + zorder=2, # Higher z-order to place on top of paths + ) + ax.add_patch(circle) + elif shape == "rectangle": + width = height = size + rect = patches.Rectangle( + (node.x - width / 2, node.y - height / 2), + width, + height, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + zorder=2, # Higher z-order + ) + ax.add_patch(rect) + else: + # Default to circle if shape is unknown + radius = size / 2 + circle = patches.Circle( + (node.x, node.y), + radius, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + zorder=2, + ) + ax.add_patch(circle) + + # Add text inside the shape + if node.content: + ax.text( + node.x, + node.y, + node.content, + fontsize=10, + ha="center", + va="center", + wrap=True, + zorder=3, # Even higher z-order for text + ) + + # Remove axes, ticks, and legend + ax.axis("off") + + # Adjust plot limits + all_x = [node.x for node in nodes] + all_y = [node.y for node in nodes] + padding = 1 # Adjust padding as needed + ax.set_xlim(min(all_x) - padding, max(all_x) + padding) + ax.set_ylim(min(all_y) - padding, max(all_y) + padding) + ax.set_aspect("equal", adjustable="datalim") if __name__ == "__main__": diff --git a/src/maxplotlib/colors/colors.py b/src/maxplotlib/colors/colors.py index d1381ef..fdb117e 100644 --- a/src/maxplotlib/colors/colors.py +++ b/src/maxplotlib/colors/colors.py @@ -1,7 +1,6 @@ import re import matplotlib.colors as mcolors -import matplotlib.patches as patches import numpy as np diff --git a/src/maxplotlib/objects/layer.py b/src/maxplotlib/objects/layer.py deleted file mode 100644 index 4ebb71c..0000000 --- a/src/maxplotlib/objects/layer.py +++ /dev/null @@ -1,20 +0,0 @@ -from abc import ABCMeta, abstractmethod - - -class Layer(metaclass=ABCMeta): - def __init__(self, label): - self.label = label - self.items = [] - - -class Tikzlayer(Layer): - def __init__(self, label): - super().__init__(label) - - def generate_tikz(self): - tikz_script = f"\n% Layer {self.label}\n" - tikz_script += f"\\begin{{pgfonlayer}}{{{self.label}}}\n" - for item in self.items: - tikz_script += item.to_tikz() - tikz_script += f"\\end{{pgfonlayer}}{{{self.label}}}\n" - return tikz_script diff --git a/src/maxplotlib/subfigure/line_plot.py b/src/maxplotlib/subfigure/line_plot.py index 9435b37..8a50cfc 100644 --- a/src/maxplotlib/subfigure/line_plot.py +++ b/src/maxplotlib/subfigure/line_plot.py @@ -2,9 +2,7 @@ import numpy as np import plotly.graph_objects as go from mpl_toolkits.axes_grid1 import make_axes_locatable - -import maxplotlib.subfigure.tikz_figure as tf -from maxplotlib.objects.layer import Tikzlayer +from tikzpics import TikzFigure class Node: @@ -224,6 +222,21 @@ def plot_matplotlib( if self.ymax is not None: ax.axis(ymax=self.ymax) + def plot_tikzpics(self, layers=None, verbose: bool = False) -> TikzFigure: + + tikz_figure = TikzFigure() + for layer_name, layer_lines in self.layered_line_data.items(): + if layers and layer_name not in layers: + continue + for line in layer_lines: + if line["plot_type"] == "plot": + x = (line["x"] + self._xshift) * self._xscale + y = (line["y"] + self._yshift) * self._yscale + + nodes = [[xi, yi] for xi, yi in zip(x, y)] + tikz_figure.draw(nodes=nodes, **line["kwargs"]) + return tikz_figure + def plot_plotly(self): """ Plot all lines using Plotly and return a list of traces for each line. @@ -255,68 +268,6 @@ def plot_plotly(self): return traces - def add_node(self, x, y, label=None, content="", layer=0, **kwargs): - """ - Add a node to the TikZ figure. - - Parameters: - - x (float): X-coordinate of the node. - - y (float): Y-coordinate of the node. - - label (str, optional): Label of the node. If None, a default label will be assigned. - - **kwargs: Additional TikZ node options (e.g., shape, color). - - Returns: - - node (Node): The Node object that was added. - """ - if label is None: - label = f"node{self._node_counter}" - node = Node(x=x, y=y, label=label, layer=layer, content=content, **kwargs) - self.nodes.append(node) - if layer in self.layers: - self.layers[layer].add(node) - else: - # print(f"{self.layers = } {layer = }") - self.layers[layer] = Tikzlayer(layer) - self.layers[layer].add(node) - self._node_counter += 1 - return node - - def add_path(self, nodes, layer=0, **kwargs): - """ - Add a line or path connecting multiple nodes. - - Parameters: - - nodes (list of str): List of node names to connect. - - **kwargs: Additional TikZ path options (e.g., style, color). - - Examples: - - add_path(['A', 'B', 'C'], color='blue') - Connects nodes A -> B -> C with a blue line. - """ - if not isinstance(nodes, list): - raise ValueError("nodes parameter must be a list of node names.") - - nodes = [ - ( - node - if isinstance(node, Node) - else ( - self.get_node(node) - if isinstance(node, str) - else ValueError(f"Invalid node type: {type(node)}") - ) - ) - for node in nodes - ] - path = Path(nodes, **kwargs) - self.paths.append(path) - if layer in self.layers: - self.layers[layer].add(path) - else: - self.layers[layer] = Tikzlayer(layer) - self.layers[layer].add(path) - return path - @property def xmin(self): return self._xmin diff --git a/src/maxplotlib/subfigure/tikz_figure.py b/src/maxplotlib/subfigure/tikz_figure.py deleted file mode 100644 index 4751c51..0000000 --- a/src/maxplotlib/subfigure/tikz_figure.py +++ /dev/null @@ -1,497 +0,0 @@ -import os -import re -import subprocess -import tempfile - -import matplotlib.patches as patches -import numpy as np -from matplotlib.image import imread - -from maxplotlib.colors.colors import Color -from maxplotlib.linestyle.linestyle import Linestyle - - -class Tikzlayer: - def __init__(self, label): - self.label = label - self.items = [] - - def add(self, item): - self.items.append(item) - - def get_reqs(self): - reqs = set() - for item in self.items: - if isinstance(item, Path): - for node in item.nodes: - if not node.layer == self.label: - reqs.add(node.layer) - return reqs - - def generate_tikz(self): - tikz_script = f"\n% Layer {self.label}\n" - tikz_script += f"\\begin{{pgfonlayer}}{{{self.label}}}\n" - for item in self.items: - tikz_script += item.to_tikz() - tikz_script += f"\\end{{pgfonlayer}}{{{self.label}}}\n" - return tikz_script - - -class TikzWrapper: - def __init__(self, raw_tikz, label="", content="", layer=0, **kwargs): - self.raw_tikz = raw_tikz - self.label = label - self.content = content - self.layer = layer - self.options = kwargs - - def to_tikz(self): - return self.raw_tikz - - -class Node: - def __init__(self, x, y, label="", content="", layer=0, **kwargs): - """ - Represents a TikZ node. - - Parameters: - - x (float): X-coordinate of the node. - - y (float): Y-coordinate of the node. - - name (str, optional): Name of the node. If None, a default name will be assigned. - - **kwargs: Additional TikZ node options (e.g., shape, color). - """ - self.x = x - self.y = y - self.label = label - self.content = content - self.layer = layer - self.options = kwargs - - def to_tikz(self): - """ - Generate the TikZ code for this node. - - Returns: - - tikz_str (str): TikZ code string for the node. - """ - options = ", ".join( - f"{k.replace('_', ' ')}={v}" for k, v in self.options.items() - ) - if options: - options = f"[{options}]" - return f"\\node{options} ({self.label}) at ({self.x}, {self.y}) {{{self.content}}};\n" - - -class Path: - def __init__( - self, - nodes, - path_actions=[], - cycle=False, - label="", - layer=0, - **kwargs, - ): - """ - Represents a path (line) connecting multiple nodes. - - Parameters: - - nodes (list of str): List of node names to connect. - - **kwargs: Additional TikZ path options (e.g., style, color). - """ - self.nodes = nodes - self.path_actions = path_actions - self.cycle = cycle - self.layer = layer - self.label = label - self.options = kwargs - - def to_tikz(self): - """ - Generate the TikZ code for this path. - - Returns: - - tikz_str (str): TikZ code string for the path. - """ - options = ", ".join( - f"{k.replace('_', ' ')}={v}" for k, v in self.options.items() - ) - if len(self.path_actions) > 0: - options = ", ".join(self.path_actions) + ", " + options - if options: - options = f"[{options}]" - path_str = " to ".join(f"({node.label}.center)" for node in self.nodes) - if self.cycle: - path_str += " -- cycle" - return f"\\draw{options} {path_str};\n" - - -class TikzFigure: - def __init__(self, **kwargs): - """ - Initialize the TikzFigure class for creating TikZ figures. - - Parameters: - **kwargs: Arbitrary keyword arguments. - - figsize (tuple): Figure size (default is (10, 6)). - - caption (str): Caption for the figure. - - description (str): Description of the figure. - - label (str): Label for the figure. - - grid (bool): Whether to display grid lines (default is False). - TODO: Add all options - """ - # Set default values - self._figsize = kwargs.get("figsize", (10, 6)) - self._caption = kwargs.get("caption", None) - self._description = kwargs.get("description", None) - self._label = kwargs.get("label", None) - self._grid = kwargs.get("grid", False) - - # Initialize lists to hold Node and Path objects - self.nodes = [] - self.paths = [] - self.layers = {} - - # Counter for unnamed nodes - self._node_counter = 0 - - def add_node(self, x, y, label=None, content="", layer=0, **kwargs): - """ - Add a node to the TikZ figure. - - Parameters: - - x (float): X-coordinate of the node. - - y (float): Y-coordinate of the node. - - label (str, optional): Label of the node. If None, a default label will be assigned. - - **kwargs: Additional TikZ node options (e.g., shape, color). - - Returns: - - node (Node): The Node object that was added. - """ - if label is None: - label = f"node{self._node_counter}" - node = Node(x=x, y=y, label=label, layer=layer, content=content, **kwargs) - self.nodes.append(node) - if layer in self.layers: - self.layers[layer].add(node) - else: - self.layers[layer] = Tikzlayer(layer) - self.layers[layer].add(node) - self._node_counter += 1 - return node - - def add_path(self, nodes, layer=0, **kwargs): - """ - Add a line or path connecting multiple nodes. - - Parameters: - - nodes (list of str): List of node names to connect. - - **kwargs: Additional TikZ path options (e.g., style, color). - - Examples: - - add_path(['A', 'B', 'C'], color='blue') - Connects nodes A -> B -> C with a blue line. - """ - if not isinstance(nodes, list): - raise ValueError("nodes parameter must be a list of node names.") - - nodes = [ - ( - node - if isinstance(node, Node) - else ( - self.get_node(node) - if isinstance(node, str) - else ValueError(f"Invalid node type: {type(node)}") - ) - ) - for node in nodes - ] - path = Path(nodes, **kwargs) - self.paths.append(path) - if layer in self.layers: - self.layers[layer].add(path) - else: - self.layers[layer] = Tikzlayer(layer) - self.layers[layer].add(path) - return path - - def add_raw(self, raw_tikz, layer=0, **kwargs): - tikz = TikzWrapper(raw_tikz) - if layer in self.layers: - self.layers[layer].add(tikz) - else: - self.layers[layer] = Tikzlayer(layer) - self.layers[layer].add(tikz) - return tikz - - def get_node(self, node_label): - for node in self.nodes: - if node.label == node_label: - return node - - def get_layer(self, item): - for layer, layer_items in self.layers.items(): - if item in [layer_item.label for layer_item in layer_items]: - return layer - print(f"Item {item} not found in any layer!") - - def add_tabs(self, tikz_script): - tikz_script_new = "" - tab_str = " " - num_tabs = 0 - for line in tikz_script.split("\n"): - if "\\end" in line: - num_tabs = max(num_tabs - 1, 0) - tikz_script_new += f"{tab_str*num_tabs}{line}\n" - if "\\begin" in line: - num_tabs += 1 - return tikz_script_new - - def generate_tikz(self): - """ - Generate the TikZ script for the figure. - - Returns: - - tikz_script (str): The TikZ script as a string. - """ - tikz_script = "\\begin{tikzpicture}\n" - tikz_script += "% Define the layers library\n" - layers = sorted([str(layer) for layer in self.layers.keys()]) - for layer in layers: - tikz_script += f"\\pgfdeclarelayer{{{layer}}}\n" - tikz_script += f"\\pgfsetlayers{{{','.join(layers)}}}\n" - - # Add grid if enabled - # TODO: Create a Grid class - if self._grid: - tikz_script += ( - " \\draw[step=1cm, gray, very thin] (-10,-10) grid (10,10);\n" - ) - ordered_layers = [] - buffered_layers = set() - - for key, layer in self.layers.items(): - # layer_order, buffered_layers = update_layer_order(layer, layer_order, buffered_layers) - reqs = layer.get_reqs() - if all([r == layer.label for r in reqs]): - ordered_layers.append(layer) - elif all([r in [l.label for l in ordered_layers] for r in reqs]): - ordered_layers.append(layer) - else: - buffered_layers.add(layer) - - for buffered_layer in buffered_layers: - buff_reqs = buffered_layer.get_reqs() - if all([r in [l.label for l in ordered_layers] for r in buff_reqs]): - print("Move layer from buffer") - ordered_layers.append(key) - buffered_layers.remove(key) - assert ( - len(buffered_layers) == 0 - ), f"Layer order is impossible for layer {[layer.label for layer in buffered_layers]}" - for layer in ordered_layers: - tikz_script += layer.generate_tikz() - - tikz_script += "\\end{tikzpicture}" - - # Wrap in figure environment if necessary - if self._caption or self._description or self._label: - figure_env = "\\begin{figure}\n" + tikz_script + "\n" - if self._caption: - figure_env += f" \\caption{{{self._caption}}}\n" - if self._label: - figure_env += f" \\label{{{self._label}}}\n" - figure_env += "\\end{figure}" - tikz_script = figure_env - tikz_script = self.add_tabs(tikz_script) - return tikz_script - - def savefig(self, filepath): - tikz_code = self.generate_tikz() - with open(filepath, "w") as f: - f.write(tikz_code) - - def generate_standalone(self): - tikz_code = self.generate_tikz() - - # Create a minimal LaTeX document - latex_document = ( - "\\documentclass[border=10pt]{standalone}\n" - "\\usepackage{tikz}\n" - "\\begin{document}\n" - f"{tikz_code}\n" - "\\end{document}" - ) - return latex_document - - def compile_pdf(self, filename="output.pdf"): - """ - Compile the TikZ script into a PDF using pdflatex. - - Parameters: - - filename (str): The name of the output PDF file (default is 'output.pdf'). - - Notes: - - Requires 'pdflatex' to be installed and accessible from the command line. - """ - latex_document = self.generate_standalone() - - # Use a temporary directory to store the LaTeX files - with tempfile.TemporaryDirectory() as tempdir: - tex_file = os.path.join(tempdir, "figure.tex") - with open(tex_file, "w") as f: - f.write(latex_document) - - # Run pdflatex - try: - subprocess.run( - ["pdflatex", "-interaction=nonstopmode", tex_file], - cwd=tempdir, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - except subprocess.CalledProcessError as e: - print("An error occurred while compiling the LaTeX document:") - print(e.stderr.decode()) - return - - # Move the output PDF to the desired location - pdf_output = os.path.join(tempdir, "figure.pdf") - if os.path.exists(pdf_output): - os.rename(pdf_output, filename) - print(f"PDF successfully compiled and saved as '{filename}'.") - else: - print("PDF compilation failed. Please check the LaTeX log for details.") - - def plot_matplotlib(self, ax, layers=None): - """ - Plot all nodes and paths on the provided axis using Matplotlib. - - Parameters: - - ax (matplotlib.axes.Axes): Axis on which to plot the figure. - """ - - # Plot paths first so they appear behind nodes - for path in self.paths: - x_coords = [node.x for node in path.nodes] - y_coords = [node.y for node in path.nodes] - - # Parse path color - path_color_spec = path.options.get("color", "black") - try: - color = Color(path_color_spec).to_rgb() - except ValueError as e: - print(e) - color = "black" - - # Parse line width - line_width_spec = path.options.get("line_width", 1) - if isinstance(line_width_spec, str): - match = re.match(r"([\d.]+)(pt)?", line_width_spec) - if match: - line_width = float(match.group(1)) - else: - print( - f"Invalid line width specification: '{line_width_spec}', defaulting to 1", - ) - line_width = 1 - else: - line_width = float(line_width_spec) - - # Parse line style using Linestyle class - style_spec = path.options.get("style", "solid") - linestyle = Linestyle(style_spec).to_matplotlib() - - ax.plot( - x_coords, - y_coords, - color=color, - linewidth=line_width, - linestyle=linestyle, - zorder=1, # Lower z-order to place behind nodes - ) - - # Plot nodes after paths so they appear on top - for node in self.nodes: - # Determine shape and size - shape = node.options.get("shape", "circle") - fill_color_spec = node.options.get("fill", "white") - edge_color_spec = node.options.get("draw", "black") - linewidth = float(node.options.get("line_width", 1)) - size = float(node.options.get("size", 1)) - - # Parse colors using the Color class - try: - facecolor = Color(fill_color_spec).to_rgb() - except ValueError as e: - print(e) - facecolor = "white" - - try: - edgecolor = Color(edge_color_spec).to_rgb() - except ValueError as e: - print(e) - edgecolor = "black" - - # Plot shapes - if shape == "circle": - radius = size / 2 - circle = patches.Circle( - (node.x, node.y), - radius, - facecolor=facecolor, - edgecolor=edgecolor, - linewidth=linewidth, - zorder=2, # Higher z-order to place on top of paths - ) - ax.add_patch(circle) - elif shape == "rectangle": - width = height = size - rect = patches.Rectangle( - (node.x - width / 2, node.y - height / 2), - width, - height, - facecolor=facecolor, - edgecolor=edgecolor, - linewidth=linewidth, - zorder=2, # Higher z-order - ) - ax.add_patch(rect) - else: - # Default to circle if shape is unknown - radius = size / 2 - circle = patches.Circle( - (node.x, node.y), - radius, - facecolor=facecolor, - edgecolor=edgecolor, - linewidth=linewidth, - zorder=2, - ) - ax.add_patch(circle) - - # Add text inside the shape - if node.content: - ax.text( - node.x, - node.y, - node.content, - fontsize=10, - ha="center", - va="center", - wrap=True, - zorder=3, # Even higher z-order for text - ) - - # Remove axes, ticks, and legend - ax.axis("off") - - # Adjust plot limits - all_x = [node.x for node in self.nodes] - all_y = [node.y for node in self.nodes] - padding = 1 # Adjust padding as needed - ax.set_xlim(min(all_x) - padding, max(all_x) + padding) - ax.set_ylim(min(all_y) - padding, max(all_y) + padding) - ax.set_aspect("equal", adjustable="datalim") diff --git a/src/maxplotlib/tests/test_canvas.py b/src/maxplotlib/tests/test_canvas.py index 5b7592d..847a2d5 100644 --- a/src/maxplotlib/tests/test_canvas.py +++ b/src/maxplotlib/tests/test_canvas.py @@ -1,6 +1,5 @@ def test(): - import maxplotlib.canvas.canvas - import maxplotlib.subfigure.line_plot + pass if __name__ == "__main__": diff --git a/src/maxplotlib/tests/test_imports.py b/src/maxplotlib/tests/test_imports.py index 302d52d..4eb1309 100644 --- a/src/maxplotlib/tests/test_imports.py +++ b/src/maxplotlib/tests/test_imports.py @@ -3,9 +3,8 @@ @pytest.mark.parametrize("x", [0]) def import_modules(x): - import matplotlib - import maxplotlib + pass if __name__ == "__main__": diff --git a/src/maxplotlib/utils/options.py b/src/maxplotlib/utils/options.py index 78d5482..6666e4d 100644 --- a/src/maxplotlib/utils/options.py +++ b/src/maxplotlib/utils/options.py @@ -1,3 +1,3 @@ from typing import Literal -Backends = Literal["matplotlib", "plotly"] +Backends = Literal["matplotlib", "plotly", "tikzpics"] diff --git a/tutorials/tutorial_02.ipynb b/tutorials/tutorial_02.ipynb index 62110a4..f2e3528 100644 --- a/tutorials/tutorial_02.ipynb +++ b/tutorials/tutorial_02.ipynb @@ -32,14 +32,14 @@ "tikz = c.add_tikzfigure(grid=False)\n", "\n", "# Add nodes\n", - "tikz.add_node(0, 0, \"A\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", - "tikz.add_node(1, 0, \"B\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", - "tikz.add_node(1, 1, \"C\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", - "tikz.add_node(0, 1, \"D\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=2)\n", + "tikz.add_node(0, 0, label=\"A\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", + "tikz.add_node(1, 0, label=\"B\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", + "tikz.add_node(1, 1, label=\"C\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", + "tikz.add_node(0, 1, label=\"D\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=2)\n", "\n", "\n", "# Add a line between nodes\n", - "tikz.add_path(\n", + "tikz.draw(\n", " [\"A\", \"B\", \"C\", \"D\"],\n", " path_actions=[\"draw\", \"rounded corners\"],\n", " fill=\"red\",\n", @@ -51,7 +51,7 @@ "tikz.add_node(0.5, 0.5, content=\"Cube\", layer=10)\n", "\n", "# tikz.compile_pdf(\"tutorial_02_01.pdf\")\n", - "#" + "c.plot(backend=\"matplotlib\")" ] }, { @@ -66,23 +66,29 @@ "\n", "# Add nodes\n", "node_a = tikz.add_node(\n", - " -5, 0, \"A\", content=\"Origin node\", shape=\"circle\", draw=\"black\", fill=\"blue!20\"\n", + " -5,\n", + " 0,\n", + " label=\"A\",\n", + " content=\"Origin node\",\n", + " shape=\"circle\",\n", + " draw=\"black\",\n", + " fill=\"blue!20\",\n", ")\n", "tikz.add_node(\n", " 2,\n", " 2,\n", - " \"B\",\n", + " label=\"B\",\n", " content=\"$a^2 + b^2 = c^2$\",\n", " shape=\"rectangle\",\n", " draw=\"red\",\n", " fill=\"white\",\n", " layer=1,\n", ")\n", - "tikz.add_node(2, 5, \"C\", shape=\"rectangle\", draw=\"red\", fill=\"red\")\n", + "tikz.add_node(2, 5, label=\"C\", shape=\"rectangle\", draw=\"red\", fill=\"red\")\n", "last_node = tikz.add_node(-1, 5, shape=\"rectangle\", draw=\"red\", fill=\"red\", layer=-10)\n", "\n", - "# Add a line between nodes\n", - "tikz.add_path(\n", + "# # Add a line between nodes\n", + "tikz.draw(\n", " [node_a.label, \"B\", \"C\", \"A\", last_node],\n", " color=\"green\",\n", " style=\"solid\",\n", @@ -115,12 +121,12 @@ "tikz = c.add_tikzfigure(grid=False)\n", "\n", "# Add nodes\n", - "tikz.add_node(0, 0, \"A\")\n", - "tikz.add_node(10, 0, \"B\")\n", + "tikz.add_node(0, 0, label=\"A\")\n", + "tikz.add_node(10, 0, label=\"B\")\n", "\n", "\n", "# Add a line between nodes\n", - "tikz.add_path([\"A\", \"B\"], path_actions=[\"->\"], out=30)\n", + "tikz.draw([\"A\", \"B\"], path_actions=[\"->\"], out=30)\n", "\n", "# Generate the TikZ script\n", "# script = tikz.generate_tikz()\n", @@ -132,7 +138,7 @@ ], "metadata": { "kernelspec": { - "display_name": "env_maxplotlib", + "display_name": "env_maxpic", "language": "python", "name": "python3" }, @@ -146,7 +152,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.13.3" } }, "nbformat": 4, diff --git a/tutorials/tutorial_04.ipynb b/tutorials/tutorial_04.ipynb index bf46f10..ef2dc8b 100644 --- a/tutorials/tutorial_04.ipynb +++ b/tutorials/tutorial_04.ipynb @@ -10,10 +10,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'\\nTutorial 4.\\n\\nAdd raw tikz code to the tikz subplot.\\n'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "\"\"\"\n", "Tutorial 4.\n", @@ -24,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "2", "metadata": {}, "outputs": [], @@ -34,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "3", "metadata": {}, "outputs": [], @@ -45,29 +56,51 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "4", "metadata": { "lines_to_next_cell": 2 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Add nodes\n", - "tikz.add_node(0, 0, \"A\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", - "tikz.add_node(10, 0, \"B\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", - "tikz.add_node(10, 10, \"C\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", - "tikz.add_node(0, 10, \"D\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=2)" + "tikz.add_node(0, 0, label=\"A\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", + "tikz.add_node(10, 0, label=\"B\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", + "tikz.add_node(10, 10, label=\"C\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=0)\n", + "tikz.add_node(0, 10, label=\"D\", shape=\"circle\", draw=\"black\", fill=\"blue\", layer=2)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "5", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Add a line between nodes\n", - "tikz.add_path(\n", + "tikz.draw(\n", " [\"A\", \"B\", \"C\", \"D\"],\n", " path_actions=[\"draw\", \"rounded corners\"],\n", " fill=\"red\",\n", @@ -79,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "6", "metadata": {}, "outputs": [], @@ -99,20 +132,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "7", "metadata": {}, "outputs": [], "source": [ - "tikz.add_raw(raw_tikz)" + "# TODO: Not implemented in tikzpics yet\n", + "# tikz.add_raw(raw_tikz)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "tikz.add_node(0.5, 0.5, content=\"Cube\", layer=10)" ] @@ -122,14 +167,58 @@ "execution_count": null, "id": "9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\n", + "\n", + "% --------------------------------------------- %\n", + "% Tikzfigure generated by tikzpics v0.1.1 %\n", + "% https://github.com/max-models/tikzpics %\n", + "% --------------------------------------------- %\n", + "\\begin{tikzpicture}\n", + " \n", + " % Define the layers library\n", + " \\pgfdeclarelayer{0}\n", + " \\pgfdeclarelayer{1}\n", + " \\pgfdeclarelayer{10}\n", + " \\pgfdeclarelayer{2}\n", + " \\pgfsetlayers{0,1,10,2}\n", + " \n", + " % Layer 0\n", + " \\begin{pgfonlayer}{0}\n", + " \\node[shape=circle, draw=black, fill=blue] (A) at (0, 0) {};\n", + " \\node[shape=circle, draw=black, fill=blue] (B) at (10, 0) {};\n", + " \\node[shape=circle, draw=black, fill=blue] (C) at (10, 10) {};\n", + " \\end{pgfonlayer}{0}\n", + " \n", + " % Layer 2\n", + " \\begin{pgfonlayer}{2}\n", + " \\node[shape=circle, draw=black, fill=blue] (D) at (0, 10) {};\n", + " \\end{pgfonlayer}{2}\n", + " \n", + " % Layer 1\n", + " \\begin{pgfonlayer}{1}\n", + " \\draw[path actions=['draw', 'rounded corners'], fill=red, opacity=0.5] (A) to (B) to (C) to (D) -- cycle;\n", + " \\end{pgfonlayer}{1}\n", + " \n", + " % Layer 10\n", + " \\begin{pgfonlayer}{10}\n", + " \\node (node4) at (0.5, 0.5) {Cube};\n", + " \\end{pgfonlayer}{10}\n", + "\\end{tikzpicture}\n", + "\n" + ] + } + ], "source": [ "# Generate the TikZ script\n", "script = tikz.generate_tikz()\n", - "print(script)\n", - "# print(tikz.generate_standalone())\n", - "# tikz.compile_pdf(\"tutorial_04_01.pdf\")\n", - "#" + "print(script)" ] } ], @@ -140,7 +229,7 @@ "notebook_metadata_filter": "-all" }, "kernelspec": { - "display_name": "env_maxplotlib", + "display_name": "env_maxpic", "language": "python", "name": "python3" }, @@ -154,7 +243,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.13.3" } }, "nbformat": 4, diff --git a/tutorials/tutorial_05.ipynb b/tutorials/tutorial_05.ipynb index 55334f8..2f70552 100644 --- a/tutorials/tutorial_05.ipynb +++ b/tutorials/tutorial_05.ipynb @@ -44,7 +44,7 @@ "# last_node = sp.add_node(-1, 5, shape='rectangle', draw='red', fill='red', layer=-10)\n", "\n", "# Add a line between nodes\n", - "# sp.add_path([\"A\", \"B\"], color=\"green\", style=\"solid\", line_width=\"2\", layer=-5)\n", + "# sp.draw([\"A\", \"B\"], color=\"green\", style=\"solid\", line_width=\"2\", layer=-5)\n", "\n", "x = np.arange(0, 2 * np.pi, 0.01)\n", "y = np.sin(x)\n", diff --git a/tutorials/tutorial_06.ipynb b/tutorials/tutorial_06.ipynb index d783520..514274a 100644 --- a/tutorials/tutorial_06.ipynb +++ b/tutorials/tutorial_06.ipynb @@ -16,7 +16,6 @@ "outputs": [], "source": [ "from maxplotlib import Canvas\n", - "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "%load_ext autoreload\n", @@ -38,19 +37,11 @@ "\n", "c.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "env_maxplotlib", + "display_name": "env_maxpic", "language": "python", "name": "python3" }, @@ -64,7 +55,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.13.3" } }, "nbformat": 4, diff --git a/tutorials/tutorial_07_tikzpics.ipynb b/tutorials/tutorial_07_tikzpics.ipynb new file mode 100644 index 0000000..d3bc587 --- /dev/null +++ b/tutorials/tutorial_07_tikzpics.ipynb @@ -0,0 +1,62 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Tutorial 6" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "from maxplotlib import Canvas\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "c = Canvas(width=\"17cm\", ratio=0.5)\n", + "sp = c.add_subplot(grid=False, xlabel=\"x\", ylabel=\"y\")\n", + "sp.add_line([0, 1, 2, 3], [0, 1, 0, 2], label=\"Line 1\", layer=1, line_width=2.0)\n", + "\n", + "\n", + "# TODO: Uncomment if pdflatex is installed\n", + "# c.show(backend=\"tikzpics\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env_maxpic", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}