Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/entity-database-adapter-knex/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
},
"dependencies": {
"@expo/entity": "workspace:^",
"knex": "^3.1.0"
"knex": "^3.2.9"
},
"devDependencies": {
"@expo/entity-testing-utils": "workspace:^",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export class PostgresEntityDatabaseAdapter<
.select()
.from(tableName)
.whereRaw(`(??) = ANY(?)`, [
tableColumns[0],
tableColumns[0]!,
tableTuples.map((tableTuple) => tableTuple[0]),
]),
);
Expand Down
45 changes: 23 additions & 22 deletions packages/entity-database-adapter-knex/src/SQLOperator.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import assert from 'assert';
import type { Knex } from 'knex';

/**
* Supported SQL value types that can be safely parameterized.
Expand All @@ -12,7 +13,6 @@ export type SupportedSQLValue =
| Date
| Buffer
| bigint
| undefined // Will be treated as NULL
| readonly SupportedSQLValue[] // For IN clauses and array types
| Readonly<{ [key: string]: unknown }>; // For JSON/JSONB columns

Expand Down Expand Up @@ -41,15 +41,18 @@ export class SQLFragment<TFields extends Record<string, any>> {
*/
getKnexBindings(
getColumnForField: (fieldName: keyof TFields) => string,
): readonly SupportedSQLValue[] {
): readonly Knex.RawBinding[] {
return this.bindings.map((b) => {
switch (b.type) {
case 'entityField':
return getColumnForField(b.fieldName);
case 'identifier':
return b.name;
case 'value':
return b.value;
// Needs a cast since bigint is supported by knex postgres dialect but not all dialects, and thus isn't included
// in the type. Because we only use the postgres dialect in this adapter, it's safe to allow it here.
// https://github.com/knex/knex/issues/5013#issuecomment-3368744254
return b.value as Knex.RawBinding;
}
});
}
Expand Down Expand Up @@ -133,8 +136,8 @@ export class SQLFragment<TFields extends Record<string, any>> {
* Handles all SupportedSQLValue types.
*/
private static formatDebugValue(value: SupportedSQLValue): string {
// Handle null and undefined
if (value === null || value === undefined) {
// Handle null
if (value === null) {
return 'NULL';
}

Expand Down Expand Up @@ -303,7 +306,7 @@ export function sql<TFields extends Record<string, any>>(
strings.forEach((string, i) => {
sqlString += string;
if (i < values.length) {
const value = values[i];
const value = values[i]!;

if (value instanceof SQLFragment) {
// Handle nested SQL fragments
Expand Down Expand Up @@ -344,7 +347,7 @@ type PickSupportedSQLValueKeys<T> = {
}[keyof T];

type PickStringValueKeys<T> = {
[K in keyof T]: T[K] extends string | null | undefined ? K : never;
[K in keyof T]: T[K] extends string | null ? K : never;
}[keyof T];

type JsonSerializable =
Expand All @@ -368,27 +371,27 @@ export class SQLChainableFragment<
> extends SQLFragment<TFields> {
/**
* Generates an equality condition (`= value`).
* Automatically converts `null`/`undefined` to `IS NULL`.
* Automatically converts `null` to `IS NULL`.
*
* @param value - The value to compare against
* @returns A {@link SQLFragment} representing the equality condition
*/
eq(value: TValue | null | undefined): SQLFragment<TFields> {
if (value === null || value === undefined) {
eq(value: TValue | null): SQLFragment<TFields> {
if (value === null) {
return this.isNull();
}
return sql`${this} = ${value}`;
}

/**
* Generates an inequality condition (`!= value`).
* Automatically converts `null`/`undefined` to `IS NOT NULL`.
* Automatically converts `null` to `IS NOT NULL`.
*
* @param value - The value to compare against
* @returns A {@link SQLFragment} representing the inequality condition
*/
neq(value: TValue | null | undefined): SQLFragment<TFields> {
if (value === null || value === undefined) {
neq(value: TValue | null): SQLFragment<TFields> {
if (value === null) {
return this.isNotNull();
}
return sql`${this} != ${value}`;
Expand Down Expand Up @@ -635,9 +638,7 @@ type ExtractFragmentFields<T> = T extends SQLFragment<infer F> ? F : never;
// Conditional value types for expression overloads.
// Uses SQLChainableFragment<any, ...> so that TExpr alone drives inference (single type param).
type FragmentValueNullable<TFragment> =
TFragment extends SQLChainableFragment<any, infer TValue>
? TValue | null | undefined
: SupportedSQLValue;
TFragment extends SQLChainableFragment<any, infer TValue> ? TValue | null : SupportedSQLValue;

type FragmentValue<TFragment> =
TFragment extends SQLChainableFragment<any, infer TValue> ? TValue : SupportedSQLValue;
Expand Down Expand Up @@ -950,7 +951,7 @@ function isNotNullHelper<TFields extends Record<string, any>>(

/**
* Generates an equality condition (`= value`) from a fragment.
* Automatically converts `null`/`undefined` to `IS NULL`.
* Automatically converts `null` to `IS NULL`.
*
* @param fragment - A SQLFragment or SQLChainableFragment to compare
* @param value - The value to compare against
Expand All @@ -961,7 +962,7 @@ function eqHelper<TFragment extends SQLFragment<any>>(
): SQLFragment<ExtractFragmentFields<TFragment>>;
/**
* Generates an equality condition (`= value`) from a field name.
* Automatically converts `null`/`undefined` to `IS NULL`.
* Automatically converts `null` to `IS NULL`.
*
* @param fieldName - The entity field name to compare
* @param value - The value to compare against
Expand All @@ -979,7 +980,7 @@ function eqHelper<TFields extends Record<string, any>>(

/**
* Generates an inequality condition (`!= value`) from a fragment.
* Automatically converts `null`/`undefined` to `IS NOT NULL`.
* Automatically converts `null` to `IS NOT NULL`.
*
* @param fragment - A SQLFragment or SQLChainableFragment to compare
* @param value - The value to compare against
Expand All @@ -990,7 +991,7 @@ function neqHelper<TFragment extends SQLFragment<any>>(
): SQLFragment<ExtractFragmentFields<TFragment>>;
/**
* Generates an inequality condition (`!= value`) from a field name.
* Automatically converts `null`/`undefined` to `IS NOT NULL`.
* Automatically converts `null` to `IS NOT NULL`.
*
* @param fieldName - The entity field name to compare
* @param value - The value to compare against
Expand Down Expand Up @@ -1521,12 +1522,12 @@ export const SQLExpression = {
isNotNull: isNotNullHelper,

/**
* Equality operator. Automatically converts null/undefined to IS NULL.
* Equality operator. Automatically converts null to IS NULL.
*/
eq: eqHelper,

/**
* Inequality operator. Automatically converts null/undefined to IS NOT NULL.
* Inequality operator. Automatically converts null to IS NOT NULL.
*/
neq: neqHelper,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,11 @@ describe('SQLOperator', () => {
});

it('handles all SupportedSQLValue types in getDebugString', () => {
const fragment = new SQLFragment('INSERT INTO test VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', [
const fragment = new SQLFragment('INSERT INTO test VALUES (?, ?, ?, ?, ?, ?, ?, ?)', [
{ type: 'value', value: 'string' },
{ type: 'value', value: 123 },
{ type: 'value', value: true },
{ type: 'value', value: null },
{ type: 'value', value: undefined },
{ type: 'value', value: new Date('2024-01-01T00:00:00.000Z') },
{ type: 'value', value: Buffer.from('hello') },
{ type: 'value', value: BigInt(999) },
Expand All @@ -272,7 +271,7 @@ describe('SQLOperator', () => {

const text = fragment.getDebugString();
expect(text).toBe(
"INSERT INTO test VALUES ('string', 123, TRUE, NULL, NULL, '2024-01-01T00:00:00.000Z', '\\x68656c6c6f', 999, ARRAY[1, 2, 3])",
"INSERT INTO test VALUES ('string', 123, TRUE, NULL, '2024-01-01T00:00:00.000Z', '\\x68656c6c6f', 999, ARRAY[1, 2, 3])",
);
});

Expand Down Expand Up @@ -763,13 +762,6 @@ describe('SQLOperator', () => {
expect(fragment.getKnexBindings(getColumnForField)).toEqual(['nullable_field']);
});

it('handles undefined in equality check', () => {
const fragment = SQLExpression.eq('nullableField', undefined);

expect(fragment.sql).toBe('?? IS NULL');
expect(fragment.getKnexBindings(getColumnForField)).toEqual(['nullable_field']);
});

it('accepts a SQLFragment expression', () => {
const fragment = SQLExpression.eq(sql<TestFields>`${entityField('stringField')}`, 'active');
expect(fragment.sql).toBe('?? = ?');
Expand Down Expand Up @@ -801,13 +793,6 @@ describe('SQLOperator', () => {
expect(fragment.getKnexBindings(getColumnForField)).toEqual(['nullable_field']);
});

it('handles undefined in inequality check', () => {
const fragment = SQLExpression.neq('nullableField', undefined);

expect(fragment.sql).toBe('?? IS NOT NULL');
expect(fragment.getKnexBindings(getColumnForField)).toEqual(['nullable_field']);
});

it('accepts a SQLFragment expression', () => {
const fragment = SQLExpression.neq(
sql<TestFields>`${entityField('stringField')}`,
Expand Down Expand Up @@ -1131,12 +1116,6 @@ describe('SQLOperator', () => {
expect(fragment.getKnexBindings(getColumnForField)).toEqual(['string_field']);
});

it('eq(undefined) uses IS NULL', () => {
const fragment = makeExpr<string>(stringFieldFragment()).eq(undefined);
expect(fragment.sql).toBe('?? IS NULL');
expect(fragment.getKnexBindings(getColumnForField)).toEqual(['string_field']);
});

it('neq(value)', () => {
const fragment = makeExpr<string>(stringFieldFragment()).neq('deleted');
expect(fragment.sql).toBe('?? != ?');
Expand All @@ -1149,12 +1128,6 @@ describe('SQLOperator', () => {
expect(fragment.getKnexBindings(getColumnForField)).toEqual(['string_field']);
});

it('neq(undefined) uses IS NOT NULL', () => {
const fragment = makeExpr<string>(stringFieldFragment()).neq(undefined);
expect(fragment.sql).toBe('?? IS NOT NULL');
expect(fragment.getKnexBindings(getColumnForField)).toEqual(['string_field']);
});

it('gt(value)', () => {
const fragment = makeExpr<number>(intFieldFragment()).gt(10);
expect(fragment.sql).toBe('?? > ?');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,17 +480,16 @@ export class EntityKnexDataManager<
baseWhere: SQLFragment<TFields> | undefined,
cursorCondition: SQLFragment<TFields> | null,
): SQLFragment<TFields> {
const conditions = [baseWhere, cursorCondition].filter((it) => !!it);
if (conditions.length === 0) {
return sql`TRUE`;
if (!baseWhere) {
return cursorCondition ?? sql`TRUE`;
}
if (conditions.length === 1) {
return conditions[0]!;

if (!cursorCondition) {
return baseWhere;
}
// Wrap baseWhere in parens if combining with cursor condition
// We know we have exactly 2 conditions at this point
const [first, second] = conditions;
return sql`(${first}) AND ${second}`;

// Wrap baseWhere in parens when combining with cursor condition
return sql`(${baseWhere}) AND ${cursorCondition}`;
}

private augmentOrderByIfNecessary(
Expand Down
14 changes: 9 additions & 5 deletions yarn.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.