import bpy, mathutils, bmesh
import random, math, re
from . import Node, DATA_GETTER_NODE_Material, DATA_Material, DATA_Image, SOCKET_Image, SOCKET_Material
from .. import my_globals, utils
from typing import List, Tuple


class Get_Material(bpy.types.Node, Node, DATA_GETTER_NODE_Material):
	bl_idname = 'sc_node_49dd4wznqwcew6w0z98e'
	bl_label = 'Get material'
	bl_icon = 'SHADING_RENDERED'
	material: bpy.props.PointerProperty(name='Material', type=bpy.types.Material, update=Node.prop_updated)
	at_least_one_input_socket_required = False

	def sc_init(self, context):
		self.create_output(SOCKET_Material, is_new_data_output=False)

	def are_all_inputs_correct(self):
		return self.material != None

	def sc_draw_buttons(self, context, layout):
		layout.prop(self, 'material', text='')

	@Node.get_data_first
	def get_materials(self, *args, **kwargs):
		if not self.material:
			self.print('No material specified')
			raise ValueError
		else:
			material = DATA_Material(self.material)
			self.print(f'Material "{self.material.name}" found')
		return [material]


def stack_2_materials(bottom_mat, top_mat, new_mat_name, mix_image):
	assert bottom_mat.use_nodes
	assert top_mat.use_nodes

	to_material = bpy.data.materials.new(new_mat_name)
	to_material.use_nodes = True
	to_material.node_tree.nodes.clear()

	# input mat 1
	# create new nodes, compute bounding box, get output nodes
	bbox_bottom_mat = {"minX": math.inf, "maxX": -math.inf, "minY": math.inf, "maxY": -math.inf}
	output_node_bottom_mat = None
	for source_node in bottom_mat.node_tree.nodes:
		new_node = to_material.node_tree.nodes.new(source_node.bl_idname)
		if source_node.bl_idname == "ShaderNodeOutputMaterial":
			# assert not output_node_bottom_mat
			output_node_bottom_mat = new_node
		# copy node attributes
		for attr_name in dir(source_node):
			attr_value = getattr(source_node, attr_name)
			try:
				setattr(new_node, attr_name, attr_value)
			except:
				pass
		# copy default node input values
		for src_node_input in source_node.inputs:
			try:
				new_node.inputs[src_node_input.name].default_value = src_node_input.default_value
			except:
				pass

		if bbox_bottom_mat["minX"] > source_node.location[0]:
			bbox_bottom_mat["minX"] = source_node.location[0]
		if bbox_bottom_mat["maxX"] < source_node.location[0] + source_node.dimensions[0]:
			bbox_bottom_mat["maxX"] = source_node.location[0] + source_node.dimensions[0]
		if bbox_bottom_mat["minY"] > source_node.location[1] - source_node.dimensions[1]:
			bbox_bottom_mat["minY"] = source_node.location[1] - source_node.dimensions[1]
		if bbox_bottom_mat["maxY"] < source_node.location[1]:
			bbox_bottom_mat["maxY"] = source_node.location[1]

	assert output_node_bottom_mat

	# create new links
	for link in bottom_mat.node_tree.links:
		from_node = to_material.node_tree.nodes[link.from_node.name]
		# from_socket = from_node.outputs[link.from_socket.name] <= Bug quand plusieurs sockets ont le même nom
		# from_socket = from_node.outputs[int(link.from_socket.path_from_id()[-2])] <= bug pour les indices de socket >= 10
		from_socket = from_node.outputs[int(re.findall(r"\d+", link.from_socket.path_from_id())[-1])]
		to_node = to_material.node_tree.nodes[link.to_node.name]
		# to_socket = to_node.inputs[link.to_socket.name]
		# to_socket = to_node.inputs[int(link.to_socket.path_from_id()[-2])]
		to_socket = to_node.inputs[int(re.findall(r"\d+", link.to_socket.path_from_id())[-1])]
		to_material.node_tree.links.new(from_socket, to_socket)

	# input mat 2
	bbox_top_mat = {"minX": math.inf, "maxX": -math.inf, "minY": math.inf, "maxY": -math.inf}
	old_nodes_names_to_new_names = {}
	output_node_top_mat = None
	for source_node in top_mat.node_tree.nodes:
		new_node = to_material.node_tree.nodes.new(source_node.bl_idname)
		if source_node.bl_idname == "ShaderNodeOutputMaterial":
			# assert not output_node_top_mat
			output_node_top_mat = new_node
		# copy node attributes
		for attr_name in dir(source_node):
			attr_value = getattr(source_node, attr_name)
			try:
				setattr(new_node, attr_name, attr_value)
			except:
				pass
		# copy default node input values
		for src_node_input in source_node.inputs:
			try:
				new_node.inputs[src_node_input.name].default_value = src_node_input.default_value
			except:
				pass
		# save new unique names for linking next
		old_nodes_names_to_new_names[source_node.name] = new_node.name

		if bbox_top_mat["minX"] > source_node.location[0]:
			bbox_top_mat["minX"] = source_node.location[0]
		if bbox_top_mat["maxX"] < source_node.location[0] + source_node.dimensions[0]:
			bbox_top_mat["maxX"] = source_node.location[0] + source_node.dimensions[0]
		if bbox_top_mat["minY"] > source_node.location[1] - source_node.dimensions[1]:
			bbox_top_mat["minY"] = source_node.location[1] - source_node.dimensions[1]
		if bbox_top_mat["maxY"] < source_node.location[1]:
			bbox_top_mat["maxY"] = source_node.location[1]

	assert output_node_top_mat

	# create new links
	for link in top_mat.node_tree.links:
		from_node = to_material.node_tree.nodes[old_nodes_names_to_new_names[link.from_node.name]]
		# from_socket = from_node.outputs[link.from_socket.name] <= Bug quand plusieurs sockets ont le même nom
		# from_socket = from_node.outputs[int(link.from_socket.path_from_id()[-2])] <= bug pour les indices de socket >= 10
		from_socket = from_node.outputs[int(re.findall(r"\d+", link.from_socket.path_from_id())[-1])]
		to_node = to_material.node_tree.nodes[old_nodes_names_to_new_names[link.to_node.name]]
		# to_socket = to_node.inputs[link.to_socket.name]
		# to_socket = to_node.inputs[int(link.to_socket.path_from_id()[-2])]
		to_socket = to_node.inputs[int(re.findall(r"\d+", link.to_socket.path_from_id())[-1])]
		to_material.node_tree.links.new(from_socket, to_socket)

	# On décale les deux groupes de nodes, et on les centre
	hauteur_nodes_totale = (100 + bbox_bottom_mat["maxY"] - bbox_bottom_mat["minY"]) + (bbox_top_mat["maxY"] - bbox_top_mat["minY"])
	for source_node in top_mat.node_tree.nodes:
		node = to_material.node_tree.nodes[old_nodes_names_to_new_names[source_node.name]]
		node.location[0] = node.location[0] - ((bbox_top_mat["maxX"] - bbox_top_mat["minX"]) / 2) - bbox_top_mat["minX"]
		node.location[1] = node.location[1] - bbox_top_mat["minY"] + (100 + bbox_bottom_mat["maxY"] - bbox_bottom_mat["minY"]) - hauteur_nodes_totale / 2
	for source_node in bottom_mat.node_tree.nodes:
		node = to_material.node_tree.nodes[source_node.name]
		node.location[0] = node.location[0] - ((bbox_bottom_mat["maxX"] - bbox_bottom_mat["minX"]) / 2) - bbox_bottom_mat["minX"]
		node.location[1] = node.location[1] - bbox_bottom_mat["minY"] - hauteur_nodes_totale / 2

	# make sure only one output node, and mix both initial outputs into one
	surface_output_socket_top_mat = output_node_top_mat.inputs["Surface"].links[0].from_socket
	print("Top mat final node:", surface_output_socket_top_mat.node.name)
	surface_output_socket_bottom_mat = output_node_bottom_mat.inputs["Surface"].links[0].from_socket
	print("Bottom mat final node:", surface_output_socket_bottom_mat.node.name)
	to_material.node_tree.nodes.remove(output_node_bottom_mat)
	mix_shader_node = to_material.node_tree.nodes.new("ShaderNodeMixShader")
	mix_shader_node.inputs[0].default_value = random.random()
	# to_material.node_tree.links.remove(surface_output_socket_top_mat.links[0])
	to_material.node_tree.links.new(surface_output_socket_top_mat, mix_shader_node.inputs[1])
	to_material.node_tree.links.new(surface_output_socket_bottom_mat, mix_shader_node.inputs[2])
	to_material.node_tree.links.new(mix_shader_node.outputs[0], output_node_top_mat.inputs[0])
	mix_shader_node.location[0] = 300 + max((bbox_bottom_mat["maxX"] - bbox_bottom_mat["minX"]) / 2, (bbox_top_mat["maxX"] - bbox_top_mat["minX"]) / 2)
	mix_shader_node.location[1] = -100
	output_node_top_mat.location[0] = mix_shader_node.location[0] + 300
	output_node_top_mat.location[1] = 0

	# add a mix image node
	mix_texture_node = to_material.node_tree.nodes.new("ShaderNodeTexImage")
	mix_texture_node.image = mix_image
	mix_texture_node.location[0] = mix_shader_node.location[0] - 300
	mix_texture_node.location[1] = mix_shader_node.location[1] + 300
	to_material.node_tree.links.new(mix_texture_node.outputs[0], mix_shader_node.inputs[0])

	'''total_links = 0
	for node in to_material.node_tree.nodes:
		for input in node.inputs:
			total_links += len(input.links)
			assert len(input.links) <= 1
		for output in node.outputs:
			assert len(input.links) <= 1

	assert len(to_material.node_tree.links) == total_links
	print(total_links)'''

	return to_material


class Stack_Materials(bpy.types.Node, Node, DATA_GETTER_NODE_Material):
	bl_idname = 'sc_node_r1mro9b04dnf33crlh85'
	bl_label = 'Stack materials'
	bl_icon = 'SHADING_RENDERED'

	def sc_init(self, context):
		self.create_output(SOCKET_Material, is_new_data_output=True)
		self.create_input(SOCKET_Image, is_required=True, label="Mask image")
		self.create_input(SOCKET_Material, is_required=True, label="Top material")
		self.create_input(SOCKET_Material, is_required=True, label="Bottom material")

	def _get_materials_necessary_data(self, *args, **kwargs):
		input_socket_img_mask: SOCKET_Image = self.inputs[0]
		input_socket_mat_top: SOCKET_Material = self.inputs[1]
		input_socket_mat_bottom: SOCKET_Material = self.inputs[2]
		return (
			input_socket_img_mask.get_input_images(),
			input_socket_mat_top.get_input_materials(),
			input_socket_mat_bottom.get_input_materials(),
		)

	@Node.get_data_first
	def get_materials(self, data: Tuple[List[DATA_Image], List[DATA_Material], List[DATA_Material]], *args, **kwargs):
		new_material = stack_2_materials(data[1][0].bl_material, data[2][0].bl_material, 'Stacked material', data[0][0].bl_img)
		return [DATA_Material(new_material)]
