"""Windows / DirectML CLIP worker for occlusion scoring. Reads a queue.json staged by /opt/face-sets/work/filter_occlusions.py (WSL side), runs open_clip ViT-L-14 (dfn2b_s39b) on each PNG via torch-directml on the AMD Vega, and writes a scores.json with mask + sunglasses softmax probabilities. CLI: py -3.12 clip_worker.py [--limit N] [--batch 8] queue.json shape: list of objects {"wsl_path": "...", "win_path": "E:\\...\\faceset_NNN\\faces\\NNNN.png", "faceset": "faceset_NNN", "file": "NNNN.png"} scores.json shape: {"model": "ViT-L-14/dfn2b_s39b", "logit_scale": 100.0, "prompts": {...}, "results": [{"wsl_path": "...", "faceset": "...", "file": "...", "mask": float, "sunglasses": float}], "processed": [wsl_path, ...]} """ from __future__ import annotations import argparse import json import os import sys import time import warnings from pathlib import Path # DML emits a verbose UserWarning per attention call -- silence at import time warnings.filterwarnings("ignore", category=UserWarning) import torch import torch_directml import open_clip from PIL import Image MODEL_NAME = "ViT-L-14" PRETRAINED = "dfn2b_s39b" # kept in sync with /opt/face-sets/work/filter_occlusions.py PROMPTS PROMPTS = { "mask": { "pos": [ "a photo of a person wearing a surgical face mask", "a photo of a person wearing an FFP2 respirator covering mouth and nose", "a photo of a person wearing a cloth face mask", "a face partially covered by a medical mask", "a person whose mouth and nose are hidden by a face mask", ], "neg": [ "a photo of a person's face with mouth and nose clearly visible", "a clear, unobstructed photo of a face", "a photo of a face without any mask or covering", "a portrait of a person showing their full face", "a photo of a person with a beard and visible mouth", ], }, "sunglasses": { "pos": [ "a face with dark sunglasses covering the eyes", "a portrait with the eyes hidden behind opaque sunglasses", "a person wearing dark sunglasses over their eyes, eyes not visible", "a face where the eyes are completely concealed by tinted lenses", "a close-up portrait wearing aviator sunglasses on the eyes", ], "neg": [ "a portrait with both eyes clearly visible and uncovered", "a face with sunglasses pushed up on the forehead, eyes visible below", "a face with sunglasses resting on top of the head, eyes visible", "a person with sunglasses hanging from their shirt, eyes visible", "a face wearing clear prescription eyeglasses with visible eyes", "a portrait with no eyewear and visible eyes", ], }, } FLUSH_EVERY = 100 def load_existing(out_path: Path): if not out_path.exists(): return None, set() try: d = json.loads(out_path.read_text()) processed = set(d.get("processed", [])) return d, processed except Exception as e: print(f"[warn] could not parse existing {out_path}: {e}; starting fresh", file=sys.stderr) return None, set() def save_atomic(out_path: Path, data: dict): tmp = out_path.with_suffix(".tmp.json") tmp.write_text(json.dumps(data, indent=2)) os.replace(tmp, out_path) @torch.no_grad() def build_text_features(model, tokenizer, device): out = {} for attr, sides in PROMPTS.items(): feats = {} for side in ("pos", "neg"): tokens = tokenizer(sides[side]).to(device) f = model.encode_text(tokens) f = f / f.norm(dim=-1, keepdim=True) mean = f.mean(dim=0) feats[side] = mean / mean.norm() out[attr] = (feats["pos"], feats["neg"]) return out def main(): ap = argparse.ArgumentParser() ap.add_argument("queue", type=Path) ap.add_argument("out", type=Path) ap.add_argument("--limit", type=int, default=None) ap.add_argument("--batch", type=int, default=8) args = ap.parse_args() queue = json.loads(args.queue.read_text()) print(f"[queue] {len(queue)} entries from {args.queue}") args.out.parent.mkdir(parents=True, exist_ok=True) existing, processed = load_existing(args.out) if existing: print(f"[resume] {len(processed)} entries already scored") results = existing.get("results", []) else: results = [] pending = [e for e in queue if e["wsl_path"] not in processed] if args.limit is not None: pending = pending[: args.limit] print(f"[pending] {len(pending)} entries to score") if not pending: print("[done] nothing to do") return device = torch_directml.device() print(f"[load] {MODEL_NAME}/{PRETRAINED} on {torch_directml.device_name(0)}") t0 = time.time() model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED) tokenizer = open_clip.get_tokenizer(MODEL_NAME) model = model.to(device).eval() logit_scale = float(model.logit_scale.exp().detach().cpu()) print(f"[load] ready in {time.time()-t0:.1f}s logit_scale={logit_scale:.2f}") text_feats = build_text_features(model, tokenizer, device) def flush(): save_atomic(args.out, { "model": f"{MODEL_NAME}/{PRETRAINED}", "logit_scale": logit_scale, "prompts": PROMPTS, "results": results, "processed": sorted(processed), }) n_done_this_run = 0 n_load_err = 0 last_flush = time.time() t_start = time.time() for i in range(0, len(pending), args.batch): chunk = pending[i:i + args.batch] imgs = [] keep = [] for entry in chunk: try: img = Image.open(entry["win_path"]).convert("RGB") imgs.append(preprocess(img)) keep.append(entry) except Exception as e: print(f"[skip] {entry['win_path']}: {e}", file=sys.stderr) n_load_err += 1 processed.add(entry["wsl_path"]) if not imgs: continue x = torch.stack(imgs).to(device) with torch.no_grad(): feats = model.encode_image(x) feats = feats / feats.norm(dim=-1, keepdim=True) scores_per_attr = {} for attr, (pos, neg) in text_feats.items(): sims = torch.stack([feats @ pos, feats @ neg], dim=1) * logit_scale probs = sims.softmax(dim=1)[:, 0].detach().cpu().tolist() scores_per_attr[attr] = probs for j, entry in enumerate(keep): results.append({ "wsl_path": entry["wsl_path"], "faceset": entry["faceset"], "file": entry["file"], "mask": round(scores_per_attr["mask"][j], 4), "sunglasses": round(scores_per_attr["sunglasses"][j], 4), }) processed.add(entry["wsl_path"]) n_done_this_run += 1 if (n_done_this_run % FLUSH_EVERY < args.batch) or (time.time() - last_flush) > 30.0: flush() last_flush = time.time() elapsed = time.time() - t_start rate = n_done_this_run / max(0.1, elapsed) eta_min = (len(pending) - n_done_this_run) / max(0.1, rate) / 60.0 print(f"[score] {n_done_this_run}/{len(pending)} " f"rate={rate:.2f} img/s eta={eta_min:.1f}min " f"load_err={n_load_err}", flush=True) flush() elapsed = time.time() - t_start print(f"[done] {n_done_this_run} scored, {n_load_err} load errors, " f"{elapsed:.1f}s ({n_done_this_run/max(0.1,elapsed):.2f} img/s) -> {args.out}") if __name__ == "__main__": main()