import cv2
import numpy as np
import sys
import os

def imread_chinese(path):
    """解决 OpenCV 无法读取中文路径的问题"""
    img = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_COLOR)
    return img

def image_to_svg(input_path, output_path, threshold=100, simplify_epsilon=0.005):
    """
    将图片转换为黑白轮廓 SVG 矢量图

    参数:
    input_path: 输入图片路径 (jpg/png)
    output_path: 输出 SVG 路径
    threshold: Canny 边缘检测的低阈值 (推荐 50~150)
    simplify_epsilon: 多边形简化的精度 (相对于图片周长的比例)
    """
    # 1. 读取图像（支持中文路径），转换为灰度图
    img = imread_chinese(input_path)
    if img is None:
        raise FileNotFoundError(f"无法读取图片: {input_path}")
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # 2. 降噪并提取边缘
    blurred = cv2.GaussianBlur(gray, (3, 3), 0)
    edges = cv2.Canny(blurred, threshold, threshold * 2)

    # 3. 寻找轮廓（只保留外部轮廓，也可用 RETR_LIST 保留全部）
    contours, hierarchy = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 4. 简化轮廓点（减少结点数量）
    simplified_contours = []
    for cnt in contours:
        peri = cv2.arcLength(cnt, True)
        approx = cv2.approxPolyDP(cnt, simplify_epsilon * peri, True)
        if len(approx) >= 3:  # 至少是三角形
            simplified_contours.append(approx)

    # 5. 生成 SVG 文件
    height, width = img.shape[:2]
    svg_content = f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" width="{width}" height="{height}">\n'
    svg_content += ' <rect width="100%" height="100%" fill="white"/>\n'
    svg_content += ' <g fill="black">\n'

    for cnt in simplified_contours:
        path_data = ""
        for i, point in enumerate(cnt):
            x, y = point[0]
            if i == 0:
                path_data += f"M{x},{y} "
            else:
                path_data += f"L{x},{y} "
        path_data += "Z"
        svg_content += f' <path d="{path_data}"/>\n'

    svg_content += ' </g>\n</svg>'

    # 写入文件（支持中文路径）
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(svg_content)

    print(f"转换完成，共提取 {len(simplified_contours)} 个轮廓，SVG 已保存至 {output_path}")


def parse_custom_args(argv):
    """
    解析你用的逗号分隔格式：
    python img2svg.py input.jpg,output.svg,threshold=80,simplify_epsilon=0.008
    """
    # 方式一: 用标准的 sys.argv 格式
    # python img2svg.py input.jpg output.svg --threshold 80 --simplify_epsilon 0.008
    if len(argv) >= 3 and not argv[1].endswith(',') and not ',' in argv[1]:
        import argparse
        parser = argparse.ArgumentParser(description='图片转 SVG 矢量图')
        parser.add_argument('input', help='输入图片路径')
        parser.add_argument('output', help='输出 SVG 路径')
        parser.add_argument('--threshold', type=int, default=80, help='Canny 边缘检测低阈值')
        parser.add_argument('--simplify_epsilon', type=float, default=0.008, help='多边形简化精度')
        args = parser.parse_args(argv[1:])
        return args.input, args.output, args.threshold, args.simplify_epsilon

    # 方式二: 你用的逗号分隔格式
    parts = argv[1].split(',')
    input_path = parts[0]
    output_path = parts[1]

    threshold = 80
    simplify_epsilon = 0.008

    for part in parts[2:]:
        if '=' in part:
            key, val = part.split('=', 1)
            key = key.strip().lower()
            val = val.strip()
            if key == 'threshold':
                threshold = int(val)
            elif key in ('simplify_epsilon', 'simplify_epslon'):  # 兼容拼写错误
                simplify_epsilon = float(val)

    return input_path, output_path, threshold, simplify_epsilon


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("用法:")
        print("  python img2svg.py input.jpg output.svg --threshold 80 --simplify_epsilon 0.008")
        print("  或: python img2svg.py input.jpg,output.svg,threshold=80,simplify_epsilon=0.008")
        sys.exit(1)

    input_path, output_path, threshold, simplify_epsilon = parse_custom_args(sys.argv)
    image_to_svg(input_path, output_path, threshold=threshold, simplify_epsilon=simplify_epsilon)
