Minor code cleanup.

This commit is contained in:
Nolan Darilek 2025-06-16 17:36:50 -04:00
parent 46bee6ee60
commit 2b1a470196

View file

@ -8,9 +8,10 @@ use arrow::{
datatypes::{DataType, Field, Schema, TimeUnit}, datatypes::{DataType, Field, Schema, TimeUnit},
}; };
use futures::TryStreamExt; use futures::TryStreamExt;
use mysql_async::{Conn, Pool, params::Params, prelude::*}; use mysql_async::{Conn, Pool, Row, params::Params, prelude::*};
use parquet::arrow::AsyncArrowWriter; use parquet::arrow::AsyncArrowWriter;
use rayon::prelude::*; use rayon::prelude::*;
use tokio::fs::File;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
@ -18,11 +19,9 @@ mod test;
type Error = Box<dyn std::error::Error + Send + Sync>; type Error = Box<dyn std::error::Error + Send + Sync>;
pub async fn load_data(pool: &Pool) -> Result<(), Error> { pub async fn load_data(pool: &Pool) -> Result<(), Error> {
// Obviously use something better/more robust here if you're a) loading
// other data sources and b) aren't so lucky as to have a single `;` after
// each block of SQL. :)
let data = String::from_utf8_lossy(include_bytes!("../data.sql")); let data = String::from_utf8_lossy(include_bytes!("../data.sql"));
for stmt in data.split(";") { // Obviously do something better if you don't control the input data.
for stmt in data.trim().split(";") {
if !stmt.trim().is_empty() { if !stmt.trim().is_empty() {
let mut conn = pool.get_conn().await?; let mut conn = pool.get_conn().await?;
conn.exec_drop(stmt, Params::Empty).await?; conn.exec_drop(stmt, Params::Empty).await?;
@ -34,10 +33,10 @@ pub async fn load_data(pool: &Pool) -> Result<(), Error> {
pub async fn convert_data(pool: &Pool, table: &str) -> Result<(), Error> { pub async fn convert_data(pool: &Pool, table: &str) -> Result<(), Error> {
let mut conn = pool.get_conn().await?; let mut conn = pool.get_conn().await?;
let schema = discover_schema(&mut conn, table).await?; let schema = discover_schema(&mut conn, table).await?;
let file = tokio::fs::File::create(format!("{table}.parquet")).await?; let file = File::create(format!("{table}.parquet")).await?;
let mut writer = AsyncArrowWriter::try_new(file, schema.clone(), Default::default())?; let mut writer = AsyncArrowWriter::try_new(file, schema.clone(), Default::default())?;
const BATCH_SIZE: usize = 50000; const BATCH_SIZE: usize = 50000;
let mut rows = Vec::with_capacity(BATCH_SIZE); // Pre-allocate for efficiency let mut rows = Vec::with_capacity(BATCH_SIZE);
let mut stream = conn let mut stream = conn
.exec_stream(format!("select * from {table}"), Params::Empty) .exec_stream(format!("select * from {table}"), Params::Empty)
.await?; .await?;
@ -55,8 +54,8 @@ pub async fn convert_data(pool: &Pool, table: &str) -> Result<(), Error> {
} }
async fn discover_schema(conn: &mut Conn, table: &str) -> Result<Arc<Schema>, Error> { async fn discover_schema(conn: &mut Conn, table: &str) -> Result<Arc<Schema>, Error> {
let query = format!("DESCRIBE {}", table); let query = format!("describe {}", table);
let rows: Vec<mysql_async::Row> = conn.exec(query, ()).await?; let rows: Vec<Row> = conn.exec(query, ()).await?;
let mut fields = Vec::new(); let mut fields = Vec::new();
for row in rows { for row in rows {
let name: String = row.get("Field").unwrap(); let name: String = row.get("Field").unwrap();
@ -77,72 +76,65 @@ async fn discover_schema(conn: &mut Conn, table: &str) -> Result<Arc<Schema>, Er
} }
async fn process_parquet( async fn process_parquet(
rows: &mut Vec<mysql_async::Row>, rows: &mut Vec<Row>,
schema: Arc<Schema>, schema: Arc<Schema>,
writer: &mut AsyncArrowWriter<tokio::fs::File>, writer: &mut AsyncArrowWriter<File>,
) -> Result<(), Error> { ) -> Result<(), Error> {
if rows.is_empty() { if rows.is_empty() {
return Ok(()); return Ok(());
} }
let epoch = mysql_common::chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); let epoch = mysql_common::chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
let mut column_data: Vec<Box<dyn Array>> = Vec::new(); let mut data: Vec<Box<dyn Array>> = Vec::new();
for field in schema.fields() { for field in schema.fields() {
let name = field.name(); let name = field.name().as_str();
println!("Importing {name}");
match field.data_type() { match field.data_type() {
DataType::Int32 => { DataType::Int32 => {
let values: Vec<Option<i32>> = rows let values: Vec<Option<i32>> =
.par_iter() rows.par_iter().map(|row| row.get::<i32, _>(name)).collect();
.map(|row| row.get::<i32, _>(name.as_str())) data.push(Box::new(Int32Array::from(values)));
.collect();
column_data.push(Box::new(Int32Array::from(values)));
} }
DataType::Int64 => { DataType::Int64 => {
let values: Vec<Option<i64>> = rows let values: Vec<Option<i64>> =
.par_iter() rows.par_iter().map(|row| row.get::<i64, _>(name)).collect();
.map(|row| row.get::<i64, _>(name.as_str())) data.push(Box::new(Int64Array::from(values)));
.collect();
column_data.push(Box::new(Int64Array::from(values)));
} }
DataType::Utf8 => { DataType::Utf8 => {
let values: Vec<Option<String>> = rows let values: Vec<Option<String>> = rows
.par_iter() .par_iter()
.map(|row| row.get::<String, _>(name.as_str())) .map(|row| row.get::<String, _>(name))
.collect(); .collect();
column_data.push(Box::new(StringArray::from(values))); data.push(Box::new(StringArray::from(values)));
} }
DataType::Timestamp(TimeUnit::Microsecond, None) => { DataType::Timestamp(TimeUnit::Microsecond, None) => {
let values: Vec<Option<i64>> = rows let values: Vec<Option<i64>> = rows
.par_iter() .par_iter()
.map(|row| { .map(|row| {
// Convert MySQL datetime to microseconds since epoch row.get::<mysql_common::chrono::NaiveDateTime, _>(name)
row.get::<mysql_common::chrono::NaiveDateTime, _>(name.as_str())
.map(|dt| dt.and_utc().timestamp_micros()) .map(|dt| dt.and_utc().timestamp_micros())
}) })
.collect(); .collect();
column_data.push(Box::new(TimestampMicrosecondArray::from(values))); data.push(Box::new(TimestampMicrosecondArray::from(values)));
} }
DataType::Date32 => { DataType::Date32 => {
let values: Vec<Option<i32>> = rows let values: Vec<Option<i32>> = rows
.par_iter() .par_iter()
.map(|row| { .map(|row| {
let date = row.get::<mysql_common::chrono::NaiveDate, _>(name.as_str()); let date = row.get::<mysql_common::chrono::NaiveDate, _>(name);
Some((date.unwrap() - epoch).num_days() as i32) Some((date.unwrap() - epoch).num_days() as i32)
}) })
.collect(); .collect();
column_data.push(Box::new(Date32Array::from(values))); data.push(Box::new(Date32Array::from(values)));
} }
_ => { _ => {
// Fallback to string for unknown types
let values: Vec<Option<String>> = rows let values: Vec<Option<String>> = rows
.par_iter() .par_iter()
.map(|row| row.get::<String, _>(name.as_str())) .map(|row| row.get::<String, _>(name))
.collect(); .collect();
column_data.push(Box::new(StringArray::from(values))); data.push(Box::new(StringArray::from(values)));
} }
} }
} }
let columns: Vec<Arc<dyn Array>> = column_data.into_iter().map(|arr| arr.into()).collect(); let columns: Vec<Arc<dyn Array>> = data.into_iter().map(|arr| arr.into()).collect();
let batch = RecordBatch::try_new(schema.clone(), columns)?; let batch = RecordBatch::try_new(schema.clone(), columns)?;
writer.write(&batch).await?; writer.write(&batch).await?;
rows.clear(); rows.clear();