Source code for text_renderer.dataset

import os
import json
from typing import Dict

import lmdb
import cv2
import numpy as np


[docs]class Dataset: def __init__(self, data_dir: str, jpg_quality: int = 95): self.data_dir = data_dir if not os.path.exists(data_dir): os.makedirs(data_dir) self.jpg_quality = jpg_quality def encode_param(self): return [int(cv2.IMWRITE_JPEG_QUALITY), self.jpg_quality] def write(self, name: str, image: np.ndarray, label: str): pass
[docs] def read(self, name) -> Dict: """ Parameters ---------- name : str 000000001 Returns ------- dict : .. code-block:: bash { "image": ndarray, "label": "label", "size": [int_width, int_height] } """ pass
def read_count(self) -> int: pass def write_count(self, count: int): pass def close(self): pass def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()
[docs]class ImgDataset(Dataset): """ Save generated image as jpg file, save label and meta in json json file format: .. code-block:: bash { "labels": { "000000000": "test", "000000001": "text2" }, "sizes": { "000000000": [width, height], "000000001": [width, height], } "num-samples": 2, } """ LABEL_NAME = "labels.json" def __init__(self, data_dir: str): super().__init__(data_dir) self._img_dir = os.path.join(data_dir, "images") if not os.path.exists(self._img_dir): os.makedirs(self._img_dir) self._label_path = os.path.join(data_dir, self.LABEL_NAME) self._data = {"num-samples": 0, "labels": {}, "sizes": {}} if os.path.exists(self._label_path): with open(self._label_path, "r", encoding="utf-8") as f: self._data = json.load(f) def write(self, name: str, image: np.ndarray, label: str): img_path = os.path.join(self._img_dir, name + ".jpg") cv2.imwrite(img_path, image, self.encode_param()) self._data["labels"][name] = label height, width = image.shape[:2] self._data["sizes"][name] = (width, height) def read(self, name: str) -> Dict: img_path = os.path.join(self._img_dir, name + ".jpg") image = cv2.imread(img_path) label = self._data["labels"][name] size = self._data["sizes"][name] return {"image": image, "label": label, "size": size} def read_size(self, name: str) -> [int, int]: return self._data["sizes"][name] def read_count(self) -> int: return self._data.get("num-samples", 0) def write_count(self, count: int): self._data["num-samples"] = count def close(self): with open(self._label_path, "w", encoding="utf-8") as f: json.dump(self._data, f, indent=2, ensure_ascii=False)
[docs]class LmdbDataset(Dataset): """ Save generated image into lmdb. Compatible with https://github.com/PaddlePaddle/PaddleOCR Keys in lmdb: - image-000000001: image raw bytes - label-000000001: string - size-000000001: "width,height" """ def __init__(self, data_dir: str): super().__init__(data_dir) self._lmdb_env = lmdb.open(self.data_dir, map_size=1099511627776) # 1T self._lmdb_txn = self._lmdb_env.begin(write=True) def write(self, name: str, image: np.ndarray, label: str): self._lmdb_txn.put( self.image_key(name), cv2.imencode(".jpg", image, self.encode_param())[1].tobytes(), ) self._lmdb_txn.put(self.label_key(name), label.encode()) height, width = image.shape[:2] self._lmdb_txn.put(self.size_key(name), f"{width},{height}".encode()) def read(self, name: str) -> Dict: label = self._lmdb_txn.get(self.label_key(name)).decode() size_str = self._lmdb_txn.get(self.size_key(name)).decode() size = [int(it) for it in size_str.split(",")] image_bytes = self._lmdb_txn.get(self.image_key(name)) image_buf = np.frombuffer(image_bytes, dtype=np.uint8) image = cv2.imdecode(image_buf, cv2.IMREAD_UNCHANGED) return {"image": image, "label": label, "size": size} def read_size(self, name: str) -> [int, int]: """ Args: name: Returns: (width, height) """ size_key = f"size_{name}" size = self._lmdb_txn.get(size_key.encode()).decode() width = int(size.split[","][0]) height = int(size.split[","][1]) return width, height def read_count(self) -> int: count = self._lmdb_txn.get("num-samples".encode()) if count is None: return 0 return int(count) def write_count(self, count: int): self._lmdb_txn.put("num-samples".encode(), str(count).encode()) def image_key(self, name: str): return f"image-{name}".encode() def label_key(self, name: str): return f"label-{name}".encode() def size_key(self, name: str): return f"size-{name}".encode() def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self._lmdb_txn.__exit__(exc_type, exc_value, traceback) self._lmdb_env.close()
if __name__ == "__main__": # image = cv2.imread("f_004.jpg") # label = "test" with LmdbDataset("./test/train") as writer: # writer.write("test", image, label) writer.write_count(1) print(writer.read_count()) # with LmdbDataset("train") as ld: # data = ld.read("test") # cv2.imwrite("test.jpg", data["image"])