Google Safesearch Mini V2

FredZhang7
Clasificación de imagen

Google Safesearch Mini V2 es un clasificador de imágenes multi-clase ultra preciso que detecta con precisión contenido explícito. El modelo utilizó la arquitectura InceptionResNetV2 y un conjunto de datos de aproximadamente 3,400,000 imágenes obtenidas aleatoriamente de internet, algunas de las cuales fueron generadas mediante argumentación de datos. Los datos de entrenamiento y validación se obtienen de Google Images, Reddit, Kaggle e Imgur, y fueron clasificados como seguros o nsfw por empresas, Google SafeSearch y moderadores. Después de entrenar el modelo durante 5 épocas con pérdida de entropía cruzada y evaluarlo en los conjuntos de datos de entrenamiento y validación, se hicieron correcciones necesarias y el modelo se entrenó durante 8 épocas adicionales. El modelo se perfeccionó posteriormente con 15 conjuntos de datos adicionales de Kaggle durante una época, y luego se entrenó por una última época con una combinación de datos de entrenamiento y prueba. Esto resultó en una precisión del 97% tanto en los datos de entrenamiento como en los de validación.

Como usar

pip install --upgrade torchvision

import torch, os
from torchvision import transforms
from PIL import Image
import urllib.request
import timm

image_path = "https://www.allaboutcats.ca/wp-content/uploads/sites/235/2022/03/shutterstock_320462102-2500-e1647917149997.jpg"
device = "cuda"

def preprocess_image(image_path):
# Define image pre-processing transforms
transform = transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
if image_path.startswith('http://') or image_path.startswith('https://'):
import requests
from io import BytesIO
response = requests.get(image_path)
img = Image.open(BytesIO(response.content)).convert('RGB')
else:
img = Image.open(image_path).convert('RGB')
img = transform(img).unsqueeze(0)
img = img.cuda() if device.lower() == "cuda" else img.cpu()
return img

def eval():
model = timm.create_model("hf_hub:FredZhang7/google-safesearch-mini-v2", pretrained=True)
model.to(device)
img = preprocess_image(image_path)

with torch.no_grad():
out = model(img)
_, predicted = torch.max(out.data, 1)
classes = {
0: 'nsfw_gore',
1: 'nsfw_suggestive',
2: 'safe'
}
print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')

if __name__ == '__main__':
eval()

Funcionalidades

Clasificador de imágenes multi-clase ultra preciso
Detecta contenido explícito
Utiliza la arquitectura InceptionResNetV2
Entrenado con aproximadamente 3,400,000 imágenes
Precisión del 97% en datos de entrenamiento y validación

Casos de uso

Herramienta de moderación de redes sociales
Filtro de conjuntos de datos