import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

# Made by ChatGPT
def draw_cone(ax3d, height=2.0, base_radius=1.0, offset_x=0, offset_y=0, offset_z=0, color='#FFFF00', alpha=0.7, angle_x=0, angle_y=0, angle_z=0):
    # Generate cone surface
    theta = np.linspace(0, 2 * np.pi, 50)
    z = np.linspace(0, height, 50)
    theta, z = np.meshgrid(theta, z)
    x = base_radius * (1 - z / height) * np.cos(theta) + offset_x
    y = base_radius * (1 - z / height) * np.sin(theta) + offset_y
    z = z + offset_z

    ax, ay, az = np.deg2rad([angle_x, angle_y, angle_z])
    def Rx(a): return np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]])
    def Ry(a): return np.array([[np.cos(a), 0, np.sin(a)], [0, 1, 0], [-np.sin(a), 0, np.cos(a)]])
    def Rz(a): return np.array([[np.cos(a), -np.sin(a), 0], [np.sin(a), np.cos(a), 0], [0, 0, 1]])
    R = Rz(az) @ Ry(ay) @ Rx(ax)
    points = np.array([x.flatten(), y.flatten(), z.flatten()]).T @ R.T
    x, y, z = points[:, 0].reshape(x.shape), points[:, 1].reshape(y.shape), points[:, 2].reshape(z.shape)

    ax3d.plot_surface(x, y, z, color=color, alpha=alpha)

def draw_sphere(ax3d, radius=1.0, offset_x=0, offset_y=0, offset_z=0, color='red', alpha=0.7, angle_x=0, angle_y=0, angle_z=0):
    # Generate sphere surface
    u = np.linspace(0, 2 * np.pi, 50)
    v = np.linspace(0, np.pi, 50)
    u, v = np.meshgrid(u, v)
    x = radius * np.sin(v) * np.cos(u) + offset_x
    y = radius * np.sin(v) * np.sin(u) + offset_y
    z = radius * np.cos(v) + offset_z

    ax, ay, az = np.deg2rad([angle_x, angle_y, angle_z])
    def Rx(a): return np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]])
    def Ry(a): return np.array([[np.cos(a), 0, np.sin(a)], [0, 1, 0], [-np.sin(a), 0, np.cos(a)]])
    def Rz(a): return np.array([[np.cos(a), -np.sin(a), 0], [np.sin(a), np.cos(a), 0], [0, 0, 1]])
    R = Rz(az) @ Ry(ay) @ Rx(ax)
    points = np.array([x.flatten(), y.flatten(), z.flatten()]).T @ R.T
    x, y, z = points[:, 0].reshape(x.shape), points[:, 1].reshape(y.shape), points[:, 2].reshape(z.shape)

    ax3d.plot_surface(x, y, z, color=color, alpha=alpha)

def make_cone_sphere(angle_x=0, angle_y=0, angle_z=0, img_size=400):
    angle_x = float(angle_x)
    angle_y = float(angle_y)
    angle_z = float(angle_z)
    fig = plt.figure(figsize=(img_size / 100, img_size / 100), dpi=100)
    fig.patch.set_alpha(0.0)  # transparent
    canvas = FigureCanvas(fig)
    ax3d = fig.add_subplot(111, projection='3d')
    ax3d.set_facecolor((0, 0, 0, 0))
    ax3d.axis('off')
    ax3d.set_box_aspect([1, 1, 1])  # Ensure equal aspect ratio for all axes
    ax3d.view_init(elev=30, azim=45)  # Set a default view angle

    # Draw cone and sphere with specific positions
    draw_sphere(ax3d, radius=0.5, offset_x=1.0, offset_y=1, offset_z=0, angle_x=angle_x, angle_y=angle_y, angle_z=angle_z)
    draw_sphere(ax3d, radius=1.2, offset_x=0.0, offset_y=0, offset_z=0, angle_x=angle_x, angle_y=angle_y, angle_z=angle_z)
    draw_cone(ax3d, height=2.0, base_radius=1.0, offset_x=0, offset_y=0, offset_z=0, angle_x=angle_x, angle_y=angle_y, angle_z=angle_z)

    ax3d.set_xlim(-1, 1)
    ax3d.set_ylim(-1, 1)
    ax3d.set_zlim(-1, 1)

    canvas.draw()
    width, height = map(int, fig.get_size_inches() * fig.get_dpi())
    rgba = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(height, width, 4)
    bgra = rgba[:, :, [2,1,0,3]].copy()
    plt.close(fig)
    return bgra

def execute(params, inputs, outputs):
    outputs.m1 = make_cone_sphere(params.a, params.b, params.c)
    return "Hello! figure is in m1"

if __name__ == "__main__":
    img = make_cone_sphere(30, 130, 30)
    from PIL import Image
    Image.fromarray(img).show()