AI

VercelデプロイのためCDNのTensorflow.jsを使ってみた。または「AIを駆使した動物顔判定アプリの開発(完結編)」

先日、以下の2つの記事で「動物顔判定アプリ」の作成について解説しました。

AIを駆使して動物顔判定アプリを作成してみた。(前編:モデル作成)前回、v0とCursorを使用していくつかのアプリを試作しましたが、今回は機械学習を組み合わせてアプリを作成してみたいと思います。 ...
AIを駆使して動物顔判定アプリを作成してみた。(後編:UI 作成)前回の記事では、動物判定の分類器の作成までを行いました。 https://dodotechno.com/animal-face-j...

このアプリは、画像を解析してユーザーの顔が「犬顔」「猫顔」「猿顔」「豚顔」「蛙顔」など、どの動物に似ているかを判定するサービスです。(ちなみに、実際には「動物」を判定しているだけです。あと、「蛙は動物じゃないやろ」というツッコミはご遠慮ください。)

フロントエンドはNext.js、バックエンドはPythonのPyTorchを使用し、Google Colabで動物の画像から学習したモデルを使って予測を行っています。さらに、ソースコードの大部分をv0とCursor Composerで書いています。(詳細はリンク先の記事をご参照ください。)

そして今回、このアプリをVercelにデプロイしました。

後編の記事でアプリ自体は完成していたので、「あとは設定ファイルをちょこっといじればデプロイできるかな」と軽く考えていました。

超、甘かったです。

最大の壁は、VercelがAmazon Lambdaを使用していることによる容量制限で、250MBの制約に対してPyTorchのライブラリは余裕でオーバーしていることです。(なんと800MBもある…)

最終的に、PyTorchモデルをTensorFlow.jsに変換し、クライアントサイドで動作させる方法でこの問題を解決しました。

一言で済ましましたが、一筋縄では行かなかったので、本記事では、この壁をどのように乗り越えたかについて詳しく解説いたします。

方針

先述の通り、PyTorchライブラリが250MBの容量制限を大幅に超えていたため、アプリのデプロイができませんでした。(実を言うと、この壁にぶち当たる前にも、設定ファイルの微調整などでいろいろとつまずきましたが、今回はその話は省きます。)

対策についてChatGPTに相談しました。

vercelにアプリをデプロイしようとしたらpytorchを使用しているためlambdaの容量を超過しました。代替案含めて良い案を教えてください。

回答結果は以下の通りです。

普通なら1を選ぶところですが、今回は意地でも無料枠にこだわりたいので、別のサービスは使いたくありません。また、「モデルの最適化」という方法も選択肢にありますが、今回の問題はモデルのサイズではなく(これも40MBほどありますが)、あくまでPyTorchライブラリ自体のサイズなのです。

そこで次の質問です。

pytorchのライブラリ容量が大きすぎます。小さくする方法はありますか

回答が長かったので有望な候補だけピックアップします。

  • pytorchのCPU版のみ使用する。
  • pytorchで予測に必要なライブラリのみをインストールする。
  • ONNX 、TorchScript、TensorFlow.jsなどの代替ライブラリを使用する。(pytorchで作成したモデルを変換する。)

この選択肢を検討した結果、代替ライブラリの使用を試してみることにしました。(ちなみに、他の選択肢も試してみましたが、うまくいかず…。)

では上記3つのライブラリの中から何を選ぶか。

PythonとNext.jsの連携は予想以上にハマりどころが多かったため、Next.js単体で完結させたいと考え、TensorFlow.jsが候補に挙がりました。(他のライブラリもNext.jsで使えるかもしれませんが、なんとなくTensorFlow.jsを試してみたくて…。)

試しにChatGPTに「PyTorchモデルをTensorFlow.jsに変換する方法を教えてください」と聞いてみると、“いかにもできそうな感じ”のコードとコマンドを返してくれました。

このような経緯でPyTorchのモデルをTensorFlow.jsに変換し、サーバーサイドをPythonではなくNext.jsで実装する方針を採用することにしました。(「ん?最初に書いてある事と違うのでは?」と思った方、とりあえず読み進めてください。)

次章からは、具体的な環境設定やソースコードについて解説していきます。

実装

モデルの変換

概要

PyTorchで作成した動物顔判定モデル(.pth形式)をTensorFlow.jsで使用するために、まずモデルをONNX形式に変換し、最終的にTensorFlow.js形式へ変換しました。

この変換作業はGoogle Colab(L4インスタンス)上で実施しました。

実は、最初はローカル環境で変換作業を試みたのですが、作成したモデルがうまく読み込めず、最終的にGoogle Colabで同じ手順を試したところ、正常に動作しました。正直、原因はわかりませんが、GPUとCPUの違いが影響したのかもしれません。(もしくは、単にローカル環境で何か間違えていた可能性もありますが…)

google driveへのマウント

モデルの読み取り・保存のためにgoogle driveをマウントします。

from google.colab import drive
drive.mount('/content/drive')

 

ONNX形式のモデルに変換

必要なライブラリをインストールします。

!pip install onnx
!pip install onnx-tf

 

以下のコードを実行してONNX形式のモデルに変換します。

import torch
import torch.onnx
from torchvision import models

# モデルのロード
model_path = "/content/drive/MyDrive/open_images_v6/resnet18_transfer_learning_final.pth"

# torchvisionのResNet18モデルをベースに、転移学習済みのモデルを作成
model = models.resnet18(pretrained=False)

# 保存されたモデルの状態辞書をロード
state_dict = torch.load(model_path)

# fc層の出力サイズを取得(この例では5)
num_classes = state_dict['fc.weight'].size(0)

# モデルのfc層を新しいサイズに変更
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

# 変更後のモデルに保存されたパラメータをロード
model.load_state_dict(state_dict)

# 評価モードに設定
model.eval()

# ダミーの入力テンソル(pytorchでの形状を指定)
dummy_input = torch.randn(1, 3, 224, 224)

# ONNXファイルにエクスポート
onnx_model_path = "/content/drive/MyDrive/open_images_v6/resnet18_transfer_learning.onnx"
torch.onnx.export(
    model,                      # モデル
    dummy_input,                # ダミー入力
    onnx_model_path,            # エクスポート先のファイル名
    export_params=True,         # 学習したパラメータもエクスポートする
    opset_version=11,           # ONNXのバージョン(11が推奨)
    do_constant_folding=True,   # 定数フォールディングを有効にして最適化
    input_names=['input'],      # 入力名
    output_names=['output'],    # 出力名
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}  # 動的バッチサイズに対応
)

print(f"モデルをONNX形式に変換して保存しました: {onnx_model_path}")

 

tensorflow形式のモデルに変換

TensorFlow形式へのモデル変換では、TensorFlowとKerasのバージョンを合わせる点で苦労しました。

デフォルトのバージョンのままだとエラーが発生してしまい、試行錯誤の末、以下のバージョンの組み合わせでようやく成功しました。

!pip install "keras<3.0.0" "tensorflow<2.16"
!pip install tensorflow_probability==0.23

 

以下のコマンドを実行してONNX形式からTensorflow形式に変換します。

!onnx-tf convert -i /content/drive/MyDrive/open_images_v6/resnet18_transfer_learning.onnx -o /content/drive/MyDrive/open_images_v6/resnet18_transfer_learning

 

Tesorflow.js形式に変換

必要なライブラリをインストールしてコマンドを実行します。

!pip install onnx tensorflowjs
!tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model \
/content/drive/MyDrive/open_images_v6/resnet18_transfer_learning /content/drive/MyDrive/open_images_v6/tfjs_model

 

モデルはディレクトリ(tfjs_model)に格納されます。(モデルの構造を示したjsonファイルとパラメータの重みを格納したbinファイルから構成されています。)

PythonからNext.jsへ

概要

サーバーサイドのPythonコードをNext.jsに置き換え、TensorFlow.js形式に変換したモデルを使用するようにします。

今回もCursor Composerをフル活用しました。指示は以下の通りです。

現在はpython(api/index.py)でpthファイルのモデルを呼び出して判定をしていますが、Next.jsでtensorflow.jsを使って処理するように変更してください。モデルのパスはmodel/tfjs_modelです。

この時の回答により以下のライブラリをインストールします。

npm install @tensorflow/tfjs-node
npm install formidable
npm install sharp

 

なお、今回は一発で動くものはできずに、以下の点で想定通りに行かずに何度か調整(AIと数回のやりとり)を行いました。

  • 画像の前処理再現(リサイズ、切り取り、正規化など)
  • PyTorchで対応できていたWebP形式の画像がTensorFlow.jsでは処理できなかったため、PNG形式への変換を実施
  • 入力データの次元の置換(PyTorchはNCHW形式、TensorFlowはNHWC形式)

NCHWの構造において、各要素の意味は以下の通りです:

  • N: バッチ数
  • C: チャネル数(色や特徴の層数)
  • H: 画像の高さ
  • W: 画像の幅

一方、TensorFlow.jsが使用するNHWC形式では、バッチ数、画像の高さ、幅、チャネル数の順にデータが並びます。

上記の修正を経て、ローカル環境で動作するものが完成しました。

なおここで問題が・・

結果がPyTorch使用時と微妙に異なり、正解のパーセンテージが若干下がってしまっています。

この違いが、TensorFlow.jsへの変換でデータ型がint64からint32に変更された影響なのか、前処理の違いによるものなのかは不明です。今回はこの点を深く追求せず、先に進むことにします。

結果

さあ、ローカルでうまく動作したので、いざデプロイ!…と思いきや、サイズ超過のエラーでデプロイできませんでした。

考えてみれば、TensorFlow.jsのライブラリサイズを確認していなかった・・

Tensorflow.jsのライブラリも予想以上に大きかったようです。(いや、ChatGPT、どうしてこれを代替案として出してきたの?とは思いますが、最初に確認すべきでしたね。仕事じゃなくてよかった…)

その後、TensorFlow.jsの予測に必要なパッケージ部分だけをインストールしてみるなど試行錯誤しましたが、うまくいきませんでした。

次の手を考えることにします。

ソースコード

一応、ローカルでは動作するので参考までにソースを掲載します。(page.tsx.layout.tsxは省略します。以下同様)

components/animal-face-detector.tsx
'use client'

import { useState, useCallback } from 'react'
import { useDropzone } from 'react-dropzone'
import { PieChart, Pie, Cell, Legend, ResponsiveContainer } from 'recharts'
import Image from 'next/image'

// 定数の定義
const COLOR_LIST = ['#FF4136', '#2ECC40', '#0074D9', '#FF851B', '#B10DC9']

export function AnimalFaceDetector() {
  // 状態の定義
  const [result, setResult] = useState<{ name: string; value: number; }[] | null>(null)
  const [image, setImage] = useState<string | null>(null)
  const [error, setError] = useState<string | null>(null)

  // 画像分析関数
  const analyzeImage = useCallback(async (file: File) => {
    const formData = new FormData()
    formData.append('file', file)

    try {
      console.log('サーバーにリクエストを送信中...')
      const response = await fetch('/api/predict', {
        method: 'POST',
        body: formData,
      })
      console.log('レスポンス受信:', response.status, response.statusText)
      
      if (!response.ok) {
        const errorText = await response.text()
        console.error('APIエラー:', errorText)
        throw new Error(`HTTPエラー! ステータス: ${response.status}, 詳細: ${errorText}`)
      }
      
      const data = await response.json()
      console.log('サーバーからのレスポンス:', data)
      
      // 結果の調整:確率をパーセンテージに変換
      const animal_prob_list = data.map((item: { name: string; probability: number }) => ({
        name: item.name,
        value: Math.round(item.probability * 100)
      }))
      
      setResult(animal_prob_list)
      setError(null)
      console.log('更新された結果:', animal_prob_list)
    } catch (error) {
      console.error('画像分析エラー:', error)
      setError(error instanceof Error ? error.message : '画像の分析中にエラーが発生しました。')
    }
  }, [])

  // ドロップゾーンのコールバック関数
  const onDrop = useCallback((acceptedFiles: File[]) => {
    const file = acceptedFiles[0]
    if (file) {
      console.log('ファイルがドロップされました:', file.name)
      const reader = new FileReader()

      reader.onload = (e) => {
        if (e.target && e.target.result) {
          setImage(e.target.result as string)
          analyzeImage(file)
        }
      }

      reader.onerror = (e) => {
        console.error('ファイル読み込みエラー:', e)
        setError('ファイルの読み込み中にエラーが発生しました。')
      }

      reader.readAsDataURL(file)
    }
  }, [analyzeImage])

  // react-dropzoneのフック
  const { getRootProps, getInputProps } = useDropzone({ 
    onDrop,
    accept: {'image/*': []}
  })

  // 最大カテゴリ(最も高い確率の動物)を取得する関数
  const getMaxCategory = useCallback(() => {
    if (!result) return null
    return result.reduce((max, item) => item.value > max.value ? item : max)
  }, [result])

  const maxCategory = getMaxCategory()

  return (
    <div className="min-h-screen bg-gradient-to-br from-blue-100 to-purple-100 py-12 px-4 sm:px-6 lg:px-8">
      <div className="max-w-md mx-auto bg-white rounded-xl shadow-2xl overflow-hidden md:max-w-2xl">
        <div className="p-8">
          <h1 className="text-4xl font-bold text-center text-gray-800 mb-8">動物顔判定</h1>
          
          <div className="mt-8">
            {/* ドロップゾーン */}
            <div 
              {...getRootProps()} 
              className="mb-8 relative rounded-lg overflow-hidden shadow-lg cursor-pointer"
              style={{width: '100%', height: '250px'}}
            >
              <input {...getInputProps()} />
              {image ? (
                <Image src={image} alt="分析対象の顔" layout="fill" objectFit="cover" className="rounded-lg transition-transform duration-300 hover:scale-105" />
              ) : (
                <div className="flex items-center justify-center h-full bg-gray-100">
                  <p className="text-gray-500 text-lg font-semibold">ここをクリックまたはドラッグ&ドロップで画像をアップロード</p>
                </div>
              )}
            </div>
            
            {/* エラーメッセージ */}
            {error && (
              <p className="text-red-500 text-center mb-4">{error}</p>
            )}

            {/* 結果の表示 */}
            {result && maxCategory && (
              <>
                <h2 className="text-2xl font-bold text-center text-gray-800 mb-6 animate-fade-in">
                  あなたは{maxCategory.value}%の確率で{maxCategory.name}です。
                </h2>
                
                {/* 円グラフ */}
                <div className="bg-gray-50 rounded-lg p-4 shadow-inner mb-8">
                  <ResponsiveContainer width="100%" height={300}>
                    <PieChart>
                      <Pie
                        data={result}
                        cx="50%"
                        cy="50%"
                        labelLine={false}
                        outerRadius={100}
                        fill="#8884d8"
                        dataKey="value"
                        label={({ name, value }) => `${name} ${value}%`}
                      >
                        {result.map((_, index: number) => (
                          <Cell key={`cell-${index}`} fill={COLOR_LIST[index % COLOR_LIST.length]} />
                        ))}
                      </Pie>
                      <Legend 
                        formatter={(value, _, index) => <span style={{ color: COLOR_LIST[index % COLOR_LIST.length], fontWeight: 'bold' }}>{value}</span>}
                      />
                    </PieChart>
                  </ResponsiveContainer>
                </div>
              </>
            )}
          </div>
          
          {/* フッターの追加 */}
          <footer className="mt-12 pt-4 border-t border-gray-200">
            <p className="text-sm text-gray-600 mb-2">
              <strong>免責事項:</strong>本サービスの判定結果について、いかなる責任も負いかねます。
            </p>
            <p className="text-sm text-gray-600 mb-4">
              <strong>禁止事項:</strong>公序良俗に反する画像のアップロードは固く禁じます。
            </p>
            {/* 著作権表示の追加 */}
            <p className="text-xs text-gray-500 text-center">
              © Copyright 2024 by ドドテクノ
            </p>
          </footer>
        </div>
      </div>
    </div>
  )
}

 

page/api/predicts.ts
import { NextApiRequest, NextApiResponse } from 'next'
import * as tf from '@tensorflow/tfjs-node'
import { IncomingForm, Fields, Files } from 'formidable'
import fs from 'fs'
import sharp from 'sharp'

const ANIMAL_LIST = ['猫', '犬', '蛙', '猿', '豚']

export const config = {
  api: {
    bodyParser: false,
  },
}

export default async function handler(req: NextApiRequest, res: NextApiResponse) {
  if (req.method === 'POST') {
    try {
      const { files } = await parseForm(req)
      if (!files.file) {
        throw new Error('No file uploaded')
      }

      const file = Array.isArray(files.file) ? files.file[0] : files.file
      console.log('File received:', file.originalFilename, 'Size:', file.size)

      const model = await tf.loadGraphModel('file://model/tfjs_model/model.json')
      const imageTensor = await preprocessImage(file.filepath)
      const predictions = model.predict(imageTensor) as tf.Tensor

      const probabilities = tf.softmax(predictions as tf.Tensor1D)
      const probabilitiesData = await probabilities.data()

      const result = ANIMAL_LIST.map((animal, index) => ({
        name: animal,
        probability: probabilitiesData[index]
      }))

      res.status(200).json(result)
    } catch (error) {
      console.error('Prediction error:', error)
      res.status(500).json({ error: '予測中にエラーが発生しました' })
    }
  } else {
    res.setHeader('Allow', ['POST'])
    res.status(405).end(`Method ${req.method} Not Allowed`)
  }
}

function parseForm(req: NextApiRequest): Promise<{ fields: Fields; files: Files }> {
  return new Promise((resolve, reject) => {
    const form = new IncomingForm({
      maxFileSize: 10 * 1024 * 1024, // 10MB limit
    })
    form.parse(req, (err, fields, files) => {
      if (err) reject(err)
      else resolve({ fields, files })
    })
  })
}

async function preprocessImage(filePath: string): Promise<tf.Tensor> {
  const imageBuffer = await fs.promises.readFile(filePath)
  const processedImageBuffer = await sharp(imageBuffer)
    .resize(256, 256, { fit: 'cover' })
    .toFormat('png')
    .toBuffer()

  const decodedImage = tf.node.decodeImage(processedImageBuffer, 3)
  const height = decodedImage.shape[0]
  const width = decodedImage.shape[1]
  const cropSize = 224
  const startHeight = Math.floor((height - cropSize) / 2)
  const startWidth = Math.floor((width - cropSize) / 2)
  const croppedImage = tf.slice(decodedImage, [startHeight, startWidth, 0], [cropSize, cropSize, 3])

  const expandedImage = croppedImage.expandDims(0)
  const normalizedImage = expandedImage.div(255.0)

  const mean = tf.tensor([0.485, 0.456, 0.406])
  const std = tf.tensor([0.229, 0.224, 0.225])
  const normalizedImageWithStats = normalizedImage.sub(mean).div(std)

  return normalizedImageWithStats.transpose([0, 3, 1, 2])
}
next.config.js
/** @type {import('next').NextConfig} */
const nextConfig = {
  reactStrictMode: true,
  webpack: (config, { isServer }) => {
    if (!isServer) {
      config.resolve.fallback = {
        ...config.resolve.fallback,
        fs: false,
        path: false,
        os: false,
      };
    }

    if (isServer) {
      config.externals = [...config.externals, 'sharp'];
    }

    return config;
  },
}

module.exports = nextConfig

 

そしてクライアントサイドへ

概要

試行錯誤の結果、モデルを使用した予測処理をクライアントサイドで実行する方針に変更しました。

理由は、TensorFlow.jsがCDNで利用でき、ライブラリをインストールする必要がないからです。

ただし、クライアントサイドにモデルを配置することはセキュリティ上のリスクがあります。(この程度のアプリのモデルであれば公開されても問題は少ないのですが、一応見られたくない前提で進めます。)

そこで、モデルは完全にpublicにはせず、サーバーサイドに配置し、クライアントサイドからAPI経由でアクセスできるように設定しました。このAPIは、環境変数に格納されたアクセスキーによって保護されており、クライアントはリクエスト時にアクセストークンを付与し、サーバー側で認証することで未承認のアクセスを防止しています。

(実は、この方法でも見破られるリスクは残りますが、「やらないよりはマシ」ということでこの設定を採用しています。)

なお、前章でインストールしたライブラリは不要になったため、npm uninstallコマンドでアンインストールし、「npm install」を実行してpackage.jsonも更新してください。

npm uninstall @tensorflow/tfjs-node
npm uninstall formidable
npm uninstall sharp
npm install

 

また、上述したアクセスキーとして、Vercelの環境変数に「API_KEY」という名前で変数を設定します(値は任意)。ローカル環境では、.env.localファイルを作成し、API_KEY=(任意の値)を記載してください。

結果

今回、無事にデプロイに成功しました。

ただ、クライアントサイドにモデルをロードするために数秒を要するため、画面表示後にすぐに画像を貼り付けると「モデル読み込み中」とのエラーが表示されます。(余裕があればそのうちモデル読み込み中のインジケーターでもつけます。)

デプロイしたアプリのリンクを再掲します。

ソースコード

完成版のソースコードを掲載します。

今回もCursor Coposerを存分に使用してソースを出力しました。

components/animal-face-detector.tsx

今回はこのファイルで画面のUIからモデルを使っての判定まで行っています。

なお、クライアントに処理を移行したら画像のPNG変換がいつの間にか消えてました。ただそのままwebp形式のファイルをアップしてもエラーになりません。結果オーライですが、こういう事がコード自動生成の気持ち悪いところでもあります。

また、今回はVerselで型指定でAnyを使用しているとデプロイ時にエラーになるため、冒頭にESLintを無視する設定を追加しました。

/* eslint-disable @typescript-eslint/no-explicit-any */
'use client'

import { useState, useCallback, useEffect } from 'react'
import { useDropzone } from 'react-dropzone'
import { PieChart, Pie, Cell, Legend, ResponsiveContainer } from 'recharts'
import Image from 'next/image'

// グローバル変数 tf の型を定義
declare global {
  const tf: any;
}

// 定数の定義
const ANIMAL_LIST = ['猫', '犬', '蛙', '猿', '豚']
const COLOR_LIST = ['#FF4136', '#2ECC40', '#0074D9', '#FF851B', '#B10DC9']

export function AnimalFaceDetector() {
  // 状態の定義
  const [result, setResult] = useState<{ name: string; value: number; }[] | null>(null)
  const [image, setImage] = useState<string | null>(null)
  const [error, setError] = useState<string | null>(null)
  const [model, setModel] = useState<any | null>(null)
  const [tfLoaded, setTfLoaded] = useState(false)

  useEffect(() => {
    // TensorFlow.jsを動的に読み込む
    const script = document.createElement('script')
    script.src = 'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs'
    script.async = true
    script.onload = () => {
      console.log('TensorFlow.js loaded')
      setTfLoaded(true)
    }
    document.body.appendChild(script)

    return () => {
      document.body.removeChild(script)
    }
  }, [])

  useEffect(() => {
    if (tfLoaded) {
      loadModel()
    }
  }, [tfLoaded])

  
  async function loadModel() {
    try {
      console.log('Loading model...')
      const loadedModel = await tf.loadGraphModel('/api/getModel', {
        fetchFunc: async (path: string, init?: RequestInit) => {
          const headers = new Headers(init?.headers);
          headers.append('Authorization', `Bearer ${process.env.API_KEY}`);
          const response = await fetch(path, { ...init, headers });
          if (!response.ok) {
            console.error(`HTTP error! status: ${response.status}`, await response.text())  // デバッグ用ログ
            throw new Error(`HTTP error! status: ${response.status}`)
          }
          return response
        }
      })
      console.log('Model loaded')
      setModel(loadedModel)
    } catch (error) {
      console.error('Error loading model:', error)
      setError('モデルの読み込み中にエラーが発生しました')
    }
  }

  // 画像分析関数
  const analyzeImage = useCallback(async (file: File) => {
    if (!model) {
      console.error('モデルがまだ読み込まれていません');
      setError('モデルがまだ読み込まれていません。しばらく待ってから再試行してください。');
      return;
    }

    try {
      const img = await preprocessImage(file)
      const predictions = model.predict(img) as any
      const probabilities = tf.softmax(predictions)
      const probabilitiesData = await probabilities.data()

      const result = ANIMAL_LIST.map((animal, index) => ({
        name: animal,
        value: Math.round(probabilitiesData[index] * 100)
      }))

      setResult(result)
      setError(null)
    } catch (error: unknown) {
      console.error('画像分析エラー:', error);
      if (error instanceof Error) {
        setError(`画像の分析中にエラーが発生しました: ${error.message}`);
      } else {
        setError('画像の分析中に予期せぬエラーが発生しました。');
      }
    }
  }, [model])

  const preprocessImage = async (file: File): Promise<any> => {
    return new Promise((resolve, reject) => {
      const reader = new FileReader()
      reader.onload = async (e) => {
        if (e.target && e.target.result) {
          const img = new (window.Image || Image)() as HTMLImageElement;          
          img.onload = async () => {
            try {
              // Resize to 256x256
              const canvas = document.createElement('canvas')
              canvas.width = 256
              canvas.height = 256
              const ctx = canvas.getContext('2d')
              if (!ctx) throw new Error('Failed to get canvas context')
              ctx.drawImage(img, 0, 0, 256, 256)

              // Convert to tensor
              let tensor = tf.browser.fromPixels(canvas, 3)

              // Center crop to 224x224
              const cropSize = 224
              const startHeight = Math.floor((256 - cropSize) / 2)
              const startWidth = Math.floor((256 - cropSize) / 2)
              tensor = tf.slice(tensor, [startHeight, startWidth, 0], [cropSize, cropSize, 3])

              // Expand dimensions and normalize
              tensor = tensor.expandDims(0).toFloat().div(255.0)

              // Normalize with mean and std
              const mean = tf.tensor([0.485, 0.456, 0.406])
              const std = tf.tensor([0.229, 0.224, 0.225])
              tensor = tensor.sub(mean).div(std)

              // Transpose to match the model's expected input shape
              tensor = tensor.transpose([0, 3, 1, 2])

              resolve(tensor)
            } catch (error) {
              reject(error)
            }
          }
          img.onerror = (err: Event | string) => reject(err)
          img.src = e.target.result as string
        }
      }
      reader.onerror = (err: ProgressEvent<FileReader>) => reject(err)
      reader.readAsDataURL(file)
    })
  }

  // ドロップゾーンのコールバック関数
  const onDrop = useCallback((acceptedFiles: File[]) => {
    const file = acceptedFiles[0]
    if (file) {
      console.log('ファイルがドロップされました:', file.name)
      const reader = new FileReader()

      reader.onload = (e) => {
        if (e.target && e.target.result) {
          setImage(e.target.result as string)
          analyzeImage(file)
        }
      }

      reader.onerror = (e) => {
        console.error('ファイル読み込みエラー:', e)
        setError('ファイルの読み込み中にエラーが発生しました。')
      }

      reader.readAsDataURL(file)
    }
  }, [analyzeImage])

  // react-dropzoneのフック
  const { getRootProps, getInputProps } = useDropzone({ 
    onDrop,
    accept: {'image/*': []}
  })

  // 最大カテゴリ(最も高い確率の動物)を取得する関数
  const getMaxCategory = useCallback(() => {
    if (!result) return null
    return result.reduce((max, item) => item.value > max.value ? item : max)
  }, [result])

  const maxCategory = getMaxCategory()

  return (
    <div className="min-h-screen bg-gradient-to-br from-blue-100 to-purple-100 py-12 px-4 sm:px-6 lg:px-8">
      <div className="max-w-md mx-auto bg-white rounded-xl shadow-2xl overflow-hidden md:max-w-2xl">
        <div className="p-8">
          <h1 className="text-4xl font-bold text-center text-gray-800 mb-8">動物顔判定</h1>
          
          <div className="mt-8">
            {/* ドロップゾーン */}
            <div 
              {...getRootProps()} 
              className="mb-8 relative rounded-lg overflow-hidden shadow-lg cursor-pointer"
              style={{width: '100%', height: '250px'}}
            >
              <input {...getInputProps()} />
              {image ? (
                <Image src={image} alt="分析対象の顔" layout="fill" objectFit="cover" className="rounded-lg transition-transform duration-300 hover:scale-105" />
              ) : (
                <div className="flex items-center justify-center h-full bg-gray-100">
                  <p className="text-gray-500 text-lg font-semibold">ここをクリックまたはドラッグ&ドロップで画像をアップロード</p>
                </div>
              )}
            </div>
            
            {/* エラーメッセージ */}
            {error && (
              <p className="text-red-500 text-center mb-4">{error}</p>
            )}

            {/* 結果の表示 */}
            {result && maxCategory && (
              <>
                <h2 className="text-2xl font-bold text-center text-gray-800 mb-6 animate-fade-in">
                  あなたは{maxCategory.value}%の確率で{maxCategory.name}です。
                </h2>
                
                {/* 円グラフ */}
                <div className="bg-gray-50 rounded-lg p-4 shadow-inner mb-8">
                  <ResponsiveContainer width="100%" height={300}>
                    <PieChart>
                      <Pie
                        data={result}
                        cx="50%"
                        cy="50%"
                        labelLine={false}
                        outerRadius={100}
                        fill="#8884d8"
                        dataKey="value"
                        label={({ name, value }) => `${name} ${value}%`}
                      >
                        {result.map((_, index: number) => (
                          <Cell key={`cell-${index}`} fill={COLOR_LIST[index % COLOR_LIST.length]} />
                        ))}
                      </Pie>
                      <Legend 
                        formatter={(value, _, index) => <span style={{ color: COLOR_LIST[index % COLOR_LIST.length], fontWeight: 'bold' }}>{value}</span>}
                      />
                    </PieChart>
                  </ResponsiveContainer>
                </div>
              </>
            )}
          </div>
          
          {/* フッターの追 */}
          <footer className="mt-12 pt-4 border-t border-gray-200">
            <p className="text-sm text-gray-600 mb-2">
              <strong>免責事項:</strong>本サービスの判定結果について、いかなる責任も負いかねます。
            </p>
            <p className="text-sm text-gray-600 mb-4">
              <strong>禁止事項:</strong>公序良俗に反する画像のアップロードは固く禁じます。
            </p>
            {/* 著作権表示の追加 */}
            <p className="text-xs text-gray-500 text-center">
              © Copyright 2024 by ドドテクノ
            </p>
          </footer>
        </div>
      </div>
    </div>
  )
}
app/api/getModel/route.ts

モデルの構成情報が格納されたjsonファイルを読み込み関連する重み情報を読み込みます。(重み情報の読み込みはgetModelWeightに移譲)

import { NextRequest,NextResponse } from 'next/server'
import fs from 'fs'
import path from 'path'

export async function GET(req: NextRequest) {
  try {
    const authHeader = req.headers.get('Authorization');
    const token = authHeader && authHeader.split(' ')[1];
      
    if (token !== process.env.API_KEY) {
      return NextResponse.json({ error: 'Unauthorized' }, { status: 401 });
    }
    
    const modelPath = path.join(process.cwd(), 'model', 'tfjs_model', 'model.json')
    const modelJson = JSON.parse(fs.readFileSync(modelPath, 'utf-8'))

    // モデルのweightsPathを相対パスに変更
    modelJson.weightsManifest[0].paths = modelJson.weightsManifest[0].paths.map(
      (p: string) => `./getModelWeights?file=${p}`
    )

    //console.log('Model JSON:', JSON.stringify(modelJson, null, 2))  // デバッグ用ログ

    return NextResponse.json(modelJson)
  } catch (error: unknown) {
    console.error('Error loading model:', error)
    const errorMessage = error instanceof Error ? error.message : 'An unknown error occurred'
    return NextResponse.json({ error: 'Failed to load model', details: errorMessage }, { status: 500 })
  }
}
app/api/getModelWeight/route.ts

モデルの重み情報が格納されたbinファイルを読み込みます。

import { NextRequest, NextResponse } from 'next/server'
import fs from 'fs'
import path from 'path'

export async function GET(req: NextRequest) {

    const authHeader = req.headers.get('Authorization');
    const token = authHeader && authHeader.split(' ')[1];

    if (token !== process.env.API_KEY) {
        return NextResponse.json({ error: 'Unauthorized' }, { status: 401 });
    }
    const file = req.nextUrl.searchParams.get('file')
    if (!file) {
        return NextResponse.json({ error: 'No file specified' }, { status: 400 })
    }

    const filePath = path.join(process.cwd(), 'model', 'tfjs_model', file)
    
    if (!fs.existsSync(filePath)) {
        return NextResponse.json({ error: 'File not found' }, { status: 404 })
    }

    const fileBuffer = fs.readFileSync(filePath)
    
    return new NextResponse(fileBuffer, {
        headers: { 'Content-Type': 'application/octet-stream' }
    })
}
next.config.js

アクセスキーの設定とクライアントサイドでビルドするライブラリの制御を記述しています。

/** @type {import('next').NextConfig} */
const nextConfig = {
  reactStrictMode: true,
  env: {
    API_KEY: process.env.API_KEY,
  },
  webpack: (config, { isServer }) => {
    if (!isServer) {
      config.resolve.fallback = {
        ...config.resolve.fallback,
        fs: false,
        path: false,
        os: false,
      };
    }

    return config;
  },
};

module.exports = nextConfig;

課題

今回はデプロイまで成功しましたが、もし実運用を考えるのならば、以下の3つが課題として残ります。

変換したモデルの精度検証

前述の通り、TensorFlow.jsに対応するモデルへ変換した際、PyTorchで使用していた時と同様の結果が得られず、精度が悪化しました(正解ラベルの確率が低下)。

本来であれば、この点を追求すべきですが、今回はこのまま進めることにしますが、商用利用の場合、顧客への説明(説得?)が必要となるため、この精度低下の点は曖昧にできないでしょう。

セキュリティ対策

モデルはサーバーサイドに配置し、アクセスキーによる制御を行っていますが、わかる人ならば簡単にアクセスできる状態です。本来であれば、さらに堅牢なセキュリティ対策が必要です。

しかし、そもそもこれだけの労力をかけるのであれば、サーバーサイドで処理を行う方が適切です。

今回は「無料」にこだわりこの方針を選択しましたが、サーバー処理にCloud Functionなどを利用すれば、アクセスが少なければ”無料みたいな金額”で運用できるため、そちらの方が妥当でしょう。アクセスが増えるようなら、いっそのこと有料化して広告を付ければ利益が出るかもしれません。

コードの品質

これについては前回も触れましたが、ほぼ自動生成されたソースコードを使用しており、また私自身もNext.jsにそれほど詳しくないため、このコードが良いものかどうか判断がつかない状況です。

AIが生成したPythonのコードを見て「この書き方はダメ」と感じることが多々あるので、このNext.jsのコードも専門家が見れば「なんだ、この書き方は!」と思う箇所があるかもしれません。

なお、実際にpages/_doocument.tsxというファイルがいつの間にかできており、一見したところapp/layout.tsxと内容が被るので、最初は「必要ですか?」と聞くと「必要です」と返ってきましたが、「layout.tsxがあるのにpages/_doocument.tsxは必要ですか?」と聞くと「App Routerを使用している場合は必要ありません」と返ってきました。

全くNext.jsについて知らなければ、騙されるところでした。

もし実運用を行うのであれば、最終的にはNext.jsの有識者によるレビューが必要になるでしょう。

まとめ

本記事では、動物顔判定アプリをVercelにデプロイするため、PyTorchモデルをTensorFlow.jsに変換し、クライアントサイドで実行する方法に切り替えることで容量制限を回避しました。

TensorFlow.jsをCDNから呼び出す構成により、サーバーサイドなしでもVercelの無料枠内で動作するアプリを実現しています。

(結果的にはCDN対応で解決しましたが、事前にTensorFlow.jsのライブラリサイズを確認せず、「軽いだろう」と考えたのはダメですね・・)

今回の作業を通じて、AIを活用したコーディングの手軽さと効率性を実感しております。

ただ一方で、これが実際のシステム開発であった場合に自動生成されたコードを安定的に保守できるかという不安もあります。

ただしAI技術は日進月歩で進化し、明日には今の常識が覆るかもしれません。

今後も日々の情報をアップデートしながらAIとどのように共存し、効率よく活用していくかを模索し続けたいと思います。