Minor code cleanup.

This commit is contained in:
Nolan Darilek 2025-06-16 17:36:50 -04:00
parent 46bee6ee60
commit 2b1a470196

View file

@ -8,9 +8,10 @@ use arrow::{
datatypes::{DataType, Field, Schema, TimeUnit},
};
use futures::TryStreamExt;
use mysql_async::{Conn, Pool, params::Params, prelude::*};
use mysql_async::{Conn, Pool, Row, params::Params, prelude::*};
use parquet::arrow::AsyncArrowWriter;
use rayon::prelude::*;
use tokio::fs::File;
#[cfg(test)]
mod test;
@ -18,11 +19,9 @@ mod test;
type Error = Box<dyn std::error::Error + Send + Sync>;
pub async fn load_data(pool: &Pool) -> Result<(), Error> {
// Obviously use something better/more robust here if you're a) loading
// other data sources and b) aren't so lucky as to have a single `;` after
// each block of SQL. :)
let data = String::from_utf8_lossy(include_bytes!("../data.sql"));
for stmt in data.split(";") {
// Obviously do something better if you don't control the input data.
for stmt in data.trim().split(";") {
if !stmt.trim().is_empty() {
let mut conn = pool.get_conn().await?;
conn.exec_drop(stmt, Params::Empty).await?;
@ -34,10 +33,10 @@ pub async fn load_data(pool: &Pool) -> Result<(), Error> {
pub async fn convert_data(pool: &Pool, table: &str) -> Result<(), Error> {
let mut conn = pool.get_conn().await?;
let schema = discover_schema(&mut conn, table).await?;
let file = tokio::fs::File::create(format!("{table}.parquet")).await?;
let file = File::create(format!("{table}.parquet")).await?;
let mut writer = AsyncArrowWriter::try_new(file, schema.clone(), Default::default())?;
const BATCH_SIZE: usize = 50000;
let mut rows = Vec::with_capacity(BATCH_SIZE); // Pre-allocate for efficiency
let mut rows = Vec::with_capacity(BATCH_SIZE);
let mut stream = conn
.exec_stream(format!("select * from {table}"), Params::Empty)
.await?;
@ -55,8 +54,8 @@ pub async fn convert_data(pool: &Pool, table: &str) -> Result<(), Error> {
}
async fn discover_schema(conn: &mut Conn, table: &str) -> Result<Arc<Schema>, Error> {
let query = format!("DESCRIBE {}", table);
let rows: Vec<mysql_async::Row> = conn.exec(query, ()).await?;
let query = format!("describe {}", table);
let rows: Vec<Row> = conn.exec(query, ()).await?;
let mut fields = Vec::new();
for row in rows {
let name: String = row.get("Field").unwrap();
@ -77,72 +76,65 @@ async fn discover_schema(conn: &mut Conn, table: &str) -> Result<Arc<Schema>, Er
}
async fn process_parquet(
rows: &mut Vec<mysql_async::Row>,
rows: &mut Vec<Row>,
schema: Arc<Schema>,
writer: &mut AsyncArrowWriter<tokio::fs::File>,
writer: &mut AsyncArrowWriter<File>,
) -> Result<(), Error> {
if rows.is_empty() {
return Ok(());
}
let epoch = mysql_common::chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
let mut column_data: Vec<Box<dyn Array>> = Vec::new();
let mut data: Vec<Box<dyn Array>> = Vec::new();
for field in schema.fields() {
let name = field.name();
println!("Importing {name}");
let name = field.name().as_str();
match field.data_type() {
DataType::Int32 => {
let values: Vec<Option<i32>> = rows
.par_iter()
.map(|row| row.get::<i32, _>(name.as_str()))
.collect();
column_data.push(Box::new(Int32Array::from(values)));
let values: Vec<Option<i32>> =
rows.par_iter().map(|row| row.get::<i32, _>(name)).collect();
data.push(Box::new(Int32Array::from(values)));
}
DataType::Int64 => {
let values: Vec<Option<i64>> = rows
.par_iter()
.map(|row| row.get::<i64, _>(name.as_str()))
.collect();
column_data.push(Box::new(Int64Array::from(values)));
let values: Vec<Option<i64>> =
rows.par_iter().map(|row| row.get::<i64, _>(name)).collect();
data.push(Box::new(Int64Array::from(values)));
}
DataType::Utf8 => {
let values: Vec<Option<String>> = rows
.par_iter()
.map(|row| row.get::<String, _>(name.as_str()))
.map(|row| row.get::<String, _>(name))
.collect();
column_data.push(Box::new(StringArray::from(values)));
data.push(Box::new(StringArray::from(values)));
}
DataType::Timestamp(TimeUnit::Microsecond, None) => {
let values: Vec<Option<i64>> = rows
.par_iter()
.map(|row| {
// Convert MySQL datetime to microseconds since epoch
row.get::<mysql_common::chrono::NaiveDateTime, _>(name.as_str())
row.get::<mysql_common::chrono::NaiveDateTime, _>(name)
.map(|dt| dt.and_utc().timestamp_micros())
})
.collect();
column_data.push(Box::new(TimestampMicrosecondArray::from(values)));
data.push(Box::new(TimestampMicrosecondArray::from(values)));
}
DataType::Date32 => {
let values: Vec<Option<i32>> = rows
.par_iter()
.map(|row| {
let date = row.get::<mysql_common::chrono::NaiveDate, _>(name.as_str());
let date = row.get::<mysql_common::chrono::NaiveDate, _>(name);
Some((date.unwrap() - epoch).num_days() as i32)
})
.collect();
column_data.push(Box::new(Date32Array::from(values)));
data.push(Box::new(Date32Array::from(values)));
}
_ => {
// Fallback to string for unknown types
let values: Vec<Option<String>> = rows
.par_iter()
.map(|row| row.get::<String, _>(name.as_str()))
.map(|row| row.get::<String, _>(name))
.collect();
column_data.push(Box::new(StringArray::from(values)));
data.push(Box::new(StringArray::from(values)));
}
}
}
let columns: Vec<Arc<dyn Array>> = column_data.into_iter().map(|arr| arr.into()).collect();
let columns: Vec<Arc<dyn Array>> = data.into_iter().map(|arr| arr.into()).collect();
let batch = RecordBatch::try_new(schema.clone(), columns)?;
writer.write(&batch).await?;
rows.clear();