Skip to content

Commit

Permalink
Enforce stricter types for formatters in MappedMatrixVis
Browse files Browse the repository at this point in the history
  • Loading branch information
axelboc committed Sep 3, 2024
1 parent cb20c1b commit 8746fe1
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import visualizerStyles from '../../../visualizer/Visualizer.module.css';
import { useMappedArray, useSlicedDimsAndMapping } from '../hooks';
import type { MatrixVisConfig } from '../matrix/config';
import MatrixToolbar from '../matrix/MatrixToolbar';
import { getCellWidth, getFormatter } from '../matrix/utils';
import { getCellFormatter, getCellWidth } from '../matrix/utils';
import { getSliceSelection } from '../utils';

interface Props {
Expand Down Expand Up @@ -46,7 +46,7 @@ function MappedCompoundVis(props: Props) {
);

const fieldFormatters = Object.values(fields).map((field) =>
getFormatter(field, notation),
getCellFormatter(mappedArray, field, notation),
);

const { getExportURL } = useDataContext();
Expand All @@ -70,9 +70,7 @@ function MappedCompoundVis(props: Props) {
<MatrixVis
className={visualizerStyles.vis}
dims={mappedArray.shape}
cellFormatter={(row, col) =>
fieldFormatters[col](mappedArray.get(row, col))
}
cellFormatter={(row, col) => fieldFormatters[col](row, col)}
cellWidth={customCellWidth ?? cellWidth}
columnHeaders={fieldNames}
sticky={sticky}
Expand Down
8 changes: 3 additions & 5 deletions packages/app/src/vis-packs/core/matrix/MappedMatrixVis.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { useMappedArray, useSlicedDimsAndMapping } from '../hooks';
import { getSliceSelection } from '../utils';
import type { MatrixVisConfig } from './config';
import MatrixToolbar from './MatrixToolbar';
import { getCellWidth, getFormatter } from './utils';
import { getCellFormatter, getCellWidth } from './utils';

interface Props {
dataset: Dataset<ArrayShape, PrintableType>;
Expand All @@ -32,7 +32,7 @@ function MappedMatrixVis(props: Props) {
const [slicedDims, slicedMapping] = useSlicedDimsAndMapping(dims, dimMapping);
const mappedArray = useMappedArray(value, slicedDims, slicedMapping);

const formatter = getFormatter(type, notation);
const cellFormatter = getCellFormatter(mappedArray, type, notation);
const cellWidth = getCellWidth(type);

const { getExportURL } = useDataContext();
Expand All @@ -57,9 +57,7 @@ function MappedMatrixVis(props: Props) {
<MatrixVis
className={visualizerStyles.vis}
dims={mappedArray.shape}
cellFormatter={(row: number, col: number) =>
formatter(mappedArray.get(row, col))
}
cellFormatter={cellFormatter}
cellWidth={customCellWidth ?? cellWidth}
sticky={sticky}
/>
Expand Down
28 changes: 20 additions & 8 deletions packages/app/src/vis-packs/core/matrix/utils.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import { Notation } from '@h5web/lib';
import {
assertNdArrayValue,
isBoolType,
isComplexType,
isEnumType,
isNumericType,
} from '@h5web/shared/guards';
import type {
BooleanType,
ArrayValue,
ComplexType,
NumericType,
PrintableCompoundType,
Expand All @@ -20,6 +21,7 @@ import {
formatBool,
} from '@h5web/shared/vis-utils';
import { format } from 'd3-format';
import type { NdArray } from 'ndarray';

export function createNumericFormatter(
notation: Notation,
Expand All @@ -45,27 +47,37 @@ export function createMatrixComplexFormatter(
return createComplexFormatter(formatStr, true);
}

export function getFormatter(
export function getCellFormatter(
dataArray: NdArray<ArrayValue<PrintableType>>,
type: PrintableType,
notation: Notation,
): ValueFormatter<PrintableType> {
): (row: number, col: number) => string {
if (isComplexType(type)) {
return createMatrixComplexFormatter(notation);
assertNdArrayValue(type, dataArray);
const formatter = createMatrixComplexFormatter(notation);
return (row, col) => formatter(dataArray.get(row, col));
}

if (isNumericType(type)) {
return createNumericFormatter(notation);
assertNdArrayValue(type, dataArray);
const formatter = createNumericFormatter(notation);
return (row, col) => formatter(dataArray.get(row, col));
}

if (isBoolType(type)) {
return formatBool as ValueFormatter<BooleanType>;
assertNdArrayValue(type, dataArray);
return (row, col) => formatBool(dataArray.get(row, col));
}

if (isEnumType(type)) {
return createEnumFormatter(type.mapping);
assertNdArrayValue(type, dataArray);
const formatter = createEnumFormatter(type.mapping);
return (row, col) => formatter(dataArray.get(row, col));
}

return (val) => (val as string).toString(); // call `toString()` for safety, in case type cast is wrong
// `StringType`
assertNdArrayValue(type, dataArray);
return (row, col) => dataArray.get(row, col);
}

export function getCellWidth(
Expand Down
10 changes: 10 additions & 0 deletions packages/shared/src/guards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { Data, NdArray, TypedArray } from 'ndarray';

import type {
ArrayShape,
ArrayValue,
BooleanType,
ComplexArray,
ComplexType,
Expand Down Expand Up @@ -438,6 +439,15 @@ function assertPrimitiveValue<T extends DType>(
}
}

export function assertNdArrayValue<T extends DType>(
type: T,
value: NdArray<unknown[] | TypedArray>,
): asserts value is NdArray<ArrayValue<T>> {
if (value.size > 0) {
assertPrimitiveValue(type, value.get(0));
}
}

export function assertDatasetValue<D extends Dataset<ScalarShape | ArrayShape>>(
value: unknown,
dataset: D,
Expand Down

0 comments on commit 8746fe1

Please sign in to comment.