PyMySql を使って pandsa の DataFrame をまとめてDBに挿入する

IT

はじめに

本記事では PyMySql を使って pandsa の DataFrame をまとめてDBに挿入する方法についてまとめます。

環境

以下が今回の環境です。

$ python -V
Python 3.9.16
$ pip list | grep -e PyMySQL -e pandas
pandas                1.5.2
PyMySQL               1.0.2

$ mysql --version
mysql  Ver 8.0.29 for Linux on x86_64

サンプルコード

以下サンプルコードです。executemanyメソッドを使用しています。
ポイントとしては

  • SQLではプライマリキーが同一のレコードが既に存在するレコードを挿入する場合にエラーが発生しないようにINSERT ... ON DUPLICATE KEY UPDATEを使用している。
    こちらは必要に応じてON DUPLICATE KEY UPDATEを削除する必要がある。
  • まとめて挿入した時に1レコードでも不正なレコードが存在するとエラーになるが、どのレコードが問題でエラーになっているかわからないので、
    エラー発生後に1レコードずつ再度挿入をリトライして、該当レコードを明らかにしている。
  • 上記の1レコード挿入する場合は各挿入ごとにコミットは行わず最後にコミットしているため、DataFrameの一部のみ保存されることはない(全て保存するか全て保存しない)
import os

from dotenv import load_dotenv
import pandas as pd
from pandas.core.frame import DataFrame
import pymysql

# load environment variables
load_dotenv()

# connection parameters
host = os.environ.get('host')
user = os.environ.get('user')
password = os.environ.get('password')
database = os.environ.get('database')
port = os.environ.get('port')

# connect to database
connection = pymysql.connect(
    host=host,
    user=user,
    password=password,
    db=database,
)


def insert_multiple_rows(df: DataFrame, table: str) -> None:
    def executemany_sql(sql: str, records: list[list], commit: bool = True):
        with connection.cursor() as cursor:
            print(sql)
            cursor.executemany(sql, records)
            if commit:
                connection.commit()
    columns = ','.join(df.columns)
    values=','.join(['%s' for i in range(len(df.columns))])
    update_condition = ",".join([f"{column}=VALUES({column})" for column in df.columns])
    sql = f"INSERT INTO {table} ({columns:}) VALUES ({values:}) ON DUPLICATE KEY UPDATE {update_condition};"
    print(sql)
    try:
        executemany_sql(sql, df.values.tolist())
    except Exception as e:
        print(f"An Exception occured during batch insert: {e}")
        for _, row in df.iterrows():
            try:
                executemany_sql(sql, [row.tolist()], False)
            except Exception as e:
                print(f"An Exception occured during sigle insert: {e}")
                print(f"record info: {row}")
                raise
        connection.commit()


# insert to table from dataframe
test_df = pd.read.csv('./test.csv')
insert_multiple_rows(test_df, 'test_table')

おわりに

本記事では PyMySql を使って pandsa の DataFrame をまとめてDBに挿入する方法についてまとめます。
この記事がどなたかの参考になれば幸いです。

参考

コメント