Skip to content

Commit

Permalink
Fix Immer type inference for setState (#2696)
Browse files Browse the repository at this point in the history
* fix(immer): tweak type inference to base `setState` type off of store `setState` instead of `getState`

* fix(immer): instead, infer type directly from StoreApi<T>["setState"]

* fix(immer): instead of using `StoreApi`, extract from A2 the non-functional component of state

* docs: add comment describing why it is not derived from `A1`

* test: add example middleware that modifies getState w/o setState

* fix: add assertion for inner `set` and `get` types

---------

Co-authored-by: Daishi Kato <dai-shi@users.noreply.github.com>
  • Loading branch information
chrisvander and dai-shi committed Aug 27, 2024
1 parent 42bbfcf commit d7345da
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 5 deletions.
15 changes: 12 additions & 3 deletions src/middleware/immer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,33 @@ type SkipTwo<T> = T extends { length: 0 }
? A
: never

type SetStateType<T extends unknown[]> = Exclude<T[0], (...args: any[]) => any>

type WithImmer<S> = Write<S, StoreImmer<S>>

type StoreImmer<S> = S extends {
getState: () => infer T
setState: infer SetState
}
? SetState extends {
(...a: infer A1): infer Sr1
(...a: infer A2): infer Sr2
}
? {
// Ideally, we would want to infer the `nextStateOrUpdater` `T` type from the
// `A1` type, but this is infeasible since it is an intersection with
// a partial type.
setState(
nextStateOrUpdater: T | Partial<T> | ((state: Draft<T>) => void),
nextStateOrUpdater:
| SetStateType<A2>
| Partial<SetStateType<A2>>
| ((state: Draft<SetStateType<A2>>) => void),
shouldReplace?: false,
...a: SkipTwo<A1>
): Sr1
setState(
nextStateOrUpdater: T | ((state: Draft<T>) => void),
nextStateOrUpdater:
| SetStateType<A2>
| ((state: Draft<SetStateType<A2>>) => void),
shouldReplace: true,
...a: SkipTwo<A2>
): Sr2
Expand Down
58 changes: 56 additions & 2 deletions tests/middlewareTypes.test.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/* eslint @typescript-eslint/no-unused-expressions: off */ // FIXME
/* eslint react-compiler/react-compiler: off */

import { describe, expect, it } from 'vitest'
import { describe, expect, expectTypeOf, it } from 'vitest'
import { create } from 'zustand'
import type { StoreApi } from 'zustand'
import type { StateCreator, StoreApi, StoreMutatorIdentifier } from 'zustand'
import {
combine,
devtools,
Expand All @@ -19,6 +19,27 @@ type CounterState = {
inc: () => void
}

type ExampleStateCreator<T, A> = <
Mps extends [StoreMutatorIdentifier, unknown][] = [],
Mcs extends [StoreMutatorIdentifier, unknown][] = [],
U = T,
>(
f: StateCreator<T, [...Mps, ['org/example', A]], Mcs>,
) => StateCreator<T, Mps, [['org/example', A], ...Mcs], U & A>

type Write<T, U> = Omit<T, keyof U> & U
type StoreModifyAllButSetState<S, A> = S extends {
getState: () => infer T
}
? Omit<StoreApi<T & A>, 'setState'>
: never

declare module 'zustand/vanilla' {
interface StoreMutators<S, A> {
'org/example': Write<S, StoreModifyAllButSetState<S, A>>
}
}

describe('counter state spec (no middleware)', () => {
it('no middleware', () => {
const useBoundStore = create<CounterState>((set, get) => ({
Expand Down Expand Up @@ -64,6 +85,39 @@ describe('counter state spec (single middleware)', () => {
immer(() => ({ count: 0 })),
)
expect(testSubtyping).toBeDefined()

const exampleMiddleware = ((initializer) =>
initializer) as ExampleStateCreator<CounterState, { additional: number }>

const testDerivedSetStateType = create<CounterState>()(
exampleMiddleware(
immer((set, get) => ({
count: 0,
inc: () =>
set((state) => {
state.count = get().count + 1
type OmitFn<T> = Exclude<T, (...args: any[]) => any>
expectTypeOf<
OmitFn<Parameters<typeof set>[0]>
>().not.toMatchTypeOf<{ additional: number }>()
expectTypeOf<ReturnType<typeof get>>().toMatchTypeOf<{
additional: number
}>()
}),
})),
),
)
expect(testDerivedSetStateType).toBeDefined()
// the type of the `getState` should include our new property
expectTypeOf(testDerivedSetStateType.getState()).toMatchTypeOf<{
additional: number
}>()
// the type of the `setState` should not include our new property
expectTypeOf<
Parameters<typeof testDerivedSetStateType.setState>[0]
>().not.toMatchTypeOf<{
additional: number
}>()
})

it('redux', () => {
Expand Down

0 comments on commit d7345da

Please sign in to comment.