在进行有关草图相关模型训练时,如何构造一个符合人类绘制习惯的草图数据集成为了能否让模型学习到一个较好性能的关键因素。对草图而言,一般来说的任务有:草图检索,草图生成(2D or 3D) 等。本篇旨在整理总结截止到当前的一些数据集以及相关的合成方法。
整体而言,构造一个草图数据集会有如下几种角度:
- 人工绘制并收集标注
- 自动合成草图
- 风格迁移与域适配
第一种构建方式带来了一些现有的草图数据集,也是最符合人类绘图习惯的数据集。后两种方式则是目前草图合成的主要实现思路,具体再根据输入模态不同进行进一步拆分。
现有草图数据集
QuickDraw
该草图数据集是由 Google
于 2016 年发布的游戏 Quick Draw!
中整理得到。在这款游戏中,玩家会根据提示的物件名字画出其图像,然后由另一方的玩家来看图猜测所绘制的是什么内容。可以将其理解为涂鸦版的“你来比划我来猜”。总计包含有 345
个类别总计 500M
个图像。该数据集中只包含类别信息与手绘草图,没有其他对应的实际物体图片。常用来作为草图检索任务的训练数据集。但由于全部收集自真实手绘,非常符合人类手绘习惯,与实际物体图存在一定的 deformation
。
SketchCOCO
该数据集采用人工挑选的方法,从其他数据集中挑选出相同类别的草图归到一类,对每一张前景图,采用分割算法进行前景和背景的分割,并挑选与前景最为匹配的草图成对,形成物体级草图与实物的对应。场景级同理实现。包含有两部分:
- 物体级数据:包含有
14
个类别在内的总计20198
对 <前景草图,前景物体图,前景边缘图> 在内的数据对。 - 场景级数据:包含有
14081
对 <前景物体图,背景草图,场景图> 数据对, 以及14081
对 <场景草图, 场景图> 数据对。
对物体级草图而言,由于采用分割算法,因此对应的实物图边缘部分不算很明确,且类别少而单类别下的草图数据足够多,具体可见下表:
cat | dog | zebra | giraffe | horse | cow | elephant | sheep | car |
---|---|---|---|---|---|---|---|---|
659 | 777 | 401 | 246 | 773 | 628 | 398 | 369 | 411 |
Children Drawing
有关儿童手绘草图数据集目前只找到一篇相关工作:Amateur Drawings,主要实现的是草图动画的生成。数据集下载命令:
# download annotations (~275Mb)
wget https://dl.fbaipublicfiles.com/amateur_drawings/amateur_drawings_annotations.json
# download images (~50Gb)
wget https://dl.fbaipublicfiles.com/amateur_drawings/amateur_drawings.tar
从图中不难看出,该数据集的特点是:真实来源于儿童的绘画图片,但是相比其他最为抽象,很多图片并没有实物对应。与此同时还会有一定的色彩信息。没有标签与描述信息,只包含一张图片。因此在该数据集上几乎没有人做草图检索与其他模态生成,大多研究内容为让草图动起来。
DifferSketch
该数据集中稿 SIGGRAPH Asia 2022
,是一个用于研究专业用户和新手用户在绘制 3D 物体时差异的自由手绘草图数据集。以往的草图数据集大多规模较小,涵盖的对象或类别有限,且主要包含专业用户的自由手绘草图,难以比较专业用户和新手用户的绘图差异。该数据集构建了一个包含专业用户和新手用户绘制相同 3D 物体的大型自由手绘草图数据集,以便深入分析两者在绘图过程中的差异,为设计更友好、更有效的草图交互系统提供依据。
[
数据集从九个类别中选取了总计 136
个 3D
模型,这些类别包括动物、动物头部、椅子、人脸、工业部件、灯具、几何体、鞋子和车辆等。每个 3D 模型根据其类别和结构特点,被渲染成 2 到 3 个不同视角的图像,总共生成了 362
张参考图像,多视角选取随机。文件下载后的主要组织格式有:
[Category]
- {obj}: 3D models
- {mv-img}: Multi-view reference images rendered from 3D models
- {tracing}: The tracings over the reference images
- [Type]_png: The drawings rasterized from .json files
- [Type]_json: The sequence data includes pressure, timeslot, drawing path, etc.
* [Category] indicates the category name, e.g. {Animal_Head}.
* [Type] indicates {original} type (user drawings) or registered types ({gloabl} -> sketch-level, {stroke} -> stroke-level and {reg} -> pixel-level) from the user drawings
可以通过谷歌云盘下载,下载链接。
合成数据方法
在很多时候,现有的草图数据集并不能很好地满足实际需求,在此时通过一套合理的草图生产方法造出符合预期的数据集则显得尤为重要。在这里首先需要区分清楚的是,在当前问题内该如何理解草图(侧重点在哪里),需要的草图和真实物体之间的变形程度会有多少。以下整理了一些现有的合成方法并做简单分析。
3D 物体渲染合成
通过这种方式合成草图的论文主要有:
- DeepShapeSketch : Generating hand drawing sketches from 3D objects (IJCNN 2019)
- Neural Contours: Learning to Draw Lines from 3D Shapes. (CVPR 2020)
- Cloud2Curve: Generation and Vectorization of Parametric Sketches (CVPR 2021)
- Neural Strokes: Stylized Line Drawing of 3D Shapes (ICCV 2021)
- Learning a Style Space for Interactive Line Drawing Synthesis from Animated 3D Models (PG 2022)
- CAD2Sketch: Generating Concept Sketches from CAD Sequences (SIGGRAPH Asia 2022)
不同的方法对草图细节的关注点也不同。早期方法(如 DeepShapeSketch
和 Neural Contours
)主要通过 CNN
和图像映射来获得草图图像,侧重于直观还原3D物体的轮廓与细节;后续方法(如 Cloud2Curve
和 CAD2Sketch
)逐渐转向参数化方法,直接从点云或 CAD
数据提取控制点,实现矢量化草图生成,便于编辑和后期处理。而在关注点上也各有不同:
Neural Contours
强调边缘的连续平滑;Neural Strokes
通过对笔触进行建模(厚度,变形与颜色等)来提供更为细腻的手绘感,这种分解方法有助于理解草图的局部结构;CAD2Sketch
在处理精细几何信息时,兼顾了自动抽象和人类草图的随意性之间的平衡。
以上均是不同论文对草图的不同理解。切换到生成领域而言,已经有论文 《 Block and Detail: Scaffolding Sketch-to-Image Generation》 以及 《 DrawingSpinUp: 3D Animation from Single Character Drawings》 等都提及,草图的边缘轮廓部分会十分影响生成的结果质量。如果从这个角度出发的话,那么构建的草图就可以优先表达物体完整的结构,之后选择性表达一些物体的关键信息,而忽视对草图厚度,颜色等的刻画。基于以上思路,可以借助 SVG
实现草图的合成:
首先通过物体级图片(可渲染得到,也可以前后景分割得到)得到对应的线稿图,之后将其矢量化转成 SVG
格式以刻画描述每一笔 stroke
的长度与形状,在此基础上加以筛选保留关键笔画。上图所展示的流程中 Line art model
采用的是 AniDoc 中的一部分,矢量化采用的是当前的 sota 模型 《 Deep Sketch Vectorization via Implicit Surface Extraction》,再根据笔画长度与到边缘的距离进行筛选:
给出筛选笔画的代码(有点长不想看可以折叠 qwq):
import svgwrite
from svgpathtools import svg2paths2, wsvg
import numpy as np
import os
import cv2
import cairosvg
from PIL import Image
from svgpathtools import svg2paths2
from svgpathtools import Path as SVGPath
def extract_edge_distance_map(png_path, size=(512, 512), white_threshold=240):
"""
提取边缘距离图,支持RGBA和RGB图片
参数:
png_path: 输入图片路径
size: 目标尺寸
white_threshold: 判定白色的阈值(0-255)
"""
# Load image
img = Image.open(png_path)
img = img.resize(size)
if img.mode == "RGBA":
# 处理带透明通道的图片
alpha = np.array(img)[:, :, 3]
mask = (alpha > 0).astype(np.uint8) * 255
else:
# 处理不带透明通道的图片
img_array = np.array(img.convert("RGB"))
# 判断白色像素
is_white = np.all(img_array >= white_threshold, axis=2)
mask = (~is_white).astype(np.uint8) * 255
# 保存并返回mask
Image.fromarray(mask, mode='L').save("mask_map.png")
# 计算距离变换
distance_map = cv2.distanceTransform(
mask,
distanceType=cv2.DIST_L2, # 使用欧氏距离
maskSize=cv2.DIST_MASK_PRECISE # 使用精确计算
)
# 归一化距离图到0-255范围
distance_map = cv2.normalize(
distance_map,
None,
0,
255,
cv2.NORM_MINMAX,
dtype=cv2.CV_8U
)
# 保存距离图
Image.fromarray(distance_map).save("distance_map.png")
return distance_map
def sample_path_points(path: SVGPath, num_points=100):
"""Sample `num_points` along an SVG path."""
return [path.point(t) for t in np.linspace(0, 1, num_points)]
def normalize_svg_point(point, viewbox, target_size=(512, 512)):
"""将 SVG 坐标标准化为 512x512 图像坐标"""
min_x, min_y, width, height = viewbox
x = (point.real - min_x) / width * target_size[0]
y = (point.imag - min_y) / height * target_size[1]
return int(x), int(y)
def process_distances(valid_distances, method='minmax'):
"""
处理距离数据,使其分布更均匀
参数:
valid_distances: 有效距离列表
method: 处理方法 ('log', 'robust', 'minmax', 'zscore')
"""
distances = np.array(valid_distances)
if method == 'log':
# 对数变换
processed = np.log1p(distances)
elif method == 'robust':
# 稳健缩放(基于分位数)
q1, q3 = np.percentile(distances, [25, 75])
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
# 过滤异常值并缩放
mask = (distances >= lower_bound) & (distances <= upper_bound)
processed = distances[mask]
# 重新缩放到0-1范围
processed = (processed - processed.min()) / (processed.max() - processed.min())
elif method == 'minmax':
# 最小最大缩放
processed = (distances - distances.min()) / (distances.max() - distances.min())
elif method == 'zscore':
# Z-score标准化
mean = np.mean(distances)
std = np.std(distances)
processed = (distances - mean) / std
return processed
def filter_by_edge_and_length(svg_path, png_path, k=0.5, s=0.5,
mode='and', output_svg_path="filtered_final.svg",
canvas_size=(512, 512), soft_edge=True):
from svgpathtools import svg2paths2, wsvg
# 读取路径与 SVG 属性
paths, attributes, svg_attr = svg2paths2(svg_path)
viewbox = [float(x) for x in svg_attr['viewBox'].split()] # [min_x, min_y, width, height]
# 距离图
distance_map = extract_edge_distance_map(png_path, size=canvas_size)
edge_filtered_set = set()
length_filtered_set = set()
all_path_distances = []
all_path_lengths = []
for i, path in enumerate(paths):
# ---- 计算平均距离 ----
points = sample_path_points(path, num_points=200)
distances = []
for p in points:
x, y = normalize_svg_point(p, viewbox, canvas_size)
if 0 <= x < canvas_size[0] and 0 <= y < canvas_size[1]:
distances.append(distance_map[y, x])
avg_distance = np.mean(distances) if distances else None
all_path_distances.append(avg_distance)
# ---- 计算路径长度 ----
length = path.length() * 10
all_path_lengths.append(length)
# 计算距离阈值
all_path_distances = process_distances(all_path_distances, method='log')
valid_distances = [d for d in all_path_distances if d is not None]
a_d, b_d = min(valid_distances), max(valid_distances)
distance_threshold = a_d + k * (b_d - a_d)
# print(f"max={b_d}, min={a_d}, thres={distance_threshold}", end='')
# print(len(valid_distances), valid_distances)
# 计算长度阈值
all_path_lengths = process_distances(all_path_lengths)
a_l, b_l = max(all_path_lengths), min(all_path_lengths)
length_threshold = b_l + s * (a_l - b_l)
# lengths = process_distances(np.array([l for l in all_path_lengths if l is not None]))
# mean_length = np.mean(lengths) # 平均长度
# std_length = np.std(lengths) # 标准差
# length_threshold = mean_length + s * std_length
# print(len(all_path_lengths), all_path_lengths)
# 满足距离条件的路径索引
nums = 0
for idx, (d, l) in enumerate(zip(all_path_distances, all_path_lengths)):
if l > b_l + 0.05 * (a_l - b_l):
nums += 1
if l >= length_threshold:
length_filtered_set.add(idx)
if d is not None and d <= distance_threshold:
if soft_edge and l <= b_l + 0.1 * (a_l - b_l):
continue
edge_filtered_set.add(idx)
print(f"nums = {nums}, ", end='')
# 合并结果
if mode == 'and':
final_indices = edge_filtered_set & length_filtered_set
elif mode == 'or':
final_indices = edge_filtered_set | length_filtered_set
else:
raise ValueError("mode 应为 'and' 或 'or'")
# 生成新的路径列表
filtered_paths = [paths[i] for i in final_indices]
filtered_attributes = [attributes[i] for i in final_indices]
# 保存新 SVG
wsvg(filtered_paths, attributes=filtered_attributes, svg_attributes=svg_attr, filename=output_svg_path)
print(f"共保留 {len(filtered_paths)} 条路径,保存至:{output_svg_path}")
def create_image_grid_enhanced(svg_dir, output_path, ks, ss,
single_size=(256, 256), padding=10,
background_color='white'):
"""
网格图像创建函数
功能:
- 图像间距
- 背景颜色选择
- 可选的标签
"""
grid_size = (len(ks), len(ss))
rows, cols = grid_size
width, height = single_size
# 计算含padding的总尺寸
total_width = width * cols + padding * (cols + 1)
total_height = height * rows + padding * (rows + 1)
# 创建空白画布
grid_image = Image.new('RGB',
(total_width, total_height),
color=background_color)
print(f"final ks = {ks} , ss = {ss}")
for i, k in enumerate(ks):
for j, s in enumerate(ss):
svg_name = f"k_{k}_s_{s}.svg"
svg_path = os.path.join(svg_dir, svg_name)
temp_png = f"temp_{i}_{j}.png"
try:
# 转换SVG到PNG
cairosvg.svg2png(
url=svg_path,
write_to=temp_png,
output_width=width,
output_height=height,
# background_color="white"
)
img = Image.open(temp_png)
# 计算带padding的位置
x = j * (width + padding) + padding
y = i * (height + padding) + padding
# 粘贴到网格中
grid_image.paste(img, (x, y))
os.remove(temp_png)
except Exception as e:
print(f"处理 {svg_name} 时出错: {e}")
grid_image.save(output_path)
print(f"网格图像已保存到 {output_path}")
if __name__ == '__main__':
svg_path = f'xxx.svg'
png_path = f'mask.png' # 用于提取边缘图
output_path = "./assets/diff/svg"
ks = [0.9, 0.7, 0.5, 0.3] # 距离
ss = [0.1, 0.2, 0.3, 0.4] # 长度
for k in ks:
for s in ss:
out_path = os.path.join(output_path, f"k_{k}_s_{s}.svg")
print(f"k = {k}, s = {s}, out_path = {out_path}...")
filter_by_edge_and_length(svg_path, png_path, k=k, s=s, mode='or', output_svg_path=out_path)
print("generate the final image...")
create_image_grid_enhanced(output_path, f".diff/tot.png", ks, ss)
当然,除去对两个标签进行筛选外,《 Block and Detail: Scaffolding Sketch-to-Image Generation》中也给出了另一种实现方式:先读取前景掩码和对应SVG
图像,从 SVG
中提取路径信息生成单路径草图,并利用前景 mask
轮廓对草图中的每一笔进行距离计算和排序,随机选取部分线条组合成最终草图,经过后处理后保存为RGB格式的训练数据。相关代码在 这里。